source: tags/MetisMQIDemo/src/main/java/weka/classifiers/bayes/net/search/global/GlobalScoreSearchAlgorithm.java

Last change on this file was 29, checked in by gnappo, 15 years ago

Taggata versione per la demo e aggiunto branch.

File size: 16.7 KB
Line 
1/*
2 * This program is free software; you can redistribute it and/or modify
3 * it under the terms of the GNU General Public License as published by
4 * the Free Software Foundation; either version 2 of the License, or
5 * (at your option) any later version.
6 *
7 * This program is distributed in the hope that it will be useful,
8 * but WITHOUT ANY WARRANTY; without even the implied warranty of
9 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 * GNU General Public License for more details.
11 *
12 * You should have received a copy of the GNU General Public License
13 * along with this program; if not, write to the Free Software
14 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 * GlobalScoreSearchAlgorithm.java
19 * Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.bayes.net.search.global;
24
25import weka.classifiers.bayes.BayesNet;
26import weka.classifiers.bayes.net.ParentSet;
27import weka.classifiers.bayes.net.search.SearchAlgorithm;
28import weka.core.Instance;
29import weka.core.Instances;
30import weka.core.Option;
31import weka.core.RevisionUtils;
32import weka.core.SelectedTag;
33import weka.core.Tag;
34import weka.core.Utils;
35
36import java.util.Enumeration;
37import java.util.Vector;
38
39/**
40 <!-- globalinfo-start -->
41 * This Bayes Network learning algorithm uses cross validation to estimate classification accuracy.
42 * <p/>
43 <!-- globalinfo-end -->
44 *
45 <!-- options-start -->
46 * Valid options are: <p/>
47 *
48 * <pre> -mbc
49 *  Applies a Markov Blanket correction to the network structure,
50 *  after a network structure is learned. This ensures that all
51 *  nodes in the network are part of the Markov blanket of the
52 *  classifier node.</pre>
53 *
54 * <pre> -S [LOO-CV|k-Fold-CV|Cumulative-CV]
55 *  Score type (LOO-CV,k-Fold-CV,Cumulative-CV)</pre>
56 *
57 * <pre> -Q
58 *  Use probabilistic or 0/1 scoring.
59 *  (default probabilistic scoring)</pre>
60 *
61 <!-- options-end -->
62 *
63 * @author Remco Bouckaert
64 * @version $Revision: 1.10 $
65 */
66public class GlobalScoreSearchAlgorithm 
67        extends SearchAlgorithm {
68
69        /** for serialization */
70        static final long serialVersionUID = 7341389867906199781L;
71       
72        /** points to Bayes network for which a structure is searched for **/
73        BayesNet m_BayesNet;
74       
75        /** toggle between scoring using accuracy = 0-1 loss (when false) or class probabilities (when true) **/
76        boolean m_bUseProb = true;
77       
78        /** number of folds for k-fold cross validation **/
79        int m_nNrOfFolds = 10;
80
81        /** constant for score type: LOO-CV */
82        final static int LOOCV = 0;
83        /** constant for score type: k-fold-CV */
84        final static int KFOLDCV = 1;
85        /** constant for score type: Cumulative-CV */
86        final static int CUMCV = 2;
87
88        /** the score types **/
89        public static final Tag[] TAGS_CV_TYPE =
90                {
91                        new Tag(LOOCV, "LOO-CV"),
92                        new Tag(KFOLDCV, "k-Fold-CV"),
93                        new Tag(CUMCV, "Cumulative-CV")
94                };
95        /**
96         * Holds the cross validation strategy used to measure quality of network
97         */
98        int m_nCVType = LOOCV;
99
100        /**
101         * performCV returns the accuracy calculated using cross validation. 
102         * The dataset used is m_Instances associated with the Bayes Network.
103         *
104         * @param bayesNet : Bayes Network containing structure to evaluate
105         * @return accuracy (in interval 0..1) measured using cv.
106         * @throws Exception whn m_nCVType is invalided + exceptions passed on by updateClassifier
107         */
108        public double calcScore(BayesNet bayesNet) throws Exception {
109                switch (m_nCVType) {
110                        case LOOCV: 
111                                return leaveOneOutCV(bayesNet);
112                        case CUMCV: 
113                                return cumulativeCV(bayesNet);
114                        case KFOLDCV: 
115                                return kFoldCV(bayesNet, m_nNrOfFolds);
116                        default:
117                                throw new Exception("Unrecognized cross validation type encountered: " + m_nCVType);
118                }
119        } // calcScore
120
121        /**
122         * Calc Node Score With Added Parent
123         *
124         * @param nNode node for which the score is calculate
125         * @param nCandidateParent candidate parent to add to the existing parent set
126         * @return log score
127         * @throws Exception if something goes wrong
128         */
129        public double calcScoreWithExtraParent(int nNode, int nCandidateParent) throws Exception {
130                ParentSet oParentSet = m_BayesNet.getParentSet(nNode);
131                Instances instances = m_BayesNet.m_Instances;
132
133                // sanity check: nCandidateParent should not be in parent set already
134                for (int iParent = 0; iParent < oParentSet.getNrOfParents(); iParent++) {
135                        if (oParentSet.getParent(iParent) == nCandidateParent) {
136                                return -1e100;
137                        }
138                }
139
140                // set up candidate parent
141                oParentSet.addParent(nCandidateParent, instances);
142
143                // calculate the score
144                double fAccuracy = calcScore(m_BayesNet);
145
146                // delete temporarily added parent
147                oParentSet.deleteLastParent(instances);
148
149                return fAccuracy;
150        } // calcScoreWithExtraParent
151
152
153        /**
154         * Calc Node Score With Parent Deleted
155         *
156         * @param nNode node for which the score is calculate
157         * @param nCandidateParent candidate parent to delete from the existing parent set
158         * @return log score
159         * @throws Exception if something goes wrong
160         */
161        public double calcScoreWithMissingParent(int nNode, int nCandidateParent) throws Exception {
162                ParentSet oParentSet = m_BayesNet.getParentSet(nNode);
163                Instances instances = m_BayesNet.m_Instances;
164
165                // sanity check: nCandidateParent should be in parent set already
166                if (!oParentSet.contains( nCandidateParent)) {
167                                return -1e100;
168                }
169
170                // set up candidate parent
171                int iParent = oParentSet.deleteParent(nCandidateParent, instances);
172
173                // calculate the score
174                double fAccuracy = calcScore(m_BayesNet);
175
176                // reinsert temporarily deleted parent
177                oParentSet.addParent(nCandidateParent, iParent, instances);
178
179                return fAccuracy;
180        } // calcScoreWithMissingParent
181
182        /**
183         * Calc Node Score With Arrow reversed
184         *
185         * @param nNode node for which the score is calculate
186         * @param nCandidateParent candidate parent to delete from the existing parent set
187         * @return log score
188         * @throws Exception if something goes wrong
189         */
190        public double calcScoreWithReversedParent(int nNode, int nCandidateParent) throws Exception {
191                ParentSet oParentSet = m_BayesNet.getParentSet(nNode);
192                ParentSet oParentSet2 = m_BayesNet.getParentSet(nCandidateParent);
193                Instances instances = m_BayesNet.m_Instances;
194
195                // sanity check: nCandidateParent should be in parent set already
196                if (!oParentSet.contains( nCandidateParent)) {
197                                return -1e100;
198                }
199
200                // set up candidate parent
201                int iParent = oParentSet.deleteParent(nCandidateParent, instances);
202                oParentSet2.addParent(nNode, instances);
203
204                // calculate the score
205                double fAccuracy = calcScore(m_BayesNet);
206
207                // restate temporarily reversed arrow
208                oParentSet2.deleteLastParent(instances);
209                oParentSet.addParent(nCandidateParent, iParent, instances);
210
211                return fAccuracy;
212        } // calcScoreWithReversedParent
213
214        /**
215         * LeaveOneOutCV returns the accuracy calculated using Leave One Out
216         * cross validation. The dataset used is m_Instances associated with
217         * the Bayes Network.
218         * @param bayesNet : Bayes Network containing structure to evaluate
219         * @return accuracy (in interval 0..1) measured using leave one out cv.
220         * @throws Exception passed on by updateClassifier
221         */
222        public double leaveOneOutCV(BayesNet bayesNet) throws Exception {
223                m_BayesNet = bayesNet;
224                double fAccuracy = 0.0;
225                double fWeight = 0.0;
226                Instances instances = bayesNet.m_Instances;
227                bayesNet.estimateCPTs();
228                for (int iInstance = 0; iInstance < instances.numInstances(); iInstance++) {
229                        Instance instance = instances.instance(iInstance);
230                        instance.setWeight(-instance.weight());
231                        bayesNet.updateClassifier(instance);
232                        fAccuracy += accuracyIncrease(instance);
233                        fWeight += instance.weight();
234                        instance.setWeight(-instance.weight());
235                        bayesNet.updateClassifier(instance);
236                }
237                return fAccuracy / fWeight;
238        } // LeaveOneOutCV
239
240        /**
241         * CumulativeCV returns the accuracy calculated using cumulative
242         * cross validation. The idea is to run through the data set and
243         * try to classify each of the instances based on the previously
244         * seen data.
245         * The data set used is m_Instances associated with the Bayes Network.
246         * @param bayesNet : Bayes Network containing structure to evaluate
247         * @return accuracy (in interval 0..1) measured using leave one out cv.
248         * @throws Exception passed on by updateClassifier
249         */
250        public double cumulativeCV(BayesNet bayesNet) throws Exception {
251                m_BayesNet = bayesNet;
252                double fAccuracy = 0.0;
253                double fWeight = 0.0;
254                Instances instances = bayesNet.m_Instances;
255                bayesNet.initCPTs();
256                for (int iInstance = 0; iInstance < instances.numInstances(); iInstance++) {
257                        Instance instance = instances.instance(iInstance);
258                        fAccuracy += accuracyIncrease(instance);
259                        bayesNet.updateClassifier(instance);
260                        fWeight += instance.weight();
261                }
262                return fAccuracy / fWeight;
263        } // LeaveOneOutCV
264       
265        /**
266         * kFoldCV uses k-fold cross validation to measure the accuracy of a Bayes
267         * network classifier.
268         * @param bayesNet : Bayes Network containing structure to evaluate
269         * @param nNrOfFolds : the number of folds k to perform k-fold cv
270         * @return accuracy (in interval 0..1) measured using leave one out cv.
271         * @throws Exception passed on by updateClassifier
272         */
273        public double kFoldCV(BayesNet bayesNet, int nNrOfFolds) throws Exception {
274                m_BayesNet = bayesNet;
275                double fAccuracy = 0.0;
276                double fWeight = 0.0;
277                Instances instances = bayesNet.m_Instances;
278                // estimate CPTs based on complete data set
279                bayesNet.estimateCPTs();
280                int nFoldStart = 0;
281                int nFoldEnd = instances.numInstances() / nNrOfFolds;
282                int iFold = 1;
283                while (nFoldStart < instances.numInstances()) {
284                        // remove influence of fold iFold from the probability distribution
285                        for (int iInstance = nFoldStart; iInstance < nFoldEnd; iInstance++) {
286                                Instance instance = instances.instance(iInstance);
287                                instance.setWeight(-instance.weight());
288                                bayesNet.updateClassifier(instance);
289                        }
290                       
291                        // measure accuracy on fold iFold
292                        for (int iInstance = nFoldStart; iInstance < nFoldEnd; iInstance++) {
293                                Instance instance = instances.instance(iInstance);
294                                instance.setWeight(-instance.weight());
295                                fAccuracy += accuracyIncrease(instance);
296                                instance.setWeight(-instance.weight());
297                                fWeight += instance.weight();
298                        }
299
300                        // restore influence of fold iFold from the probability distribution
301                        for (int iInstance = nFoldStart; iInstance < nFoldEnd; iInstance++) {
302                                Instance instance = instances.instance(iInstance);
303                                instance.setWeight(-instance.weight());
304                                bayesNet.updateClassifier(instance);
305                        }
306
307                        // go to next fold
308                        nFoldStart = nFoldEnd;
309                        iFold++;
310                        nFoldEnd = iFold * instances.numInstances() / nNrOfFolds;
311                }
312                return fAccuracy / fWeight;
313        } // kFoldCV
314       
315        /** accuracyIncrease determines how much the accuracy estimate should
316         * be increased due to the contribution of a single given instance.
317         *
318         * @param instance : instance for which to calculate the accuracy increase.
319         * @return increase in accuracy due to given instance.
320         * @throws Exception passed on by distributionForInstance and classifyInstance
321         */
322        double accuracyIncrease(Instance instance) throws Exception {
323                if (m_bUseProb) {
324                        double [] fProb = m_BayesNet.distributionForInstance(instance);
325                        return fProb[(int) instance.classValue()] * instance.weight();
326                } else {
327                        if (m_BayesNet.classifyInstance(instance) == instance.classValue()) {
328                                return instance.weight();
329                        }
330                }
331                return 0;
332        } // accuracyIncrease
333
334        /**
335         * @return use probabilities or not in accuracy estimate
336         */
337        public boolean getUseProb() {
338                return m_bUseProb;
339        } // getUseProb
340
341        /**
342         * @param useProb : use probabilities or not in accuracy estimate
343         */
344        public void setUseProb(boolean useProb) {
345                m_bUseProb = useProb;
346        } // setUseProb
347       
348        /**
349         * set cross validation strategy to be used in searching for networks.
350         * @param newCVType : cross validation strategy
351         */
352        public void setCVType(SelectedTag newCVType) {
353                if (newCVType.getTags() == TAGS_CV_TYPE) {
354                        m_nCVType = newCVType.getSelectedTag().getID();
355                }
356        } // setCVType
357
358        /**
359         * get cross validation strategy to be used in searching for networks.
360         * @return cross validation strategy
361         */
362        public SelectedTag getCVType() {
363                return new SelectedTag(m_nCVType, TAGS_CV_TYPE);
364        } // getCVType
365
366        /**
367         *
368         * @param bMarkovBlanketClassifier
369         */
370        public void setMarkovBlanketClassifier(boolean bMarkovBlanketClassifier) {
371          super.setMarkovBlanketClassifier(bMarkovBlanketClassifier);
372        }
373
374        /**
375         *
376         * @return
377         */
378        public boolean getMarkovBlanketClassifier() {
379          return super.getMarkovBlanketClassifier();
380        }
381
382        /**
383         * Returns an enumeration describing the available options
384         *
385         * @return an enumeration of all the available options
386         */
387        public Enumeration listOptions() {
388                Vector newVector = new Vector();
389
390                newVector.addElement(new Option(
391                    "\tApplies a Markov Blanket correction to the network structure, \n"
392                    + "\tafter a network structure is learned. This ensures that all \n"
393                    + "\tnodes in the network are part of the Markov blanket of the \n"
394                    + "\tclassifier node.",
395                    "mbc", 0, "-mbc"));
396     
397                newVector.addElement(
398                        new Option(
399                                "\tScore type (LOO-CV,k-Fold-CV,Cumulative-CV)",
400                                "S",
401                                1,
402                                "-S [LOO-CV|k-Fold-CV|Cumulative-CV]"));
403
404                newVector.addElement(new Option("\tUse probabilistic or 0/1 scoring.\n\t(default probabilistic scoring)", "Q", 0, "-Q"));
405
406                Enumeration enu = super.listOptions();
407                while (enu.hasMoreElements()) {
408                        newVector.addElement(enu.nextElement());
409                }
410                return newVector.elements();
411        } // listOptions
412
413        /**
414         * Parses a given list of options. <p/>
415         *
416         <!-- options-start -->
417         * Valid options are: <p/>
418         *
419         * <pre> -mbc
420         *  Applies a Markov Blanket correction to the network structure,
421         *  after a network structure is learned. This ensures that all
422         *  nodes in the network are part of the Markov blanket of the
423         *  classifier node.</pre>
424         *
425         * <pre> -S [LOO-CV|k-Fold-CV|Cumulative-CV]
426         *  Score type (LOO-CV,k-Fold-CV,Cumulative-CV)</pre>
427         *
428         * <pre> -Q
429         *  Use probabilistic or 0/1 scoring.
430         *  (default probabilistic scoring)</pre>
431         *
432         <!-- options-end -->
433         *
434         * @param options the list of options as an array of strings
435         * @throws Exception if an option is not supported
436         */
437        public void setOptions(String[] options) throws Exception {
438
439                setMarkovBlanketClassifier(Utils.getFlag("mbc", options));
440
441                String sScore = Utils.getOption('S', options);
442
443                if (sScore.compareTo("LOO-CV") == 0) {
444                        setCVType(new SelectedTag(LOOCV, TAGS_CV_TYPE));
445                }
446                if (sScore.compareTo("k-Fold-CV") == 0) {
447                        setCVType(new SelectedTag(KFOLDCV, TAGS_CV_TYPE));
448                }
449                if (sScore.compareTo("Cumulative-CV") == 0) {
450                        setCVType(new SelectedTag(CUMCV, TAGS_CV_TYPE));
451                }
452                setUseProb(!Utils.getFlag('Q', options));               
453                super.setOptions(options);
454        } // setOptions
455
456        /**
457         * Gets the current settings of the search algorithm.
458         *
459         * @return an array of strings suitable for passing to setOptions
460         */
461        public String[] getOptions() {
462                String[] superOptions = super.getOptions();
463                String[] options = new String[4 + superOptions.length];
464                int current = 0;
465
466                if (getMarkovBlanketClassifier())
467                  options[current++] = "-mbc";
468
469                options[current++] = "-S";
470
471                switch (m_nCVType) {
472                        case (LOOCV) :
473                                options[current++] = "LOO-CV";
474                                break;
475                        case (KFOLDCV) :
476                                options[current++] = "k-Fold-CV";
477                                break;
478                        case (CUMCV) :
479                                options[current++] = "Cumulative-CV";
480                                break;
481                }
482               
483                if (!getUseProb()) {
484                  options[current++] = "-Q";
485                }
486
487                // insert options from parent class
488                for (int iOption = 0; iOption < superOptions.length; iOption++) {
489                        options[current++] = superOptions[iOption];
490                }
491
492                // Fill up rest with empty strings, not nulls!
493                while (current < options.length) {
494                        options[current++] = "";
495                }
496                return options;
497        } // getOptions
498
499        /**
500         * @return a string to describe the CVType option.
501         */
502        public String CVTypeTipText() {
503          return "Select cross validation strategy to be used in searching for networks." +
504          "LOO-CV = Leave one out cross validation\n" +
505          "k-Fold-CV = k fold cross validation\n" +
506          "Cumulative-CV = cumulative cross validation."
507          ;
508        } // CVTypeTipText
509
510        /**
511         * @return a string to describe the UseProb option.
512         */
513        public String useProbTipText() {
514          return "If set to true, the probability of the class if returned in the estimate of the "+
515          "accuracy. If set to false, the accuracy estimate is only increased if the classifier returns " +
516          "exactly the correct class.";
517        } // useProbTipText
518
519        /**
520         * This will return a string describing the search algorithm.
521         * @return The string.
522         */
523        public String globalInfo() {
524          return "This Bayes Network learning algorithm uses cross validation to estimate " +
525          "classification accuracy.";
526        } // globalInfo
527       
528        /**
529         * @return a string to describe the MarkovBlanketClassifier option.
530         */
531        public String markovBlanketClassifierTipText() {
532          return super.markovBlanketClassifierTipText();
533        }
534
535        /**
536         * Returns the revision string.
537         *
538         * @return              the revision
539         */
540        public String getRevision() {
541          return RevisionUtils.extract("$Revision: 1.10 $");
542        }
543}
Note: See TracBrowser for help on using the repository browser.