source: src/main/java/weka/classifiers/BVDecomposeSegCVSub.java @ 9

Last change on this file since 9 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 37.2 KB
RevLine 
[4]1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    BVDecomposeSegCVSub.java
19 *    Copyright (C) 2003 Paul Conilione
20 *
21 *    Based on the class: BVDecompose.java by Len Trigg (1999)
22 */
23
24
25/*
26 *    DEDICATION
27 *
28 *    Paul Conilione would like to express his deep gratitude and appreciation
29 *    to his Chinese Buddhist Taoist Master Sifu Chow Yuk Nen for the abilities
30 *    and insight that he has been taught, which have allowed him to program in
31 *    a clear and efficient manner.
32 *
33 *    Master Sifu Chow Yuk Nen's Teachings are unique and precious. They are
34 *    applicable to any field of human endeavour. Through his unique and powerful
35 *    ability to skilfully apply Chinese Buddhist Teachings, people have achieved
36 *    success in; Computing, chemical engineering, business, accounting, philosophy
37 *    and more.
38 *
39 */
40
41package weka.classifiers;
42
43import weka.core.Attribute;
44import weka.core.Instance;
45import weka.core.Instances;
46import weka.core.Option;
47import weka.core.OptionHandler;
48import weka.core.RevisionHandler;
49import weka.core.RevisionUtils;
50import weka.core.TechnicalInformation;
51import weka.core.TechnicalInformationHandler;
52import weka.core.Utils;
53import weka.core.TechnicalInformation.Field;
54import weka.core.TechnicalInformation.Type;
55
56import java.io.BufferedReader;
57import java.io.FileReader;
58import java.io.Reader;
59import java.util.Enumeration;
60import java.util.Random;
61import java.util.Vector;
62
63/**
64 <!-- globalinfo-start -->
65 * This class performs Bias-Variance decomposion on any classifier using the sub-sampled cross-validation procedure as specified in (1).<br/>
66 * The Kohavi and Wolpert definition of bias and variance is specified in (2).<br/>
67 * The Webb definition of bias and variance is specified in (3).<br/>
68 * <br/>
69 * Geoffrey I. Webb, Paul Conilione (2002). Estimating bias and variance from data. School of Computer Science and Software Engineering, Victoria, Australia.<br/>
70 * <br/>
71 * Ron Kohavi, David H. Wolpert: Bias Plus Variance Decomposition for Zero-One Loss Functions. In: Machine Learning: Proceedings of the Thirteenth International Conference, 275-283, 1996.<br/>
72 * <br/>
73 * Geoffrey I. Webb (2000). MultiBoosting: A Technique for Combining Boosting and Wagging. Machine Learning. 40(2):159-196.
74 * <p/>
75 <!-- globalinfo-end -->
76 *
77 <!-- technical-bibtex-start -->
78 * BibTeX:
79 * <pre>
80 * &#64;misc{Webb2002,
81 *    address = {School of Computer Science and Software Engineering, Victoria, Australia},
82 *    author = {Geoffrey I. Webb and Paul Conilione},
83 *    institution = {Monash University},
84 *    title = {Estimating bias and variance from data},
85 *    year = {2002},
86 *    PDF = {http://www.csse.monash.edu.au/\~webb/Files/WebbConilione04.pdf}
87 * }
88 *
89 * &#64;inproceedings{Kohavi1996,
90 *    author = {Ron Kohavi and David H. Wolpert},
91 *    booktitle = {Machine Learning: Proceedings of the Thirteenth International Conference},
92 *    editor = {Lorenza Saitta},
93 *    pages = {275-283},
94 *    publisher = {Morgan Kaufmann},
95 *    title = {Bias Plus Variance Decomposition for Zero-One Loss Functions},
96 *    year = {1996},
97 *    PS = {http://robotics.stanford.edu/\~ronnyk/biasVar.ps}
98 * }
99 *
100 * &#64;article{Webb2000,
101 *    author = {Geoffrey I. Webb},
102 *    journal = {Machine Learning},
103 *    number = {2},
104 *    pages = {159-196},
105 *    title = {MultiBoosting: A Technique for Combining Boosting and Wagging},
106 *    volume = {40},
107 *    year = {2000}
108 * }
109 * </pre>
110 * <p/>
111 <!-- technical-bibtex-end -->
112 *
113 <!-- options-start -->
114 * Valid options are: <p/>
115 *
116 * <pre> -c &lt;class index&gt;
117 *  The index of the class attribute.
118 *  (default last)</pre>
119 *
120 * <pre> -D
121 *  Turn on debugging output.</pre>
122 *
123 * <pre> -l &lt;num&gt;
124 *  The number of times each instance is classified.
125 *  (default 10)</pre>
126 *
127 * <pre> -p &lt;proportion of objects in common&gt;
128 *  The average proportion of instances common between any two training sets</pre>
129 *
130 * <pre> -s &lt;seed&gt;
131 *  The random number seed used.</pre>
132 *
133 * <pre> -t &lt;name of arff file&gt;
134 *  The name of the arff file used for the decomposition.</pre>
135 *
136 * <pre> -T &lt;number of instances in training set&gt;
137 *  The number of instances in the training set.</pre>
138 *
139 * <pre> -W &lt;classifier class name&gt;
140 *  Full class name of the learner used in the decomposition.
141 *  eg: weka.classifiers.bayes.NaiveBayes</pre>
142 *
143 * <pre>
144 * Options specific to learner weka.classifiers.rules.ZeroR:
145 * </pre>
146 *
147 * <pre> -D
148 *  If set, classifier is run in debug mode and
149 *  may output additional info to the console</pre>
150 *
151 <!-- options-end -->
152 *
153 * Options after -- are passed to the designated sub-learner. <p>
154 *
155 * @author Paul Conilione (paulc4321@yahoo.com.au)
156 * @version $Revision: 6041 $
157 */
158public class BVDecomposeSegCVSub
159    implements OptionHandler, TechnicalInformationHandler, RevisionHandler {
160
161    /** Debugging mode, gives extra output if true. */
162    protected boolean m_Debug;
163
164    /** An instantiated base classifier used for getting and testing options. */
165    protected Classifier m_Classifier = new weka.classifiers.rules.ZeroR();
166
167    /** The options to be passed to the base classifier. */
168    protected String [] m_ClassifierOptions;
169
170    /** The number of times an instance is classified*/
171    protected int m_ClassifyIterations;
172
173    /** The name of the data file used for the decomposition */
174    protected String m_DataFileName;
175
176    /** The index of the class attribute */
177    protected int m_ClassIndex = -1;
178
179    /** The random number seed */
180    protected int m_Seed = 1;
181
182    /** The calculated Kohavi & Wolpert bias (squared) */
183    protected double m_KWBias;
184
185    /** The calculated Kohavi & Wolpert variance */
186    protected double m_KWVariance;
187
188    /** The calculated Kohavi & Wolpert sigma */
189    protected double m_KWSigma;
190
191    /** The calculated Webb bias */
192    protected double m_WBias;
193
194    /** The calculated Webb variance */
195    protected double m_WVariance;
196
197    /** The error rate */
198    protected double m_Error;
199
200    /** The training set size */
201    protected int m_TrainSize;
202
203    /** Proportion of instances common between any two training sets. */
204    protected double m_P;
205
206    /**
207     * Returns a string describing this object
208     * @return a description of the classifier suitable for
209     * displaying in the explorer/experimenter gui
210     */
211    public String globalInfo() {
212      return
213          "This class performs Bias-Variance decomposion on any classifier using the "
214        + "sub-sampled cross-validation procedure as specified in (1).\n"
215        + "The Kohavi and Wolpert definition of bias and variance is specified in (2).\n"
216        + "The Webb definition of bias and variance is specified in (3).\n\n"
217        + getTechnicalInformation().toString();
218    }
219
220    /**
221     * Returns an instance of a TechnicalInformation object, containing
222     * detailed information about the technical background of this class,
223     * e.g., paper reference or book this class is based on.
224     *
225     * @return the technical information about this class
226     */
227    public TechnicalInformation getTechnicalInformation() {
228      TechnicalInformation      result;
229      TechnicalInformation      additional;
230
231      result = new TechnicalInformation(Type.MISC);
232      result.setValue(Field.AUTHOR, "Geoffrey I. Webb and Paul Conilione");
233      result.setValue(Field.YEAR, "2002");
234      result.setValue(Field.TITLE, "Estimating bias and variance from data");
235      result.setValue(Field.INSTITUTION, "Monash University");
236      result.setValue(Field.ADDRESS, "School of Computer Science and Software Engineering, Victoria, Australia");
237      result.setValue(Field.PDF, "http://www.csse.monash.edu.au/~webb/Files/WebbConilione04.pdf");
238
239      additional = result.add(Type.INPROCEEDINGS);
240      additional.setValue(Field.AUTHOR, "Ron Kohavi and David H. Wolpert");
241      additional.setValue(Field.YEAR, "1996");
242      additional.setValue(Field.TITLE, "Bias Plus Variance Decomposition for Zero-One Loss Functions");
243      additional.setValue(Field.BOOKTITLE, "Machine Learning: Proceedings of the Thirteenth International Conference");
244      additional.setValue(Field.PUBLISHER, "Morgan Kaufmann");
245      additional.setValue(Field.EDITOR, "Lorenza Saitta");
246      additional.setValue(Field.PAGES, "275-283");
247      additional.setValue(Field.PS, "http://robotics.stanford.edu/~ronnyk/biasVar.ps");
248
249      additional = result.add(Type.ARTICLE);
250      additional.setValue(Field.AUTHOR, "Geoffrey I. Webb");
251      additional.setValue(Field.YEAR, "2000");
252      additional.setValue(Field.TITLE, "MultiBoosting: A Technique for Combining Boosting and Wagging");
253      additional.setValue(Field.JOURNAL, "Machine Learning");
254      additional.setValue(Field.VOLUME, "40");
255      additional.setValue(Field.NUMBER, "2");
256      additional.setValue(Field.PAGES, "159-196");
257
258      return result;
259    }
260
261    /**
262     * Returns an enumeration describing the available options.
263     *
264     * @return an enumeration of all the available options.
265     */
266    public Enumeration listOptions() {
267
268        Vector newVector = new Vector(8);
269
270        newVector.addElement(new Option(
271        "\tThe index of the class attribute.\n"+
272        "\t(default last)",
273        "c", 1, "-c <class index>"));
274        newVector.addElement(new Option(
275        "\tTurn on debugging output.",
276        "D", 0, "-D"));
277        newVector.addElement(new Option(
278        "\tThe number of times each instance is classified.\n"
279        +"\t(default 10)",
280        "l", 1, "-l <num>"));
281        newVector.addElement(new Option(
282        "\tThe average proportion of instances common between any two training sets",
283        "p", 1, "-p <proportion of objects in common>"));
284        newVector.addElement(new Option(
285        "\tThe random number seed used.",
286        "s", 1, "-s <seed>"));
287        newVector.addElement(new Option(
288        "\tThe name of the arff file used for the decomposition.",
289        "t", 1, "-t <name of arff file>"));
290        newVector.addElement(new Option(
291        "\tThe number of instances in the training set.",
292        "T", 1, "-T <number of instances in training set>"));
293        newVector.addElement(new Option(
294        "\tFull class name of the learner used in the decomposition.\n"
295        +"\teg: weka.classifiers.bayes.NaiveBayes",
296        "W", 1, "-W <classifier class name>"));
297
298        if ((m_Classifier != null) &&
299        (m_Classifier instanceof OptionHandler)) {
300            newVector.addElement(new Option(
301            "",
302            "", 0, "\nOptions specific to learner "
303            + m_Classifier.getClass().getName()
304            + ":"));
305            Enumeration enu = ((OptionHandler)m_Classifier).listOptions();
306            while (enu.hasMoreElements()) {
307                newVector.addElement(enu.nextElement());
308            }
309        }
310        return newVector.elements();
311    }
312
313
314    /**
315     * Sets the OptionHandler's options using the given list. All options
316     * will be set (or reset) during this call (i.e. incremental setting
317     * of options is not possible). <p/>
318     *
319     <!-- options-start -->
320     * Valid options are: <p/>
321     *
322     * <pre> -c &lt;class index&gt;
323     *  The index of the class attribute.
324     *  (default last)</pre>
325     *
326     * <pre> -D
327     *  Turn on debugging output.</pre>
328     *
329     * <pre> -l &lt;num&gt;
330     *  The number of times each instance is classified.
331     *  (default 10)</pre>
332     *
333     * <pre> -p &lt;proportion of objects in common&gt;
334     *  The average proportion of instances common between any two training sets</pre>
335     *
336     * <pre> -s &lt;seed&gt;
337     *  The random number seed used.</pre>
338     *
339     * <pre> -t &lt;name of arff file&gt;
340     *  The name of the arff file used for the decomposition.</pre>
341     *
342     * <pre> -T &lt;number of instances in training set&gt;
343     *  The number of instances in the training set.</pre>
344     *
345     * <pre> -W &lt;classifier class name&gt;
346     *  Full class name of the learner used in the decomposition.
347     *  eg: weka.classifiers.bayes.NaiveBayes</pre>
348     *
349     * <pre>
350     * Options specific to learner weka.classifiers.rules.ZeroR:
351     * </pre>
352     *
353     * <pre> -D
354     *  If set, classifier is run in debug mode and
355     *  may output additional info to the console</pre>
356     *
357     <!-- options-end -->
358     *
359     * @param options the list of options as an array of strings
360     * @throws Exception if an option is not supported
361     */
362    public void setOptions(String[] options) throws Exception {
363        setDebug(Utils.getFlag('D', options));
364
365        String classIndex = Utils.getOption('c', options);
366        if (classIndex.length() != 0) {
367            if (classIndex.toLowerCase().equals("last")) {
368                setClassIndex(0);
369            } else if (classIndex.toLowerCase().equals("first")) {
370                setClassIndex(1);
371            } else {
372                setClassIndex(Integer.parseInt(classIndex));
373            }
374        } else {
375            setClassIndex(0);
376        }
377
378        String classifyIterations = Utils.getOption('l', options);
379        if (classifyIterations.length() != 0) {
380            setClassifyIterations(Integer.parseInt(classifyIterations));
381        } else {
382            setClassifyIterations(10);
383        }
384
385        String prob = Utils.getOption('p', options);
386        if (prob.length() != 0) {
387            setP( Double.parseDouble(prob));
388        } else {
389            setP(-1);
390        }
391        //throw new Exception("A proportion must be specified" + " with a -p option.");
392
393        String seedString = Utils.getOption('s', options);
394        if (seedString.length() != 0) {
395            setSeed(Integer.parseInt(seedString));
396        } else {
397            setSeed(1);
398        }
399
400        String dataFile = Utils.getOption('t', options);
401        if (dataFile.length() != 0) {
402            setDataFileName(dataFile);
403        } else {
404            throw new Exception("An arff file must be specified"
405            + " with the -t option.");
406        }
407
408        String trainSize = Utils.getOption('T', options);
409        if (trainSize.length() != 0) {
410            setTrainSize(Integer.parseInt(trainSize));
411        } else {
412            setTrainSize(-1);
413        }
414        //throw new Exception("A training set size must be specified" + " with a -T option.");
415
416        String classifierName = Utils.getOption('W', options);
417        if (classifierName.length() != 0) {
418            setClassifier(AbstractClassifier.forName(classifierName, Utils.partitionOptions(options)));
419        } else {
420            throw new Exception("A learner must be specified with the -W option.");
421        }
422    }
423
424    /**
425     * Gets the current settings of the CheckClassifier.
426     *
427     * @return an array of strings suitable for passing to setOptions
428     */
429    public String [] getOptions() {
430
431        String [] classifierOptions = new String [0];
432        if ((m_Classifier != null) &&
433        (m_Classifier instanceof OptionHandler)) {
434            classifierOptions = ((OptionHandler)m_Classifier).getOptions();
435        }
436        String [] options = new String [classifierOptions.length + 14];
437        int current = 0;
438        if (getDebug()) {
439            options[current++] = "-D";
440        }
441        options[current++] = "-c"; options[current++] = "" + getClassIndex();
442        options[current++] = "-l"; options[current++] = "" + getClassifyIterations();
443        options[current++] = "-p"; options[current++] = "" + getP();
444        options[current++] = "-s"; options[current++] = "" + getSeed();
445        if (getDataFileName() != null) {
446            options[current++] = "-t"; options[current++] = "" + getDataFileName();
447        }
448        options[current++] = "-T"; options[current++] = "" + getTrainSize();
449        if (getClassifier() != null) {
450            options[current++] = "-W";
451            options[current++] = getClassifier().getClass().getName();
452        }
453
454        options[current++] = "--";
455        System.arraycopy(classifierOptions, 0, options, current,
456        classifierOptions.length);
457        current += classifierOptions.length;
458        while (current < options.length) {
459            options[current++] = "";
460        }
461        return options;
462    }
463
464    /**
465     * Set the classifiers being analysed
466     *
467     * @param newClassifier the Classifier to use.
468     */
469    public void setClassifier(Classifier newClassifier) {
470
471        m_Classifier = newClassifier;
472    }
473
474    /**
475     * Gets the name of the classifier being analysed
476     *
477     * @return the classifier being analysed.
478     */
479    public Classifier getClassifier() {
480
481        return m_Classifier;
482    }
483
484    /**
485     * Sets debugging mode
486     *
487     * @param debug true if debug output should be printed
488     */
489    public void setDebug(boolean debug) {
490
491        m_Debug = debug;
492    }
493
494    /**
495     * Gets whether debugging is turned on
496     *
497     * @return true if debugging output is on
498     */
499    public boolean getDebug() {
500
501        return m_Debug;
502    }
503
504
505    /**
506     * Sets the random number seed
507     *
508     * @param seed the random number seed
509     */
510    public void setSeed(int seed) {
511
512        m_Seed = seed;
513    }
514
515    /**
516     * Gets the random number seed
517     *
518     * @return the random number seed
519     */
520    public int getSeed() {
521
522        return m_Seed;
523    }
524
525    /**
526     * Sets the number of times an instance is classified
527     *
528     * @param classifyIterations number of times an instance is classified
529     */
530    public void setClassifyIterations(int classifyIterations) {
531
532        m_ClassifyIterations = classifyIterations;
533    }
534
535    /**
536     * Gets the number of times an instance is classified
537     *
538     * @return the maximum number of times an instance is classified
539     */
540    public int getClassifyIterations() {
541
542        return m_ClassifyIterations;
543    }
544
545    /**
546     * Sets the name of the dataset file.
547     *
548     * @param dataFileName name of dataset file.
549     */
550    public void setDataFileName(String dataFileName) {
551
552        m_DataFileName = dataFileName;
553    }
554
555    /**
556     * Get the name of the data file used for the decomposition
557     *
558     * @return the name of the data file
559     */
560    public String getDataFileName() {
561
562        return m_DataFileName;
563    }
564
565    /**
566     * Get the index (starting from 1) of the attribute used as the class.
567     *
568     * @return the index of the class attribute
569     */
570    public int getClassIndex() {
571
572        return m_ClassIndex + 1;
573    }
574
575    /**
576     * Sets index of attribute to discretize on
577     *
578     * @param classIndex the index (starting from 1) of the class attribute
579     */
580    public void setClassIndex(int classIndex) {
581
582        m_ClassIndex = classIndex - 1;
583    }
584
585    /**
586     * Get the calculated bias squared according to the Kohavi and Wolpert definition
587     *
588     * @return the bias squared
589     */
590    public double getKWBias() {
591
592        return m_KWBias;
593    }
594
595    /**
596     * Get the calculated bias according to the Webb definition
597     *
598     * @return the bias
599     *
600     */
601    public double getWBias() {
602
603        return m_WBias;
604    }
605
606
607    /**
608     * Get the calculated variance according to the Kohavi and Wolpert definition
609     *
610     * @return the variance
611     */
612    public double getKWVariance() {
613
614        return m_KWVariance;
615    }
616
617    /**
618     * Get the calculated variance according to the Webb definition
619     *
620     * @return the variance according to Webb
621     *
622     */
623    public double getWVariance() {
624
625        return m_WVariance;
626    }
627
628    /**
629     * Get the calculated sigma according to the Kohavi and Wolpert definition
630     *
631     * @return the sigma
632     *
633     */
634    public double getKWSigma() {
635
636        return m_KWSigma;
637    }
638
639    /**
640     * Set the training size.
641     *
642     * @param size the size of the training set
643     *
644     */
645    public void setTrainSize(int size) {
646
647        m_TrainSize = size;
648    }
649
650    /**
651     * Get the training size
652     *
653     * @return the size of the training set
654     *
655     */
656    public int getTrainSize() {
657
658        return m_TrainSize;
659    }
660
661    /**
662     * Set the proportion of instances that are common between two training sets
663     * used to train a classifier.
664     *
665     * @param proportion the proportion of instances that are common between training
666     * sets.
667     *
668     */
669    public void setP(double proportion) {
670
671        m_P = proportion;
672    }
673
674    /**
675     * Get the proportion of instances that are common between two training sets.
676     *
677     * @return the proportion
678     *
679     */
680    public double getP() {
681
682        return m_P;
683    }
684
685    /**
686     * Get the calculated error rate
687     *
688     * @return the error rate
689     */
690    public double getError() {
691
692        return m_Error;
693    }
694
695    /**
696     * Carry out the bias-variance decomposition using the sub-sampled cross-validation method.
697     *
698     * @throws Exception if the decomposition couldn't be carried out
699     */
700    public void decompose() throws Exception {
701
702        Reader dataReader;
703        Instances data;
704
705        int tps; // training pool size, size of segment E.
706        int k; // number of folds in segment E.
707        int q; // number of segments of size tps.
708
709        dataReader = new BufferedReader(new FileReader(m_DataFileName)); //open file
710        data = new Instances(dataReader); // encapsulate in wrapper class called weka.Instances()
711
712        if (m_ClassIndex < 0) {
713            data.setClassIndex(data.numAttributes() - 1);
714        } else {
715            data.setClassIndex(m_ClassIndex);
716        }
717
718        if (data.classAttribute().type() != Attribute.NOMINAL) {
719            throw new Exception("Class attribute must be nominal");
720        }
721        int numClasses = data.numClasses();
722
723        data.deleteWithMissingClass();
724        if ( data.checkForStringAttributes() ) {
725            throw new Exception("Can't handle string attributes!");
726        }
727
728        // Dataset size must be greater than 2
729        if ( data.numInstances() <= 2 ){
730            throw new Exception("Dataset size must be greater than 2.");
731        }
732
733        if ( m_TrainSize == -1 ){ // default value
734            m_TrainSize = (int) Math.floor( (double) data.numInstances() / 2.0 );
735        }else  if ( m_TrainSize < 0 || m_TrainSize >= data.numInstances() - 1 ) {  // Check if 0 < training Size < D - 1
736            throw new Exception("Training set size of "+m_TrainSize+" is invalid.");
737        }
738
739        if ( m_P == -1 ){ // default value
740            m_P = (double) m_TrainSize / ( (double)data.numInstances() - 1 );
741        }else if (  m_P < ( m_TrainSize / ( (double)data.numInstances() - 1 ) ) || m_P >= 1.0  ) { //Check if p is in range: m/(|D|-1) <= p < 1.0
742            throw new Exception("Proportion is not in range: "+ (m_TrainSize / ((double) data.numInstances() - 1 )) +" <= p < 1.0 ");
743        }
744
745        //roundup tps from double to integer
746        tps = (int) Math.ceil( ((double)m_TrainSize / (double)m_P) + 1 );
747        k = (int) Math.ceil( tps / (tps - (double) m_TrainSize));
748
749        // number of folds cannot be more than the number of instances in the training pool
750        if ( k > tps ) {
751            throw new Exception("The required number of folds is too many."
752            + "Change p or the size of the training set.");
753        }
754
755        // calculate the number of segments, round down.
756        q = (int) Math.floor( (double) data.numInstances() / (double)tps );
757
758        //create confusion matrix, columns = number of instances in data set, as all will be used,  by rows = number of classes.
759        double [][] instanceProbs = new double [data.numInstances()][numClasses];
760        int [][] foldIndex = new int [ k ][ 2 ];
761        Vector segmentList = new Vector(q + 1);
762
763        //Set random seed
764        Random random = new Random(m_Seed);
765
766        data.randomize(random);
767
768        //create index arrays for different segments
769
770        int currentDataIndex = 0;
771
772        for( int count = 1; count <= (q + 1); count++ ){
773            if( count > q){
774                int [] segmentIndex = new int [ (data.numInstances() - (q * tps)) ];
775                for(int index = 0; index < segmentIndex.length; index++, currentDataIndex++){
776
777                    segmentIndex[index] = currentDataIndex;
778                }
779                segmentList.add(segmentIndex);
780            } else {
781                int [] segmentIndex = new int [ tps ];
782
783                for(int index = 0; index < segmentIndex.length; index++, currentDataIndex++){
784                    segmentIndex[index] = currentDataIndex;
785                }
786                segmentList.add(segmentIndex);
787            }
788        }
789
790        int remainder = tps % k; // remainder is used to determine when to shrink the fold size by 1.
791
792        //foldSize = ROUNDUP( tps / k ) (round up, eg 3 -> 3,  3.3->4)
793        int foldSize = (int) Math.ceil( (double)tps /(double) k); //roundup fold size double to integer
794        int index = 0;
795        int currentIndex;
796
797        for( int count = 0; count < k; count ++){
798            if( remainder != 0 && count == remainder ){
799                foldSize -= 1;
800            }
801            foldIndex[count][0] = index;
802            foldIndex[count][1] = foldSize;
803            index += foldSize;
804        }
805
806        for( int l = 0; l < m_ClassifyIterations; l++) {
807
808            for(int i = 1; i <= q; i++){
809
810                int [] currentSegment = (int[]) segmentList.get(i - 1);
811
812                randomize(currentSegment, random);
813
814                //CROSS FOLD VALIDATION for current Segment
815                for( int j = 1; j <= k; j++){
816
817                    Instances TP = null;
818                    for(int foldNum = 1; foldNum <= k; foldNum++){
819                        if( foldNum != j){
820
821                            int startFoldIndex = foldIndex[ foldNum - 1 ][ 0 ]; //start index
822                            foldSize = foldIndex[ foldNum - 1 ][ 1 ];
823                            int endFoldIndex = startFoldIndex + foldSize - 1;
824
825                            for(int currentFoldIndex = startFoldIndex; currentFoldIndex <= endFoldIndex; currentFoldIndex++){
826
827                                if( TP == null ){
828                                    TP = new Instances(data, currentSegment[ currentFoldIndex ], 1);
829                                }else{
830                                    TP.add( data.instance( currentSegment[ currentFoldIndex ] ) );
831                                }
832                            }
833                        }
834                    }
835
836                    TP.randomize(random);
837
838                    if( getTrainSize() > TP.numInstances() ){
839                        throw new Exception("The training set size of " + getTrainSize() + ", is greater than the training pool "
840                        + TP.numInstances() );
841                    }
842
843                    Instances train = new Instances(TP, 0, m_TrainSize);
844
845                    Classifier current = AbstractClassifier.makeCopy(m_Classifier);
846                    current.buildClassifier(train); // create a clssifier using the instances in train.
847
848                    int currentTestIndex = foldIndex[ j - 1 ][ 0 ]; //start index
849                    int testFoldSize = foldIndex[ j - 1 ][ 1 ]; //size
850                    int endTestIndex = currentTestIndex + testFoldSize - 1;
851
852                    while( currentTestIndex <= endTestIndex ){
853
854                        Instance testInst = data.instance( currentSegment[currentTestIndex] );
855                        int pred = (int)current.classifyInstance( testInst );
856
857
858                        if(pred != testInst.classValue()) {
859                            m_Error++; // add 1 to mis-classifications.
860                        }
861                        instanceProbs[ currentSegment[ currentTestIndex ] ][ pred ]++;
862                        currentTestIndex++;
863                    }
864
865                    if( i == 1 && j == 1){
866                        int[] segmentElast = (int[])segmentList.lastElement();
867                        for( currentIndex = 0; currentIndex < segmentElast.length; currentIndex++){
868                            Instance testInst = data.instance( segmentElast[currentIndex] );
869                            int pred = (int)current.classifyInstance( testInst );
870                            if(pred != testInst.classValue()) {
871                                m_Error++; // add 1 to mis-classifications.
872                            }
873
874                            instanceProbs[ segmentElast[ currentIndex ] ][ pred ]++;
875                        }
876                    }
877                }
878            }
879        }
880
881        m_Error /= (double)( m_ClassifyIterations * data.numInstances() );
882
883        m_KWBias = 0.0;
884        m_KWVariance = 0.0;
885        m_KWSigma = 0.0;
886
887        m_WBias = 0.0;
888        m_WVariance = 0.0;
889
890        for (int i = 0; i < data.numInstances(); i++) {
891
892            Instance current = data.instance( i );
893
894            double [] predProbs = instanceProbs[ i ];
895            double pActual, pPred;
896            double bsum = 0, vsum = 0, ssum = 0;
897            double wBSum = 0, wVSum = 0;
898
899            Vector centralTendencies = findCentralTendencies( predProbs );
900
901            if( centralTendencies == null ){
902                throw new Exception("Central tendency was null.");
903            }
904
905            for (int j = 0; j < numClasses; j++) {
906                pActual = (current.classValue() == j) ? 1 : 0;
907                pPred = predProbs[j] / m_ClassifyIterations;
908                bsum += (pActual - pPred) * (pActual - pPred) - pPred * (1 - pPred) / (m_ClassifyIterations - 1);
909                vsum += pPred * pPred;
910                ssum += pActual * pActual;
911            }
912
913            m_KWBias += bsum;
914            m_KWVariance += (1 - vsum);
915            m_KWSigma += (1 - ssum);
916
917            for( int count = 0; count < centralTendencies.size(); count++ ) {
918
919                int wB = 0, wV = 0;
920                int centralTendency = ((Integer)centralTendencies.get(count)).intValue();
921
922                // For a single instance xi, find the bias and variance.
923                for (int j = 0; j < numClasses; j++) {
924
925                    //Webb definition
926                    if( j != (int)current.classValue() && j == centralTendency ) {
927                        wB += predProbs[j];
928                    }
929                    if( j != (int)current.classValue() && j != centralTendency ) {
930                        wV += predProbs[j];
931                    }
932
933                }
934                wBSum += (double) wB;
935                wVSum += (double) wV;
936            }
937
938            // calculate bais by dividing bSum by the number of central tendencies and
939            // total number of instances. (effectively finding the average and dividing
940            // by the number of instances to get the nominalised probability).
941
942            m_WBias += ( wBSum / ((double) ( centralTendencies.size() * m_ClassifyIterations )));
943            // calculate variance by dividing vSum by the total number of interations
944            m_WVariance += ( wVSum / ((double) ( centralTendencies.size() * m_ClassifyIterations )));
945
946        }
947
948        m_KWBias /= (2.0 * (double) data.numInstances());
949        m_KWVariance /= (2.0 * (double) data.numInstances());
950        m_KWSigma /= (2.0 * (double) data.numInstances());
951
952        // bias = bias / number of data instances
953        m_WBias /= (double) data.numInstances();
954        // variance = variance / number of data instances.
955        m_WVariance /= (double) data.numInstances();
956
957        if (m_Debug) {
958            System.err.println("Decomposition finished");
959        }
960
961    }
962
963    /** Finds the central tendency, given the classifications for an instance.
964     *
965     * Where the central tendency is defined as the class that was most commonly
966     * selected for a given instance.<p>
967     *
968     * For example, instance 'x' may be classified out of 3 classes y = {1, 2, 3},
969     * so if x is classified 10 times, and is classified as follows, '1' = 2 times, '2' = 5 times
970     * and '3' = 3 times. Then the central tendency is '2'. <p>
971     *
972     * However, it is important to note that this method returns a list of all classes
973     * that have the highest number of classifications.
974     *
975     * In cases where there are several classes with the largest number of classifications, then
976     * all of these classes are returned. For example if 'x' is classified '1' = 4 times,
977     * '2' = 4 times and '3' = 2 times. Then '1' and '2' are returned.<p>
978     *
979     * @param predProbs the array of classifications for a single instance.
980     *
981     * @return a Vector containing Integer objects which store the class(s) which
982     * are the central tendency.
983     */
984    public Vector findCentralTendencies(double[] predProbs) {
985
986        int centralTValue = 0;
987        int currentValue = 0;
988        //array to store the list of classes the have the greatest number of classifictions.
989        Vector centralTClasses;
990
991        centralTClasses = new Vector(); //create an array with size of the number of classes.
992
993        // Go through array, finding the central tendency.
994        for( int i = 0; i < predProbs.length; i++) {
995            currentValue = (int) predProbs[i];
996            // if current value is greater than the central tendency value then
997            // clear vector and add new class to vector array.
998            if( currentValue > centralTValue) {
999                centralTClasses.clear();
1000                centralTClasses.addElement( new Integer(i) );
1001                centralTValue = currentValue;
1002            } else if( currentValue != 0 && currentValue == centralTValue) {
1003                centralTClasses.addElement( new Integer(i) );
1004            }
1005        }
1006        //return all classes that have the greatest number of classifications.
1007        if( centralTValue != 0){
1008            return centralTClasses;
1009        } else {
1010            return null;
1011        }
1012
1013    }
1014
1015    /**
1016     * Returns description of the bias-variance decomposition results.
1017     *
1018     * @return the bias-variance decomposition results as a string
1019     */
1020    public String toString() {
1021
1022        String result = "\nBias-Variance Decomposition Segmentation, Cross Validation\n" +
1023        "with subsampling.\n";
1024
1025        if (getClassifier() == null) {
1026            return "Invalid setup";
1027        }
1028
1029        result += "\nClassifier    : " + getClassifier().getClass().getName();
1030        if (getClassifier() instanceof OptionHandler) {
1031            result += Utils.joinOptions(((OptionHandler)m_Classifier).getOptions());
1032        }
1033        result += "\nData File     : " + getDataFileName();
1034        result += "\nClass Index   : ";
1035        if (getClassIndex() == 0) {
1036            result += "last";
1037        } else {
1038            result += getClassIndex();
1039        }
1040        result += "\nIterations    : " + getClassifyIterations();
1041        result += "\np             : " + getP();
1042        result += "\nTraining Size : " + getTrainSize();
1043        result += "\nSeed          : " + getSeed();
1044
1045        result += "\n\nDefinition   : " +"Kohavi and Wolpert";
1046        result += "\nError         :" + Utils.doubleToString(getError(), 4);
1047        result += "\nBias^2        :" + Utils.doubleToString(getKWBias(), 4);
1048        result += "\nVariance      :" + Utils.doubleToString(getKWVariance(), 4);
1049        result += "\nSigma^2       :" + Utils.doubleToString(getKWSigma(), 4);
1050
1051        result += "\n\nDefinition   : " +"Webb";
1052        result += "\nError         :" + Utils.doubleToString(getError(), 4);
1053        result += "\nBias          :" + Utils.doubleToString(getWBias(), 4);
1054        result += "\nVariance      :" + Utils.doubleToString(getWVariance(), 4);
1055
1056        return result;
1057    }
1058
1059    /**
1060     * Returns the revision string.
1061     *
1062     * @return          the revision
1063     */
1064    public String getRevision() {
1065      return RevisionUtils.extract("$Revision: 6041 $");
1066    }
1067
1068    /**
1069     * Test method for this class
1070     *
1071     * @param args the command line arguments
1072     */
1073    public static void main(String [] args) {
1074
1075        try {
1076            BVDecomposeSegCVSub bvd = new BVDecomposeSegCVSub();
1077
1078            try {
1079                bvd.setOptions(args);
1080                Utils.checkForRemainingOptions(args);
1081            } catch (Exception ex) {
1082                String result = ex.getMessage() + "\nBVDecompose Options:\n\n";
1083                Enumeration enu = bvd.listOptions();
1084                while (enu.hasMoreElements()) {
1085                    Option option = (Option) enu.nextElement();
1086                    result += option.synopsis() + "\n" + option.description() + "\n";
1087                }
1088                throw new Exception(result);
1089            }
1090
1091            bvd.decompose();
1092
1093            System.out.println(bvd.toString());
1094
1095        } catch (Exception ex) {
1096            System.err.println(ex.getMessage());
1097        }
1098
1099    }
1100
1101    /**
1102     * Accepts an array of ints and randomises the values in the array, using the
1103     * random seed.
1104     *
1105     *@param index is the array of integers
1106     *@param random is the Random seed.
1107     */
1108    public final void randomize(int[] index, Random random) {
1109        for( int j = index.length - 1; j > 0; j-- ){
1110            int k = random.nextInt( j + 1 );
1111            int temp = index[j];
1112            index[j] = index[k];
1113            index[k] = temp;
1114        }
1115    }
1116}
Note: See TracBrowser for help on using the repository browser.