source: src/main/java/weka/classifiers/meta/EnsembleSelection.java @ 7

Last change on this file since 7 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 56.7 KB
RevLine 
[4]1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    EnsembleSelection.java
19 *    Copyright (C) 2006 David Michael
20 *
21 */
22
23package weka.classifiers.meta;
24
25import weka.classifiers.Evaluation;
26import weka.classifiers.RandomizableClassifier;
27import weka.classifiers.meta.ensembleSelection.EnsembleMetricHelper;
28import weka.classifiers.meta.ensembleSelection.EnsembleSelectionLibrary;
29import weka.classifiers.meta.ensembleSelection.EnsembleSelectionLibraryModel;
30import weka.classifiers.meta.ensembleSelection.ModelBag;
31import weka.classifiers.trees.REPTree;
32import weka.classifiers.xml.XMLClassifier;
33import weka.core.Capabilities;
34import weka.core.Instance;
35import weka.core.Instances;
36import weka.core.Option;
37import weka.core.RevisionUtils;
38import weka.core.SelectedTag;
39import weka.core.Tag;
40import weka.core.TechnicalInformation;
41import weka.core.TechnicalInformationHandler;
42import weka.core.Utils;
43import weka.core.Capabilities.Capability;
44import weka.core.TechnicalInformation.Field;
45import weka.core.TechnicalInformation.Type;
46import weka.core.xml.KOML;
47import weka.core.xml.XMLOptions;
48import weka.core.xml.XMLSerialization;
49
50import java.io.BufferedInputStream;
51import java.io.BufferedOutputStream;
52import java.io.BufferedReader;
53import java.io.File;
54import java.io.FileInputStream;
55import java.io.FileOutputStream;
56import java.io.FileReader;
57import java.io.InputStream;
58import java.io.ObjectInputStream;
59import java.io.ObjectOutputStream;
60import java.io.OutputStream;
61import java.util.Date;
62import java.util.Enumeration;
63import java.util.HashMap;
64import java.util.Iterator;
65import java.util.Map;
66import java.util.Random;
67import java.util.Set;
68import java.util.Vector;
69import java.util.zip.GZIPInputStream;
70import java.util.zip.GZIPOutputStream;
71
72/**
73 <!-- globalinfo-start -->
74 * Combines several classifiers using the ensemble selection method. For more information, see: Caruana, Rich, Niculescu, Alex, Crew, Geoff, and Ksikes, Alex, Ensemble Selection from Libraries of Models, The International Conference on Machine Learning (ICML'04), 2004.  Implemented in Weka by Bob Jung and David Michael.
75 * <p/>
76 <!-- globalinfo-end -->
77 *
78 <!-- technical-bibtex-start -->
79 * BibTeX:
80 * <pre>
81 * &#64;inproceedings{RichCaruana2004,
82 *    author = {Rich Caruana, Alex Niculescu, Geoff Crew, and Alex Ksikes},
83 *    booktitle = {21st International Conference on Machine Learning},
84 *    title = {Ensemble Selection from Libraries of Models},
85 *    year = {2004}
86 * }
87 * </pre>
88 * <p/>
89 <!-- technical-bibtex-end -->
90 *
91 * Our implementation of ensemble selection is a bit different from the other
92 * classifiers because we assume that the list of models to be trained is too
93 * large to fit in memory and that our base classifiers will need to be
94 * serialized to the file system (in the directory listed in the "workingDirectory
95 * option).  We have adopted the term "model library" for this large set of
96 * classifiers keeping in line with the original paper.
97 * <p/>
98 *
99 * If you are planning to use this classifier, we highly recommend you take a
100 * quick look at our FAQ/tutorial on the WIKI.  There are a few things that
101 * are unique to this classifier that could trip you up.  Otherwise, this
102 * method is a great way to get really great classifier performance without
103 * having to do too much parameter tuning.  What is nice is that in the worst
104 * case you get a nice summary of how s large number of diverse models
105 * performed on your data set. 
106 * <p/>
107 *
108 * This class relies on the package weka.classifiers.meta.ensembleSelection.
109 * <p/>
110 *
111 * When run from the Explorer or another GUI, the classifier depends on the
112 * package weka.gui.libraryEditor.
113 * <p/>
114 *
115 <!-- options-start -->
116 * Valid options are: <p/>
117 *
118 * <pre> -L &lt;/path/to/modelLibrary&gt;
119 *  Specifies the Model Library File, continuing the list of all models.</pre>
120 *
121 * <pre> -W &lt;/path/to/working/directory&gt;
122 *  Specifies the Working Directory, where all models will be stored.</pre>
123 *
124 * <pre> -B &lt;numModelBags&gt;
125 *  Set the number of bags, i.e., number of iterations to run
126 *  the ensemble selection algorithm.</pre>
127 *
128 * <pre> -E &lt;modelRatio&gt;
129 *  Set the ratio of library models that will be randomly chosen
130 *  to populate each bag of models.</pre>
131 *
132 * <pre> -V &lt;validationRatio&gt;
133 *  Set the ratio of the training data set that will be reserved
134 *  for validation.</pre>
135 *
136 * <pre> -H &lt;hillClimbIterations&gt;
137 *  Set the number of hillclimbing iterations to be performed
138 *  on each model bag.</pre>
139 *
140 * <pre> -I &lt;sortInitialization&gt;
141 *  Set the the ratio of the ensemble library that the sort
142 *  initialization algorithm will be able to choose from while
143 *  initializing the ensemble for each model bag</pre>
144 *
145 * <pre> -X &lt;numFolds&gt;
146 *  Sets the number of cross-validation folds.</pre>
147 *
148 * <pre> -P &lt;hillclimbMettric&gt;
149 *  Specify the metric that will be used for model selection
150 *  during the hillclimbing algorithm.
151 *  Valid metrics are:
152 *   accuracy, rmse, roc, precision, recall, fscore, all</pre>
153 *
154 * <pre> -A &lt;algorithm&gt;
155 *  Specifies the algorithm to be used for ensemble selection.
156 *  Valid algorithms are:
157 *   "forward" (default) for forward selection.
158 *   "backward" for backward elimination.
159 *   "both" for both forward and backward elimination.
160 *   "best" to simply print out top performer from the
161 *      ensemble library
162 *   "library" to only train the models in the ensemble
163 *      library</pre>
164 *
165 * <pre> -R
166 *  Flag whether or not models can be selected more than once
167 *  for an ensemble.</pre>
168 *
169 * <pre> -G
170 *  Whether sort initialization greedily stops adding models
171 *  when performance degrades.</pre>
172 *
173 * <pre> -O
174 *  Flag for verbose output. Prints out performance of all
175 *  selected models.</pre>
176 *
177 * <pre> -S &lt;num&gt;
178 *  Random number seed.
179 *  (default 1)</pre>
180 *
181 * <pre> -D
182 *  If set, classifier is run in debug mode and
183 *  may output additional info to the console</pre>
184 *
185 <!-- options-end -->
186 *
187 * @author Robert Jung
188 * @author David Michael
189 * @version $Revision: 5480 $
190 */
191public class EnsembleSelection 
192  extends RandomizableClassifier
193  implements TechnicalInformationHandler {
194
195  /** for serialization */
196  private static final long serialVersionUID = -1744155148765058511L;
197
198  /**
199   * The Library of models, from which we can select our ensemble. Usually
200   * loaded from a model list file (.mlf or .model.xml) using the -L
201   * command-line option.
202   */
203  protected EnsembleSelectionLibrary m_library = new EnsembleSelectionLibrary();
204 
205  /**
206   * List of models chosen by EnsembleSelection. Populated by buildClassifier.
207   */
208  protected EnsembleSelectionLibraryModel[] m_chosen_models = null;
209 
210  /**
211   * An array of weights for the chosen models. Elements are parallel to those
212   * in m_chosen_models. That is, m_chosen_model_weights[i] is the weight
213   * associated with the model at m_chosen_models[i].
214   */
215  protected int[] m_chosen_model_weights = null;
216 
217  /** Total weight of all chosen models. */
218  protected int m_total_weight = 0;
219 
220  /**
221   * ratio of library models that will be randomly chosen to be used for each
222   * model bag
223   */
224  protected double m_modelRatio = 0.5;
225 
226  /**
227   * Indicates the fraction of the given training set that should be used for
228   * hillclimbing/validation. This fraction is set aside and not used for
229   * training. It is assumed that any loaded models were also not trained on
230   * set-aside data. (If the same percentage and random seed were used
231   * previously to train the models in the library, this will work as expected -
232   * i.e., those models will be valid)
233   */
234  protected double m_validationRatio = 0.25;
235 
236  /** defines metrics that can be chosen for hillclimbing */
237  public static final Tag[] TAGS_METRIC = {
238    new Tag(EnsembleMetricHelper.METRIC_ACCURACY, "Optimize with Accuracy"),
239    new Tag(EnsembleMetricHelper.METRIC_RMSE, "Optimize with RMSE"),
240    new Tag(EnsembleMetricHelper.METRIC_ROC, "Optimize with ROC"),
241    new Tag(EnsembleMetricHelper.METRIC_PRECISION, "Optimize with precision"),
242    new Tag(EnsembleMetricHelper.METRIC_RECALL, "Optimize with recall"),
243    new Tag(EnsembleMetricHelper.METRIC_FSCORE, "Optimize with fscore"),
244    new Tag(EnsembleMetricHelper.METRIC_ALL, "Optimize with all metrics"), };
245 
246  /**
247   * The "enumeration" of the algorithms we can use. Forward - forward
248   * selection. For hillclimb iterations,
249   */
250  public static final int ALGORITHM_FORWARD = 0;
251 
252  public static final int ALGORITHM_BACKWARD = 1;
253 
254  public static final int ALGORITHM_FORWARD_BACKWARD = 2;
255 
256  public static final int ALGORITHM_BEST = 3;
257 
258  public static final int ALGORITHM_BUILD_LIBRARY = 4;
259 
260  /** defines metrics that can be chosen for hillclimbing */
261  public static final Tag[] TAGS_ALGORITHM = {
262    new Tag(ALGORITHM_FORWARD, "Forward selection"),
263    new Tag(ALGORITHM_BACKWARD, "Backward elimation"),
264    new Tag(ALGORITHM_FORWARD_BACKWARD, "Forward Selection + Backward Elimination"),
265    new Tag(ALGORITHM_BEST, "Best model"),
266    new Tag(ALGORITHM_BUILD_LIBRARY, "Build Library Only") };
267 
268  /**
269   * this specifies the number of "Ensembl-X" directories that are allowed to
270   * be created in the users home directory where X is the number of the
271   * ensemble
272   */
273  private static final int MAX_DEFAULT_DIRECTORIES = 1000;
274 
275  /**
276   * The name of the Model Library File (if one is specified) which lists
277   * models from which ensemble selection will choose. This is only used when
278   * run from the command-line, as otherwise m_library is responsible for
279   * this.
280   */
281  protected String m_modelLibraryFileName = null;
282 
283  /**
284   * The number of "model bags". Using 1 is equivalent to no bagging at all.
285   */
286  protected int m_numModelBags = 10;
287 
288  /** The metric for which the ensemble will be optimized. */
289  protected int m_hillclimbMetric = EnsembleMetricHelper.METRIC_RMSE;
290 
291  /** The algorithm used for ensemble selection. */
292  protected int m_algorithm = ALGORITHM_FORWARD;
293 
294  /**
295   * number of hillclimbing iterations for the ensemble selection algorithm
296   */
297  protected int m_hillclimbIterations = 100;
298 
299  /** ratio of library models to be used for sort initialization */
300  protected double m_sortInitializationRatio = 1.0;
301 
302  /**
303   * specifies whether or not the ensemble algorithm is allowed to include a
304   * specific model in the library more than once in each ensemble
305   */
306  protected boolean m_replacement = true;
307 
308  /**
309   * specifies whether we use "greedy" sort initialization. If false, we
310   * simply add the best m_sortInitializationRatio models of the bag blindly.
311   * If true, we add the best models in order up to m_sortInitializationRatio
312   * until adding the next model would not help performance.
313   */
314  protected boolean m_greedySortInitialization = true;
315 
316  /**
317   * Specifies whether or not we will output metrics for all models
318   */
319  protected boolean m_verboseOutput = false;
320 
321  /**
322   * Hash map of cached predictions. The key is a stringified Instance. Each
323   * entry is a 2d array, first indexed by classifier index (i.e., the one
324   * used in m_chosen_model). The second index is the usual "distribution"
325   * index across classes.
326   */
327  protected Map m_cachedPredictions = null;
328 
329  /**
330   * This string will store the working directory where all models , temporary
331   * prediction values, and modellist logs are to be built and stored.
332   */
333  protected File m_workingDirectory = new File(getDefaultWorkingDirectory());
334 
335  /**
336   * Indicates the number of folds for cross-validation. A value of 1
337   * indicates there is no cross-validation. Cross validation is done in the
338   * "embedded" fashion described by Caruana, Niculescu, and Munson
339   * (unpublished work - tech report forthcoming)
340   */
341  protected int m_NumFolds = 1;
342 
343  /**
344   * Returns a string describing classifier
345   *
346   * @return a description suitable for displaying in the
347   *         explorer/experimenter gui
348   */
349  public String globalInfo() {
350   
351    return "Combines several classifiers using the ensemble "
352    + "selection method. For more information, see: "
353    + "Caruana, Rich, Niculescu, Alex, Crew, Geoff, and Ksikes, Alex, "
354    + "Ensemble Selection from Libraries of Models, "
355    + "The International Conference on Machine Learning (ICML'04), 2004.  "
356    + "Implemented in Weka by Bob Jung and David Michael.";
357  }
358 
359  /**
360   * Returns an enumeration describing the available options.
361   *
362   * @return an enumeration of all the available options.
363   */
364  public Enumeration listOptions() {
365    Vector result = new Vector();
366   
367    result.addElement(new Option(
368        "\tSpecifies the Model Library File, continuing the list of all models.",
369        "L", 1, "-L </path/to/modelLibrary>"));
370   
371    result.addElement(new Option(
372        "\tSpecifies the Working Directory, where all models will be stored.",
373        "W", 1, "-W </path/to/working/directory>"));
374   
375    result.addElement(new Option(
376        "\tSet the number of bags, i.e., number of iterations to run \n"
377        + "\tthe ensemble selection algorithm.",
378        "B", 1, "-B <numModelBags>"));
379   
380    result.addElement(new Option(
381        "\tSet the ratio of library models that will be randomly chosen \n"
382        + "\tto populate each bag of models.",
383        "E", 1, "-E <modelRatio>"));
384   
385    result.addElement(new Option(
386        "\tSet the ratio of the training data set that will be reserved \n"
387        + "\tfor validation.",
388        "V", 1, "-V <validationRatio>"));
389   
390    result.addElement(new Option(
391        "\tSet the number of hillclimbing iterations to be performed \n"
392        + "\ton each model bag.",
393        "H", 1, "-H <hillClimbIterations>"));
394   
395    result.addElement(new Option(
396        "\tSet the the ratio of the ensemble library that the sort \n"
397        + "\tinitialization algorithm will be able to choose from while \n"
398        + "\tinitializing the ensemble for each model bag",
399        "I", 1, "-I <sortInitialization>"));
400   
401    result.addElement(new Option(
402        "\tSets the number of cross-validation folds.", 
403        "X", 1, "-X <numFolds>"));
404   
405    result.addElement(new Option(
406        "\tSpecify the metric that will be used for model selection \n"
407        + "\tduring the hillclimbing algorithm.\n"
408        + "\tValid metrics are: \n"
409        + "\t\taccuracy, rmse, roc, precision, recall, fscore, all",
410        "P", 1, "-P <hillclimbMettric>"));
411   
412    result.addElement(new Option(
413        "\tSpecifies the algorithm to be used for ensemble selection. \n"
414        + "\tValid algorithms are:\n"
415        + "\t\t\"forward\" (default) for forward selection.\n"
416        + "\t\t\"backward\" for backward elimination.\n"
417        + "\t\t\"both\" for both forward and backward elimination.\n"
418        + "\t\t\"best\" to simply print out top performer from the \n"
419        + "\t\t   ensemble library\n"
420        + "\t\t\"library\" to only train the models in the ensemble \n"
421        + "\t\t   library",
422        "A", 1, "-A <algorithm>"));
423   
424    result.addElement(new Option(
425        "\tFlag whether or not models can be selected more than once \n"
426        + "\tfor an ensemble.",
427        "R", 0, "-R"));
428   
429    result.addElement(new Option(
430        "\tWhether sort initialization greedily stops adding models \n"
431        + "\twhen performance degrades.",
432        "G", 0, "-G"));
433   
434    result.addElement(new Option(
435        "\tFlag for verbose output. Prints out performance of all \n"
436        + "\tselected models.",
437        "O", 0, "-O"));
438   
439    // TODO - Add more options here
440    Enumeration enu = super.listOptions();
441    while (enu.hasMoreElements()) {
442      result.addElement(enu.nextElement());
443    }
444   
445    return result.elements();
446  }
447 
448  /**
449   * We return true for basically everything except for Missing class values,
450   * because we can't really answer for all the models in our library. If any of
451   * them don't work with the supplied data then we just trap the exception.
452   *
453   * @return      the capabilities of this classifier
454   */
455  public Capabilities getCapabilities() {
456    Capabilities result = super.getCapabilities(); // returns the object
457    result.disableAll();
458    // from
459    // weka.classifiers.Classifier
460   
461    // attributes
462    result.enable(Capability.NOMINAL_ATTRIBUTES);
463    result.enable(Capability.NUMERIC_ATTRIBUTES);
464    result.enable(Capability.DATE_ATTRIBUTES);
465    result.enable(Capability.MISSING_VALUES);
466    result.enable(Capability.BINARY_ATTRIBUTES);
467   
468    // class
469    result.enable(Capability.NOMINAL_CLASS);
470    result.enable(Capability.NUMERIC_CLASS);
471    result.enable(Capability.BINARY_CLASS);
472   
473    return result;
474  }
475 
476  /**
477   <!-- options-start -->
478   * Valid options are: <p/>
479   *
480   * <pre> -L &lt;/path/to/modelLibrary&gt;
481   *  Specifies the Model Library File, continuing the list of all models.</pre>
482   *
483   * <pre> -W &lt;/path/to/working/directory&gt;
484   *  Specifies the Working Directory, where all models will be stored.</pre>
485   *
486   * <pre> -B &lt;numModelBags&gt;
487   *  Set the number of bags, i.e., number of iterations to run
488   *  the ensemble selection algorithm.</pre>
489   *
490   * <pre> -E &lt;modelRatio&gt;
491   *  Set the ratio of library models that will be randomly chosen
492   *  to populate each bag of models.</pre>
493   *
494   * <pre> -V &lt;validationRatio&gt;
495   *  Set the ratio of the training data set that will be reserved
496   *  for validation.</pre>
497   *
498   * <pre> -H &lt;hillClimbIterations&gt;
499   *  Set the number of hillclimbing iterations to be performed
500   *  on each model bag.</pre>
501   *
502   * <pre> -I &lt;sortInitialization&gt;
503   *  Set the the ratio of the ensemble library that the sort
504   *  initialization algorithm will be able to choose from while
505   *  initializing the ensemble for each model bag</pre>
506   *
507   * <pre> -X &lt;numFolds&gt;
508   *  Sets the number of cross-validation folds.</pre>
509   *
510   * <pre> -P &lt;hillclimbMettric&gt;
511   *  Specify the metric that will be used for model selection
512   *  during the hillclimbing algorithm.
513   *  Valid metrics are:
514   *   accuracy, rmse, roc, precision, recall, fscore, all</pre>
515   *
516   * <pre> -A &lt;algorithm&gt;
517   *  Specifies the algorithm to be used for ensemble selection.
518   *  Valid algorithms are:
519   *   "forward" (default) for forward selection.
520   *   "backward" for backward elimination.
521   *   "both" for both forward and backward elimination.
522   *   "best" to simply print out top performer from the
523   *      ensemble library
524   *   "library" to only train the models in the ensemble
525   *      library</pre>
526   *
527   * <pre> -R
528   *  Flag whether or not models can be selected more than once
529   *  for an ensemble.</pre>
530   *
531   * <pre> -G
532   *  Whether sort initialization greedily stops adding models
533   *  when performance degrades.</pre>
534   *
535   * <pre> -O
536   *  Flag for verbose output. Prints out performance of all
537   *  selected models.</pre>
538   *
539   * <pre> -S &lt;num&gt;
540   *  Random number seed.
541   *  (default 1)</pre>
542   *
543   * <pre> -D
544   *  If set, classifier is run in debug mode and
545   *  may output additional info to the console</pre>
546   *
547   <!-- options-end -->
548   *
549   * @param options
550   *            the list of options as an array of strings
551   * @throws Exception
552   *                if an option is not supported
553   */
554  public void setOptions(String[] options) throws Exception {
555    String      tmpStr;
556   
557    tmpStr = Utils.getOption('L', options);
558    if (tmpStr.length() != 0) {
559      m_modelLibraryFileName = tmpStr;
560      m_library = new EnsembleSelectionLibrary(m_modelLibraryFileName);
561    } else {
562      setLibrary(new EnsembleSelectionLibrary());
563      // setLibrary(new Library(super.m_Classifiers));
564    }
565   
566    tmpStr = Utils.getOption('W', options);
567    if (tmpStr.length() != 0 && validWorkingDirectory(tmpStr)) {
568      m_workingDirectory = new File(tmpStr);
569    } else {
570      m_workingDirectory = new File(getDefaultWorkingDirectory());
571    }
572    m_library.setWorkingDirectory(m_workingDirectory);
573   
574    tmpStr = Utils.getOption('E', options);
575    if (tmpStr.length() != 0) {
576      setModelRatio(Double.parseDouble(tmpStr));
577    } else {
578      setModelRatio(1.0);
579    }
580   
581    tmpStr = Utils.getOption('V', options);
582    if (tmpStr.length() != 0) {
583      setValidationRatio(Double.parseDouble(tmpStr));
584    } else {
585      setValidationRatio(0.25);
586    }
587   
588    tmpStr = Utils.getOption('B', options);
589    if (tmpStr.length() != 0) {
590      setNumModelBags(Integer.parseInt(tmpStr));
591    } else {
592      setNumModelBags(10);
593    }
594   
595    tmpStr = Utils.getOption('H', options);
596    if (tmpStr.length() != 0) {
597      setHillclimbIterations(Integer.parseInt(tmpStr));
598    } else {
599      setHillclimbIterations(100);
600    }
601   
602    tmpStr = Utils.getOption('I', options);
603    if (tmpStr.length() != 0) {
604      setSortInitializationRatio(Double.parseDouble(tmpStr));
605    } else {
606      setSortInitializationRatio(1.0);
607    }
608   
609    tmpStr = Utils.getOption('X', options);
610    if (tmpStr.length() != 0) {
611      setNumFolds(Integer.parseInt(tmpStr));
612    } else {
613      setNumFolds(10);
614    }
615   
616    setReplacement(Utils.getFlag('R', options));
617   
618    setGreedySortInitialization(Utils.getFlag('G', options));
619   
620    setVerboseOutput(Utils.getFlag('O', options));
621   
622    tmpStr = Utils.getOption('P', options);
623    // if (hillclimbMetricString.length() != 0) {
624   
625    if (tmpStr.toLowerCase().equals("accuracy")) {
626      setHillclimbMetric(new SelectedTag(
627          EnsembleMetricHelper.METRIC_ACCURACY, TAGS_METRIC));
628    } else if (tmpStr.toLowerCase().equals("rmse")) {
629      setHillclimbMetric(new SelectedTag(
630          EnsembleMetricHelper.METRIC_RMSE, TAGS_METRIC));
631    } else if (tmpStr.toLowerCase().equals("roc")) {
632      setHillclimbMetric(new SelectedTag(
633          EnsembleMetricHelper.METRIC_ROC, TAGS_METRIC));
634    } else if (tmpStr.toLowerCase().equals("precision")) {
635      setHillclimbMetric(new SelectedTag(
636          EnsembleMetricHelper.METRIC_PRECISION, TAGS_METRIC));
637    } else if (tmpStr.toLowerCase().equals("recall")) {
638      setHillclimbMetric(new SelectedTag(
639          EnsembleMetricHelper.METRIC_RECALL, TAGS_METRIC));
640    } else if (tmpStr.toLowerCase().equals("fscore")) {
641      setHillclimbMetric(new SelectedTag(
642          EnsembleMetricHelper.METRIC_FSCORE, TAGS_METRIC));
643    } else if (tmpStr.toLowerCase().equals("all")) {
644      setHillclimbMetric(new SelectedTag(
645          EnsembleMetricHelper.METRIC_ALL, TAGS_METRIC));
646    } else {
647      setHillclimbMetric(new SelectedTag(
648          EnsembleMetricHelper.METRIC_RMSE, TAGS_METRIC));
649    }
650   
651    tmpStr = Utils.getOption('A', options);
652    if (tmpStr.toLowerCase().equals("forward")) {
653      setAlgorithm(new SelectedTag(ALGORITHM_FORWARD, TAGS_ALGORITHM));
654    } else if (tmpStr.toLowerCase().equals("backward")) {
655      setAlgorithm(new SelectedTag(ALGORITHM_BACKWARD, TAGS_ALGORITHM));
656    } else if (tmpStr.toLowerCase().equals("both")) {
657      setAlgorithm(new SelectedTag(ALGORITHM_FORWARD_BACKWARD, TAGS_ALGORITHM));
658    } else if (tmpStr.toLowerCase().equals("forward")) {
659      setAlgorithm(new SelectedTag(ALGORITHM_FORWARD, TAGS_ALGORITHM));
660    } else if (tmpStr.toLowerCase().equals("best")) {
661      setAlgorithm(new SelectedTag(ALGORITHM_BEST, TAGS_ALGORITHM));
662    } else if (tmpStr.toLowerCase().equals("library")) {
663      setAlgorithm(new SelectedTag(ALGORITHM_BUILD_LIBRARY, TAGS_ALGORITHM));
664    } else {
665      setAlgorithm(new SelectedTag(ALGORITHM_FORWARD, TAGS_ALGORITHM));
666    }
667   
668    super.setOptions(options);
669   
670    m_library.setDebug(m_Debug);
671  }
672 
673 
674  /**
675   * Gets the current settings of the Classifier.
676   *
677   * @return an array of strings suitable for passing to setOptions
678   */
679  public String[] getOptions() {
680    Vector        result;
681    String[]      options;
682    int           i;
683   
684    result  = new Vector();
685   
686    if (m_library.getModelListFile() != null) {
687      result.add("-L");
688      result.add("" + m_library.getModelListFile());
689    }
690   
691    if (!m_workingDirectory.equals("")) {
692      result.add("-W");
693      result.add("" + getWorkingDirectory());
694    }
695   
696    result.add("-P");
697    switch (getHillclimbMetric().getSelectedTag().getID()) {
698      case (EnsembleMetricHelper.METRIC_ACCURACY):
699        result.add("accuracy");
700      break;
701      case (EnsembleMetricHelper.METRIC_RMSE):
702        result.add("rmse");
703      break;
704      case (EnsembleMetricHelper.METRIC_ROC):
705        result.add("roc");
706      break;
707      case (EnsembleMetricHelper.METRIC_PRECISION):
708        result.add("precision");
709      break;
710      case (EnsembleMetricHelper.METRIC_RECALL):
711        result.add("recall");
712      break;
713      case (EnsembleMetricHelper.METRIC_FSCORE):
714        result.add("fscore");
715      break;
716      case (EnsembleMetricHelper.METRIC_ALL):
717        result.add("all");
718      break;
719    }
720   
721    result.add("-A");
722    switch (getAlgorithm().getSelectedTag().getID()) {
723      case (ALGORITHM_FORWARD):
724        result.add("forward");
725      break;
726      case (ALGORITHM_BACKWARD):
727        result.add("backward");
728      break;
729      case (ALGORITHM_FORWARD_BACKWARD):
730        result.add("both");
731      break;
732      case (ALGORITHM_BEST):
733        result.add("best");
734      break;
735      case (ALGORITHM_BUILD_LIBRARY):
736        result.add("library");
737      break;
738    }
739   
740    result.add("-B");
741    result.add("" + getNumModelBags());
742    result.add("-V");
743    result.add("" + getValidationRatio());
744    result.add("-E");
745    result.add("" + getModelRatio());
746    result.add("-H");
747    result.add("" + getHillclimbIterations());
748    result.add("-I");
749    result.add("" + getSortInitializationRatio());
750    result.add("-X");
751    result.add("" + getNumFolds());
752   
753    if (m_replacement)
754      result.add("-R");
755    if (m_greedySortInitialization)
756      result.add("-G");
757    if (m_verboseOutput)
758      result.add("-O");
759   
760    options = super.getOptions();
761    for (i = 0; i < options.length; i++)
762      result.add(options[i]);
763   
764    return (String[]) result.toArray(new String[result.size()]);
765  }
766 
767  /**
768   * Returns the tip text for this property
769   *
770   * @return tip text for this property suitable for displaying in the
771   *         explorer/experimenter gui
772   */
773  public String numFoldsTipText() {
774    return "The number of folds used for cross-validation.";
775  }
776 
777  /**
778   * Gets the number of folds for the cross-validation.
779   *
780   * @return the number of folds for the cross-validation
781   */
782  public int getNumFolds() {
783    return m_NumFolds;
784  }
785 
786  /**
787   * Sets the number of folds for the cross-validation.
788   *
789   * @param numFolds
790   *            the number of folds for the cross-validation
791   * @throws Exception
792   *                if parameter illegal
793   */
794  public void setNumFolds(int numFolds) throws Exception {
795    if (numFolds < 0) {
796      throw new IllegalArgumentException(
797          "EnsembleSelection: Number of cross-validation "
798          + "folds must be positive.");
799    }
800    m_NumFolds = numFolds;
801  }
802 
803  /**
804   * Returns the tip text for this property
805   *
806   * @return tip text for this property suitable for displaying in the
807   *         explorer/experimenter gui
808   */
809  public String libraryTipText() {
810    return "An ensemble library.";
811  }
812 
813  /**
814   * Gets the ensemble library.
815   *
816   * @return the ensemble library
817   */
818  public EnsembleSelectionLibrary getLibrary() {
819    return m_library;
820  }
821 
822  /**
823   * Sets the ensemble library.
824   *
825   * @param newLibrary
826   *            the ensemble library
827   */
828  public void setLibrary(EnsembleSelectionLibrary newLibrary) {
829    m_library = newLibrary;
830    m_library.setDebug(m_Debug);
831  }
832 
833  /**
834   * Returns the tip text for this property
835   *
836   * @return tip text for this property suitable for displaying in the
837   *         explorer/experimenter gui
838   */
839  public String modelRatioTipText() {
840    return "The ratio of library models that will be randomly chosen to be used for each iteration.";
841  }
842 
843  /**
844   * Get the value of modelRatio.
845   *
846   * @return Value of modelRatio.
847   */
848  public double getModelRatio() {
849    return m_modelRatio;
850  }
851 
852  /**
853   * Set the value of modelRatio.
854   *
855   * @param v
856   *            Value to assign to modelRatio.
857   */
858  public void setModelRatio(double v) {
859    m_modelRatio = v;
860  }
861 
862  /**
863   * Returns the tip text for this property
864   *
865   * @return tip text for this property suitable for displaying in the
866   *         explorer/experimenter gui
867   */
868  public String validationRatioTipText() {
869    return "The ratio of the training data set that will be reserved for validation.";
870  }
871 
872  /**
873   * Get the value of validationRatio.
874   *
875   * @return Value of validationRatio.
876   */
877  public double getValidationRatio() {
878    return m_validationRatio;
879  }
880 
881  /**
882   * Set the value of validationRatio.
883   *
884   * @param v
885   *            Value to assign to validationRatio.
886   */
887  public void setValidationRatio(double v) {
888    m_validationRatio = v;
889  }
890 
891  /**
892   * Returns the tip text for this property
893   *
894   * @return tip text for this property suitable for displaying in the
895   *         explorer/experimenter gui
896   */
897  public String hillclimbMetricTipText() {
898    return "the metric that will be used to optimizer the chosen ensemble..";
899  }
900 
901  /**
902   * Gets the hill climbing metric. Will be one of METRIC_ACCURACY,
903   * METRIC_RMSE, METRIC_ROC, METRIC_PRECISION, METRIC_RECALL, METRIC_FSCORE,
904   * METRIC_ALL
905   *
906   * @return the hillclimbMetric
907   */
908  public SelectedTag getHillclimbMetric() {
909    return new SelectedTag(m_hillclimbMetric, TAGS_METRIC);
910  }
911 
912  /**
913   * Sets the hill climbing metric. Will be one of METRIC_ACCURACY,
914   * METRIC_RMSE, METRIC_ROC, METRIC_PRECISION, METRIC_RECALL, METRIC_FSCORE,
915   * METRIC_ALL
916   *
917   * @param newType
918   *            the new hillclimbMetric
919   */
920  public void setHillclimbMetric(SelectedTag newType) {
921    if (newType.getTags() == TAGS_METRIC) {
922      m_hillclimbMetric = newType.getSelectedTag().getID();
923    }
924  }
925 
926  /**
927   * Returns the tip text for this property
928   *
929   * @return tip text for this property suitable for displaying in the
930   *         explorer/experimenter gui
931   */
932  public String algorithmTipText() {
933    return "the algorithm used to optimizer the ensemble";
934  }
935 
936  /**
937   * Gets the algorithm
938   *
939   * @return the algorithm
940   */
941  public SelectedTag getAlgorithm() {
942    return new SelectedTag(m_algorithm, TAGS_ALGORITHM);
943  }
944 
945  /**
946   * Sets the Algorithm to use
947   *
948   * @param newType
949   *            the new algorithm
950   */
951  public void setAlgorithm(SelectedTag newType) {
952    if (newType.getTags() == TAGS_ALGORITHM) {
953      m_algorithm = newType.getSelectedTag().getID();
954    }
955  }
956 
957  /**
958   * Returns the tip text for this property
959   *
960   * @return tip text for this property suitable for displaying in the
961   *         explorer/experimenter gui
962   */
963  public String hillclimbIterationsTipText() {
964    return "The number of hillclimbing iterations for the ensemble selection algorithm.";
965  }
966 
967  /**
968   * Gets the number of hillclimbIterations.
969   *
970   * @return the number of hillclimbIterations
971   */
972  public int getHillclimbIterations() {
973    return m_hillclimbIterations;
974  }
975 
976  /**
977   * Sets the number of hillclimbIterations.
978   *
979   * @param n
980   *            the number of hillclimbIterations
981   * @throws Exception
982   *                if parameter illegal
983   */
984  public void setHillclimbIterations(int n) throws Exception {
985    if (n < 0) {
986      throw new IllegalArgumentException(
987          "EnsembleSelection: Number of hillclimb iterations "
988          + "must be positive.");
989    }
990    m_hillclimbIterations = n;
991  }
992 
993  /**
994   * Returns the tip text for this property
995   *
996   * @return tip text for this property suitable for displaying in the
997   *         explorer/experimenter gui
998   */
999  public String numModelBagsTipText() {
1000    return "The number of \"model bags\" used in the ensemble selection algorithm.";
1001  }
1002 
1003  /**
1004   * Gets numModelBags.
1005   *
1006   * @return numModelBags
1007   */
1008  public int getNumModelBags() {
1009    return m_numModelBags;
1010  }
1011 
1012  /**
1013   * Sets numModelBags.
1014   *
1015   * @param n
1016   *            the new value for numModelBags
1017   * @throws Exception
1018   *                if parameter illegal
1019   */
1020  public void setNumModelBags(int n) throws Exception {
1021    if (n <= 0) {
1022      throw new IllegalArgumentException(
1023          "EnsembleSelection: Number of model bags "
1024          + "must be positive.");
1025    }
1026    m_numModelBags = n;
1027  }
1028 
1029  /**
1030   * Returns the tip text for this property
1031   *
1032   * @return tip text for this property suitable for displaying in the
1033   *         explorer/experimenter gui
1034   */
1035  public String sortInitializationRatioTipText() {
1036    return "The ratio of library models to be used for sort initialization.";
1037  }
1038 
1039  /**
1040   * Get the value of sortInitializationRatio.
1041   *
1042   * @return Value of sortInitializationRatio.
1043   */
1044  public double getSortInitializationRatio() {
1045    return m_sortInitializationRatio;
1046  }
1047 
1048  /**
1049   * Set the value of sortInitializationRatio.
1050   *
1051   * @param v
1052   *            Value to assign to sortInitializationRatio.
1053   */
1054  public void setSortInitializationRatio(double v) {
1055    m_sortInitializationRatio = v;
1056  }
1057 
1058  /**
1059   * Returns the tip text for this property
1060   *
1061   * @return tip text for this property suitable for displaying in the
1062   *         explorer/experimenter gui
1063   */
1064  public String replacementTipText() {
1065    return "Whether models in the library can be included more than once in an ensemble.";
1066  }
1067 
1068  /**
1069   * Get the value of replacement.
1070   *
1071   * @return Value of replacement.
1072   */
1073  public boolean getReplacement() {
1074    return m_replacement;
1075  }
1076 
1077  /**
1078   * Set the value of replacement.
1079   *
1080   * @param newReplacement
1081   *            Value to assign to replacement.
1082   */
1083  public void setReplacement(boolean newReplacement) {
1084    m_replacement = newReplacement;
1085  }
1086 
1087  /**
1088   * Returns the tip text for this property
1089   *
1090   * @return tip text for this property suitable for displaying in the
1091   *         explorer/experimenter gui
1092   */
1093  public String greedySortInitializationTipText() {
1094    return "Whether sort initialization greedily stops adding models when performance degrades.";
1095  }
1096 
1097  /**
1098   * Get the value of greedySortInitialization.
1099   *
1100   * @return Value of replacement.
1101   */
1102  public boolean getGreedySortInitialization() {
1103    return m_greedySortInitialization;
1104  }
1105 
1106  /**
1107   * Set the value of greedySortInitialization.
1108   *
1109   * @param newGreedySortInitialization
1110   *            Value to assign to replacement.
1111   */
1112  public void setGreedySortInitialization(boolean newGreedySortInitialization) {
1113    m_greedySortInitialization = newGreedySortInitialization;
1114  }
1115 
1116  /**
1117   * Returns the tip text for this property
1118   *
1119   * @return tip text for this property suitable for displaying in the
1120   *         explorer/experimenter gui
1121   */
1122  public String verboseOutputTipText() {
1123    return "Whether metrics are printed for each model.";
1124  }
1125 
1126  /**
1127   * Get the value of verboseOutput.
1128   *
1129   * @return Value of verboseOutput.
1130   */
1131  public boolean getVerboseOutput() {
1132    return m_verboseOutput;
1133  }
1134 
1135  /**
1136   * Set the value of verboseOutput.
1137   *
1138   * @param newVerboseOutput
1139   *            Value to assign to verboseOutput.
1140   */
1141  public void setVerboseOutput(boolean newVerboseOutput) {
1142    m_verboseOutput = newVerboseOutput;
1143  }
1144 
1145  /**
1146   * Returns the tip text for this property
1147   *
1148   * @return tip text for this property suitable for displaying in the
1149   *         explorer/experimenter gui
1150   */
1151  public String workingDirectoryTipText() {
1152    return "The working directory of the ensemble - where trained models will be stored.";
1153  }
1154 
1155  /**
1156   * Get the value of working directory.
1157   *
1158   * @return Value of working directory.
1159   */
1160  public File getWorkingDirectory() {
1161    return m_workingDirectory;
1162  }
1163 
1164  /**
1165   * Set the value of working directory.
1166   *
1167   * @param newWorkingDirectory directory Value.
1168   */
1169  public void setWorkingDirectory(File newWorkingDirectory) {
1170    if (m_Debug) {
1171      System.out.println("working directory changed to: "
1172          + newWorkingDirectory);
1173    }
1174    m_library.setWorkingDirectory(newWorkingDirectory);
1175   
1176    m_workingDirectory = newWorkingDirectory;
1177  }
1178 
1179  /**
1180   * Buildclassifier selects a classifier from the set of classifiers by
1181   * minimising error on the training data.
1182   *
1183   * @param trainData   the training data to be used for generating the boosted
1184   *                    classifier.
1185   * @throws Exception  if the classifier could not be built successfully
1186   */
1187  public void buildClassifier(Instances trainData) throws Exception {
1188   
1189    getCapabilities().testWithFail(trainData);
1190   
1191    // First we need to make sure that some library models
1192    // were specified. If not, then use the default list
1193    if (m_library.m_Models.size() == 0) {
1194     
1195      System.out
1196      .println("WARNING: No library file specified.  Using some default models.");
1197      System.out
1198      .println("You should specify a model list with -L <file> from the command line.");
1199      System.out
1200      .println("Or edit the list directly with the LibraryEditor from the GUI");
1201     
1202      for (int i = 0; i < 10; i++) {
1203       
1204        REPTree tree = new REPTree();
1205        tree.setSeed(i);
1206        m_library.addModel(new EnsembleSelectionLibraryModel(tree));
1207       
1208      }
1209     
1210    }
1211   
1212    if (m_library == null) {
1213      m_library = new EnsembleSelectionLibrary();
1214      m_library.setDebug(m_Debug);
1215    }
1216   
1217    m_library.setNumFolds(getNumFolds());
1218    m_library.setValidationRatio(getValidationRatio());
1219    // train all untrained models, and set "data" to the hillclimbing set.
1220    Instances data = m_library.trainAll(trainData, m_workingDirectory.getAbsolutePath(),
1221        m_algorithm);
1222    // We cache the hillclimb predictions from all of the models in
1223    // the library so that we can evaluate their performances when we
1224    // combine them
1225    // in various ways (without needing to keep the classifiers in memory).
1226    double predictions[][][] = m_library.getHillclimbPredictions();
1227    int numModels = predictions.length;
1228    int modelWeights[] = new int[numModels];
1229    m_total_weight = 0;
1230    Random rand = new Random(m_Seed);
1231   
1232    if (m_algorithm == ALGORITHM_BUILD_LIBRARY) {
1233      return;
1234     
1235    } else if (m_algorithm == ALGORITHM_BEST) {
1236      // If we want to choose the best model, just make a model bag that
1237      // includes all the models, then sort initialize to find the 1 that
1238      // performs best.
1239      ModelBag model_bag = new ModelBag(predictions, 1.0, m_Debug);
1240      int[] modelPicked = model_bag.sortInitialize(1, false, data,
1241          m_hillclimbMetric);
1242      // Then give it a weight of 1, while all others remain 0.
1243      modelWeights[modelPicked[0]] = 1;
1244    } else {
1245     
1246      if (m_Debug)
1247        System.out.println("Starting hillclimbing algorithm: "
1248            + m_algorithm);
1249     
1250      for (int i = 0; i < getNumModelBags(); ++i) {
1251        // For the number of bags,
1252        if (m_Debug)
1253          System.out.println("Starting on ensemble bag: " + i);
1254        // Create a new bag of the appropriate size
1255        ModelBag modelBag = new ModelBag(predictions, getModelRatio(),
1256            m_Debug);
1257        // And shuffle it.
1258        modelBag.shuffle(rand);
1259        if (getSortInitializationRatio() > 0.0) {
1260          // Sort initialize, if the ratio greater than 0.
1261          modelBag.sortInitialize((int) (getSortInitializationRatio()
1262              * getModelRatio() * numModels),
1263              getGreedySortInitialization(), data,
1264              m_hillclimbMetric);
1265        }
1266       
1267        if (m_algorithm == ALGORITHM_BACKWARD) {
1268          // If we're doing backwards elimination, we just give all
1269          // models
1270          // a weight of 1 initially. If the # of hillclimb iterations
1271          // is too high, we'll end up with just one model in the end
1272          // (we never delete all models from a bag). TODO - it might
1273          // be
1274          // smarter to base this weight off of how many models we
1275          // have.
1276          modelBag.weightAll(1); // for now at least, I'm just
1277          // assuming 1.
1278        }
1279        // Now the bag is initialized, and we're ready to hillclimb.
1280        for (int j = 0; j < getHillclimbIterations(); ++j) {
1281          if (m_algorithm == ALGORITHM_FORWARD) {
1282            modelBag.forwardSelect(getReplacement(), data,
1283                m_hillclimbMetric);
1284          } else if (m_algorithm == ALGORITHM_BACKWARD) {
1285            modelBag.backwardEliminate(data, m_hillclimbMetric);
1286          } else if (m_algorithm == ALGORITHM_FORWARD_BACKWARD) {
1287            modelBag.forwardSelectOrBackwardEliminate(
1288                getReplacement(), data, m_hillclimbMetric);
1289          }
1290        }
1291        // Now that we've done all the hillclimbing steps, we can just
1292        // get
1293        // the model weights that the bag determined, and add them to
1294        // our
1295        // running total.
1296        int[] bagWeights = modelBag.getModelWeights();
1297        for (int j = 0; j < bagWeights.length; ++j) {
1298          modelWeights[j] += bagWeights[j];
1299        }
1300      }
1301    }
1302    // Now we've done the hard work of actually learning the ensemble. Now
1303    // we set up the appropriate data structures so that Ensemble Selection
1304    // can
1305    // make predictions for future test examples.
1306    Set modelNames = m_library.getModelNames();
1307    String[] modelNamesArray = new String[m_library.size()];
1308    Iterator iter = modelNames.iterator();
1309    // libraryIndex indexes over all the models in the library (not just
1310    // those
1311    // which we chose for the ensemble).
1312    int libraryIndex = 0;
1313    // chosenModels will count the total number of models which were
1314    // selected
1315    // by EnsembleSelection (those that have non-zero weight).
1316    int chosenModels = 0;
1317    while (iter.hasNext()) {
1318      // Note that we have to be careful of order. Our model_weights array
1319      // is in the same order as our list of models in m_library.
1320     
1321      // Get the name of the model,
1322      modelNamesArray[libraryIndex] = (String) iter.next();
1323      // and its weight.
1324      int weightOfModel = modelWeights[libraryIndex++];
1325      m_total_weight += weightOfModel;
1326      if (weightOfModel > 0) {
1327        // If the model was chosen at least once, increment the
1328        // number of chosen models.
1329        ++chosenModels;
1330      }
1331    }
1332    if (m_verboseOutput) {
1333      // Output every model and its performance with respect to the
1334      // validation
1335      // data.
1336      ModelBag bag = new ModelBag(predictions, 1.0, m_Debug);
1337      int modelIndexes[] = bag.sortInitialize(modelNamesArray.length,
1338          false, data, m_hillclimbMetric);
1339      double modelPerformance[] = bag.getIndividualPerformance(data,
1340          m_hillclimbMetric);
1341      for (int i = 0; i < modelIndexes.length; ++i) {
1342        // TODO - Could do this in a more readable way.
1343        System.out.println("" + modelPerformance[i] + " "
1344            + modelNamesArray[modelIndexes[i]]);
1345      }
1346    }
1347    // We're now ready to build our array of the models which were chosen
1348    // and there associated weights.
1349    m_chosen_models = new EnsembleSelectionLibraryModel[chosenModels];
1350    m_chosen_model_weights = new int[chosenModels];
1351   
1352    libraryIndex = 0;
1353    // chosenIndex indexes over the models which were chosen by
1354    // EnsembleSelection
1355    // (those which have non-zero weight).
1356    int chosenIndex = 0;
1357    iter = m_library.getModels().iterator();
1358    while (iter.hasNext()) {
1359      int weightOfModel = modelWeights[libraryIndex++];
1360     
1361      EnsembleSelectionLibraryModel model = (EnsembleSelectionLibraryModel) iter
1362      .next();
1363     
1364      if (weightOfModel > 0) {
1365        // If the model was chosen at least once, add it to our array
1366        // of chosen models and weights.
1367        m_chosen_models[chosenIndex] = model;
1368        m_chosen_model_weights[chosenIndex] = weightOfModel;
1369        // Note that the EnsembleSelectionLibraryModel may not be
1370        // "loaded" -
1371        // that is, its classifier(s) may be null pointers. That's okay
1372        // -
1373        // we'll "rehydrate" them later, if and when we need to.
1374        ++chosenIndex;
1375      }
1376    }
1377  }
1378 
1379  /**
1380   * Calculates the class membership probabilities for the given test instance.
1381   *
1382   * @param instance the instance to be classified
1383   * @return predicted class probability distribution
1384   * @throws Exception if instance could not be classified
1385   * successfully
1386   */
1387  public double[] distributionForInstance(Instance instance) throws Exception {
1388    String stringInstance = instance.toString();
1389    double cachedPreds[][] = null;
1390   
1391    if (m_cachedPredictions != null) {
1392      // If we have any cached predictions (i.e., if cachePredictions was
1393      // called), look for a cached set of predictions for this instance.
1394      if (m_cachedPredictions.containsKey(stringInstance)) {
1395        cachedPreds = (double[][]) m_cachedPredictions.get(stringInstance);
1396      }
1397    }
1398    double[] prediction = new double[instance.numClasses()];
1399    for (int i = 0; i < prediction.length; ++i) {
1400      prediction[i] = 0.0;
1401    }
1402   
1403    // Now do a weighted average of the predictions of each of our models.
1404    for (int i = 0; i < m_chosen_models.length; ++i) {
1405      double[] predictionForThisModel = null;
1406      if (cachedPreds == null) {
1407        // If there are no predictions cached, we'll load the model's
1408        // classifier(s) in to memory and get the predictions.
1409        m_chosen_models[i].rehydrateModel(m_workingDirectory.getAbsolutePath());
1410        predictionForThisModel = m_chosen_models[i].getAveragePrediction(instance);
1411        // We could release the model here to save memory, but we assume
1412        // that there is enough available since we're not using the
1413        // prediction caching functionality. If we load and release a
1414        // model
1415        // every time we need to get a prediction for an instance, it
1416        // can be
1417        // prohibitively slow.
1418      } else {
1419        // If it's cached, just get it from the array of cached preds
1420        // for this instance.
1421        predictionForThisModel = cachedPreds[i];
1422      }
1423      // We have encountered a bug where MultilayerPerceptron returns a
1424      // null
1425      // prediction array. If that happens, we just don't count that model
1426      // in
1427      // our ensemble prediction.
1428      if (predictionForThisModel != null) {
1429        // Okay, the model returned a valid prediction array, so we'll
1430        // add the appropriate fraction of this model's prediction.
1431        for (int j = 0; j < prediction.length; ++j) {
1432          prediction[j] += m_chosen_model_weights[i] * predictionForThisModel[j] / m_total_weight;
1433        }
1434      }
1435    }
1436    // normalize to add up to 1.
1437    if (instance.classAttribute().isNominal()) {
1438      if (Utils.sum(prediction) > 0)
1439        Utils.normalize(prediction);
1440    }
1441    return prediction;
1442  }
1443 
1444  /**
1445   * This function tests whether or not a given path is appropriate for being
1446   * the working directory. Specifically, we care that we can write to the
1447   * path and that it doesn't point to a "non-directory" file handle.
1448   *
1449   * @param dir         the directory to test
1450   * @return            true if the directory is valid
1451   */
1452  private boolean validWorkingDirectory(String dir) {
1453   
1454    boolean valid = false;
1455   
1456    File f = new File((dir));
1457   
1458    if (f.exists()) {
1459      if (f.isDirectory() && f.canWrite())
1460        valid = true;
1461    } else {
1462      if (f.canWrite())
1463        valid = true;
1464    }
1465   
1466    return valid;
1467   
1468  }
1469 
1470  /**
1471   * This method tries to find a reasonable path name for the ensemble working
1472   * directory where models and files will be stored.
1473   *
1474   *
1475   * @return true if m_workingDirectory now has a valid file name
1476   */
1477  public static String getDefaultWorkingDirectory() {
1478   
1479    String defaultDirectory = new String("");
1480   
1481    boolean success = false;
1482   
1483    int i = 1;
1484   
1485    while (i < MAX_DEFAULT_DIRECTORIES && !success) {
1486     
1487      File f = new File(System.getProperty("user.home"), "Ensemble-" + i);
1488     
1489      if (!f.exists() && f.getParentFile().canWrite()) {
1490        defaultDirectory = f.getPath();
1491        success = true;
1492      }
1493      i++;
1494     
1495    }
1496   
1497    if (!success) {
1498      defaultDirectory = new String("");
1499      // should we print an error or something?
1500    }
1501   
1502    return defaultDirectory;
1503  }
1504 
1505  /**
1506   * Output a representation of this classifier
1507   *
1508   * @return    a string representation of the classifier
1509   */
1510  public String toString() {
1511    // We just print out the models which were selected, and the number
1512    // of times each was selected.
1513    String result = new String();
1514    if (m_chosen_models != null) {
1515      for (int i = 0; i < m_chosen_models.length; ++i) {
1516        result += m_chosen_model_weights[i];
1517        result += " " + m_chosen_models[i].getStringRepresentation()
1518        + "\n";
1519      }
1520    } else {
1521      result = "No models selected.";
1522    }
1523    return result;
1524  }
1525 
1526  /**
1527   * Cache predictions for the individual base classifiers in the ensemble
1528   * with respect to the given dataset. This is used so that when testing a
1529   * large ensemble on a test set, we don't have to keep the models in memory.
1530   *
1531   * @param test        The instances for which to cache predictions.
1532   * @throws Exception  if somethng goes wrong
1533   */
1534  private void cachePredictions(Instances test) throws Exception {
1535    m_cachedPredictions = new HashMap();
1536    Evaluation evalModel = null;
1537    Instances originalInstances = null;
1538    // If the verbose flag is set, we'll also print out the performances of
1539    // all the individual models w.r.t. this test set while we're at it.
1540    boolean printModelPerformances = getVerboseOutput();
1541    if (printModelPerformances) {
1542      // To get performances, we need to keep the class attribute.
1543      originalInstances = new Instances(test);
1544    }
1545   
1546    // For each model, we'll go through the dataset and get predictions.
1547    // The idea is we want to only have one model in memory at a time, so
1548    // we'll
1549    // load one model in to memory, get all its predictions, and add them to
1550    // the
1551    // hash map. Then we can release it from memory and move on to the next.
1552    for (int i = 0; i < m_chosen_models.length; ++i) {
1553      if (printModelPerformances) {
1554        // If we're going to print predictions, we need to make a new
1555        // Evaluation object.
1556        evalModel = new Evaluation(originalInstances);
1557      }
1558     
1559      Date startTime = new Date();
1560     
1561      // Load the model in to memory.
1562      m_chosen_models[i].rehydrateModel(m_workingDirectory.getAbsolutePath());
1563      // Now loop through all the instances and get the model's
1564      // predictions.
1565      for (int j = 0; j < test.numInstances(); ++j) {
1566        Instance currentInstance = test.instance(j);
1567        // When we're looking for a cached prediction later, we'll only
1568        // have the non-class attributes, so we set the class missing
1569        // here
1570        // in order to make the string match up properly.
1571        currentInstance.setClassMissing();
1572        String stringInstance = currentInstance.toString();
1573       
1574        // When we come in here with the first model, the instance will
1575        // not
1576        // yet be part of the map.
1577        if (!m_cachedPredictions.containsKey(stringInstance)) {
1578          // The instance isn't in the map yet, so add it.
1579          // For each instance, we store a two-dimensional array - the
1580          // first
1581          // index is over all the models in the ensemble, and the
1582          // second
1583          // index is over the (i.e., typical prediction array).
1584          int predSize = test.classAttribute().isNumeric() ? 1 : test
1585              .classAttribute().numValues();
1586          double predictionArray[][] = new double[m_chosen_models.length][predSize];
1587          m_cachedPredictions.put(stringInstance, predictionArray);
1588        }
1589        // Get the array from the map which is associated with this
1590        // instance
1591        double predictions[][] = (double[][]) m_cachedPredictions
1592        .get(stringInstance);
1593        // And add our model's prediction for it.
1594        predictions[i] = m_chosen_models[i].getAveragePrediction(test
1595            .instance(j));
1596       
1597        if (printModelPerformances) {
1598          evalModel.evaluateModelOnceAndRecordPrediction(
1599              predictions[i], originalInstances.instance(j));
1600        }
1601      }
1602      // Now we're done with model #i, so we can release it.
1603      m_chosen_models[i].releaseModel();
1604     
1605      Date endTime = new Date();
1606      long diff = endTime.getTime() - startTime.getTime();
1607     
1608      if (m_Debug)
1609        System.out.println("Test time for "
1610            + m_chosen_models[i].getStringRepresentation()
1611            + " was: " + diff);
1612     
1613      if (printModelPerformances) {
1614        String output = new String(m_chosen_models[i]
1615                                                   .getStringRepresentation()
1616                                                   + ": ");
1617        output += "\tRMSE:" + evalModel.rootMeanSquaredError();
1618        output += "\tACC:" + evalModel.pctCorrect();
1619        if (test.numClasses() == 2) {
1620          // For multiclass problems, we could print these too, but
1621          // it's
1622          // not clear which class we should use in that case... so
1623          // instead
1624          // we only print these metrics for binary classification
1625          // problems.
1626          output += "\tROC:" + evalModel.areaUnderROC(1);
1627          output += "\tPREC:" + evalModel.precision(1);
1628          output += "\tFSCR:" + evalModel.fMeasure(1);
1629        }
1630        System.out.println(output);
1631      }
1632    }
1633  }
1634 
1635  /**
1636   * Return the technical information.  There is actually another
1637   * paper that describes our current method of CV for this classifier
1638   * TODO: Cite Technical report when published
1639   *
1640   * @return the technical information about this class
1641   */
1642  public TechnicalInformation getTechnicalInformation() {
1643   
1644    TechnicalInformation result;
1645   
1646    result = new TechnicalInformation(Type.INPROCEEDINGS);
1647    result.setValue(Field.AUTHOR, "Rich Caruana, Alex Niculescu, Geoff Crew, and Alex Ksikes");
1648    result.setValue(Field.TITLE, "Ensemble Selection from Libraries of Models");
1649    result.setValue(Field.BOOKTITLE, "21st International Conference on Machine Learning");
1650    result.setValue(Field.YEAR, "2004");
1651   
1652    return result;
1653  }
1654 
1655  /**
1656   * Returns the revision string.
1657   *
1658   * @return            the revision
1659   */
1660  public String getRevision() {
1661    return RevisionUtils.extract("$Revision: 5480 $");
1662  }
1663 
1664  /**
1665   * Executes the classifier from commandline.
1666   *
1667   * @param argv
1668   *            should contain the following arguments: -t training file [-T
1669   *            test file] [-c class index]
1670   */
1671  public static void main(String[] argv) {
1672   
1673    try {
1674     
1675      String options[] = (String[]) argv.clone();
1676     
1677      // do we get the input from XML instead of normal parameters?
1678      String xml = Utils.getOption("xml", options);
1679      if (!xml.equals(""))
1680        options = new XMLOptions(xml).toArray();
1681     
1682      String trainFileName = Utils.getOption('t', options);
1683      String objectInputFileName = Utils.getOption('l', options);
1684      String testFileName = Utils.getOption('T', options);
1685     
1686      if (testFileName.length() != 0 && objectInputFileName.length() != 0
1687          && trainFileName.length() == 0) {
1688       
1689        System.out.println("Caching predictions");
1690       
1691        EnsembleSelection classifier = null;
1692       
1693        BufferedReader testReader = new BufferedReader(new FileReader(
1694            testFileName));
1695       
1696        // Set up the Instances Object
1697        Instances test;
1698        int classIndex = -1;
1699        String classIndexString = Utils.getOption('c', options);
1700        if (classIndexString.length() != 0) {
1701          classIndex = Integer.parseInt(classIndexString);
1702        }
1703       
1704        test = new Instances(testReader, 1);
1705        if (classIndex != -1) {
1706          test.setClassIndex(classIndex - 1);
1707        } else {
1708          test.setClassIndex(test.numAttributes() - 1);
1709        }
1710        if (classIndex > test.numAttributes()) {
1711          throw new Exception("Index of class attribute too large.");
1712        }
1713       
1714        while (test.readInstance(testReader)) {
1715         
1716        }
1717        testReader.close();
1718       
1719        // Now yoink the EnsembleSelection Object from the fileSystem
1720       
1721        InputStream is = new FileInputStream(objectInputFileName);
1722        if (objectInputFileName.endsWith(".gz")) {
1723          is = new GZIPInputStream(is);
1724        }
1725       
1726        // load from KOML?
1727        if (!(objectInputFileName.endsWith("UpdateableClassifier.koml") && KOML
1728            .isPresent())) {
1729          ObjectInputStream objectInputStream = new ObjectInputStream(
1730              is);
1731          classifier = (EnsembleSelection) objectInputStream
1732          .readObject();
1733          objectInputStream.close();
1734        } else {
1735          BufferedInputStream xmlInputStream = new BufferedInputStream(
1736              is);
1737          classifier = (EnsembleSelection) KOML.read(xmlInputStream);
1738          xmlInputStream.close();
1739        }
1740       
1741        String workingDir = Utils.getOption('W', argv);
1742        if (!workingDir.equals("")) {
1743          classifier.setWorkingDirectory(new File(workingDir));
1744        }
1745       
1746        classifier.setDebug(Utils.getFlag('D', argv));
1747        classifier.setVerboseOutput(Utils.getFlag('O', argv));
1748       
1749        classifier.cachePredictions(test);
1750       
1751        // Now we write the model back out to the file system.
1752        String objectOutputFileName = objectInputFileName;
1753        OutputStream os = new FileOutputStream(objectOutputFileName);
1754        // binary
1755        if (!(objectOutputFileName.endsWith(".xml") || (objectOutputFileName
1756            .endsWith(".koml") && KOML.isPresent()))) {
1757          if (objectOutputFileName.endsWith(".gz")) {
1758            os = new GZIPOutputStream(os);
1759          }
1760          ObjectOutputStream objectOutputStream = new ObjectOutputStream(
1761              os);
1762          objectOutputStream.writeObject(classifier);
1763          objectOutputStream.flush();
1764          objectOutputStream.close();
1765        }
1766        // KOML/XML
1767        else {
1768          BufferedOutputStream xmlOutputStream = new BufferedOutputStream(
1769              os);
1770          if (objectOutputFileName.endsWith(".xml")) {
1771            XMLSerialization xmlSerial = new XMLClassifier();
1772            xmlSerial.write(xmlOutputStream, classifier);
1773          } else
1774            // whether KOML is present has already been checked
1775            // if not present -> ".koml" is interpreted as binary - see
1776            // above
1777            if (objectOutputFileName.endsWith(".koml")) {
1778              KOML.write(xmlOutputStream, classifier);
1779            }
1780          xmlOutputStream.close();
1781        }
1782       
1783      }
1784     
1785      System.out.println(Evaluation.evaluateModel(
1786          new EnsembleSelection(), argv));
1787     
1788    } catch (Exception e) {
1789      if (    (e.getMessage() != null)
1790           && (e.getMessage().indexOf("General options") == -1) )
1791        e.printStackTrace();
1792      else
1793        System.err.println(e.getMessage());
1794    }
1795  }
1796}
Note: See TracBrowser for help on using the repository browser.