source: src/main/java/weka/classifiers/trees/RandomForest.java @ 21

Last change on this file since 21 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 13.8 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    RandomForest.java
19 *    Copyright (C) 2001 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.trees;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.classifiers.meta.Bagging;
28import weka.core.AdditionalMeasureProducer;
29import weka.core.Capabilities;
30import weka.core.Instance;
31import weka.core.Instances;
32import weka.core.Option;
33import weka.core.OptionHandler;
34import weka.core.Randomizable;
35import weka.core.RevisionUtils;
36import weka.core.TechnicalInformation;
37import weka.core.TechnicalInformationHandler;
38import weka.core.Utils;
39import weka.core.WeightedInstancesHandler;
40import weka.core.TechnicalInformation.Field;
41import weka.core.TechnicalInformation.Type;
42
43import java.util.Enumeration;
44import java.util.Vector;
45
46/**
47 <!-- globalinfo-start -->
48 * Class for constructing a forest of random trees.<br/>
49 * <br/>
50 * For more information see: <br/>
51 * <br/>
52 * Leo Breiman (2001). Random Forests. Machine Learning. 45(1):5-32.
53 * <p/>
54 <!-- globalinfo-end -->
55 *
56 <!-- technical-bibtex-start -->
57 * BibTeX:
58 * <pre>
59 * &#64;article{Breiman2001,
60 *    author = {Leo Breiman},
61 *    journal = {Machine Learning},
62 *    number = {1},
63 *    pages = {5-32},
64 *    title = {Random Forests},
65 *    volume = {45},
66 *    year = {2001}
67 * }
68 * </pre>
69 * <p/>
70 <!-- technical-bibtex-end -->
71 *
72 <!-- options-start -->
73 * Valid options are: <p/>
74 *
75 * <pre> -I &lt;number of trees&gt;
76 *  Number of trees to build.</pre>
77 *
78 * <pre> -K &lt;number of features&gt;
79 *  Number of features to consider (&lt;1=int(logM+1)).</pre>
80 *
81 * <pre> -S
82 *  Seed for random number generator.
83 *  (default 1)</pre>
84 *
85 * <pre> -depth &lt;num&gt;
86 *  The maximum depth of the trees, 0 for unlimited.
87 *  (default 0)</pre>
88 *
89 * <pre> -D
90 *  If set, classifier is run in debug mode and
91 *  may output additional info to the console</pre>
92 *
93 <!-- options-end -->
94 *
95 * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
96 * @version $Revision: 5928 $
97 */
98public class RandomForest 
99  extends AbstractClassifier
100  implements OptionHandler, Randomizable, WeightedInstancesHandler, 
101             AdditionalMeasureProducer, TechnicalInformationHandler {
102
103  /** for serialization */
104  static final long serialVersionUID = 4216839470751428698L;
105 
106  /** Number of trees in forest. */
107  protected int m_numTrees = 10;
108
109  /** Number of features to consider in random feature selection.
110      If less than 1 will use int(logM+1) ) */
111  protected int m_numFeatures = 0;
112
113  /** The random seed. */
114  protected int m_randomSeed = 1; 
115
116  /** Final number of features that were considered in last build. */
117  protected int m_KValue = 0;
118
119  /** The bagger. */
120  protected Bagging m_bagger = null;
121 
122  /** The maximum depth of the trees (0 = unlimited) */
123  protected int m_MaxDepth = 0;
124
125  /**
126   * Returns a string describing classifier
127   * @return a description suitable for
128   * displaying in the explorer/experimenter gui
129   */
130  public String globalInfo() {
131
132    return 
133        "Class for constructing a forest of random trees.\n\n"
134      + "For more information see: \n\n"
135      + getTechnicalInformation().toString();
136  }
137
138  /**
139   * Returns an instance of a TechnicalInformation object, containing
140   * detailed information about the technical background of this class,
141   * e.g., paper reference or book this class is based on.
142   *
143   * @return the technical information about this class
144   */
145  public TechnicalInformation getTechnicalInformation() {
146    TechnicalInformation        result;
147   
148    result = new TechnicalInformation(Type.ARTICLE);
149    result.setValue(Field.AUTHOR, "Leo Breiman");
150    result.setValue(Field.YEAR, "2001");
151    result.setValue(Field.TITLE, "Random Forests");
152    result.setValue(Field.JOURNAL, "Machine Learning");
153    result.setValue(Field.VOLUME, "45");
154    result.setValue(Field.NUMBER, "1");
155    result.setValue(Field.PAGES, "5-32");
156   
157    return result;
158  }
159 
160  /**
161   * Returns the tip text for this property
162   * @return tip text for this property suitable for
163   * displaying in the explorer/experimenter gui
164   */
165  public String numTreesTipText() {
166    return "The number of trees to be generated.";
167  }
168
169  /**
170   * Get the value of numTrees.
171   *
172   * @return Value of numTrees.
173   */
174  public int getNumTrees() {
175   
176    return m_numTrees;
177  }
178 
179  /**
180   * Set the value of numTrees.
181   *
182   * @param newNumTrees Value to assign to numTrees.
183   */
184  public void setNumTrees(int newNumTrees) {
185   
186    m_numTrees = newNumTrees;
187  }
188 
189  /**
190   * Returns the tip text for this property
191   * @return tip text for this property suitable for
192   * displaying in the explorer/experimenter gui
193   */
194  public String numFeaturesTipText() {
195    return "The number of attributes to be used in random selection (see RandomTree).";
196  }
197
198  /**
199   * Get the number of features used in random selection.
200   *
201   * @return Value of numFeatures.
202   */
203  public int getNumFeatures() {
204   
205    return m_numFeatures;
206  }
207 
208  /**
209   * Set the number of features to use in random selection.
210   *
211   * @param newNumFeatures Value to assign to numFeatures.
212   */
213  public void setNumFeatures(int newNumFeatures) {
214   
215    m_numFeatures = newNumFeatures;
216  }
217 
218  /**
219   * Returns the tip text for this property
220   * @return tip text for this property suitable for
221   * displaying in the explorer/experimenter gui
222   */
223  public String seedTipText() {
224    return "The random number seed to be used.";
225  }
226
227  /**
228   * Set the seed for random number generation.
229   *
230   * @param seed the seed
231   */
232  public void setSeed(int seed) {
233
234    m_randomSeed = seed;
235  }
236 
237  /**
238   * Gets the seed for the random number generations
239   *
240   * @return the seed for the random number generation
241   */
242  public int getSeed() {
243
244    return m_randomSeed;
245  }
246 
247  /**
248   * Returns the tip text for this property
249   *
250   * @return            tip text for this property suitable for
251   *                    displaying in the explorer/experimenter gui
252   */
253  public String maxDepthTipText() {
254    return "The maximum depth of the trees, 0 for unlimited.";
255  }
256
257  /**
258   * Get the maximum depth of trh tree, 0 for unlimited.
259   *
260   * @return            the maximum depth.
261   */
262  public int getMaxDepth() {
263    return m_MaxDepth;
264  }
265 
266  /**
267   * Set the maximum depth of the tree, 0 for unlimited.
268   *
269   * @param value       the maximum depth.
270   */
271  public void setMaxDepth(int value) {
272    m_MaxDepth = value;
273  }
274
275  /**
276   * Gets the out of bag error that was calculated as the classifier was built.
277   *
278   * @return the out of bag error
279   */
280  public double measureOutOfBagError() {
281   
282    if (m_bagger != null) {
283      return m_bagger.measureOutOfBagError();
284    } else return Double.NaN;
285  }
286 
287  /**
288   * Returns an enumeration of the additional measure names.
289   *
290   * @return an enumeration of the measure names
291   */
292  public Enumeration enumerateMeasures() {
293   
294    Vector newVector = new Vector(1);
295    newVector.addElement("measureOutOfBagError");
296    return newVector.elements();
297  }
298 
299  /**
300   * Returns the value of the named measure.
301   *
302   * @param additionalMeasureName the name of the measure to query for its value
303   * @return the value of the named measure
304   * @throws IllegalArgumentException if the named measure is not supported
305   */
306  public double getMeasure(String additionalMeasureName) {
307   
308    if (additionalMeasureName.equalsIgnoreCase("measureOutOfBagError")) {
309      return measureOutOfBagError();
310    }
311    else {throw new IllegalArgumentException(additionalMeasureName
312                                             + " not supported (RandomForest)");
313    }
314  }
315
316  /**
317   * Returns an enumeration describing the available options.
318   *
319   * @return an enumeration of all the available options
320   */
321  public Enumeration listOptions() {
322   
323    Vector newVector = new Vector();
324
325    newVector.addElement(new Option(
326        "\tNumber of trees to build.",
327        "I", 1, "-I <number of trees>"));
328   
329    newVector.addElement(new Option(
330        "\tNumber of features to consider (<1=int(logM+1)).",
331        "K", 1, "-K <number of features>"));
332   
333    newVector.addElement(new Option(
334        "\tSeed for random number generator.\n"
335        + "\t(default 1)",
336        "S", 1, "-S"));
337
338    newVector.addElement(new Option(
339        "\tThe maximum depth of the trees, 0 for unlimited.\n"
340        + "\t(default 0)",
341        "depth", 1, "-depth <num>"));
342
343    Enumeration enu = super.listOptions();
344    while (enu.hasMoreElements()) {
345      newVector.addElement(enu.nextElement());
346    }
347
348    return newVector.elements();
349  }
350
351  /**
352   * Gets the current settings of the forest.
353   *
354   * @return an array of strings suitable for passing to setOptions()
355   */
356  public String[] getOptions() {
357    Vector        result;
358    String[]      options;
359    int           i;
360   
361    result = new Vector();
362   
363    result.add("-I");
364    result.add("" + getNumTrees());
365   
366    result.add("-K");
367    result.add("" + getNumFeatures());
368   
369    result.add("-S");
370    result.add("" + getSeed());
371   
372    if (getMaxDepth() > 0) {
373      result.add("-depth");
374      result.add("" + getMaxDepth());
375    }
376   
377    options = super.getOptions();
378    for (i = 0; i < options.length; i++)
379      result.add(options[i]);
380   
381    return (String[]) result.toArray(new String[result.size()]);
382  }
383
384  /**
385   * Parses a given list of options. <p/>
386   *
387   <!-- options-start -->
388   * Valid options are: <p/>
389   *
390   * <pre> -I &lt;number of trees&gt;
391   *  Number of trees to build.</pre>
392   *
393   * <pre> -K &lt;number of features&gt;
394   *  Number of features to consider (&lt;1=int(logM+1)).</pre>
395   *
396   * <pre> -S
397   *  Seed for random number generator.
398   *  (default 1)</pre>
399   *
400   * <pre> -depth &lt;num&gt;
401   *  The maximum depth of the trees, 0 for unlimited.
402   *  (default 0)</pre>
403   *
404   * <pre> -D
405   *  If set, classifier is run in debug mode and
406   *  may output additional info to the console</pre>
407   *
408   <!-- options-end -->
409   *
410   * @param options the list of options as an array of strings
411   * @throws Exception if an option is not supported
412   */
413  public void setOptions(String[] options) throws Exception{
414    String      tmpStr;
415   
416    tmpStr = Utils.getOption('I', options);
417    if (tmpStr.length() != 0) {
418      m_numTrees = Integer.parseInt(tmpStr);
419    } else {
420      m_numTrees = 10;
421    }
422   
423    tmpStr = Utils.getOption('K', options);
424    if (tmpStr.length() != 0) {
425      m_numFeatures = Integer.parseInt(tmpStr);
426    } else {
427      m_numFeatures = 0;
428    }
429   
430    tmpStr = Utils.getOption('S', options);
431    if (tmpStr.length() != 0) {
432      setSeed(Integer.parseInt(tmpStr));
433    } else {
434      setSeed(1);
435    }
436   
437    tmpStr = Utils.getOption("depth", options);
438    if (tmpStr.length() != 0) {
439      setMaxDepth(Integer.parseInt(tmpStr));
440    } else {
441      setMaxDepth(0);
442    }
443   
444    super.setOptions(options);
445   
446    Utils.checkForRemainingOptions(options);
447  } 
448
449  /**
450   * Returns default capabilities of the classifier.
451   *
452   * @return      the capabilities of this classifier
453   */
454  public Capabilities getCapabilities() {
455    return new RandomTree().getCapabilities();
456  }
457
458  /**
459   * Builds a classifier for a set of instances.
460   *
461   * @param data the instances to train the classifier with
462   * @throws Exception if something goes wrong
463   */
464  public void buildClassifier(Instances data) throws Exception {
465
466    // can classifier handle the data?
467    getCapabilities().testWithFail(data);
468
469    // remove instances with missing class
470    data = new Instances(data);
471    data.deleteWithMissingClass();
472   
473    m_bagger = new Bagging();
474    RandomTree rTree = new RandomTree();
475
476    // set up the random tree options
477    m_KValue = m_numFeatures;
478    if (m_KValue < 1) m_KValue = (int) Utils.log2(data.numAttributes())+1;
479    rTree.setKValue(m_KValue);
480    rTree.setMaxDepth(getMaxDepth());
481
482    // set up the bagger and build the forest
483    m_bagger.setClassifier(rTree);
484    m_bagger.setSeed(m_randomSeed);
485    m_bagger.setNumIterations(m_numTrees);
486    m_bagger.setCalcOutOfBag(true);
487    m_bagger.buildClassifier(data);
488  }
489
490  /**
491   * Returns the class probability distribution for an instance.
492   *
493   * @param instance the instance to be classified
494   * @return the distribution the forest generates for the instance
495   * @throws Exception if computation fails
496   */
497  public double[] distributionForInstance(Instance instance) throws Exception {
498
499    return m_bagger.distributionForInstance(instance);
500  }
501
502  /**
503   * Outputs a description of this classifier.
504   *
505   * @return a string containing a description of the classifier
506   */
507  public String toString() {
508
509    if (m_bagger == null) 
510      return "Random forest not built yet";
511    else 
512      return "Random forest of " + m_numTrees
513           + " trees, each constructed while considering "
514           + m_KValue + " random feature" + (m_KValue==1 ? "" : "s") + ".\n"
515           + "Out of bag error: "
516           + Utils.doubleToString(m_bagger.measureOutOfBagError(), 4) + "\n"
517           + (getMaxDepth() > 0 ? ("Max. depth of trees: " + getMaxDepth() + "\n") : (""))
518           + "\n";
519  }
520 
521  /**
522   * Returns the revision string.
523   *
524   * @return            the revision
525   */
526  public String getRevision() {
527    return RevisionUtils.extract("$Revision: 5928 $");
528  }
529
530  /**
531   * Main method for this class.
532   *
533   * @param argv the options
534   */
535  public static void main(String[] argv) {
536    runClassifier(new RandomForest(), argv);
537  }
538}
Note: See TracBrowser for help on using the repository browser.