source: src/main/java/weka/classifiers/meta/ensembleSelection/EnsembleSelectionLibrary.java @ 11

Last change on this file since 11 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 19.1 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    EnsembleSelectionLibrary.java
19 *    Copyright (C) 2006 Robert Jung
20 *
21 */
22
23package weka.classifiers.meta.ensembleSelection;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.classifiers.EnsembleLibrary;
28import weka.classifiers.EnsembleLibraryModel;
29import weka.classifiers.meta.EnsembleSelection;
30import weka.core.Instances;
31import weka.core.RevisionUtils;
32
33import java.beans.PropertyChangeListener;
34import java.beans.PropertyChangeSupport;
35import java.io.File;
36import java.io.FileWriter;
37import java.io.InputStream;
38import java.io.Serializable;
39import java.io.UnsupportedEncodingException;
40import java.text.DateFormat;
41import java.text.SimpleDateFormat;
42import java.util.Date;
43import java.util.HashSet;
44import java.util.Iterator;
45import java.util.Set;
46import java.util.TreeSet;
47import java.util.zip.Adler32;
48
49/**
50 * This class represents an ensemble library.  That is a
51 * collection of models that will be combined via the
52 * ensemble selection algorithm.  This class is responsible for
53 * tracking all of the unique model specifications in the current
54 * library and trainined them when asked.  There are also methods
55 * to save/load library model list files. 
56 *
57 * @author  Robert Jung
58 * @author  David Michael
59 * @version $Revision: 5928 $
60 */
61public class EnsembleSelectionLibrary 
62  extends EnsembleLibrary
63  implements Serializable {
64 
65  /** for serialization */
66  private static final long serialVersionUID = -6444026512552917835L;
67
68  /** the working ensemble library directory. */
69  private File m_workingDirectory;
70 
71  /** tha name of the model list file storing the list of
72   * models currently being used by the model library */
73  private String m_modelListFile = null;
74 
75  /** the training data used to build the library.  One per fold.*/
76  private Instances[] m_trainingData;
77 
78  /** the test data used for hillclimbing.  One per fold. */
79  private Instances[] m_hillclimbData;
80 
81  /** the predictions of each model.  Built by trainAll.  First index is
82   * for the model.  Second is for the instance.  third is for the class
83   * (we use distributionForInstance).
84   */
85  private double[][][] m_predictions;
86 
87  /** the random seed used to partition the training data into
88   * validation and training folds */
89  private int m_seed;
90 
91  /** the number of folds */
92  private int m_folds;
93 
94  /** the ratio of validation data used to train the model */
95  private double m_validationRatio;
96 
97  /** A helper class for notifying listeners when working directory changes */
98  private transient PropertyChangeSupport m_workingDirectoryPropertySupport = new PropertyChangeSupport(this);
99 
100  /** Whether we should print debug messages. */
101  public transient boolean m_Debug = true;
102 
103  /**
104   * Creates a default libary.  Library should be associated with
105   *
106   */ 
107  public EnsembleSelectionLibrary() {
108    super();
109   
110    m_workingDirectory = new File(EnsembleSelection.getDefaultWorkingDirectory());
111  }
112 
113  /**
114   * Creates a default libary.  Library should be associated with
115   * a working directory
116   *
117   * @param dir                 the working directory form the ensemble library
118   * @param seed                the seed value
119   * @param folds               the number of folds
120   * @param validationRatio     the ratio to use
121   */ 
122  public EnsembleSelectionLibrary(String dir, int seed, 
123      int folds, double validationRatio) {
124   
125    super();
126   
127    if (dir != null)
128      m_workingDirectory = new File(dir);
129    m_seed = seed;
130    m_folds = folds;
131    m_validationRatio = validationRatio;
132   
133  }
134 
135  /**
136   * This constructor will create a library from a model
137   * list file given by the file name argument
138   *
139   * @param libraryFileName     the library filename
140   */
141  public EnsembleSelectionLibrary(String libraryFileName) {             
142    super();
143   
144    File libraryFile = new File(libraryFileName);
145    try {
146      EnsembleLibrary.loadLibrary(libraryFile, this);
147    } catch (Exception e) {
148      System.err.println("Could not load specified library file: "+libraryFileName);
149    }
150  }
151 
152  /**
153   * This constructor will create a library from the given XML stream.
154   *
155   * @param stream      the XML library stream
156   */
157  public EnsembleSelectionLibrary(InputStream stream) {         
158    super();
159   
160    try {
161      EnsembleLibrary.loadLibrary(stream, this);
162    }
163    catch (Exception e) {
164      System.err.println("Could not load library from XML stream: " + e);
165    }
166  }
167 
168  /**
169   * Set debug flag for the library and all its models.  The debug flag
170   * determines whether we print debugging information to stdout.
171   *
172   * @param debug       if true debug mode is on
173   */
174  public void setDebug(boolean debug) {
175    m_Debug = debug;
176   
177    Iterator it = getModels().iterator();
178    while (it.hasNext()) {
179      ((EnsembleSelectionLibraryModel)it.next()).setDebug(m_Debug);
180    }
181  }
182 
183  /**
184   * Sets the validation-set ratio.  This is the portion of the
185   * training set that is set aside for hillclimbing.  Note that
186   * this value is ignored if we are doing cross-validation
187   * (indicated by the number of folds being > 1).
188   * 
189   * @param validationRatio     the new ratio
190   */
191  public void setValidationRatio(double validationRatio) {
192    m_validationRatio = validationRatio;
193  }
194 
195  /**
196   * Set the number of folds for cross validation.  If the number
197   * of folds is > 1, the validation ratio is ignored.
198   *
199   * @param numFolds            the number of folds to use
200   */
201  public void setNumFolds(int numFolds) {
202    m_folds = numFolds;
203  }
204 
205  /**
206   * This method will iterate through the TreeMap of models and
207   * train all models that do not currently exist (are not
208   * yet trained).
209   * <p/>
210   * Returns the data set which should be used for hillclimbing.
211   * <p/>
212   * If training a model fails then an error will
213   * be sent to stdout and that model will be removed from the
214   * TreeMap.   FIXME Should we maybe raise an exception instead?
215   *
216   * @param data        the data to work on
217   * @param directory   the working directory
218   * @param algorithm   the type of algorithm
219   * @return            the data that should be used for hillclimbing
220   * @throws Exception  if something goes wrong
221   */
222  public Instances trainAll(Instances data, String directory, int algorithm) throws Exception {
223   
224    createWorkingDirectory(directory);
225   
226    //craete the directory if it doesn't already exist
227    String dataDirectoryName = getDataDirectoryName(data);
228    File dataDirectory = new File(directory, dataDirectoryName);
229   
230    if (!dataDirectory.exists()) {
231      dataDirectory.mkdirs();
232    }
233   
234    //Now create a record of all the models trained.  This will be a .mlf
235    //flat file with a file name based on the time/date of training
236    //DateFormat formatter = new SimpleDateFormat("yyyy.MM.dd.HH.mm");
237    //String dateString = formatter.format(new Date());
238   
239    //Go ahead and save in both formats just in case:
240    DateFormat formatter = new SimpleDateFormat("yyyy.MM.dd.HH.mm");
241    String modelListFileName = formatter.format(new Date())+"_"+size()+"_models.mlf";
242    //String modelListFileName = dataDirectory.getName()+".mlf";
243    File modelListFile = new File(dataDirectory.getPath(), modelListFileName);
244    EnsembleLibrary.saveLibrary(modelListFile, this, null);
245   
246    //modelListFileName = dataDirectory.getName()+".model.xml";
247    modelListFileName = formatter.format(new Date())+"_"+size()+"_models.model.xml";
248    modelListFile = new File(dataDirectory.getPath(), modelListFileName);
249    EnsembleLibrary.saveLibrary(modelListFile, this, null);
250   
251   
252    //log the instances used just in case we need to know...
253    String arf = data.toString();
254    FileWriter f = new FileWriter(new File(dataDirectory.getPath(), dataDirectory.getName()+".arff"));
255    f.write(arf);
256    f.close();
257   
258    // m_trainingData will contain the datasets used for training models for each fold.
259    m_trainingData = new Instances[m_folds];
260    // m_hillclimbData will contain the dataset which we will use for hillclimbing -
261    // m_hillclimbData[i] should be disjoint from m_trainingData[i].
262    m_hillclimbData = new Instances[m_folds];
263    // validationSet is all of the hillclimbing data from all folds, in the same
264    // order as it is in m_hillclimbData
265    Instances validationSet;
266    if (m_folds > 1) {
267      validationSet = new Instances(data, data.numInstances());  //make a new set
268      //with the same capacity and header as data.
269      //instances may come from CV functions in
270      //different order, so we'll make sure the
271      //validation set's order matches that of
272      //the concatenated testCV sets
273      for (int i=0; i < m_folds; ++i) {
274        m_trainingData[i] = data.trainCV(m_folds, i);
275        m_hillclimbData[i] = data.testCV(m_folds, i);
276      }
277      // If we're doing "embedded CV" we can hillclimb on
278      // the entire training set, so we just put all of the hillclimbData
279      // from all folds in to validationSet (making sure it's in the appropriate
280      // order).
281      for (int i=0; i < m_folds; ++i) {
282        for (int j=0; j < m_hillclimbData[i].numInstances(); ++j) {
283          validationSet.add(m_hillclimbData[i].instance(j));
284        }
285      }
286    }
287    else {
288      // Otherwise, we're not doing CV, we're just using a validation set.
289      // Partition the data set in to a training set and a hillclimb set
290      // based on the m_validationRatio.
291      int validation_size = (int)(data.numInstances() * m_validationRatio);
292      m_trainingData[0] = new Instances(data, 0, data.numInstances() - validation_size);
293      m_hillclimbData[0] = new Instances(data, data.numInstances() - validation_size, validation_size);
294      validationSet = m_hillclimbData[0];
295    }
296   
297    // Now we have all the data chopped up appropriately, and we can train all models
298    Iterator it = m_Models.iterator();
299    int model_index = 0;
300    m_predictions = new double[m_Models.size()][validationSet.numInstances()][data.numClasses()];
301   
302    // We'll keep a set of all the models which fail so that we can remove them from
303    // our library.
304    Set invalidModels = new HashSet();
305   
306    while (it.hasNext()) {
307      // For each model,
308      EnsembleSelectionLibraryModel model = (EnsembleSelectionLibraryModel)it.next();
309     
310      // set the appropriate options
311      model.setDebug(m_Debug);
312      model.setFolds(m_folds);
313      model.setSeed(m_seed);
314      model.setValidationRatio(m_validationRatio);
315      model.setChecksum(getInstancesChecksum(data));
316     
317      try {
318        // Create the model.  This will attempt to load the model, if it
319        // alreay exists.  If it does not, it will train the model using
320        // m_trainingData and cache the model's predictions for
321        // m_hillclimbData.
322        model.createModel(m_trainingData, m_hillclimbData, dataDirectory.getPath(), algorithm);
323      } catch (Exception e) {
324        // If the model failed, print a message and add it to our set of
325        // invalid models.
326        System.out.println("**Couldn't create model "+model.getStringRepresentation()
327            +" because of following exception: "+e.getMessage());
328       
329        invalidModels.add(model);
330        continue;
331      }
332     
333      if (!invalidModels.contains(model)) {
334        // If the model succeeded, add its predictions to our array
335        // of predictions.  Note that the successful models' predictions
336        // are packed in to the front of m_predictions.
337        m_predictions[model_index] = model.getValidationPredictions();
338        ++model_index;
339        // We no longer need it in memory, so release it.
340        model.releaseModel();                           
341      }
342     
343     
344    }
345   
346    // Remove all invalidModels from m_Models.
347    it = invalidModels.iterator();
348    while (it.hasNext()) {
349      EnsembleSelectionLibraryModel model = (EnsembleSelectionLibraryModel)it.next();
350      if (m_Debug) System.out.println("removing invalid library model: "+model.getStringRepresentation());
351      m_Models.remove(model);
352    }
353   
354    if (m_Debug) System.out.println("model index: "+model_index+" tree set size: "+m_Models.size());
355   
356    if (invalidModels.size() > 0) {
357      // If we had any invalid models, we have some bad predictions in the back
358      // of m_predictions, so we'll shrink it to the right size.
359      double tmpPredictions[][][] = new double[m_Models.size()][][];
360     
361      for (int i = 0; i < m_Models.size(); i++) {
362        tmpPredictions[i] = m_predictions[i];
363      }
364      m_predictions = tmpPredictions;
365    }
366   
367    if (m_Debug) System.out.println("Finished remapping models");
368   
369    return validationSet;       //Give the appropriate "hillclimb" set back to ensemble
370                                //selection. 
371  }
372 
373  /**
374   * Creates the working directory associated with this library
375   *
376   * @param dirName     the new directory
377   */
378  public void createWorkingDirectory(String dirName) {
379    File directory = new File(dirName);
380   
381    if (!directory.exists())
382      directory.mkdirs();
383  }
384 
385  /**
386   * This will remove the model associated with the given String
387   * from the model libraryHashMap
388   *
389   * @param modelKey    the key of the model
390   */
391  public void removeModel(String modelKey) {
392    m_Models.remove(modelKey);  //TODO - is this really all there is to it??
393  }
394 
395  /**
396   * This method will return a Set object containing all the
397   * String representations of the models.  The iterator across
398   * this Set object will return the model name in alphebetical
399   * order.
400   *
401   * @return            all model representations
402   */
403  public Set getModelNames() {
404    Set names = new TreeSet();
405   
406    Iterator it = m_Models.iterator();
407   
408    while (it.hasNext()) {
409      names.add(((EnsembleLibraryModel)it.next()).getStringRepresentation());
410    }
411   
412    return names;
413  }
414 
415  /**
416   * This method will get the predictions for all the models in the
417   * ensemble library.  If cross validaiton is used, then predictions
418   * will be returned for the entire training set.  If cross validation
419   * is not used, then predictions will only be returned for the ratio
420   * of the training set reserved for validation.
421   *
422   * @return            the predictions
423   */
424  public double[][][] getHillclimbPredictions() {
425    return m_predictions;
426  }
427 
428  /**
429   * Gets the working Directory of the ensemble library.
430   *
431   * @return the working directory.
432   */
433  public File getWorkingDirectory() {
434    return m_workingDirectory;
435  }
436 
437  /**
438   * Sets the working Directory of the ensemble library.
439   *
440   * @param workingDirectory    the working directory to use.
441   */
442  public void setWorkingDirectory(File workingDirectory) {
443    m_workingDirectory = workingDirectory;
444    if (m_workingDirectoryPropertySupport != null) {
445      m_workingDirectoryPropertySupport.firePropertyChange(null, null, null);
446    }
447  }
448 
449  /**
450   * Gets the model list file that holds the list of models
451   * in the ensemble library.
452   *
453   * @return the working directory.
454   */
455  public String getModelListFile() {
456    return m_modelListFile;
457  }
458 
459  /**
460   * Sets the model list file that holds the list of models
461   * in the ensemble library.
462   *
463   * @param modelListFile       the model list file to use
464   */
465  public void setModelListFile(String modelListFile) {
466    m_modelListFile = modelListFile;
467  }
468 
469  /**
470   * creates a LibraryModel from a set of arguments
471   *
472   * @param classifier  the classifier to use
473   * @return            the generated library model
474   */
475  public EnsembleLibraryModel createModel(Classifier classifier) {
476    EnsembleSelectionLibraryModel model = new EnsembleSelectionLibraryModel(classifier);
477    model.setDebug(m_Debug);
478   
479    return model;
480  }
481 
482  /**
483   * This method takes a String argument defining a classifier and
484   * uses it to create a base Classifier. 
485   *
486   * WARNING! This method is only called when trying to craete models
487   * from flat files (.mlf).  This method is highly untested and
488   * foreseeably will cause problems when trying to nest arguments
489   * within multiplte meta classifiers.  To avoid any problems we
490   * recommend using only XML serialization, via saving to
491   * .model.xml and using only the createModel(Classifier) method
492   * above.
493   *
494   * @param modelString         the classifier definition
495   * @return                    the generated library model
496   */
497  public EnsembleLibraryModel createModel(String modelString) {
498   
499    String[] splitString = modelString.split("\\s+");
500    String className = splitString[0];
501   
502    String argString = modelString.replaceAll(splitString[0], "");
503    String[] optionStrings = argString.split("\\s+"); 
504   
505    EnsembleSelectionLibraryModel model = null;
506    try {
507      model = new EnsembleSelectionLibraryModel(AbstractClassifier.forName(className, optionStrings));
508      model.setDebug(m_Debug);
509     
510    } catch (Exception e) {
511      e.printStackTrace();
512    }
513   
514    return model;
515  }
516 
517 
518  /**
519   * This method takes an Instances object and returns a checksum of its
520   * toString method - that is the checksum of the .arff file that would
521   * be created if the Instances object were transformed into an arff file
522   * in the file system.
523   *
524   * @param instances   the data to get the checksum for
525   * @return            the checksum
526   */
527  public static String getInstancesChecksum(Instances instances) {
528   
529    String checksumString = null;
530   
531    try {
532     
533      Adler32 checkSummer = new Adler32();
534     
535      byte[] utf8 = instances.toString().getBytes("UTF8");;
536     
537      checkSummer.update(utf8);
538      checksumString = Long.toHexString(checkSummer.getValue());
539     
540    } catch (UnsupportedEncodingException e) {
541      // TODO Auto-generated catch block
542      e.printStackTrace();
543    }
544   
545   
546    return checksumString;
547  }
548 
549  /**
550   * Returns the unique name for the set of instances supplied.  This is
551   * used to create a directory for all of the models corresponding to that
552   * set of instances.  This was intended as a way to keep Working Directories
553   * "organized"
554   *
555   * @param instances   the data to get the directory for
556   * @return            the directory
557   */
558  public static String getDataDirectoryName(Instances instances) {
559   
560    String directory = null;
561   
562   
563    directory = new String(instances.numInstances()+
564        "_instances_"+getInstancesChecksum(instances));
565   
566    //System.out.println("generated directory name: "+directory);
567   
568    return directory;
569   
570  }
571 
572  /**
573   * Adds an object to the list of those that wish to be informed when the
574   * eotking directory changes.
575   *
576   * @param listener a new listener to add to the list
577   */   
578  public void addWorkingDirectoryListener(PropertyChangeListener listener) {
579   
580    if (m_workingDirectoryPropertySupport != null) {
581      m_workingDirectoryPropertySupport.addPropertyChangeListener(listener);
582     
583    }
584  }
585 
586  /**
587   * Returns the revision string.
588   *
589   * @return            the revision
590   */
591  public String getRevision() {
592    return RevisionUtils.extract("$Revision: 5928 $");
593  }
594}
Note: See TracBrowser for help on using the repository browser.