[4] | 1 | /* |
---|
| 2 | * This program is free software; you can redistribute it and/or modify |
---|
| 3 | * it under the terms of the GNU General Public License as published by |
---|
| 4 | * the Free Software Foundation; either version 2 of the License, or |
---|
| 5 | * (at your option) any later version. |
---|
| 6 | * |
---|
| 7 | * This program is distributed in the hope that it will be useful, |
---|
| 8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
| 9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
---|
| 10 | * GNU General Public License for more details. |
---|
| 11 | * |
---|
| 12 | * You should have received a copy of the GNU General Public License |
---|
| 13 | * along with this program; if not, write to the Free Software |
---|
| 14 | * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. |
---|
| 15 | */ |
---|
| 16 | |
---|
| 17 | /* |
---|
| 18 | * ThresholdCurve.java |
---|
| 19 | * Copyright (C) 2002 University of Waikato, Hamilton, New Zealand |
---|
| 20 | * |
---|
| 21 | */ |
---|
| 22 | |
---|
| 23 | package weka.classifiers.evaluation; |
---|
| 24 | |
---|
| 25 | import weka.classifiers.Classifier; |
---|
| 26 | import weka.classifiers.AbstractClassifier; |
---|
| 27 | import weka.core.Attribute; |
---|
| 28 | import weka.core.FastVector; |
---|
| 29 | import weka.core.Instance; |
---|
| 30 | import weka.core.DenseInstance; |
---|
| 31 | import weka.core.Instances; |
---|
| 32 | import weka.core.RevisionHandler; |
---|
| 33 | import weka.core.RevisionUtils; |
---|
| 34 | import weka.core.Utils; |
---|
| 35 | |
---|
| 36 | /** |
---|
| 37 | * Generates points illustrating prediction tradeoffs that can be obtained |
---|
| 38 | * by varying the threshold value between classes. For example, the typical |
---|
| 39 | * threshold value of 0.5 means the predicted probability of "positive" must be |
---|
| 40 | * higher than 0.5 for the instance to be predicted as "positive". The |
---|
| 41 | * resulting dataset can be used to visualize precision/recall tradeoff, or |
---|
| 42 | * for ROC curve analysis (true positive rate vs false positive rate). |
---|
| 43 | * Weka just varies the threshold on the class probability estimates in each |
---|
| 44 | * case. The Mann Whitney statistic is used to calculate the AUC. |
---|
| 45 | * |
---|
| 46 | * @author Len Trigg (len@reeltwo.com) |
---|
| 47 | * @version $Revision: 5987 $ |
---|
| 48 | */ |
---|
| 49 | public class ThresholdCurve |
---|
| 50 | implements RevisionHandler { |
---|
| 51 | |
---|
| 52 | /** The name of the relation used in threshold curve datasets */ |
---|
| 53 | public static final String RELATION_NAME = "ThresholdCurve"; |
---|
| 54 | |
---|
| 55 | /** attribute name: True Positives */ |
---|
| 56 | public static final String TRUE_POS_NAME = "True Positives"; |
---|
| 57 | /** attribute name: False Negatives */ |
---|
| 58 | public static final String FALSE_NEG_NAME = "False Negatives"; |
---|
| 59 | /** attribute name: False Positives */ |
---|
| 60 | public static final String FALSE_POS_NAME = "False Positives"; |
---|
| 61 | /** attribute name: True Negatives */ |
---|
| 62 | public static final String TRUE_NEG_NAME = "True Negatives"; |
---|
| 63 | /** attribute name: False Positive Rate" */ |
---|
| 64 | public static final String FP_RATE_NAME = "False Positive Rate"; |
---|
| 65 | /** attribute name: True Positive Rate */ |
---|
| 66 | public static final String TP_RATE_NAME = "True Positive Rate"; |
---|
| 67 | /** attribute name: Precision */ |
---|
| 68 | public static final String PRECISION_NAME = "Precision"; |
---|
| 69 | /** attribute name: Recall */ |
---|
| 70 | public static final String RECALL_NAME = "Recall"; |
---|
| 71 | /** attribute name: Fallout */ |
---|
| 72 | public static final String FALLOUT_NAME = "Fallout"; |
---|
| 73 | /** attribute name: FMeasure */ |
---|
| 74 | public static final String FMEASURE_NAME = "FMeasure"; |
---|
| 75 | /** attribute name: Sample Size */ |
---|
| 76 | public static final String SAMPLE_SIZE_NAME = "Sample Size"; |
---|
| 77 | /** attribute name: Lift */ |
---|
| 78 | public static final String LIFT_NAME = "Lift"; |
---|
| 79 | /** attribute name: Threshold */ |
---|
| 80 | public static final String THRESHOLD_NAME = "Threshold"; |
---|
| 81 | |
---|
| 82 | /** |
---|
| 83 | * Calculates the performance stats for the default class and return |
---|
| 84 | * results as a set of Instances. The |
---|
| 85 | * structure of these Instances is as follows:<p> <ul> |
---|
| 86 | * <li> <b>True Positives </b> |
---|
| 87 | * <li> <b>False Negatives</b> |
---|
| 88 | * <li> <b>False Positives</b> |
---|
| 89 | * <li> <b>True Negatives</b> |
---|
| 90 | * <li> <b>False Positive Rate</b> |
---|
| 91 | * <li> <b>True Positive Rate</b> |
---|
| 92 | * <li> <b>Precision</b> |
---|
| 93 | * <li> <b>Recall</b> |
---|
| 94 | * <li> <b>Fallout</b> |
---|
| 95 | * <li> <b>Threshold</b> contains the probability threshold that gives |
---|
| 96 | * rise to the previous performance values. |
---|
| 97 | * </ul> <p> |
---|
| 98 | * For the definitions of these measures, see TwoClassStats <p> |
---|
| 99 | * |
---|
| 100 | * @see TwoClassStats |
---|
| 101 | * @param predictions the predictions to base the curve on |
---|
| 102 | * @return datapoints as a set of instances, null if no predictions |
---|
| 103 | * have been made. |
---|
| 104 | */ |
---|
| 105 | public Instances getCurve(FastVector predictions) { |
---|
| 106 | |
---|
| 107 | if (predictions.size() == 0) { |
---|
| 108 | return null; |
---|
| 109 | } |
---|
| 110 | return getCurve(predictions, |
---|
| 111 | ((NominalPrediction)predictions.elementAt(0)) |
---|
| 112 | .distribution().length - 1); |
---|
| 113 | } |
---|
| 114 | |
---|
| 115 | /** |
---|
| 116 | * Calculates the performance stats for the desired class and return |
---|
| 117 | * results as a set of Instances. |
---|
| 118 | * |
---|
| 119 | * @param predictions the predictions to base the curve on |
---|
| 120 | * @param classIndex index of the class of interest. |
---|
| 121 | * @return datapoints as a set of instances. |
---|
| 122 | */ |
---|
| 123 | public Instances getCurve(FastVector predictions, int classIndex) { |
---|
| 124 | |
---|
| 125 | if ((predictions.size() == 0) || |
---|
| 126 | (((NominalPrediction)predictions.elementAt(0)) |
---|
| 127 | .distribution().length <= classIndex)) { |
---|
| 128 | return null; |
---|
| 129 | } |
---|
| 130 | |
---|
| 131 | double totPos = 0, totNeg = 0; |
---|
| 132 | double [] probs = getProbabilities(predictions, classIndex); |
---|
| 133 | |
---|
| 134 | // Get distribution of positive/negatives |
---|
| 135 | for (int i = 0; i < probs.length; i++) { |
---|
| 136 | NominalPrediction pred = (NominalPrediction)predictions.elementAt(i); |
---|
| 137 | if (pred.actual() == Prediction.MISSING_VALUE) { |
---|
| 138 | System.err.println(getClass().getName() |
---|
| 139 | + " Skipping prediction with missing class value"); |
---|
| 140 | continue; |
---|
| 141 | } |
---|
| 142 | if (pred.weight() < 0) { |
---|
| 143 | System.err.println(getClass().getName() |
---|
| 144 | + " Skipping prediction with negative weight"); |
---|
| 145 | continue; |
---|
| 146 | } |
---|
| 147 | if (pred.actual() == classIndex) { |
---|
| 148 | totPos += pred.weight(); |
---|
| 149 | } else { |
---|
| 150 | totNeg += pred.weight(); |
---|
| 151 | } |
---|
| 152 | } |
---|
| 153 | |
---|
| 154 | Instances insts = makeHeader(); |
---|
| 155 | int [] sorted = Utils.sort(probs); |
---|
| 156 | TwoClassStats tc = new TwoClassStats(totPos, totNeg, 0, 0); |
---|
| 157 | double threshold = 0; |
---|
| 158 | double cumulativePos = 0; |
---|
| 159 | double cumulativeNeg = 0; |
---|
| 160 | for (int i = 0; i < sorted.length; i++) { |
---|
| 161 | |
---|
| 162 | if ((i == 0) || (probs[sorted[i]] > threshold)) { |
---|
| 163 | tc.setTruePositive(tc.getTruePositive() - cumulativePos); |
---|
| 164 | tc.setFalseNegative(tc.getFalseNegative() + cumulativePos); |
---|
| 165 | tc.setFalsePositive(tc.getFalsePositive() - cumulativeNeg); |
---|
| 166 | tc.setTrueNegative(tc.getTrueNegative() + cumulativeNeg); |
---|
| 167 | threshold = probs[sorted[i]]; |
---|
| 168 | insts.add(makeInstance(tc, threshold)); |
---|
| 169 | cumulativePos = 0; |
---|
| 170 | cumulativeNeg = 0; |
---|
| 171 | if (i == sorted.length - 1) { |
---|
| 172 | break; |
---|
| 173 | } |
---|
| 174 | } |
---|
| 175 | |
---|
| 176 | NominalPrediction pred = (NominalPrediction)predictions.elementAt(sorted[i]); |
---|
| 177 | |
---|
| 178 | if (pred.actual() == Prediction.MISSING_VALUE) { |
---|
| 179 | System.err.println(getClass().getName() |
---|
| 180 | + " Skipping prediction with missing class value"); |
---|
| 181 | continue; |
---|
| 182 | } |
---|
| 183 | if (pred.weight() < 0) { |
---|
| 184 | System.err.println(getClass().getName() |
---|
| 185 | + " Skipping prediction with negative weight"); |
---|
| 186 | continue; |
---|
| 187 | } |
---|
| 188 | if (pred.actual() == classIndex) { |
---|
| 189 | cumulativePos += pred.weight(); |
---|
| 190 | } else { |
---|
| 191 | cumulativeNeg += pred.weight(); |
---|
| 192 | } |
---|
| 193 | |
---|
| 194 | /* |
---|
| 195 | System.out.println(tc + " " + probs[sorted[i]] |
---|
| 196 | + " " + (pred.actual() == classIndex)); |
---|
| 197 | */ |
---|
| 198 | /*if ((i != (sorted.length - 1)) && |
---|
| 199 | ((i == 0) || |
---|
| 200 | (probs[sorted[i]] != probs[sorted[i - 1]]))) { |
---|
| 201 | insts.add(makeInstance(tc, probs[sorted[i]])); |
---|
| 202 | }*/ |
---|
| 203 | } |
---|
| 204 | return insts; |
---|
| 205 | } |
---|
| 206 | |
---|
| 207 | /** |
---|
| 208 | * Calculates the n point precision result, which is the precision averaged |
---|
| 209 | * over n evenly spaced (w.r.t recall) samples of the curve. |
---|
| 210 | * |
---|
| 211 | * @param tcurve a previously extracted threshold curve Instances. |
---|
| 212 | * @param n the number of points to average over. |
---|
| 213 | * @return the n-point precision. |
---|
| 214 | */ |
---|
| 215 | public static double getNPointPrecision(Instances tcurve, int n) { |
---|
| 216 | |
---|
| 217 | if (!RELATION_NAME.equals(tcurve.relationName()) |
---|
| 218 | || (tcurve.numInstances() == 0)) { |
---|
| 219 | return Double.NaN; |
---|
| 220 | } |
---|
| 221 | int recallInd = tcurve.attribute(RECALL_NAME).index(); |
---|
| 222 | int precisInd = tcurve.attribute(PRECISION_NAME).index(); |
---|
| 223 | double [] recallVals = tcurve.attributeToDoubleArray(recallInd); |
---|
| 224 | int [] sorted = Utils.sort(recallVals); |
---|
| 225 | double isize = 1.0 / (n - 1); |
---|
| 226 | double psum = 0; |
---|
| 227 | for (int i = 0; i < n; i++) { |
---|
| 228 | int pos = binarySearch(sorted, recallVals, i * isize); |
---|
| 229 | double recall = recallVals[sorted[pos]]; |
---|
| 230 | double precis = tcurve.instance(sorted[pos]).value(precisInd); |
---|
| 231 | /* |
---|
| 232 | System.err.println("Point " + (i + 1) + ": i=" + pos |
---|
| 233 | + " r=" + (i * isize) |
---|
| 234 | + " p'=" + precis |
---|
| 235 | + " r'=" + recall); |
---|
| 236 | */ |
---|
| 237 | // interpolate figures for non-endpoints |
---|
| 238 | while ((pos != 0) && (pos < sorted.length - 1)) { |
---|
| 239 | pos++; |
---|
| 240 | double recall2 = recallVals[sorted[pos]]; |
---|
| 241 | if (recall2 != recall) { |
---|
| 242 | double precis2 = tcurve.instance(sorted[pos]).value(precisInd); |
---|
| 243 | double slope = (precis2 - precis) / (recall2 - recall); |
---|
| 244 | double offset = precis - recall * slope; |
---|
| 245 | precis = isize * i * slope + offset; |
---|
| 246 | /* |
---|
| 247 | System.err.println("Point2 " + (i + 1) + ": i=" + pos |
---|
| 248 | + " r=" + (i * isize) |
---|
| 249 | + " p'=" + precis2 |
---|
| 250 | + " r'=" + recall2 |
---|
| 251 | + " p''=" + precis); |
---|
| 252 | */ |
---|
| 253 | break; |
---|
| 254 | } |
---|
| 255 | } |
---|
| 256 | psum += precis; |
---|
| 257 | } |
---|
| 258 | return psum / n; |
---|
| 259 | } |
---|
| 260 | |
---|
| 261 | /** |
---|
| 262 | * Calculates the area under the ROC curve as the Wilcoxon-Mann-Whitney statistic. |
---|
| 263 | * |
---|
| 264 | * @param tcurve a previously extracted threshold curve Instances. |
---|
| 265 | * @return the ROC area, or Double.NaN if you don't pass in |
---|
| 266 | * a ThresholdCurve generated Instances. |
---|
| 267 | */ |
---|
| 268 | public static double getROCArea(Instances tcurve) { |
---|
| 269 | |
---|
| 270 | final int n = tcurve.numInstances(); |
---|
| 271 | if (!RELATION_NAME.equals(tcurve.relationName()) |
---|
| 272 | || (n == 0)) { |
---|
| 273 | return Double.NaN; |
---|
| 274 | } |
---|
| 275 | final int tpInd = tcurve.attribute(TRUE_POS_NAME).index(); |
---|
| 276 | final int fpInd = tcurve.attribute(FALSE_POS_NAME).index(); |
---|
| 277 | final double [] tpVals = tcurve.attributeToDoubleArray(tpInd); |
---|
| 278 | final double [] fpVals = tcurve.attributeToDoubleArray(fpInd); |
---|
| 279 | |
---|
| 280 | double area = 0.0, cumNeg = 0.0; |
---|
| 281 | final double totalPos = tpVals[0]; |
---|
| 282 | final double totalNeg = fpVals[0]; |
---|
| 283 | for (int i = 0; i < n; i++) { |
---|
| 284 | double cip, cin; |
---|
| 285 | if (i < n - 1) { |
---|
| 286 | cip = tpVals[i] - tpVals[i + 1]; |
---|
| 287 | cin = fpVals[i] - fpVals[i + 1]; |
---|
| 288 | } else { |
---|
| 289 | cip = tpVals[n - 1]; |
---|
| 290 | cin = fpVals[n - 1]; |
---|
| 291 | } |
---|
| 292 | area += cip * (cumNeg + (0.5 * cin)); |
---|
| 293 | cumNeg += cin; |
---|
| 294 | } |
---|
| 295 | area /= (totalNeg * totalPos); |
---|
| 296 | |
---|
| 297 | return area; |
---|
| 298 | } |
---|
| 299 | |
---|
| 300 | /** |
---|
| 301 | * Gets the index of the instance with the closest threshold value to the |
---|
| 302 | * desired target |
---|
| 303 | * |
---|
| 304 | * @param tcurve a set of instances that have been generated by this class |
---|
| 305 | * @param threshold the target threshold |
---|
| 306 | * @return the index of the instance that has threshold closest to |
---|
| 307 | * the target, or -1 if this could not be found (i.e. no data, or |
---|
| 308 | * bad threshold target) |
---|
| 309 | */ |
---|
| 310 | public static int getThresholdInstance(Instances tcurve, double threshold) { |
---|
| 311 | |
---|
| 312 | if (!RELATION_NAME.equals(tcurve.relationName()) |
---|
| 313 | || (tcurve.numInstances() == 0) |
---|
| 314 | || (threshold < 0) |
---|
| 315 | || (threshold > 1.0)) { |
---|
| 316 | return -1; |
---|
| 317 | } |
---|
| 318 | if (tcurve.numInstances() == 1) { |
---|
| 319 | return 0; |
---|
| 320 | } |
---|
| 321 | double [] tvals = tcurve.attributeToDoubleArray(tcurve.numAttributes() - 1); |
---|
| 322 | int [] sorted = Utils.sort(tvals); |
---|
| 323 | return binarySearch(sorted, tvals, threshold); |
---|
| 324 | } |
---|
| 325 | |
---|
| 326 | /** |
---|
| 327 | * performs a binary search |
---|
| 328 | * |
---|
| 329 | * @param index the indices |
---|
| 330 | * @param vals the values |
---|
| 331 | * @param target the target to look for |
---|
| 332 | * @return the index of the target |
---|
| 333 | */ |
---|
| 334 | private static int binarySearch(int [] index, double [] vals, double target) { |
---|
| 335 | |
---|
| 336 | int lo = 0, hi = index.length - 1; |
---|
| 337 | while (hi - lo > 1) { |
---|
| 338 | int mid = lo + (hi - lo) / 2; |
---|
| 339 | double midval = vals[index[mid]]; |
---|
| 340 | if (target > midval) { |
---|
| 341 | lo = mid; |
---|
| 342 | } else if (target < midval) { |
---|
| 343 | hi = mid; |
---|
| 344 | } else { |
---|
| 345 | while ((mid > 0) && (vals[index[mid - 1]] == target)) { |
---|
| 346 | mid --; |
---|
| 347 | } |
---|
| 348 | return mid; |
---|
| 349 | } |
---|
| 350 | } |
---|
| 351 | return lo; |
---|
| 352 | } |
---|
| 353 | |
---|
| 354 | /** |
---|
| 355 | * |
---|
| 356 | * @param predictions the predictions to use |
---|
| 357 | * @param classIndex the class index |
---|
| 358 | * @return the probabilities |
---|
| 359 | */ |
---|
| 360 | private double [] getProbabilities(FastVector predictions, int classIndex) { |
---|
| 361 | |
---|
| 362 | // sort by predicted probability of the desired class. |
---|
| 363 | double [] probs = new double [predictions.size()]; |
---|
| 364 | for (int i = 0; i < probs.length; i++) { |
---|
| 365 | NominalPrediction pred = (NominalPrediction)predictions.elementAt(i); |
---|
| 366 | probs[i] = pred.distribution()[classIndex]; |
---|
| 367 | } |
---|
| 368 | return probs; |
---|
| 369 | } |
---|
| 370 | |
---|
| 371 | /** |
---|
| 372 | * generates the header |
---|
| 373 | * |
---|
| 374 | * @return the header |
---|
| 375 | */ |
---|
| 376 | private Instances makeHeader() { |
---|
| 377 | |
---|
| 378 | FastVector fv = new FastVector(); |
---|
| 379 | fv.addElement(new Attribute(TRUE_POS_NAME)); |
---|
| 380 | fv.addElement(new Attribute(FALSE_NEG_NAME)); |
---|
| 381 | fv.addElement(new Attribute(FALSE_POS_NAME)); |
---|
| 382 | fv.addElement(new Attribute(TRUE_NEG_NAME)); |
---|
| 383 | fv.addElement(new Attribute(FP_RATE_NAME)); |
---|
| 384 | fv.addElement(new Attribute(TP_RATE_NAME)); |
---|
| 385 | fv.addElement(new Attribute(PRECISION_NAME)); |
---|
| 386 | fv.addElement(new Attribute(RECALL_NAME)); |
---|
| 387 | fv.addElement(new Attribute(FALLOUT_NAME)); |
---|
| 388 | fv.addElement(new Attribute(FMEASURE_NAME)); |
---|
| 389 | fv.addElement(new Attribute(SAMPLE_SIZE_NAME)); |
---|
| 390 | fv.addElement(new Attribute(LIFT_NAME)); |
---|
| 391 | fv.addElement(new Attribute(THRESHOLD_NAME)); |
---|
| 392 | return new Instances(RELATION_NAME, fv, 100); |
---|
| 393 | } |
---|
| 394 | |
---|
| 395 | /** |
---|
| 396 | * generates an instance out of the given data |
---|
| 397 | * |
---|
| 398 | * @param tc the statistics |
---|
| 399 | * @param prob the probability |
---|
| 400 | * @return the generated instance |
---|
| 401 | */ |
---|
| 402 | private Instance makeInstance(TwoClassStats tc, double prob) { |
---|
| 403 | |
---|
| 404 | int count = 0; |
---|
| 405 | double [] vals = new double[13]; |
---|
| 406 | vals[count++] = tc.getTruePositive(); |
---|
| 407 | vals[count++] = tc.getFalseNegative(); |
---|
| 408 | vals[count++] = tc.getFalsePositive(); |
---|
| 409 | vals[count++] = tc.getTrueNegative(); |
---|
| 410 | vals[count++] = tc.getFalsePositiveRate(); |
---|
| 411 | vals[count++] = tc.getTruePositiveRate(); |
---|
| 412 | vals[count++] = tc.getPrecision(); |
---|
| 413 | vals[count++] = tc.getRecall(); |
---|
| 414 | vals[count++] = tc.getFallout(); |
---|
| 415 | vals[count++] = tc.getFMeasure(); |
---|
| 416 | double ss = (tc.getTruePositive() + tc.getFalsePositive()) / |
---|
| 417 | (tc.getTruePositive() + tc.getFalsePositive() + tc.getTrueNegative() + tc.getFalseNegative()); |
---|
| 418 | vals[count++] = ss; |
---|
| 419 | double expectedByChance = (ss * (tc.getTruePositive() + tc.getFalseNegative())); |
---|
| 420 | if (expectedByChance < 1) { |
---|
| 421 | vals[count++] = Utils.missingValue(); |
---|
| 422 | } else { |
---|
| 423 | vals[count++] = tc.getTruePositive() / expectedByChance; |
---|
| 424 | |
---|
| 425 | } |
---|
| 426 | vals[count++] = prob; |
---|
| 427 | return new DenseInstance(1.0, vals); |
---|
| 428 | } |
---|
| 429 | |
---|
| 430 | /** |
---|
| 431 | * Returns the revision string. |
---|
| 432 | * |
---|
| 433 | * @return the revision |
---|
| 434 | */ |
---|
| 435 | public String getRevision() { |
---|
| 436 | return RevisionUtils.extract("$Revision: 5987 $"); |
---|
| 437 | } |
---|
| 438 | |
---|
| 439 | /** |
---|
| 440 | * Tests the ThresholdCurve generation from the command line. |
---|
| 441 | * The classifier is currently hardcoded. Pipe in an arff file. |
---|
| 442 | * |
---|
| 443 | * @param args currently ignored |
---|
| 444 | */ |
---|
| 445 | public static void main(String [] args) { |
---|
| 446 | |
---|
| 447 | try { |
---|
| 448 | |
---|
| 449 | Instances inst = new Instances(new java.io.InputStreamReader(System.in)); |
---|
| 450 | if (false) { |
---|
| 451 | System.out.println(ThresholdCurve.getNPointPrecision(inst, 11)); |
---|
| 452 | } else { |
---|
| 453 | inst.setClassIndex(inst.numAttributes() - 1); |
---|
| 454 | ThresholdCurve tc = new ThresholdCurve(); |
---|
| 455 | EvaluationUtils eu = new EvaluationUtils(); |
---|
| 456 | Classifier classifier = new weka.classifiers.functions.Logistic(); |
---|
| 457 | FastVector predictions = new FastVector(); |
---|
| 458 | for (int i = 0; i < 2; i++) { // Do two runs. |
---|
| 459 | eu.setSeed(i); |
---|
| 460 | predictions.appendElements(eu.getCVPredictions(classifier, inst, 10)); |
---|
| 461 | //System.out.println("\n\n\n"); |
---|
| 462 | } |
---|
| 463 | Instances result = tc.getCurve(predictions); |
---|
| 464 | System.out.println(result); |
---|
| 465 | } |
---|
| 466 | } catch (Exception ex) { |
---|
| 467 | ex.printStackTrace(); |
---|
| 468 | } |
---|
| 469 | } |
---|
| 470 | } |
---|