source: src/main/java/weka/classifiers/meta/RandomCommittee.java @ 4

Last change on this file since 4 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 7.4 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    RandomCommittee.java
19 *    Copyright (C) 2003 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.meta;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.classifiers.RandomizableIteratedSingleClassifierEnhancer;
28import weka.classifiers.RandomizableParallelIteratedSingleClassifierEnhancer;
29import weka.core.Instance;
30import weka.core.Instances;
31import weka.core.Randomizable;
32import weka.core.RevisionUtils;
33import weka.core.Utils;
34import weka.core.WeightedInstancesHandler;
35
36import java.util.Random;
37
38/**
39 <!-- globalinfo-start -->
40 * Class for building an ensemble of randomizable base classifiers. Each base classifiers is built using a different random number seed (but based one the same data). The final prediction is a straight average of the predictions generated by the individual base classifiers.
41 * <p/>
42 <!-- globalinfo-end -->
43 *
44 <!-- options-start -->
45 * Valid options are: <p/>
46 *
47 * <pre> -S &lt;num&gt;
48 *  Random number seed.
49 *  (default 1)</pre>
50 *
51 * <pre> -I &lt;num&gt;
52 *  Number of iterations.
53 *  (default 10)</pre>
54 *
55 * <pre> -D
56 *  If set, classifier is run in debug mode and
57 *  may output additional info to the console</pre>
58 *
59 * <pre> -W
60 *  Full name of base classifier.
61 *  (default: weka.classifiers.trees.RandomTree)</pre>
62 *
63 * <pre>
64 * Options specific to classifier weka.classifiers.trees.RandomTree:
65 * </pre>
66 *
67 * <pre> -K &lt;number of attributes&gt;
68 *  Number of attributes to randomly investigate
69 *  (&lt;1 = int(log(#attributes)+1)).</pre>
70 *
71 * <pre> -M &lt;minimum number of instances&gt;
72 *  Set minimum number of instances per leaf.</pre>
73 *
74 * <pre> -S &lt;num&gt;
75 *  Seed for random number generator.
76 *  (default 1)</pre>
77 *
78 * <pre> -depth &lt;num&gt;
79 *  The maximum depth of the tree, 0 for unlimited.
80 *  (default 0)</pre>
81 *
82 * <pre> -D
83 *  If set, classifier is run in debug mode and
84 *  may output additional info to the console</pre>
85 *
86 <!-- options-end -->
87 *
88 * Options after -- are passed to the designated classifier.<p>
89 *
90 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
91 * @version $Revision: 5928 $
92 */
93public class RandomCommittee 
94  extends RandomizableParallelIteratedSingleClassifierEnhancer
95  implements WeightedInstancesHandler {
96   
97  /** for serialization */
98  static final long serialVersionUID = -9204394360557300092L;
99 
100  /** training data */
101  protected Instances m_data;
102 
103  /**
104   * Constructor.
105   */
106  public RandomCommittee() {
107   
108    m_Classifier = new weka.classifiers.trees.RandomTree();
109  }
110
111  /**
112   * String describing default classifier.
113   *
114   * @return the default classifier classname
115   */
116  protected String defaultClassifierString() {
117   
118    return "weka.classifiers.trees.RandomTree";
119  }
120
121  /**
122   * Returns a string describing classifier
123   * @return a description suitable for
124   * displaying in the explorer/experimenter gui
125   */
126  public String globalInfo() {
127 
128    return "Class for building an ensemble of randomizable base classifiers. Each "
129      + "base classifiers is built using a different random number seed (but based "
130      + "one the same data). The final prediction is a straight average of the "
131      + "predictions generated by the individual base classifiers.";
132  }
133
134  /**
135   * Builds the committee of randomizable classifiers.
136   *
137   * @param data the training data to be used for generating the
138   * bagged classifier.
139   * @exception Exception if the classifier could not be built successfully
140   */
141  public void buildClassifier(Instances data) throws Exception {
142
143    // can classifier handle the data?
144    getCapabilities().testWithFail(data);
145
146    // remove instances with missing class
147    m_data = new Instances(data);
148    m_data.deleteWithMissingClass();
149    super.buildClassifier(m_data);
150   
151    if (!(m_Classifier instanceof Randomizable)) {
152      throw new IllegalArgumentException("Base learner must implement Randomizable!");
153    }
154
155    m_Classifiers = AbstractClassifier.makeCopies(m_Classifier, m_NumIterations);
156
157    Random random = m_data.getRandomNumberGenerator(m_Seed);
158    for (int j = 0; j < m_Classifiers.length; j++) {
159
160      // Set the random number seed for the current classifier.
161      ((Randomizable) m_Classifiers[j]).setSeed(random.nextInt());
162     
163      // Build the classifier.
164//      m_Classifiers[j].buildClassifier(m_data);
165    }
166   
167    buildClassifiers();
168   
169    // save memory
170    m_data = null;
171  }
172 
173  /**
174   * Returns a training set for a particular iteration.
175   *
176   * @param iteration the number of the iteration for the requested training set.
177   * @return the training set for the supplied iteration number
178   * @throws Exception if something goes wrong when generating a training set.
179   */
180  protected synchronized Instances getTrainingSet(int iteration) throws Exception {
181   
182    // we don't manipulate the training data in any way.
183    return m_data;
184  }
185
186  /**
187   * Calculates the class membership probabilities for the given test
188   * instance.
189   *
190   * @param instance the instance to be classified
191   * @return preedicted class probability distribution
192   * @exception Exception if distribution can't be computed successfully
193   */
194  public double[] distributionForInstance(Instance instance) throws Exception {
195
196    double [] sums = new double [instance.numClasses()], newProbs; 
197   
198    for (int i = 0; i < m_NumIterations; i++) {
199      if (instance.classAttribute().isNumeric() == true) {
200        sums[0] += m_Classifiers[i].classifyInstance(instance);
201      } else {
202        newProbs = m_Classifiers[i].distributionForInstance(instance);
203        for (int j = 0; j < newProbs.length; j++)
204          sums[j] += newProbs[j];
205      }
206    }
207    if (instance.classAttribute().isNumeric() == true) {
208      sums[0] /= (double)m_NumIterations;
209      return sums;
210    } else if (Utils.eq(Utils.sum(sums), 0)) {
211      return sums;
212    } else {
213      Utils.normalize(sums);
214      return sums;
215    }
216  }
217
218  /**
219   * Returns description of the committee.
220   *
221   * @return description of the committee as a string
222   */
223  public String toString() {
224   
225    if (m_Classifiers == null) {
226      return "RandomCommittee: No model built yet.";
227    }
228    StringBuffer text = new StringBuffer();
229    text.append("All the base classifiers: \n\n");
230    for (int i = 0; i < m_Classifiers.length; i++)
231      text.append(m_Classifiers[i].toString() + "\n\n");
232
233    return text.toString();
234  }
235 
236  /**
237   * Returns the revision string.
238   *
239   * @return            the revision
240   */
241  public String getRevision() {
242    return RevisionUtils.extract("$Revision: 5928 $");
243  }
244
245  /**
246   * Main method for testing this class.
247   *
248   * @param argv the options
249   */
250  public static void main(String [] argv) {
251    runClassifier(new RandomCommittee(), argv);
252  }
253}
254
Note: See TracBrowser for help on using the repository browser.