source: src/main/java/weka/classifiers/meta/LogitBoost.java @ 23

Last change on this file since 23 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 33.6 KB
RevLine 
[4]1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    LogitBoost.java
19 *    Copyright (C) 1999, 2002 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.meta;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.classifiers.Evaluation;
28import weka.classifiers.RandomizableIteratedSingleClassifierEnhancer;
29import weka.classifiers.Sourcable;
30import weka.core.Attribute;
31import weka.core.Capabilities;
32import weka.core.Instance;
33import weka.core.Instances;
34import weka.core.Option;
35import weka.core.RevisionUtils;
36import weka.core.TechnicalInformation;
37import weka.core.TechnicalInformationHandler;
38import weka.core.Utils;
39import weka.core.WeightedInstancesHandler;
40import weka.core.Capabilities.Capability;
41import weka.core.TechnicalInformation.Field;
42import weka.core.TechnicalInformation.Type;
43
44import java.util.Enumeration;
45import java.util.Random;
46import java.util.Vector;
47
48/**
49 <!-- globalinfo-start -->
50 * Class for performing additive logistic regression. <br/>
51 * This class performs classification using a regression scheme as the base learner, and can handle multi-class problems.  For more information, see<br/>
52 * <br/>
53 * J. Friedman, T. Hastie, R. Tibshirani (1998). Additive Logistic Regression: a Statistical View of Boosting. Stanford University.<br/>
54 * <br/>
55 * Can do efficient internal cross-validation to determine appropriate number of iterations.
56 * <p/>
57 <!-- globalinfo-end -->
58 *
59 <!-- technical-bibtex-start -->
60 * BibTeX:
61 * <pre>
62 * &#64;techreport{Friedman1998,
63 *    address = {Stanford University},
64 *    author = {J. Friedman and T. Hastie and R. Tibshirani},
65 *    title = {Additive Logistic Regression: a Statistical View of Boosting},
66 *    year = {1998},
67 *    PS = {http://www-stat.stanford.edu/\~jhf/ftp/boost.ps}
68 * }
69 * </pre>
70 * <p/>
71 <!-- technical-bibtex-end -->
72 *
73 <!-- options-start -->
74 * Valid options are: <p/>
75 *
76 * <pre> -Q
77 *  Use resampling instead of reweighting for boosting.</pre>
78 *
79 * <pre> -P &lt;percent&gt;
80 *  Percentage of weight mass to base training on.
81 *  (default 100, reduce to around 90 speed up)</pre>
82 *
83 * <pre> -F &lt;num&gt;
84 *  Number of folds for internal cross-validation.
85 *  (default 0 -- no cross-validation)</pre>
86 *
87 * <pre> -R &lt;num&gt;
88 *  Number of runs for internal cross-validation.
89 *  (default 1)</pre>
90 *
91 * <pre> -L &lt;num&gt;
92 *  Threshold on the improvement of the likelihood.
93 *  (default -Double.MAX_VALUE)</pre>
94 *
95 * <pre> -H &lt;num&gt;
96 *  Shrinkage parameter.
97 *  (default 1)</pre>
98 *
99 * <pre> -S &lt;num&gt;
100 *  Random number seed.
101 *  (default 1)</pre>
102 *
103 * <pre> -I &lt;num&gt;
104 *  Number of iterations.
105 *  (default 10)</pre>
106 *
107 * <pre> -D
108 *  If set, classifier is run in debug mode and
109 *  may output additional info to the console</pre>
110 *
111 * <pre> -W
112 *  Full name of base classifier.
113 *  (default: weka.classifiers.trees.DecisionStump)</pre>
114 *
115 * <pre>
116 * Options specific to classifier weka.classifiers.trees.DecisionStump:
117 * </pre>
118 *
119 * <pre> -D
120 *  If set, classifier is run in debug mode and
121 *  may output additional info to the console</pre>
122 *
123 <!-- options-end -->
124 *
125 * Options after -- are passed to the designated learner.<p>
126 *
127 * @author Len Trigg (trigg@cs.waikato.ac.nz)
128 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
129 * @version $Revision: 6091 $
130 */
131public class LogitBoost 
132  extends RandomizableIteratedSingleClassifierEnhancer
133  implements Sourcable, WeightedInstancesHandler, TechnicalInformationHandler {
134
135  /** for serialization */
136  static final long serialVersionUID = -3905660358715833753L;
137 
138  /** Array for storing the generated base classifiers.
139   Note: we are hiding the variable from IteratedSingleClassifierEnhancer*/
140  protected Classifier [][] m_Classifiers;
141
142  /** The number of classes */
143  protected int m_NumClasses;
144
145  /** The number of successfully generated base classifiers. */
146  protected int m_NumGenerated;
147
148  /** The number of folds for the internal cross-validation. */
149  protected int m_NumFolds = 0;
150
151  /** The number of runs for the internal cross-validation. */
152  protected int m_NumRuns = 1;
153
154  /** Weight thresholding. The percentage of weight mass used in training */
155  protected int m_WeightThreshold = 100;
156
157  /** A threshold for responses (Friedman suggests between 2 and 4) */
158  protected static final double Z_MAX = 3;
159
160  /** Dummy dataset with a numeric class */
161  protected Instances m_NumericClassData;
162
163  /** The actual class attribute (for getting class names) */
164  protected Attribute m_ClassAttribute;
165
166  /** Use boosting with reweighting? */
167  protected boolean m_UseResampling;
168
169  /** The threshold on the improvement of the likelihood */   
170  protected double m_Precision = -Double.MAX_VALUE;
171
172  /** The value of the shrinkage parameter */
173  protected double m_Shrinkage = 1;
174
175  /** The random number generator used */
176  protected Random m_RandomInstance = null;
177
178  /** The value by which the actual target value for the
179      true class is offset. */
180  protected double m_Offset = 0.0;
181   
182  /** a ZeroR model in case no model can be built from the data */
183  protected Classifier m_ZeroR;
184   
185  /**
186   * Returns a string describing classifier
187   * @return a description suitable for
188   * displaying in the explorer/experimenter gui
189   */
190  public String globalInfo() {
191
192    return "Class for performing additive logistic regression. \n"
193      + "This class performs classification using a regression scheme as the "
194      + "base learner, and can handle multi-class problems.  For more "
195      + "information, see\n\n"
196      + getTechnicalInformation().toString() + "\n\n"
197      + "Can do efficient internal cross-validation to determine "
198      + "appropriate number of iterations.";
199  }
200   
201  /**
202   * Constructor.
203   */
204  public LogitBoost() {
205   
206    m_Classifier = new weka.classifiers.trees.DecisionStump();
207  }
208
209  /**
210   * Returns an instance of a TechnicalInformation object, containing
211   * detailed information about the technical background of this class,
212   * e.g., paper reference or book this class is based on.
213   *
214   * @return the technical information about this class
215   */
216  public TechnicalInformation getTechnicalInformation() {
217    TechnicalInformation        result;
218   
219    result = new TechnicalInformation(Type.TECHREPORT);
220    result.setValue(Field.AUTHOR, "J. Friedman and T. Hastie and R. Tibshirani");
221    result.setValue(Field.YEAR, "1998");
222    result.setValue(Field.TITLE, "Additive Logistic Regression: a Statistical View of Boosting");
223    result.setValue(Field.ADDRESS, "Stanford University");
224    result.setValue(Field.PS, "http://www-stat.stanford.edu/~jhf/ftp/boost.ps");
225   
226    return result;
227  }
228
229  /**
230   * String describing default classifier.
231   *
232   * @return the default classifier classname
233   */
234  protected String defaultClassifierString() {
235   
236    return "weka.classifiers.trees.DecisionStump";
237  }
238
239  /**
240   * Select only instances with weights that contribute to
241   * the specified quantile of the weight distribution
242   *
243   * @param data the input instances
244   * @param quantile the specified quantile eg 0.9 to select
245   * 90% of the weight mass
246   * @return the selected instances
247   */
248  protected Instances selectWeightQuantile(Instances data, double quantile) { 
249
250    int numInstances = data.numInstances();
251    Instances trainData = new Instances(data, numInstances);
252    double [] weights = new double [numInstances];
253
254    double sumOfWeights = 0;
255    for (int i = 0; i < numInstances; i++) {
256      weights[i] = data.instance(i).weight();
257      sumOfWeights += weights[i];
258    }
259    double weightMassToSelect = sumOfWeights * quantile;
260    int [] sortedIndices = Utils.sort(weights);
261
262    // Select the instances
263    sumOfWeights = 0;
264    for (int i = numInstances-1; i >= 0; i--) {
265      Instance instance = (Instance)data.instance(sortedIndices[i]).copy();
266      trainData.add(instance);
267      sumOfWeights += weights[sortedIndices[i]];
268      if ((sumOfWeights > weightMassToSelect) && 
269          (i > 0) && 
270          (weights[sortedIndices[i]] != weights[sortedIndices[i-1]])) {
271        break;
272      }
273    }
274    if (m_Debug) {
275      System.err.println("Selected " + trainData.numInstances()
276                         + " out of " + numInstances);
277    }
278    return trainData;
279  }
280
281  /**
282   * Returns an enumeration describing the available options.
283   *
284   * @return an enumeration of all the available options.
285   */
286  public Enumeration listOptions() {
287
288    Vector newVector = new Vector(6);
289
290    newVector.addElement(new Option(
291              "\tUse resampling instead of reweighting for boosting.",
292              "Q", 0, "-Q"));
293    newVector.addElement(new Option(
294              "\tPercentage of weight mass to base training on.\n"
295              +"\t(default 100, reduce to around 90 speed up)",
296              "P", 1, "-P <percent>"));
297    newVector.addElement(new Option(
298              "\tNumber of folds for internal cross-validation.\n"
299              +"\t(default 0 -- no cross-validation)",
300              "F", 1, "-F <num>"));
301    newVector.addElement(new Option(
302              "\tNumber of runs for internal cross-validation.\n"
303              +"\t(default 1)",
304              "R", 1, "-R <num>"));
305    newVector.addElement(new Option(
306              "\tThreshold on the improvement of the likelihood.\n"
307              +"\t(default -Double.MAX_VALUE)",
308              "L", 1, "-L <num>"));
309    newVector.addElement(new Option(
310              "\tShrinkage parameter.\n"
311              +"\t(default 1)",
312              "H", 1, "-H <num>"));
313
314    Enumeration enu = super.listOptions();
315    while (enu.hasMoreElements()) {
316      newVector.addElement(enu.nextElement());
317    }
318    return newVector.elements();
319  }
320
321
322  /**
323   * Parses a given list of options. <p/>
324   *
325   <!-- options-start -->
326   * Valid options are: <p/>
327   *
328   * <pre> -Q
329   *  Use resampling instead of reweighting for boosting.</pre>
330   *
331   * <pre> -P &lt;percent&gt;
332   *  Percentage of weight mass to base training on.
333   *  (default 100, reduce to around 90 speed up)</pre>
334   *
335   * <pre> -F &lt;num&gt;
336   *  Number of folds for internal cross-validation.
337   *  (default 0 -- no cross-validation)</pre>
338   *
339   * <pre> -R &lt;num&gt;
340   *  Number of runs for internal cross-validation.
341   *  (default 1)</pre>
342   *
343   * <pre> -L &lt;num&gt;
344   *  Threshold on the improvement of the likelihood.
345   *  (default -Double.MAX_VALUE)</pre>
346   *
347   * <pre> -H &lt;num&gt;
348   *  Shrinkage parameter.
349   *  (default 1)</pre>
350   *
351   * <pre> -S &lt;num&gt;
352   *  Random number seed.
353   *  (default 1)</pre>
354   *
355   * <pre> -I &lt;num&gt;
356   *  Number of iterations.
357   *  (default 10)</pre>
358   *
359   * <pre> -D
360   *  If set, classifier is run in debug mode and
361   *  may output additional info to the console</pre>
362   *
363   * <pre> -W
364   *  Full name of base classifier.
365   *  (default: weka.classifiers.trees.DecisionStump)</pre>
366   *
367   * <pre>
368   * Options specific to classifier weka.classifiers.trees.DecisionStump:
369   * </pre>
370   *
371   * <pre> -D
372   *  If set, classifier is run in debug mode and
373   *  may output additional info to the console</pre>
374   *
375   <!-- options-end -->
376   *
377   * Options after -- are passed to the designated learner.<p>
378   *
379   * @param options the list of options as an array of strings
380   * @throws Exception if an option is not supported
381   */
382  public void setOptions(String[] options) throws Exception {
383   
384    String numFolds = Utils.getOption('F', options);
385    if (numFolds.length() != 0) {
386      setNumFolds(Integer.parseInt(numFolds));
387    } else {
388      setNumFolds(0);
389    }
390   
391    String numRuns = Utils.getOption('R', options);
392    if (numRuns.length() != 0) {
393      setNumRuns(Integer.parseInt(numRuns));
394    } else {
395      setNumRuns(1);
396    }
397
398    String thresholdString = Utils.getOption('P', options);
399    if (thresholdString.length() != 0) {
400      setWeightThreshold(Integer.parseInt(thresholdString));
401    } else {
402      setWeightThreshold(100);
403    }
404
405    String precisionString = Utils.getOption('L', options);
406    if (precisionString.length() != 0) {
407      setLikelihoodThreshold(new Double(precisionString).
408        doubleValue());
409    } else {
410      setLikelihoodThreshold(-Double.MAX_VALUE);
411    }
412
413    String shrinkageString = Utils.getOption('H', options);
414    if (shrinkageString.length() != 0) {
415      setShrinkage(new Double(shrinkageString).
416        doubleValue());
417    } else {
418      setShrinkage(1.0);
419    }
420
421    setUseResampling(Utils.getFlag('Q', options));
422    if (m_UseResampling && (thresholdString.length() != 0)) {
423      throw new Exception("Weight pruning with resampling"+
424                          "not allowed.");
425    }
426
427    super.setOptions(options);
428  }
429
430  /**
431   * Gets the current settings of the Classifier.
432   *
433   * @return an array of strings suitable for passing to setOptions
434   */
435  public String [] getOptions() {
436
437    String [] superOptions = super.getOptions();
438    String [] options = new String [superOptions.length + 10];
439
440    int current = 0;
441    if (getUseResampling()) {
442      options[current++] = "-Q";
443    } else {
444      options[current++] = "-P"; 
445      options[current++] = "" + getWeightThreshold();
446    }
447    options[current++] = "-F"; options[current++] = "" + getNumFolds();
448    options[current++] = "-R"; options[current++] = "" + getNumRuns();
449    options[current++] = "-L"; options[current++] = "" + getLikelihoodThreshold();
450    options[current++] = "-H"; options[current++] = "" + getShrinkage();
451
452    System.arraycopy(superOptions, 0, options, current, 
453                     superOptions.length);
454    current += superOptions.length;
455    while (current < options.length) {
456      options[current++] = "";
457    }
458    return options;
459  }
460 
461  /**
462   * Returns the tip text for this property
463   * @return tip text for this property suitable for
464   * displaying in the explorer/experimenter gui
465   */
466  public String shrinkageTipText() {
467    return "Shrinkage parameter (use small value like 0.1 to reduce "
468      + "overfitting).";
469  }
470                         
471  /**
472   * Get the value of Shrinkage.
473   *
474   * @return Value of Shrinkage.
475   */
476  public double getShrinkage() {
477   
478    return m_Shrinkage;
479  }
480 
481  /**
482   * Set the value of Shrinkage.
483   *
484   * @param newShrinkage Value to assign to Shrinkage.
485   */
486  public void setShrinkage(double newShrinkage) {
487   
488    m_Shrinkage = newShrinkage;
489  }
490 
491  /**
492   * Returns the tip text for this property
493   * @return tip text for this property suitable for
494   * displaying in the explorer/experimenter gui
495   */
496  public String likelihoodThresholdTipText() {
497    return "Threshold on improvement in likelihood.";
498  }
499                         
500  /**
501   * Get the value of Precision.
502   *
503   * @return Value of Precision.
504   */
505  public double getLikelihoodThreshold() {
506   
507    return m_Precision;
508  }
509 
510  /**
511   * Set the value of Precision.
512   *
513   * @param newPrecision Value to assign to Precision.
514   */
515  public void setLikelihoodThreshold(double newPrecision) {
516   
517    m_Precision = newPrecision;
518  }
519 
520  /**
521   * Returns the tip text for this property
522   * @return tip text for this property suitable for
523   * displaying in the explorer/experimenter gui
524   */
525  public String numRunsTipText() {
526    return "Number of runs for internal cross-validation.";
527  }
528 
529  /**
530   * Get the value of NumRuns.
531   *
532   * @return Value of NumRuns.
533   */
534  public int getNumRuns() {
535   
536    return m_NumRuns;
537  }
538 
539  /**
540   * Set the value of NumRuns.
541   *
542   * @param newNumRuns Value to assign to NumRuns.
543   */
544  public void setNumRuns(int newNumRuns) {
545   
546    m_NumRuns = newNumRuns;
547  }
548 
549  /**
550   * Returns the tip text for this property
551   * @return tip text for this property suitable for
552   * displaying in the explorer/experimenter gui
553   */
554  public String numFoldsTipText() {
555    return "Number of folds for internal cross-validation (default 0 "
556      + "means no cross-validation is performed).";
557  }
558 
559  /**
560   * Get the value of NumFolds.
561   *
562   * @return Value of NumFolds.
563   */
564  public int getNumFolds() {
565   
566    return m_NumFolds;
567  }
568 
569  /**
570   * Set the value of NumFolds.
571   *
572   * @param newNumFolds Value to assign to NumFolds.
573   */
574  public void setNumFolds(int newNumFolds) {
575   
576    m_NumFolds = newNumFolds;
577  }
578 
579  /**
580   * Returns the tip text for this property
581   * @return tip text for this property suitable for
582   * displaying in the explorer/experimenter gui
583   */
584  public String useResamplingTipText() {
585    return "Whether resampling is used instead of reweighting.";
586  }
587 
588  /**
589   * Set resampling mode
590   *
591   * @param r true if resampling should be done
592   */
593  public void setUseResampling(boolean r) {
594   
595    m_UseResampling = r;
596  }
597
598  /**
599   * Get whether resampling is turned on
600   *
601   * @return true if resampling output is on
602   */
603  public boolean getUseResampling() {
604   
605    return m_UseResampling;
606  }
607 
608  /**
609   * Returns the tip text for this property
610   * @return tip text for this property suitable for
611   * displaying in the explorer/experimenter gui
612   */
613  public String weightThresholdTipText() {
614    return "Weight threshold for weight pruning (reduce to 90 "
615      + "for speeding up learning process).";
616  }
617
618  /**
619   * Set weight thresholding
620   *
621   * @param threshold the percentage of weight mass used for training
622   */
623  public void setWeightThreshold(int threshold) {
624
625    m_WeightThreshold = threshold;
626  }
627
628  /**
629   * Get the degree of weight thresholding
630   *
631   * @return the percentage of weight mass used for training
632   */
633  public int getWeightThreshold() {
634
635    return m_WeightThreshold;
636  }
637
638  /**
639   * Returns default capabilities of the classifier.
640   *
641   * @return      the capabilities of this classifier
642   */
643  public Capabilities getCapabilities() {
644    Capabilities result = super.getCapabilities();
645
646    // class
647    result.disableAllClasses();
648    result.disableAllClassDependencies();
649    result.enable(Capability.NOMINAL_CLASS);
650   
651    return result;
652  }
653
654  /**
655   * Builds the boosted classifier
656   *
657   * @param data the data to train the classifier with
658   * @throws Exception if building fails, e.g., can't handle data
659   */
660  public void buildClassifier(Instances data) throws Exception {
661
662    m_RandomInstance = new Random(m_Seed);
663    int classIndex = data.classIndex();
664
665    if (m_Classifier == null) {
666      throw new Exception("A base classifier has not been specified!");
667    }
668   
669    if (!(m_Classifier instanceof WeightedInstancesHandler) &&
670        !m_UseResampling) {
671      m_UseResampling = true;
672    }
673
674    // can classifier handle the data?
675    getCapabilities().testWithFail(data);
676
677    if (m_Debug) {
678      System.err.println("Creating copy of the training data");
679    }
680
681    // remove instances with missing class
682    data = new Instances(data);
683    data.deleteWithMissingClass();
684   
685    // only class? -> build ZeroR model
686    if (data.numAttributes() == 1) {
687      System.err.println(
688          "Cannot build model (only class attribute present in data!), "
689          + "using ZeroR model instead!");
690      m_ZeroR = new weka.classifiers.rules.ZeroR();
691      m_ZeroR.buildClassifier(data);
692      return;
693    }
694    else {
695      m_ZeroR = null;
696    }
697   
698    m_NumClasses = data.numClasses();
699    m_ClassAttribute = data.classAttribute();
700
701    // Create the base classifiers
702    if (m_Debug) {
703      System.err.println("Creating base classifiers");
704    }
705    m_Classifiers = new Classifier [m_NumClasses][];
706    for (int j = 0; j < m_NumClasses; j++) {
707      m_Classifiers[j] = AbstractClassifier.makeCopies(m_Classifier,
708                                               getNumIterations());
709    }
710
711    // Do we want to select the appropriate number of iterations
712    // using cross-validation?
713    int bestNumIterations = getNumIterations();
714    if (m_NumFolds > 1) {
715      if (m_Debug) {
716        System.err.println("Processing first fold.");
717      }
718
719      // Array for storing the results
720      double[] results = new double[getNumIterations()];
721
722      // Iterate throught the cv-runs
723      for (int r = 0; r < m_NumRuns; r++) {
724
725        // Stratify the data
726        data.randomize(m_RandomInstance);
727        data.stratify(m_NumFolds);
728       
729        // Perform the cross-validation
730        for (int i = 0; i < m_NumFolds; i++) {
731         
732          // Get train and test folds
733          Instances train = data.trainCV(m_NumFolds, i, m_RandomInstance);
734          Instances test = data.testCV(m_NumFolds, i);
735         
736          // Make class numeric
737          Instances trainN = new Instances(train);
738          trainN.setClassIndex(-1);
739          trainN.deleteAttributeAt(classIndex);
740          trainN.insertAttributeAt(new Attribute("'pseudo class'"), classIndex);
741          trainN.setClassIndex(classIndex);
742          m_NumericClassData = new Instances(trainN, 0);
743         
744          // Get class values
745          int numInstances = train.numInstances();
746          double [][] trainFs = new double [numInstances][m_NumClasses];
747          double [][] trainYs = new double [numInstances][m_NumClasses];
748          for (int j = 0; j < m_NumClasses; j++) {
749            for (int k = 0; k < numInstances; k++) {
750              trainYs[k][j] = (train.instance(k).classValue() == j) ? 
751                1.0 - m_Offset: 0.0 + (m_Offset / (double)m_NumClasses);
752            }
753          }
754         
755          // Perform iterations
756          double[][] probs = initialProbs(numInstances);
757          m_NumGenerated = 0;
758          double sumOfWeights = train.sumOfWeights();
759          for (int j = 0; j < getNumIterations(); j++) {
760            performIteration(trainYs, trainFs, probs, trainN, sumOfWeights);
761            Evaluation eval = new Evaluation(train);
762            eval.evaluateModel(this, test);
763            results[j] += eval.correct();
764          }
765        }
766      }
767     
768      // Find the number of iterations with the lowest error
769      double bestResult = -Double.MAX_VALUE;
770      for (int j = 0; j < getNumIterations(); j++) {
771        if (results[j] > bestResult) {
772          bestResult = results[j];
773          bestNumIterations = j;
774        }
775      }
776      if (m_Debug) {
777        System.err.println("Best result for " + 
778                           bestNumIterations + " iterations: " +
779                           bestResult);
780      }
781    }
782
783    // Build classifier on all the data
784    int numInstances = data.numInstances();
785    double [][] trainFs = new double [numInstances][m_NumClasses];
786    double [][] trainYs = new double [numInstances][m_NumClasses];
787    for (int j = 0; j < m_NumClasses; j++) {
788      for (int i = 0, k = 0; i < numInstances; i++, k++) {
789        trainYs[i][j] = (data.instance(k).classValue() == j) ? 
790          1.0 - m_Offset: 0.0 + (m_Offset / (double)m_NumClasses);
791      }
792    }
793   
794    // Make class numeric
795    data.setClassIndex(-1);
796    data.deleteAttributeAt(classIndex);
797    data.insertAttributeAt(new Attribute("'pseudo class'"), classIndex);
798    data.setClassIndex(classIndex);
799    m_NumericClassData = new Instances(data, 0);
800       
801    // Perform iterations
802    double[][] probs = initialProbs(numInstances);
803    double logLikelihood = logLikelihood(trainYs, probs);
804    m_NumGenerated = 0;
805    if (m_Debug) {
806      System.err.println("Avg. log-likelihood: " + logLikelihood);
807    }
808    double sumOfWeights = data.sumOfWeights();
809    for (int j = 0; j < bestNumIterations; j++) {
810      double previousLoglikelihood = logLikelihood;
811      performIteration(trainYs, trainFs, probs, data, sumOfWeights);
812      logLikelihood = logLikelihood(trainYs, probs);
813      if (m_Debug) {
814        System.err.println("Avg. log-likelihood: " + logLikelihood);
815      }
816      if (Math.abs(previousLoglikelihood - logLikelihood) < m_Precision) {
817        return;
818      }
819    }
820  }
821
822  /**
823   * Gets the intial class probabilities.
824   *
825   * @param numInstances the number of instances
826   * @return the initial class probabilities
827   */
828  private double[][] initialProbs(int numInstances) {
829
830    double[][] probs = new double[numInstances][m_NumClasses];
831    for (int i = 0; i < numInstances; i++) {
832      for (int j = 0 ; j < m_NumClasses; j++) {
833        probs[i][j] = 1.0 / m_NumClasses;
834      }
835    }
836    return probs;
837  }
838
839  /**
840   * Computes loglikelihood given class values
841   * and estimated probablities.
842   *
843   * @param trainYs class values
844   * @param probs estimated probabilities
845   * @return the computed loglikelihood
846   */
847  private double logLikelihood(double[][] trainYs, double[][] probs) {
848
849    double logLikelihood = 0;
850    for (int i = 0; i < trainYs.length; i++) {
851      for (int j = 0; j < m_NumClasses; j++) {
852        if (trainYs[i][j] == 1.0 - m_Offset) {
853          logLikelihood -= Math.log(probs[i][j]);
854        }
855      }
856    }
857    return logLikelihood / (double)trainYs.length;
858  }
859
860  /**
861   * Performs one boosting iteration.
862   *
863   * @param trainYs class values
864   * @param trainFs F scores
865   * @param probs probabilities
866   * @param data the data to run the iteration on
867   * @param origSumOfWeights the original sum of weights
868   * @throws Exception in case base classifiers run into problems
869   */
870  private void performIteration(double[][] trainYs,
871                                double[][] trainFs,
872                                double[][] probs,
873                                Instances data,
874                                double origSumOfWeights) throws Exception {
875
876    if (m_Debug) {
877      System.err.println("Training classifier " + (m_NumGenerated + 1));
878    }
879
880    // Build the new models
881    for (int j = 0; j < m_NumClasses; j++) {
882      if (m_Debug) {
883        System.err.println("\t...for class " + (j + 1)
884                           + " (" + m_ClassAttribute.name() 
885                           + "=" + m_ClassAttribute.value(j) + ")");
886      }
887   
888      // Make copy because we want to save the weights
889      Instances boostData = new Instances(data);
890     
891      // Set instance pseudoclass and weights
892      for (int i = 0; i < probs.length; i++) {
893
894        // Compute response and weight
895        double p = probs[i][j];
896        double z, actual = trainYs[i][j];
897        if (actual == 1 - m_Offset) {
898          z = 1.0 / p;
899          if (z > Z_MAX) { // threshold
900            z = Z_MAX;
901          }
902        } else {
903          z = -1.0 / (1.0 - p);
904          if (z < -Z_MAX) { // threshold
905            z = -Z_MAX;
906          }
907        }
908        double w = (actual - p) / z;
909
910        // Set values for instance
911        Instance current = boostData.instance(i);
912        current.setValue(boostData.classIndex(), z);
913        current.setWeight(current.weight() * w);
914      }
915     
916      // Scale the weights (helps with some base learners)
917      double sumOfWeights = boostData.sumOfWeights();
918      double scalingFactor = (double)origSumOfWeights / sumOfWeights;
919      for (int i = 0; i < probs.length; i++) {
920        Instance current = boostData.instance(i);
921        current.setWeight(current.weight() * scalingFactor);
922      }
923
924      // Select instances to train the classifier on
925      Instances trainData = boostData;
926      if (m_WeightThreshold < 100) {
927        trainData = selectWeightQuantile(boostData, 
928                                         (double)m_WeightThreshold / 100);
929      } else {
930        if (m_UseResampling) {
931          double[] weights = new double[boostData.numInstances()];
932          for (int kk = 0; kk < weights.length; kk++) {
933            weights[kk] = boostData.instance(kk).weight();
934          }
935          trainData = boostData.resampleWithWeights(m_RandomInstance, 
936                                                    weights);
937        }
938      }
939     
940      // Build the classifier
941      m_Classifiers[j][m_NumGenerated].buildClassifier(trainData);
942    }     
943   
944    // Evaluate / increment trainFs from the classifier
945    for (int i = 0; i < trainFs.length; i++) {
946      double [] pred = new double [m_NumClasses];
947      double predSum = 0;
948      for (int j = 0; j < m_NumClasses; j++) {
949        pred[j] = m_Shrinkage * m_Classifiers[j][m_NumGenerated]
950          .classifyInstance(data.instance(i));
951        predSum += pred[j];
952      }
953      predSum /= m_NumClasses;
954      for (int j = 0; j < m_NumClasses; j++) {
955        trainFs[i][j] += (pred[j] - predSum) * (m_NumClasses - 1) 
956          / m_NumClasses;
957      }
958    }
959    m_NumGenerated++;
960   
961    // Compute the current probability estimates
962    for (int i = 0; i < trainYs.length; i++) {
963      probs[i] = probs(trainFs[i]);
964    }
965  }
966
967  /**
968   * Returns the array of classifiers that have been built.
969   *
970   * @return the built classifiers
971   */
972  public Classifier[][] classifiers() {
973
974    Classifier[][] classifiers = 
975      new Classifier[m_NumClasses][m_NumGenerated];
976    for (int j = 0; j < m_NumClasses; j++) {
977      for (int i = 0; i < m_NumGenerated; i++) {
978        classifiers[j][i] = m_Classifiers[j][i];
979      }
980    }
981    return classifiers;
982  }
983
984  /**
985   * Computes probabilities from F scores
986   *
987   * @param Fs the F scores
988   * @return the computed probabilities
989   */
990  private double[] probs(double[] Fs) {
991
992    double maxF = -Double.MAX_VALUE;
993    for (int i = 0; i < Fs.length; i++) {
994      if (Fs[i] > maxF) {
995        maxF = Fs[i];
996      }
997    }
998    double sum = 0;
999    double[] probs = new double[Fs.length];
1000    for (int i = 0; i < Fs.length; i++) {
1001      probs[i] = Math.exp(Fs[i] - maxF);
1002      sum += probs[i];
1003    }
1004    Utils.normalize(probs, sum);
1005    return probs;
1006  }
1007   
1008  /**
1009   * Calculates the class membership probabilities for the given test instance.
1010   *
1011   * @param instance the instance to be classified
1012   * @return predicted class probability distribution
1013   * @throws Exception if instance could not be classified
1014   * successfully
1015   */
1016  public double [] distributionForInstance(Instance instance) 
1017    throws Exception {
1018
1019    // default model?
1020    if (m_ZeroR != null) {
1021      return m_ZeroR.distributionForInstance(instance);
1022    }
1023   
1024    instance = (Instance)instance.copy();
1025    instance.setDataset(m_NumericClassData);
1026    double [] pred = new double [m_NumClasses];
1027    double [] Fs = new double [m_NumClasses]; 
1028    for (int i = 0; i < m_NumGenerated; i++) {
1029      double predSum = 0;
1030      for (int j = 0; j < m_NumClasses; j++) {
1031        pred[j] = m_Shrinkage * m_Classifiers[j][i].classifyInstance(instance);
1032        predSum += pred[j];
1033      }
1034      predSum /= m_NumClasses;
1035      for (int j = 0; j < m_NumClasses; j++) {
1036        Fs[j] += (pred[j] - predSum) * (m_NumClasses - 1) 
1037          / m_NumClasses;
1038      }
1039    }
1040
1041    return probs(Fs);
1042  }
1043
1044  /**
1045   * Returns the boosted model as Java source code.
1046   *
1047   * @param className the classname in the generated code
1048   * @return the tree as Java source code
1049   * @throws Exception if something goes wrong
1050   */
1051  public String toSource(String className) throws Exception {
1052
1053    if (m_NumGenerated == 0) {
1054      throw new Exception("No model built yet");
1055    }
1056    if (!(m_Classifiers[0][0] instanceof Sourcable)) {
1057      throw new Exception("Base learner " + m_Classifier.getClass().getName()
1058                          + " is not Sourcable");
1059    }
1060
1061    StringBuffer text = new StringBuffer("class ");
1062    text.append(className).append(" {\n\n");
1063    text.append("  private static double RtoP(double []R, int j) {\n"+
1064                "    double Rcenter = 0;\n"+
1065                "    for (int i = 0; i < R.length; i++) {\n"+
1066                "      Rcenter += R[i];\n"+
1067                "    }\n"+
1068                "    Rcenter /= R.length;\n"+
1069                "    double Rsum = 0;\n"+
1070                "    for (int i = 0; i < R.length; i++) {\n"+
1071                "      Rsum += Math.exp(R[i] - Rcenter);\n"+
1072                "    }\n"+
1073                "    return Math.exp(R[j]) / Rsum;\n"+
1074                "  }\n\n");
1075
1076    text.append("  public static double classify(Object[] i) {\n" +
1077                "    double [] d = distribution(i);\n" +
1078                "    double maxV = d[0];\n" +
1079                "    int maxI = 0;\n"+
1080                "    for (int j = 1; j < " + m_NumClasses + "; j++) {\n"+
1081                "      if (d[j] > maxV) { maxV = d[j]; maxI = j; }\n"+
1082                "    }\n    return (double) maxI;\n  }\n\n");
1083
1084    text.append("  public static double [] distribution(Object [] i) {\n");
1085    text.append("    double [] Fs = new double [" + m_NumClasses + "];\n");
1086    text.append("    double [] Fi = new double [" + m_NumClasses + "];\n");
1087    text.append("    double Fsum;\n");
1088    for (int i = 0; i < m_NumGenerated; i++) {
1089      text.append("    Fsum = 0;\n");
1090      for (int j = 0; j < m_NumClasses; j++) {
1091        text.append("    Fi[" + j + "] = " + className + '_' +j + '_' + i
1092                    + ".classify(i); Fsum += Fi[" + j + "];\n");
1093      }
1094      text.append("    Fsum /= " + m_NumClasses + ";\n");
1095      text.append("    for (int j = 0; j < " + m_NumClasses + "; j++) {");
1096      text.append(" Fs[j] += (Fi[j] - Fsum) * "
1097                  + (m_NumClasses - 1) + " / " + m_NumClasses + "; }\n");
1098    }
1099   
1100    text.append("    double [] dist = new double [" + m_NumClasses + "];\n" +
1101                "    for (int j = 0; j < " + m_NumClasses + "; j++) {\n"+
1102                "      dist[j] = RtoP(Fs, j);\n"+
1103                "    }\n    return dist;\n");
1104    text.append("  }\n}\n");
1105
1106    for (int i = 0; i < m_Classifiers.length; i++) {
1107      for (int j = 0; j < m_Classifiers[i].length; j++) {
1108        text.append(((Sourcable)m_Classifiers[i][j])
1109                    .toSource(className + '_' + i + '_' + j));
1110      }
1111    }
1112    return text.toString();
1113  }
1114
1115  /**
1116   * Returns description of the boosted classifier.
1117   *
1118   * @return description of the boosted classifier as a string
1119   */
1120  public String toString() {
1121   
1122    // only ZeroR model?
1123    if (m_ZeroR != null) {
1124      StringBuffer buf = new StringBuffer();
1125      buf.append(this.getClass().getName().replaceAll(".*\\.", "") + "\n");
1126      buf.append(this.getClass().getName().replaceAll(".*\\.", "").replaceAll(".", "=") + "\n\n");
1127      buf.append("Warning: No model could be built, hence ZeroR model is used:\n\n");
1128      buf.append(m_ZeroR.toString());
1129      return buf.toString();
1130    }
1131   
1132    StringBuffer text = new StringBuffer();
1133   
1134    if (m_NumGenerated == 0) {
1135      text.append("LogitBoost: No model built yet.");
1136      //      text.append(m_Classifiers[0].toString()+"\n");
1137    } else {
1138      text.append("LogitBoost: Base classifiers and their weights: \n");
1139      for (int i = 0; i < m_NumGenerated; i++) {
1140        text.append("\nIteration "+(i+1));
1141        for (int j = 0; j < m_NumClasses; j++) {
1142          text.append("\n\tClass " + (j + 1) 
1143                      + " (" + m_ClassAttribute.name() 
1144                      + "=" + m_ClassAttribute.value(j) + ")\n\n"
1145                      + m_Classifiers[j][i].toString() + "\n");
1146        }
1147      }
1148      text.append("Number of performed iterations: " +
1149                    m_NumGenerated + "\n");
1150    }
1151   
1152    return text.toString();
1153  }
1154 
1155  /**
1156   * Returns the revision string.
1157   *
1158   * @return            the revision
1159   */
1160  public String getRevision() {
1161    return RevisionUtils.extract("$Revision: 6091 $");
1162  }
1163
1164  /**
1165   * Main method for testing this class.
1166   *
1167   * @param argv the options
1168   */
1169  public static void main(String [] argv) {
1170    runClassifier(new LogitBoost(), argv);
1171  }
1172}
Note: See TracBrowser for help on using the repository browser.