source: src/main/java/weka/classifiers/meta/ThresholdSelector.java @ 16

Last change on this file since 16 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 35.2 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    ThresholdSelector.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.meta;
24
25import weka.classifiers.RandomizableSingleClassifierEnhancer;
26import weka.classifiers.evaluation.EvaluationUtils;
27import weka.classifiers.evaluation.ThresholdCurve;
28import weka.core.Attribute;
29import weka.core.AttributeStats;
30import weka.core.Capabilities;
31import weka.core.Drawable;
32import weka.core.FastVector;
33import weka.core.Instance;
34import weka.core.Instances;
35import weka.core.Option;
36import weka.core.OptionHandler;
37import weka.core.RevisionUtils;
38import weka.core.SelectedTag;
39import weka.core.Tag;
40import weka.core.Utils;
41import weka.core.Capabilities.Capability;
42
43import java.util.Enumeration;
44import java.util.Random;
45import java.util.Vector;
46
47/**
48 <!-- globalinfo-start -->
49 * A metaclassifier that selecting a mid-point threshold on the probability output by a Classifier. The midpoint threshold is set so that a given performance measure is optimized. Currently this is the F-measure. Performance is measured either on the training data, a hold-out set or using cross-validation. In addition, the probabilities returned by the base learner can have their range expanded so that the output probabilities will reside between 0 and 1 (this is useful if the scheme normally produces probabilities in a very narrow range).
50 * <p/>
51 <!-- globalinfo-end -->
52 *
53 <!-- options-start -->
54 * Valid options are: <p/>
55 *
56 * <pre> -C &lt;integer&gt;
57 *  The class for which threshold is determined. Valid values are:
58 *  1, 2 (for first and second classes, respectively), 3 (for whichever
59 *  class is least frequent), and 4 (for whichever class value is most
60 *  frequent), and 5 (for the first class named any of "yes","pos(itive)"
61 *  "1", or method 3 if no matches). (default 5).</pre>
62 *
63 * <pre> -X &lt;number of folds&gt;
64 *  Number of folds used for cross validation. If just a
65 *  hold-out set is used, this determines the size of the hold-out set
66 *  (default 3).</pre>
67 *
68 * <pre> -R &lt;integer&gt;
69 *  Sets whether confidence range correction is applied. This
70 *  can be used to ensure the confidences range from 0 to 1.
71 *  Use 0 for no range correction, 1 for correction based on
72 *  the min/max values seen during threshold selection
73 *  (default 0).</pre>
74 *
75 * <pre> -E &lt;integer&gt;
76 *  Sets the evaluation mode. Use 0 for
77 *  evaluation using cross-validation,
78 *  1 for evaluation using hold-out set,
79 *  and 2 for evaluation on the
80 *  training data (default 1).</pre>
81 *
82 * <pre> -M [FMEASURE|ACCURACY|TRUE_POS|TRUE_NEG|TP_RATE|PRECISION|RECALL]
83 *  Measure used for evaluation (default is FMEASURE).
84 * </pre>
85 *
86 * <pre> -manual &lt;real&gt;
87 *  Set a manual threshold to use. This option overrides
88 *  automatic selection and options pertaining to
89 *  automatic selection will be ignored.
90 *  (default -1, i.e. do not use a manual threshold).</pre>
91 *
92 * <pre> -S &lt;num&gt;
93 *  Random number seed.
94 *  (default 1)</pre>
95 *
96 * <pre> -D
97 *  If set, classifier is run in debug mode and
98 *  may output additional info to the console</pre>
99 *
100 * <pre> -W
101 *  Full name of base classifier.
102 *  (default: weka.classifiers.functions.Logistic)</pre>
103 *
104 * <pre>
105 * Options specific to classifier weka.classifiers.functions.Logistic:
106 * </pre>
107 *
108 * <pre> -D
109 *  Turn on debugging output.</pre>
110 *
111 * <pre> -R &lt;ridge&gt;
112 *  Set the ridge in the log-likelihood.</pre>
113 *
114 * <pre> -M &lt;number&gt;
115 *  Set the maximum number of iterations (default -1, until convergence).</pre>
116 *
117 <!-- options-end -->
118 *
119 * Options after -- are passed to the designated sub-classifier. <p>
120 *
121 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
122 * @version $Revision: 1.43 $
123 */
124public class ThresholdSelector 
125  extends RandomizableSingleClassifierEnhancer
126  implements OptionHandler, Drawable {
127
128  /** for serialization */
129  static final long serialVersionUID = -1795038053239867444L;
130
131  /** no range correction */
132  public static final int RANGE_NONE = 0;
133  /** Correct based on min/max observed */
134  public static final int RANGE_BOUNDS = 1;
135  /** Type of correction applied to threshold range */ 
136  public static final Tag [] TAGS_RANGE = {
137    new Tag(RANGE_NONE, "No range correction"),
138    new Tag(RANGE_BOUNDS, "Correct based on min/max observed")
139  };
140
141  /** entire training set */
142  public static final int EVAL_TRAINING_SET = 2;
143  /** single tuned fold */
144  public static final int EVAL_TUNED_SPLIT = 1;
145  /** n-fold cross-validation */
146  public static final int EVAL_CROSS_VALIDATION = 0;
147  /** The evaluation modes */
148  public static final Tag [] TAGS_EVAL = {
149    new Tag(EVAL_TRAINING_SET, "Entire training set"),
150    new Tag(EVAL_TUNED_SPLIT, "Single tuned fold"),
151    new Tag(EVAL_CROSS_VALIDATION, "N-Fold cross validation")
152  };
153
154  /** first class value */
155  public static final int OPTIMIZE_0     = 0;
156  /** second class value */
157  public static final int OPTIMIZE_1     = 1;
158  /** least frequent class value */
159  public static final int OPTIMIZE_LFREQ = 2;
160  /** most frequent class value */
161  public static final int OPTIMIZE_MFREQ = 3;
162  /** class value name, either 'yes' or 'pos(itive)' */
163  public static final int OPTIMIZE_POS_NAME = 4;
164  /** How to determine which class value to optimize for */
165  public static final Tag [] TAGS_OPTIMIZE = {
166    new Tag(OPTIMIZE_0, "First class value"),
167    new Tag(OPTIMIZE_1, "Second class value"),
168    new Tag(OPTIMIZE_LFREQ, "Least frequent class value"),
169    new Tag(OPTIMIZE_MFREQ, "Most frequent class value"),
170    new Tag(OPTIMIZE_POS_NAME, "Class value named: \"yes\", \"pos(itive)\",\"1\"")
171  };
172
173  /** F-measure */
174  public static final int FMEASURE  = 1;
175  /** accuracy */
176  public static final int ACCURACY  = 2;
177  /** true-positive */
178  public static final int TRUE_POS  = 3;
179  /** true-negative */
180  public static final int TRUE_NEG  = 4;
181  /** true-positive rate */
182  public static final int TP_RATE   = 5;
183  /** precision */
184  public static final int PRECISION = 6;
185  /** recall */
186  public static final int RECALL    = 7;
187  /** the measure to use */
188  public static final Tag[] TAGS_MEASURE = {
189    new Tag(FMEASURE,  "FMEASURE"),
190    new Tag(ACCURACY,  "ACCURACY"),
191    new Tag(TRUE_POS,  "TRUE_POS"),
192    new Tag(TRUE_NEG,  "TRUE_NEG"), 
193    new Tag(TP_RATE,   "TP_RATE"),   
194    new Tag(PRECISION, "PRECISION"), 
195    new Tag(RECALL,    "RECALL")
196  };
197
198  /** The upper threshold used as the basis of correction */
199  protected double m_HighThreshold = 1;
200
201  /** The lower threshold used as the basis of correction */
202  protected double m_LowThreshold = 0;
203
204  /** The threshold that lead to the best performance */
205  protected double m_BestThreshold = -Double.MAX_VALUE;
206
207  /** The best value that has been observed */
208  protected double m_BestValue = - Double.MAX_VALUE;
209 
210  /** The number of folds used in cross-validation */
211  protected int m_NumXValFolds = 3;
212
213  /** Designated class value, determined during building */
214  protected int m_DesignatedClass = 0;
215
216  /** Method to determine which class to optimize for */
217  protected int m_ClassMode = OPTIMIZE_POS_NAME;
218
219  /** The evaluation mode */
220  protected int m_EvalMode = EVAL_TUNED_SPLIT;
221
222  /** The range correction mode */
223  protected int m_RangeMode = RANGE_NONE;
224
225  /** evaluation measure used for determining threshold **/
226  int m_nMeasure = FMEASURE;
227
228  /** True if a manually set threshold is being used */
229  protected boolean m_manualThreshold = false;
230  /** -1 = not used by default */
231  protected double m_manualThresholdValue = -1;
232
233  /** The minimum value for the criterion. If threshold adjustment
234      yields less than that, the default threshold of 0.5 is used. */
235  protected static final double MIN_VALUE = 0.05;
236   
237  /**
238   * Constructor.
239   */
240  public ThresholdSelector() {
241   
242    m_Classifier = new weka.classifiers.functions.Logistic();
243  }
244
245  /**
246   * String describing default classifier.
247   *
248   * @return the default classifier classname
249   */
250  protected String defaultClassifierString() {
251   
252    return "weka.classifiers.functions.Logistic";
253  }
254
255  /**
256   * Collects the classifier predictions using the specified evaluation method.
257   *
258   * @param instances the set of <code>Instances</code> to generate
259   * predictions for.
260   * @param mode the evaluation mode.
261   * @param numFolds the number of folds to use if not evaluating on the
262   * full training set.
263   * @return a <code>FastVector</code> containing the predictions.
264   * @throws Exception if an error occurs generating the predictions.
265   */
266  protected FastVector getPredictions(Instances instances, int mode, int numFolds) 
267    throws Exception {
268
269    EvaluationUtils eu = new EvaluationUtils();
270    eu.setSeed(m_Seed);
271   
272    switch (mode) {
273    case EVAL_TUNED_SPLIT:
274      Instances trainData = null, evalData = null;
275      Instances data = new Instances(instances);
276      Random random = new Random(m_Seed);
277      data.randomize(random);
278      data.stratify(numFolds);
279     
280      // Make sure that both subsets contain at least one positive instance
281      for (int subsetIndex = 0; subsetIndex < numFolds; subsetIndex++) {
282        trainData = data.trainCV(numFolds, subsetIndex, random);
283        evalData = data.testCV(numFolds, subsetIndex);
284        if (checkForInstance(trainData) && checkForInstance(evalData)) {
285          break;
286        }
287      }
288      return eu.getTrainTestPredictions(m_Classifier, trainData, evalData);
289    case EVAL_TRAINING_SET:
290      return eu.getTrainTestPredictions(m_Classifier, instances, instances);
291    case EVAL_CROSS_VALIDATION:
292      return eu.getCVPredictions(m_Classifier, instances, numFolds);
293    default:
294      throw new RuntimeException("Unrecognized evaluation mode");
295    }
296  }
297
298  /**
299   * Tooltip for this property.
300   *
301   * @return    tip text for this property suitable for
302   *            displaying in the explorer/experimenter gui
303   */
304  public String measureTipText() {
305    return "Sets the measure for determining the threshold.";
306  }
307
308  /**
309   * set measure used for determining threshold
310   *
311   * @param newMeasure Tag representing measure to be used
312   */
313  public void setMeasure(SelectedTag newMeasure) {
314    if (newMeasure.getTags() == TAGS_MEASURE) {
315      m_nMeasure = newMeasure.getSelectedTag().getID();
316    }
317  }
318
319  /**
320   * get measure used for determining threshold
321   *
322   * @return Tag representing measure used
323   */
324  public SelectedTag getMeasure() {
325    return new SelectedTag(m_nMeasure, TAGS_MEASURE);
326  }
327
328
329  /**
330   * Finds the best threshold, this implementation searches for the
331   * highest FMeasure. If no FMeasure higher than MIN_VALUE is found,
332   * the default threshold of 0.5 is used.
333   *
334   * @param predictions a <code>FastVector</code> containing the predictions.
335   */
336  protected void findThreshold(FastVector predictions) {
337
338    Instances curve = (new ThresholdCurve()).getCurve(predictions, m_DesignatedClass);
339
340    double low = 1.0;
341    double high = 0.0;
342
343    //System.err.println(curve);
344    if (curve.numInstances() > 0) {
345      Instance maxInst = curve.instance(0);
346      double maxValue = 0; 
347      int index1 = 0;
348      int index2 = 0;
349      switch (m_nMeasure) {
350        case FMEASURE:
351          index1 = curve.attribute(ThresholdCurve.FMEASURE_NAME).index();
352          maxValue = maxInst.value(index1);
353          break;
354        case TRUE_POS:
355          index1 = curve.attribute(ThresholdCurve.TRUE_POS_NAME).index();
356          maxValue = maxInst.value(index1);
357          break;
358        case TRUE_NEG:
359          index1 = curve.attribute(ThresholdCurve.TRUE_NEG_NAME).index();
360          maxValue = maxInst.value(index1);
361          break;
362        case TP_RATE:
363          index1 = curve.attribute(ThresholdCurve.TP_RATE_NAME).index();
364          maxValue = maxInst.value(index1);
365          break;
366        case PRECISION:
367          index1 = curve.attribute(ThresholdCurve.PRECISION_NAME).index();
368          maxValue = maxInst.value(index1);
369          break;
370        case RECALL:
371          index1 = curve.attribute(ThresholdCurve.RECALL_NAME).index();
372          maxValue = maxInst.value(index1);
373          break;
374        case ACCURACY:
375          index1 = curve.attribute(ThresholdCurve.TRUE_POS_NAME).index();
376          index2 = curve.attribute(ThresholdCurve.TRUE_NEG_NAME).index();
377          maxValue = maxInst.value(index1) + maxInst.value(index2);
378          break;
379      }
380      int indexThreshold = curve.attribute(ThresholdCurve.THRESHOLD_NAME).index();
381      for (int i = 1; i < curve.numInstances(); i++) {
382        Instance current = curve.instance(i);
383        double currentValue = 0;
384        if (m_nMeasure ==  ACCURACY) {
385          currentValue= current.value(index1) + current.value(index2);
386          } else {
387              currentValue= current.value(index1);
388          }
389
390          if (currentValue> maxValue) {
391              maxInst = current;
392              maxValue = currentValue;
393          }
394          if (m_RangeMode == RANGE_BOUNDS) {
395              double thresh = current.value(indexThreshold);
396              if (thresh < low) {
397                  low = thresh;
398              }
399              if (thresh > high) {
400                  high = thresh;
401              }
402          }
403      }
404      if (maxValue > MIN_VALUE) {
405        m_BestThreshold = maxInst.value(indexThreshold);
406        m_BestValue = maxValue;
407        //System.err.println("maxFM: " + maxFM);
408      }
409      if (m_RangeMode == RANGE_BOUNDS) {
410          m_LowThreshold = low;
411          m_HighThreshold = high;
412        //System.err.println("Threshold range: " + low + " - " + high);
413      }
414    }
415
416  }
417
418  /**
419   * Returns an enumeration describing the available options.
420   *
421   * @return an enumeration of all the available options.
422   */
423  public Enumeration listOptions() {
424
425    Vector newVector = new Vector(5);
426
427    newVector.addElement(new Option(
428        "\tThe class for which threshold is determined. Valid values are:\n" +
429        "\t1, 2 (for first and second classes, respectively), 3 (for whichever\n" +
430        "\tclass is least frequent), and 4 (for whichever class value is most\n" +
431        "\tfrequent), and 5 (for the first class named any of \"yes\",\"pos(itive)\"\n" +
432        "\t\"1\", or method 3 if no matches). (default 5).",
433        "C", 1, "-C <integer>"));
434   
435    newVector.addElement(new Option(
436              "\tNumber of folds used for cross validation. If just a\n" +
437              "\thold-out set is used, this determines the size of the hold-out set\n" +
438              "\t(default 3).",
439              "X", 1, "-X <number of folds>"));
440   
441    newVector.addElement(new Option(
442        "\tSets whether confidence range correction is applied. This\n" +
443        "\tcan be used to ensure the confidences range from 0 to 1.\n" +
444        "\tUse 0 for no range correction, 1 for correction based on\n" +
445        "\tthe min/max values seen during threshold selection\n"+
446        "\t(default 0).",
447        "R", 1, "-R <integer>"));
448   
449    newVector.addElement(new Option(
450              "\tSets the evaluation mode. Use 0 for\n" +
451              "\tevaluation using cross-validation,\n" +
452              "\t1 for evaluation using hold-out set,\n" +
453              "\tand 2 for evaluation on the\n" +
454              "\ttraining data (default 1).",
455              "E", 1, "-E <integer>"));
456
457    newVector.addElement(new Option(
458              "\tMeasure used for evaluation (default is FMEASURE).\n",
459              "M", 1, "-M [FMEASURE|ACCURACY|TRUE_POS|TRUE_NEG|TP_RATE|PRECISION|RECALL]"));
460   
461    newVector.addElement(new Option(
462              "\tSet a manual threshold to use. This option overrides\n"
463              + "\tautomatic selection and options pertaining to\n"
464              + "\tautomatic selection will be ignored.\n"
465              + "\t(default -1, i.e. do not use a manual threshold).",
466              "manual", 1, "-manual <real>"));
467
468    Enumeration enu = super.listOptions();
469    while (enu.hasMoreElements()) {
470      newVector.addElement(enu.nextElement());
471    }
472    return newVector.elements();
473  }
474
475  /**
476   * Parses a given list of options. <p/>
477   *
478   <!-- options-start -->
479   * Valid options are: <p/>
480   *
481   * <pre> -C &lt;integer&gt;
482   *  The class for which threshold is determined. Valid values are:
483   *  1, 2 (for first and second classes, respectively), 3 (for whichever
484   *  class is least frequent), and 4 (for whichever class value is most
485   *  frequent), and 5 (for the first class named any of "yes","pos(itive)"
486   *  "1", or method 3 if no matches). (default 5).</pre>
487   *
488   * <pre> -X &lt;number of folds&gt;
489   *  Number of folds used for cross validation. If just a
490   *  hold-out set is used, this determines the size of the hold-out set
491   *  (default 3).</pre>
492   *
493   * <pre> -R &lt;integer&gt;
494   *  Sets whether confidence range correction is applied. This
495   *  can be used to ensure the confidences range from 0 to 1.
496   *  Use 0 for no range correction, 1 for correction based on
497   *  the min/max values seen during threshold selection
498   *  (default 0).</pre>
499   *
500   * <pre> -E &lt;integer&gt;
501   *  Sets the evaluation mode. Use 0 for
502   *  evaluation using cross-validation,
503   *  1 for evaluation using hold-out set,
504   *  and 2 for evaluation on the
505   *  training data (default 1).</pre>
506   *
507   * <pre> -M [FMEASURE|ACCURACY|TRUE_POS|TRUE_NEG|TP_RATE|PRECISION|RECALL]
508   *  Measure used for evaluation (default is FMEASURE).
509   * </pre>
510   *
511   * <pre> -manual &lt;real&gt;
512   *  Set a manual threshold to use. This option overrides
513   *  automatic selection and options pertaining to
514   *  automatic selection will be ignored.
515   *  (default -1, i.e. do not use a manual threshold).</pre>
516   *
517   * <pre> -S &lt;num&gt;
518   *  Random number seed.
519   *  (default 1)</pre>
520   *
521   * <pre> -D
522   *  If set, classifier is run in debug mode and
523   *  may output additional info to the console</pre>
524   *
525   * <pre> -W
526   *  Full name of base classifier.
527   *  (default: weka.classifiers.functions.Logistic)</pre>
528   *
529   * <pre>
530   * Options specific to classifier weka.classifiers.functions.Logistic:
531   * </pre>
532   *
533   * <pre> -D
534   *  Turn on debugging output.</pre>
535   *
536   * <pre> -R &lt;ridge&gt;
537   *  Set the ridge in the log-likelihood.</pre>
538   *
539   * <pre> -M &lt;number&gt;
540   *  Set the maximum number of iterations (default -1, until convergence).</pre>
541   *
542   <!-- options-end -->
543   *
544   * Options after -- are passed to the designated sub-classifier. <p>
545   *
546   * @param options the list of options as an array of strings
547   * @throws Exception if an option is not supported
548   */
549  public void setOptions(String[] options) throws Exception {
550   
551    String manualS = Utils.getOption("manual", options);
552    if (manualS.length() > 0) {
553      double val = Double.parseDouble(manualS);
554      if (val >= 0.0) {
555        setManualThresholdValue(val);
556      } 
557    }
558
559    String classString = Utils.getOption('C', options);
560    if (classString.length() != 0) {
561      setDesignatedClass(new SelectedTag(Integer.parseInt(classString) - 1, 
562                                         TAGS_OPTIMIZE));
563    } else {
564      setDesignatedClass(new SelectedTag(OPTIMIZE_POS_NAME, TAGS_OPTIMIZE));
565    }
566
567    String modeString = Utils.getOption('E', options);
568    if (modeString.length() != 0) {
569      setEvaluationMode(new SelectedTag(Integer.parseInt(modeString), 
570                                         TAGS_EVAL));
571    } else {
572      setEvaluationMode(new SelectedTag(EVAL_TUNED_SPLIT, TAGS_EVAL));
573    }
574
575    String rangeString = Utils.getOption('R', options);
576    if (rangeString.length() != 0) {
577      setRangeCorrection(new SelectedTag(Integer.parseInt(rangeString), 
578                                         TAGS_RANGE));
579    } else {
580      setRangeCorrection(new SelectedTag(RANGE_NONE, TAGS_RANGE));
581    }
582
583    String measureString = Utils.getOption('M', options);
584    if (measureString.length() != 0) {
585      setMeasure(new SelectedTag(measureString, TAGS_MEASURE));
586    } else {
587      setMeasure(new SelectedTag(FMEASURE, TAGS_MEASURE));
588    }
589
590    String foldsString = Utils.getOption('X', options);
591    if (foldsString.length() != 0) {
592      setNumXValFolds(Integer.parseInt(foldsString));
593    } else {
594      setNumXValFolds(3);
595    }
596
597    super.setOptions(options);
598  }
599
600  /**
601   * Gets the current settings of the Classifier.
602   *
603   * @return an array of strings suitable for passing to setOptions
604   */
605  public String [] getOptions() {
606
607    String [] superOptions = super.getOptions();
608    String [] options = new String [superOptions.length + 12];
609
610    int current = 0;
611
612    if (m_manualThreshold) {
613      options[current++] = "-manual"; options[current++] = "" + getManualThresholdValue();
614    }
615    options[current++] = "-C"; options[current++] = "" + (m_ClassMode + 1);
616    options[current++] = "-X"; options[current++] = "" + getNumXValFolds();
617    options[current++] = "-E"; options[current++] = "" + m_EvalMode;
618    options[current++] = "-R"; options[current++] = "" + m_RangeMode;
619    options[current++] = "-M"; options[current++] = "" + getMeasure().getSelectedTag().getReadable();
620
621    System.arraycopy(superOptions, 0, options, current, 
622                     superOptions.length);
623
624    current += superOptions.length;
625    while (current < options.length) {
626      options[current++] = "";
627    }
628    return options;
629  }
630
631  /**
632   * Returns default capabilities of the classifier.
633   *
634   * @return      the capabilities of this classifier
635   */
636  public Capabilities getCapabilities() {
637    Capabilities result = super.getCapabilities();
638
639    // class
640    result.disableAllClasses();
641    result.disableAllClassDependencies();
642    result.enable(Capability.BINARY_CLASS);
643   
644    return result;
645  }
646
647  /**
648   * Generates the classifier.
649   *
650   * @param instances set of instances serving as training data
651   * @throws Exception if the classifier has not been generated successfully
652   */
653  public void buildClassifier(Instances instances) 
654    throws Exception {
655
656    // can classifier handle the data?
657    getCapabilities().testWithFail(instances);
658
659    // remove instances with missing class
660    instances = new Instances(instances);
661    instances.deleteWithMissingClass();
662   
663    AttributeStats stats = instances.attributeStats(instances.classIndex());
664    if (m_manualThreshold) {
665      m_BestThreshold = m_manualThresholdValue;
666    } else {
667      m_BestThreshold = 0.5;
668    }
669    m_BestValue = MIN_VALUE;
670    m_HighThreshold = 1;
671    m_LowThreshold = 0;
672
673    // If data contains only one instance of positive data
674    // optimize on training data
675    if (stats.distinctCount != 2) {
676      System.err.println("Couldn't find examples of both classes. No adjustment.");
677      m_Classifier.buildClassifier(instances);
678    } else {
679     
680      // Determine which class value to look for
681      switch (m_ClassMode) {
682      case OPTIMIZE_0:
683        m_DesignatedClass = 0;
684        break;
685      case OPTIMIZE_1:
686        m_DesignatedClass = 1;
687        break;
688      case OPTIMIZE_POS_NAME:
689        Attribute cAtt = instances.classAttribute();
690        boolean found = false;
691        for (int i = 0; i < cAtt.numValues() && !found; i++) {
692          String name = cAtt.value(i).toLowerCase();
693          if (name.startsWith("yes") || name.equals("1") || 
694              name.startsWith("pos")) {
695            found = true;
696            m_DesignatedClass = i;
697          }
698        }
699        if (found) {
700          break;
701        }
702        // No named class found, so fall through to default of least frequent
703      case OPTIMIZE_LFREQ:
704        m_DesignatedClass = (stats.nominalCounts[0] > stats.nominalCounts[1]) ? 1 : 0;
705        break;
706      case OPTIMIZE_MFREQ:
707        m_DesignatedClass = (stats.nominalCounts[0] > stats.nominalCounts[1]) ? 0 : 1;
708        break;
709      default:
710        throw new Exception("Unrecognized class value selection mode");
711      }
712     
713      /*
714        System.err.println("ThresholdSelector: Using mode="
715        + TAGS_OPTIMIZE[m_ClassMode].getReadable());
716        System.err.println("ThresholdSelector: Optimizing using class "
717        + m_DesignatedClass + "/"
718        + instances.classAttribute().value(m_DesignatedClass));
719      */
720     
721      if (m_manualThreshold) {
722        m_Classifier.buildClassifier(instances);
723        return;
724      }
725
726      if (stats.nominalCounts[m_DesignatedClass] == 1) {
727        System.err.println("Only 1 positive found: optimizing on training data");
728        findThreshold(getPredictions(instances, EVAL_TRAINING_SET, 0));
729      } else {
730        int numFolds = Math.min(m_NumXValFolds, stats.nominalCounts[m_DesignatedClass]);
731        //System.err.println("Number of folds for threshold selector: " + numFolds);
732        findThreshold(getPredictions(instances, m_EvalMode, numFolds));
733        if (m_EvalMode != EVAL_TRAINING_SET) {
734          m_Classifier.buildClassifier(instances);
735        }
736      }
737    }
738  }
739
740  /**
741   * Checks whether instance of designated class is in subset.
742   *
743   * @param data the data to check for instance
744   * @return true if the instance is in the subset
745   * @throws Exception if checking fails
746   */
747  private boolean checkForInstance(Instances data) throws Exception {
748
749    for (int i = 0; i < data.numInstances(); i++) {
750      if (((int)data.instance(i).classValue()) == m_DesignatedClass) {
751        return true;
752      }
753    }
754    return false;
755  }
756
757
758  /**
759   * Calculates the class membership probabilities for the given test instance.
760   *
761   * @param instance the instance to be classified
762   * @return predicted class probability distribution
763   * @throws Exception if instance could not be classified
764   * successfully
765   */
766  public double [] distributionForInstance(Instance instance) 
767    throws Exception {
768   
769    double [] pred = m_Classifier.distributionForInstance(instance);
770    double prob = pred[m_DesignatedClass];
771
772    // Warp probability
773    if (prob > m_BestThreshold) {
774      prob = 0.5 + (prob - m_BestThreshold) / 
775        ((m_HighThreshold - m_BestThreshold) * 2);
776    } else {
777      prob = (prob - m_LowThreshold) / 
778        ((m_BestThreshold - m_LowThreshold) * 2);
779    }
780    if (prob < 0) {
781      prob = 0.0;
782    } else if (prob > 1) {
783      prob = 1.0;
784    }
785
786    // Alter the distribution
787    pred[m_DesignatedClass] = prob;
788    if (pred.length == 2) { // Handle case when there's only one class
789      pred[(m_DesignatedClass + 1) % 2] = 1.0 - prob;
790    }
791    return pred;
792  }
793
794  /**
795   * @return a description of the classifier suitable for
796   * displaying in the explorer/experimenter gui
797   */
798  public String globalInfo() {
799
800    return "A metaclassifier that selecting a mid-point threshold on the "
801      + "probability output by a Classifier. The midpoint "
802      + "threshold is set so that a given performance measure is optimized. "
803      + "Currently this is the F-measure. Performance is measured either on "
804      + "the training data, a hold-out set or using cross-validation. In "
805      + "addition, the probabilities returned by the base learner can "
806      + "have their range expanded so that the output probabilities will "
807      + "reside between 0 and 1 (this is useful if the scheme normally "
808      + "produces probabilities in a very narrow range).";
809  }
810   
811  /**
812   * @return tip text for this property suitable for
813   * displaying in the explorer/experimenter gui
814   */
815  public String designatedClassTipText() {
816
817    return "Sets the class value for which the optimization is performed. "
818      + "The options are: pick the first class value; pick the second "
819      + "class value; pick whichever class is least frequent; pick whichever "
820      + "class value is most frequent; pick the first class named any of "
821      + "\"yes\",\"pos(itive)\", \"1\", or the least frequent if no matches).";
822  }
823
824  /**
825   * Gets the method to determine which class value to optimize. Will
826   * be one of OPTIMIZE_0, OPTIMIZE_1, OPTIMIZE_LFREQ, OPTIMIZE_MFREQ,
827   * OPTIMIZE_POS_NAME.
828   *
829   * @return the class selection mode.
830   */
831  public SelectedTag getDesignatedClass() {
832
833    return new SelectedTag(m_ClassMode, TAGS_OPTIMIZE);
834  }
835 
836  /**
837   * Sets the method to determine which class value to optimize. Will
838   * be one of OPTIMIZE_0, OPTIMIZE_1, OPTIMIZE_LFREQ, OPTIMIZE_MFREQ,
839   * OPTIMIZE_POS_NAME.
840   *
841   * @param newMethod the new class selection mode.
842   */
843  public void setDesignatedClass(SelectedTag newMethod) {
844   
845    if (newMethod.getTags() == TAGS_OPTIMIZE) {
846      m_ClassMode = newMethod.getSelectedTag().getID();
847    }
848  }
849
850  /**
851   * @return tip text for this property suitable for
852   * displaying in the explorer/experimenter gui
853   */
854  public String evaluationModeTipText() {
855
856    return "Sets the method used to determine the threshold/performance "
857      + "curve. The options are: perform optimization based on the entire "
858      + "training set (may result in overfitting); perform an n-fold "
859      + "cross-validation (may be time consuming); perform one fold of "
860      + "an n-fold cross-validation (faster but likely less accurate).";
861  }
862
863  /**
864   * Sets the evaluation mode used. Will be one of
865   * EVAL_TRAINING, EVAL_TUNED_SPLIT, or EVAL_CROSS_VALIDATION
866   *
867   * @param newMethod the new evaluation mode.
868   */
869  public void setEvaluationMode(SelectedTag newMethod) {
870   
871    if (newMethod.getTags() == TAGS_EVAL) {
872      m_EvalMode = newMethod.getSelectedTag().getID();
873    }
874  }
875
876  /**
877   * Gets the evaluation mode used. Will be one of
878   * EVAL_TRAINING, EVAL_TUNED_SPLIT, or EVAL_CROSS_VALIDATION
879   *
880   * @return the evaluation mode.
881   */
882  public SelectedTag getEvaluationMode() {
883
884    return new SelectedTag(m_EvalMode, TAGS_EVAL);
885  }
886
887  /**
888   * @return tip text for this property suitable for
889   * displaying in the explorer/experimenter gui
890   */
891  public String rangeCorrectionTipText() {
892
893    return "Sets the type of prediction range correction performed. "
894      + "The options are: do not do any range correction; "
895      + "expand predicted probabilities so that the minimum probability "
896      + "observed during the optimization maps to 0, and the maximum "
897      + "maps to 1 (values outside this range are clipped to 0 and 1).";
898  }
899
900  /**
901   * Sets the confidence range correction mode used. Will be one of
902   * RANGE_NONE, or RANGE_BOUNDS
903   *
904   * @param newMethod the new correciton mode.
905   */
906  public void setRangeCorrection(SelectedTag newMethod) {
907   
908    if (newMethod.getTags() == TAGS_RANGE) {
909      m_RangeMode = newMethod.getSelectedTag().getID();
910    }
911  }
912
913  /**
914   * Gets the confidence range correction mode used. Will be one of
915   * RANGE_NONE, or RANGE_BOUNDS
916   *
917   * @return the confidence correction mode.
918   */
919  public SelectedTag getRangeCorrection() {
920
921    return new SelectedTag(m_RangeMode, TAGS_RANGE);
922  }
923
924  /**
925   * @return tip text for this property suitable for
926   * displaying in the explorer/experimenter gui
927   */
928  public String numXValFoldsTipText() {
929
930    return "Sets the number of folds used during full cross-validation "
931      + "and tuned fold evaluation. This number will be automatically "
932      + "reduced if there are insufficient positive examples.";
933  }
934
935  /**
936   * Get the number of folds used for cross-validation.
937   *
938   * @return the number of folds used for cross-validation.
939   */
940  public int getNumXValFolds() {
941   
942    return m_NumXValFolds;
943  }
944 
945  /**
946   * Set the number of folds used for cross-validation.
947   *
948   * @param newNumFolds the number of folds used for cross-validation.
949   */
950  public void setNumXValFolds(int newNumFolds) {
951   
952    if (newNumFolds < 2) {
953      throw new IllegalArgumentException("Number of folds must be greater than 1");
954    }
955    m_NumXValFolds = newNumFolds;
956  }
957
958  /**
959   * Returns the type of graph this classifier
960   * represents.
961   * 
962   * @return the type of graph this classifier represents
963   */   
964  public int graphType() {
965   
966    if (m_Classifier instanceof Drawable)
967      return ((Drawable)m_Classifier).graphType();
968    else 
969      return Drawable.NOT_DRAWABLE;
970  }
971
972  /**
973   * Returns graph describing the classifier (if possible).
974   *
975   * @return the graph of the classifier in dotty format
976   * @throws Exception if the classifier cannot be graphed
977   */
978  public String graph() throws Exception {
979   
980    if (m_Classifier instanceof Drawable)
981      return ((Drawable)m_Classifier).graph();
982    else throw new Exception("Classifier: " + getClassifierSpec()
983                             + " cannot be graphed");
984  }
985
986  /**
987   * @return tip text for this property suitable for
988   * displaying in the explorer/experimenter gui
989   */
990  public String manualThresholdValueTipText() {
991
992    return "Sets a manual threshold value to use. "
993      + "If this is set (non-negative value between 0 and 1), then "
994      + "all options pertaining to automatic threshold selection are "
995      + "ignored. ";
996  }
997
998  /**
999   * Sets the value for a manual threshold. If this option
1000   * is set (non-negative value between 0 and 1), then options
1001   * pertaining to automatic threshold selection are ignored.
1002   *
1003   * @param threshold the manual threshold to use
1004   */
1005  public void setManualThresholdValue(double threshold) throws Exception {
1006    m_manualThresholdValue = threshold;
1007    if (threshold >= 0.0 && threshold <= 1.0) {
1008      m_manualThreshold = true;
1009    } else {
1010      m_manualThreshold = false;
1011      if (threshold >= 0) {
1012        throw new IllegalArgumentException("Threshold must be in the "
1013                                           + "range 0..1.");
1014      }
1015    }
1016  }
1017
1018  /**
1019   * Returns the value of the manual threshold. (a negative
1020   * value indicates that no manual threshold is being used.
1021   *
1022   * @return the value of the manual threshold.
1023   */
1024  public double getManualThresholdValue() {
1025    return m_manualThresholdValue;
1026  }
1027 
1028  /**
1029   * Returns description of the cross-validated classifier.
1030   *
1031   * @return description of the cross-validated classifier as a string
1032   */
1033  public String toString() {
1034
1035    if (m_BestValue == -Double.MAX_VALUE)
1036      return "ThresholdSelector: No model built yet.";
1037
1038    String result = "Threshold Selector.\n"
1039    + "Classifier: " + m_Classifier.getClass().getName() + "\n";
1040
1041    result += "Index of designated class: " + m_DesignatedClass + "\n";
1042
1043    if (m_manualThreshold) {
1044      result += "User supplied threshold: " + m_BestThreshold + "\n";
1045    } else {
1046      result += "Evaluation mode: ";
1047      switch (m_EvalMode) {
1048      case EVAL_CROSS_VALIDATION:
1049        result += m_NumXValFolds + "-fold cross-validation";
1050        break;
1051      case EVAL_TUNED_SPLIT:
1052        result += "tuning on 1/" + m_NumXValFolds + " of the data";
1053        break;
1054      case EVAL_TRAINING_SET:
1055      default:
1056        result += "tuning on the training data";
1057      }
1058      result += "\n";
1059
1060      result += "Threshold: " + m_BestThreshold + "\n";
1061      result += "Best value: " + m_BestValue + "\n";
1062      if (m_RangeMode == RANGE_BOUNDS) {
1063        result += "Expanding range [" + m_LowThreshold + "," + m_HighThreshold
1064          + "] to [0, 1]\n";
1065      }
1066      result += "Measure: " + getMeasure().getSelectedTag().getReadable() + "\n";
1067    }
1068    result += m_Classifier.toString();
1069    return result;
1070  }
1071 
1072  /**
1073   * Returns the revision string.
1074   *
1075   * @return            the revision
1076   */
1077  public String getRevision() {
1078    return RevisionUtils.extract("$Revision: 1.43 $");
1079  }
1080 
1081  /**
1082   * Main method for testing this class.
1083   *
1084   * @param argv the options
1085   */
1086  public static void main(String [] argv) {
1087    runClassifier(new ThresholdSelector(), argv);
1088  }
1089}
1090
Note: See TracBrowser for help on using the repository browser.