source: tags/MetisMQIDemo/src/main/java/weka/classifiers/bayes/net/search/global/HillClimber.java

Last change on this file was 29, checked in by gnappo, 15 years ago

Taggata versione per la demo e aggiunto branch.

File size: 16.9 KB
Line 
1/*
2 * This program is free software; you can redistribute it and/or modify
3 * it under the terms of the GNU General Public License as published by
4 * the Free Software Foundation; either version 2 of the License, or
5 * (at your option) any later version.
6 *
7 * This program is distributed in the hope that it will be useful,
8 * but WITHOUT ANY WARRANTY; without even the implied warranty of
9 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 * GNU General Public License for more details.
11 *
12 * You should have received a copy of the GNU General Public License
13 * along with this program; if not, write to the Free Software
14 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 * HillClimber.java
19 * Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
20 *
21 */
22 
23package weka.classifiers.bayes.net.search.global;
24
25import weka.classifiers.bayes.BayesNet;
26import weka.classifiers.bayes.net.ParentSet;
27import weka.core.Instances;
28import weka.core.Option;
29import weka.core.RevisionHandler;
30import weka.core.RevisionUtils;
31import weka.core.Utils;
32
33import java.io.Serializable;
34import java.util.Enumeration;
35import java.util.Vector;
36
37/**
38 <!-- globalinfo-start -->
39 * This Bayes Network learning algorithm uses a hill climbing algorithm adding, deleting and reversing arcs. The search is not restricted by an order on the variables (unlike K2). The difference with B and B2 is that this hill climber also considers arrows part of the naive Bayes structure for deletion.
40 * <p/>
41 <!-- globalinfo-end -->
42 *
43 <!-- options-start -->
44 * Valid options are: <p/>
45 *
46 * <pre> -P &lt;nr of parents&gt;
47 *  Maximum number of parents</pre>
48 *
49 * <pre> -R
50 *  Use arc reversal operation.
51 *  (default false)</pre>
52 *
53 * <pre> -N
54 *  Initial structure is empty (instead of Naive Bayes)</pre>
55 *
56 * <pre> -mbc
57 *  Applies a Markov Blanket correction to the network structure,
58 *  after a network structure is learned. This ensures that all
59 *  nodes in the network are part of the Markov blanket of the
60 *  classifier node.</pre>
61 *
62 * <pre> -S [LOO-CV|k-Fold-CV|Cumulative-CV]
63 *  Score type (LOO-CV,k-Fold-CV,Cumulative-CV)</pre>
64 *
65 * <pre> -Q
66 *  Use probabilistic or 0/1 scoring.
67 *  (default probabilistic scoring)</pre>
68 *
69 <!-- options-end -->
70 *
71 * @author Remco Bouckaert (rrb@xm.co.nz)
72 * @version $Revision: 1.9 $
73 */
74public class HillClimber 
75    extends GlobalScoreSearchAlgorithm {
76
77    /** for serialization */
78    static final long serialVersionUID = -3885042888195820149L;
79 
80  /**
81   * the Operation class contains info on operations performed
82   * on the current Bayesian network.
83   */
84    class Operation 
85        implements Serializable, RevisionHandler {
86     
87        /** for serialization */
88        static final long serialVersionUID = -2934970456587374967L;
89     
90        // constants indicating the type of an operation
91        final static int OPERATION_ADD = 0;
92        final static int OPERATION_DEL = 1;
93        final static int OPERATION_REVERSE = 2;
94
95        /** c'tor **/
96        public Operation() {
97        }
98       
99                /** c'tor + initializers
100                 *
101                 * @param nTail
102                 * @param nHead
103                 * @param nOperation
104                 */ 
105            public Operation(int nTail, int nHead, int nOperation) {
106                        m_nHead = nHead;
107                        m_nTail = nTail;
108                        m_nOperation = nOperation;
109                }
110                /** compare this operation with another
111                 * @param other operation to compare with
112                 * @return true if operation is the same
113                 */
114                public boolean equals(Operation other) {
115                        if (other == null) {
116                                return false;
117                        }
118                        return ((       m_nOperation == other.m_nOperation) &&
119                        (m_nHead == other.m_nHead) &&
120                        (m_nTail == other.m_nTail));
121                } // equals
122                /** number of the tail node **/
123        public int m_nTail;
124                /** number of the head node **/
125        public int m_nHead;
126                /** type of operation (ADD, DEL, REVERSE) **/
127        public int m_nOperation;
128        /** change of score due to this operation **/
129        public double m_fScore = -1E100;
130       
131        /**
132         * Returns the revision string.
133         *
134         * @return              the revision
135         */
136        public String getRevision() {
137          return RevisionUtils.extract("$Revision: 1.9 $");
138        }
139    } // class Operation
140       
141    /** use the arc reversal operator **/
142    boolean m_bUseArcReversal = false;
143
144    /**
145     * search determines the network structure/graph of the network
146     * with the Taby algorithm.
147     *
148     * @param bayesNet the network to search
149     * @param instances the instances to work with
150     * @throws Exception if something goes wrong
151     */
152    protected void search(BayesNet bayesNet, Instances instances) throws Exception {
153        m_BayesNet = bayesNet;
154                double fScore = calcScore(bayesNet);
155        // go do the search       
156                Operation oOperation = getOptimalOperation(bayesNet, instances);
157                while ((oOperation != null) && (oOperation.m_fScore > fScore)) {
158                        performOperation(bayesNet, instances, oOperation);
159                        fScore = oOperation.m_fScore;
160                        oOperation = getOptimalOperation(bayesNet, instances);
161        }       
162    } // search
163
164
165
166        /** check whether the operation is not in the forbidden.
167         * For base hill climber, there are no restrictions on operations,
168         * so we always return true.
169         * @param oOperation operation to be checked
170         * @return true if operation is not in the tabu list
171         */
172        boolean isNotTabu(Operation oOperation) {
173                return true;
174        } // isNotTabu
175
176        /**
177         * getOptimalOperation finds the optimal operation that can be performed
178         * on the Bayes network that is not in the tabu list.
179         *
180         * @param bayesNet Bayes network to apply operation on
181         * @param instances data set to learn from
182         * @return optimal operation found
183         * @throws Exception if something goes wrong
184         */
185    Operation getOptimalOperation(BayesNet bayesNet, Instances instances) throws Exception {
186        Operation oBestOperation = new Operation();
187
188                // Add???
189                oBestOperation = findBestArcToAdd(bayesNet, instances, oBestOperation);
190                // Delete???
191                oBestOperation = findBestArcToDelete(bayesNet, instances, oBestOperation);
192                // Reverse???
193                if (getUseArcReversal()) {
194                        oBestOperation = findBestArcToReverse(bayesNet, instances, oBestOperation);
195                }
196
197                // did we find something?
198                if (oBestOperation.m_fScore == -1E100) {
199                        return null;
200                }
201
202        return oBestOperation;
203    } // getOptimalOperation
204
205        /** performOperation applies an operation
206         * on the Bayes network and update the cache.
207         *
208         * @param bayesNet Bayes network to apply operation on
209         * @param instances data set to learn from
210         * @param oOperation operation to perform
211         * @throws Exception if something goes wrong
212         */
213        void performOperation(BayesNet bayesNet, Instances instances, Operation oOperation) throws Exception {
214                // perform operation
215                switch (oOperation.m_nOperation) {
216                        case Operation.OPERATION_ADD:
217                                applyArcAddition(bayesNet, oOperation.m_nHead, oOperation.m_nTail, instances);
218                                if (bayesNet.getDebug()) {
219                                        System.out.print("Add " + oOperation.m_nHead + " -> " + oOperation.m_nTail);
220                                }
221                                break;
222                        case Operation.OPERATION_DEL:
223                                applyArcDeletion(bayesNet, oOperation.m_nHead, oOperation.m_nTail, instances);
224                                if (bayesNet.getDebug()) {
225                                        System.out.print("Del " + oOperation.m_nHead + " -> " + oOperation.m_nTail);
226                                }
227                                break;
228                        case Operation.OPERATION_REVERSE:
229                                applyArcDeletion(bayesNet, oOperation.m_nHead, oOperation.m_nTail, instances);
230                                applyArcAddition(bayesNet, oOperation.m_nTail, oOperation.m_nHead, instances);
231                                if (bayesNet.getDebug()) {
232                                        System.out.print("Rev " + oOperation.m_nHead+ " -> " + oOperation.m_nTail);
233                                }
234                                break;
235                }
236        } // performOperation
237
238        /**
239         *
240         * @param bayesNet
241         * @param iHead
242         * @param iTail
243         * @param instances
244         */
245        void applyArcAddition(BayesNet bayesNet, int iHead, int iTail, Instances instances) {
246                ParentSet bestParentSet = bayesNet.getParentSet(iHead);
247                bestParentSet.addParent(iTail, instances);
248        } // applyArcAddition
249
250        /**
251         *
252         * @param bayesNet
253         * @param iHead
254         * @param iTail
255         * @param instances
256         */
257        void applyArcDeletion(BayesNet bayesNet, int iHead, int iTail, Instances instances) {
258                ParentSet bestParentSet = bayesNet.getParentSet(iHead);
259                bestParentSet.deleteParent(iTail, instances);
260        } // applyArcAddition
261
262
263        /**
264         * find best (or least bad) arc addition operation
265         *
266         * @param bayesNet Bayes network to add arc to
267         * @param instances data set
268         * @param oBestOperation
269         * @return Operation containing best arc to add, or null if no arc addition is allowed
270         * (this can happen if any arc addition introduces a cycle, or all parent sets are filled
271         * up to the maximum nr of parents).
272         * @throws Exception if something goes wrong
273         */
274        Operation findBestArcToAdd(BayesNet bayesNet, Instances instances, Operation oBestOperation) throws Exception {
275                int nNrOfAtts = instances.numAttributes();
276                // find best arc to add
277                for (int iAttributeHead = 0; iAttributeHead < nNrOfAtts; iAttributeHead++) {
278                        if (bayesNet.getParentSet(iAttributeHead).getNrOfParents() < m_nMaxNrOfParents) {
279                                for (int iAttributeTail = 0; iAttributeTail < nNrOfAtts; iAttributeTail++) {
280                                        if (addArcMakesSense(bayesNet, instances, iAttributeHead, iAttributeTail)) {
281                                                Operation oOperation = new Operation(iAttributeTail, iAttributeHead, Operation.OPERATION_ADD);
282                                                double fScore = calcScoreWithExtraParent(oOperation.m_nHead, oOperation.m_nTail);
283                                                if (fScore > oBestOperation.m_fScore) {
284                                                        if (isNotTabu(oOperation)) {
285                                                                oBestOperation = oOperation;
286                                                                oBestOperation.m_fScore = fScore;
287                                                        }
288                                                }
289                                        }
290                                }
291                        }
292                }
293                return oBestOperation;
294        } // findBestArcToAdd
295
296        /**
297         * find best (or least bad) arc deletion operation
298         *
299         * @param bayesNet Bayes network to delete arc from
300         * @param instances data set
301         * @param oBestOperation
302         * @return Operation containing best arc to delete, or null if no deletion can be made
303         * (happens when there is no arc in the network yet).
304         * @throws Exception of something goes wrong
305         */
306        Operation findBestArcToDelete(BayesNet bayesNet, Instances instances, Operation oBestOperation) throws Exception {
307                int nNrOfAtts = instances.numAttributes();
308                // find best arc to delete
309                for (int iNode = 0; iNode < nNrOfAtts; iNode++) {
310                        ParentSet parentSet = bayesNet.getParentSet(iNode);
311                        for (int iParent = 0; iParent < parentSet.getNrOfParents(); iParent++) {
312                                Operation oOperation = new Operation(parentSet.getParent(iParent), iNode, Operation.OPERATION_DEL);
313                                double fScore = calcScoreWithMissingParent(oOperation.m_nHead, oOperation.m_nTail);
314                                if (fScore > oBestOperation.m_fScore) {
315                                        if (isNotTabu(oOperation)) {
316                                                oBestOperation = oOperation;
317                                                oBestOperation.m_fScore = fScore;
318                                        }
319                                }
320                        }
321                }
322                return oBestOperation;
323        } // findBestArcToDelete
324
325        /**
326         * find best (or least bad) arc reversal operation
327         *
328         * @param bayesNet Bayes network to reverse arc in
329         * @param instances data set
330         * @param oBestOperation
331         * @return Operation containing best arc to reverse, or null if no reversal is allowed
332         * (happens if there is no arc in the network yet, or when any such reversal introduces
333         * a cycle).
334         * @throws Exception if something goes wrong
335         */
336        Operation findBestArcToReverse(BayesNet bayesNet, Instances instances, Operation oBestOperation) throws Exception {
337                int nNrOfAtts = instances.numAttributes();
338                // find best arc to reverse
339                for (int iNode = 0; iNode < nNrOfAtts; iNode++) {
340                        ParentSet parentSet = bayesNet.getParentSet(iNode);
341                        for (int iParent = 0; iParent < parentSet.getNrOfParents(); iParent++) {
342                                int iTail = parentSet.getParent(iParent);
343                                // is reversal allowed?
344                                if (reverseArcMakesSense(bayesNet, instances, iNode, iTail) && 
345                                    bayesNet.getParentSet(iTail).getNrOfParents() < m_nMaxNrOfParents) {
346                                        // go check if reversal results in the best step forward
347                                        Operation oOperation = new Operation(parentSet.getParent(iParent), iNode, Operation.OPERATION_REVERSE);
348                                        double fScore = calcScoreWithReversedParent(oOperation.m_nHead, oOperation.m_nTail);
349                                        if (fScore > oBestOperation.m_fScore) {
350                                                if (isNotTabu(oOperation)) {
351                                                        oBestOperation = oOperation;
352                                                        oBestOperation.m_fScore = fScore;
353                                                }
354                                        }
355                                }
356                        }
357                }
358                return oBestOperation;
359        } // findBestArcToReverse
360       
361
362        /**
363         * Sets the max number of parents
364         *
365         * @param nMaxNrOfParents the max number of parents
366         */
367        public void setMaxNrOfParents(int nMaxNrOfParents) {
368          m_nMaxNrOfParents = nMaxNrOfParents;
369        } 
370
371        /**
372         * Gets the max number of parents.
373         *
374         * @return the max number of parents
375         */
376        public int getMaxNrOfParents() {
377          return m_nMaxNrOfParents;
378        } 
379
380        /**
381         * Returns an enumeration describing the available options.
382         *
383         * @return an enumeration of all the available options.
384         */
385        public Enumeration listOptions() {
386                Vector newVector = new Vector(2);
387
388                newVector.addElement(new Option("\tMaximum number of parents", "P", 1, "-P <nr of parents>"));
389                newVector.addElement(new Option("\tUse arc reversal operation.\n\t(default false)", "R", 0, "-R"));
390                newVector.addElement(new Option("\tInitial structure is empty (instead of Naive Bayes)", "N", 0, "-N"));
391
392                Enumeration enu = super.listOptions();
393                while (enu.hasMoreElements()) {
394                        newVector.addElement(enu.nextElement());
395                }
396                return newVector.elements();
397        } // listOptions
398
399        /**
400         * Parses a given list of options. <p/>
401         *
402         <!-- options-start -->
403         * Valid options are: <p/>
404         *
405         * <pre> -P &lt;nr of parents&gt;
406         *  Maximum number of parents</pre>
407         *
408         * <pre> -R
409         *  Use arc reversal operation.
410         *  (default false)</pre>
411         *
412         * <pre> -N
413         *  Initial structure is empty (instead of Naive Bayes)</pre>
414         *
415         * <pre> -mbc
416         *  Applies a Markov Blanket correction to the network structure,
417         *  after a network structure is learned. This ensures that all
418         *  nodes in the network are part of the Markov blanket of the
419         *  classifier node.</pre>
420         *
421         * <pre> -S [LOO-CV|k-Fold-CV|Cumulative-CV]
422         *  Score type (LOO-CV,k-Fold-CV,Cumulative-CV)</pre>
423         *
424         * <pre> -Q
425         *  Use probabilistic or 0/1 scoring.
426         *  (default probabilistic scoring)</pre>
427         *
428         <!-- options-end -->
429         *
430         * @param options the list of options as an array of strings
431         * @throws Exception if an option is not supported
432         */
433        public void setOptions(String[] options) throws Exception {
434                setUseArcReversal(Utils.getFlag('R', options));
435
436                setInitAsNaiveBayes (!(Utils.getFlag('N', options)));
437               
438                String sMaxNrOfParents = Utils.getOption('P', options);
439                if (sMaxNrOfParents.length() != 0) {
440                  setMaxNrOfParents(Integer.parseInt(sMaxNrOfParents));
441                } else {
442                  setMaxNrOfParents(100000);
443                }
444               
445                super.setOptions(options);
446        } // setOptions
447
448        /**
449         * Gets the current settings of the search algorithm.
450         *
451         * @return an array of strings suitable for passing to setOptions
452         */
453        public String[] getOptions() {
454                String[] superOptions = super.getOptions();
455                String[] options = new String[7 + superOptions.length];
456                int current = 0;
457                if (getUseArcReversal()) {
458                  options[current++] = "-R";
459                }
460               
461                if (!getInitAsNaiveBayes()) {
462                  options[current++] = "-N";
463                } 
464
465                options[current++] = "-P";
466                options[current++] = "" + m_nMaxNrOfParents;
467
468                // insert options from parent class
469                for (int iOption = 0; iOption < superOptions.length; iOption++) {
470                        options[current++] = superOptions[iOption];
471                }
472
473                // Fill up rest with empty strings, not nulls!
474                while (current < options.length) {
475                        options[current++] = "";
476                }
477                return options;
478        } // getOptions
479
480        /**
481         * Sets whether to init as naive bayes
482         *
483         * @param bInitAsNaiveBayes whether to init as naive bayes
484         */
485        public void setInitAsNaiveBayes(boolean bInitAsNaiveBayes) {
486          m_bInitAsNaiveBayes = bInitAsNaiveBayes;
487        } 
488
489        /**
490         * Gets whether to init as naive bayes
491         *
492         * @return whether to init as naive bayes
493         */
494        public boolean getInitAsNaiveBayes() {
495          return m_bInitAsNaiveBayes;
496        } 
497
498        /** get use the arc reversal operation
499         * @return whether the arc reversal operation should be used
500         */
501        public boolean getUseArcReversal() {
502                return m_bUseArcReversal;
503        } // getUseArcReversal
504
505        /** set use the arc reversal operation
506         * @param bUseArcReversal whether the arc reversal operation should be used
507         */
508        public void setUseArcReversal(boolean bUseArcReversal) {
509                m_bUseArcReversal = bUseArcReversal;
510        } // setUseArcReversal
511
512        /**
513         * This will return a string describing the search algorithm.
514         * @return The string.
515         */
516        public String globalInfo() {
517          return "This Bayes Network learning algorithm uses a hill climbing algorithm " +
518          "adding, deleting and reversing arcs. The search is not restricted by an order " +
519          "on the variables (unlike K2). The difference with B and B2 is that this hill " +       
520          "climber also considers arrows part of the naive Bayes structure for deletion.";
521        } // globalInfo
522
523        /**
524         * @return a string to describe the Use Arc Reversal option.
525         */
526        public String useArcReversalTipText() {
527          return "When set to true, the arc reversal operation is used in the search.";
528        } // useArcReversalTipText
529
530        /**
531         * Returns the revision string.
532         *
533         * @return              the revision
534         */
535        public String getRevision() {
536          return RevisionUtils.extract("$Revision: 1.9 $");
537        }
538} // HillClimber
Note: See TracBrowser for help on using the repository browser.