source: src/main/java/weka/classifiers/BVDecompose.java @ 11

Last change on this file since 11 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 20.0 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    BVDecompose.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers;
24
25import weka.core.Attribute;
26import weka.core.Instance;
27import weka.core.Instances;
28import weka.core.Option;
29import weka.core.OptionHandler;
30import weka.core.RevisionHandler;
31import weka.core.RevisionUtils;
32import weka.core.TechnicalInformation;
33import weka.core.TechnicalInformationHandler;
34import weka.core.Utils;
35import weka.core.TechnicalInformation.Field;
36import weka.core.TechnicalInformation.Type;
37
38import java.io.BufferedReader;
39import java.io.FileReader;
40import java.io.Reader;
41import java.util.Enumeration;
42import java.util.Random;
43import java.util.Vector;
44
45/**
46 <!-- globalinfo-start -->
47 * Class for performing a Bias-Variance decomposition on any classifier using the method specified in:<br/>
48 * <br/>
49 * Ron Kohavi, David H. Wolpert: Bias Plus Variance Decomposition for Zero-One Loss Functions. In: Machine Learning: Proceedings of the Thirteenth International Conference, 275-283, 1996.
50 * <p/>
51 <!-- globalinfo-end -->
52 *
53 <!-- technical-bibtex-start -->
54 * BibTeX:
55 * <pre>
56 * &#64;inproceedings{Kohavi1996,
57 *    author = {Ron Kohavi and David H. Wolpert},
58 *    booktitle = {Machine Learning: Proceedings of the Thirteenth International Conference},
59 *    editor = {Lorenza Saitta},
60 *    pages = {275-283},
61 *    publisher = {Morgan Kaufmann},
62 *    title = {Bias Plus Variance Decomposition for Zero-One Loss Functions},
63 *    year = {1996},
64 *    PS = {http://robotics.stanford.edu/\~ronnyk/biasVar.ps}
65 * }
66 * </pre>
67 * <p/>
68 <!-- technical-bibtex-end -->
69 *
70 <!-- options-start -->
71 * Valid options are: <p/>
72 *
73 * <pre> -c &lt;class index&gt;
74 *  The index of the class attribute.
75 *  (default last)</pre>
76 *
77 * <pre> -t &lt;name of arff file&gt;
78 *  The name of the arff file used for the decomposition.</pre>
79 *
80 * <pre> -T &lt;training pool size&gt;
81 *  The number of instances placed in the training pool.
82 *  The remainder will be used for testing. (default 100)</pre>
83 *
84 * <pre> -s &lt;seed&gt;
85 *  The random number seed used.</pre>
86 *
87 * <pre> -x &lt;num&gt;
88 *  The number of training repetitions used.
89 *  (default 50)</pre>
90 *
91 * <pre> -D
92 *  Turn on debugging output.</pre>
93 *
94 * <pre> -W &lt;classifier class name&gt;
95 *  Full class name of the learner used in the decomposition.
96 *  eg: weka.classifiers.bayes.NaiveBayes</pre>
97 *
98 * <pre>
99 * Options specific to learner weka.classifiers.rules.ZeroR:
100 * </pre>
101 *
102 * <pre> -D
103 *  If set, classifier is run in debug mode and
104 *  may output additional info to the console</pre>
105 *
106 <!-- options-end -->
107 *
108 * Options after -- are passed to the designated sub-learner. <p>
109 *
110 * @author Len Trigg (trigg@cs.waikato.ac.nz)
111 * @version $Revision: 6041 $
112 */
113public class BVDecompose
114  implements OptionHandler, TechnicalInformationHandler, RevisionHandler {
115
116  /** Debugging mode, gives extra output if true */
117  protected boolean m_Debug;
118
119  /** An instantiated base classifier used for getting and testing options. */
120  protected Classifier m_Classifier = new weka.classifiers.rules.ZeroR();
121
122  /** The options to be passed to the base classifier. */
123  protected String [] m_ClassifierOptions;
124
125  /** The number of train iterations */
126  protected int m_TrainIterations = 50;
127
128  /** The name of the data file used for the decomposition */
129  protected String m_DataFileName;
130
131  /** The index of the class attribute */
132  protected int m_ClassIndex = -1;
133
134  /** The random number seed */
135  protected int m_Seed = 1;
136
137  /** The calculated bias (squared) */
138  protected double m_Bias;
139
140  /** The calculated variance */
141  protected double m_Variance;
142
143  /** The calculated sigma (squared) */
144  protected double m_Sigma;
145
146  /** The error rate */
147  protected double m_Error;
148
149  /** The number of instances used in the training pool */
150  protected int m_TrainPoolSize = 100;
151
152  /**
153   * Returns a string describing this object
154   * @return a description of the classifier suitable for
155   * displaying in the explorer/experimenter gui
156   */
157  public String globalInfo() {
158
159    return
160        "Class for performing a Bias-Variance decomposition on any classifier "
161      + "using the method specified in:\n\n"
162      + getTechnicalInformation().toString();
163  }
164
165  /**
166   * Returns an instance of a TechnicalInformation object, containing
167   * detailed information about the technical background of this class,
168   * e.g., paper reference or book this class is based on.
169   *
170   * @return the technical information about this class
171   */
172  public TechnicalInformation getTechnicalInformation() {
173    TechnicalInformation        result;
174
175    result = new TechnicalInformation(Type.INPROCEEDINGS);
176    result.setValue(Field.AUTHOR, "Ron Kohavi and David H. Wolpert");
177    result.setValue(Field.YEAR, "1996");
178    result.setValue(Field.TITLE, "Bias Plus Variance Decomposition for Zero-One Loss Functions");
179    result.setValue(Field.BOOKTITLE, "Machine Learning: Proceedings of the Thirteenth International Conference");
180    result.setValue(Field.PUBLISHER, "Morgan Kaufmann");
181    result.setValue(Field.EDITOR, "Lorenza Saitta");
182    result.setValue(Field.PAGES, "275-283");
183    result.setValue(Field.PS, "http://robotics.stanford.edu/~ronnyk/biasVar.ps");
184
185    return result;
186  }
187
188  /**
189   * Returns an enumeration describing the available options.
190   *
191   * @return an enumeration of all the available options.
192   */
193  public Enumeration listOptions() {
194
195    Vector newVector = new Vector(7);
196
197    newVector.addElement(new Option(
198          "\tThe index of the class attribute.\n"+
199          "\t(default last)",
200          "c", 1, "-c <class index>"));
201    newVector.addElement(new Option(
202          "\tThe name of the arff file used for the decomposition.",
203          "t", 1, "-t <name of arff file>"));
204    newVector.addElement(new Option(
205          "\tThe number of instances placed in the training pool.\n"
206          + "\tThe remainder will be used for testing. (default 100)",
207          "T", 1, "-T <training pool size>"));
208    newVector.addElement(new Option(
209          "\tThe random number seed used.",
210          "s", 1, "-s <seed>"));
211    newVector.addElement(new Option(
212          "\tThe number of training repetitions used.\n"
213          +"\t(default 50)",
214          "x", 1, "-x <num>"));
215    newVector.addElement(new Option(
216          "\tTurn on debugging output.",
217          "D", 0, "-D"));
218    newVector.addElement(new Option(
219          "\tFull class name of the learner used in the decomposition.\n"
220          +"\teg: weka.classifiers.bayes.NaiveBayes",
221          "W", 1, "-W <classifier class name>"));
222
223    if ((m_Classifier != null) &&
224        (m_Classifier instanceof OptionHandler)) {
225      newVector.addElement(new Option(
226            "",
227            "", 0, "\nOptions specific to learner "
228            + m_Classifier.getClass().getName()
229            + ":"));
230      Enumeration enu = ((OptionHandler)m_Classifier).listOptions();
231      while (enu.hasMoreElements()) {
232        newVector.addElement(enu.nextElement());
233      }
234    }
235    return newVector.elements();
236  }
237
238  /**
239   * Parses a given list of options. <p/>
240   *
241   <!-- options-start -->
242   * Valid options are: <p/>
243   *
244   * <pre> -c &lt;class index&gt;
245   *  The index of the class attribute.
246   *  (default last)</pre>
247   *
248   * <pre> -t &lt;name of arff file&gt;
249   *  The name of the arff file used for the decomposition.</pre>
250   *
251   * <pre> -T &lt;training pool size&gt;
252   *  The number of instances placed in the training pool.
253   *  The remainder will be used for testing. (default 100)</pre>
254   *
255   * <pre> -s &lt;seed&gt;
256   *  The random number seed used.</pre>
257   *
258   * <pre> -x &lt;num&gt;
259   *  The number of training repetitions used.
260   *  (default 50)</pre>
261   *
262   * <pre> -D
263   *  Turn on debugging output.</pre>
264   *
265   * <pre> -W &lt;classifier class name&gt;
266   *  Full class name of the learner used in the decomposition.
267   *  eg: weka.classifiers.bayes.NaiveBayes</pre>
268   *
269   * <pre>
270   * Options specific to learner weka.classifiers.rules.ZeroR:
271   * </pre>
272   *
273   * <pre> -D
274   *  If set, classifier is run in debug mode and
275   *  may output additional info to the console</pre>
276   *
277   <!-- options-end -->
278   *
279   * Options after -- are passed to the designated sub-learner. <p>
280   *
281   * @param options the list of options as an array of strings
282   * @throws Exception if an option is not supported
283   */
284  public void setOptions(String[] options) throws Exception {
285
286    setDebug(Utils.getFlag('D', options));
287
288    String classIndex = Utils.getOption('c', options);
289    if (classIndex.length() != 0) {
290      if (classIndex.toLowerCase().equals("last")) {
291        setClassIndex(0);
292      } else if (classIndex.toLowerCase().equals("first")) {
293        setClassIndex(1);
294      } else {
295        setClassIndex(Integer.parseInt(classIndex));
296      }
297    } else {
298      setClassIndex(0);
299    }
300
301    String trainIterations = Utils.getOption('x', options);
302    if (trainIterations.length() != 0) {
303      setTrainIterations(Integer.parseInt(trainIterations));
304    } else {
305      setTrainIterations(50);
306    }
307
308    String trainPoolSize = Utils.getOption('T', options);
309    if (trainPoolSize.length() != 0) {
310      setTrainPoolSize(Integer.parseInt(trainPoolSize));
311    } else {
312      setTrainPoolSize(100);
313    }
314
315    String seedString = Utils.getOption('s', options);
316    if (seedString.length() != 0) {
317      setSeed(Integer.parseInt(seedString));
318    } else {
319      setSeed(1);
320    }
321
322    String dataFile = Utils.getOption('t', options);
323    if (dataFile.length() == 0) {
324      throw new Exception("An arff file must be specified"
325          + " with the -t option.");
326    }
327    setDataFileName(dataFile);
328
329    String classifierName = Utils.getOption('W', options);
330    if (classifierName.length() == 0) {
331      throw new Exception("A learner must be specified with the -W option.");
332    }
333    setClassifier(AbstractClassifier.forName(classifierName,
334          Utils.partitionOptions(options)));
335  }
336
337  /**
338   * Gets the current settings of the CheckClassifier.
339   *
340   * @return an array of strings suitable for passing to setOptions
341   */
342  public String [] getOptions() {
343
344    String [] classifierOptions = new String [0];
345    if ((m_Classifier != null) &&
346        (m_Classifier instanceof OptionHandler)) {
347      classifierOptions = ((OptionHandler)m_Classifier).getOptions();
348        }
349    String [] options = new String [classifierOptions.length + 14];
350    int current = 0;
351    if (getDebug()) {
352      options[current++] = "-D";
353    }
354    options[current++] = "-c"; options[current++] = "" + getClassIndex();
355    options[current++] = "-x"; options[current++] = "" + getTrainIterations();
356    options[current++] = "-T"; options[current++] = "" + getTrainPoolSize();
357    options[current++] = "-s"; options[current++] = "" + getSeed();
358    if (getDataFileName() != null) {
359      options[current++] = "-t"; options[current++] = "" + getDataFileName();
360    }
361    if (getClassifier() != null) {
362      options[current++] = "-W";
363      options[current++] = getClassifier().getClass().getName();
364    }
365    options[current++] = "--";
366    System.arraycopy(classifierOptions, 0, options, current,
367        classifierOptions.length);
368    current += classifierOptions.length;
369    while (current < options.length) {
370      options[current++] = "";
371    }
372    return options;
373  }
374
375  /**
376   * Get the number of instances in the training pool.
377   *
378   * @return number of instances in the training pool.
379   */
380  public int getTrainPoolSize() {
381
382    return m_TrainPoolSize;
383  }
384
385  /**
386   * Set the number of instances in the training pool.
387   *
388   * @param numTrain number of instances in the training pool.
389   */
390  public void setTrainPoolSize(int numTrain) {
391
392    m_TrainPoolSize = numTrain;
393  }
394
395  /**
396   * Set the classifiers being analysed
397   *
398   * @param newClassifier the Classifier to use.
399   */
400  public void setClassifier(Classifier newClassifier) {
401
402    m_Classifier = newClassifier;
403  }
404
405  /**
406   * Gets the name of the classifier being analysed
407   *
408   * @return the classifier being analysed.
409   */
410  public Classifier getClassifier() {
411
412    return m_Classifier;
413  }
414
415  /**
416   * Sets debugging mode
417   *
418   * @param debug true if debug output should be printed
419   */
420  public void setDebug(boolean debug) {
421
422    m_Debug = debug;
423  }
424
425  /**
426   * Gets whether debugging is turned on
427   *
428   * @return true if debugging output is on
429   */
430  public boolean getDebug() {
431
432    return m_Debug;
433  }
434
435  /**
436   * Sets the random number seed
437   *
438   * @param seed the random number seed
439   */
440  public void setSeed(int seed) {
441
442    m_Seed = seed;
443  }
444
445  /**
446   * Gets the random number seed
447   *
448   * @return the random number seed
449   */
450  public int getSeed() {
451
452    return m_Seed;
453  }
454
455  /**
456   * Sets the maximum number of boost iterations
457   *
458   * @param trainIterations the number of boost iterations
459   */
460  public void setTrainIterations(int trainIterations) {
461
462    m_TrainIterations = trainIterations;
463  }
464
465  /**
466   * Gets the maximum number of boost iterations
467   *
468   * @return the maximum number of boost iterations
469   */
470  public int getTrainIterations() {
471
472    return m_TrainIterations;
473  }
474
475  /**
476   * Sets the name of the data file used for the decomposition
477   *
478   * @param dataFileName the data file to use
479   */
480  public void setDataFileName(String dataFileName) {
481
482    m_DataFileName = dataFileName;
483  }
484
485  /**
486   * Get the name of the data file used for the decomposition
487   *
488   * @return the name of the data file
489   */
490  public String getDataFileName() {
491
492    return m_DataFileName;
493  }
494
495  /**
496   * Get the index (starting from 1) of the attribute used as the class.
497   *
498   * @return the index of the class attribute
499   */
500  public int getClassIndex() {
501
502    return m_ClassIndex + 1;
503  }
504
505  /**
506   * Sets index of attribute to discretize on
507   *
508   * @param classIndex the index (starting from 1) of the class attribute
509   */
510  public void setClassIndex(int classIndex) {
511
512    m_ClassIndex = classIndex - 1;
513  }
514
515  /**
516   * Get the calculated bias squared
517   *
518   * @return the bias squared
519   */
520  public double getBias() {
521
522    return m_Bias;
523  }
524
525  /**
526   * Get the calculated variance
527   *
528   * @return the variance
529   */
530  public double getVariance() {
531
532    return m_Variance;
533  }
534
535  /**
536   * Get the calculated sigma squared
537   *
538   * @return the sigma squared
539   */
540  public double getSigma() {
541
542    return m_Sigma;
543  }
544
545  /**
546   * Get the calculated error rate
547   *
548   * @return the error rate
549   */
550  public double getError() {
551
552    return m_Error;
553  }
554
555  /**
556   * Carry out the bias-variance decomposition
557   *
558   * @throws Exception if the decomposition couldn't be carried out
559   */
560  public void decompose() throws Exception {
561
562    Reader dataReader = new BufferedReader(new FileReader(m_DataFileName));
563    Instances data = new Instances(dataReader);
564
565    if (m_ClassIndex < 0) {
566      data.setClassIndex(data.numAttributes() - 1);
567    } else {
568      data.setClassIndex(m_ClassIndex);
569    }
570    if (data.classAttribute().type() != Attribute.NOMINAL) {
571      throw new Exception("Class attribute must be nominal");
572    }
573    int numClasses = data.numClasses();
574
575    data.deleteWithMissingClass();
576    if (data.checkForStringAttributes()) {
577      throw new Exception("Can't handle string attributes!");
578    }
579
580    if (data.numInstances() < 2 * m_TrainPoolSize) {
581      throw new Exception("The dataset must contain at least "
582          + (2 * m_TrainPoolSize) + " instances");
583    }
584    Random random = new Random(m_Seed);
585    data.randomize(random);
586    Instances trainPool = new Instances(data, 0, m_TrainPoolSize);
587    Instances test = new Instances(data, m_TrainPoolSize,
588        data.numInstances() - m_TrainPoolSize);
589    int numTest = test.numInstances();
590    double [][] instanceProbs = new double [numTest][numClasses];
591
592    m_Error = 0;
593    for (int i = 0; i < m_TrainIterations; i++) {
594      if (m_Debug) {
595        System.err.println("Iteration " + (i + 1));
596      }
597      trainPool.randomize(random);
598      Instances train = new Instances(trainPool, 0, m_TrainPoolSize / 2);
599
600      Classifier current = AbstractClassifier.makeCopy(m_Classifier);
601      current.buildClassifier(train);
602
603      //// Evaluate the classifier on test, updating BVD stats
604      for (int j = 0; j < numTest; j++) {
605        int pred = (int)current.classifyInstance(test.instance(j));
606        if (pred != test.instance(j).classValue()) {
607          m_Error++;
608        }
609        instanceProbs[j][pred]++;
610      }
611    }
612    m_Error /= (m_TrainIterations * numTest);
613
614    // Average the BV over each instance in test.
615    m_Bias = 0;
616    m_Variance = 0;
617    m_Sigma = 0;
618    for (int i = 0; i < numTest; i++) {
619      Instance current = test.instance(i);
620      double [] predProbs = instanceProbs[i];
621      double pActual, pPred;
622      double bsum = 0, vsum = 0, ssum = 0;
623      for (int j = 0; j < numClasses; j++) {
624        pActual = (current.classValue() == j) ? 1 : 0; // Or via 1NN from test data?
625        pPred = predProbs[j] / m_TrainIterations;
626        bsum += (pActual - pPred) * (pActual - pPred)
627          - pPred * (1 - pPred) / (m_TrainIterations - 1);
628        vsum += pPred * pPred;
629        ssum += pActual * pActual;
630      }
631      m_Bias += bsum;
632      m_Variance += (1 - vsum);
633      m_Sigma += (1 - ssum);
634    }
635    m_Bias /= (2 * numTest);
636    m_Variance /= (2 * numTest);
637    m_Sigma /= (2 * numTest);
638
639    if (m_Debug) {
640      System.err.println("Decomposition finished");
641    }
642  }
643
644
645  /**
646   * Returns description of the bias-variance decomposition results.
647   *
648   * @return the bias-variance decomposition results as a string
649   */
650  public String toString() {
651
652    String result = "\nBias-Variance Decomposition\n";
653
654    if (getClassifier() == null) {
655      return "Invalid setup";
656    }
657
658    result += "\nClassifier   : " + getClassifier().getClass().getName();
659    if (getClassifier() instanceof OptionHandler) {
660      result += Utils.joinOptions(((OptionHandler)m_Classifier).getOptions());
661    }
662    result += "\nData File    : " + getDataFileName();
663    result += "\nClass Index  : ";
664    if (getClassIndex() == 0) {
665      result += "last";
666    } else {
667      result += getClassIndex();
668    }
669    result += "\nTraining Pool: " + getTrainPoolSize();
670    result += "\nIterations   : " + getTrainIterations();
671    result += "\nSeed         : " + getSeed();
672    result += "\nError        : " + Utils.doubleToString(getError(), 6, 4);
673    result += "\nSigma^2      : " + Utils.doubleToString(getSigma(), 6, 4);
674    result += "\nBias^2       : " + Utils.doubleToString(getBias(), 6, 4);
675    result += "\nVariance     : " + Utils.doubleToString(getVariance(), 6, 4);
676
677    return result + "\n";
678  }
679
680  /**
681   * Returns the revision string.
682   *
683   * @return            the revision
684   */
685  public String getRevision() {
686    return RevisionUtils.extract("$Revision: 6041 $");
687  }
688
689  /**
690   * Test method for this class
691   *
692   * @param args the command line arguments
693   */
694  public static void main(String [] args) {
695
696    try {
697      BVDecompose bvd = new BVDecompose();
698
699      try {
700        bvd.setOptions(args);
701        Utils.checkForRemainingOptions(args);
702      } catch (Exception ex) {
703        String result = ex.getMessage() + "\nBVDecompose Options:\n\n";
704        Enumeration enu = bvd.listOptions();
705        while (enu.hasMoreElements()) {
706          Option option = (Option) enu.nextElement();
707          result += option.synopsis() + "\n" + option.description() + "\n";
708        }
709        throw new Exception(result);
710      }
711
712      bvd.decompose();
713      System.out.println(bvd.toString());
714    } catch (Exception ex) {
715      System.err.println(ex.getMessage());
716    }
717  }
718}
Note: See TracBrowser for help on using the repository browser.