source: tags/MetisMQIDemo/src/main/java/weka/classifiers/bayes/net/search/local/HillClimber.java

Last change on this file was 29, checked in by gnappo, 15 years ago

Taggata versione per la demo e aggiunto branch.

File size: 20.8 KB
Line 
1/*
2 * This program is free software; you can redistribute it and/or modify
3 * it under the terms of the GNU General Public License as published by
4 * the Free Software Foundation; either version 2 of the License, or
5 * (at your option) any later version.
6 *
7 * This program is distributed in the hope that it will be useful,
8 * but WITHOUT ANY WARRANTY; without even the implied warranty of
9 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 * GNU General Public License for more details.
11 *
12 * You should have received a copy of the GNU General Public License
13 * along with this program; if not, write to the Free Software
14 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 * HillClimber.java
19 * Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
20 *
21 */
22 
23package weka.classifiers.bayes.net.search.local;
24
25import weka.classifiers.bayes.BayesNet;
26import weka.classifiers.bayes.net.ParentSet;
27import weka.core.Instances;
28import weka.core.Option;
29import weka.core.RevisionHandler;
30import weka.core.RevisionUtils;
31import weka.core.Utils;
32
33import java.io.Serializable;
34import java.util.Enumeration;
35import java.util.Vector;
36
37/**
38 <!-- globalinfo-start -->
39 * This Bayes Network learning algorithm uses a hill climbing algorithm adding, deleting and reversing arcs. The search is not restricted by an order on the variables (unlike K2). The difference with B and B2 is that this hill climber also considers arrows part of the naive Bayes structure for deletion.
40 * <p/>
41 <!-- globalinfo-end -->
42 *
43 <!-- options-start -->
44 * Valid options are: <p/>
45 *
46 * <pre> -P &lt;nr of parents&gt;
47 *  Maximum number of parents</pre>
48 *
49 * <pre> -R
50 *  Use arc reversal operation.
51 *  (default false)</pre>
52 *
53 * <pre> -N
54 *  Initial structure is empty (instead of Naive Bayes)</pre>
55 *
56 * <pre> -mbc
57 *  Applies a Markov Blanket correction to the network structure,
58 *  after a network structure is learned. This ensures that all
59 *  nodes in the network are part of the Markov blanket of the
60 *  classifier node.</pre>
61 *
62 * <pre> -S [BAYES|MDL|ENTROPY|AIC|CROSS_CLASSIC|CROSS_BAYES]
63 *  Score type (BAYES, BDeu, MDL, ENTROPY and AIC)</pre>
64 *
65 <!-- options-end -->
66 *
67 * @author Remco Bouckaert (rrb@xm.co.nz)
68 * @version $Revision: 1.9 $
69 */
70public class HillClimber 
71    extends LocalScoreSearchAlgorithm {
72 
73    /** for serialization */
74    static final long serialVersionUID = 4322783593818122403L;
75
76        /** the Operation class contains info on operations performed
77         * on the current Bayesian network.
78         */
79    class Operation 
80        implements Serializable, RevisionHandler {
81     
82        /** for serialization */
83        static final long serialVersionUID = -4880888790432547895L;
84     
85        // constants indicating the type of an operation
86        final static int OPERATION_ADD = 0;
87        final static int OPERATION_DEL = 1;
88        final static int OPERATION_REVERSE = 2;
89       
90        /**
91         * c'tor
92         */
93        public Operation() {
94        }
95       
96                /** c'tor + initializers
97                 *
98                 * @param nTail
99                 * @param nHead
100                 * @param nOperation
101                 */ 
102            public Operation(int nTail, int nHead, int nOperation) {
103                        m_nHead = nHead;
104                        m_nTail = nTail;
105                        m_nOperation = nOperation;
106                }
107                /** compare this operation with another
108                 * @param other operation to compare with
109                 * @return true if operation is the same
110                 */
111                public boolean equals(Operation other) {
112                        if (other == null) {
113                                return false;
114                        }
115                        return ((       m_nOperation == other.m_nOperation) &&
116                        (m_nHead == other.m_nHead) &&
117                        (m_nTail == other.m_nTail));
118                } // equals
119               
120                /** number of the tail node **/
121        public int m_nTail;
122       
123                /** number of the head node **/
124        public int m_nHead;
125       
126                /** type of operation (ADD, DEL, REVERSE) **/
127        public int m_nOperation;
128       
129        /** change of score due to this operation **/
130        public double m_fDeltaScore = -1E100;
131       
132        /**
133         * Returns the revision string.
134         *
135         * @return              the revision
136         */
137        public String getRevision() {
138          return RevisionUtils.extract("$Revision: 1.9 $");
139        }
140    } // class Operation
141
142        /** cache for remembering the change in score for steps in the search space
143         */
144        class Cache implements RevisionHandler {
145         
146                /** change in score due to adding an arc **/
147                double [] [] m_fDeltaScoreAdd;
148                /** change in score due to deleting an arc **/
149                double [] [] m_fDeltaScoreDel;
150                /** c'tor
151                 * @param nNrOfNodes number of nodes in network, used to determine memory size to reserve
152                 */
153                Cache(int nNrOfNodes) {
154                        m_fDeltaScoreAdd = new double [nNrOfNodes][nNrOfNodes];
155                        m_fDeltaScoreDel = new double [nNrOfNodes][nNrOfNodes];
156                }
157
158                /** set cache entry
159                 * @param oOperation operation to perform
160                 * @param fValue value to put in cache
161                 */
162                public void put(Operation oOperation, double fValue) {
163                        if (oOperation.m_nOperation == Operation.OPERATION_ADD) {
164                                m_fDeltaScoreAdd[oOperation.m_nTail][oOperation.m_nHead] = fValue;
165                        } else {
166                                m_fDeltaScoreDel[oOperation.m_nTail][oOperation.m_nHead] = fValue;
167                        }
168                } // put
169
170                /** get cache entry
171                 * @param oOperation operation to perform
172                 * @return cache value
173                 */
174                public double get(Operation oOperation) {
175                        switch(oOperation.m_nOperation) {
176                                case Operation.OPERATION_ADD:
177                                        return m_fDeltaScoreAdd[oOperation.m_nTail][oOperation.m_nHead];
178                                case Operation.OPERATION_DEL:
179                                        return m_fDeltaScoreDel[oOperation.m_nTail][oOperation.m_nHead];
180                                case Operation.OPERATION_REVERSE:
181                                return m_fDeltaScoreDel[oOperation.m_nTail][oOperation.m_nHead] + 
182                                                m_fDeltaScoreAdd[oOperation.m_nHead][oOperation.m_nTail];
183                        }
184                        // should never get here
185                        return 0;
186                } // get
187
188                /**
189                 * Returns the revision string.
190                 *
191                 * @return              the revision
192                 */
193                public String getRevision() {
194                  return RevisionUtils.extract("$Revision: 1.9 $");
195                }
196        } // class Cache
197
198        /** cache for storing score differences **/
199        Cache m_Cache = null;
200       
201    /** use the arc reversal operator **/
202    boolean m_bUseArcReversal = false;
203       
204
205    /**
206     * search determines the network structure/graph of the network
207     * with the Taby algorithm.
208     *
209     * @param bayesNet the network to use
210     * @param instances the data to use
211     * @throws Exception if something goes wrong
212     */
213    protected void search(BayesNet bayesNet, Instances instances) throws Exception {
214        initCache(bayesNet, instances);
215
216        // go do the search       
217                Operation oOperation = getOptimalOperation(bayesNet, instances);
218                while ((oOperation != null) && (oOperation.m_fDeltaScore > 0)) {
219                        performOperation(bayesNet, instances, oOperation);
220                        oOperation = getOptimalOperation(bayesNet, instances);
221        }
222       
223                // free up memory
224                m_Cache = null;
225    } // search
226
227
228        /**
229         * initCache initializes the cache
230         *
231         * @param bayesNet Bayes network to be learned
232         * @param instances data set to learn from
233         * @throws Exception if something goes wrong
234         */
235    void initCache(BayesNet bayesNet, Instances instances)  throws Exception {
236       
237        // determine base scores
238                double[] fBaseScores = new double[instances.numAttributes()];
239        int nNrOfAtts = instances.numAttributes();
240
241                m_Cache = new Cache (nNrOfAtts);
242               
243                for (int iAttribute = 0; iAttribute < nNrOfAtts; iAttribute++) {
244                        updateCache(iAttribute, nNrOfAtts, bayesNet.getParentSet(iAttribute));
245                }
246
247
248        for (int iAttribute = 0; iAttribute < nNrOfAtts; iAttribute++) {
249            fBaseScores[iAttribute] = calcNodeScore(iAttribute);
250        }
251
252        for (int iAttributeHead = 0; iAttributeHead < nNrOfAtts; iAttributeHead++) {
253                for (int iAttributeTail = 0; iAttributeTail < nNrOfAtts; iAttributeTail++) {
254                        if (iAttributeHead != iAttributeTail) {
255                            Operation oOperation = new Operation(iAttributeTail, iAttributeHead, Operation.OPERATION_ADD);
256                            m_Cache.put(oOperation, calcScoreWithExtraParent(iAttributeHead, iAttributeTail) - fBaseScores[iAttributeHead]);
257                                        }
258            }
259        }
260
261    } // initCache
262
263        /** check whether the operation is not in the forbidden.
264         * For base hill climber, there are no restrictions on operations,
265         * so we always return true.
266         * @param oOperation operation to be checked
267         * @return true if operation is not in the tabu list
268         */
269        boolean isNotTabu(Operation oOperation) {
270                return true;
271        } // isNotTabu
272
273        /**
274         * getOptimalOperation finds the optimal operation that can be performed
275         * on the Bayes network that is not in the tabu list.
276         *
277         * @param bayesNet Bayes network to apply operation on
278         * @param instances data set to learn from
279         * @return optimal operation found
280         * @throws Exception if something goes wrong
281         */
282    Operation getOptimalOperation(BayesNet bayesNet, Instances instances) throws Exception {
283        Operation oBestOperation = new Operation();
284
285                // Add???
286                oBestOperation = findBestArcToAdd(bayesNet, instances, oBestOperation);
287                // Delete???
288                oBestOperation = findBestArcToDelete(bayesNet, instances, oBestOperation);
289                // Reverse???
290                if (getUseArcReversal()) {
291                        oBestOperation = findBestArcToReverse(bayesNet, instances, oBestOperation);
292                }
293
294                // did we find something?
295                if (oBestOperation.m_fDeltaScore == -1E100) {
296                        return null;
297                }
298
299        return oBestOperation;
300    } // getOptimalOperation
301
302        /**
303         * performOperation applies an operation
304         * on the Bayes network and update the cache.
305         *
306         * @param bayesNet Bayes network to apply operation on
307         * @param instances data set to learn from
308         * @param oOperation operation to perform
309         * @throws Exception if something goes wrong
310         */
311        void performOperation(BayesNet bayesNet, Instances instances, Operation oOperation) throws Exception {
312                // perform operation
313                switch (oOperation.m_nOperation) {
314                        case Operation.OPERATION_ADD:
315                                applyArcAddition(bayesNet, oOperation.m_nHead, oOperation.m_nTail, instances);
316                                if (bayesNet.getDebug()) {
317                                        System.out.print("Add " + oOperation.m_nHead + " -> " + oOperation.m_nTail);
318                                }
319                                break;
320                        case Operation.OPERATION_DEL:
321                                applyArcDeletion(bayesNet, oOperation.m_nHead, oOperation.m_nTail, instances);
322                                if (bayesNet.getDebug()) {
323                                        System.out.print("Del " + oOperation.m_nHead + " -> " + oOperation.m_nTail);
324                                }
325                                break;
326                        case Operation.OPERATION_REVERSE:
327                                applyArcDeletion(bayesNet, oOperation.m_nHead, oOperation.m_nTail, instances);
328                                applyArcAddition(bayesNet, oOperation.m_nTail, oOperation.m_nHead, instances);
329                                if (bayesNet.getDebug()) {
330                                        System.out.print("Rev " + oOperation.m_nHead+ " -> " + oOperation.m_nTail);
331                                }
332                                break;
333                }
334        } // performOperation
335
336
337        /**
338         *
339         * @param bayesNet
340         * @param iHead
341         * @param iTail
342         * @param instances
343         */
344        void applyArcAddition(BayesNet bayesNet, int iHead, int iTail, Instances instances) {
345                ParentSet bestParentSet = bayesNet.getParentSet(iHead);
346                bestParentSet.addParent(iTail, instances);
347                updateCache(iHead, instances.numAttributes(), bestParentSet);
348        } // applyArcAddition
349
350        /**
351         *
352         * @param bayesNet
353         * @param iHead
354         * @param iTail
355         * @param instances
356         */
357        void applyArcDeletion(BayesNet bayesNet, int iHead, int iTail, Instances instances) {
358                ParentSet bestParentSet = bayesNet.getParentSet(iHead);
359                bestParentSet.deleteParent(iTail, instances);
360                updateCache(iHead, instances.numAttributes(), bestParentSet);
361        } // applyArcAddition
362
363
364        /**
365         * find best (or least bad) arc addition operation
366         *
367         * @param bayesNet Bayes network to add arc to
368         * @param instances data set
369         * @param oBestOperation
370         * @return Operation containing best arc to add, or null if no arc addition is allowed
371         * (this can happen if any arc addition introduces a cycle, or all parent sets are filled
372         * up to the maximum nr of parents).
373         */
374        Operation findBestArcToAdd(BayesNet bayesNet, Instances instances, Operation oBestOperation) {
375                int nNrOfAtts = instances.numAttributes();
376                // find best arc to add
377                for (int iAttributeHead = 0; iAttributeHead < nNrOfAtts; iAttributeHead++) {
378                        if (bayesNet.getParentSet(iAttributeHead).getNrOfParents() < m_nMaxNrOfParents) {
379                                for (int iAttributeTail = 0; iAttributeTail < nNrOfAtts; iAttributeTail++) {
380                                        if (addArcMakesSense(bayesNet, instances, iAttributeHead, iAttributeTail)) {
381                                                Operation oOperation = new Operation(iAttributeTail, iAttributeHead, Operation.OPERATION_ADD);
382                                                if (m_Cache.get(oOperation) > oBestOperation.m_fDeltaScore) {
383                                                        if (isNotTabu(oOperation)) {
384                                                                oBestOperation = oOperation;
385                                                                oBestOperation.m_fDeltaScore = m_Cache.get(oOperation);
386                                                        }
387                                                }
388                                        }
389                                }
390                        }
391                }
392                return oBestOperation;
393        } // findBestArcToAdd
394
395        /**
396         * find best (or least bad) arc deletion operation
397         *
398         * @param bayesNet Bayes network to delete arc from
399         * @param instances data set
400         * @param oBestOperation
401         * @return Operation containing best arc to delete, or null if no deletion can be made
402         * (happens when there is no arc in the network yet).
403         */
404        Operation findBestArcToDelete(BayesNet bayesNet, Instances instances, Operation oBestOperation) {
405                int nNrOfAtts = instances.numAttributes();
406                // find best arc to delete
407                for (int iNode = 0; iNode < nNrOfAtts; iNode++) {
408                        ParentSet parentSet = bayesNet.getParentSet(iNode);
409                        for (int iParent = 0; iParent < parentSet.getNrOfParents(); iParent++) {
410                                Operation oOperation = new Operation(parentSet.getParent(iParent), iNode, Operation.OPERATION_DEL);
411                                if (m_Cache.get(oOperation) > oBestOperation.m_fDeltaScore) {
412                                        if (isNotTabu(oOperation)) {
413                                                oBestOperation = oOperation;
414                                                oBestOperation.m_fDeltaScore = m_Cache.get(oOperation);
415                                        }
416                                }
417                        }
418                }
419                return oBestOperation;
420        } // findBestArcToDelete
421
422        /**
423         * find best (or least bad) arc reversal operation
424         *
425         * @param bayesNet Bayes network to reverse arc in
426         * @param instances data set
427         * @param oBestOperation
428         * @return Operation containing best arc to reverse, or null if no reversal is allowed
429         * (happens if there is no arc in the network yet, or when any such reversal introduces
430         * a cycle).
431         */
432        Operation findBestArcToReverse(BayesNet bayesNet, Instances instances, Operation oBestOperation) {
433                int nNrOfAtts = instances.numAttributes();
434                // find best arc to reverse
435                for (int iNode = 0; iNode < nNrOfAtts; iNode++) {
436                        ParentSet parentSet = bayesNet.getParentSet(iNode);
437                        for (int iParent = 0; iParent < parentSet.getNrOfParents(); iParent++) {
438                                int iTail = parentSet.getParent(iParent);
439                                // is reversal allowed?
440                                if (reverseArcMakesSense(bayesNet, instances, iNode, iTail) && 
441                                    bayesNet.getParentSet(iTail).getNrOfParents() < m_nMaxNrOfParents) {
442                                        // go check if reversal results in the best step forward
443                                        Operation oOperation = new Operation(parentSet.getParent(iParent), iNode, Operation.OPERATION_REVERSE);
444                                        if (m_Cache.get(oOperation) > oBestOperation.m_fDeltaScore) {
445                                                if (isNotTabu(oOperation)) {
446                                                        oBestOperation = oOperation;
447                                                        oBestOperation.m_fDeltaScore = m_Cache.get(oOperation);
448                                                }
449                                        }
450                                }
451                        }
452                }
453                return oBestOperation;
454        } // findBestArcToReverse
455
456        /**
457         * update the cache due to change of parent set of a node
458         *
459         * @param iAttributeHead node that has its parent set changed
460         * @param nNrOfAtts number of nodes/attributes in data set
461         * @param parentSet new parents set of node iAttributeHead
462         */
463        void updateCache(int iAttributeHead, int nNrOfAtts, ParentSet parentSet) {
464                // update cache entries for arrows heading towards iAttributeHead
465                double fBaseScore = calcNodeScore(iAttributeHead);
466                int nNrOfParents = parentSet.getNrOfParents();
467                for (int iAttributeTail = 0; iAttributeTail < nNrOfAtts; iAttributeTail++) {
468                        if (iAttributeTail != iAttributeHead) {
469                                if (!parentSet.contains(iAttributeTail)) {
470                                        // add entries to cache for adding arcs
471                                        if (nNrOfParents < m_nMaxNrOfParents) {
472                                                Operation oOperation = new Operation(iAttributeTail, iAttributeHead, Operation.OPERATION_ADD);
473                                                m_Cache.put(oOperation, calcScoreWithExtraParent(iAttributeHead, iAttributeTail) - fBaseScore);
474                                        }
475                                } else {
476                                        // add entries to cache for deleting arcs
477                                        Operation oOperation = new Operation(iAttributeTail, iAttributeHead, Operation.OPERATION_DEL);
478                                        m_Cache.put(oOperation, calcScoreWithMissingParent(iAttributeHead, iAttributeTail) - fBaseScore);
479                                }
480                        }
481                }
482        } // updateCache
483       
484
485        /**
486         * Sets the max number of parents
487         *
488         * @param nMaxNrOfParents the max number of parents
489         */
490        public void setMaxNrOfParents(int nMaxNrOfParents) {
491          m_nMaxNrOfParents = nMaxNrOfParents;
492        } 
493
494        /**
495         * Gets the max number of parents.
496         *
497         * @return the max number of parents
498         */
499        public int getMaxNrOfParents() {
500          return m_nMaxNrOfParents;
501        } 
502
503        /**
504         * Returns an enumeration describing the available options.
505         *
506         * @return an enumeration of all the available options.
507         */
508        public Enumeration listOptions() {
509                Vector newVector = new Vector(2);
510
511                newVector.addElement(new Option("\tMaximum number of parents", "P", 1, "-P <nr of parents>"));
512                newVector.addElement(new Option("\tUse arc reversal operation.\n\t(default false)", "R", 0, "-R"));
513                newVector.addElement(new Option("\tInitial structure is empty (instead of Naive Bayes)", "N", 0, "-N"));
514
515                Enumeration enu = super.listOptions();
516                while (enu.hasMoreElements()) {
517                        newVector.addElement(enu.nextElement());
518                }
519                return newVector.elements();
520        } // listOptions
521
522        /**
523         * Parses a given list of options. <p/>
524         *
525         <!-- options-start -->
526         * Valid options are: <p/>
527         *
528         * <pre> -P &lt;nr of parents&gt;
529         *  Maximum number of parents</pre>
530         *
531         * <pre> -R
532         *  Use arc reversal operation.
533         *  (default false)</pre>
534         *
535         * <pre> -N
536         *  Initial structure is empty (instead of Naive Bayes)</pre>
537         *
538         * <pre> -mbc
539         *  Applies a Markov Blanket correction to the network structure,
540         *  after a network structure is learned. This ensures that all
541         *  nodes in the network are part of the Markov blanket of the
542         *  classifier node.</pre>
543         *
544         * <pre> -S [BAYES|MDL|ENTROPY|AIC|CROSS_CLASSIC|CROSS_BAYES]
545         *  Score type (BAYES, BDeu, MDL, ENTROPY and AIC)</pre>
546         *
547         <!-- options-end -->
548         *
549         * @param options the list of options as an array of strings
550         * @throws Exception if an option is not supported
551         */
552        public void setOptions(String[] options) throws Exception {
553                setUseArcReversal(Utils.getFlag('R', options));
554
555                setInitAsNaiveBayes (!(Utils.getFlag('N', options)));
556               
557                String sMaxNrOfParents = Utils.getOption('P', options);
558                if (sMaxNrOfParents.length() != 0) {
559                  setMaxNrOfParents(Integer.parseInt(sMaxNrOfParents));
560                } else {
561                  setMaxNrOfParents(100000);
562                }
563               
564                super.setOptions(options);
565        } // setOptions
566
567        /**
568         * Gets the current settings of the search algorithm.
569         *
570         * @return an array of strings suitable for passing to setOptions
571         */
572        public String[] getOptions() {
573                String[] superOptions = super.getOptions();
574                String[] options = new String[7 + superOptions.length];
575                int current = 0;
576                if (getUseArcReversal()) {
577                  options[current++] = "-R";
578                }
579               
580                if (!getInitAsNaiveBayes()) {
581                  options[current++] = "-N";
582                } 
583
584                options[current++] = "-P";
585                options[current++] = "" + m_nMaxNrOfParents;
586
587                // insert options from parent class
588                for (int iOption = 0; iOption < superOptions.length; iOption++) {
589                        options[current++] = superOptions[iOption];
590                }
591
592                // Fill up rest with empty strings, not nulls!
593                while (current < options.length) {
594                        options[current++] = "";
595                }
596                return options;
597        } // getOptions
598
599        /**
600         * Sets whether to init as naive bayes
601         *
602         * @param bInitAsNaiveBayes whether to init as naive bayes
603         */
604        public void setInitAsNaiveBayes(boolean bInitAsNaiveBayes) {
605          m_bInitAsNaiveBayes = bInitAsNaiveBayes;
606        } 
607
608        /**
609         * Gets whether to init as naive bayes
610         *
611         * @return whether to init as naive bayes
612         */
613        public boolean getInitAsNaiveBayes() {
614          return m_bInitAsNaiveBayes;
615        } 
616
617        /** get use the arc reversal operation
618         * @return whether the arc reversal operation should be used
619         */
620        public boolean getUseArcReversal() {
621                return m_bUseArcReversal;
622        } // getUseArcReversal
623
624        /** set use the arc reversal operation
625         * @param bUseArcReversal whether the arc reversal operation should be used
626         */
627        public void setUseArcReversal(boolean bUseArcReversal) {
628                m_bUseArcReversal = bUseArcReversal;
629        } // setUseArcReversal
630
631        /**
632         * This will return a string describing the search algorithm.
633         * @return The string.
634         */
635        public String globalInfo() {
636          return "This Bayes Network learning algorithm uses a hill climbing algorithm " +
637          "adding, deleting and reversing arcs. The search is not restricted by an order " +
638          "on the variables (unlike K2). The difference with B and B2 is that this hill " +       
639          "climber also considers arrows part of the naive Bayes structure for deletion.";
640        } // globalInfo
641
642        /**
643         * @return a string to describe the Use Arc Reversal option.
644         */
645        public String useArcReversalTipText() {
646          return "When set to true, the arc reversal operation is used in the search.";
647        } // useArcReversalTipText
648
649        /**
650         * Returns the revision string.
651         *
652         * @return              the revision
653         */
654        public String getRevision() {
655          return RevisionUtils.extract("$Revision: 1.9 $");
656        }
657
658} // HillClimber
Note: See TracBrowser for help on using the repository browser.