source: src/main/java/weka/classifiers/mi/MIBoost.java @ 4

Last change on this file since 4 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 20.4 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 * MIBoost.java
19 * Copyright (C) 2005 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.mi;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.classifiers.SingleClassifierEnhancer;
28import weka.core.Capabilities;
29import weka.core.Instance;
30import weka.core.Instances;
31import weka.core.MultiInstanceCapabilitiesHandler;
32import weka.core.Optimization;
33import weka.core.Option;
34import weka.core.OptionHandler;
35import weka.core.RevisionUtils;
36import weka.core.TechnicalInformation;
37import weka.core.TechnicalInformationHandler;
38import weka.core.Utils;
39import weka.core.WeightedInstancesHandler;
40import weka.core.Capabilities.Capability;
41import weka.core.TechnicalInformation.Field;
42import weka.core.TechnicalInformation.Type;
43import weka.filters.Filter;
44import weka.filters.unsupervised.attribute.Discretize;
45import weka.filters.unsupervised.attribute.MultiInstanceToPropositional;
46
47import java.util.Enumeration;
48import java.util.Vector;
49
50/**
51 <!-- globalinfo-start -->
52 * MI AdaBoost method, considers the geometric mean of posterior of instances inside a bag (arithmatic mean of log-posterior) and the expectation for a bag is taken inside the loss function.<br/>
53 * <br/>
54 * For more information about Adaboost, see:<br/>
55 * <br/>
56 * Yoav Freund, Robert E. Schapire: Experiments with a new boosting algorithm. In: Thirteenth International Conference on Machine Learning, San Francisco, 148-156, 1996.
57 * <p/>
58 <!-- globalinfo-end -->
59 *
60 <!-- technical-bibtex-start -->
61 * BibTeX:
62 * <pre>
63 * &#64;inproceedings{Freund1996,
64 *    address = {San Francisco},
65 *    author = {Yoav Freund and Robert E. Schapire},
66 *    booktitle = {Thirteenth International Conference on Machine Learning},
67 *    pages = {148-156},
68 *    publisher = {Morgan Kaufmann},
69 *    title = {Experiments with a new boosting algorithm},
70 *    year = {1996}
71 * }
72 * </pre>
73 * <p/>
74 <!-- technical-bibtex-end -->
75 *
76 <!-- options-start -->
77 * Valid options are: <p/>
78 *
79 * <pre> -D
80 *  Turn on debugging output.</pre>
81 *
82 * <pre> -B &lt;num&gt;
83 *  The number of bins in discretization
84 *  (default 0, no discretization)</pre>
85 *
86 * <pre> -R &lt;num&gt;
87 *  Maximum number of boost iterations.
88 *  (default 10)</pre>
89 *
90 * <pre> -W &lt;class name&gt;
91 *  Full name of classifier to boost.
92 *  eg: weka.classifiers.bayes.NaiveBayes</pre>
93 *
94 * <pre> -D
95 *  If set, classifier is run in debug mode and
96 *  may output additional info to the console</pre>
97 *
98 <!-- options-end -->
99 *
100 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
101 * @author Xin Xu (xx5@cs.waikato.ac.nz)
102 * @version $Revision: 5928 $
103 */
104public class MIBoost 
105  extends SingleClassifierEnhancer
106  implements OptionHandler, MultiInstanceCapabilitiesHandler,
107             TechnicalInformationHandler {
108
109  /** for serialization */
110  static final long serialVersionUID = -3808427225599279539L;
111 
112  /** the models for the iterations */
113  protected Classifier[] m_Models;
114
115  /** The number of the class labels */
116  protected int m_NumClasses;
117
118  /** Class labels for each bag */
119  protected int[] m_Classes;
120
121  /** attributes name for the new dataset used to build the model  */
122  protected Instances m_Attributes;
123
124  /** Number of iterations */   
125  private int m_NumIterations = 100;
126
127  /** Voting weights of models */ 
128  protected double[] m_Beta;
129
130  /** the maximum number of boost iterations */
131  protected int m_MaxIterations = 10;
132
133  /** the number of discretization bins */
134  protected int m_DiscretizeBin = 0;
135
136  /** filter used for discretization */
137  protected Discretize m_Filter = null;
138
139  /** filter used to convert the MI dataset into single-instance dataset */
140  protected MultiInstanceToPropositional m_ConvertToSI = new MultiInstanceToPropositional();
141
142  /**
143   * Returns a string describing this filter
144   *
145   * @return a description of the filter suitable for
146   * displaying in the explorer/experimenter gui
147   */
148  public String globalInfo() {
149    return 
150        "MI AdaBoost method, considers the geometric mean of posterior "
151      + "of instances inside a bag (arithmatic mean of log-posterior) and "
152      + "the expectation for a bag is taken inside the loss function.\n\n"
153      + "For more information about Adaboost, see:\n\n"
154      + getTechnicalInformation().toString();
155  }
156
157  /**
158   * Returns an instance of a TechnicalInformation object, containing
159   * detailed information about the technical background of this class,
160   * e.g., paper reference or book this class is based on.
161   *
162   * @return the technical information about this class
163   */
164  public TechnicalInformation getTechnicalInformation() {
165    TechnicalInformation        result;
166   
167    result = new TechnicalInformation(Type.INPROCEEDINGS);
168    result.setValue(Field.AUTHOR, "Yoav Freund and Robert E. Schapire");
169    result.setValue(Field.TITLE, "Experiments with a new boosting algorithm");
170    result.setValue(Field.BOOKTITLE, "Thirteenth International Conference on Machine Learning");
171    result.setValue(Field.YEAR, "1996");
172    result.setValue(Field.PAGES, "148-156");
173    result.setValue(Field.PUBLISHER, "Morgan Kaufmann");
174    result.setValue(Field.ADDRESS, "San Francisco");
175   
176    return result;
177  }
178
179  /**
180   * Returns an enumeration describing the available options
181   *
182   * @return an enumeration of all the available options
183   */
184  public Enumeration listOptions() {
185    Vector result = new Vector();
186
187    result.addElement(new Option(
188          "\tTurn on debugging output.",
189          "D", 0, "-D"));
190
191    result.addElement(new Option(
192          "\tThe number of bins in discretization\n"
193          + "\t(default 0, no discretization)",
194          "B", 1, "-B <num>")); 
195
196    result.addElement(new Option(
197          "\tMaximum number of boost iterations.\n"
198          + "\t(default 10)",
199          "R", 1, "-R <num>")); 
200
201    result.addElement(new Option(
202          "\tFull name of classifier to boost.\n"
203          + "\teg: weka.classifiers.bayes.NaiveBayes",
204          "W", 1, "-W <class name>"));
205
206    Enumeration enu = ((OptionHandler)m_Classifier).listOptions();
207    while (enu.hasMoreElements()) {
208      result.addElement(enu.nextElement());
209    }
210
211    return result.elements();
212  }
213
214  /**
215   * Parses a given list of options. <p/>
216   *
217   <!-- options-start -->
218   * Valid options are: <p/>
219   *
220   * <pre> -D
221   *  Turn on debugging output.</pre>
222   *
223   * <pre> -B &lt;num&gt;
224   *  The number of bins in discretization
225   *  (default 0, no discretization)</pre>
226   *
227   * <pre> -R &lt;num&gt;
228   *  Maximum number of boost iterations.
229   *  (default 10)</pre>
230   *
231   * <pre> -W &lt;class name&gt;
232   *  Full name of classifier to boost.
233   *  eg: weka.classifiers.bayes.NaiveBayes</pre>
234   *
235   * <pre> -D
236   *  If set, classifier is run in debug mode and
237   *  may output additional info to the console</pre>
238   *
239   <!-- options-end -->
240   *
241   * @param options the list of options as an array of strings
242   * @throws Exception if an option is not supported
243   */
244  public void setOptions(String[] options) throws Exception {
245    setDebug(Utils.getFlag('D', options));
246
247    String bin = Utils.getOption('B', options);
248    if (bin.length() != 0) {
249      setDiscretizeBin(Integer.parseInt(bin));
250    } else {
251      setDiscretizeBin(0);
252    }
253
254    String boostIterations = Utils.getOption('R', options);
255    if (boostIterations.length() != 0) {
256      setMaxIterations(Integer.parseInt(boostIterations));
257    } else {
258      setMaxIterations(10);
259    }
260
261    super.setOptions(options);
262  }
263
264  /**
265   * Gets the current settings of the classifier.
266   *
267   * @return an array of strings suitable for passing to setOptions
268   */
269  public String[] getOptions() {
270    Vector        result;
271    String[]      options;
272    int           i;
273   
274    result  = new Vector();
275
276    result.add("-R");
277    result.add("" + getMaxIterations());
278
279    result.add("-B");
280    result.add("" + getDiscretizeBin());
281
282    options = super.getOptions();
283    for (i = 0; i < options.length; i++)
284      result.add(options[i]);
285
286    return (String[]) result.toArray(new String[result.size()]);
287  }
288
289  /**
290   * Returns the tip text for this property
291   *
292   * @return tip text for this property suitable for
293   * displaying in the explorer/experimenter gui
294   */
295  public String maxIterationsTipText() {
296    return "The maximum number of boost iterations.";
297  }
298
299  /**
300   * Set the maximum number of boost iterations
301   *
302   * @param maxIterations the maximum number of boost iterations
303   */
304  public void setMaxIterations(int maxIterations) {     
305    m_MaxIterations = maxIterations;
306  }
307
308  /**
309   * Get the maximum number of boost iterations
310   *
311   * @return the maximum number of boost iterations
312   */
313  public int getMaxIterations() {
314
315    return m_MaxIterations;
316  }
317
318  /**
319   * Returns the tip text for this property
320   *
321   * @return tip text for this property suitable for
322   * displaying in the explorer/experimenter gui
323   */
324  public String discretizeBinTipText() {
325    return "The number of bins in discretization.";
326  }
327
328  /**
329   * Set the number of bins in discretization
330   *
331   * @param bin the number of bins in discretization
332   */
333  public void setDiscretizeBin(int bin) {       
334    m_DiscretizeBin = bin;
335  }
336
337  /**
338   * Get the number of bins in discretization
339   *
340   * @return the number of bins in discretization
341   */
342  public int getDiscretizeBin() {       
343    return m_DiscretizeBin;
344  }
345
346  private class OptEng 
347    extends Optimization {
348   
349    private double[] weights, errs;
350
351    public void setWeights(double[] w){
352      weights = w;
353    }
354
355    public void setErrs(double[] e){
356      errs = e;
357    }
358
359    /**
360     * Evaluate objective function
361     * @param x the current values of variables
362     * @return the value of the objective function
363     * @throws Exception if result is NaN
364     */
365    protected double objectiveFunction(double[] x) throws Exception{
366      double obj=0;
367      for(int i=0; i<weights.length; i++){
368        obj += weights[i]*Math.exp(x[0]*(2.0*errs[i]-1.0));
369        if(Double.isNaN(obj))
370          throw new Exception("Objective function value is NaN!");
371
372      }
373      return obj;
374    }
375
376    /**
377     * Evaluate Jacobian vector
378     * @param x the current values of variables
379     * @return the gradient vector
380     * @throws Exception if gradient is NaN
381     */
382    protected double[] evaluateGradient(double[] x)  throws Exception{
383      double[] grad = new double[1];
384      for(int i=0; i<weights.length; i++){
385        grad[0] += weights[i]*(2.0*errs[i]-1.0)*Math.exp(x[0]*(2.0*errs[i]-1.0));
386        if(Double.isNaN(grad[0]))
387          throw new Exception("Gradient is NaN!");
388
389      }
390      return grad;
391    }
392   
393    /**
394     * Returns the revision string.
395     *
396     * @return          the revision
397     */
398    public String getRevision() {
399      return RevisionUtils.extract("$Revision: 5928 $");
400    }
401  }
402
403  /**
404   * Returns default capabilities of the classifier.
405   *
406   * @return      the capabilities of this classifier
407   */
408  public Capabilities getCapabilities() {
409    Capabilities result = super.getCapabilities();
410
411    // attributes
412    result.enable(Capability.NOMINAL_ATTRIBUTES);
413    result.enable(Capability.RELATIONAL_ATTRIBUTES);
414    result.enable(Capability.MISSING_VALUES);
415
416    // class
417    result.disableAllClasses();
418    result.disableAllClassDependencies();
419    if (super.getCapabilities().handles(Capability.BINARY_CLASS))
420      result.enable(Capability.BINARY_CLASS);
421    result.enable(Capability.MISSING_CLASS_VALUES);
422   
423    // other
424    result.enable(Capability.ONLY_MULTIINSTANCE);
425   
426    return result;
427  }
428
429  /**
430   * Returns the capabilities of this multi-instance classifier for the
431   * relational data.
432   *
433   * @return            the capabilities of this object
434   * @see               Capabilities
435   */
436  public Capabilities getMultiInstanceCapabilities() {
437    Capabilities result = super.getCapabilities();
438   
439    // class
440    result.disableAllClasses();
441    result.enable(Capability.NO_CLASS);
442   
443    return result;
444  }
445
446  /**
447   * Builds the classifier
448   *
449   * @param exps the training data to be used for generating the
450   * boosted classifier.
451   * @throws Exception if the classifier could not be built successfully
452   */
453  public void buildClassifier(Instances exps) throws Exception {
454
455    // can classifier handle the data?
456    getCapabilities().testWithFail(exps);
457
458    // remove instances with missing class
459    Instances train = new Instances(exps);
460    train.deleteWithMissingClass();
461
462    m_NumClasses = train.numClasses();
463    m_NumIterations = m_MaxIterations;
464
465    if (m_Classifier == null)
466      throw new Exception("A base classifier has not been specified!");
467    if(!(m_Classifier instanceof WeightedInstancesHandler))
468      throw new Exception("Base classifier cannot handle weighted instances!");
469
470    m_Models = AbstractClassifier.makeCopies(m_Classifier, getMaxIterations());
471    if(m_Debug)
472      System.err.println("Base classifier: "+m_Classifier.getClass().getName());
473
474    m_Beta = new double[m_NumIterations];
475
476    /* modified by Lin Dong. (use MIToSingleInstance filter to convert the MI datasets) */
477
478    //Initialize the bags' weights
479    double N = (double)train.numInstances(), sumNi=0;
480    for(int i=0; i<N; i++)
481      sumNi += train.instance(i).relationalValue(1).numInstances();     
482    for(int i=0; i<N; i++){
483      train.instance(i).setWeight(sumNi/N);
484    }
485
486    //convert the training dataset into single-instance dataset
487    m_ConvertToSI.setInputFormat(train);
488    Instances data = Filter.useFilter( train, m_ConvertToSI);
489    data.deleteAttributeAt(0); //remove the bagIndex attribute;
490
491
492    // Assume the order of the instances are preserved in the Discretize filter
493    if(m_DiscretizeBin > 0){
494      m_Filter = new Discretize();
495      m_Filter.setInputFormat(new Instances(data, 0));
496      m_Filter.setBins(m_DiscretizeBin);
497      data = Filter.useFilter(data, m_Filter);
498    }
499
500    // Main algorithm
501    int dataIdx;
502iterations:
503    for(int m=0; m < m_MaxIterations; m++){
504      if(m_Debug)
505        System.err.println("\nIteration "+m); 
506
507
508      // Build a model
509      m_Models[m].buildClassifier(data);
510
511      // Prediction of each bag
512      double[] err=new double[(int)N], weights=new double[(int)N];
513      boolean perfect = true, tooWrong=true;
514      dataIdx = 0;
515      for(int n=0; n<N; n++){
516        Instance exn = train.instance(n);
517        // Prediction of each instance and the predicted class distribution
518        // of the bag           
519        double nn = (double)exn.relationalValue(1).numInstances();
520        for(int p=0; p<nn; p++){
521          Instance testIns = data.instance(dataIdx++);                 
522          if((int)m_Models[m].classifyInstance(testIns) 
523              != (int)exn.classValue()) // Weighted instance-wise 0-1 errors
524            err[n] ++;                                 
525        }
526        weights[n] = exn.weight();
527        err[n] /= nn;
528        if(err[n] > 0.5)
529          perfect = false;
530        if(err[n] < 0.5)
531          tooWrong = false;
532      }
533
534      if(perfect || tooWrong){ // No or 100% classification error, cannot find beta
535        if (m == 0)
536          m_Beta[m] = 1.0;
537        else               
538          m_Beta[m] = 0;               
539        m_NumIterations = m+1;
540        if(m_Debug)  System.err.println("No errors");
541        break iterations;
542      }
543
544      double[] x = new double[1];
545      x[0] = 0;
546      double[][] b = new double[2][x.length];
547      b[0][0] = Double.NaN;
548      b[1][0] = Double.NaN;
549
550      OptEng opt = new OptEng();       
551      opt.setWeights(weights);
552      opt.setErrs(err);
553      //opt.setDebug(m_Debug);
554      if (m_Debug)
555        System.out.println("Start searching for c... ");
556      x = opt.findArgmin(x, b);
557      while(x==null){
558        x = opt.getVarbValues();
559        if (m_Debug)
560          System.out.println("200 iterations finished, not enough!");
561        x = opt.findArgmin(x, b);
562      } 
563      if (m_Debug)
564        System.out.println("Finished.");   
565      m_Beta[m] = x[0];
566
567      if(m_Debug)
568        System.err.println("c = "+m_Beta[m]);
569
570      // Stop if error too small or error too big and ignore this model
571      if (Double.isInfinite(m_Beta[m]) 
572          || Utils.smOrEq(m_Beta[m], 0)
573         ) {
574        if (m == 0)
575          m_Beta[m] = 1.0;
576        else               
577          m_Beta[m] = 0;
578        m_NumIterations = m+1;
579        if(m_Debug)
580          System.err.println("Errors out of range!");
581        break iterations;
582         }
583
584      // Update weights of data and class label of wfData
585      dataIdx=0;
586      double totWeights=0;
587      for(int r=0; r<N; r++){           
588        Instance exr = train.instance(r);
589        exr.setWeight(weights[r]*Math.exp(m_Beta[m]*(2.0*err[r]-1.0)));
590        totWeights += exr.weight();
591      }
592
593      if(m_Debug)
594        System.err.println("Total weights = "+totWeights);
595
596      for(int r=0; r<N; r++){           
597        Instance exr = train.instance(r);
598        double num = (double)exr.relationalValue(1).numInstances();
599        exr.setWeight(sumNi*exr.weight()/totWeights);
600        //if(m_Debug)
601        //    System.err.print("\nExemplar "+r+"="+exr.weight()+": \t");
602        for(int s=0; s<num; s++){
603          Instance inss = data.instance(dataIdx);       
604          inss.setWeight(exr.weight()/num);               
605          //    if(m_Debug)
606          //  System.err.print("instance "+s+"="+inss.weight()+
607          //                     "|ew*iw*sumNi="+data.instance(dataIdx).weight()+"\t");
608          if(Double.isNaN(inss.weight()))
609            throw new Exception("instance "+s+" in bag "+r+" has weight NaN!"); 
610          dataIdx++;
611        }
612        //if(m_Debug)
613        //    System.err.println();
614      }       
615    }
616  }             
617
618  /**
619   * Computes the distribution for a given exemplar
620   *
621   * @param exmp the exemplar for which distribution is computed
622   * @return the classification
623   * @throws Exception if the distribution can't be computed successfully
624   */
625  public double[] distributionForInstance(Instance exmp) 
626    throws Exception { 
627
628    double[] rt = new double[m_NumClasses];
629
630    Instances insts = new Instances(exmp.dataset(), 0);
631    insts.add(exmp);
632
633    // convert the training dataset into single-instance dataset
634    insts = Filter.useFilter( insts, m_ConvertToSI);
635    insts.deleteAttributeAt(0); //remove the bagIndex attribute
636
637    double n = insts.numInstances();
638
639    if(m_DiscretizeBin > 0)
640      insts = Filter.useFilter(insts, m_Filter);
641
642    for(int y=0; y<n; y++){
643      Instance ins = insts.instance(y); 
644      for(int x=0; x<m_NumIterations; x++){ 
645        rt[(int)m_Models[x].classifyInstance(ins)] += m_Beta[x]/n;
646      }
647    }
648
649    for(int i=0; i<rt.length; i++)
650      rt[i] = Math.exp(rt[i]);
651
652    Utils.normalize(rt);
653    return rt;
654  }
655
656  /**
657   * Gets a string describing the classifier.
658   *
659   * @return a string describing the classifer built.
660   */
661  public String toString() {
662
663    if (m_Models == null) {
664      return "No model built yet!";
665    }
666    StringBuffer text = new StringBuffer();
667    text.append("MIBoost: number of bins in discretization = "+m_DiscretizeBin+"\n");
668    if (m_NumIterations == 0) {
669      text.append("No model built yet.\n");
670    } else if (m_NumIterations == 1) {
671      text.append("No boosting possible, one classifier used: Weight = " 
672          + Utils.roundDouble(m_Beta[0], 2)+"\n");
673      text.append("Base classifiers:\n"+m_Models[0].toString());
674    } else {
675      text.append("Base classifiers and their weights: \n");
676      for (int i = 0; i < m_NumIterations ; i++) {
677        text.append("\n\n"+i+": Weight = " + Utils.roundDouble(m_Beta[i], 2)
678            +"\nBase classifier:\n"+m_Models[i].toString() );
679      }
680    }
681
682    text.append("\n\nNumber of performed Iterations: " 
683        + m_NumIterations + "\n");
684
685    return text.toString();
686  }
687 
688  /**
689   * Returns the revision string.
690   *
691   * @return            the revision
692   */
693  public String getRevision() {
694    return RevisionUtils.extract("$Revision: 5928 $");
695  }
696
697  /**
698   * Main method for testing this class.
699   *
700   * @param argv should contain the command line arguments to the
701   * scheme (see Evaluation)
702   */
703  public static void main(String[] argv) {
704    runClassifier(new MIBoost(), argv);
705  }
706}
Note: See TracBrowser for help on using the repository browser.