source: src/main/java/weka/classifiers/evaluation/CostCurve.java @ 15

Last change on this file since 15 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 5.8 KB
RevLine 
[4]1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    CostCurve.java
19 *    Copyright (C) 2001 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.evaluation;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.core.Attribute;
28import weka.core.FastVector;
29import weka.core.Instance;
30import weka.core.DenseInstance;
31import weka.core.Instances;
32import weka.core.RevisionHandler;
33import weka.core.RevisionUtils;
34
35/**
36 * Generates points illustrating probablity cost tradeoffs that can be
37 * obtained by varying the threshold value between classes. For example,
38 * the typical threshold value of 0.5 means the predicted probability of
39 * "positive" must be higher than 0.5 for the instance to be predicted as
40 * "positive".
41 *
42 * @author Mark Hall (mhall@cs.waikato.ac.nz)
43 * @version $Revision: 5987 $
44 */
45
46public class CostCurve 
47  implements RevisionHandler {
48
49  /** The name of the relation used in cost curve datasets */
50  public static final String RELATION_NAME = "CostCurve";
51
52  /** attribute name: Probability Cost Function */
53  public static final String PROB_COST_FUNC_NAME = "Probability Cost Function";
54  /** attribute name: Normalized Expected Cost */
55  public static final String NORM_EXPECTED_COST_NAME = "Normalized Expected Cost";
56  /** attribute name: Threshold */
57  public static final String THRESHOLD_NAME = "Threshold";
58
59  /**
60   * Calculates the performance stats for the default class and return
61   * results as a set of Instances. The
62   * structure of these Instances is as follows:<p> <ul>
63   * <li> <b>Probability Cost Function </b>
64   * <li> <b>Normalized Expected Cost</b>
65   * <li> <b>Threshold</b> contains the probability threshold that gives
66   * rise to the previous performance values.
67   * </ul> <p>
68   *
69   * @see TwoClassStats
70   * @param predictions the predictions to base the curve on
71   * @return datapoints as a set of instances, null if no predictions
72   * have been made.
73   */
74  public Instances getCurve(FastVector predictions) {
75
76    if (predictions.size() == 0) {
77      return null;
78    }
79    return getCurve(predictions, 
80                    ((NominalPrediction)predictions.elementAt(0))
81                    .distribution().length - 1);
82  }
83
84  /**
85   * Calculates the performance stats for the desired class and return
86   * results as a set of Instances.
87   *
88   * @param predictions the predictions to base the curve on
89   * @param classIndex index of the class of interest.
90   * @return datapoints as a set of instances.
91   */
92  public Instances getCurve(FastVector predictions, int classIndex) {
93
94    if ((predictions.size() == 0) ||
95        (((NominalPrediction)predictions.elementAt(0))
96         .distribution().length <= classIndex)) {
97      return null;
98    }
99   
100    ThresholdCurve tc = new ThresholdCurve();
101    Instances threshInst = tc.getCurve(predictions, classIndex);
102
103    Instances insts = makeHeader();
104    int fpind = threshInst.attribute(ThresholdCurve.FP_RATE_NAME).index();
105    int tpind = threshInst.attribute(ThresholdCurve.TP_RATE_NAME).index();
106    int threshind = threshInst.attribute(ThresholdCurve.THRESHOLD_NAME).index();
107   
108    double [] vals;
109    double fpval, tpval, thresh;
110    for (int i = 0; i< threshInst.numInstances(); i++) {
111      fpval = threshInst.instance(i).value(fpind);
112      tpval = threshInst.instance(i).value(tpind);
113      thresh = threshInst.instance(i).value(threshind);
114      vals = new double [3];
115      vals[0] = 0; vals[1] = fpval; vals[2] = thresh;
116      insts.add(new DenseInstance(1.0, vals));
117      vals = new double [3];
118      vals[0] = 1; vals[1] = 1.0 - tpval; vals[2] = thresh;
119      insts.add(new DenseInstance(1.0, vals));
120    }
121   
122    return insts;
123  }
124
125  /**
126   * generates the header
127   *
128   * @return the header
129   */
130  private Instances makeHeader() {
131
132    FastVector fv = new FastVector();
133    fv.addElement(new Attribute(PROB_COST_FUNC_NAME));
134    fv.addElement(new Attribute(NORM_EXPECTED_COST_NAME));
135    fv.addElement(new Attribute(THRESHOLD_NAME));
136    return new Instances(RELATION_NAME, fv, 100);
137  }
138 
139  /**
140   * Returns the revision string.
141   *
142   * @return            the revision
143   */
144  public String getRevision() {
145    return RevisionUtils.extract("$Revision: 5987 $");
146  }
147
148  /**
149   * Tests the CostCurve generation from the command line.
150   * The classifier is currently hardcoded. Pipe in an arff file.
151   *
152   * @param args currently ignored
153   */
154  public static void main(String [] args) {
155
156    try {
157     
158      Instances inst = new Instances(new java.io.InputStreamReader(System.in));
159     
160      inst.setClassIndex(inst.numAttributes() - 1);
161      CostCurve cc = new CostCurve();
162      EvaluationUtils eu = new EvaluationUtils();
163      Classifier classifier = new weka.classifiers.functions.Logistic();
164      FastVector predictions = new FastVector();
165      for (int i = 0; i < 2; i++) { // Do two runs.
166        eu.setSeed(i);
167        predictions.appendElements(eu.getCVPredictions(classifier, inst, 10));
168        //System.out.println("\n\n\n");
169      }
170      Instances result = cc.getCurve(predictions);
171      System.out.println(result);
172     
173    } catch (Exception ex) {
174      ex.printStackTrace();
175    }
176  }
177}
Note: See TracBrowser for help on using the repository browser.