source: src/main/java/weka/classifiers/meta/Bagging.java @ 4

Last change on this file since 4 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 18.9 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    Bagging.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.meta;
24
25import weka.classifiers.RandomizableIteratedSingleClassifierEnhancer;
26import weka.classifiers.RandomizableParallelIteratedSingleClassifierEnhancer;
27import weka.core.AdditionalMeasureProducer;
28import weka.core.Instance;
29import weka.core.Instances;
30import weka.core.Option;
31import weka.core.Randomizable;
32import weka.core.RevisionUtils;
33import weka.core.TechnicalInformation;
34import weka.core.TechnicalInformationHandler;
35import weka.core.Utils;
36import weka.core.WeightedInstancesHandler;
37import weka.core.TechnicalInformation.Field;
38import weka.core.TechnicalInformation.Type;
39
40import java.util.Enumeration;
41import java.util.Random;
42import java.util.Vector;
43
44/**
45 <!-- globalinfo-start -->
46 * Class for bagging a classifier to reduce variance. Can do classification and regression depending on the base learner. <br/>
47 * <br/>
48 * For more information, see<br/>
49 * <br/>
50 * Leo Breiman (1996). Bagging predictors. Machine Learning. 24(2):123-140.
51 * <p/>
52 <!-- globalinfo-end -->
53 *
54 <!-- technical-bibtex-start -->
55 * BibTeX:
56 * <pre>
57 * &#64;article{Breiman1996,
58 *    author = {Leo Breiman},
59 *    journal = {Machine Learning},
60 *    number = {2},
61 *    pages = {123-140},
62 *    title = {Bagging predictors},
63 *    volume = {24},
64 *    year = {1996}
65 * }
66 * </pre>
67 * <p/>
68 <!-- technical-bibtex-end -->
69 *
70 <!-- options-start -->
71 * Valid options are: <p/>
72 *
73 * <pre> -P
74 *  Size of each bag, as a percentage of the
75 *  training set size. (default 100)</pre>
76 *
77 * <pre> -O
78 *  Calculate the out of bag error.</pre>
79 *
80 * <pre> -S &lt;num&gt;
81 *  Random number seed.
82 *  (default 1)</pre>
83 *
84 * <pre> -I &lt;num&gt;
85 *  Number of iterations.
86 *  (default 10)</pre>
87 *
88 * <pre> -D
89 *  If set, classifier is run in debug mode and
90 *  may output additional info to the console</pre>
91 *
92 * <pre> -W
93 *  Full name of base classifier.
94 *  (default: weka.classifiers.trees.REPTree)</pre>
95 *
96 * <pre>
97 * Options specific to classifier weka.classifiers.trees.REPTree:
98 * </pre>
99 *
100 * <pre> -M &lt;minimum number of instances&gt;
101 *  Set minimum number of instances per leaf (default 2).</pre>
102 *
103 * <pre> -V &lt;minimum variance for split&gt;
104 *  Set minimum numeric class variance proportion
105 *  of train variance for split (default 1e-3).</pre>
106 *
107 * <pre> -N &lt;number of folds&gt;
108 *  Number of folds for reduced error pruning (default 3).</pre>
109 *
110 * <pre> -S &lt;seed&gt;
111 *  Seed for random data shuffling (default 1).</pre>
112 *
113 * <pre> -P
114 *  No pruning.</pre>
115 *
116 * <pre> -L
117 *  Maximum tree depth (default -1, no maximum)</pre>
118 *
119 <!-- options-end -->
120 *
121 * Options after -- are passed to the designated classifier.<p>
122 *
123 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
124 * @author Len Trigg (len@reeltwo.com)
125 * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
126 * @version $Revision: 5801 $
127 */
128public class Bagging
129  extends RandomizableParallelIteratedSingleClassifierEnhancer
130  implements WeightedInstancesHandler, AdditionalMeasureProducer,
131             TechnicalInformationHandler {
132
133  /** for serialization */
134  static final long serialVersionUID = -505879962237199703L;
135 
136  /** The size of each bag sample, as a percentage of the training size */
137  protected int m_BagSizePercent = 100;
138
139  /** Whether to calculate the out of bag error */
140  protected boolean m_CalcOutOfBag = false;
141
142  /** The out of bag error that has been calculated */
143  protected double m_OutOfBagError; 
144   
145  /**
146   * Constructor.
147   */
148  public Bagging() {
149   
150    m_Classifier = new weka.classifiers.trees.REPTree();
151  }
152 
153  /**
154   * Returns a string describing classifier
155   * @return a description suitable for
156   * displaying in the explorer/experimenter gui
157   */
158  public String globalInfo() {
159 
160    return "Class for bagging a classifier to reduce variance. Can do classification "
161      + "and regression depending on the base learner. \n\n"
162      + "For more information, see\n\n"
163      + getTechnicalInformation().toString();
164  }
165
166  /**
167   * Returns an instance of a TechnicalInformation object, containing
168   * detailed information about the technical background of this class,
169   * e.g., paper reference or book this class is based on.
170   *
171   * @return the technical information about this class
172   */
173  public TechnicalInformation getTechnicalInformation() {
174    TechnicalInformation        result;
175   
176    result = new TechnicalInformation(Type.ARTICLE);
177    result.setValue(Field.AUTHOR, "Leo Breiman");
178    result.setValue(Field.YEAR, "1996");
179    result.setValue(Field.TITLE, "Bagging predictors");
180    result.setValue(Field.JOURNAL, "Machine Learning");
181    result.setValue(Field.VOLUME, "24");
182    result.setValue(Field.NUMBER, "2");
183    result.setValue(Field.PAGES, "123-140");
184   
185    return result;
186  }
187
188  /**
189   * String describing default classifier.
190   *
191   * @return the default classifier classname
192   */
193  protected String defaultClassifierString() {
194   
195    return "weka.classifiers.trees.REPTree";
196  }
197
198  /**
199   * Returns an enumeration describing the available options.
200   *
201   * @return an enumeration of all the available options.
202   */
203  public Enumeration listOptions() {
204
205    Vector newVector = new Vector(2);
206
207    newVector.addElement(new Option(
208              "\tSize of each bag, as a percentage of the\n" 
209              + "\ttraining set size. (default 100)",
210              "P", 1, "-P"));
211    newVector.addElement(new Option(
212              "\tCalculate the out of bag error.",
213              "O", 0, "-O"));
214
215    Enumeration enu = super.listOptions();
216    while (enu.hasMoreElements()) {
217      newVector.addElement(enu.nextElement());
218    }
219    return newVector.elements();
220  }
221
222
223  /**
224   * Parses a given list of options. <p/>
225   *
226   <!-- options-start -->
227   * Valid options are: <p/>
228   *
229   * <pre> -P
230   *  Size of each bag, as a percentage of the
231   *  training set size. (default 100)</pre>
232   *
233   * <pre> -O
234   *  Calculate the out of bag error.</pre>
235   *
236   * <pre> -S &lt;num&gt;
237   *  Random number seed.
238   *  (default 1)</pre>
239   *
240   * <pre> -I &lt;num&gt;
241   *  Number of iterations.
242   *  (default 10)</pre>
243   *
244   * <pre> -D
245   *  If set, classifier is run in debug mode and
246   *  may output additional info to the console</pre>
247   *
248   * <pre> -W
249   *  Full name of base classifier.
250   *  (default: weka.classifiers.trees.REPTree)</pre>
251   *
252   * <pre>
253   * Options specific to classifier weka.classifiers.trees.REPTree:
254   * </pre>
255   *
256   * <pre> -M &lt;minimum number of instances&gt;
257   *  Set minimum number of instances per leaf (default 2).</pre>
258   *
259   * <pre> -V &lt;minimum variance for split&gt;
260   *  Set minimum numeric class variance proportion
261   *  of train variance for split (default 1e-3).</pre>
262   *
263   * <pre> -N &lt;number of folds&gt;
264   *  Number of folds for reduced error pruning (default 3).</pre>
265   *
266   * <pre> -S &lt;seed&gt;
267   *  Seed for random data shuffling (default 1).</pre>
268   *
269   * <pre> -P
270   *  No pruning.</pre>
271   *
272   * <pre> -L
273   *  Maximum tree depth (default -1, no maximum)</pre>
274   *
275   <!-- options-end -->
276   *
277   * Options after -- are passed to the designated classifier.<p>
278   *
279   * @param options the list of options as an array of strings
280   * @throws Exception if an option is not supported
281   */
282  public void setOptions(String[] options) throws Exception {
283
284    String bagSize = Utils.getOption('P', options);
285    if (bagSize.length() != 0) {
286      setBagSizePercent(Integer.parseInt(bagSize));
287    } else {
288      setBagSizePercent(100);
289    }
290
291    setCalcOutOfBag(Utils.getFlag('O', options));
292
293    super.setOptions(options);
294  }
295
296  /**
297   * Gets the current settings of the Classifier.
298   *
299   * @return an array of strings suitable for passing to setOptions
300   */
301  public String [] getOptions() {
302
303
304    String [] superOptions = super.getOptions();
305    String [] options = new String [superOptions.length + 3];
306
307    int current = 0;
308    options[current++] = "-P"; 
309    options[current++] = "" + getBagSizePercent();
310
311    if (getCalcOutOfBag()) { 
312      options[current++] = "-O";
313    }
314
315    System.arraycopy(superOptions, 0, options, current, 
316                     superOptions.length);
317
318    current += superOptions.length;
319    while (current < options.length) {
320      options[current++] = "";
321    }
322    return options;
323  }
324
325  /**
326   * Returns the tip text for this property
327   * @return tip text for this property suitable for
328   * displaying in the explorer/experimenter gui
329   */
330  public String bagSizePercentTipText() {
331    return "Size of each bag, as a percentage of the training set size.";
332  }
333
334  /**
335   * Gets the size of each bag, as a percentage of the training set size.
336   *
337   * @return the bag size, as a percentage.
338   */
339  public int getBagSizePercent() {
340
341    return m_BagSizePercent;
342  }
343 
344  /**
345   * Sets the size of each bag, as a percentage of the training set size.
346   *
347   * @param newBagSizePercent the bag size, as a percentage.
348   */
349  public void setBagSizePercent(int newBagSizePercent) {
350
351    m_BagSizePercent = newBagSizePercent;
352  }
353
354  /**
355   * Returns the tip text for this property
356   * @return tip text for this property suitable for
357   * displaying in the explorer/experimenter gui
358   */
359  public String calcOutOfBagTipText() {
360    return "Whether the out-of-bag error is calculated.";
361  }
362
363  /**
364   * Set whether the out of bag error is calculated.
365   *
366   * @param calcOutOfBag whether to calculate the out of bag error
367   */
368  public void setCalcOutOfBag(boolean calcOutOfBag) {
369
370    m_CalcOutOfBag = calcOutOfBag;
371  }
372
373  /**
374   * Get whether the out of bag error is calculated.
375   *
376   * @return whether the out of bag error is calculated
377   */
378  public boolean getCalcOutOfBag() {
379
380    return m_CalcOutOfBag;
381  }
382
383  /**
384   * Gets the out of bag error that was calculated as the classifier
385   * was built.
386   *
387   * @return the out of bag error
388   */
389  public double measureOutOfBagError() {
390   
391    return m_OutOfBagError;
392  }
393 
394  /**
395   * Returns an enumeration of the additional measure names.
396   *
397   * @return an enumeration of the measure names
398   */
399  public Enumeration enumerateMeasures() {
400   
401    Vector newVector = new Vector(1);
402    newVector.addElement("measureOutOfBagError");
403    return newVector.elements();
404  }
405 
406  /**
407   * Returns the value of the named measure.
408   *
409   * @param additionalMeasureName the name of the measure to query for its value
410   * @return the value of the named measure
411   * @throws IllegalArgumentException if the named measure is not supported
412   */
413  public double getMeasure(String additionalMeasureName) {
414   
415    if (additionalMeasureName.equalsIgnoreCase("measureOutOfBagError")) {
416      return measureOutOfBagError();
417    }
418    else {throw new IllegalArgumentException(additionalMeasureName
419                                             + " not supported (Bagging)");
420    }
421  }
422
423  /**
424   * Creates a new dataset of the same size using random sampling
425   * with replacement according to the given weight vector. The
426   * weights of the instances in the new dataset are set to one.
427   * The length of the weight vector has to be the same as the
428   * number of instances in the dataset, and all weights have to
429   * be positive.
430   *
431   * @param data the data to be sampled from
432   * @param random a random number generator
433   * @param sampled indicating which instance has been sampled
434   * @return the new dataset
435   * @throws IllegalArgumentException if the weights array is of the wrong
436   * length or contains negative weights.
437   */
438  public final Instances resampleWithWeights(Instances data,
439                                             Random random, 
440                                             boolean[] sampled) {
441
442    double[] weights = new double[data.numInstances()];
443    for (int i = 0; i < weights.length; i++) {
444      weights[i] = data.instance(i).weight();
445    }
446    Instances newData = new Instances(data, data.numInstances());
447    if (data.numInstances() == 0) {
448      return newData;
449    }
450    double[] probabilities = new double[data.numInstances()];
451    double sumProbs = 0, sumOfWeights = Utils.sum(weights);
452    for (int i = 0; i < data.numInstances(); i++) {
453      sumProbs += random.nextDouble();
454      probabilities[i] = sumProbs;
455    }
456    Utils.normalize(probabilities, sumProbs / sumOfWeights);
457
458    // Make sure that rounding errors don't mess things up
459    probabilities[data.numInstances() - 1] = sumOfWeights;
460    int k = 0; int l = 0;
461    sumProbs = 0;
462    while ((k < data.numInstances() && (l < data.numInstances()))) {
463      if (weights[l] < 0) {
464        throw new IllegalArgumentException("Weights have to be positive.");
465      }
466      sumProbs += weights[l];
467      while ((k < data.numInstances()) &&
468             (probabilities[k] <= sumProbs)) { 
469        newData.add(data.instance(l));
470        sampled[l] = true;
471        newData.instance(k).setWeight(1);
472        k++;
473      }
474      l++;
475    }
476    return newData;
477  }
478 
479  protected Random m_random;
480  protected boolean[][] m_inBag;
481  protected Instances m_data;
482 
483  /**
484   * Returns a training set for a particular iteration.
485   *
486   * @param iteration the number of the iteration for the requested training set.
487   * @return the training set for the supplied iteration number
488   * @throws Exception if something goes wrong when generating a training set.
489   */
490  protected synchronized Instances getTrainingSet(int iteration) throws Exception {
491    int bagSize = m_data.numInstances() * m_BagSizePercent / 100;
492    Instances bagData = null;
493
494    // create the in-bag dataset
495    if (m_CalcOutOfBag) {
496      m_inBag[iteration] = new boolean[m_data.numInstances()];
497      bagData = resampleWithWeights(m_data, m_random, m_inBag[iteration]);
498    } else {
499      bagData = m_data.resampleWithWeights(m_random);
500      if (bagSize < m_data.numInstances()) {
501        bagData.randomize(m_random);
502        Instances newBagData = new Instances(bagData, 0, bagSize);
503        bagData = newBagData;
504      }
505    }
506   
507    return bagData;
508  }
509 
510  /**
511   * Bagging method.
512   *
513   * @param data the training data to be used for generating the
514   * bagged classifier.
515   * @throws Exception if the classifier could not be built successfully
516   */
517  public void buildClassifier(Instances data) throws Exception {
518
519    // can classifier handle the data?
520    getCapabilities().testWithFail(data);
521
522    // remove instances with missing class
523    m_data = new Instances(data);
524    m_data.deleteWithMissingClass();
525   
526    super.buildClassifier(m_data);
527
528    if (m_CalcOutOfBag && (m_BagSizePercent != 100)) {
529      throw new IllegalArgumentException("Bag size needs to be 100% if " +
530                                         "out-of-bag error is to be calculated!");
531    }
532
533    int bagSize = m_data.numInstances() * m_BagSizePercent / 100;
534    m_random = new Random(m_Seed);
535   
536    m_inBag = null;
537    if (m_CalcOutOfBag)
538      m_inBag = new boolean[m_Classifiers.length][];
539   
540    for (int j = 0; j < m_Classifiers.length; j++) {     
541      if (m_Classifier instanceof Randomizable) {
542        ((Randomizable) m_Classifiers[j]).setSeed(m_random.nextInt());
543      }
544    }
545   
546    buildClassifiers();
547   
548    // calc OOB error?
549    if (getCalcOutOfBag()) {
550      double outOfBagCount = 0.0;
551      double errorSum = 0.0;
552      boolean numeric = m_data.classAttribute().isNumeric();
553     
554      for (int i = 0; i < m_data.numInstances(); i++) {
555        double vote;
556        double[] votes;
557        if (numeric)
558          votes = new double[1];
559        else
560          votes = new double[m_data.numClasses()];
561       
562        // determine predictions for instance
563        int voteCount = 0;
564        for (int j = 0; j < m_Classifiers.length; j++) {
565          if (m_inBag[j][i])
566            continue;
567         
568          voteCount++;
569          double pred = m_Classifiers[j].classifyInstance(m_data.instance(i));
570          if (numeric)
571            votes[0] += pred;
572          else
573            votes[(int) pred]++;
574        }
575       
576        // "vote"
577        if (numeric) {
578          vote = votes[0];
579          if (voteCount > 0) {
580            vote  /= voteCount;    // average
581          }
582        } else {
583          vote = Utils.maxIndex(votes);   // majority vote
584        }
585       
586        // error for instance
587        outOfBagCount += m_data.instance(i).weight();
588        if (numeric) {
589          errorSum += StrictMath.abs(vote - m_data.instance(i).classValue()) 
590          * m_data.instance(i).weight();
591        }
592        else {
593          if (vote != m_data.instance(i).classValue())
594            errorSum += m_data.instance(i).weight();
595        }
596      }
597     
598      m_OutOfBagError = errorSum / outOfBagCount;
599    }
600    else {
601      m_OutOfBagError = 0;
602    }
603   
604    // save memory
605    m_data = null;
606  }
607
608  /**
609   * Calculates the class membership probabilities for the given test
610   * instance.
611   *
612   * @param instance the instance to be classified
613   * @return preedicted class probability distribution
614   * @throws Exception if distribution can't be computed successfully
615   */
616  public double[] distributionForInstance(Instance instance) throws Exception {
617
618    double [] sums = new double [instance.numClasses()], newProbs; 
619   
620    for (int i = 0; i < m_NumIterations; i++) {
621      if (instance.classAttribute().isNumeric() == true) {
622        sums[0] += m_Classifiers[i].classifyInstance(instance);
623      } else {
624        newProbs = m_Classifiers[i].distributionForInstance(instance);
625        for (int j = 0; j < newProbs.length; j++)
626          sums[j] += newProbs[j];
627      }
628    }
629    if (instance.classAttribute().isNumeric() == true) {
630      sums[0] /= (double)m_NumIterations;
631      return sums;
632    } else if (Utils.eq(Utils.sum(sums), 0)) {
633      return sums;
634    } else {
635      Utils.normalize(sums);
636      return sums;
637    }
638  }
639
640  /**
641   * Returns description of the bagged classifier.
642   *
643   * @return description of the bagged classifier as a string
644   */
645  public String toString() {
646   
647    if (m_Classifiers == null) {
648      return "Bagging: No model built yet.";
649    }
650    StringBuffer text = new StringBuffer();
651    text.append("All the base classifiers: \n\n");
652    for (int i = 0; i < m_Classifiers.length; i++)
653      text.append(m_Classifiers[i].toString() + "\n\n");
654   
655    if (m_CalcOutOfBag) {
656      text.append("Out of bag error: "
657                  + Utils.doubleToString(m_OutOfBagError, 4)
658                  + "\n\n");
659    }
660
661    return text.toString();
662  }
663 
664  /**
665   * Returns the revision string.
666   *
667   * @return            the revision
668   */
669  public String getRevision() {
670    return RevisionUtils.extract("$Revision: 5801 $");
671  }
672
673  /**
674   * Main method for testing this class.
675   *
676   * @param argv the options
677   */
678  public static void main(String [] argv) {
679    runClassifier(new Bagging(), argv);
680  }
681}
Note: See TracBrowser for help on using the repository browser.