| 1 | /* | 
|---|
| 2 |  *    This program is free software; you can redistribute it and/or modify | 
|---|
| 3 |  *    it under the terms of the GNU General Public License as published by | 
|---|
| 4 |  *    the Free Software Foundation; either version 2 of the License, or | 
|---|
| 5 |  *    (at your option) any later version. | 
|---|
| 6 |  * | 
|---|
| 7 |  *    This program is distributed in the hope that it will be useful, | 
|---|
| 8 |  *    but WITHOUT ANY WARRANTY; without even the implied warranty of | 
|---|
| 9 |  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | 
|---|
| 10 |  *    GNU General Public License for more details. | 
|---|
| 11 |  * | 
|---|
| 12 |  *    You should have received a copy of the GNU General Public License | 
|---|
| 13 |  *    along with this program; if not, write to the Free Software | 
|---|
| 14 |  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. | 
|---|
| 15 |  */ | 
|---|
| 16 |  | 
|---|
| 17 | /* | 
|---|
| 18 |  *    CostCurve.java | 
|---|
| 19 |  *    Copyright (C) 2001 University of Waikato, Hamilton, New Zealand | 
|---|
| 20 |  * | 
|---|
| 21 |  */ | 
|---|
| 22 |  | 
|---|
| 23 | package weka.classifiers.evaluation; | 
|---|
| 24 |  | 
|---|
| 25 | import weka.classifiers.Classifier; | 
|---|
| 26 | import weka.classifiers.AbstractClassifier; | 
|---|
| 27 | import weka.core.Attribute; | 
|---|
| 28 | import weka.core.FastVector; | 
|---|
| 29 | import weka.core.Instance; | 
|---|
| 30 | import weka.core.DenseInstance; | 
|---|
| 31 | import weka.core.Instances; | 
|---|
| 32 | import weka.core.RevisionHandler; | 
|---|
| 33 | import weka.core.RevisionUtils; | 
|---|
| 34 |  | 
|---|
| 35 | /** | 
|---|
| 36 |  * Generates points illustrating probablity cost tradeoffs that can be  | 
|---|
| 37 |  * obtained by varying the threshold value between classes. For example,  | 
|---|
| 38 |  * the typical threshold value of 0.5 means the predicted probability of  | 
|---|
| 39 |  * "positive" must be higher than 0.5 for the instance to be predicted as  | 
|---|
| 40 |  * "positive". | 
|---|
| 41 |  * | 
|---|
| 42 |  * @author Mark Hall (mhall@cs.waikato.ac.nz) | 
|---|
| 43 |  * @version $Revision: 5987 $ | 
|---|
| 44 |  */ | 
|---|
| 45 |  | 
|---|
| 46 | public class CostCurve  | 
|---|
| 47 |   implements RevisionHandler { | 
|---|
| 48 |  | 
|---|
| 49 |   /** The name of the relation used in cost curve datasets */ | 
|---|
| 50 |   public static final String RELATION_NAME = "CostCurve"; | 
|---|
| 51 |  | 
|---|
| 52 |   /** attribute name: Probability Cost Function */ | 
|---|
| 53 |   public static final String PROB_COST_FUNC_NAME = "Probability Cost Function"; | 
|---|
| 54 |   /** attribute name: Normalized Expected Cost */ | 
|---|
| 55 |   public static final String NORM_EXPECTED_COST_NAME = "Normalized Expected Cost"; | 
|---|
| 56 |   /** attribute name: Threshold */ | 
|---|
| 57 |   public static final String THRESHOLD_NAME = "Threshold"; | 
|---|
| 58 |  | 
|---|
| 59 |   /** | 
|---|
| 60 |    * Calculates the performance stats for the default class and return  | 
|---|
| 61 |    * results as a set of Instances. The | 
|---|
| 62 |    * structure of these Instances is as follows:<p> <ul>  | 
|---|
| 63 |    * <li> <b>Probability Cost Function </b> | 
|---|
| 64 |    * <li> <b>Normalized Expected Cost</b> | 
|---|
| 65 |    * <li> <b>Threshold</b> contains the probability threshold that gives | 
|---|
| 66 |    * rise to the previous performance values.  | 
|---|
| 67 |    * </ul> <p> | 
|---|
| 68 |    * | 
|---|
| 69 |    * @see TwoClassStats | 
|---|
| 70 |    * @param predictions the predictions to base the curve on | 
|---|
| 71 |    * @return datapoints as a set of instances, null if no predictions | 
|---|
| 72 |    * have been made. | 
|---|
| 73 |    */ | 
|---|
| 74 |   public Instances getCurve(FastVector predictions) { | 
|---|
| 75 |  | 
|---|
| 76 |     if (predictions.size() == 0) { | 
|---|
| 77 |       return null; | 
|---|
| 78 |     } | 
|---|
| 79 |     return getCurve(predictions,  | 
|---|
| 80 |                     ((NominalPrediction)predictions.elementAt(0)) | 
|---|
| 81 |                     .distribution().length - 1); | 
|---|
| 82 |   } | 
|---|
| 83 |  | 
|---|
| 84 |   /** | 
|---|
| 85 |    * Calculates the performance stats for the desired class and return  | 
|---|
| 86 |    * results as a set of Instances. | 
|---|
| 87 |    * | 
|---|
| 88 |    * @param predictions the predictions to base the curve on | 
|---|
| 89 |    * @param classIndex index of the class of interest. | 
|---|
| 90 |    * @return datapoints as a set of instances. | 
|---|
| 91 |    */ | 
|---|
| 92 |   public Instances getCurve(FastVector predictions, int classIndex) { | 
|---|
| 93 |  | 
|---|
| 94 |     if ((predictions.size() == 0) || | 
|---|
| 95 |         (((NominalPrediction)predictions.elementAt(0)) | 
|---|
| 96 |          .distribution().length <= classIndex)) { | 
|---|
| 97 |       return null; | 
|---|
| 98 |     } | 
|---|
| 99 |      | 
|---|
| 100 |     ThresholdCurve tc = new ThresholdCurve(); | 
|---|
| 101 |     Instances threshInst = tc.getCurve(predictions, classIndex); | 
|---|
| 102 |  | 
|---|
| 103 |     Instances insts = makeHeader(); | 
|---|
| 104 |     int fpind = threshInst.attribute(ThresholdCurve.FP_RATE_NAME).index(); | 
|---|
| 105 |     int tpind = threshInst.attribute(ThresholdCurve.TP_RATE_NAME).index(); | 
|---|
| 106 |     int threshind = threshInst.attribute(ThresholdCurve.THRESHOLD_NAME).index(); | 
|---|
| 107 |      | 
|---|
| 108 |     double [] vals; | 
|---|
| 109 |     double fpval, tpval, thresh; | 
|---|
| 110 |     for (int i = 0; i< threshInst.numInstances(); i++) { | 
|---|
| 111 |       fpval = threshInst.instance(i).value(fpind); | 
|---|
| 112 |       tpval = threshInst.instance(i).value(tpind); | 
|---|
| 113 |       thresh = threshInst.instance(i).value(threshind); | 
|---|
| 114 |       vals = new double [3]; | 
|---|
| 115 |       vals[0] = 0; vals[1] = fpval; vals[2] = thresh; | 
|---|
| 116 |       insts.add(new DenseInstance(1.0, vals)); | 
|---|
| 117 |       vals = new double [3]; | 
|---|
| 118 |       vals[0] = 1; vals[1] = 1.0 - tpval; vals[2] = thresh; | 
|---|
| 119 |       insts.add(new DenseInstance(1.0, vals)); | 
|---|
| 120 |     } | 
|---|
| 121 |      | 
|---|
| 122 |     return insts; | 
|---|
| 123 |   } | 
|---|
| 124 |  | 
|---|
| 125 |   /** | 
|---|
| 126 |    * generates the header | 
|---|
| 127 |    *  | 
|---|
| 128 |    * @return the header | 
|---|
| 129 |    */ | 
|---|
| 130 |   private Instances makeHeader() { | 
|---|
| 131 |  | 
|---|
| 132 |     FastVector fv = new FastVector(); | 
|---|
| 133 |     fv.addElement(new Attribute(PROB_COST_FUNC_NAME)); | 
|---|
| 134 |     fv.addElement(new Attribute(NORM_EXPECTED_COST_NAME)); | 
|---|
| 135 |     fv.addElement(new Attribute(THRESHOLD_NAME)); | 
|---|
| 136 |     return new Instances(RELATION_NAME, fv, 100); | 
|---|
| 137 |   } | 
|---|
| 138 |    | 
|---|
| 139 |   /** | 
|---|
| 140 |    * Returns the revision string. | 
|---|
| 141 |    *  | 
|---|
| 142 |    * @return            the revision | 
|---|
| 143 |    */ | 
|---|
| 144 |   public String getRevision() { | 
|---|
| 145 |     return RevisionUtils.extract("$Revision: 5987 $"); | 
|---|
| 146 |   } | 
|---|
| 147 |  | 
|---|
| 148 |   /** | 
|---|
| 149 |    * Tests the CostCurve generation from the command line. | 
|---|
| 150 |    * The classifier is currently hardcoded. Pipe in an arff file. | 
|---|
| 151 |    * | 
|---|
| 152 |    * @param args currently ignored | 
|---|
| 153 |    */ | 
|---|
| 154 |   public static void main(String [] args) { | 
|---|
| 155 |  | 
|---|
| 156 |     try { | 
|---|
| 157 |        | 
|---|
| 158 |       Instances inst = new Instances(new java.io.InputStreamReader(System.in)); | 
|---|
| 159 |        | 
|---|
| 160 |       inst.setClassIndex(inst.numAttributes() - 1); | 
|---|
| 161 |       CostCurve cc = new CostCurve(); | 
|---|
| 162 |       EvaluationUtils eu = new EvaluationUtils(); | 
|---|
| 163 |       Classifier classifier = new weka.classifiers.functions.Logistic(); | 
|---|
| 164 |       FastVector predictions = new FastVector(); | 
|---|
| 165 |       for (int i = 0; i < 2; i++) { // Do two runs. | 
|---|
| 166 |         eu.setSeed(i); | 
|---|
| 167 |         predictions.appendElements(eu.getCVPredictions(classifier, inst, 10)); | 
|---|
| 168 |         //System.out.println("\n\n\n"); | 
|---|
| 169 |       } | 
|---|
| 170 |       Instances result = cc.getCurve(predictions); | 
|---|
| 171 |       System.out.println(result); | 
|---|
| 172 |        | 
|---|
| 173 |     } catch (Exception ex) { | 
|---|
| 174 |       ex.printStackTrace(); | 
|---|
| 175 |     } | 
|---|
| 176 |   } | 
|---|
| 177 | } | 
|---|