source: src/main/java/weka/classifiers/evaluation/ThresholdCurve.java @ 23

Last change on this file since 23 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 15.7 KB
RevLine 
[4]1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    ThresholdCurve.java
19 *    Copyright (C) 2002 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.evaluation;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.core.Attribute;
28import weka.core.FastVector;
29import weka.core.Instance;
30import weka.core.DenseInstance;
31import weka.core.Instances;
32import weka.core.RevisionHandler;
33import weka.core.RevisionUtils;
34import weka.core.Utils;
35
36/**
37 * Generates points illustrating prediction tradeoffs that can be obtained
38 * by varying the threshold value between classes. For example, the typical
39 * threshold value of 0.5 means the predicted probability of "positive" must be
40 * higher than 0.5 for the instance to be predicted as "positive". The
41 * resulting dataset can be used to visualize precision/recall tradeoff, or
42 * for ROC curve analysis (true positive rate vs false positive rate).
43 * Weka just varies the threshold on the class probability estimates in each
44 * case. The Mann Whitney statistic is used to calculate the AUC.
45 *
46 * @author Len Trigg (len@reeltwo.com)
47 * @version $Revision: 5987 $
48 */
49public class ThresholdCurve
50  implements RevisionHandler {
51
52  /** The name of the relation used in threshold curve datasets */
53  public static final String RELATION_NAME = "ThresholdCurve";
54
55  /** attribute name: True Positives */
56  public static final String TRUE_POS_NAME  = "True Positives";
57  /** attribute name: False Negatives */
58  public static final String FALSE_NEG_NAME = "False Negatives";
59  /** attribute name: False Positives */
60  public static final String FALSE_POS_NAME = "False Positives";
61  /** attribute name: True Negatives */
62  public static final String TRUE_NEG_NAME  = "True Negatives";
63  /** attribute name: False Positive Rate" */
64  public static final String FP_RATE_NAME   = "False Positive Rate";
65  /** attribute name: True Positive Rate */
66  public static final String TP_RATE_NAME   = "True Positive Rate";
67  /** attribute name: Precision */
68  public static final String PRECISION_NAME = "Precision";
69  /** attribute name: Recall */
70  public static final String RECALL_NAME    = "Recall";
71  /** attribute name: Fallout */
72  public static final String FALLOUT_NAME   = "Fallout";
73  /** attribute name: FMeasure */
74  public static final String FMEASURE_NAME  = "FMeasure";
75  /** attribute name: Sample Size */
76  public static final String SAMPLE_SIZE_NAME = "Sample Size";
77  /** attribute name: Lift */
78  public static final String LIFT_NAME = "Lift";
79  /** attribute name: Threshold */
80  public static final String THRESHOLD_NAME = "Threshold";
81
82  /**
83   * Calculates the performance stats for the default class and return
84   * results as a set of Instances. The
85   * structure of these Instances is as follows:<p> <ul>
86   * <li> <b>True Positives </b>
87   * <li> <b>False Negatives</b>
88   * <li> <b>False Positives</b>
89   * <li> <b>True Negatives</b>
90   * <li> <b>False Positive Rate</b>
91   * <li> <b>True Positive Rate</b>
92   * <li> <b>Precision</b>
93   * <li> <b>Recall</b> 
94   * <li> <b>Fallout</b> 
95   * <li> <b>Threshold</b> contains the probability threshold that gives
96   * rise to the previous performance values.
97   * </ul> <p>
98   * For the definitions of these measures, see TwoClassStats <p>
99   *
100   * @see TwoClassStats
101   * @param predictions the predictions to base the curve on
102   * @return datapoints as a set of instances, null if no predictions
103   * have been made.
104   */
105  public Instances getCurve(FastVector predictions) {
106
107    if (predictions.size() == 0) {
108      return null;
109    }
110    return getCurve(predictions, 
111                    ((NominalPrediction)predictions.elementAt(0))
112                    .distribution().length - 1);
113  }
114
115  /**
116   * Calculates the performance stats for the desired class and return
117   * results as a set of Instances.
118   *
119   * @param predictions the predictions to base the curve on
120   * @param classIndex index of the class of interest.
121   * @return datapoints as a set of instances.
122   */
123  public Instances getCurve(FastVector predictions, int classIndex) {
124
125    if ((predictions.size() == 0) ||
126        (((NominalPrediction)predictions.elementAt(0))
127         .distribution().length <= classIndex)) {
128      return null;
129    }
130
131    double totPos = 0, totNeg = 0;
132    double [] probs = getProbabilities(predictions, classIndex);
133
134    // Get distribution of positive/negatives
135    for (int i = 0; i < probs.length; i++) {
136      NominalPrediction pred = (NominalPrediction)predictions.elementAt(i);
137      if (pred.actual() == Prediction.MISSING_VALUE) {
138        System.err.println(getClass().getName() 
139                           + " Skipping prediction with missing class value");
140        continue;
141      }
142      if (pred.weight() < 0) {
143        System.err.println(getClass().getName() 
144                           + " Skipping prediction with negative weight");
145        continue;
146      }
147      if (pred.actual() == classIndex) {
148        totPos += pred.weight();
149      } else {
150        totNeg += pred.weight();
151      }
152    }
153
154    Instances insts = makeHeader();
155    int [] sorted = Utils.sort(probs);
156    TwoClassStats tc = new TwoClassStats(totPos, totNeg, 0, 0);
157    double threshold = 0;
158    double cumulativePos = 0;
159    double cumulativeNeg = 0;
160    for (int i = 0; i < sorted.length; i++) {
161
162      if ((i == 0) || (probs[sorted[i]] > threshold)) {
163        tc.setTruePositive(tc.getTruePositive() - cumulativePos);
164        tc.setFalseNegative(tc.getFalseNegative() + cumulativePos);
165        tc.setFalsePositive(tc.getFalsePositive() - cumulativeNeg);
166        tc.setTrueNegative(tc.getTrueNegative() + cumulativeNeg);
167        threshold = probs[sorted[i]];
168        insts.add(makeInstance(tc, threshold));
169        cumulativePos = 0;
170        cumulativeNeg = 0;
171        if (i == sorted.length - 1) {
172          break;
173        }
174      }
175
176      NominalPrediction pred = (NominalPrediction)predictions.elementAt(sorted[i]);
177
178      if (pred.actual() == Prediction.MISSING_VALUE) {
179        System.err.println(getClass().getName()
180                           + " Skipping prediction with missing class value");
181        continue;
182      }
183      if (pred.weight() < 0) {
184        System.err.println(getClass().getName() 
185                           + " Skipping prediction with negative weight");
186        continue;
187      }
188      if (pred.actual() == classIndex) {
189        cumulativePos += pred.weight();
190      } else {
191        cumulativeNeg += pred.weight();
192      }
193
194      /*
195      System.out.println(tc + " " + probs[sorted[i]]
196                         + " " + (pred.actual() == classIndex));
197      */
198      /*if ((i != (sorted.length - 1)) &&
199          ((i == 0) || 
200          (probs[sorted[i]] != probs[sorted[i - 1]]))) {
201        insts.add(makeInstance(tc, probs[sorted[i]]));
202        }*/
203    }
204    return insts;
205  }
206
207  /**
208   * Calculates the n point precision result, which is the precision averaged
209   * over n evenly spaced (w.r.t recall) samples of the curve.
210   *
211   * @param tcurve a previously extracted threshold curve Instances.
212   * @param n the number of points to average over.
213   * @return the n-point precision.
214   */
215  public static double getNPointPrecision(Instances tcurve, int n) {
216
217    if (!RELATION_NAME.equals(tcurve.relationName()) 
218        || (tcurve.numInstances() == 0)) {
219      return Double.NaN;
220    }
221    int recallInd = tcurve.attribute(RECALL_NAME).index();
222    int precisInd = tcurve.attribute(PRECISION_NAME).index();
223    double [] recallVals = tcurve.attributeToDoubleArray(recallInd);
224    int [] sorted = Utils.sort(recallVals);
225    double isize = 1.0 / (n - 1);
226    double psum = 0;
227    for (int i = 0; i < n; i++) {
228      int pos = binarySearch(sorted, recallVals, i * isize);
229      double recall = recallVals[sorted[pos]];
230      double precis = tcurve.instance(sorted[pos]).value(precisInd);
231      /*
232      System.err.println("Point " + (i + 1) + ": i=" + pos
233                         + " r=" + (i * isize)
234                         + " p'=" + precis
235                         + " r'=" + recall);
236      */
237      // interpolate figures for non-endpoints
238      while ((pos != 0) && (pos < sorted.length - 1)) {
239        pos++;
240        double recall2 = recallVals[sorted[pos]];
241        if (recall2 != recall) {
242          double precis2 = tcurve.instance(sorted[pos]).value(precisInd);
243          double slope = (precis2 - precis) / (recall2 - recall);
244          double offset = precis - recall * slope;
245          precis = isize * i * slope + offset;
246          /*
247          System.err.println("Point2 " + (i + 1) + ": i=" + pos
248                             + " r=" + (i * isize)
249                             + " p'=" + precis2
250                             + " r'=" + recall2
251                             + " p''=" + precis);
252          */
253          break;
254        }
255      }
256      psum += precis;
257    }
258    return psum / n;
259  }
260
261  /**
262   * Calculates the area under the ROC curve as the Wilcoxon-Mann-Whitney statistic.
263   *
264   * @param tcurve a previously extracted threshold curve Instances.
265   * @return the ROC area, or Double.NaN if you don't pass in
266   * a ThresholdCurve generated Instances.
267   */
268  public static double getROCArea(Instances tcurve) {
269
270    final int n = tcurve.numInstances();
271    if (!RELATION_NAME.equals(tcurve.relationName()) 
272        || (n == 0)) {
273      return Double.NaN;
274    }
275    final int tpInd = tcurve.attribute(TRUE_POS_NAME).index();
276    final int fpInd = tcurve.attribute(FALSE_POS_NAME).index();
277    final double [] tpVals = tcurve.attributeToDoubleArray(tpInd);
278    final double [] fpVals = tcurve.attributeToDoubleArray(fpInd);
279
280    double area = 0.0, cumNeg = 0.0;
281    final double totalPos = tpVals[0];
282    final double totalNeg = fpVals[0];
283    for (int i = 0; i < n; i++) {
284        double cip, cin;
285        if (i < n - 1) {
286            cip = tpVals[i] - tpVals[i + 1];
287            cin = fpVals[i] - fpVals[i + 1];
288        } else {
289            cip = tpVals[n - 1];
290            cin = fpVals[n - 1];
291        }
292        area += cip * (cumNeg + (0.5 * cin));
293        cumNeg += cin;
294    }
295    area /= (totalNeg * totalPos);
296
297    return area;
298  }
299
300  /**
301   * Gets the index of the instance with the closest threshold value to the
302   * desired target
303   *
304   * @param tcurve a set of instances that have been generated by this class
305   * @param threshold the target threshold
306   * @return the index of the instance that has threshold closest to
307   * the target, or -1 if this could not be found (i.e. no data, or
308   * bad threshold target)
309   */
310  public static int getThresholdInstance(Instances tcurve, double threshold) {
311
312    if (!RELATION_NAME.equals(tcurve.relationName()) 
313        || (tcurve.numInstances() == 0)
314        || (threshold < 0)
315        || (threshold > 1.0)) {
316      return -1;
317    }
318    if (tcurve.numInstances() == 1) {
319      return 0;
320    }
321    double [] tvals = tcurve.attributeToDoubleArray(tcurve.numAttributes() - 1);
322    int [] sorted = Utils.sort(tvals);
323    return binarySearch(sorted, tvals, threshold);
324  }
325
326  /**
327   * performs a binary search
328   *
329   * @param index the indices
330   * @param vals the values
331   * @param target the target to look for
332   * @return the index of the target
333   */
334  private static int binarySearch(int [] index, double [] vals, double target) {
335   
336    int lo = 0, hi = index.length - 1;
337    while (hi - lo > 1) {
338      int mid = lo + (hi - lo) / 2;
339      double midval = vals[index[mid]];
340      if (target > midval) {
341        lo = mid;
342      } else if (target < midval) {
343        hi = mid;
344      } else {
345        while ((mid > 0) && (vals[index[mid - 1]] == target)) {
346          mid --;
347        }
348        return mid;
349      }
350    }
351    return lo;
352  }
353
354  /**
355   *
356   * @param predictions the predictions to use
357   * @param classIndex the class index
358   * @return the probabilities
359   */
360  private double [] getProbabilities(FastVector predictions, int classIndex) {
361
362    // sort by predicted probability of the desired class.
363    double [] probs = new double [predictions.size()];
364    for (int i = 0; i < probs.length; i++) {
365      NominalPrediction pred = (NominalPrediction)predictions.elementAt(i);
366      probs[i] = pred.distribution()[classIndex];
367    }
368    return probs;
369  }
370
371  /**
372   * generates the header
373   *
374   * @return the header
375   */
376  private Instances makeHeader() {
377
378    FastVector fv = new FastVector();
379    fv.addElement(new Attribute(TRUE_POS_NAME));
380    fv.addElement(new Attribute(FALSE_NEG_NAME));
381    fv.addElement(new Attribute(FALSE_POS_NAME));
382    fv.addElement(new Attribute(TRUE_NEG_NAME));
383    fv.addElement(new Attribute(FP_RATE_NAME));
384    fv.addElement(new Attribute(TP_RATE_NAME));
385    fv.addElement(new Attribute(PRECISION_NAME));
386    fv.addElement(new Attribute(RECALL_NAME));
387    fv.addElement(new Attribute(FALLOUT_NAME));
388    fv.addElement(new Attribute(FMEASURE_NAME));
389    fv.addElement(new Attribute(SAMPLE_SIZE_NAME));
390    fv.addElement(new Attribute(LIFT_NAME));
391    fv.addElement(new Attribute(THRESHOLD_NAME));     
392    return new Instances(RELATION_NAME, fv, 100);
393  }
394 
395  /**
396   * generates an instance out of the given data
397   *
398   * @param tc the statistics
399   * @param prob the probability
400   * @return the generated instance
401   */
402  private Instance makeInstance(TwoClassStats tc, double prob) {
403
404    int count = 0;
405    double [] vals = new double[13];
406    vals[count++] = tc.getTruePositive();
407    vals[count++] = tc.getFalseNegative();
408    vals[count++] = tc.getFalsePositive();
409    vals[count++] = tc.getTrueNegative();
410    vals[count++] = tc.getFalsePositiveRate();
411    vals[count++] = tc.getTruePositiveRate();
412    vals[count++] = tc.getPrecision();
413    vals[count++] = tc.getRecall();
414    vals[count++] = tc.getFallout();
415    vals[count++] = tc.getFMeasure();
416      double ss = (tc.getTruePositive() + tc.getFalsePositive()) / 
417        (tc.getTruePositive() + tc.getFalsePositive() + tc.getTrueNegative() + tc.getFalseNegative());
418    vals[count++] = ss;
419    double expectedByChance = (ss * (tc.getTruePositive() + tc.getFalseNegative()));
420    if (expectedByChance < 1) {
421      vals[count++] = Utils.missingValue();
422    } else {
423    vals[count++] = tc.getTruePositive() / expectedByChance; 
424     
425    }
426    vals[count++] = prob;
427    return new DenseInstance(1.0, vals);
428  }
429 
430  /**
431   * Returns the revision string.
432   *
433   * @return            the revision
434   */
435  public String getRevision() {
436    return RevisionUtils.extract("$Revision: 5987 $");
437  }
438 
439  /**
440   * Tests the ThresholdCurve generation from the command line.
441   * The classifier is currently hardcoded. Pipe in an arff file.
442   *
443   * @param args currently ignored
444   */
445  public static void main(String [] args) {
446
447    try {
448     
449      Instances inst = new Instances(new java.io.InputStreamReader(System.in));
450      if (false) {
451        System.out.println(ThresholdCurve.getNPointPrecision(inst, 11));
452      } else {
453        inst.setClassIndex(inst.numAttributes() - 1);
454        ThresholdCurve tc = new ThresholdCurve();
455        EvaluationUtils eu = new EvaluationUtils();
456        Classifier classifier = new weka.classifiers.functions.Logistic();
457        FastVector predictions = new FastVector();
458        for (int i = 0; i < 2; i++) { // Do two runs.
459          eu.setSeed(i);
460          predictions.appendElements(eu.getCVPredictions(classifier, inst, 10));
461          //System.out.println("\n\n\n");
462        }
463        Instances result = tc.getCurve(predictions);
464        System.out.println(result);
465      }
466    } catch (Exception ex) {
467      ex.printStackTrace();
468    }
469  }
470}
Note: See TracBrowser for help on using the repository browser.