source: src/main/java/weka/classifiers/meta/Stacking.java @ 4

Last change on this file since 4 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 15.8 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    Stacking.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.meta;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.classifiers.RandomizableMultipleClassifiersCombiner;
28import weka.classifiers.RandomizableParallelMultipleClassifiersCombiner;
29import weka.classifiers.rules.ZeroR;
30import weka.core.Attribute;
31import weka.core.Capabilities;
32import weka.core.FastVector;
33import weka.core.Instance;
34import weka.core.DenseInstance;
35import weka.core.Instances;
36import weka.core.Option;
37import weka.core.OptionHandler;
38import weka.core.RevisionUtils;
39import weka.core.TechnicalInformation;
40import weka.core.TechnicalInformationHandler;
41import weka.core.Utils;
42import weka.core.TechnicalInformation.Field;
43import weka.core.TechnicalInformation.Type;
44
45import java.util.Enumeration;
46import java.util.Random;
47import java.util.Vector;
48
49/**
50 <!-- globalinfo-start -->
51 * Combines several classifiers using the stacking method. Can do classification or regression.<br/>
52 * <br/>
53 * For more information, see<br/>
54 * <br/>
55 * David H. Wolpert (1992). Stacked generalization. Neural Networks. 5:241-259.
56 * <p/>
57 <!-- globalinfo-end -->
58 *
59 <!-- technical-bibtex-start -->
60 * BibTeX:
61 * <pre>
62 * &#64;article{Wolpert1992,
63 *    author = {David H. Wolpert},
64 *    journal = {Neural Networks},
65 *    pages = {241-259},
66 *    publisher = {Pergamon Press},
67 *    title = {Stacked generalization},
68 *    volume = {5},
69 *    year = {1992}
70 * }
71 * </pre>
72 * <p/>
73 <!-- technical-bibtex-end -->
74 *
75 <!-- options-start -->
76 * Valid options are: <p/>
77 *
78 * <pre> -M &lt;scheme specification&gt;
79 *  Full name of meta classifier, followed by options.
80 *  (default: "weka.classifiers.rules.Zero")</pre>
81 *
82 * <pre> -X &lt;number of folds&gt;
83 *  Sets the number of cross-validation folds.</pre>
84 *
85 * <pre> -S &lt;num&gt;
86 *  Random number seed.
87 *  (default 1)</pre>
88 *
89 * <pre> -B &lt;classifier specification&gt;
90 *  Full class name of classifier to include, followed
91 *  by scheme options. May be specified multiple times.
92 *  (default: "weka.classifiers.rules.ZeroR")</pre>
93 *
94 * <pre> -D
95 *  If set, classifier is run in debug mode and
96 *  may output additional info to the console</pre>
97 *
98 <!-- options-end -->
99 *
100 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
101 * @version $Revision: 5987 $
102 */
103public class Stacking 
104  extends RandomizableParallelMultipleClassifiersCombiner
105  implements TechnicalInformationHandler {
106
107  /** for serialization */
108  static final long serialVersionUID = 5134738557155845452L;
109 
110  /** The meta classifier */
111  protected Classifier m_MetaClassifier = new ZeroR();
112 
113  /** Format for meta data */
114  protected Instances m_MetaFormat = null;
115
116  /** Format for base data */
117  protected Instances m_BaseFormat = null;
118
119  /** Set the number of folds for the cross-validation */
120  protected int m_NumFolds = 10;
121 
122  /**
123   * Returns a string describing classifier
124   * @return a description suitable for
125   * displaying in the explorer/experimenter gui
126   */
127  public String globalInfo() {
128
129    return "Combines several classifiers using the stacking method. "
130      + "Can do classification or regression.\n\n"
131      + "For more information, see\n\n"
132      + getTechnicalInformation().toString();
133  }
134
135  /**
136   * Returns an instance of a TechnicalInformation object, containing
137   * detailed information about the technical background of this class,
138   * e.g., paper reference or book this class is based on.
139   *
140   * @return the technical information about this class
141   */
142  public TechnicalInformation getTechnicalInformation() {
143    TechnicalInformation        result;
144   
145    result = new TechnicalInformation(Type.ARTICLE);
146    result.setValue(Field.AUTHOR, "David H. Wolpert");
147    result.setValue(Field.YEAR, "1992");
148    result.setValue(Field.TITLE, "Stacked generalization");
149    result.setValue(Field.JOURNAL, "Neural Networks");
150    result.setValue(Field.VOLUME, "5");
151    result.setValue(Field.PAGES, "241-259");
152    result.setValue(Field.PUBLISHER, "Pergamon Press");
153   
154    return result;
155  }
156 
157  /**
158   * Returns an enumeration describing the available options.
159   *
160   * @return an enumeration of all the available options.
161   */
162  public Enumeration listOptions() {
163   
164    Vector newVector = new Vector(2);
165    newVector.addElement(new Option(
166              metaOption(),
167              "M", 0, "-M <scheme specification>"));
168    newVector.addElement(new Option(
169              "\tSets the number of cross-validation folds.",
170              "X", 1, "-X <number of folds>"));
171
172    Enumeration enu = super.listOptions();
173    while (enu.hasMoreElements()) {
174      newVector.addElement(enu.nextElement());
175    }
176    return newVector.elements();
177  }
178
179  /**
180   * String describing option for setting meta classifier
181   *
182   * @return the string describing the option
183   */
184  protected String metaOption() {
185
186    return "\tFull name of meta classifier, followed by options.\n" +
187      "\t(default: \"weka.classifiers.rules.Zero\")";
188  }
189
190  /**
191   * Parses a given list of options. <p/>
192   *
193   <!-- options-start -->
194   * Valid options are: <p/>
195   *
196   * <pre> -M &lt;scheme specification&gt;
197   *  Full name of meta classifier, followed by options.
198   *  (default: "weka.classifiers.rules.Zero")</pre>
199   *
200   * <pre> -X &lt;number of folds&gt;
201   *  Sets the number of cross-validation folds.</pre>
202   *
203   * <pre> -S &lt;num&gt;
204   *  Random number seed.
205   *  (default 1)</pre>
206   *
207   * <pre> -B &lt;classifier specification&gt;
208   *  Full class name of classifier to include, followed
209   *  by scheme options. May be specified multiple times.
210   *  (default: "weka.classifiers.rules.ZeroR")</pre>
211   *
212   * <pre> -D
213   *  If set, classifier is run in debug mode and
214   *  may output additional info to the console</pre>
215   *
216   <!-- options-end -->
217   *
218   * @param options the list of options as an array of strings
219   * @throws Exception if an option is not supported
220   */
221  public void setOptions(String[] options) throws Exception {
222
223    String numFoldsString = Utils.getOption('X', options);
224    if (numFoldsString.length() != 0) {
225      setNumFolds(Integer.parseInt(numFoldsString));
226    } else {
227      setNumFolds(10);
228    }
229    processMetaOptions(options);
230    super.setOptions(options);
231  }
232
233  /**
234   * Process options setting meta classifier.
235   *
236   * @param options the options to parse
237   * @throws Exception if the parsing fails
238   */
239  protected void processMetaOptions(String[] options) throws Exception {
240
241    String classifierString = Utils.getOption('M', options);
242    String [] classifierSpec = Utils.splitOptions(classifierString);
243    String classifierName;
244    if (classifierSpec.length == 0) {
245      classifierName = "weka.classifiers.rules.ZeroR";
246    } else {
247      classifierName = classifierSpec[0];
248      classifierSpec[0] = "";
249    }
250    setMetaClassifier(AbstractClassifier.forName(classifierName, classifierSpec));
251  }
252
253  /**
254   * Gets the current settings of the Classifier.
255   *
256   * @return an array of strings suitable for passing to setOptions
257   */
258  public String [] getOptions() {
259
260    String [] superOptions = super.getOptions();
261    String [] options = new String [superOptions.length + 4];
262
263    int current = 0;
264    options[current++] = "-X"; options[current++] = "" + getNumFolds();
265    options[current++] = "-M";
266    options[current++] = getMetaClassifier().getClass().getName() + " "
267      + Utils.joinOptions(((OptionHandler)getMetaClassifier()).getOptions());
268
269    System.arraycopy(superOptions, 0, options, current, 
270                     superOptions.length);
271    return options;
272  }
273 
274  /**
275   * Returns the tip text for this property
276   * @return tip text for this property suitable for
277   * displaying in the explorer/experimenter gui
278   */
279  public String numFoldsTipText() {
280    return "The number of folds used for cross-validation.";
281  }
282
283  /**
284   * Gets the number of folds for the cross-validation.
285   *
286   * @return the number of folds for the cross-validation
287   */
288  public int getNumFolds() {
289
290    return m_NumFolds;
291  }
292
293  /**
294   * Sets the number of folds for the cross-validation.
295   *
296   * @param numFolds the number of folds for the cross-validation
297   * @throws Exception if parameter illegal
298   */
299  public void setNumFolds(int numFolds) throws Exception {
300   
301    if (numFolds < 0) {
302      throw new IllegalArgumentException("Stacking: Number of cross-validation " +
303                                         "folds must be positive.");
304    }
305    m_NumFolds = numFolds;
306  }
307 
308  /**
309   * Returns the tip text for this property
310   * @return tip text for this property suitable for
311   * displaying in the explorer/experimenter gui
312   */
313  public String metaClassifierTipText() {
314    return "The meta classifiers to be used.";
315  }
316
317  /**
318   * Adds meta classifier
319   *
320   * @param classifier the classifier with all options set.
321   */
322  public void setMetaClassifier(Classifier classifier) {
323
324    m_MetaClassifier = classifier;
325  }
326 
327  /**
328   * Gets the meta classifier.
329   *
330   * @return the meta classifier
331   */
332  public Classifier getMetaClassifier() {
333   
334    return m_MetaClassifier;
335  }
336
337  /**
338   * Returns combined capabilities of the base classifiers, i.e., the
339   * capabilities all of them have in common.
340   *
341   * @return      the capabilities of the base classifiers
342   */
343  public Capabilities getCapabilities() {
344    Capabilities      result;
345   
346    result = super.getCapabilities();
347    result.setMinimumNumberInstances(getNumFolds());
348
349    return result;
350  }
351 
352  /**
353   * Buildclassifier selects a classifier from the set of classifiers
354   * by minimising error on the training data.
355   *
356   * @param data the training data to be used for generating the
357   * boosted classifier.
358   * @throws Exception if the classifier could not be built successfully
359   */
360  public void buildClassifier(Instances data) throws Exception {
361
362    if (m_MetaClassifier == null) {
363      throw new IllegalArgumentException("No meta classifier has been set");
364    }
365
366    // can classifier handle the data?
367    getCapabilities().testWithFail(data);
368
369    // remove instances with missing class
370    Instances newData = new Instances(data);
371    m_BaseFormat = new Instances(data, 0);
372    newData.deleteWithMissingClass();
373   
374    Random random = new Random(m_Seed);
375    newData.randomize(random);
376    if (newData.classAttribute().isNominal()) {
377      newData.stratify(m_NumFolds);
378    }
379
380    // Create meta level
381    generateMetaLevel(newData, random);
382 
383    // restart the executor pool because at the end of processing
384    // a set of classifiers it gets shutdown to prevent the program
385    // executing as a server
386    super.buildClassifier(newData);
387   
388    // Rebuild all the base classifiers on the full training data
389    buildClassifiers(newData);
390  }
391
392  /**
393   * Generates the meta data
394   *
395   * @param newData the data to work on
396   * @param random the random number generator to use for cross-validation
397   * @throws Exception if generation fails
398   */
399  protected void generateMetaLevel(Instances newData, Random random) 
400    throws Exception {
401
402    Instances metaData = metaFormat(newData);
403    m_MetaFormat = new Instances(metaData, 0);
404    for (int j = 0; j < m_NumFolds; j++) {
405      Instances train = newData.trainCV(m_NumFolds, j, random);
406     
407      // start the executor pool (if necessary)
408      // has to be done after each set of classifiers as the
409      // executor pool gets shut down in order to prevent the
410      // program executing as a server (and not returning to
411      // the command prompt when run from the command line
412      super.buildClassifier(train);
413     
414      // construct the actual classifiers
415      buildClassifiers(train);
416     
417      // Classify test instances and add to meta data
418      Instances test = newData.testCV(m_NumFolds, j);
419      for (int i = 0; i < test.numInstances(); i++) {
420        metaData.add(metaInstance(test.instance(i)));
421      }
422    }
423
424    m_MetaClassifier.buildClassifier(metaData);   
425  }
426
427  /**
428   * Returns class probabilities.
429   *
430   * @param instance the instance to be classified
431   * @return the distribution
432   * @throws Exception if instance could not be classified
433   * successfully
434   */
435  public double[] distributionForInstance(Instance instance) throws Exception {
436
437    return m_MetaClassifier.distributionForInstance(metaInstance(instance));
438  }
439
440  /**
441   * Output a representation of this classifier
442   *
443   * @return a string representation of the classifier
444   */
445  public String toString() {
446
447    if (m_Classifiers.length == 0) {
448      return "Stacking: No base schemes entered.";
449    }
450    if (m_MetaClassifier == null) {
451      return "Stacking: No meta scheme selected.";
452    }
453    if (m_MetaFormat == null) {
454      return "Stacking: No model built yet.";
455    }
456    String result = "Stacking\n\nBase classifiers\n\n";
457    for (int i = 0; i < m_Classifiers.length; i++) {
458      result += getClassifier(i).toString() +"\n\n";
459    }
460   
461    result += "\n\nMeta classifier\n\n";
462    result += m_MetaClassifier.toString();
463
464    return result;
465  }
466
467  /**
468   * Makes the format for the level-1 data.
469   *
470   * @param instances the level-0 format
471   * @return the format for the meta data
472   * @throws Exception if the format generation fails
473   */
474  protected Instances metaFormat(Instances instances) throws Exception {
475
476    FastVector attributes = new FastVector();
477    Instances metaFormat;
478
479    for (int k = 0; k < m_Classifiers.length; k++) {
480      Classifier classifier = (Classifier) getClassifier(k);
481      String name = classifier.getClass().getName();
482      if (m_BaseFormat.classAttribute().isNumeric()) {
483        attributes.addElement(new Attribute(name));
484      } else {
485        for (int j = 0; j < m_BaseFormat.classAttribute().numValues(); j++) {
486          attributes.addElement(new Attribute(name + ":" + 
487                                              m_BaseFormat
488                                              .classAttribute().value(j)));
489        }
490      }
491    }
492    attributes.addElement(m_BaseFormat.classAttribute().copy());
493    metaFormat = new Instances("Meta format", attributes, 0);
494    metaFormat.setClassIndex(metaFormat.numAttributes() - 1);
495    return metaFormat;
496  }
497
498  /**
499   * Makes a level-1 instance from the given instance.
500   *
501   * @param instance the instance to be transformed
502   * @return the level-1 instance
503   * @throws Exception if the instance generation fails
504   */
505  protected Instance metaInstance(Instance instance) throws Exception {
506
507    double[] values = new double[m_MetaFormat.numAttributes()];
508    Instance metaInstance;
509    int i = 0;
510    for (int k = 0; k < m_Classifiers.length; k++) {
511      Classifier classifier = getClassifier(k);
512      if (m_BaseFormat.classAttribute().isNumeric()) {
513        values[i++] = classifier.classifyInstance(instance);
514      } else {
515        double[] dist = classifier.distributionForInstance(instance);
516        for (int j = 0; j < dist.length; j++) {
517          values[i++] = dist[j];
518        }
519      }
520    }
521    values[i] = instance.classValue();
522    metaInstance = new DenseInstance(1, values);
523    metaInstance.setDataset(m_MetaFormat);
524    return metaInstance;
525  }
526 
527  /**
528   * Returns the revision string.
529   *
530   * @return            the revision
531   */
532  public String getRevision() {
533    return RevisionUtils.extract("$Revision: 5987 $");
534  }
535
536  /**
537   * Main method for testing this class.
538   *
539   * @param argv should contain the following arguments:
540   * -t training file [-T test file] [-c class index]
541   */
542  public static void main(String [] argv) {
543    runClassifier(new Stacking(), argv);
544  }
545}
Note: See TracBrowser for help on using the repository browser.