source: src/main/java/weka/classifiers/meta/RacedIncrementalLogitBoost.java @ 9

Last change on this file since 9 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 34.7 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    RacedIncrementalLogitBoost.java
19 *    Copyright (C) 2002 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.meta;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.classifiers.RandomizableSingleClassifierEnhancer;
28import weka.classifiers.UpdateableClassifier;
29import weka.classifiers.rules.ZeroR;
30import weka.core.Attribute;
31import weka.core.Capabilities;
32import weka.core.FastVector;
33import weka.core.Instance;
34import weka.core.Instances;
35import weka.core.Option;
36import weka.core.RevisionHandler;
37import weka.core.RevisionUtils;
38import weka.core.SelectedTag;
39import weka.core.Tag;
40import weka.core.Utils;
41import weka.core.WeightedInstancesHandler;
42import weka.core.Capabilities.Capability;
43
44import java.io.Serializable;
45import java.util.Enumeration;
46import java.util.Random;
47import java.util.Vector;
48
49/**
50 <!-- globalinfo-start -->
51 * Classifier for incremental learning of large datasets by way of racing logit-boosted committees.
52 * <p/>
53 <!-- globalinfo-end -->
54 *
55 <!-- options-start -->
56 * Valid options are: <p/>
57 *
58 * <pre> -C &lt;num&gt;
59 *  Minimum size of chunks.
60 *  (default 500)</pre>
61 *
62 * <pre> -M &lt;num&gt;
63 *  Maximum size of chunks.
64 *  (default 2000)</pre>
65 *
66 * <pre> -V &lt;num&gt;
67 *  Size of validation set.
68 *  (default 1000)</pre>
69 *
70 * <pre> -P &lt;pruning type&gt;
71 *  Committee pruning to perform.
72 *  0=none, 1=log likelihood (default)</pre>
73 *
74 * <pre> -Q
75 *  Use resampling for boosting.</pre>
76 *
77 * <pre> -S &lt;num&gt;
78 *  Random number seed.
79 *  (default 1)</pre>
80 *
81 * <pre> -D
82 *  If set, classifier is run in debug mode and
83 *  may output additional info to the console</pre>
84 *
85 * <pre> -W
86 *  Full name of base classifier.
87 *  (default: weka.classifiers.trees.DecisionStump)</pre>
88 *
89 * <pre>
90 * Options specific to classifier weka.classifiers.trees.DecisionStump:
91 * </pre>
92 *
93 * <pre> -D
94 *  If set, classifier is run in debug mode and
95 *  may output additional info to the console</pre>
96 *
97 <!-- options-end -->
98 *
99 * Options after -- are passed to the designated learner.<p>
100 *
101 * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
102 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
103 * @version $Revision: 5987 $
104 */
105public class RacedIncrementalLogitBoost 
106  extends RandomizableSingleClassifierEnhancer
107  implements UpdateableClassifier {
108 
109  /** for serialization */
110  static final long serialVersionUID = 908598343772170052L;
111
112  /** no pruning */
113  public static final int PRUNETYPE_NONE = 0;
114  /** log likelihood pruning */
115  public static final int PRUNETYPE_LOGLIKELIHOOD = 1;
116  /** The pruning types */
117  public static final Tag [] TAGS_PRUNETYPE = {
118    new Tag(PRUNETYPE_NONE, "No pruning"),
119    new Tag(PRUNETYPE_LOGLIKELIHOOD, "Log likelihood pruning")
120  };
121
122  /** The committees */   
123  protected FastVector m_committees;
124
125  /** The pruning type used */
126  protected int m_PruningType = PRUNETYPE_LOGLIKELIHOOD;
127
128  /** Whether to use resampling */
129  protected boolean m_UseResampling = false;
130
131  /** The number of classes */
132  protected int m_NumClasses;
133
134  /** A threshold for responses (Friedman suggests between 2 and 4) */
135  protected static final double Z_MAX = 4;
136
137  /** Dummy dataset with a numeric class */
138  protected Instances m_NumericClassData;
139
140  /** The actual class attribute (for getting class names) */
141  protected Attribute m_ClassAttribute; 
142
143  /** The minimum chunk size used for training */
144  protected int m_minChunkSize = 500;
145
146  /** The maimum chunk size used for training */
147  protected int m_maxChunkSize = 2000;
148
149  /** The size of the validation set */
150  protected int m_validationChunkSize = 1000;
151
152  /** The number of instances consumed */ 
153  protected int m_numInstancesConsumed;
154
155  /** The instances used for validation */   
156  protected Instances m_validationSet;
157
158  /** The instances currently in memory for training */   
159  protected Instances m_currentSet;
160
161  /** The current best committee */   
162  protected Committee m_bestCommittee;
163
164  /** The default scheme used when committees aren't ready */   
165  protected ZeroR m_zeroR = null;
166
167  /** Whether the validation set has recently been changed */ 
168  protected boolean m_validationSetChanged;
169
170  /** The maximum number of instances required for processing */   
171  protected int m_maxBatchSizeRequired;
172
173  /** The random number generator used */
174  protected Random m_RandomInstance = null;
175
176   
177  /**
178   * Constructor.
179   */
180  public RacedIncrementalLogitBoost() {
181   
182    m_Classifier = new weka.classifiers.trees.DecisionStump();
183  }
184
185  /**
186   * String describing default classifier.
187   *
188   * @return the default classifier classname
189   */
190  protected String defaultClassifierString() {
191   
192    return "weka.classifiers.trees.DecisionStump";
193  }
194
195
196  /**
197   * Class representing a committee of LogitBoosted models
198   */
199  protected class Committee 
200    implements Serializable, RevisionHandler {
201   
202    /** for serialization */
203    static final long serialVersionUID = 5559880306684082199L;
204
205    protected int m_chunkSize;
206   
207    /** number eaten from m_currentSet */
208    protected int m_instancesConsumed; 
209   
210    protected FastVector m_models;
211    protected double m_lastValidationError;
212    protected double m_lastLogLikelihood;
213    protected boolean m_modelHasChanged;
214    protected boolean m_modelHasChangedLL;
215    protected double[][] m_validationFs;
216    protected double[][] m_newValidationFs;
217
218    /**
219     * constructor
220     *
221     * @param chunkSize the size of the chunk
222     */
223    public Committee(int chunkSize) {
224
225      m_chunkSize = chunkSize;
226      m_instancesConsumed = 0;
227      m_models = new FastVector();
228      m_lastValidationError = 1.0;
229      m_lastLogLikelihood = Double.MAX_VALUE;
230      m_modelHasChanged = true;
231      m_modelHasChangedLL = true;
232      m_validationFs = new double[m_validationChunkSize][m_NumClasses];
233      m_newValidationFs = new double[m_validationChunkSize][m_NumClasses];
234    } 
235
236    /**
237     * update the committee
238     *
239     * @return true if the committee has changed
240     * @throws Exception if anything goes wrong
241     */
242    public boolean update() throws Exception {
243
244      boolean hasChanged = false;
245      while (m_currentSet.numInstances() - m_instancesConsumed >= m_chunkSize) {
246        Classifier[] newModel = boost(new Instances(m_currentSet, m_instancesConsumed, m_chunkSize));
247        for (int i=0; i<m_validationSet.numInstances(); i++) {
248          m_newValidationFs[i] = updateFS(m_validationSet.instance(i), newModel, m_validationFs[i]);
249        }
250        m_models.addElement(newModel);
251        m_instancesConsumed += m_chunkSize;
252        hasChanged = true;
253      }
254      if (hasChanged) {
255        m_modelHasChanged = true;
256        m_modelHasChangedLL = true;
257      }
258      return hasChanged;
259    }
260
261    /** reset consumation counts */
262    public void resetConsumed() {
263
264      m_instancesConsumed = 0;
265    }
266
267    /** remove the last model from the committee */
268    public void pruneLastModel() {
269
270      if (m_models.size() > 0) {
271        m_models.removeElementAt(m_models.size()-1);
272        m_modelHasChanged = true;
273        m_modelHasChangedLL = true;
274      }
275    }
276
277    /**
278     * decide to keep the last model in the committee
279     * @throws Exception if anything goes wrong
280     */
281    public void keepLastModel() throws Exception {
282
283      m_validationFs = m_newValidationFs;
284      m_newValidationFs = new double[m_validationChunkSize][m_NumClasses];
285      m_modelHasChanged = true;
286      m_modelHasChangedLL = true;
287    }
288
289    /**
290     * calculate the log likelihood on the validation data
291     * @return the log likelihood
292     * @throws Exception if computation fails
293     */       
294    public double logLikelihood() throws Exception {
295
296      if (m_modelHasChangedLL) {
297
298        Instance inst;
299        double llsum = 0.0;
300        for (int i=0; i<m_validationSet.numInstances(); i++) {
301          inst = m_validationSet.instance(i);
302          llsum += (logLikelihood(m_validationFs[i],(int) inst.classValue()));
303        }
304        m_lastLogLikelihood = llsum / (double) m_validationSet.numInstances();
305        m_modelHasChangedLL = false;
306      }
307      return m_lastLogLikelihood;
308    }
309
310    /**
311     * calculate the log likelihood on the validation data after adding the last model
312     * @return the log likelihood
313     * @throws Exception if computation fails
314     */
315    public double logLikelihoodAfter() throws Exception {
316
317        Instance inst;
318        double llsum = 0.0;
319        for (int i=0; i<m_validationSet.numInstances(); i++) {
320          inst = m_validationSet.instance(i);
321          llsum += (logLikelihood(m_newValidationFs[i],(int) inst.classValue()));
322        }
323        return llsum / (double) m_validationSet.numInstances();
324    }
325
326   
327    /**
328     * calculates the log likelihood of an instance
329     * @param Fs the Fs values
330     * @param classIndex the class index
331     * @return the log likelihood
332     * @throws Exception if computation fails
333     */
334    private double logLikelihood(double[] Fs, int classIndex) throws Exception {
335
336      return -Math.log(distributionForInstance(Fs)[classIndex]);
337    }
338
339    /**
340     * calculates the validation error of the committee
341     * @return the validation error
342     * @throws Exception if computation fails
343     */
344    public double validationError() throws Exception {
345
346      if (m_modelHasChanged) {
347
348        Instance inst;
349        int numIncorrect = 0;
350        for (int i=0; i<m_validationSet.numInstances(); i++) {
351          inst = m_validationSet.instance(i);
352          if (classifyInstance(m_validationFs[i]) != inst.classValue())
353            numIncorrect++;
354        }
355        m_lastValidationError = (double) numIncorrect / (double) m_validationSet.numInstances();
356        m_modelHasChanged = false;
357      }
358      return m_lastValidationError;
359    }
360
361    /**
362     * returns the chunk size used by the committee
363     *
364     * @return the chunk size
365     */
366    public int chunkSize() {
367
368      return m_chunkSize;
369    }
370
371    /**
372     * returns the number of models in the committee
373     *
374     * @return the committee size
375     */
376    public int committeeSize() {
377
378      return m_models.size();
379    }
380
381   
382    /**
383     * classifies an instance (given Fs values) with the committee
384     *
385     * @param Fs the Fs values
386     * @return the classification
387     * @throws Exception if anything goes wrong
388     */
389    public double classifyInstance(double[] Fs) throws Exception {
390     
391      double [] dist = distributionForInstance(Fs);
392
393      double max = 0;
394      int maxIndex = 0;
395     
396      for (int i = 0; i < dist.length; i++) {
397        if (dist[i] > max) {
398          maxIndex = i;
399          max = dist[i];
400        }
401      }
402      if (max > 0) {
403        return maxIndex;
404      } else {
405        return Utils.missingValue();
406      }
407    }
408
409    /**
410     * classifies an instance with the committee
411     *
412     * @param instance the instance to classify
413     * @return the classification
414     * @throws Exception if anything goes wrong
415     */
416    public double classifyInstance(Instance instance) throws Exception {
417     
418      double [] dist = distributionForInstance(instance);
419      switch (instance.classAttribute().type()) {
420      case Attribute.NOMINAL:
421        double max = 0;
422        int maxIndex = 0;
423       
424        for (int i = 0; i < dist.length; i++) {
425          if (dist[i] > max) {
426            maxIndex = i;
427            max = dist[i];
428          }
429        }
430        if (max > 0) {
431          return maxIndex;
432        } else {
433          return Utils.missingValue();
434        }
435      case Attribute.NUMERIC:
436        return dist[0];
437      default:
438        return Utils.missingValue();
439      }
440    }
441
442    /**
443     * returns the distribution the committee generates for an instance (given Fs values)
444     *
445     * @param Fs the Fs values
446     * @return the distribution
447     * @throws Exception if anything goes wrong
448     */
449    public double[] distributionForInstance(double[] Fs) throws Exception {
450     
451      double [] distribution = new double [m_NumClasses];
452      for (int j = 0; j < m_NumClasses; j++) {
453        distribution[j] = RtoP(Fs, j);
454      }
455      return distribution;
456    }
457   
458    /**
459     * updates the Fs values given a new model in the committee
460     *
461     * @param instance the instance to use
462     * @param newModel the new model
463     * @param Fs the Fs values to update
464     * @return the updated Fs values
465     * @throws Exception if anything goes wrong
466     */
467    public double[] updateFS(Instance instance, Classifier[] newModel, double[] Fs) throws Exception {
468     
469      instance = (Instance)instance.copy();
470      instance.setDataset(m_NumericClassData);
471     
472      double [] Fi = new double [m_NumClasses];
473      double Fsum = 0;
474      for (int j = 0; j < m_NumClasses; j++) {
475        Fi[j] = newModel[j].classifyInstance(instance);
476        Fsum += Fi[j];
477      }
478      Fsum /= m_NumClasses;
479     
480      double[] newFs = new double[Fs.length];
481      for (int j = 0; j < m_NumClasses; j++) {
482        newFs[j] = Fs[j] + ((Fi[j] - Fsum) * (m_NumClasses - 1) / m_NumClasses);
483      }
484      return newFs;
485    }
486
487    /**
488     * returns the distribution the committee generates for an instance
489     *
490     * @param instance the instance to get the distribution for
491     * @return the distribution
492     * @throws Exception if anything goes wrong
493     */
494    public double[] distributionForInstance(Instance instance) throws Exception {
495
496      instance = (Instance)instance.copy();
497      instance.setDataset(m_NumericClassData);
498      double [] Fs = new double [m_NumClasses]; 
499      for (int i = 0; i < m_models.size(); i++) {
500        double [] Fi = new double [m_NumClasses];
501        double Fsum = 0;
502        Classifier[] model = (Classifier[]) m_models.elementAt(i);
503        for (int j = 0; j < m_NumClasses; j++) {
504          Fi[j] = model[j].classifyInstance(instance);
505          Fsum += Fi[j];
506        }
507        Fsum /= m_NumClasses;
508        for (int j = 0; j < m_NumClasses; j++) {
509          Fs[j] += (Fi[j] - Fsum) * (m_NumClasses - 1) / m_NumClasses;
510        }
511      }
512      double [] distribution = new double [m_NumClasses];
513      for (int j = 0; j < m_NumClasses; j++) {
514        distribution[j] = RtoP(Fs, j);
515      }
516      return distribution;
517    }
518
519    /**
520     * performs a boosting iteration, returning a new model for the committee
521     *
522     * @param data the data to boost on
523     * @return the new model
524     * @throws Exception if anything goes wrong
525     */
526    protected Classifier[] boost(Instances data) throws Exception {
527     
528      Classifier[] newModel = AbstractClassifier.makeCopies(m_Classifier, m_NumClasses);
529     
530      // Create a copy of the data with the class transformed into numeric
531      Instances boostData = new Instances(data);
532      boostData.deleteWithMissingClass();
533      int numInstances = boostData.numInstances();
534     
535      // Temporarily unset the class index
536      int classIndex = data.classIndex();
537      boostData.setClassIndex(-1);
538      boostData.deleteAttributeAt(classIndex);
539      boostData.insertAttributeAt(new Attribute("'pseudo class'"), classIndex);
540      boostData.setClassIndex(classIndex);
541      double [][] trainFs = new double [numInstances][m_NumClasses];
542      double [][] trainYs = new double [numInstances][m_NumClasses];
543      for (int j = 0; j < m_NumClasses; j++) {
544        for (int i = 0, k = 0; i < numInstances; i++, k++) {
545          while (data.instance(k).classIsMissing()) k++;
546          trainYs[i][j] = (data.instance(k).classValue() == j) ? 1 : 0;
547        }
548      }
549     
550      // Evaluate / increment trainFs from the classifiers
551      for (int x = 0; x < m_models.size(); x++) {
552        for (int i = 0; i < numInstances; i++) {
553          double [] pred = new double [m_NumClasses];
554          double predSum = 0;
555          Classifier[] model = (Classifier[]) m_models.elementAt(x);
556          for (int j = 0; j < m_NumClasses; j++) {
557            pred[j] = model[j].classifyInstance(boostData.instance(i));
558            predSum += pred[j];
559          }
560          predSum /= m_NumClasses;
561          for (int j = 0; j < m_NumClasses; j++) {
562            trainFs[i][j] += (pred[j] - predSum) * (m_NumClasses-1) 
563              / m_NumClasses;
564          }
565        }
566      }
567
568      for (int j = 0; j < m_NumClasses; j++) {
569       
570        // Set instance pseudoclass and weights
571        for (int i = 0; i < numInstances; i++) {
572          double p = RtoP(trainFs[i], j);
573          Instance current = boostData.instance(i);
574          double z, actual = trainYs[i][j];
575          if (actual == 1) {
576            z = 1.0 / p;
577            if (z > Z_MAX) { // threshold
578              z = Z_MAX;
579            }
580          } else if (actual == 0) {
581            z = -1.0 / (1.0 - p);
582            if (z < -Z_MAX) { // threshold
583              z = -Z_MAX;
584            }
585          } else {
586            z = (actual - p) / (p * (1 - p));
587          }
588
589          double w = (actual - p) / z;
590          current.setValue(classIndex, z);
591          current.setWeight(numInstances * w);
592        }
593       
594        Instances trainData = boostData;
595        if (m_UseResampling) {
596          double[] weights = new double[boostData.numInstances()];
597          for (int kk = 0; kk < weights.length; kk++) {
598            weights[kk] = boostData.instance(kk).weight();
599          }
600          trainData = boostData.resampleWithWeights(m_RandomInstance, 
601                                                    weights);
602        }
603       
604        // Build the classifier
605        newModel[j].buildClassifier(trainData);
606      }     
607     
608      return newModel;
609    }
610
611    /**
612     * outputs description of the committee
613     *
614     * @return a string representation of the classifier
615     */
616    public String toString() {
617     
618      StringBuffer text = new StringBuffer();
619     
620      text.append("RacedIncrementalLogitBoost: Best committee on validation data\n");
621      text.append("Base classifiers: \n");
622     
623      for (int i = 0; i < m_models.size(); i++) {
624        text.append("\nModel "+(i+1));
625        Classifier[] cModels = (Classifier[]) m_models.elementAt(i);
626        for (int j = 0; j < m_NumClasses; j++) {
627          text.append("\n\tClass " + (j + 1) 
628                      + " (" + m_ClassAttribute.name() 
629                      + "=" + m_ClassAttribute.value(j) + ")\n\n"
630                      + cModels[j].toString() + "\n");
631        }
632      }
633      text.append("Number of models: " +
634                  m_models.size() + "\n");     
635      text.append("Chunk size per model: " + m_chunkSize + "\n");
636     
637      return text.toString();
638    }
639   
640    /**
641     * Returns the revision string.
642     *
643     * @return          the revision
644     */
645    public String getRevision() {
646      return RevisionUtils.extract("$Revision: 5987 $");
647    }
648  }
649
650  /**
651   * Returns default capabilities of the classifier.
652   *
653   * @return      the capabilities of this classifier
654   */
655  public Capabilities getCapabilities() {
656    Capabilities result = super.getCapabilities();
657
658    // class
659    result.disableAllClasses();
660    result.disableAllClassDependencies();
661    result.enable(Capability.NOMINAL_CLASS);
662
663    // instances
664    result.setMinimumNumberInstances(0);
665   
666    return result;
667  }
668
669 /**
670   * Builds the classifier.
671   *
672   * @param data the instances to train the classifier with
673   * @throws Exception if something goes wrong
674   */
675  public void buildClassifier(Instances data) throws Exception {
676
677    m_RandomInstance = new Random(m_Seed);
678
679    Instances boostData;
680    int classIndex = data.classIndex();
681
682    // can classifier handle the data?
683    getCapabilities().testWithFail(data);
684
685    // remove instances with missing class
686    data = new Instances(data);
687    data.deleteWithMissingClass();
688   
689    if (m_Classifier == null) {
690      throw new Exception("A base classifier has not been specified!");
691    }
692
693    if (!(m_Classifier instanceof WeightedInstancesHandler) &&
694        !m_UseResampling) {
695      m_UseResampling = true;
696    }
697
698    m_NumClasses = data.numClasses();
699    m_ClassAttribute = data.classAttribute();
700
701    // Create a copy of the data with the class transformed into numeric
702    boostData = new Instances(data);
703
704    // Temporarily unset the class index
705    boostData.setClassIndex(-1);
706    boostData.deleteAttributeAt(classIndex);
707    boostData.insertAttributeAt(new Attribute("'pseudo class'"), classIndex);
708    boostData.setClassIndex(classIndex);
709    m_NumericClassData = new Instances(boostData, 0);
710
711    data.randomize(m_RandomInstance);
712
713    // create the committees
714    int cSize = m_minChunkSize;
715    m_committees = new FastVector();
716    while (cSize <= m_maxChunkSize) {
717      m_committees.addElement(new Committee(cSize));
718      m_maxBatchSizeRequired = cSize;
719      cSize *= 2;
720    }
721
722    // set up for consumption
723    m_validationSet = new Instances(data, m_validationChunkSize);
724    m_currentSet = new Instances(data, m_maxBatchSizeRequired);
725    m_bestCommittee = null;
726    m_numInstancesConsumed = 0;
727
728    // start eating what we've been given
729    for (int i=0; i<data.numInstances(); i++) updateClassifier(data.instance(i));
730  }
731
732 /**
733   * Updates the classifier.
734   *
735   * @param instance the next instance in the stream of training data
736   * @throws Exception if something goes wrong
737   */
738  public void updateClassifier(Instance instance) throws Exception {
739
740    m_numInstancesConsumed++;
741
742    if (m_validationSet.numInstances() < m_validationChunkSize) {
743      m_validationSet.add(instance);
744      m_validationSetChanged = true;
745    } else {
746      m_currentSet.add(instance);
747      boolean hasChanged = false;
748     
749      // update each committee
750      for (int i=0; i<m_committees.size(); i++) {
751        Committee c = (Committee) m_committees.elementAt(i);
752        if (c.update()) {
753         
754          hasChanged = true;
755         
756          if (m_PruningType == PRUNETYPE_LOGLIKELIHOOD) {
757            double oldLL = c.logLikelihood();
758            double newLL = c.logLikelihoodAfter();
759            if (newLL >= oldLL && c.committeeSize() > 1) {
760              c.pruneLastModel();
761              if (m_Debug) System.out.println("Pruning " + c.chunkSize()+ " committee (" +
762                                              oldLL + " < " + newLL + ")");
763            } else c.keepLastModel();
764          } else c.keepLastModel(); // no pruning
765        } 
766      }
767      if (hasChanged) {
768
769        if (m_Debug) System.out.println("After consuming " + m_numInstancesConsumed
770                                        + " instances... (" + m_validationSet.numInstances()
771                                        + " + " + m_currentSet.numInstances()
772                                        + " instances currently in memory)");
773       
774        // find best committee
775        double lowestError = 1.0;
776        for (int i=0; i<m_committees.size(); i++) {
777          Committee c = (Committee) m_committees.elementAt(i);
778
779          if (c.committeeSize() > 0) {
780
781            double err = c.validationError();
782            double ll = c.logLikelihood();
783
784            if (m_Debug) System.out.println("Chunk size " + c.chunkSize() + " with "
785                                            + c.committeeSize() + " models, has validation error of "
786                                            + err + ", log likelihood of " + ll);
787            if (err < lowestError) {
788              lowestError = err;
789              m_bestCommittee = c;
790            }
791          }
792        }
793      }
794      if (m_currentSet.numInstances() >= m_maxBatchSizeRequired) {
795        m_currentSet = new Instances(m_currentSet, m_maxBatchSizeRequired);
796
797        // reset consumation counts
798        for (int i=0; i<m_committees.size(); i++) {
799          Committee c = (Committee) m_committees.elementAt(i);
800          c.resetConsumed();
801        }
802      }
803    }
804  }
805
806  /**
807   * Convert from function responses to probabilities
808   *
809   * @param Fs an array containing the responses from each function
810   * @param j the class value of interest
811   * @return the probability prediction for j
812   * @throws Exception if can't normalize
813   */
814  protected static double RtoP(double []Fs, int j) 
815    throws Exception {
816
817    double maxF = -Double.MAX_VALUE;
818    for (int i = 0; i < Fs.length; i++) {
819      if (Fs[i] > maxF) {
820        maxF = Fs[i];
821      }
822    }
823    double sum = 0;
824    double[] probs = new double[Fs.length];
825    for (int i = 0; i < Fs.length; i++) {
826      probs[i] = Math.exp(Fs[i] - maxF);
827      sum += probs[i];
828    }
829    if (sum == 0) {
830      throw new Exception("Can't normalize");
831    }
832    return probs[j] / sum;
833  }
834
835  /**
836   * Computes class distribution of an instance using the best committee.
837   *
838   * @param instance the instance to get the distribution for
839   * @return the distribution
840   * @throws Exception if anything goes wrong
841   */
842  public double[] distributionForInstance(Instance instance) throws Exception {
843
844    if (m_bestCommittee != null) return m_bestCommittee.distributionForInstance(instance);
845    else {
846      if (m_validationSetChanged || m_zeroR == null) {
847        m_zeroR = new ZeroR();
848        m_zeroR.buildClassifier(m_validationSet);
849        m_validationSetChanged = false;
850      }
851      return m_zeroR.distributionForInstance(instance);
852    }
853  }
854
855  /**
856   * Returns an enumeration describing the available options
857   *
858   * @return an enumeration of all the available options
859   */
860  public Enumeration listOptions() {
861
862    Vector newVector = new Vector(9);
863
864    newVector.addElement(new Option(
865              "\tMinimum size of chunks.\n"
866              +"\t(default 500)",
867              "C", 1, "-C <num>"));
868
869    newVector.addElement(new Option(
870              "\tMaximum size of chunks.\n"
871              +"\t(default 2000)",
872              "M", 1, "-M <num>"));
873
874    newVector.addElement(new Option(
875              "\tSize of validation set.\n"
876              +"\t(default 1000)",
877              "V", 1, "-V <num>"));
878
879    newVector.addElement(new Option(
880              "\tCommittee pruning to perform.\n"
881              +"\t0=none, 1=log likelihood (default)",
882              "P", 1, "-P <pruning type>"));
883
884    newVector.addElement(new Option(
885              "\tUse resampling for boosting.",
886              "Q", 0, "-Q"));
887
888
889    Enumeration enu = super.listOptions();
890    while (enu.hasMoreElements()) {
891      newVector.addElement(enu.nextElement());
892    }
893    return newVector.elements();
894  }
895
896
897  /**
898   * Parses a given list of options. <p/>
899   *
900   <!-- options-start -->
901   * Valid options are: <p/>
902   *
903   * <pre> -C &lt;num&gt;
904   *  Minimum size of chunks.
905   *  (default 500)</pre>
906   *
907   * <pre> -M &lt;num&gt;
908   *  Maximum size of chunks.
909   *  (default 2000)</pre>
910   *
911   * <pre> -V &lt;num&gt;
912   *  Size of validation set.
913   *  (default 1000)</pre>
914   *
915   * <pre> -P &lt;pruning type&gt;
916   *  Committee pruning to perform.
917   *  0=none, 1=log likelihood (default)</pre>
918   *
919   * <pre> -Q
920   *  Use resampling for boosting.</pre>
921   *
922   * <pre> -S &lt;num&gt;
923   *  Random number seed.
924   *  (default 1)</pre>
925   *
926   * <pre> -D
927   *  If set, classifier is run in debug mode and
928   *  may output additional info to the console</pre>
929   *
930   * <pre> -W
931   *  Full name of base classifier.
932   *  (default: weka.classifiers.trees.DecisionStump)</pre>
933   *
934   * <pre>
935   * Options specific to classifier weka.classifiers.trees.DecisionStump:
936   * </pre>
937   *
938   * <pre> -D
939   *  If set, classifier is run in debug mode and
940   *  may output additional info to the console</pre>
941   *
942   <!-- options-end -->
943   *
944   * @param options the list of options as an array of strings
945   * @throws Exception if an option is not supported
946   */
947  public void setOptions(String[] options) throws Exception {
948
949    String minChunkSize = Utils.getOption('C', options);
950    if (minChunkSize.length() != 0) {
951      setMinChunkSize(Integer.parseInt(minChunkSize));
952    } else {
953      setMinChunkSize(500);
954    }
955
956    String maxChunkSize = Utils.getOption('M', options);
957    if (maxChunkSize.length() != 0) {
958      setMaxChunkSize(Integer.parseInt(maxChunkSize));
959    } else {
960      setMaxChunkSize(2000);
961    }
962
963    String validationChunkSize = Utils.getOption('V', options);
964    if (validationChunkSize.length() != 0) {
965      setValidationChunkSize(Integer.parseInt(validationChunkSize));
966    } else {
967      setValidationChunkSize(1000);
968    }
969
970    String pruneType = Utils.getOption('P', options);
971    if (pruneType.length() != 0) {
972      setPruningType(new SelectedTag(Integer.parseInt(pruneType), TAGS_PRUNETYPE));
973    } else {
974      setPruningType(new SelectedTag(PRUNETYPE_LOGLIKELIHOOD, TAGS_PRUNETYPE));
975    }
976
977    setUseResampling(Utils.getFlag('Q', options));
978
979    super.setOptions(options);
980  }
981
982  /**
983   * Gets the current settings of the Classifier.
984   *
985   * @return an array of strings suitable for passing to setOptions
986   */
987  public String [] getOptions() {
988
989    String [] superOptions = super.getOptions();
990    String [] options = new String [superOptions.length + 9];
991
992    int current = 0;
993
994    if (getUseResampling()) {
995      options[current++] = "-Q";
996    }
997    options[current++] = "-C"; options[current++] = "" + getMinChunkSize();
998
999    options[current++] = "-M"; options[current++] = "" + getMaxChunkSize();
1000
1001    options[current++] = "-V"; options[current++] = "" + getValidationChunkSize();
1002
1003    options[current++] = "-P"; options[current++] = "" + m_PruningType;
1004
1005    System.arraycopy(superOptions, 0, options, current, 
1006                     superOptions.length);
1007
1008    current += superOptions.length;
1009    while (current < options.length) {
1010      options[current++] = "";
1011    }
1012    return options;
1013  }
1014
1015  /**
1016   * @return a description of the classifier suitable for
1017   * displaying in the explorer/experimenter gui
1018   */
1019  public String globalInfo() {
1020
1021    return "Classifier for incremental learning of large datasets by way of racing logit-boosted committees.";
1022  }
1023
1024  /**
1025   * Set the base learner.
1026   *
1027   * @param newClassifier               the classifier to use.
1028   * @throws IllegalArgumentException   if base classifier cannot handle numeric
1029   *                                    class
1030   */
1031  public void setClassifier(Classifier newClassifier) {
1032    Capabilities cap = newClassifier.getCapabilities();
1033   
1034    if (!cap.handles(Capability.NUMERIC_CLASS))
1035      throw new IllegalArgumentException("Base classifier cannot handle numeric class!");
1036     
1037    super.setClassifier(newClassifier);
1038  }
1039
1040  /**
1041   * @return tip text for this property suitable for
1042   * displaying in the explorer/experimenter gui
1043   */
1044  public String minChunkSizeTipText() {
1045
1046    return "The minimum number of instances to train the base learner with.";
1047  }
1048
1049  /**
1050   * Set the minimum chunk size
1051   *
1052   * @param chunkSize the minimum chunk size
1053   */
1054  public void setMinChunkSize(int chunkSize) {
1055
1056    m_minChunkSize = chunkSize;
1057  }
1058
1059  /**
1060   * Get the minimum chunk size
1061   *
1062   * @return the chunk size
1063   */
1064  public int getMinChunkSize() {
1065
1066    return m_minChunkSize;
1067  }
1068
1069  /**
1070   * @return tip text for this property suitable for
1071   * displaying in the explorer/experimenter gui
1072   */
1073  public String maxChunkSizeTipText() {
1074
1075    return "The maximum number of instances to train the base learner with. The chunk sizes used will start at minChunkSize and grow twice as large for as many times as they are less than or equal to the maximum size.";
1076  }
1077
1078  /**
1079   * Set the maximum chunk size
1080   *
1081   * @param chunkSize the maximum chunk size
1082   */
1083  public void setMaxChunkSize(int chunkSize) {
1084
1085    m_maxChunkSize = chunkSize;
1086  }
1087
1088  /**
1089   * Get the maximum chunk size
1090   *
1091   * @return the chunk size
1092   */
1093  public int getMaxChunkSize() {
1094
1095    return m_maxChunkSize;
1096  }
1097
1098  /**
1099   * @return tip text for this property suitable for
1100   * displaying in the explorer/experimenter gui
1101   */
1102  public String validationChunkSizeTipText() {
1103
1104    return "The number of instances to hold out for validation. These instances will be taken from the beginning of the stream, so learning will not start until these instances have been consumed first.";
1105  }
1106
1107  /**
1108   * Set the validation chunk size
1109   *
1110   * @param chunkSize the validation chunk size
1111   */
1112  public void setValidationChunkSize(int chunkSize) {
1113
1114    m_validationChunkSize = chunkSize;
1115  }
1116
1117  /**
1118   * Get the validation chunk size
1119   *
1120   * @return the chunk size
1121   */
1122  public int getValidationChunkSize() {
1123
1124    return m_validationChunkSize;
1125  }
1126
1127  /**
1128   * @return tip text for this property suitable for
1129   * displaying in the explorer/experimenter gui
1130   */
1131  public String pruningTypeTipText() {
1132
1133    return "The pruning method to use within each committee. Log likelihood pruning will discard new models if they have a negative effect on the log likelihood of the validation data.";
1134  }
1135
1136  /**
1137   * Set the pruning type
1138   *
1139   * @param pruneType the pruning type
1140   */
1141  public void setPruningType(SelectedTag pruneType) {
1142
1143    if (pruneType.getTags() == TAGS_PRUNETYPE) {
1144      m_PruningType = pruneType.getSelectedTag().getID();
1145    }
1146  }
1147
1148  /**
1149   * Get the pruning type
1150   *
1151   * @return the type
1152   */
1153  public SelectedTag getPruningType() {
1154
1155    return new SelectedTag(m_PruningType, TAGS_PRUNETYPE);
1156  }
1157
1158  /**
1159   * @return tip text for this property suitable for
1160   * displaying in the explorer/experimenter gui
1161   */
1162  public String useResamplingTipText() {
1163
1164    return "Force the use of resampling data rather than using the weight-handling capabilities of the base classifier. Resampling is always used if the base classifier cannot handle weighted instances.";
1165  }
1166
1167  /**
1168   * Set resampling mode
1169   *
1170   * @param r true if resampling should be done
1171   */
1172  public void setUseResampling(boolean r) {
1173   
1174    m_UseResampling = r;
1175  }
1176
1177  /**
1178   * Get whether resampling is turned on
1179   *
1180   * @return true if resampling output is on
1181   */
1182  public boolean getUseResampling() {
1183   
1184    return m_UseResampling;
1185  }
1186
1187  /**
1188   * Get the best committee chunk size
1189   *
1190   * @return the best committee chunk size
1191   */
1192  public int getBestCommitteeChunkSize() {
1193
1194    if (m_bestCommittee != null) {
1195      return m_bestCommittee.chunkSize();
1196    }
1197    else return 0;
1198  }
1199
1200  /**
1201   * Get the number of members in the best committee
1202   *
1203   * @return the number of members
1204   */
1205  public int getBestCommitteeSize() {
1206
1207    if (m_bestCommittee != null) {
1208      return m_bestCommittee.committeeSize();
1209    }
1210    else return 0;
1211  }
1212
1213  /**
1214   * Get the best committee's error on the validation data
1215   *
1216   * @return the best committee's error
1217   */
1218  public double getBestCommitteeErrorEstimate() {
1219
1220    if (m_bestCommittee != null) {
1221      try {
1222        return m_bestCommittee.validationError() * 100.0;
1223      } catch (Exception e) {
1224        System.err.println(e.getMessage());
1225        return 100.0;
1226      }
1227    }
1228    else return 100.0;
1229  }
1230
1231  /**
1232   * Get the best committee's log likelihood on the validation data
1233   *
1234   * @return best committee's log likelihood
1235   */
1236  public double getBestCommitteeLLEstimate() {
1237
1238    if (m_bestCommittee != null) {
1239      try {
1240        return m_bestCommittee.logLikelihood();
1241      } catch (Exception e) {
1242        System.err.println(e.getMessage());
1243        return Double.MAX_VALUE;
1244      }
1245    }
1246    else return Double.MAX_VALUE;
1247  }
1248 
1249  /**
1250   * Returns description of the boosted classifier.
1251   *
1252   * @return description of the boosted classifier as a string
1253   */
1254  public String toString() {
1255       
1256    if (m_bestCommittee != null) {
1257      return m_bestCommittee.toString();
1258    } else {
1259      if ((m_validationSetChanged || m_zeroR == null) && m_validationSet != null
1260          && m_validationSet.numInstances() > 0) {
1261        m_zeroR = new ZeroR();
1262        try {
1263          m_zeroR.buildClassifier(m_validationSet);
1264        } catch (Exception e) {}
1265        m_validationSetChanged = false;
1266      }
1267      if (m_zeroR != null) {
1268        return ("RacedIncrementalLogitBoost: insufficient data to build model, resorting to ZeroR:\n\n"
1269                + m_zeroR.toString());
1270      }
1271      else return ("RacedIncrementalLogitBoost: no model built yet.");
1272    }
1273  }
1274 
1275  /**
1276   * Returns the revision string.
1277   *
1278   * @return            the revision
1279   */
1280  public String getRevision() {
1281    return RevisionUtils.extract("$Revision: 5987 $");
1282  }
1283
1284  /**
1285   * Main method for this class.
1286   *
1287   * @param argv the commandline parameters
1288   */
1289  public static void main(String[] argv) {
1290    runClassifier(new RacedIncrementalLogitBoost(), argv);
1291  }
1292}
Note: See TracBrowser for help on using the repository browser.