source: src/main/java/weka/classifiers/meta/Vote.java @ 25

Last change on this file since 25 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 19.8 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    Vote.java
19 *    Copyright (C) 2000 University of Waikato
20 *    Copyright (C) 2006 Roberto Perdisci
21 *
22 */
23
24package weka.classifiers.meta;
25
26import weka.classifiers.RandomizableMultipleClassifiersCombiner;
27import weka.core.Capabilities;
28import weka.core.Instance;
29import weka.core.Instances;
30import weka.core.Option;
31import weka.core.RevisionUtils;
32import weka.core.SelectedTag;
33import weka.core.Tag;
34import weka.core.TechnicalInformation;
35import weka.core.TechnicalInformationHandler;
36import weka.core.Utils;
37import weka.core.Capabilities.Capability;
38import weka.core.TechnicalInformation.Field;
39import weka.core.TechnicalInformation.Type;
40
41import java.util.Enumeration;
42import java.util.Random;
43import java.util.Vector;
44
45/**
46 <!-- globalinfo-start -->
47 * Class for combining classifiers. Different combinations of probability estimates for classification are available.<br/>
48 * <br/>
49 * For more information see:<br/>
50 * <br/>
51 * Ludmila I. Kuncheva (2004). Combining Pattern Classifiers: Methods and Algorithms. John Wiley and Sons, Inc..<br/>
52 * <br/>
53 * J. Kittler, M. Hatef, Robert P.W. Duin, J. Matas (1998). On combining classifiers. IEEE Transactions on Pattern Analysis and Machine Intelligence. 20(3):226-239.
54 * <p/>
55 <!-- globalinfo-end -->
56 *
57 <!-- options-start -->
58 * Valid options are: <p/>
59 *
60 * <pre> -S &lt;num&gt;
61 *  Random number seed.
62 *  (default 1)</pre>
63 *
64 * <pre> -B &lt;classifier specification&gt;
65 *  Full class name of classifier to include, followed
66 *  by scheme options. May be specified multiple times.
67 *  (default: "weka.classifiers.rules.ZeroR")</pre>
68 *
69 * <pre> -D
70 *  If set, classifier is run in debug mode and
71 *  may output additional info to the console</pre>
72 *
73 * <pre> -R &lt;AVG|PROD|MAJ|MIN|MAX|MED&gt;
74 *  The combination rule to use
75 *  (default: AVG)</pre>
76 *
77 <!-- options-end -->
78 *
79 <!-- technical-bibtex-start -->
80 * BibTeX:
81 * <pre>
82 * &#64;book{Kuncheva2004,
83 *    author = {Ludmila I. Kuncheva},
84 *    publisher = {John Wiley and Sons, Inc.},
85 *    title = {Combining Pattern Classifiers: Methods and Algorithms},
86 *    year = {2004}
87 * }
88 *
89 * &#64;article{Kittler1998,
90 *    author = {J. Kittler and M. Hatef and Robert P.W. Duin and J. Matas},
91 *    journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence},
92 *    number = {3},
93 *    pages = {226-239},
94 *    title = {On combining classifiers},
95 *    volume = {20},
96 *    year = {1998}
97 * }
98 * </pre>
99 * <p/>
100 <!-- technical-bibtex-end -->
101 *
102 * @author Alexander K. Seewald (alex@seewald.at)
103 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
104 * @author Roberto Perdisci (roberto.perdisci@gmail.com)
105 * @version $Revision: 5987 $
106 */
107public class Vote
108  extends RandomizableMultipleClassifiersCombiner
109  implements TechnicalInformationHandler {
110   
111  /** for serialization */
112  static final long serialVersionUID = -637891196294399624L;
113 
114  /** combination rule: Average of Probabilities */
115  public static final int AVERAGE_RULE = 1;
116  /** combination rule: Product of Probabilities (only nominal classes) */
117  public static final int PRODUCT_RULE = 2;
118  /** combination rule: Majority Voting (only nominal classes) */
119  public static final int MAJORITY_VOTING_RULE = 3;
120  /** combination rule: Minimum Probability */
121  public static final int MIN_RULE = 4;
122  /** combination rule: Maximum Probability */
123  public static final int MAX_RULE = 5;
124  /** combination rule: Median Probability (only numeric class) */
125  public static final int MEDIAN_RULE = 6;
126  /** combination rules */
127  public static final Tag[] TAGS_RULES = {
128    new Tag(AVERAGE_RULE, "AVG", "Average of Probabilities"),
129    new Tag(PRODUCT_RULE, "PROD", "Product of Probabilities"),
130    new Tag(MAJORITY_VOTING_RULE, "MAJ", "Majority Voting"),
131    new Tag(MIN_RULE, "MIN", "Minimum Probability"),
132    new Tag(MAX_RULE, "MAX", "Maximum Probability"),
133    new Tag(MEDIAN_RULE, "MED", "Median")
134  };
135 
136  /** Combination Rule variable */
137  protected int m_CombinationRule = AVERAGE_RULE;
138 
139  /** the random number generator used for breaking ties in majority voting
140   * @see #distributionForInstanceMajorityVoting(Instance) */
141  protected Random m_Random;
142 
143  /**
144   * Returns a string describing classifier
145   * @return a description suitable for
146   * displaying in the explorer/experimenter gui
147   */
148  public String globalInfo() {
149    return 
150        "Class for combining classifiers. Different combinations of "
151      + "probability estimates for classification are available.\n\n"
152      + "For more information see:\n\n"
153      + getTechnicalInformation().toString();
154  }
155 
156  /**
157   * Returns an enumeration describing the available options.
158   *
159   * @return an enumeration of all the available options.
160   */
161  public Enumeration listOptions() {
162    Enumeration         enm;
163    Vector              result;
164   
165    result = new Vector();
166   
167    enm = super.listOptions();
168    while (enm.hasMoreElements())
169      result.addElement(enm.nextElement());
170
171    result.addElement(new Option(
172        "\tThe combination rule to use\n"
173        + "\t(default: AVG)",
174        "R", 1, "-R " + Tag.toOptionList(TAGS_RULES)));
175   
176    return result.elements();
177  }
178 
179  /**
180   * Gets the current settings of Vote.
181   *
182   * @return an array of strings suitable for passing to setOptions()
183   */
184  public String [] getOptions() {
185    int         i;
186    Vector      result;
187    String[]    options;
188
189    result = new Vector();
190
191    options = super.getOptions();
192    for (i = 0; i < options.length; i++)
193      result.add(options[i]);
194
195    result.add("-R");
196    result.add("" + getCombinationRule());
197
198    return (String[]) result.toArray(new String[result.size()]);
199  }
200 
201  /**
202   * Parses a given list of options. <p/>
203   *
204   <!-- options-start -->
205   * Valid options are: <p/>
206   *
207   * <pre> -S &lt;num&gt;
208   *  Random number seed.
209   *  (default 1)</pre>
210   *
211   * <pre> -B &lt;classifier specification&gt;
212   *  Full class name of classifier to include, followed
213   *  by scheme options. May be specified multiple times.
214   *  (default: "weka.classifiers.rules.ZeroR")</pre>
215   *
216   * <pre> -D
217   *  If set, classifier is run in debug mode and
218   *  may output additional info to the console</pre>
219   *
220   * <pre> -R &lt;AVG|PROD|MAJ|MIN|MAX|MED&gt;
221   *  The combination rule to use
222   *  (default: AVG)</pre>
223   *
224   <!-- options-end -->
225   *
226   * @param options the list of options as an array of strings
227   * @throws Exception if an option is not supported
228   */
229  public void setOptions(String[] options) throws Exception {
230    String      tmpStr;
231   
232    tmpStr = Utils.getOption('R', options);
233    if (tmpStr.length() != 0) 
234      setCombinationRule(new SelectedTag(tmpStr, TAGS_RULES));
235    else
236      setCombinationRule(new SelectedTag(AVERAGE_RULE, TAGS_RULES));
237
238    super.setOptions(options);
239  }
240
241  /**
242   * Returns an instance of a TechnicalInformation object, containing
243   * detailed information about the technical background of this class,
244   * e.g., paper reference or book this class is based on.
245   *
246   * @return the technical information about this class
247   */
248  public TechnicalInformation getTechnicalInformation() {
249    TechnicalInformation        result;
250    TechnicalInformation        additional;
251   
252    result = new TechnicalInformation(Type.BOOK);
253    result.setValue(Field.AUTHOR, "Ludmila I. Kuncheva");
254    result.setValue(Field.TITLE, "Combining Pattern Classifiers: Methods and Algorithms");
255    result.setValue(Field.YEAR, "2004");
256    result.setValue(Field.PUBLISHER, "John Wiley and Sons, Inc.");
257
258    additional = result.add(Type.ARTICLE);
259    additional.setValue(Field.AUTHOR, "J. Kittler and M. Hatef and Robert P.W. Duin and J. Matas");
260    additional.setValue(Field.YEAR, "1998");
261    additional.setValue(Field.TITLE, "On combining classifiers");
262    additional.setValue(Field.JOURNAL, "IEEE Transactions on Pattern Analysis and Machine Intelligence");
263    additional.setValue(Field.VOLUME, "20");
264    additional.setValue(Field.NUMBER, "3");
265    additional.setValue(Field.PAGES, "226-239");
266   
267    return result;
268  }
269 
270  /**
271   * Returns default capabilities of the classifier.
272   *
273   * @return      the capabilities of this classifier
274   */
275  public Capabilities getCapabilities() {
276    Capabilities result = super.getCapabilities();
277
278    // class
279    if (    (m_CombinationRule == PRODUCT_RULE) 
280         || (m_CombinationRule == MAJORITY_VOTING_RULE) ) {
281      result.disableAllClasses();
282      result.disableAllClassDependencies();
283      result.enable(Capability.NOMINAL_CLASS);
284      result.enableDependency(Capability.NOMINAL_CLASS);
285    }
286    else if (m_CombinationRule == MEDIAN_RULE) {
287      result.disableAllClasses();
288      result.disableAllClassDependencies();
289      result.enable(Capability.NUMERIC_CLASS);
290      result.enableDependency(Capability.NUMERIC_CLASS);
291    }
292   
293    return result;
294  }
295
296  /**
297   * Buildclassifier selects a classifier from the set of classifiers
298   * by minimising error on the training data.
299   *
300   * @param data the training data to be used for generating the
301   * boosted classifier.
302   * @throws Exception if the classifier could not be built successfully
303   */
304  public void buildClassifier(Instances data) throws Exception {
305
306    // can classifier handle the data?
307    getCapabilities().testWithFail(data);
308
309    // remove instances with missing class
310    Instances newData = new Instances(data);
311    newData.deleteWithMissingClass();
312
313    m_Random = new Random(getSeed());
314   
315    for (int i = 0; i < m_Classifiers.length; i++) {
316      getClassifier(i).buildClassifier(newData);
317    }
318  }
319
320  /**
321   * Classifies the given test instance.
322   *
323   * @param instance the instance to be classified
324   * @return the predicted most likely class for the instance or
325   * Utils.missingValue() if no prediction is made
326   * @throws Exception if an error occurred during the prediction
327   */
328  public double classifyInstance(Instance instance) throws Exception {
329    double result;
330    double[] dist;
331    int index;
332   
333    switch (m_CombinationRule) {
334      case AVERAGE_RULE:
335      case PRODUCT_RULE:
336      case MAJORITY_VOTING_RULE:
337      case MIN_RULE:
338      case MAX_RULE:
339        dist = distributionForInstance(instance);
340        if (instance.classAttribute().isNominal()) {
341          index = Utils.maxIndex(dist);
342          if (dist[index] == 0)
343            result = Utils.missingValue();
344          else
345            result = index;
346        }
347        else if (instance.classAttribute().isNumeric()){
348          result = dist[0];
349        }
350        else {
351          result = Utils.missingValue();
352        }
353        break;
354      case MEDIAN_RULE:
355        result = classifyInstanceMedian(instance);
356        break;
357      default:
358        throw new IllegalStateException("Unknown combination rule '" + m_CombinationRule + "'!");
359    }
360   
361    return result;
362  }
363
364  /**
365   * Classifies the given test instance, returning the median from all
366   * classifiers.
367   *
368   * @param instance the instance to be classified
369   * @return the predicted most likely class for the instance or
370   * Utils.missingValue() if no prediction is made
371   * @throws Exception if an error occurred during the prediction
372   */
373  protected double classifyInstanceMedian(Instance instance) throws Exception {
374    double[] results = new double[m_Classifiers.length];
375    double result;
376
377    for (int i = 0; i < results.length; i++)
378      results[i] = m_Classifiers[i].classifyInstance(instance);
379   
380    if (results.length == 0)
381      result = 0;
382    else if (results.length == 1)
383      result = results[0];
384    else
385      result = Utils.kthSmallestValue(results, results.length / 2);
386   
387    return result;
388  }
389
390  /**
391   * Classifies a given instance using the selected combination rule.
392   *
393   * @param instance the instance to be classified
394   * @return the distribution
395   * @throws Exception if instance could not be classified
396   * successfully
397   */
398  public double[] distributionForInstance(Instance instance) throws Exception {
399    double[] result = new double[instance.numClasses()];
400   
401    switch (m_CombinationRule) {
402      case AVERAGE_RULE:
403        result = distributionForInstanceAverage(instance);
404        break;
405      case PRODUCT_RULE:
406        result = distributionForInstanceProduct(instance);
407        break;
408      case MAJORITY_VOTING_RULE:
409        result = distributionForInstanceMajorityVoting(instance);
410        break;
411      case MIN_RULE:
412        result = distributionForInstanceMin(instance);
413        break;
414      case MAX_RULE:
415        result = distributionForInstanceMax(instance);
416        break;
417      case MEDIAN_RULE:
418        result[0] = classifyInstance(instance);
419        break;
420      default:
421        throw new IllegalStateException("Unknown combination rule '" + m_CombinationRule + "'!");
422    }
423   
424    if (!instance.classAttribute().isNumeric() && (Utils.sum(result) > 0))
425      Utils.normalize(result);
426   
427    return result;
428  }
429 
430  /**
431   * Classifies a given instance using the Average of Probabilities
432   * combination rule.
433   *
434   * @param instance the instance to be classified
435   * @return the distribution
436   * @throws Exception if instance could not be classified
437   * successfully
438   */
439  protected double[] distributionForInstanceAverage(Instance instance) throws Exception {
440
441    double[] probs = getClassifier(0).distributionForInstance(instance);
442    for (int i = 1; i < m_Classifiers.length; i++) {
443      double[] dist = getClassifier(i).distributionForInstance(instance);
444      for (int j = 0; j < dist.length; j++) {
445          probs[j] += dist[j];
446      }
447    }
448    for (int j = 0; j < probs.length; j++) {
449      probs[j] /= (double)m_Classifiers.length;
450    }
451    return probs;
452  }
453 
454  /**
455   * Classifies a given instance using the Product of Probabilities
456   * combination rule.
457   *
458   * @param instance the instance to be classified
459   * @return the distribution
460   * @throws Exception if instance could not be classified
461   * successfully
462   */
463  protected double[] distributionForInstanceProduct(Instance instance) throws Exception {
464
465    double[] probs = getClassifier(0).distributionForInstance(instance);
466    for (int i = 1; i < m_Classifiers.length; i++) {
467      double[] dist = getClassifier(i).distributionForInstance(instance);
468      for (int j = 0; j < dist.length; j++) {
469          probs[j] *= dist[j];
470      }
471    }
472   
473    return probs;
474  }
475 
476  /**
477   * Classifies a given instance using the Majority Voting combination rule.
478   *
479   * @param instance the instance to be classified
480   * @return the distribution
481   * @throws Exception if instance could not be classified
482   * successfully
483   */
484  protected double[] distributionForInstanceMajorityVoting(Instance instance) throws Exception {
485
486    double[] probs = new double[instance.classAttribute().numValues()];
487    double[] votes = new double[probs.length];
488   
489    for (int i = 0; i < m_Classifiers.length; i++) {
490      probs = getClassifier(i).distributionForInstance(instance);
491      int maxIndex = 0;
492      for(int j = 0; j<probs.length; j++) {
493          if(probs[j] > probs[maxIndex])
494                  maxIndex = j;
495      }
496     
497      // Consider the cases when multiple classes happen to have the same probability
498      for (int j=0; j<probs.length; j++) {
499        if (probs[j] == probs[maxIndex])
500          votes[j]++;
501      }
502    }
503   
504    int tmpMajorityIndex = 0;
505    for (int k = 1; k < votes.length; k++) {
506      if (votes[k] > votes[tmpMajorityIndex])
507        tmpMajorityIndex = k;
508    }
509   
510    // Consider the cases when multiple classes receive the same amount of votes
511    Vector<Integer> majorityIndexes = new Vector<Integer>();
512    for (int k = 0; k < votes.length; k++) {
513      if (votes[k] == votes[tmpMajorityIndex])
514        majorityIndexes.add(k);
515     }
516    // Resolve the ties according to a uniform random distribution
517    int majorityIndex = majorityIndexes.get(m_Random.nextInt(majorityIndexes.size()));
518   
519    //set probs to 0
520    for (int k = 0; k<probs.length; k++)
521      probs[k] = 0;
522    probs[majorityIndex] = 1; //the class that have been voted the most receives 1
523   
524    return probs;
525  }
526 
527  /**
528   * Classifies a given instance using the Maximum Probability combination rule.
529   *
530   * @param instance the instance to be classified
531   * @return the distribution
532   * @throws Exception if instance could not be classified
533   * successfully
534   */
535  protected double[] distributionForInstanceMax(Instance instance) throws Exception {
536
537    double[] max = getClassifier(0).distributionForInstance(instance);
538    for (int i = 1; i < m_Classifiers.length; i++) {
539      double[] dist = getClassifier(i).distributionForInstance(instance);
540      for (int j = 0; j < dist.length; j++) {
541          if(max[j]<dist[j])
542                  max[j]=dist[j];
543      }
544    }
545   
546    return max;
547  }
548 
549  /**
550   * Classifies a given instance using the Minimum Probability combination rule.
551   *
552   * @param instance the instance to be classified
553   * @return the distribution
554   * @throws Exception if instance could not be classified
555   * successfully
556   */
557  protected double[] distributionForInstanceMin(Instance instance) throws Exception {
558
559    double[] min = getClassifier(0).distributionForInstance(instance);
560    for (int i = 1; i < m_Classifiers.length; i++) {
561      double[] dist = getClassifier(i).distributionForInstance(instance);
562      for (int j = 0; j < dist.length; j++) {
563          if(dist[j]<min[j])
564                  min[j]=dist[j];
565      }
566    }
567   
568    return min;
569  } 
570 
571  /**
572   * Returns the tip text for this property
573   *
574   * @return            tip text for this property suitable for
575   *                    displaying in the explorer/experimenter gui
576   */
577  public String combinationRuleTipText() {
578    return "The combination rule used.";
579  }
580 
581  /**
582   * Gets the combination rule used
583   *
584   * @return            the combination rule used
585   */
586  public SelectedTag getCombinationRule() {
587    return new SelectedTag(m_CombinationRule, TAGS_RULES);
588  }
589
590  /**
591   * Sets the combination rule to use. Values other than
592   *
593   * @param newRule     the combination rule method to use
594   */
595  public void setCombinationRule(SelectedTag newRule) {
596    if (newRule.getTags() == TAGS_RULES)
597      m_CombinationRule = newRule.getSelectedTag().getID();
598  }
599 
600  /**
601   * Output a representation of this classifier
602   *
603   * @return a string representation of the classifier
604   */
605  public String toString() {
606
607    if (m_Classifiers == null) {
608      return "Vote: No model built yet.";
609    }
610
611    String result = "Vote combines";
612    result += " the probability distributions of these base learners:\n";
613    for (int i = 0; i < m_Classifiers.length; i++) {
614      result += '\t' + getClassifierSpec(i) + '\n';
615    }
616    result += "using the '";
617   
618    switch (m_CombinationRule) {
619      case AVERAGE_RULE:
620        result += "Average of Probabilities";
621        break;
622       
623      case PRODUCT_RULE:
624        result += "Product of Probabilities";
625        break;
626       
627      case MAJORITY_VOTING_RULE:
628        result += "Majority Voting";
629        break;
630       
631      case MIN_RULE:
632        result += "Minimum Probability";
633        break;
634       
635      case MAX_RULE:
636        result += "Maximum Probability";
637        break;
638       
639      case MEDIAN_RULE:
640        result += "Median Probability";
641        break;
642       
643      default:
644        throw new IllegalStateException("Unknown combination rule '" + m_CombinationRule + "'!");
645    }
646   
647    result += "' combination rule \n";
648
649    return result;
650  }
651 
652  /**
653   * Returns the revision string.
654   *
655   * @return            the revision
656   */
657  public String getRevision() {
658    return RevisionUtils.extract("$Revision: 5987 $");
659  }
660
661  /**
662   * Main method for testing this class.
663   *
664   * @param argv should contain the following arguments:
665   * -t training file [-T test file] [-c class index]
666   */
667  public static void main(String [] argv) {
668    runClassifier(new Vote(), argv);
669  }
670}
Note: See TracBrowser for help on using the repository browser.