source: src/main/java/weka/classifiers/meta/MultiClassClassifier.java @ 4

Last change on this file since 4 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 28.4 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    MultiClassClassifier.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.meta;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.classifiers.RandomizableSingleClassifierEnhancer;
28import weka.classifiers.rules.ZeroR;
29import weka.core.Attribute;
30import weka.core.Capabilities;
31import weka.core.FastVector;
32import weka.core.Instance;
33import weka.core.Instances;
34import weka.core.Option;
35import weka.core.OptionHandler;
36import weka.core.Range;
37import weka.core.RevisionHandler;
38import weka.core.RevisionUtils;
39import weka.core.SelectedTag;
40import weka.core.Tag;
41import weka.core.Utils;
42import weka.core.Capabilities.Capability;
43import weka.filters.Filter;
44import weka.filters.unsupervised.attribute.MakeIndicator;
45import weka.filters.unsupervised.instance.RemoveWithValues;
46
47import java.io.Serializable;
48import java.util.Enumeration;
49import java.util.Random;
50import java.util.Vector;
51
52/**
53 <!-- globalinfo-start -->
54 * A metaclassifier for handling multi-class datasets with 2-class classifiers. This classifier is also capable of applying error correcting output codes for increased accuracy.
55 * <p/>
56 <!-- globalinfo-end -->
57 *
58 <!-- options-start -->
59 * Valid options are: <p/>
60 *
61 * <pre> -M &lt;num&gt;
62 *  Sets the method to use. Valid values are 0 (1-against-all),
63 *  1 (random codes), 2 (exhaustive code), and 3 (1-against-1). (default 0)
64 * </pre>
65 *
66 * <pre> -R &lt;num&gt;
67 *  Sets the multiplier when using random codes. (default 2.0)</pre>
68 *
69 * <pre> -P
70 *  Use pairwise coupling (only has an effect for 1-against1)</pre>
71 *
72 * <pre> -S &lt;num&gt;
73 *  Random number seed.
74 *  (default 1)</pre>
75 *
76 * <pre> -D
77 *  If set, classifier is run in debug mode and
78 *  may output additional info to the console</pre>
79 *
80 * <pre> -W
81 *  Full name of base classifier.
82 *  (default: weka.classifiers.functions.Logistic)</pre>
83 *
84 * <pre>
85 * Options specific to classifier weka.classifiers.functions.Logistic:
86 * </pre>
87 *
88 * <pre> -D
89 *  Turn on debugging output.</pre>
90 *
91 * <pre> -R &lt;ridge&gt;
92 *  Set the ridge in the log-likelihood.</pre>
93 *
94 * <pre> -M &lt;number&gt;
95 *  Set the maximum number of iterations (default -1, until convergence).</pre>
96 *
97 <!-- options-end -->
98 *
99 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
100 * @author Len Trigg (len@reeltwo.com)
101 * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
102 * @version $Revision: 5928 $
103 */
104public class MultiClassClassifier 
105  extends RandomizableSingleClassifierEnhancer
106  implements OptionHandler {
107
108  /** for serialization */
109  static final long serialVersionUID = -3879602011542849141L;
110 
111  /** The classifiers. */
112  private Classifier [] m_Classifiers;
113
114  /** Use pairwise coupling with 1-vs-1 */
115  private boolean m_pairwiseCoupling = false;
116
117  /** Needed for pairwise coupling */
118  private double [] m_SumOfWeights;
119
120  /** The filters used to transform the class. */
121  private Filter[] m_ClassFilters;
122
123  /** ZeroR classifier for when all base classifier return zero probability. */
124  private ZeroR m_ZeroR;
125
126  /** Internal copy of the class attribute for output purposes */
127  private Attribute m_ClassAttribute;
128 
129  /** A transformed dataset header used by the  1-against-1 method */
130  private Instances m_TwoClassDataset;
131
132  /**
133   * The multiplier when generating random codes. Will generate
134   * numClasses * m_RandomWidthFactor codes
135   */
136  private double m_RandomWidthFactor = 2.0;
137
138  /** The multiclass method to use */
139  private int m_Method = METHOD_1_AGAINST_ALL;
140
141  /** 1-against-all */
142  public static final int METHOD_1_AGAINST_ALL    = 0;
143  /** random correction code */
144  public static final int METHOD_ERROR_RANDOM     = 1;
145  /** exhaustive correction code */
146  public static final int METHOD_ERROR_EXHAUSTIVE = 2;
147  /** 1-against-1 */
148  public static final int METHOD_1_AGAINST_1      = 3;
149  /** The error correction modes */
150  public static final Tag [] TAGS_METHOD = {
151    new Tag(METHOD_1_AGAINST_ALL, "1-against-all"),
152    new Tag(METHOD_ERROR_RANDOM, "Random correction code"),
153    new Tag(METHOD_ERROR_EXHAUSTIVE, "Exhaustive correction code"),
154    new Tag(METHOD_1_AGAINST_1, "1-against-1")
155  };
156   
157  /**
158   * Constructor.
159   */
160  public MultiClassClassifier() {
161   
162    m_Classifier = new weka.classifiers.functions.Logistic();
163  }
164
165  /**
166   * String describing default classifier.
167   *
168   * @return the default classifier classname
169   */
170  protected String defaultClassifierString() {
171   
172    return "weka.classifiers.functions.Logistic";
173  }
174
175  /**
176   * Interface for the code constructors
177   */
178  private abstract class Code 
179    implements Serializable, RevisionHandler {
180
181    /** for serialization */
182    static final long serialVersionUID = 418095077487120846L;
183   
184    /**
185     * Subclasses must allocate and fill these.
186     * First dimension is number of codes.
187     * Second dimension is number of classes.
188     */
189    protected boolean [][]m_Codebits;
190
191    /**
192     * Returns the number of codes.
193     * @return the number of codes
194     */
195    public int size() {
196      return m_Codebits.length;
197    }
198
199    /**
200     * Returns the indices of the values set to true for this code,
201     * using 1-based indexing (for input to Range).
202     *
203     * @param which the index
204     * @return the 1-based indices
205     */
206    public String getIndices(int which) {
207      StringBuffer sb = new StringBuffer();
208      for (int i = 0; i < m_Codebits[which].length; i++) {
209        if (m_Codebits[which][i]) {
210          if (sb.length() != 0) {
211            sb.append(',');
212          }
213          sb.append(i + 1);
214        }
215      }
216      return sb.toString();
217    }
218
219    /**
220     * Returns a human-readable representation of the codes.
221     * @return a string representation of the codes
222     */
223    public String toString() {
224      StringBuffer sb = new StringBuffer();
225      for(int i = 0; i < m_Codebits[0].length; i++) {
226        for (int j = 0; j < m_Codebits.length; j++) {
227          sb.append(m_Codebits[j][i] ? " 1" : " 0");
228        }
229        sb.append('\n');
230      }
231      return sb.toString();
232    }
233   
234    /**
235     * Returns the revision string.
236     *
237     * @return          the revision
238     */
239    public String getRevision() {
240      return RevisionUtils.extract("$Revision: 5928 $");
241    }
242  }
243
244  /**
245   * Constructs a code with no error correction
246   */
247  private class StandardCode 
248    extends Code {
249   
250    /** for serialization */
251    static final long serialVersionUID = 3707829689461467358L;
252   
253    /**
254     * constructor
255     *
256     * @param numClasses the number of classes
257     */
258    public StandardCode(int numClasses) {
259      m_Codebits = new boolean[numClasses][numClasses];
260      for (int i = 0; i < numClasses; i++) {
261        m_Codebits[i][i] = true;
262      }
263      //System.err.println("Code:\n" + this);
264    }
265   
266    /**
267     * Returns the revision string.
268     *
269     * @return          the revision
270     */
271    public String getRevision() {
272      return RevisionUtils.extract("$Revision: 5928 $");
273    }
274  }
275
276  /**
277   * Constructs a random code assignment
278   */
279  private class RandomCode 
280    extends Code {
281
282    /** for serialization */
283    static final long serialVersionUID = 4413410540703926563L;
284   
285    /** random number generator */
286    Random r = null;
287   
288    /**
289     * constructor
290     *
291     * @param numClasses the number of classes
292     * @param numCodes the number of codes
293     * @param data the data to use
294     */
295    public RandomCode(int numClasses, int numCodes, Instances data) {
296      r = data.getRandomNumberGenerator(m_Seed);
297      numCodes = Math.max(2, numCodes); // Need at least two classes
298      m_Codebits = new boolean[numCodes][numClasses];
299      int i = 0;
300      do {
301        randomize();
302        //System.err.println(this);
303      } while (!good() && (i++ < 100));
304      //System.err.println("Code:\n" + this);
305    }
306
307    private boolean good() {
308      boolean [] ninClass = new boolean[m_Codebits[0].length];
309      boolean [] ainClass = new boolean[m_Codebits[0].length];
310      for (int i = 0; i < ainClass.length; i++) {
311        ainClass[i] = true;
312      }
313
314      for (int i = 0; i < m_Codebits.length; i++) {
315        boolean ninCode = false;
316        boolean ainCode = true;
317        for (int j = 0; j < m_Codebits[i].length; j++) {
318          boolean current = m_Codebits[i][j];
319          ninCode = ninCode || current;
320          ainCode = ainCode && current;
321          ninClass[j] = ninClass[j] || current;
322          ainClass[j] = ainClass[j] && current;
323        }
324        if (!ninCode || ainCode) {
325          return false;
326        }
327      }
328      for (int j = 0; j < ninClass.length; j++) {
329        if (!ninClass[j] || ainClass[j]) {
330          return false;
331        }
332      }
333      return true;
334    }
335
336    /**
337     * randomizes
338     */
339    private void randomize() {
340      for (int i = 0; i < m_Codebits.length; i++) {
341        for (int j = 0; j < m_Codebits[i].length; j++) {
342          double temp = r.nextDouble();
343          m_Codebits[i][j] = (temp < 0.5) ? false : true;
344        }
345      }
346    }
347   
348    /**
349     * Returns the revision string.
350     *
351     * @return          the revision
352     */
353    public String getRevision() {
354      return RevisionUtils.extract("$Revision: 5928 $");
355    }
356  }
357
358  /*
359   * TODO: Constructs codes as per:
360   * Bose, R.C., Ray Chaudhuri (1960), On a class of error-correcting
361   * binary group codes, Information and Control, 3, 68-79.
362   * Hocquenghem, A. (1959) Codes corecteurs d'erreurs, Chiffres, 2, 147-156.
363   */
364  //private class BCHCode extends Code {...}
365
366  /** Constructs an exhaustive code assignment */
367  private class ExhaustiveCode 
368    extends Code {
369
370    /** for serialization */
371    static final long serialVersionUID = 8090991039670804047L;
372   
373    /**
374     * constructor
375     *
376     * @param numClasses the number of classes
377     */
378    public ExhaustiveCode(int numClasses) {
379      int width = (int)Math.pow(2, numClasses - 1) - 1;
380      m_Codebits = new boolean[width][numClasses];
381      for (int j = 0; j < width; j++) {
382        m_Codebits[j][0] = true;
383      }
384      for (int i = 1; i < numClasses; i++) {
385        int skip = (int) Math.pow(2, numClasses - (i + 1));
386        for(int j = 0; j < width; j++) {
387          m_Codebits[j][i] = ((j / skip) % 2 != 0);
388        }
389      }
390      //System.err.println("Code:\n" + this);
391    }
392   
393    /**
394     * Returns the revision string.
395     *
396     * @return          the revision
397     */
398    public String getRevision() {
399      return RevisionUtils.extract("$Revision: 5928 $");
400    }
401  }
402
403  /**
404   * Returns default capabilities of the classifier.
405   *
406   * @return      the capabilities of this classifier
407   */
408  public Capabilities getCapabilities() {
409    Capabilities result = super.getCapabilities();
410
411    // class
412    result.disableAllClasses();
413    result.disableAllClassDependencies();
414    result.enable(Capability.NOMINAL_CLASS);
415   
416    return result;
417  }
418
419  /**
420   * Builds the classifiers.
421   *
422   * @param insts the training data.
423   * @throws Exception if a classifier can't be built
424   */
425  public void buildClassifier(Instances insts) throws Exception {
426
427    Instances newInsts;
428
429    // can classifier handle the data?
430    getCapabilities().testWithFail(insts);
431
432    // remove instances with missing class
433    insts = new Instances(insts);
434    insts.deleteWithMissingClass();
435   
436    if (m_Classifier == null) {
437      throw new Exception("No base classifier has been set!");
438    }
439    m_ZeroR = new ZeroR();
440    m_ZeroR.buildClassifier(insts);
441
442    m_TwoClassDataset = null;
443
444    int numClassifiers = insts.numClasses();
445    if (numClassifiers <= 2) {
446
447      m_Classifiers = AbstractClassifier.makeCopies(m_Classifier, 1);
448      m_Classifiers[0].buildClassifier(insts);
449
450      m_ClassFilters = null;
451
452    } else if (m_Method == METHOD_1_AGAINST_1) {
453      // generate fastvector of pairs
454      FastVector pairs = new FastVector();
455      for (int i=0; i<insts.numClasses(); i++) {
456        for (int j=0; j<insts.numClasses(); j++) {
457          if (j<=i) continue;
458          int[] pair = new int[2];
459          pair[0] = i; pair[1] = j;
460          pairs.addElement(pair);
461        }
462      }
463
464      numClassifiers = pairs.size();
465      m_Classifiers = AbstractClassifier.makeCopies(m_Classifier, numClassifiers);
466      m_ClassFilters = new Filter[numClassifiers];
467      m_SumOfWeights = new double[numClassifiers];
468
469      // generate the classifiers
470      for (int i=0; i<numClassifiers; i++) {
471        RemoveWithValues classFilter = new RemoveWithValues();
472        classFilter.setAttributeIndex("" + (insts.classIndex() + 1));
473        classFilter.setModifyHeader(true);
474        classFilter.setInvertSelection(true);
475        classFilter.setNominalIndicesArr((int[])pairs.elementAt(i));
476        Instances tempInstances = new Instances(insts, 0);
477        tempInstances.setClassIndex(-1);
478        classFilter.setInputFormat(tempInstances);
479        newInsts = Filter.useFilter(insts, classFilter);
480        if (newInsts.numInstances() > 0) {
481          newInsts.setClassIndex(insts.classIndex());
482          m_Classifiers[i].buildClassifier(newInsts);
483          m_ClassFilters[i] = classFilter;
484          m_SumOfWeights[i] = newInsts.sumOfWeights();
485        } else {
486          m_Classifiers[i] = null;
487          m_ClassFilters[i] = null;
488        }
489      }
490
491      // construct a two-class header version of the dataset
492      m_TwoClassDataset = new Instances(insts, 0);
493      int classIndex = m_TwoClassDataset.classIndex();
494      m_TwoClassDataset.setClassIndex(-1);
495      m_TwoClassDataset.deleteAttributeAt(classIndex);
496      FastVector classLabels = new FastVector();
497      classLabels.addElement("class0");
498      classLabels.addElement("class1");
499      m_TwoClassDataset.insertAttributeAt(new Attribute("class", classLabels),
500                                          classIndex);
501      m_TwoClassDataset.setClassIndex(classIndex);
502
503    } else { // use error correcting code style methods
504      Code code = null;
505      switch (m_Method) {
506      case METHOD_ERROR_EXHAUSTIVE:
507        code = new ExhaustiveCode(numClassifiers);
508        break;
509      case METHOD_ERROR_RANDOM:
510        code = new RandomCode(numClassifiers, 
511                              (int)(numClassifiers * m_RandomWidthFactor),
512                              insts);
513        break;
514      case METHOD_1_AGAINST_ALL:
515        code = new StandardCode(numClassifiers);
516        break;
517      default:
518        throw new Exception("Unrecognized correction code type");
519      }
520      numClassifiers = code.size();
521      m_Classifiers = AbstractClassifier.makeCopies(m_Classifier, numClassifiers);
522      m_ClassFilters = new MakeIndicator[numClassifiers];
523      for (int i = 0; i < m_Classifiers.length; i++) {
524        m_ClassFilters[i] = new MakeIndicator();
525        MakeIndicator classFilter = (MakeIndicator) m_ClassFilters[i];
526        classFilter.setAttributeIndex("" + (insts.classIndex() + 1));
527        classFilter.setValueIndices(code.getIndices(i));
528        classFilter.setNumeric(false);
529        classFilter.setInputFormat(insts);
530        newInsts = Filter.useFilter(insts, m_ClassFilters[i]);
531        m_Classifiers[i].buildClassifier(newInsts);
532      }
533    }
534    m_ClassAttribute = insts.classAttribute();
535  }
536
537  /**
538   * Returns the individual predictions of the base classifiers
539   * for an instance. Used by StackedMultiClassClassifier.
540   * Returns the probability for the second "class" predicted
541   * by each base classifier.
542   *
543   * @param inst the instance to get the prediction for
544   * @return the individual predictions
545   * @throws Exception if the predictions can't be computed successfully
546   */
547  public double[] individualPredictions(Instance inst) throws Exception {
548   
549    double[] result = null;
550
551    if (m_Classifiers.length == 1) {
552      result = new double[1];
553      result[0] = m_Classifiers[0].distributionForInstance(inst)[1];
554    } else {
555      result = new double[m_ClassFilters.length];
556      for(int i = 0; i < m_ClassFilters.length; i++) {
557        if (m_Classifiers[i] != null) {
558          if (m_Method == METHOD_1_AGAINST_1) {   
559            Instance tempInst = (Instance)inst.copy(); 
560            tempInst.setDataset(m_TwoClassDataset);
561            result[i] = m_Classifiers[i].distributionForInstance(tempInst)[1]; 
562          } else {
563            m_ClassFilters[i].input(inst);
564            m_ClassFilters[i].batchFinished();
565            result[i] = m_Classifiers[i].
566              distributionForInstance(m_ClassFilters[i].output())[1];
567          }
568        }
569      }
570    }
571    return result;
572  }
573
574  /**
575   * Returns the distribution for an instance.
576   *
577   * @param inst the instance to get the distribution for
578   * @return the distribution
579   * @throws Exception if the distribution can't be computed successfully
580   */
581  public double[] distributionForInstance(Instance inst) throws Exception {
582   
583    if (m_Classifiers.length == 1) {
584      return m_Classifiers[0].distributionForInstance(inst);
585    }
586   
587    double[] probs = new double[inst.numClasses()];
588
589    if (m_Method == METHOD_1_AGAINST_1) {
590      double[][] r = new double[inst.numClasses()][inst.numClasses()];
591      double[][] n = new double[inst.numClasses()][inst.numClasses()];
592
593      for(int i = 0; i < m_ClassFilters.length; i++) {
594        if (m_Classifiers[i] != null) {
595          Instance tempInst = (Instance)inst.copy(); 
596          tempInst.setDataset(m_TwoClassDataset);
597          double [] current = m_Classifiers[i].distributionForInstance(tempInst); 
598          Range range = new Range(((RemoveWithValues)m_ClassFilters[i])
599                                  .getNominalIndices());
600          range.setUpper(m_ClassAttribute.numValues());
601          int[] pair = range.getSelection();
602          if (m_pairwiseCoupling && inst.numClasses() > 2) {
603            r[pair[0]][pair[1]] = current[0];
604            n[pair[0]][pair[1]] = m_SumOfWeights[i];
605          } else {
606            if (current[0] > current[1]) {
607              probs[pair[0]] += 1.0;
608            } else {
609              probs[pair[1]] += 1.0;
610            }
611          }
612        }
613      }
614      if (m_pairwiseCoupling && inst.numClasses() > 2) {
615        return pairwiseCoupling(n, r);
616      }
617    } else {
618      // error correcting style methods
619      for(int i = 0; i < m_ClassFilters.length; i++) {
620        m_ClassFilters[i].input(inst);
621        m_ClassFilters[i].batchFinished();
622        double [] current = m_Classifiers[i].
623          distributionForInstance(m_ClassFilters[i].output());
624        for (int j = 0; j < m_ClassAttribute.numValues(); j++) {
625          if (((MakeIndicator)m_ClassFilters[i]).getValueRange().isInRange(j)) {
626            probs[j] += current[1];
627          } else {
628            probs[j] += current[0];
629          }
630        }
631      }
632    }
633   
634    if (Utils.gr(Utils.sum(probs), 0)) {
635      Utils.normalize(probs);
636      return probs;
637    } else {
638      return m_ZeroR.distributionForInstance(inst);
639    }
640  }
641
642  /**
643   * Prints the classifiers.
644   *
645   * @return a string representation of the classifier
646   */
647  public String toString() {
648
649    if (m_Classifiers == null) {
650      return "MultiClassClassifier: No model built yet.";
651    }
652    StringBuffer text = new StringBuffer();
653    text.append("MultiClassClassifier\n\n");
654    for (int i = 0; i < m_Classifiers.length; i++) {
655      text.append("Classifier ").append(i + 1);
656      if (m_Classifiers[i] != null) {
657        if ((m_ClassFilters != null) && (m_ClassFilters[i] != null)) {
658          if (m_ClassFilters[i] instanceof RemoveWithValues) {
659            Range range = new Range(((RemoveWithValues)m_ClassFilters[i])
660                                    .getNominalIndices());
661            range.setUpper(m_ClassAttribute.numValues());
662            int[] pair = range.getSelection();
663            text.append(", " + (pair[0]+1) + " vs " + (pair[1]+1));
664          } else if (m_ClassFilters[i] instanceof MakeIndicator) {
665            text.append(", using indicator values: ");
666            text.append(((MakeIndicator)m_ClassFilters[i]).getValueRange());
667          }
668        }
669        text.append('\n');
670        text.append(m_Classifiers[i].toString() + "\n\n");
671      } else {
672        text.append(" Skipped (no training examples)\n");
673      }
674    }
675
676    return text.toString();
677  }
678
679  /**
680   * Returns an enumeration describing the available options
681   *
682   * @return an enumeration of all the available options
683   */
684  public Enumeration listOptions()  {
685
686    Vector vec = new Vector(4);
687   
688    vec.addElement(new Option(
689       "\tSets the method to use. Valid values are 0 (1-against-all),\n"
690       +"\t1 (random codes), 2 (exhaustive code), and 3 (1-against-1). (default 0)\n",
691       "M", 1, "-M <num>"));
692    vec.addElement(new Option(
693       "\tSets the multiplier when using random codes. (default 2.0)",
694       "R", 1, "-R <num>"));
695    vec.addElement(new Option(
696        "\tUse pairwise coupling (only has an effect for 1-against1)",
697        "P", 0, "-P"));
698
699    Enumeration enu = super.listOptions();
700    while (enu.hasMoreElements()) {
701      vec.addElement(enu.nextElement());
702    }
703    return vec.elements();
704  }
705
706  /**
707   * Parses a given list of options. <p/>
708   *
709   <!-- options-start -->
710   * Valid options are: <p/>
711   *
712   * <pre> -M &lt;num&gt;
713   *  Sets the method to use. Valid values are 0 (1-against-all),
714   *  1 (random codes), 2 (exhaustive code), and 3 (1-against-1). (default 0)
715   * </pre>
716   *
717   * <pre> -R &lt;num&gt;
718   *  Sets the multiplier when using random codes. (default 2.0)</pre>
719   *
720   * <pre> -P
721   *  Use pairwise coupling (only has an effect for 1-against1)</pre>
722   *
723   * <pre> -S &lt;num&gt;
724   *  Random number seed.
725   *  (default 1)</pre>
726   *
727   * <pre> -D
728   *  If set, classifier is run in debug mode and
729   *  may output additional info to the console</pre>
730   *
731   * <pre> -W
732   *  Full name of base classifier.
733   *  (default: weka.classifiers.functions.Logistic)</pre>
734   *
735   * <pre>
736   * Options specific to classifier weka.classifiers.functions.Logistic:
737   * </pre>
738   *
739   * <pre> -D
740   *  Turn on debugging output.</pre>
741   *
742   * <pre> -R &lt;ridge&gt;
743   *  Set the ridge in the log-likelihood.</pre>
744   *
745   * <pre> -M &lt;number&gt;
746   *  Set the maximum number of iterations (default -1, until convergence).</pre>
747   *
748   <!-- options-end -->
749   *
750   * @param options the list of options as an array of strings
751   * @throws Exception if an option is not supported
752   */
753  public void setOptions(String[] options) throws Exception {
754 
755    String errorString = Utils.getOption('M', options);
756    if (errorString.length() != 0) {
757      setMethod(new SelectedTag(Integer.parseInt(errorString), 
758                                             TAGS_METHOD));
759    } else {
760      setMethod(new SelectedTag(METHOD_1_AGAINST_ALL, TAGS_METHOD));
761    }
762
763    String rfactorString = Utils.getOption('R', options);
764    if (rfactorString.length() != 0) {
765      setRandomWidthFactor((new Double(rfactorString)).doubleValue());
766    } else {
767      setRandomWidthFactor(2.0);
768    }
769
770    setUsePairwiseCoupling(Utils.getFlag('P', options));
771
772    super.setOptions(options);
773  }
774
775  /**
776   * Gets the current settings of the Classifier.
777   *
778   * @return an array of strings suitable for passing to setOptions
779   */
780  public String [] getOptions() {
781
782    String [] superOptions = super.getOptions();
783    String [] options = new String [superOptions.length + 5];
784
785    int current = 0;
786
787
788    options[current++] = "-M";
789    options[current++] = "" + m_Method;
790
791    if (getUsePairwiseCoupling()) {
792      options[current++] = "-P";
793    }
794   
795    options[current++] = "-R";
796    options[current++] = "" + m_RandomWidthFactor;
797
798    System.arraycopy(superOptions, 0, options, current, 
799                     superOptions.length);
800
801    current += superOptions.length;
802    while (current < options.length) {
803      options[current++] = "";
804    }
805    return options;
806  }
807
808  /**
809   * @return a description of the classifier suitable for
810   * displaying in the explorer/experimenter gui
811   */
812  public String globalInfo() {
813
814    return "A metaclassifier for handling multi-class datasets with 2-class "
815      + "classifiers. This classifier is also capable of "
816      + "applying error correcting output codes for increased accuracy.";
817  }
818
819  /**
820   * @return tip text for this property suitable for
821   * displaying in the explorer/experimenter gui
822   */
823  public String randomWidthFactorTipText() {
824
825    return "Sets the width multiplier when using random codes. The number "
826      + "of codes generated will be thus number multiplied by the number of "
827      + "classes.";
828  }
829
830  /**
831   * Gets the multiplier when generating random codes. Will generate
832   * numClasses * m_RandomWidthFactor codes.
833   *
834   * @return the width multiplier
835   */
836  public double getRandomWidthFactor() {
837
838    return m_RandomWidthFactor;
839  }
840 
841  /**
842   * Sets the multiplier when generating random codes. Will generate
843   * numClasses * m_RandomWidthFactor codes.
844   *
845   * @param newRandomWidthFactor the new width multiplier
846   */
847  public void setRandomWidthFactor(double newRandomWidthFactor) {
848
849    m_RandomWidthFactor = newRandomWidthFactor;
850  }
851 
852  /**
853   * @return tip text for this property suitable for
854   * displaying in the explorer/experimenter gui
855   */
856  public String methodTipText() {
857    return "Sets the method to use for transforming the multi-class problem into "
858      + "several 2-class ones."; 
859  }
860
861  /**
862   * Gets the method used. Will be one of METHOD_1_AGAINST_ALL,
863   * METHOD_ERROR_RANDOM, METHOD_ERROR_EXHAUSTIVE, or METHOD_1_AGAINST_1.
864   *
865   * @return the current method.
866   */
867  public SelectedTag getMethod() {
868     
869    return new SelectedTag(m_Method, TAGS_METHOD);
870  }
871
872  /**
873   * Sets the method used. Will be one of METHOD_1_AGAINST_ALL,
874   * METHOD_ERROR_RANDOM, METHOD_ERROR_EXHAUSTIVE, or METHOD_1_AGAINST_1.
875   *
876   * @param newMethod the new method.
877   */
878  public void setMethod(SelectedTag newMethod) {
879   
880    if (newMethod.getTags() == TAGS_METHOD) {
881      m_Method = newMethod.getSelectedTag().getID();
882    }
883  }
884
885  /**
886   * Set whether to use pairwise coupling with 1-vs-1
887   * classification to improve probability estimates.
888   *
889   * @param p true if pairwise coupling is to be used
890   */
891  public void setUsePairwiseCoupling(boolean p) {
892    m_pairwiseCoupling = p;
893  }
894
895  /**
896   * Gets whether to use pairwise coupling with 1-vs-1
897   * classification to improve probability estimates.
898   *
899   * @return true if pairwise coupling is to be used
900   */
901  public boolean getUsePairwiseCoupling() {
902    return m_pairwiseCoupling;
903  }
904
905  /**
906   * @return tip text for this property suitable for
907   * displaying in the explorer/experimenter gui
908   */
909  public String usePairwiseCouplingTipText() {
910    return "Use pairwise coupling (only has an effect for 1-against-1).";
911  }
912
913  /**
914   * Implements pairwise coupling.
915   *
916   * @param n the sum of weights used to train each model
917   * @param r the probability estimate from each model
918   * @return the coupled estimates
919   */
920  public static double[] pairwiseCoupling(double[][] n, double[][] r) {
921
922    // Initialize p and u array
923    double[] p = new double[r.length];
924    for (int i =0; i < p.length; i++) {
925      p[i] = 1.0 / (double)p.length;
926    }
927    double[][] u = new double[r.length][r.length];
928    for (int i = 0; i < r.length; i++) {
929      for (int j = i + 1; j < r.length; j++) {
930        u[i][j] = 0.5;
931      }
932    }
933
934    // firstSum doesn't change
935    double[] firstSum = new double[p.length];
936    for (int i = 0; i < p.length; i++) {
937      for (int j = i + 1; j < p.length; j++) {
938        firstSum[i] += n[i][j] * r[i][j];
939        firstSum[j] += n[i][j] * (1 - r[i][j]);
940      }
941    }
942
943    // Iterate until convergence
944    boolean changed;
945    do {
946      changed = false;
947      double[] secondSum = new double[p.length];
948      for (int i = 0; i < p.length; i++) {
949        for (int j = i + 1; j < p.length; j++) {
950          secondSum[i] += n[i][j] * u[i][j];
951          secondSum[j] += n[i][j] * (1 - u[i][j]);
952        }
953      }
954      for (int i = 0; i < p.length; i++) {
955        if ((firstSum[i] == 0) || (secondSum[i] == 0)) {
956          if (p[i] > 0) {
957            changed = true;
958          }
959          p[i] = 0;
960        } else {
961          double factor = firstSum[i] / secondSum[i];
962          double pOld = p[i];
963          p[i] *= factor;
964          if (Math.abs(pOld - p[i]) > 1.0e-3) {
965            changed = true;
966          }
967        }
968      }
969      Utils.normalize(p);
970      for (int i = 0; i < r.length; i++) {
971        for (int j = i + 1; j < r.length; j++) {
972          u[i][j] = p[i] / (p[i] + p[j]);
973        }
974      }
975    } while (changed);
976    return p;
977  }
978 
979  /**
980   * Returns the revision string.
981   *
982   * @return            the revision
983   */
984  public String getRevision() {
985    return RevisionUtils.extract("$Revision: 5928 $");
986  }
987
988  /**
989   * Main method for testing this class.
990   *
991   * @param argv the options
992   */
993  public static void main(String [] argv) {
994    runClassifier(new MultiClassClassifier(), argv);
995  }
996}
997
Note: See TracBrowser for help on using the repository browser.