source: src/main/java/weka/classifiers/Evaluation.java @ 12

Last change on this file since 12 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 121.6 KB
RevLine 
[4]1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    Evaluation.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers;
24
25import weka.classifiers.evaluation.NominalPrediction;
26import weka.classifiers.evaluation.NumericPrediction;
27import weka.classifiers.evaluation.ThresholdCurve;
28import weka.classifiers.evaluation.output.prediction.AbstractOutput;
29import weka.classifiers.evaluation.output.prediction.PlainText;
30import weka.classifiers.pmml.consumer.PMMLClassifier;
31import weka.classifiers.xml.XMLClassifier;
32import weka.core.Drawable;
33import weka.core.FastVector;
34import weka.core.Instance;
35import weka.core.Instances;
36import weka.core.Option;
37import weka.core.OptionHandler;
38import weka.core.RevisionHandler;
39import weka.core.RevisionUtils;
40import weka.core.Summarizable;
41import weka.core.Utils;
42import weka.core.Version;
43import weka.core.converters.ConverterUtils.DataSink;
44import weka.core.converters.ConverterUtils.DataSource;
45import weka.core.pmml.PMMLFactory;
46import weka.core.pmml.PMMLModel;
47import weka.core.xml.KOML;
48import weka.core.xml.XMLOptions;
49import weka.core.xml.XMLSerialization;
50import weka.estimators.UnivariateKernelEstimator;
51
52import java.beans.BeanInfo;
53import java.beans.Introspector;
54import java.beans.MethodDescriptor;
55import java.io.BufferedInputStream;
56import java.io.BufferedOutputStream;
57import java.io.BufferedReader;
58import java.io.FileInputStream;
59import java.io.FileOutputStream;
60import java.io.FileReader;
61import java.io.InputStream;
62import java.io.ObjectInputStream;
63import java.io.ObjectOutputStream;
64import java.io.OutputStream;
65import java.io.Reader;
66import java.lang.reflect.Method;
67import java.util.Date;
68import java.util.Enumeration;
69import java.util.Random;
70import java.util.zip.GZIPInputStream;
71import java.util.zip.GZIPOutputStream;
72
73/**
74 * Class for evaluating machine learning models. <p/>
75 *
76 * ------------------------------------------------------------------- <p/>
77 *
78 * General options when evaluating a learning scheme from the command-line: <p/>
79 *
80 * -t filename <br/>
81 * Name of the file with the training data. (required) <p/>
82 *
83 * -T filename <br/>
84 * Name of the file with the test data. If missing a cross-validation
85 * is performed. <p/>
86 *
87 * -c index <br/>
88 * Index of the class attribute (1, 2, ...; default: last). <p/>
89 *
90 * -x number <br/>
91 * The number of folds for the cross-validation (default: 10). <p/>
92 *
93 * -no-cv <br/>
94 * No cross validation.  If no test file is provided, no evaluation
95 * is done. <p/>
96 *
97 * -split-percentage percentage <br/>
98 * Sets the percentage for the train/test set split, e.g., 66. <p/>
99 *
100 * -preserve-order <br/>
101 * Preserves the order in the percentage split instead of randomizing
102 * the data first with the seed value ('-s'). <p/>
103 *
104 * -s seed <br/>
105 * Random number seed for the cross-validation and percentage split
106 * (default: 1). <p/>
107 *
108 * -m filename <br/>
109 * The name of a file containing a cost matrix. <p/>
110 *
111 * -l filename <br/>
112 * Loads classifier from the given file. In case the filename ends with ".xml",
113 * a PMML file is loaded or, if that fails, options are loaded from XML. <p/>
114 *
115 * -d filename <br/>
116 * Saves classifier built from the training data into the given file. In case
117 * the filename ends with ".xml" the options are saved XML, not the model. <p/>
118 *
119 * -v <br/>
120 * Outputs no statistics for the training data. <p/>
121 *
122 * -o <br/>
123 * Outputs statistics only, not the classifier. <p/>
124 *
125 * -i <br/>
126 * Outputs information-retrieval statistics per class. <p/>
127 *
128 * -k <br/>
129 * Outputs information-theoretic statistics. <p/>
130 *
131 * -classifications "weka.classifiers.evaluation.output.prediction.AbstractOutput + options" <br/>
132 * Uses the specified class for generating the classification output.
133 * E.g.: weka.classifiers.evaluation.output.prediction.PlainText
134 * or  : weka.classifiers.evaluation.output.prediction.CSV
135 *
136 * -p range <br/>
137 * Outputs predictions for test instances (or the train instances if no test
138 * instances provided and -no-cv is used), along with the attributes in the specified range
139 * (and nothing else). Use '-p 0' if no attributes are desired. <p/>
140 * Deprecated: use "-classifications ..." instead. <p/>
141 *
142 * -distribution <br/>
143 * Outputs the distribution instead of only the prediction
144 * in conjunction with the '-p' option (only nominal classes). <p/>
145 * Deprecated: use "-classifications ..." instead. <p/>
146 *
147 * -r <br/>
148 * Outputs cumulative margin distribution (and nothing else). <p/>
149 *
150 * -g <br/>
151 * Only for classifiers that implement "Graphable." Outputs
152 * the graph representation of the classifier (and nothing
153 * else). <p/>
154 *
155 * -xml filename | xml-string <br/>
156 * Retrieves the options from the XML-data instead of the command line. <p/>
157 *
158 * -threshold-file file <br/>
159 * The file to save the threshold data to.
160 * The format is determined by the extensions, e.g., '.arff' for ARFF
161 * format or '.csv' for CSV. <p/>
162 *
163 * -threshold-label label <br/>
164 * The class label to determine the threshold data for
165 * (default is the first label) <p/>
166 *
167 * ------------------------------------------------------------------- <p/>
168 *
169 * Example usage as the main of a classifier (called FunkyClassifier):
170 * <code> <pre>
171 * public static void main(String [] args) {
172 *   runClassifier(new FunkyClassifier(), args);
173 * }
174 * </pre> </code>
175 * <p/>
176 *
177 * ------------------------------------------------------------------ <p/>
178 *
179 * Example usage from within an application:
180 * <code> <pre>
181 * Instances trainInstances = ... instances got from somewhere
182 * Instances testInstances = ... instances got from somewhere
183 * Classifier scheme = ... scheme got from somewhere
184 *
185 * Evaluation evaluation = new Evaluation(trainInstances);
186 * evaluation.evaluateModel(scheme, testInstances);
187 * System.out.println(evaluation.toSummaryString());
188 * </pre> </code>
189 *
190 *
191 * @author   Eibe Frank (eibe@cs.waikato.ac.nz)
192 * @author   Len Trigg (trigg@cs.waikato.ac.nz)
193 * @version  $Revision: 6041 $
194 */
195public class Evaluation
196  implements Summarizable, RevisionHandler {
197
198  /** The number of classes. */
199  protected int m_NumClasses;
200
201  /** The number of folds for a cross-validation. */
202  protected int m_NumFolds;
203
204  /** The weight of all incorrectly classified instances. */
205  protected double m_Incorrect;
206
207  /** The weight of all correctly classified instances. */
208  protected double m_Correct;
209
210  /** The weight of all unclassified instances. */
211  protected double m_Unclassified;
212
213  /*** The weight of all instances that had no class assigned to them. */
214  protected double m_MissingClass;
215
216  /** The weight of all instances that had a class assigned to them. */
217  protected double m_WithClass;
218
219  /** Array for storing the confusion matrix. */
220  protected double [][] m_ConfusionMatrix;
221
222  /** The names of the classes. */
223  protected String [] m_ClassNames;
224
225  /** Is the class nominal or numeric? */
226  protected boolean m_ClassIsNominal;
227
228  /** The prior probabilities of the classes. */
229  protected double [] m_ClassPriors;
230
231  /** The sum of counts for priors. */
232  protected double m_ClassPriorsSum;
233
234  /** The cost matrix (if given). */
235  protected CostMatrix m_CostMatrix;
236
237  /** The total cost of predictions (includes instance weights). */
238  protected double m_TotalCost;
239
240  /** Sum of errors. */
241  protected double m_SumErr;
242
243  /** Sum of absolute errors. */
244  protected double m_SumAbsErr;
245
246  /** Sum of squared errors. */
247  protected double m_SumSqrErr;
248
249  /** Sum of class values. */
250  protected double m_SumClass;
251
252  /** Sum of squared class values. */
253  protected double m_SumSqrClass;
254
255  /*** Sum of predicted values. */
256  protected double m_SumPredicted;
257
258  /** Sum of squared predicted values. */
259  protected double m_SumSqrPredicted;
260
261  /** Sum of predicted * class values. */
262  protected double m_SumClassPredicted;
263
264  /** Sum of absolute errors of the prior. */
265  protected double m_SumPriorAbsErr;
266
267  /** Sum of absolute errors of the prior. */
268  protected double m_SumPriorSqrErr;
269
270  /** Total Kononenko & Bratko Information. */
271  protected double m_SumKBInfo;
272
273  /*** Resolution of the margin histogram. */
274  protected static int k_MarginResolution = 500;
275
276  /** Cumulative margin distribution. */
277  protected double m_MarginCounts [];
278
279  /** Number of non-missing class training instances seen. */
280  protected int m_NumTrainClassVals;
281
282  /** Array containing all numeric training class values seen. */
283  protected double [] m_TrainClassVals;
284
285  /** Array containing all numeric training class weights. */
286  protected double [] m_TrainClassWeights;
287
288  /** Numeric class estimator for prior. */
289  protected UnivariateKernelEstimator m_PriorEstimator;
290
291  /** Whether complexity statistics are available. */
292  protected boolean m_ComplexityStatisticsAvailable = true;
293
294  /**
295   * The minimum probablility accepted from an estimator to avoid
296   * taking log(0) in Sf calculations.
297   */
298  protected static final double MIN_SF_PROB = Double.MIN_VALUE;
299
300  /** Total entropy of prior predictions. */
301  protected double m_SumPriorEntropy;
302
303  /** Total entropy of scheme predictions. */
304  protected double m_SumSchemeEntropy;
305
306  /** Whether coverage statistics are available. */
307  protected boolean m_CoverageStatisticsAvailable = true;
308
309  /**  The confidence level used for coverage statistics. */
310  protected double m_ConfLevel = 0.95;
311
312  /** Total size of predicted regions at the given confidence level. */
313  protected double m_TotalSizeOfRegions;
314
315  /** Total coverage of test cases at the given confidence level. */
316  protected double m_TotalCoverage;
317
318  /** Minimum target value. */
319  protected double m_MinTarget;
320
321  /** Maximum target value. */
322  protected double m_MaxTarget;
323
324  /** The list of predictions that have been generated (for computing AUC). */
325  protected FastVector m_Predictions;
326
327  /** enables/disables the use of priors, e.g., if no training set is
328   * present in case of de-serialized schemes. */
329  protected boolean m_NoPriors = false;
330
331  /** The header of the training set. */
332  protected Instances m_Header;
333
334  /**
335   * Initializes all the counters for the evaluation.
336   * Use <code>useNoPriors()</code> if the dataset is the test set and you
337   * can't initialize with the priors from the training set via
338   * <code>setPriors(Instances)</code>.
339   *
340   * @param data        set of training instances, to get some header
341   *                    information and prior class distribution information
342   * @throws Exception  if the class is not defined
343   * @see               #useNoPriors()
344   * @see               #setPriors(Instances)
345   */
346  public Evaluation(Instances data) throws Exception {
347
348    this(data, null);
349  }
350
351  /**
352   * Initializes all the counters for the evaluation and also takes a
353   * cost matrix as parameter.
354   * Use <code>useNoPriors()</code> if the dataset is the test set and you
355   * can't initialize with the priors from the training set via
356   * <code>setPriors(Instances)</code>.
357   *
358   * @param data        set of training instances, to get some header
359   *                    information and prior class distribution information
360   * @param costMatrix  the cost matrix---if null, default costs will be used
361   * @throws Exception  if cost matrix is not compatible with
362   *                    data, the class is not defined or the class is numeric
363   * @see               #useNoPriors()
364   * @see               #setPriors(Instances)
365   */
366  public Evaluation(Instances data, CostMatrix costMatrix)
367  throws Exception {
368
369    m_Header = new Instances(data, 0);
370    m_NumClasses = data.numClasses();
371    m_NumFolds = 1;
372    m_ClassIsNominal = data.classAttribute().isNominal();
373
374    if (m_ClassIsNominal) {
375      m_ConfusionMatrix = new double [m_NumClasses][m_NumClasses];
376      m_ClassNames = new String [m_NumClasses];
377      for(int i = 0; i < m_NumClasses; i++) {
378        m_ClassNames[i] = data.classAttribute().value(i);
379      }
380    }
381    m_CostMatrix = costMatrix;
382    if (m_CostMatrix != null) {
383      if (!m_ClassIsNominal) {
384        throw new Exception("Class has to be nominal if cost matrix given!");
385      }
386      if (m_CostMatrix.size() != m_NumClasses) {
387        throw new Exception("Cost matrix not compatible with data!");
388      }
389    }
390    m_ClassPriors = new double [m_NumClasses];
391    setPriors(data);
392    m_MarginCounts = new double [k_MarginResolution + 1];
393  }
394
395  /**
396   * Returns the header of the underlying dataset.
397   *
398   * @return            the header information
399   */
400  public Instances getHeader() {
401    return m_Header;
402  }
403
404  /**
405   * Returns the area under ROC for those predictions that have been collected
406   * in the evaluateClassifier(Classifier, Instances) method. Returns
407   * Utils.missingValue() if the area is not available.
408   *
409   * @param classIndex the index of the class to consider as "positive"
410   * @return the area under the ROC curve or not a number
411   */
412  public double areaUnderROC(int classIndex) {
413
414    // Check if any predictions have been collected
415    if (m_Predictions == null) {
416      return Utils.missingValue();
417    } else {
418      ThresholdCurve tc = new ThresholdCurve();
419      Instances result = tc.getCurve(m_Predictions, classIndex);
420      return ThresholdCurve.getROCArea(result);
421    }
422  }
423
424  /**
425   * Calculates the weighted (by class size) AUC.
426   *
427   * @return the weighted AUC.
428   */
429  public double weightedAreaUnderROC() {
430    double[] classCounts = new double[m_NumClasses];
431    double classCountSum = 0;
432
433    for (int i = 0; i < m_NumClasses; i++) {
434      for (int j = 0; j < m_NumClasses; j++) {
435        classCounts[i] += m_ConfusionMatrix[i][j];
436      }
437      classCountSum += classCounts[i];
438    }
439
440    double aucTotal = 0;
441    for(int i = 0; i < m_NumClasses; i++) {
442      double temp = areaUnderROC(i);
443      if (!Utils.isMissingValue(temp)) {
444        aucTotal += (temp * classCounts[i]);
445      }
446    }
447
448    return aucTotal / classCountSum;
449  }
450
451  /**
452   * Returns a copy of the confusion matrix.
453   *
454   * @return a copy of the confusion matrix as a two-dimensional array
455   */
456  public double[][] confusionMatrix() {
457
458    double[][] newMatrix = new double[m_ConfusionMatrix.length][0];
459
460    for (int i = 0; i < m_ConfusionMatrix.length; i++) {
461      newMatrix[i] = new double[m_ConfusionMatrix[i].length];
462      System.arraycopy(m_ConfusionMatrix[i], 0, newMatrix[i], 0,
463          m_ConfusionMatrix[i].length);
464    }
465    return newMatrix;
466  }
467
468  /**
469   * Performs a (stratified if class is nominal) cross-validation
470   * for a classifier on a set of instances. Now performs
471   * a deep copy of the classifier before each call to
472   * buildClassifier() (just in case the classifier is not
473   * initialized properly).
474   *
475   * @param classifier the classifier with any options set.
476   * @param data the data on which the cross-validation is to be
477   * performed
478   * @param numFolds the number of folds for the cross-validation
479   * @param random random number generator for randomization
480   * @param forPredictionsPrinting varargs parameter that, if supplied, is
481   * expected to hold a weka.classifiers.evaluation.output.prediction.AbstractOutput
482   * object
483   * @throws Exception if a classifier could not be generated
484   * successfully or the class is not defined
485   */
486  public void crossValidateModel(Classifier classifier,
487                                 Instances data, int numFolds, Random random,
488                                 Object... forPredictionsPrinting)
489  throws Exception {
490
491    // Make a copy of the data we can reorder
492    data = new Instances(data);
493    data.randomize(random);
494    if (data.classAttribute().isNominal()) {
495      data.stratify(numFolds);
496    }
497
498    // We assume that the first element is a
499    // weka.classifiers.evaluation.output.prediction.AbstractOutput object
500    AbstractOutput classificationOutput = null;
501    if (forPredictionsPrinting.length > 0) {
502      // print the header first
503      classificationOutput = (AbstractOutput) forPredictionsPrinting[0];
504      classificationOutput.setHeader(data);
505      classificationOutput.printHeader();
506    }
507
508    // Do the folds
509    for (int i = 0; i < numFolds; i++) {
510      Instances train = data.trainCV(numFolds, i, random);
511      setPriors(train);
512      Classifier copiedClassifier = AbstractClassifier.makeCopy(classifier);
513      copiedClassifier.buildClassifier(train);
514      Instances test = data.testCV(numFolds, i);
515      evaluateModel(copiedClassifier, test, forPredictionsPrinting);
516    }
517    m_NumFolds = numFolds;
518
519    if (classificationOutput != null)
520      classificationOutput.printFooter();
521  }
522
523  /**
524   * Performs a (stratified if class is nominal) cross-validation
525   * for a classifier on a set of instances.
526   *
527   * @param classifierString a string naming the class of the classifier
528   * @param data the data on which the cross-validation is to be
529   * performed
530   * @param numFolds the number of folds for the cross-validation
531   * @param options the options to the classifier. Any options
532   * @param random the random number generator for randomizing the data
533   * accepted by the classifier will be removed from this array.
534   * @throws Exception if a classifier could not be generated
535   * successfully or the class is not defined
536   */
537  public void crossValidateModel(String classifierString,
538      Instances data, int numFolds,
539      String[] options, Random random)
540    throws Exception {
541
542    crossValidateModel(AbstractClassifier.forName(classifierString, options),
543        data, numFolds, random);
544  }
545
546  /**
547   * Evaluates a classifier with the options given in an array of
548   * strings. <p/>
549   *
550   * Valid options are: <p/>
551   *
552   * -t filename <br/>
553   * Name of the file with the training data. (required) <p/>
554   *
555   * -T filename <br/>
556   * Name of the file with the test data. If missing a cross-validation
557   * is performed. <p/>
558   *
559   * -c index <br/>
560   * Index of the class attribute (1, 2, ...; default: last). <p/>
561   *
562   * -x number <br/>
563   * The number of folds for the cross-validation (default: 10). <p/>
564   *
565   * -no-cv <br/>
566   * No cross validation.  If no test file is provided, no evaluation
567   * is done. <p/>
568   *
569   * -split-percentage percentage <br/>
570   * Sets the percentage for the train/test set split, e.g., 66. <p/>
571   *
572   * -preserve-order <br/>
573   * Preserves the order in the percentage split instead of randomizing
574   * the data first with the seed value ('-s'). <p/>
575   *
576   * -s seed <br/>
577   * Random number seed for the cross-validation and percentage split
578   * (default: 1). <p/>
579   *
580   * -m filename <br/>
581   * The name of a file containing a cost matrix. <p/>
582   *
583   * -l filename <br/>
584   * Loads classifier from the given file. In case the filename ends with
585   * ".xml",a PMML file is loaded or, if that fails, options are loaded from XML. <p/>
586   *
587   * -d filename <br/>
588   * Saves classifier built from the training data into the given file. In case
589   * the filename ends with ".xml" the options are saved XML, not the model. <p/>
590   *
591   * -v <br/>
592   * Outputs no statistics for the training data. <p/>
593   *
594   * -o <br/>
595   * Outputs statistics only, not the classifier. <p/>
596   *
597   * -i <br/>
598   * Outputs detailed information-retrieval statistics per class. <p/>
599   *
600   * -k <br/>
601   * Outputs information-theoretic statistics. <p/>
602   *
603   * -classifications "weka.classifiers.evaluation.output.prediction.AbstractOutput + options" <br/>
604   * Uses the specified class for generating the classification output.
605   * E.g.: weka.classifiers.evaluation.output.prediction.PlainText
606   * or  : weka.classifiers.evaluation.output.prediction.CSV
607   *
608   * -p range <br/>
609   * Outputs predictions for test instances (or the train instances if no test
610   * instances provided and -no-cv is used), along with the attributes in the specified range
611   * (and nothing else). Use '-p 0' if no attributes are desired. <p/>
612   * Deprecated: use "-classifications ..." instead. <p/>
613   *
614   * -distribution <br/>
615   * Outputs the distribution instead of only the prediction
616   * in conjunction with the '-p' option (only nominal classes). <p/>
617   * Deprecated: use "-classifications ..." instead. <p/>
618   *
619   * -r <br/>
620   * Outputs cumulative margin distribution (and nothing else). <p/>
621   *
622   * -g <br/>
623   * Only for classifiers that implement "Graphable." Outputs
624   * the graph representation of the classifier (and nothing
625   * else). <p/>
626   *
627   * -xml filename | xml-string <br/>
628   * Retrieves the options from the XML-data instead of the command line. <p/>
629   *
630   * -threshold-file file <br/>
631   * The file to save the threshold data to.
632   * The format is determined by the extensions, e.g., '.arff' for ARFF
633   * format or '.csv' for CSV. <p/>
634   *
635   * -threshold-label label <br/>
636   * The class label to determine the threshold data for
637   * (default is the first label) <p/>
638   *
639   * @param classifierString class of machine learning classifier as a string
640   * @param options the array of string containing the options
641   * @throws Exception if model could not be evaluated successfully
642   * @return a string describing the results
643   */
644  public static String evaluateModel(String classifierString,
645      String [] options) throws Exception {
646
647    Classifier classifier;
648
649    // Create classifier
650    try {
651      classifier =
652        //  (Classifier)Class.forName(classifierString).newInstance();
653        AbstractClassifier.forName(classifierString, null);
654    } catch (Exception e) {
655      throw new Exception("Can't find class with name "
656          + classifierString + '.');
657    }
658    return evaluateModel(classifier, options);
659  }
660
661  /**
662   * A test method for this class. Just extracts the first command line
663   * argument as a classifier class name and calls evaluateModel.
664   * @param args an array of command line arguments, the first of which
665   * must be the class name of a classifier.
666   */
667  public static void main(String [] args) {
668
669    try {
670      if (args.length == 0) {
671        throw new Exception("The first argument must be the class name"
672            + " of a classifier");
673      }
674      String classifier = args[0];
675      args[0] = "";
676      System.out.println(evaluateModel(classifier, args));
677    } catch (Exception ex) {
678      ex.printStackTrace();
679      System.err.println(ex.getMessage());
680    }
681  }
682
683  /**
684   * Evaluates a classifier with the options given in an array of
685   * strings. <p/>
686   *
687   * Valid options are: <p/>
688   *
689   * -t name of training file <br/>
690   * Name of the file with the training data. (required) <p/>
691   *
692   * -T name of test file <br/>
693   * Name of the file with the test data. If missing a cross-validation
694   * is performed. <p/>
695   *
696   * -c class index <br/>
697   * Index of the class attribute (1, 2, ...; default: last). <p/>
698   *
699   * -x number of folds <br/>
700   * The number of folds for the cross-validation (default: 10). <p/>
701   *
702   * -no-cv <br/>
703   * No cross validation.  If no test file is provided, no evaluation
704   * is done. <p/>
705   *
706   * -split-percentage percentage <br/>
707   * Sets the percentage for the train/test set split, e.g., 66. <p/>
708   *
709   * -preserve-order <br/>
710   * Preserves the order in the percentage split instead of randomizing
711   * the data first with the seed value ('-s'). <p/>
712   *
713   * -s seed <br/>
714   * Random number seed for the cross-validation and percentage split
715   * (default: 1). <p/>
716   *
717   * -m file with cost matrix <br/>
718   * The name of a file containing a cost matrix. <p/>
719   *
720   * -l filename <br/>
721   * Loads classifier from the given file. In case the filename ends with
722   * ".xml",a PMML file is loaded or, if that fails, options are loaded from XML. <p/>
723   *
724   * -d filename <br/>
725   * Saves classifier built from the training data into the given file. In case
726   * the filename ends with ".xml" the options are saved XML, not the model. <p/>
727   *
728   * -v <br/>
729   * Outputs no statistics for the training data. <p/>
730   *
731   * -o <br/>
732   * Outputs statistics only, not the classifier. <p/>
733   *
734   * -i <br/>
735   * Outputs detailed information-retrieval statistics per class. <p/>
736   *
737   * -k <br/>
738   * Outputs information-theoretic statistics. <p/>
739   *
740   * -classifications "weka.classifiers.evaluation.output.prediction.AbstractOutput + options" <br/>
741   * Uses the specified class for generating the classification output.
742   * E.g.: weka.classifiers.evaluation.output.prediction.PlainText
743   * or  : weka.classifiers.evaluation.output.prediction.CSV
744   *
745   * -p range <br/>
746   * Outputs predictions for test instances (or the train instances if no test
747   * instances provided and -no-cv is used), along with the attributes in the specified range
748   * (and nothing else). Use '-p 0' if no attributes are desired. <p/>
749   * Deprecated: use "-classifications ..." instead. <p/>
750   *
751   * -distribution <br/>
752   * Outputs the distribution instead of only the prediction
753   * in conjunction with the '-p' option (only nominal classes). <p/>
754   * Deprecated: use "-classifications ..." instead. <p/>
755   *
756   * -r <br/>
757   * Outputs cumulative margin distribution (and nothing else). <p/>
758   *
759   * -g <br/>
760   * Only for classifiers that implement "Graphable." Outputs
761   * the graph representation of the classifier (and nothing
762   * else). <p/>
763   *
764   * -xml filename | xml-string <br/>
765   * Retrieves the options from the XML-data instead of the command line. <p/>
766   *
767   * @param classifier machine learning classifier
768   * @param options the array of string containing the options
769   * @throws Exception if model could not be evaluated successfully
770   * @return a string describing the results
771   */
772  public static String evaluateModel(Classifier classifier,
773      String [] options) throws Exception {
774
775    Instances train = null, tempTrain, test = null, template = null;
776    int seed = 1, folds = 10, classIndex = -1;
777    boolean noCrossValidation = false;
778    String trainFileName, testFileName, sourceClass,
779    classIndexString, seedString, foldsString, objectInputFileName,
780    objectOutputFileName;
781    boolean noOutput = false,
782    trainStatistics = true,
783    printMargins = false, printComplexityStatistics = false,
784    printGraph = false, classStatistics = false, printSource = false;
785    StringBuffer text = new StringBuffer();
786    DataSource trainSource = null, testSource = null;
787    ObjectInputStream objectInputStream = null;
788    BufferedInputStream xmlInputStream = null;
789    CostMatrix costMatrix = null;
790    StringBuffer schemeOptionsText = null;
791    long trainTimeStart = 0, trainTimeElapsed = 0,
792    testTimeStart = 0, testTimeElapsed = 0;
793    String xml = "";
794    String[] optionsTmp = null;
795    Classifier classifierBackup;
796    Classifier classifierClassifications = null;
797    int actualClassIndex = -1;  // 0-based class index
798    String splitPercentageString = "";
799    int splitPercentage = -1;
800    boolean preserveOrder = false;
801    boolean trainSetPresent = false;
802    boolean testSetPresent = false;
803    String thresholdFile;
804    String thresholdLabel;
805    StringBuffer predsBuff = null; // predictions from cross-validation
806    AbstractOutput classificationOutput = null;
807
808    // help requested?
809    if (Utils.getFlag("h", options) || Utils.getFlag("help", options)) {
810
811      // global info requested as well?
812      boolean globalInfo = Utils.getFlag("synopsis", options) ||
813        Utils.getFlag("info", options);
814
815      throw new Exception("\nHelp requested."
816          + makeOptionString(classifier, globalInfo));
817    }
818
819    try {
820      // do we get the input from XML instead of normal parameters?
821      xml = Utils.getOption("xml", options);
822      if (!xml.equals(""))
823        options = new XMLOptions(xml).toArray();
824
825      // is the input model only the XML-Options, i.e. w/o built model?
826      optionsTmp = new String[options.length];
827      for (int i = 0; i < options.length; i++)
828        optionsTmp[i] = options[i];
829
830      String tmpO = Utils.getOption('l', optionsTmp);
831      //if (Utils.getOption('l', optionsTmp).toLowerCase().endsWith(".xml")) {
832      if (tmpO.endsWith(".xml")) {
833        // try to load file as PMML first
834        boolean success = false;
835        try {
836          PMMLModel pmmlModel = PMMLFactory.getPMMLModel(tmpO);
837          if (pmmlModel instanceof PMMLClassifier) {
838            classifier = ((PMMLClassifier)pmmlModel);
839            success = true;
840          }
841        } catch (IllegalArgumentException ex) {
842          success = false;
843        }
844        if (!success) {
845          // load options from serialized data  ('-l' is automatically erased!)
846          XMLClassifier xmlserial = new XMLClassifier();
847          OptionHandler cl = (OptionHandler) xmlserial.read(Utils.getOption('l', options));
848
849          // merge options
850          optionsTmp = new String[options.length + cl.getOptions().length];
851          System.arraycopy(cl.getOptions(), 0, optionsTmp, 0, cl.getOptions().length);
852          System.arraycopy(options, 0, optionsTmp, cl.getOptions().length, options.length);
853          options = optionsTmp;
854        }
855      }
856
857      noCrossValidation = Utils.getFlag("no-cv", options);
858      // Get basic options (options the same for all schemes)
859      classIndexString = Utils.getOption('c', options);
860      if (classIndexString.length() != 0) {
861        if (classIndexString.equals("first"))
862          classIndex = 1;
863        else if (classIndexString.equals("last"))
864          classIndex = -1;
865        else
866          classIndex = Integer.parseInt(classIndexString);
867      }
868      trainFileName = Utils.getOption('t', options);
869      objectInputFileName = Utils.getOption('l', options);
870      objectOutputFileName = Utils.getOption('d', options);
871      testFileName = Utils.getOption('T', options);
872      foldsString = Utils.getOption('x', options);
873      if (foldsString.length() != 0) {
874        folds = Integer.parseInt(foldsString);
875      }
876      seedString = Utils.getOption('s', options);
877      if (seedString.length() != 0) {
878        seed = Integer.parseInt(seedString);
879      }
880      if (trainFileName.length() == 0) {
881        if (objectInputFileName.length() == 0) {
882          throw new Exception("No training file and no object input file given.");
883        }
884        if (testFileName.length() == 0) {
885          throw new Exception("No training file and no test file given.");
886        }
887      } else if ((objectInputFileName.length() != 0) &&
888          ((!(classifier instanceof UpdateableClassifier)) ||
889           (testFileName.length() == 0))) {
890        throw new Exception("Classifier not incremental, or no " +
891            "test file provided: can't "+
892            "use both train and model file.");
893      }
894      try {
895        if (trainFileName.length() != 0) {
896          trainSetPresent = true;
897          trainSource = new DataSource(trainFileName);
898        }
899        if (testFileName.length() != 0) {
900          testSetPresent = true;
901          testSource = new DataSource(testFileName);
902        }
903        if (objectInputFileName.length() != 0) {
904          if (objectInputFileName.endsWith(".xml")) {
905            // if this is the case then it means that a PMML classifier was
906            // successfully loaded earlier in the code
907            objectInputStream = null;
908            xmlInputStream = null;
909          } else {
910            InputStream is = new FileInputStream(objectInputFileName);
911            if (objectInputFileName.endsWith(".gz")) {
912              is = new GZIPInputStream(is);
913            }
914            // load from KOML?
915            if (!(objectInputFileName.endsWith(".koml") && KOML.isPresent()) ) {
916              objectInputStream = new ObjectInputStream(is);
917              xmlInputStream    = null;
918            }
919            else {
920              objectInputStream = null;
921              xmlInputStream    = new BufferedInputStream(is);
922            }
923          }
924        }
925      } catch (Exception e) {
926        throw new Exception("Can't open file " + e.getMessage() + '.');
927      }
928      if (testSetPresent) {
929        template = test = testSource.getStructure();
930        if (classIndex != -1) {
931          test.setClassIndex(classIndex - 1);
932        } else {
933          if ( (test.classIndex() == -1) || (classIndexString.length() != 0) )
934            test.setClassIndex(test.numAttributes() - 1);
935        }
936        actualClassIndex = test.classIndex();
937      }
938      else {
939        // percentage split
940        splitPercentageString = Utils.getOption("split-percentage", options);
941        if (splitPercentageString.length() != 0) {
942          if (foldsString.length() != 0)
943            throw new Exception(
944                "Percentage split cannot be used in conjunction with "
945                + "cross-validation ('-x').");
946          splitPercentage = Integer.parseInt(splitPercentageString);
947          if ((splitPercentage <= 0) || (splitPercentage >= 100))
948            throw new Exception("Percentage split value needs be >0 and <100.");
949        }
950        else {
951          splitPercentage = -1;
952        }
953        preserveOrder = Utils.getFlag("preserve-order", options);
954        if (preserveOrder) {
955          if (splitPercentage == -1)
956            throw new Exception("Percentage split ('-percentage-split') is missing.");
957        }
958        // create new train/test sources
959        if (splitPercentage > 0) {
960          testSetPresent = true;
961          Instances tmpInst = trainSource.getDataSet(actualClassIndex);
962          if (!preserveOrder)
963            tmpInst.randomize(new Random(seed));
964          int trainSize = tmpInst.numInstances() * splitPercentage / 100;
965          int testSize  = tmpInst.numInstances() - trainSize;
966          Instances trainInst = new Instances(tmpInst, 0, trainSize);
967          Instances testInst  = new Instances(tmpInst, trainSize, testSize);
968          trainSource = new DataSource(trainInst);
969          testSource  = new DataSource(testInst);
970          template = test = testSource.getStructure();
971          if (classIndex != -1) {
972            test.setClassIndex(classIndex - 1);
973          } else {
974            if ( (test.classIndex() == -1) || (classIndexString.length() != 0) )
975              test.setClassIndex(test.numAttributes() - 1);
976          }
977          actualClassIndex = test.classIndex();
978        }
979      }
980      if (trainSetPresent) {
981        template = train = trainSource.getStructure();
982        if (classIndex != -1) {
983          train.setClassIndex(classIndex - 1);
984        } else {
985          if ( (train.classIndex() == -1) || (classIndexString.length() != 0) )
986            train.setClassIndex(train.numAttributes() - 1);
987        }
988        actualClassIndex = train.classIndex();
989        if ((testSetPresent) && !test.equalHeaders(train)) {
990          throw new IllegalArgumentException("Train and test file not compatible!\n" + test.equalHeadersMsg(train));
991        }
992      }
993      if (template == null) {
994        throw new Exception("No actual dataset provided to use as template");
995      }
996      costMatrix = handleCostOption(
997          Utils.getOption('m', options), template.numClasses());
998
999      classStatistics = Utils.getFlag('i', options);
1000      noOutput = Utils.getFlag('o', options);
1001      trainStatistics = !Utils.getFlag('v', options);
1002      printComplexityStatistics = Utils.getFlag('k', options);
1003      printMargins = Utils.getFlag('r', options);
1004      printGraph = Utils.getFlag('g', options);
1005      sourceClass = Utils.getOption('z', options);
1006      printSource = (sourceClass.length() != 0);
1007      thresholdFile = Utils.getOption("threshold-file", options);
1008      thresholdLabel = Utils.getOption("threshold-label", options);
1009
1010      String classifications = Utils.getOption("classifications", options);
1011      String classificationsOld = Utils.getOption("p", options);
1012      if (classifications.length() > 0) {
1013        noOutput = true;
1014        classificationOutput = AbstractOutput.fromCommandline(classifications);
1015        classificationOutput.setHeader(template);
1016      }
1017      // backwards compatible with old "-p range" and "-distribution" options
1018      else if (classificationsOld.length() > 0) {
1019        noOutput = true;
1020        classificationOutput = new PlainText();
1021        classificationOutput.setHeader(template);
1022        if (!classificationsOld.equals("0"))
1023          classificationOutput.setAttributes(classificationsOld);
1024        classificationOutput.setOutputDistribution(Utils.getFlag("distribution", options));
1025      }
1026      // -distribution flag needs -p option
1027      else {
1028        if (Utils.getFlag("distribution", options))
1029          throw new Exception("Cannot print distribution without '-p' option!");
1030      }
1031
1032      // if no training file given, we don't have any priors
1033      if ( (!trainSetPresent) && (printComplexityStatistics) )
1034        throw new Exception("Cannot print complexity statistics ('-k') without training file ('-t')!");
1035
1036      // If a model file is given, we can't process
1037      // scheme-specific options
1038      if (objectInputFileName.length() != 0) {
1039        Utils.checkForRemainingOptions(options);
1040      } else {
1041
1042        // Set options for classifier
1043        if (classifier instanceof OptionHandler) {
1044          for (int i = 0; i < options.length; i++) {
1045            if (options[i].length() != 0) {
1046              if (schemeOptionsText == null) {
1047                schemeOptionsText = new StringBuffer();
1048              }
1049              if (options[i].indexOf(' ') != -1) {
1050                schemeOptionsText.append('"' + options[i] + "\" ");
1051              } else {
1052                schemeOptionsText.append(options[i] + " ");
1053              }
1054            }
1055          }
1056          ((OptionHandler)classifier).setOptions(options);
1057        }
1058      }
1059
1060      Utils.checkForRemainingOptions(options);
1061    } catch (Exception e) {
1062      throw new Exception("\nWeka exception: " + e.getMessage()
1063          + makeOptionString(classifier, false));
1064    }
1065
1066    // Setup up evaluation objects
1067    Evaluation trainingEvaluation = new Evaluation(new Instances(template, 0), costMatrix);
1068    Evaluation testingEvaluation = new Evaluation(new Instances(template, 0), costMatrix);
1069
1070    // disable use of priors if no training file given
1071    if (!trainSetPresent)
1072      testingEvaluation.useNoPriors();
1073
1074    if (objectInputFileName.length() != 0) {
1075      // Load classifier from file
1076      if (objectInputStream != null) {
1077        classifier = (Classifier) objectInputStream.readObject();
1078        // try and read a header (if present)
1079        Instances savedStructure = null;
1080        try {
1081          savedStructure = (Instances) objectInputStream.readObject();
1082        } catch (Exception ex) {
1083          // don't make a fuss
1084        }
1085        if (savedStructure != null) {
1086          // test for compatibility with template
1087          if (!template.equalHeaders(savedStructure)) {
1088            throw new Exception("training and test set are not compatible\n" + template.equalHeadersMsg(savedStructure));
1089          }
1090        }
1091        objectInputStream.close();
1092      }
1093      else if (xmlInputStream != null) {
1094        // whether KOML is available has already been checked (objectInputStream would null otherwise)!
1095        classifier = (Classifier) KOML.read(xmlInputStream);
1096        xmlInputStream.close();
1097      }
1098    }
1099
1100    // backup of fully setup classifier for cross-validation
1101    classifierBackup = AbstractClassifier.makeCopy(classifier);
1102
1103    // Build the classifier if no object file provided
1104    if ((classifier instanceof UpdateableClassifier) &&
1105        (testSetPresent || noCrossValidation) &&
1106        (costMatrix == null) &&
1107        (trainSetPresent)) {
1108      // Build classifier incrementally
1109      trainingEvaluation.setPriors(train);
1110      testingEvaluation.setPriors(train);
1111      trainTimeStart = System.currentTimeMillis();
1112      if (objectInputFileName.length() == 0) {
1113        classifier.buildClassifier(train);
1114      }
1115      Instance trainInst;
1116      while (trainSource.hasMoreElements(train)) {
1117        trainInst = trainSource.nextElement(train);
1118        trainingEvaluation.updatePriors(trainInst);
1119        testingEvaluation.updatePriors(trainInst);
1120        ((UpdateableClassifier)classifier).updateClassifier(trainInst);
1121      }
1122      trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
1123    } else if (objectInputFileName.length() == 0) {
1124      // Build classifier in one go
1125      tempTrain = trainSource.getDataSet(actualClassIndex);
1126      trainingEvaluation.setPriors(tempTrain);
1127      testingEvaluation.setPriors(tempTrain);
1128      trainTimeStart = System.currentTimeMillis();
1129      classifier.buildClassifier(tempTrain);
1130      trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
1131    }
1132
1133    // backup of fully trained classifier for printing the classifications
1134    if (classificationOutput != null)
1135      classifierClassifications = AbstractClassifier.makeCopy(classifier);
1136
1137    // Save the classifier if an object output file is provided
1138    if (objectOutputFileName.length() != 0) {
1139      OutputStream os = new FileOutputStream(objectOutputFileName);
1140      // binary
1141      if (!(objectOutputFileName.endsWith(".xml") || (objectOutputFileName.endsWith(".koml") && KOML.isPresent()))) {
1142        if (objectOutputFileName.endsWith(".gz")) {
1143          os = new GZIPOutputStream(os);
1144        }
1145        ObjectOutputStream objectOutputStream = new ObjectOutputStream(os);
1146        objectOutputStream.writeObject(classifier);
1147        if (template != null) {
1148          objectOutputStream.writeObject(template);
1149        }
1150        objectOutputStream.flush();
1151        objectOutputStream.close();
1152      }
1153      // KOML/XML
1154      else {
1155        BufferedOutputStream xmlOutputStream = new BufferedOutputStream(os);
1156        if (objectOutputFileName.endsWith(".xml")) {
1157          XMLSerialization xmlSerial = new XMLClassifier();
1158          xmlSerial.write(xmlOutputStream, classifier);
1159        }
1160        else
1161          // whether KOML is present has already been checked
1162          // if not present -> ".koml" is interpreted as binary - see above
1163          if (objectOutputFileName.endsWith(".koml")) {
1164            KOML.write(xmlOutputStream, classifier);
1165          }
1166        xmlOutputStream.close();
1167      }
1168    }
1169
1170    // If classifier is drawable output string describing graph
1171    if ((classifier instanceof Drawable) && (printGraph)){
1172      return ((Drawable)classifier).graph();
1173    }
1174
1175    // Output the classifier as equivalent source
1176    if ((classifier instanceof Sourcable) && (printSource)){
1177      return wekaStaticWrapper((Sourcable) classifier, sourceClass);
1178    }
1179
1180    // Output model
1181    if (!(noOutput || printMargins)) {
1182      if (classifier instanceof OptionHandler) {
1183        if (schemeOptionsText != null) {
1184          text.append("\nOptions: "+schemeOptionsText);
1185          text.append("\n");
1186        }
1187      }
1188      text.append("\n" + classifier.toString() + "\n");
1189    }
1190
1191    if (!printMargins && (costMatrix != null)) {
1192      text.append("\n=== Evaluation Cost Matrix ===\n\n");
1193      text.append(costMatrix.toString());
1194    }
1195
1196    // Output test instance predictions only
1197    if (classificationOutput != null) {
1198      DataSource source = testSource;
1199      predsBuff = new StringBuffer();
1200      classificationOutput.setBuffer(predsBuff);
1201      // no test set -> use train set
1202      if (source == null && noCrossValidation) {
1203        source = trainSource;
1204        predsBuff.append("\n=== Predictions on training data ===\n\n");
1205      } else {
1206        predsBuff.append("\n=== Predictions on test data ===\n\n");
1207      }
1208      if (source != null)
1209        classificationOutput.print(classifierClassifications, source);
1210    }
1211
1212    // Compute error estimate from training data
1213    if ((trainStatistics) && (trainSetPresent)) {
1214
1215      if ((classifier instanceof UpdateableClassifier) &&
1216          (testSetPresent) &&
1217          (costMatrix == null)) {
1218
1219        // Classifier was trained incrementally, so we have to
1220        // reset the source.
1221        trainSource.reset();
1222
1223        // Incremental testing
1224        train = trainSource.getStructure(actualClassIndex);
1225        testTimeStart = System.currentTimeMillis();
1226        Instance trainInst;
1227        while (trainSource.hasMoreElements(train)) {
1228          trainInst = trainSource.nextElement(train);
1229          trainingEvaluation.evaluateModelOnce((Classifier)classifier, trainInst);
1230        }
1231        testTimeElapsed = System.currentTimeMillis() - testTimeStart;
1232      } else {
1233        testTimeStart = System.currentTimeMillis();
1234        trainingEvaluation.evaluateModel(
1235            classifier, trainSource.getDataSet(actualClassIndex));
1236        testTimeElapsed = System.currentTimeMillis() - testTimeStart;
1237      }
1238
1239      // Print the results of the training evaluation
1240      if (printMargins) {
1241        return trainingEvaluation.toCumulativeMarginDistributionString();
1242      } else {
1243        if (classificationOutput == null) {
1244          text.append("\nTime taken to build model: "
1245              + Utils.doubleToString(trainTimeElapsed / 1000.0,2)
1246              + " seconds");
1247
1248          if (splitPercentage > 0)
1249            text.append("\nTime taken to test model on training split: ");
1250          else
1251            text.append("\nTime taken to test model on training data: ");
1252          text.append(Utils.doubleToString(testTimeElapsed / 1000.0,2) + " seconds");
1253
1254          if (splitPercentage > 0)
1255            text.append(trainingEvaluation.toSummaryString("\n\n=== Error on training"
1256                  + " split ===\n", printComplexityStatistics));
1257          else
1258            text.append(trainingEvaluation.toSummaryString("\n\n=== Error on training"
1259                  + " data ===\n", printComplexityStatistics));
1260
1261          if (template.classAttribute().isNominal()) {
1262            if (classStatistics) {
1263              text.append("\n\n" + trainingEvaluation.toClassDetailsString());
1264            }
1265            if (!noCrossValidation)
1266              text.append("\n\n" + trainingEvaluation.toMatrixString());
1267          }
1268        }
1269      }
1270    }
1271
1272    // Compute proper error estimates
1273    if (testSource != null) {
1274      // Testing is on the supplied test data
1275      testSource.reset();
1276      test = testSource.getStructure(test.classIndex());
1277      Instance testInst;
1278      while (testSource.hasMoreElements(test)) {
1279        testInst = testSource.nextElement(test);
1280        testingEvaluation.evaluateModelOnceAndRecordPrediction(
1281            (Classifier)classifier, testInst);
1282      }
1283
1284      if (splitPercentage > 0) {
1285        if (classificationOutput == null) {
1286          text.append("\n\n" + testingEvaluation.
1287              toSummaryString("=== Error on test split ===\n",
1288                  printComplexityStatistics));
1289        }
1290      } else {
1291        if (classificationOutput == null) {
1292          text.append("\n\n" + testingEvaluation.
1293              toSummaryString("=== Error on test data ===\n",
1294                  printComplexityStatistics));
1295        }
1296      }
1297
1298    } else if (trainSource != null) {
1299      if (!noCrossValidation) {
1300        // Testing is via cross-validation on training data
1301        Random random = new Random(seed);
1302        // use untrained (!) classifier for cross-validation
1303        classifier = AbstractClassifier.makeCopy(classifierBackup);
1304        if (classificationOutput == null) {
1305          testingEvaluation.crossValidateModel(classifier,
1306                                               trainSource.getDataSet(actualClassIndex),
1307                                               folds, random);
1308          if (template.classAttribute().isNumeric()) {
1309            text.append("\n\n\n" + testingEvaluation.
1310                        toSummaryString("=== Cross-validation ===\n",
1311                                        printComplexityStatistics));
1312          } else {
1313            text.append("\n\n\n" + testingEvaluation.
1314                        toSummaryString("=== Stratified " +
1315                                        "cross-validation ===\n",
1316                                        printComplexityStatistics));
1317          }
1318        } else {
1319          predsBuff = new StringBuffer();
1320          classificationOutput.setBuffer(predsBuff);
1321          predsBuff.append("\n=== Predictions under cross-validation ===\n\n");
1322          testingEvaluation.crossValidateModel(classifier,
1323                                               trainSource.getDataSet(actualClassIndex),
1324                                               folds, random, classificationOutput);
1325        }
1326      }
1327    }
1328    if (template.classAttribute().isNominal()) {
1329      if (classStatistics && !noCrossValidation && (classificationOutput == null)) {
1330        text.append("\n\n" + testingEvaluation.toClassDetailsString());
1331      }
1332      if (!noCrossValidation && (classificationOutput == null))
1333        text.append("\n\n" + testingEvaluation.toMatrixString());
1334
1335    }
1336
1337    // predictions from cross-validation?
1338    if (predsBuff != null) {
1339      text.append("\n" + predsBuff);
1340    }
1341
1342    if ((thresholdFile.length() != 0) && template.classAttribute().isNominal()) {
1343      int labelIndex = 0;
1344      if (thresholdLabel.length() != 0)
1345        labelIndex = template.classAttribute().indexOfValue(thresholdLabel);
1346      if (labelIndex == -1)
1347        throw new IllegalArgumentException(
1348            "Class label '" + thresholdLabel + "' is unknown!");
1349      ThresholdCurve tc = new ThresholdCurve();
1350      Instances result = tc.getCurve(testingEvaluation.predictions(), labelIndex);
1351      DataSink.write(thresholdFile, result);
1352    }
1353
1354    return text.toString();
1355  }
1356
1357  /**
1358   * Attempts to load a cost matrix.
1359   *
1360   * @param costFileName the filename of the cost matrix
1361   * @param numClasses the number of classes that should be in the cost matrix
1362   * (only used if the cost file is in old format).
1363   * @return a <code>CostMatrix</code> value, or null if costFileName is empty
1364   * @throws Exception if an error occurs.
1365   */
1366  protected static CostMatrix handleCostOption(String costFileName,
1367      int numClasses)
1368    throws Exception {
1369
1370    if ((costFileName != null) && (costFileName.length() != 0)) {
1371      System.out.println(
1372          "NOTE: The behaviour of the -m option has changed between WEKA 3.0"
1373          +" and WEKA 3.1. -m now carries out cost-sensitive *evaluation*"
1374          +" only. For cost-sensitive *prediction*, use one of the"
1375          +" cost-sensitive metaschemes such as"
1376          +" weka.classifiers.meta.CostSensitiveClassifier or"
1377          +" weka.classifiers.meta.MetaCost");
1378
1379      Reader costReader = null;
1380      try {
1381        costReader = new BufferedReader(new FileReader(costFileName));
1382      } catch (Exception e) {
1383        throw new Exception("Can't open file " + e.getMessage() + '.');
1384      }
1385      try {
1386        // First try as a proper cost matrix format
1387        return new CostMatrix(costReader);
1388      } catch (Exception ex) {
1389        try {
1390          // Now try as the poxy old format :-)
1391          //System.err.println("Attempting to read old format cost file");
1392          try {
1393            costReader.close(); // Close the old one
1394            costReader = new BufferedReader(new FileReader(costFileName));
1395          } catch (Exception e) {
1396            throw new Exception("Can't open file " + e.getMessage() + '.');
1397          }
1398          CostMatrix costMatrix = new CostMatrix(numClasses);
1399          //System.err.println("Created default cost matrix");
1400          costMatrix.readOldFormat(costReader);
1401          return costMatrix;
1402          //System.err.println("Read old format");
1403        } catch (Exception e2) {
1404          // re-throw the original exception
1405          //System.err.println("Re-throwing original exception");
1406          throw ex;
1407        }
1408      }
1409    } else {
1410      return null;
1411    }
1412  }
1413
1414  /**
1415   * Evaluates the classifier on a given set of instances. Note that
1416   * the data must have exactly the same format (e.g. order of
1417   * attributes) as the data used to train the classifier! Otherwise
1418   * the results will generally be meaningless.
1419   *
1420   * @param classifier machine learning classifier
1421   * @param data set of test instances for evaluation
1422   * @param forPredictionsPrinting varargs parameter that, if supplied, is
1423   * expected to hold a weka.classifiers.evaluation.output.prediction.AbstractOutput
1424   * object
1425   * @return the predictions
1426   * @throws Exception if model could not be evaluated
1427   * successfully
1428   */
1429  public double[] evaluateModel(Classifier classifier,
1430                                Instances data,
1431                                Object... forPredictionsPrinting) throws Exception {
1432    // for predictions printing
1433    AbstractOutput classificationOutput = null;
1434
1435    double predictions[] = new double[data.numInstances()];
1436
1437    if (forPredictionsPrinting.length > 0) {
1438      classificationOutput = (AbstractOutput) forPredictionsPrinting[0];
1439    }
1440
1441    // Need to be able to collect predictions if appropriate (for AUC)
1442
1443    for (int i = 0; i < data.numInstances(); i++) {
1444      predictions[i] = evaluateModelOnceAndRecordPrediction((Classifier)classifier,
1445          data.instance(i));
1446      if (classificationOutput != null)
1447        classificationOutput.printClassification(classifier, data.instance(i), i);
1448    }
1449
1450    return predictions;
1451  }
1452
1453  /**
1454   * Evaluates the supplied distribution on a single instance.
1455   *
1456   * @param dist the supplied distribution
1457   * @param instance the test instance to be classified
1458   * @param storePredictions whether to store predictions for nominal classifier
1459   * @return the prediction
1460   * @throws Exception if model could not be evaluated successfully
1461   */
1462  public double evaluationForSingleInstance(double[] dist, Instance instance,
1463                                            boolean storePredictions) throws Exception {
1464
1465    double pred;
1466
1467    if (m_ClassIsNominal) {
1468      pred = Utils.maxIndex(dist);
1469      if (dist[(int)pred] <= 0) {
1470        pred = Utils.missingValue();
1471      }
1472      updateStatsForClassifier(dist, instance);
1473      if (storePredictions) {
1474        if (m_Predictions == null)
1475          m_Predictions = new FastVector();
1476        m_Predictions.addElement(new NominalPrediction(instance.classValue(), dist,
1477                                                       instance.weight()));
1478      }
1479    } else {
1480      pred = dist[0];
1481      updateStatsForPredictor(pred, instance);
1482      if (storePredictions) {
1483        if (m_Predictions == null)
1484          m_Predictions = new FastVector();
1485        m_Predictions.addElement(new NumericPrediction(instance.classValue(), pred,
1486                                                       instance.weight()));
1487      }
1488    }
1489
1490    return pred;
1491  }
1492
1493  /**
1494   * Evaluates the classifier on a single instance and records the
1495   * prediction.
1496   *
1497   * @param classifier machine learning classifier
1498   * @param instance the test instance to be classified
1499   * @param storePredictions whether to store predictions for nominal classifier
1500   * @return the prediction made by the clasifier
1501   * @throws Exception if model could not be evaluated
1502   * successfully or the data contains string attributes
1503   */
1504  protected double evaluationForSingleInstance(Classifier classifier,
1505                                               Instance instance,
1506                                               boolean storePredictions) throws Exception {
1507
1508    Instance classMissing = (Instance)instance.copy();
1509    classMissing.setDataset(instance.dataset());
1510    classMissing.setClassMissing();
1511    double pred = evaluationForSingleInstance(classifier.distributionForInstance(classMissing),
1512                                              instance, storePredictions);
1513
1514    // We don't need to do the following if the class is nominal because in that case
1515    // entropy and coverage statistics are always computed.
1516    if (!m_ClassIsNominal) {
1517      if (!instance.classIsMissing() && !Utils.isMissingValue(pred)) {
1518        if (classifier instanceof IntervalEstimator) {
1519          updateStatsForIntervalEstimator((IntervalEstimator)classifier, classMissing,
1520                                          instance.classValue());
1521        } else {
1522          m_CoverageStatisticsAvailable = false;
1523        }
1524        if (classifier instanceof ConditionalDensityEstimator) {
1525          updateStatsForConditionalDensityEstimator((ConditionalDensityEstimator)classifier,
1526                                                    classMissing, instance.classValue());
1527        } else {
1528          m_ComplexityStatisticsAvailable = false;
1529        }
1530      }
1531    }
1532    return pred;
1533  }
1534
1535  /**
1536   * Evaluates the classifier on a single instance and records the
1537   * prediction.
1538   *
1539   * @param classifier machine learning classifier
1540   * @param instance the test instance to be classified
1541   * @return the prediction made by the clasifier
1542   * @throws Exception if model could not be evaluated
1543   * successfully or the data contains string attributes
1544   */
1545  public double evaluateModelOnceAndRecordPrediction(Classifier classifier,
1546      Instance instance) throws Exception {
1547
1548    return evaluationForSingleInstance(classifier, instance, true);
1549  }
1550
1551  /**
1552   * Evaluates the classifier on a single instance.
1553   *
1554   * @param classifier machine learning classifier
1555   * @param instance the test instance to be classified
1556   * @return the prediction made by the clasifier
1557   * @throws Exception if model could not be evaluated
1558   * successfully or the data contains string attributes
1559   */
1560  public double evaluateModelOnce(Classifier classifier, Instance instance) throws Exception {
1561
1562    return evaluationForSingleInstance(classifier, instance, false);
1563  }
1564
1565  /**
1566   * Evaluates the supplied distribution on a single instance.
1567   *
1568   * @param dist the supplied distribution
1569   * @param instance the test instance to be classified
1570   * @return the prediction
1571   * @throws Exception if model could not be evaluated
1572   * successfully
1573   */
1574  public double evaluateModelOnce(double [] dist, Instance instance) throws Exception {
1575
1576    return evaluationForSingleInstance(dist, instance, false);
1577  }
1578
1579  /**
1580   * Evaluates the supplied distribution on a single instance.
1581   *
1582   * @param dist the supplied distribution
1583   * @param instance the test instance to be classified
1584   * @return the prediction
1585   * @throws Exception if model could not be evaluated
1586   * successfully
1587   */
1588  public double evaluateModelOnceAndRecordPrediction(double [] dist,
1589      Instance instance) throws Exception {
1590
1591    return evaluationForSingleInstance(dist, instance, true);
1592  }
1593
1594  /**
1595   * Evaluates the supplied prediction on a single instance.
1596   *
1597   * @param prediction the supplied prediction
1598   * @param instance the test instance to be classified
1599   * @throws Exception if model could not be evaluated
1600   * successfully
1601   */
1602  public void evaluateModelOnce(double prediction,
1603      Instance instance) throws Exception {
1604
1605    evaluateModelOnce(makeDistribution(prediction), instance);
1606  }
1607
1608  /**
1609   * Returns the predictions that have been collected.
1610   *
1611   * @return a reference to the FastVector containing the predictions
1612   * that have been collected. This should be null if no predictions
1613   * have been collected.
1614   */
1615  public FastVector predictions() {
1616    return m_Predictions;
1617  }
1618
1619  /**
1620   * Wraps a static classifier in enough source to test using the weka
1621   * class libraries.
1622   *
1623   * @param classifier a Sourcable Classifier
1624   * @param className the name to give to the source code class
1625   * @return the source for a static classifier that can be tested with
1626   * weka libraries.
1627   * @throws Exception if code-generation fails
1628   */
1629  public static String wekaStaticWrapper(Sourcable classifier, String className)
1630    throws Exception {
1631
1632    StringBuffer result = new StringBuffer();
1633    String staticClassifier = classifier.toSource(className);
1634
1635    result.append("// Generated with Weka " + Version.VERSION + "\n");
1636    result.append("//\n");
1637    result.append("// This code is public domain and comes with no warranty.\n");
1638    result.append("//\n");
1639    result.append("// Timestamp: " + new Date() + "\n");
1640    result.append("\n");
1641    result.append("package weka.classifiers;\n");
1642    result.append("\n");
1643    result.append("import weka.core.Attribute;\n");
1644    result.append("import weka.core.Capabilities;\n");
1645    result.append("import weka.core.Capabilities.Capability;\n");
1646    result.append("import weka.core.Instance;\n");
1647    result.append("import weka.core.Instances;\n");
1648    result.append("import weka.core.RevisionUtils;\n");
1649    result.append("import weka.classifiers.Classifier;\nimport weka.classifiers.AbstractClassifier;\n");
1650    result.append("\n");
1651    result.append("public class WekaWrapper\n");
1652    result.append("  extends AbstractClassifier {\n");
1653
1654    // globalInfo
1655    result.append("\n");
1656    result.append("  /**\n");
1657    result.append("   * Returns only the toString() method.\n");
1658    result.append("   *\n");
1659    result.append("   * @return a string describing the classifier\n");
1660    result.append("   */\n");
1661    result.append("  public String globalInfo() {\n");
1662    result.append("    return toString();\n");
1663    result.append("  }\n");
1664
1665    // getCapabilities
1666    result.append("\n");
1667    result.append("  /**\n");
1668    result.append("   * Returns the capabilities of this classifier.\n");
1669    result.append("   *\n");
1670    result.append("   * @return the capabilities\n");
1671    result.append("   */\n");
1672    result.append("  public Capabilities getCapabilities() {\n");
1673    result.append(((Classifier) classifier).getCapabilities().toSource("result", 4));
1674    result.append("    return result;\n");
1675    result.append("  }\n");
1676
1677    // buildClassifier
1678    result.append("\n");
1679    result.append("  /**\n");
1680    result.append("   * only checks the data against its capabilities.\n");
1681    result.append("   *\n");
1682    result.append("   * @param i the training data\n");
1683    result.append("   */\n");
1684    result.append("  public void buildClassifier(Instances i) throws Exception {\n");
1685    result.append("    // can classifier handle the data?\n");
1686    result.append("    getCapabilities().testWithFail(i);\n");
1687    result.append("  }\n");
1688
1689    // classifyInstance
1690    result.append("\n");
1691    result.append("  /**\n");
1692    result.append("   * Classifies the given instance.\n");
1693    result.append("   *\n");
1694    result.append("   * @param i the instance to classify\n");
1695    result.append("   * @return the classification result\n");
1696    result.append("   */\n");
1697    result.append("  public double classifyInstance(Instance i) throws Exception {\n");
1698    result.append("    Object[] s = new Object[i.numAttributes()];\n");
1699    result.append("    \n");
1700    result.append("    for (int j = 0; j < s.length; j++) {\n");
1701    result.append("      if (!i.isMissing(j)) {\n");
1702    result.append("        if (i.attribute(j).isNominal())\n");
1703    result.append("          s[j] = new String(i.stringValue(j));\n");
1704    result.append("        else if (i.attribute(j).isNumeric())\n");
1705    result.append("          s[j] = new Double(i.value(j));\n");
1706    result.append("      }\n");
1707    result.append("    }\n");
1708    result.append("    \n");
1709    result.append("    // set class value to missing\n");
1710    result.append("    s[i.classIndex()] = null;\n");
1711    result.append("    \n");
1712    result.append("    return " + className + ".classify(s);\n");
1713    result.append("  }\n");
1714
1715    // getRevision
1716    result.append("\n");
1717    result.append("  /**\n");
1718    result.append("   * Returns the revision string.\n");
1719    result.append("   * \n");
1720    result.append("   * @return        the revision\n");
1721    result.append("   */\n");
1722    result.append("  public String getRevision() {\n");
1723    result.append("    return RevisionUtils.extract(\"1.0\");\n");
1724    result.append("  }\n");
1725
1726    // toString
1727    result.append("\n");
1728    result.append("  /**\n");
1729    result.append("   * Returns only the classnames and what classifier it is based on.\n");
1730    result.append("   *\n");
1731    result.append("   * @return a short description\n");
1732    result.append("   */\n");
1733    result.append("  public String toString() {\n");
1734    result.append("    return \"Auto-generated classifier wrapper, based on "
1735        + classifier.getClass().getName() + " (generated with Weka " + Version.VERSION + ").\\n"
1736        + "\" + this.getClass().getName() + \"/" + className + "\";\n");
1737    result.append("  }\n");
1738
1739    // main
1740    result.append("\n");
1741    result.append("  /**\n");
1742    result.append("   * Runs the classfier from commandline.\n");
1743    result.append("   *\n");
1744    result.append("   * @param args the commandline arguments\n");
1745    result.append("   */\n");
1746    result.append("  public static void main(String args[]) {\n");
1747    result.append("    runClassifier(new WekaWrapper(), args);\n");
1748    result.append("  }\n");
1749    result.append("}\n");
1750
1751    // actual classifier code
1752    result.append("\n");
1753    result.append(staticClassifier);
1754
1755    return result.toString();
1756  }
1757
1758  /**
1759   * Gets the number of test instances that had a known class value
1760   * (actually the sum of the weights of test instances with known
1761   * class value).
1762   *
1763   * @return the number of test instances with known class
1764   */
1765  public final double numInstances() {
1766
1767    return m_WithClass;
1768  }
1769
1770  /**
1771   * Gets the coverage of the test cases by the predicted regions at
1772   * the confidence level specified when evaluation was performed.
1773   *
1774   * @return the coverage of the test cases by the predicted regions
1775   */
1776  public final double coverageOfTestCasesByPredictedRegions() {
1777
1778    if (!m_CoverageStatisticsAvailable)
1779      return Double.NaN;
1780
1781    return 100 * m_TotalCoverage / m_WithClass;
1782  }
1783
1784  /**
1785   * Gets the average size of the predicted regions, relative to the
1786   * range of the target in the training data, at the confidence level
1787   * specified when evaluation was performed.
1788   *
1789   * @return the average size of the predicted regions
1790   */
1791  public final double sizeOfPredictedRegions() {
1792
1793    if (m_NoPriors || !m_CoverageStatisticsAvailable)
1794      return Double.NaN;
1795
1796    return 100 * m_TotalSizeOfRegions / m_WithClass;
1797  }
1798
1799  /**
1800   * Gets the number of instances incorrectly classified (that is, for
1801   * which an incorrect prediction was made). (Actually the sum of the
1802   * weights of these instances)
1803   *
1804   * @return the number of incorrectly classified instances
1805   */
1806  public final double incorrect() {
1807
1808    return m_Incorrect;
1809  }
1810
1811  /**
1812   * Gets the percentage of instances incorrectly classified (that is,
1813   * for which an incorrect prediction was made).
1814   *
1815   * @return the percent of incorrectly classified instances
1816   * (between 0 and 100)
1817   */
1818  public final double pctIncorrect() {
1819
1820    return 100 * m_Incorrect / m_WithClass;
1821  }
1822
1823  /**
1824   * Gets the total cost, that is, the cost of each prediction times the
1825   * weight of the instance, summed over all instances.
1826   *
1827   * @return the total cost
1828   */
1829  public final double totalCost() {
1830
1831    return m_TotalCost;
1832  }
1833
1834  /**
1835   * Gets the average cost, that is, total cost of misclassifications
1836   * (incorrect plus unclassified) over the total number of instances.
1837   *
1838   * @return the average cost.
1839   */
1840  public final double avgCost() {
1841
1842    return m_TotalCost / m_WithClass;
1843  }
1844
1845  /**
1846   * Gets the number of instances correctly classified (that is, for
1847   * which a correct prediction was made). (Actually the sum of the weights
1848   * of these instances)
1849   *
1850   * @return the number of correctly classified instances
1851   */
1852  public final double correct() {
1853
1854    return m_Correct;
1855  }
1856
1857  /**
1858   * Gets the percentage of instances correctly classified (that is, for
1859   * which a correct prediction was made).
1860   *
1861   * @return the percent of correctly classified instances (between 0 and 100)
1862   */
1863  public final double pctCorrect() {
1864
1865    return 100 * m_Correct / m_WithClass;
1866  }
1867
1868  /**
1869   * Gets the number of instances not classified (that is, for
1870   * which no prediction was made by the classifier). (Actually the sum
1871   * of the weights of these instances)
1872   *
1873   * @return the number of unclassified instances
1874   */
1875  public final double unclassified() {
1876
1877    return m_Unclassified;
1878  }
1879
1880  /**
1881   * Gets the percentage of instances not classified (that is, for
1882   * which no prediction was made by the classifier).
1883   *
1884   * @return the percent of unclassified instances (between 0 and 100)
1885   */
1886  public final double pctUnclassified() {
1887
1888    return 100 * m_Unclassified / m_WithClass;
1889  }
1890
1891  /**
1892   * Returns the estimated error rate or the root mean squared error
1893   * (if the class is numeric). If a cost matrix was given this
1894   * error rate gives the average cost.
1895   *
1896   * @return the estimated error rate (between 0 and 1, or between 0 and
1897   * maximum cost)
1898   */
1899  public final double errorRate() {
1900
1901    if (!m_ClassIsNominal) {
1902      return Math.sqrt(m_SumSqrErr / (m_WithClass - m_Unclassified));
1903    }
1904    if (m_CostMatrix == null) {
1905      return m_Incorrect / m_WithClass;
1906    } else {
1907      return avgCost();
1908    }
1909  }
1910
1911  /**
1912   * Returns value of kappa statistic if class is nominal.
1913   *
1914   * @return the value of the kappa statistic
1915   */
1916  public final double kappa() {
1917
1918
1919    double[] sumRows = new double[m_ConfusionMatrix.length];
1920    double[] sumColumns = new double[m_ConfusionMatrix.length];
1921    double sumOfWeights = 0;
1922    for (int i = 0; i < m_ConfusionMatrix.length; i++) {
1923      for (int j = 0; j < m_ConfusionMatrix.length; j++) {
1924        sumRows[i] += m_ConfusionMatrix[i][j];
1925        sumColumns[j] += m_ConfusionMatrix[i][j];
1926        sumOfWeights += m_ConfusionMatrix[i][j];
1927      }
1928    }
1929    double correct = 0, chanceAgreement = 0;
1930    for (int i = 0; i < m_ConfusionMatrix.length; i++) {
1931      chanceAgreement += (sumRows[i] * sumColumns[i]);
1932      correct += m_ConfusionMatrix[i][i];
1933    }
1934    chanceAgreement /= (sumOfWeights * sumOfWeights);
1935    correct /= sumOfWeights;
1936
1937    if (chanceAgreement < 1) {
1938      return (correct - chanceAgreement) / (1 - chanceAgreement);
1939    } else {
1940      return 1;
1941    }
1942  }
1943
1944  /**
1945   * Returns the correlation coefficient if the class is numeric.
1946   *
1947   * @return the correlation coefficient
1948   * @throws Exception if class is not numeric
1949   */
1950  public final double correlationCoefficient() throws Exception {
1951
1952    if (m_ClassIsNominal) {
1953      throw
1954      new Exception("Can't compute correlation coefficient: " +
1955      "class is nominal!");
1956    }
1957
1958    double correlation = 0;
1959    double varActual =
1960      m_SumSqrClass - m_SumClass * m_SumClass /
1961      (m_WithClass - m_Unclassified);
1962    double varPredicted =
1963      m_SumSqrPredicted - m_SumPredicted * m_SumPredicted /
1964      (m_WithClass - m_Unclassified);
1965    double varProd =
1966      m_SumClassPredicted - m_SumClass * m_SumPredicted /
1967      (m_WithClass - m_Unclassified);
1968
1969    if (varActual * varPredicted <= 0) {
1970      correlation = 0.0;
1971    } else {
1972      correlation = varProd / Math.sqrt(varActual * varPredicted);
1973    }
1974
1975    return correlation;
1976  }
1977
1978  /**
1979   * Returns the mean absolute error. Refers to the error of the
1980   * predicted values for numeric classes, and the error of the
1981   * predicted probability distribution for nominal classes.
1982   *
1983   * @return the mean absolute error
1984   */
1985  public final double meanAbsoluteError() {
1986
1987    return m_SumAbsErr / (m_WithClass - m_Unclassified);
1988  }
1989
1990  /**
1991   * Returns the mean absolute error of the prior.
1992   *
1993   * @return the mean absolute error
1994   */
1995  public final double meanPriorAbsoluteError() {
1996
1997    if (m_NoPriors)
1998      return Double.NaN;
1999
2000    return m_SumPriorAbsErr / m_WithClass;
2001  }
2002
2003  /**
2004   * Returns the relative absolute error.
2005   *
2006   * @return the relative absolute error
2007   * @throws Exception if it can't be computed
2008   */
2009  public final double relativeAbsoluteError() throws Exception {
2010
2011    if (m_NoPriors)
2012      return Double.NaN;
2013
2014    return 100 * meanAbsoluteError() / meanPriorAbsoluteError();
2015  }
2016
2017  /**
2018   * Returns the root mean squared error.
2019   *
2020   * @return the root mean squared error
2021   */
2022  public final double rootMeanSquaredError() {
2023
2024    return Math.sqrt(m_SumSqrErr / (m_WithClass - m_Unclassified));
2025  }
2026
2027  /**
2028   * Returns the root mean prior squared error.
2029   *
2030   * @return the root mean prior squared error
2031   */
2032  public final double rootMeanPriorSquaredError() {
2033
2034    if (m_NoPriors)
2035      return Double.NaN;
2036
2037    return Math.sqrt(m_SumPriorSqrErr / m_WithClass);
2038  }
2039
2040  /**
2041   * Returns the root relative squared error if the class is numeric.
2042   *
2043   * @return the root relative squared error
2044   */
2045  public final double rootRelativeSquaredError() {
2046
2047    if (m_NoPriors)
2048      return Double.NaN;
2049
2050    return 100.0 * rootMeanSquaredError() / rootMeanPriorSquaredError();
2051  }
2052
2053  /**
2054   * Calculate the entropy of the prior distribution.
2055   *
2056   * @return the entropy of the prior distribution
2057   * @throws Exception if the class is not nominal
2058   */
2059  public final double priorEntropy() throws Exception {
2060
2061    if (!m_ClassIsNominal) {
2062      throw
2063      new Exception("Can't compute entropy of class prior: " +
2064      "class numeric!");
2065    }
2066
2067    if (m_NoPriors)
2068      return Double.NaN;
2069
2070    double entropy = 0;
2071    for(int i = 0; i < m_NumClasses; i++) {
2072      entropy -= m_ClassPriors[i] / m_ClassPriorsSum *
2073        Utils.log2(m_ClassPriors[i] / m_ClassPriorsSum);
2074    }
2075    return entropy;
2076  }
2077
2078  /**
2079   * Return the total Kononenko & Bratko Information score in bits.
2080   *
2081   * @return the K&B information score
2082   * @throws Exception if the class is not nominal
2083   */
2084  public final double KBInformation() throws Exception {
2085
2086    if (!m_ClassIsNominal) {
2087      throw
2088      new Exception("Can't compute K&B Info score: " +
2089      "class numeric!");
2090    }
2091
2092    if (m_NoPriors)
2093      return Double.NaN;
2094
2095    return m_SumKBInfo;
2096  }
2097
2098  /**
2099   * Return the Kononenko & Bratko Information score in bits per
2100   * instance.
2101   *
2102   * @return the K&B information score
2103   * @throws Exception if the class is not nominal
2104   */
2105  public final double KBMeanInformation() throws Exception {
2106
2107    if (!m_ClassIsNominal) {
2108      throw
2109      new Exception("Can't compute K&B Info score: class numeric!");
2110    }
2111
2112    if (m_NoPriors)
2113      return Double.NaN;
2114
2115    return m_SumKBInfo / (m_WithClass - m_Unclassified);
2116  }
2117
2118  /**
2119   * Return the Kononenko & Bratko Relative Information score.
2120   *
2121   * @return the K&B relative information score
2122   * @throws Exception if the class is not nominal
2123   */
2124  public final double KBRelativeInformation() throws Exception {
2125
2126    if (!m_ClassIsNominal) {
2127      throw
2128      new Exception("Can't compute K&B Info score: " +
2129      "class numeric!");
2130    }
2131
2132    if (m_NoPriors)
2133      return Double.NaN;
2134
2135    return 100.0 * KBInformation() / priorEntropy();
2136  }
2137
2138  /**
2139   * Returns the total entropy for the null model.
2140   *
2141   * @return the total null model entropy
2142   */
2143  public final double SFPriorEntropy() {
2144
2145    if (m_NoPriors || !m_ComplexityStatisticsAvailable)
2146      return Double.NaN;
2147
2148    return m_SumPriorEntropy;
2149  }
2150
2151  /**
2152   * Returns the entropy per instance for the null model.
2153   *
2154   * @return the null model entropy per instance
2155   */
2156  public final double SFMeanPriorEntropy() {
2157
2158    if (m_NoPriors || !m_ComplexityStatisticsAvailable)
2159      return Double.NaN;
2160
2161    return m_SumPriorEntropy / m_WithClass;
2162  }
2163
2164  /**
2165   * Returns the total entropy for the scheme.
2166   *
2167   * @return the total scheme entropy
2168   */
2169  public final double SFSchemeEntropy() {
2170
2171    if (!m_ComplexityStatisticsAvailable)
2172      return Double.NaN;
2173
2174    return m_SumSchemeEntropy;
2175  }
2176
2177  /**
2178   * Returns the entropy per instance for the scheme.
2179   *
2180   * @return the scheme entropy per instance
2181   */
2182  public final double SFMeanSchemeEntropy() {
2183
2184    if (!m_ComplexityStatisticsAvailable)
2185      return Double.NaN;
2186
2187    return m_SumSchemeEntropy / (m_WithClass - m_Unclassified);
2188  }
2189
2190  /**
2191   * Returns the total SF, which is the null model entropy minus
2192   * the scheme entropy.
2193   *
2194   * @return the total SF
2195   */
2196  public final double SFEntropyGain() {
2197
2198    if (m_NoPriors || !m_ComplexityStatisticsAvailable)
2199      return Double.NaN;
2200
2201    return m_SumPriorEntropy - m_SumSchemeEntropy;
2202  }
2203
2204  /**
2205   * Returns the SF per instance, which is the null model entropy
2206   * minus the scheme entropy, per instance.
2207   *
2208   * @return the SF per instance
2209   */
2210  public final double SFMeanEntropyGain() {
2211
2212    if (m_NoPriors || !m_ComplexityStatisticsAvailable)
2213      return Double.NaN;
2214
2215    return (m_SumPriorEntropy - m_SumSchemeEntropy) /
2216      (m_WithClass - m_Unclassified);
2217  }
2218
2219  /**
2220   * Output the cumulative margin distribution as a string suitable
2221   * for input for gnuplot or similar package.
2222   *
2223   * @return the cumulative margin distribution
2224   * @throws Exception if the class attribute is nominal
2225   */
2226  public String toCumulativeMarginDistributionString() throws Exception {
2227
2228    if (!m_ClassIsNominal) {
2229      throw new Exception("Class must be nominal for margin distributions");
2230    }
2231    String result = "";
2232    double cumulativeCount = 0;
2233    double margin;
2234    for(int i = 0; i <= k_MarginResolution; i++) {
2235      if (m_MarginCounts[i] != 0) {
2236        cumulativeCount += m_MarginCounts[i];
2237        margin = (double)i * 2.0 / k_MarginResolution - 1.0;
2238        result = result + Utils.doubleToString(margin, 7, 3) + ' '
2239          + Utils.doubleToString(cumulativeCount * 100
2240              / m_WithClass, 7, 3) + '\n';
2241      } else if (i == 0) {
2242        result = Utils.doubleToString(-1.0, 7, 3) + ' '
2243          + Utils.doubleToString(0, 7, 3) + '\n';
2244      }
2245    }
2246    return result;
2247  }
2248
2249  /**
2250   * Calls toSummaryString() with no title and no complexity stats.
2251   *
2252   * @return a summary description of the classifier evaluation
2253   */
2254  public String toSummaryString() {
2255
2256    return toSummaryString("", false);
2257  }
2258
2259  /**
2260   * Calls toSummaryString() with a default title.
2261   *
2262   * @param printComplexityStatistics if true, complexity statistics are
2263   * returned as well
2264   * @return the summary string
2265   */
2266  public String toSummaryString(boolean printComplexityStatistics) {
2267
2268    return toSummaryString("=== Summary ===\n", printComplexityStatistics);
2269  }
2270
2271  /**
2272   * Outputs the performance statistics in summary form. Lists
2273   * number (and percentage) of instances classified correctly,
2274   * incorrectly and unclassified. Outputs the total number of
2275   * instances classified, and the number of instances (if any)
2276   * that had no class value provided.
2277   *
2278   * @param title the title for the statistics
2279   * @param printComplexityStatistics if true, complexity statistics are
2280   * returned as well
2281   * @return the summary as a String
2282   */
2283  public String toSummaryString(String title,
2284      boolean printComplexityStatistics) {
2285
2286    StringBuffer text = new StringBuffer();
2287
2288    if (printComplexityStatistics && m_NoPriors) {
2289      printComplexityStatistics = false;
2290      System.err.println("Priors disabled, cannot print complexity statistics!");
2291    }
2292
2293    text.append(title + "\n");
2294    try {
2295      if (m_WithClass > 0) {
2296        if (m_ClassIsNominal) {
2297
2298          text.append("Correctly Classified Instances     ");
2299          text.append(Utils.doubleToString(correct(), 12, 4) + "     " +
2300              Utils.doubleToString(pctCorrect(),
2301                12, 4) + " %\n");
2302          text.append("Incorrectly Classified Instances   ");
2303          text.append(Utils.doubleToString(incorrect(), 12, 4) + "     " +
2304              Utils.doubleToString(pctIncorrect(),
2305                12, 4) + " %\n");
2306          text.append("Kappa statistic                    ");
2307          text.append(Utils.doubleToString(kappa(), 12, 4) + "\n");
2308
2309          if (m_CostMatrix != null) {
2310            text.append("Total Cost                         ");
2311            text.append(Utils.doubleToString(totalCost(), 12, 4) + "\n");
2312            text.append("Average Cost                       ");
2313            text.append(Utils.doubleToString(avgCost(), 12, 4) + "\n");
2314          }
2315          if (printComplexityStatistics) {
2316            text.append("K&B Relative Info Score            ");
2317            text.append(Utils.doubleToString(KBRelativeInformation(), 12, 4)
2318                + " %\n");
2319            text.append("K&B Information Score              ");
2320            text.append(Utils.doubleToString(KBInformation(), 12, 4)
2321                + " bits");
2322            text.append(Utils.doubleToString(KBMeanInformation(), 12, 4)
2323                + " bits/instance\n");
2324          }
2325        } else {
2326          text.append("Correlation coefficient            ");
2327          text.append(Utils.doubleToString(correlationCoefficient(), 12 , 4) +
2328              "\n");
2329        }
2330        if (printComplexityStatistics && m_ComplexityStatisticsAvailable) {
2331          text.append("Class complexity | order 0         ");
2332          text.append(Utils.doubleToString(SFPriorEntropy(), 12, 4)
2333              + " bits");
2334          text.append(Utils.doubleToString(SFMeanPriorEntropy(), 12, 4)
2335              + " bits/instance\n");
2336          text.append("Class complexity | scheme          ");
2337          text.append(Utils.doubleToString(SFSchemeEntropy(), 12, 4)
2338              + " bits");
2339          text.append(Utils.doubleToString(SFMeanSchemeEntropy(), 12, 4)
2340              + " bits/instance\n");
2341          text.append("Complexity improvement     (Sf)    ");
2342          text.append(Utils.doubleToString(SFEntropyGain(), 12, 4) + " bits");
2343          text.append(Utils.doubleToString(SFMeanEntropyGain(), 12, 4)
2344              + " bits/instance\n");
2345        }
2346
2347        text.append("Mean absolute error                ");
2348        text.append(Utils.doubleToString(meanAbsoluteError(), 12, 4)
2349            + "\n");
2350        text.append("Root mean squared error            ");
2351        text.append(Utils.
2352            doubleToString(rootMeanSquaredError(), 12, 4)
2353            + "\n");
2354        if (!m_NoPriors) {
2355          text.append("Relative absolute error            ");
2356          text.append(Utils.doubleToString(relativeAbsoluteError(),
2357                12, 4) + " %\n");
2358          text.append("Root relative squared error        ");
2359          text.append(Utils.doubleToString(rootRelativeSquaredError(),
2360                12, 4) + " %\n");
2361        }
2362        if (m_CoverageStatisticsAvailable) {
2363          text.append("Coverage of cases (" + Utils.doubleToString(m_ConfLevel, 4, 2) + " level)     ");
2364          text.append(Utils.doubleToString(coverageOfTestCasesByPredictedRegions(),
2365                12, 4) + " %\n");
2366          if (!m_NoPriors) {
2367            text.append("Mean rel. region size (" + Utils.doubleToString(m_ConfLevel, 4, 2) + " level) ");
2368            text.append(Utils.doubleToString(sizeOfPredictedRegions(), 12, 4) + " %\n");
2369          }
2370        }
2371      }
2372      if (Utils.gr(unclassified(), 0)) {
2373        text.append("UnClassified Instances             ");
2374        text.append(Utils.doubleToString(unclassified(), 12,4) +  "     " +
2375            Utils.doubleToString(pctUnclassified(),
2376              12, 4) + " %\n");
2377      }
2378      text.append("Total Number of Instances          ");
2379      text.append(Utils.doubleToString(m_WithClass, 12, 4) + "\n");
2380      if (m_MissingClass > 0) {
2381        text.append("Ignored Class Unknown Instances            ");
2382        text.append(Utils.doubleToString(m_MissingClass, 12, 4) + "\n");
2383      }
2384    } catch (Exception ex) {
2385      // Should never occur since the class is known to be nominal
2386      // here
2387      System.err.println("Arggh - Must be a bug in Evaluation class");
2388    }
2389
2390    return text.toString();
2391  }
2392
2393  /**
2394   * Calls toMatrixString() with a default title.
2395   *
2396   * @return the confusion matrix as a string
2397   * @throws Exception if the class is numeric
2398   */
2399  public String toMatrixString() throws Exception {
2400
2401    return toMatrixString("=== Confusion Matrix ===\n");
2402  }
2403
2404  /**
2405   * Outputs the performance statistics as a classification confusion
2406   * matrix. For each class value, shows the distribution of
2407   * predicted class values.
2408   *
2409   * @param title the title for the confusion matrix
2410   * @return the confusion matrix as a String
2411   * @throws Exception if the class is numeric
2412   */
2413  public String toMatrixString(String title) throws Exception {
2414
2415    StringBuffer text = new StringBuffer();
2416    char [] IDChars = {'a','b','c','d','e','f','g','h','i','j',
2417      'k','l','m','n','o','p','q','r','s','t',
2418      'u','v','w','x','y','z'};
2419    int IDWidth;
2420    boolean fractional = false;
2421
2422    if (!m_ClassIsNominal) {
2423      throw new Exception("Evaluation: No confusion matrix possible!");
2424    }
2425
2426    // Find the maximum value in the matrix
2427    // and check for fractional display requirement
2428    double maxval = 0;
2429    for(int i = 0; i < m_NumClasses; i++) {
2430      for(int j = 0; j < m_NumClasses; j++) {
2431        double current = m_ConfusionMatrix[i][j];
2432        if (current < 0) {
2433          current *= -10;
2434        }
2435        if (current > maxval) {
2436          maxval = current;
2437        }
2438        double fract = current - Math.rint(current);
2439        if (!fractional && ((Math.log(fract) / Math.log(10)) >= -2)) {
2440          fractional = true;
2441        }
2442      }
2443    }
2444
2445    IDWidth = 1 + Math.max((int)(Math.log(maxval) / Math.log(10)
2446          + (fractional ? 3 : 0)),
2447        (int)(Math.log(m_NumClasses) /
2448          Math.log(IDChars.length)));
2449    text.append(title).append("\n");
2450    for(int i = 0; i < m_NumClasses; i++) {
2451      if (fractional) {
2452        text.append(" ").append(num2ShortID(i,IDChars,IDWidth - 3))
2453          .append("   ");
2454      } else {
2455        text.append(" ").append(num2ShortID(i,IDChars,IDWidth));
2456      }
2457    }
2458    text.append("   <-- classified as\n");
2459    for(int i = 0; i< m_NumClasses; i++) {
2460      for(int j = 0; j < m_NumClasses; j++) {
2461        text.append(" ").append(
2462            Utils.doubleToString(m_ConfusionMatrix[i][j],
2463              IDWidth,
2464              (fractional ? 2 : 0)));
2465      }
2466      text.append(" | ").append(num2ShortID(i,IDChars,IDWidth))
2467        .append(" = ").append(m_ClassNames[i]).append("\n");
2468    }
2469    return text.toString();
2470  }
2471
2472  /**
2473   * Generates a breakdown of the accuracy for each class (with default title),
2474   * incorporating various information-retrieval statistics, such as
2475   * true/false positive rate, precision/recall/F-Measure.  Should be
2476   * useful for ROC curves, recall/precision curves.
2477   *
2478   * @return the statistics presented as a string
2479   * @throws Exception if class is not nominal
2480   */
2481  public String toClassDetailsString() throws Exception {
2482
2483    return toClassDetailsString("=== Detailed Accuracy By Class ===\n");
2484  }
2485
2486  /**
2487   * Generates a breakdown of the accuracy for each class,
2488   * incorporating various information-retrieval statistics, such as
2489   * true/false positive rate, precision/recall/F-Measure.  Should be
2490   * useful for ROC curves, recall/precision curves.
2491   *
2492   * @param title the title to prepend the stats string with
2493   * @return the statistics presented as a string
2494   * @throws Exception if class is not nominal
2495   */
2496  public String toClassDetailsString(String title) throws Exception {
2497
2498    if (!m_ClassIsNominal) {
2499      throw new Exception("Evaluation: No per class statistics possible!");
2500    }
2501
2502    StringBuffer text = new StringBuffer(title
2503        + "\n               TP Rate   FP Rate"
2504        + "   Precision   Recall"
2505        + "  F-Measure   ROC Area  Class\n");
2506    for(int i = 0; i < m_NumClasses; i++) {
2507      text.append("               " + Utils.doubleToString(truePositiveRate(i), 7, 3))
2508        .append("   ");
2509      text.append(Utils.doubleToString(falsePositiveRate(i), 7, 3))
2510        .append("    ");
2511      text.append(Utils.doubleToString(precision(i), 7, 3))
2512        .append("   ");
2513      text.append(Utils.doubleToString(recall(i), 7, 3))
2514        .append("   ");
2515      text.append(Utils.doubleToString(fMeasure(i), 7, 3))
2516        .append("    ");
2517
2518      double rocVal = areaUnderROC(i);
2519      if (Utils.isMissingValue(rocVal)) {
2520        text.append("  ?    ")
2521          .append("    ");
2522      } else {
2523        text.append(Utils.doubleToString(rocVal, 7, 3))
2524          .append("    ");
2525      }
2526      text.append(m_ClassNames[i]).append('\n');
2527    }
2528
2529    text.append("Weighted Avg.  " + Utils.doubleToString(weightedTruePositiveRate(), 7, 3));
2530    text.append("   " + Utils.doubleToString(weightedFalsePositiveRate(), 7 ,3));
2531    text.append("    " + Utils.doubleToString(weightedPrecision(), 7 ,3));
2532    text.append("   " + Utils.doubleToString(weightedRecall(), 7 ,3));
2533    text.append("   " + Utils.doubleToString(weightedFMeasure(), 7 ,3));
2534    text.append("    " + Utils.doubleToString(weightedAreaUnderROC(), 7 ,3));
2535    text.append("\n");
2536
2537    return text.toString();
2538  }
2539
2540  /**
2541   * Calculate the number of true positives with respect to a particular class.
2542   * This is defined as<p/>
2543   * <pre>
2544   * correctly classified positives
2545   * </pre>
2546   *
2547   * @param classIndex the index of the class to consider as "positive"
2548   * @return the true positive rate
2549   */
2550  public double numTruePositives(int classIndex) {
2551
2552    double correct = 0;
2553    for (int j = 0; j < m_NumClasses; j++) {
2554      if (j == classIndex) {
2555        correct += m_ConfusionMatrix[classIndex][j];
2556      }
2557    }
2558    return correct;
2559  }
2560
2561  /**
2562   * Calculate the true positive rate with respect to a particular class.
2563   * This is defined as<p/>
2564   * <pre>
2565   * correctly classified positives
2566   * ------------------------------
2567   *       total positives
2568   * </pre>
2569   *
2570   * @param classIndex the index of the class to consider as "positive"
2571   * @return the true positive rate
2572   */
2573  public double truePositiveRate(int classIndex) {
2574
2575    double correct = 0, total = 0;
2576    for (int j = 0; j < m_NumClasses; j++) {
2577      if (j == classIndex) {
2578        correct += m_ConfusionMatrix[classIndex][j];
2579      }
2580      total += m_ConfusionMatrix[classIndex][j];
2581    }
2582    if (total == 0) {
2583      return 0;
2584    }
2585    return correct / total;
2586  }
2587
2588  /**
2589   * Calculates the weighted (by class size) true positive rate.
2590   *
2591   * @return the weighted true positive rate.
2592   */
2593  public double weightedTruePositiveRate() {
2594    double[] classCounts = new double[m_NumClasses];
2595    double classCountSum = 0;
2596
2597    for (int i = 0; i < m_NumClasses; i++) {
2598      for (int j = 0; j < m_NumClasses; j++) {
2599        classCounts[i] += m_ConfusionMatrix[i][j];
2600      }
2601      classCountSum += classCounts[i];
2602    }
2603
2604    double truePosTotal = 0;
2605    for(int i = 0; i < m_NumClasses; i++) {
2606      double temp = truePositiveRate(i);
2607      truePosTotal += (temp * classCounts[i]);
2608    }
2609
2610    return truePosTotal / classCountSum;
2611  }
2612
2613  /**
2614   * Calculate the number of true negatives with respect to a particular class.
2615   * This is defined as<p/>
2616   * <pre>
2617   * correctly classified negatives
2618   * </pre>
2619   *
2620   * @param classIndex the index of the class to consider as "positive"
2621   * @return the true positive rate
2622   */
2623  public double numTrueNegatives(int classIndex) {
2624
2625    double correct = 0;
2626    for (int i = 0; i < m_NumClasses; i++) {
2627      if (i != classIndex) {
2628        for (int j = 0; j < m_NumClasses; j++) {
2629          if (j != classIndex) {
2630            correct += m_ConfusionMatrix[i][j];
2631          }
2632        }
2633      }
2634    }
2635    return correct;
2636  }
2637
2638  /**
2639   * Calculate the true negative rate with respect to a particular class.
2640   * This is defined as<p/>
2641   * <pre>
2642   * correctly classified negatives
2643   * ------------------------------
2644   *       total negatives
2645   * </pre>
2646   *
2647   * @param classIndex the index of the class to consider as "positive"
2648   * @return the true positive rate
2649   */
2650  public double trueNegativeRate(int classIndex) {
2651
2652    double correct = 0, total = 0;
2653    for (int i = 0; i < m_NumClasses; i++) {
2654      if (i != classIndex) {
2655        for (int j = 0; j < m_NumClasses; j++) {
2656          if (j != classIndex) {
2657            correct += m_ConfusionMatrix[i][j];
2658          }
2659          total += m_ConfusionMatrix[i][j];
2660        }
2661      }
2662    }
2663    if (total == 0) {
2664      return 0;
2665    }
2666    return correct / total;
2667  }
2668
2669  /**
2670   * Calculates the weighted (by class size) true negative rate.
2671   *
2672   * @return the weighted true negative rate.
2673   */
2674  public double weightedTrueNegativeRate() {
2675    double[] classCounts = new double[m_NumClasses];
2676    double classCountSum = 0;
2677
2678    for (int i = 0; i < m_NumClasses; i++) {
2679      for (int j = 0; j < m_NumClasses; j++) {
2680        classCounts[i] += m_ConfusionMatrix[i][j];
2681      }
2682      classCountSum += classCounts[i];
2683    }
2684
2685    double trueNegTotal = 0;
2686    for(int i = 0; i < m_NumClasses; i++) {
2687      double temp = trueNegativeRate(i);
2688      trueNegTotal += (temp * classCounts[i]);
2689    }
2690
2691    return trueNegTotal / classCountSum;
2692  }
2693
2694  /**
2695   * Calculate number of false positives with respect to a particular class.
2696   * This is defined as<p/>
2697   * <pre>
2698   * incorrectly classified negatives
2699   * </pre>
2700   *
2701   * @param classIndex the index of the class to consider as "positive"
2702   * @return the false positive rate
2703   */
2704  public double numFalsePositives(int classIndex) {
2705
2706    double incorrect = 0;
2707    for (int i = 0; i < m_NumClasses; i++) {
2708      if (i != classIndex) {
2709        for (int j = 0; j < m_NumClasses; j++) {
2710          if (j == classIndex) {
2711            incorrect += m_ConfusionMatrix[i][j];
2712          }
2713        }
2714      }
2715    }
2716    return incorrect;
2717  }
2718
2719  /**
2720   * Calculate the false positive rate with respect to a particular class.
2721   * This is defined as<p/>
2722   * <pre>
2723   * incorrectly classified negatives
2724   * --------------------------------
2725   *        total negatives
2726   * </pre>
2727   *
2728   * @param classIndex the index of the class to consider as "positive"
2729   * @return the false positive rate
2730   */
2731  public double falsePositiveRate(int classIndex) {
2732
2733    double incorrect = 0, total = 0;
2734    for (int i = 0; i < m_NumClasses; i++) {
2735      if (i != classIndex) {
2736        for (int j = 0; j < m_NumClasses; j++) {
2737          if (j == classIndex) {
2738            incorrect += m_ConfusionMatrix[i][j];
2739          }
2740          total += m_ConfusionMatrix[i][j];
2741        }
2742      }
2743    }
2744    if (total == 0) {
2745      return 0;
2746    }
2747    return incorrect / total;
2748  }
2749
2750  /**
2751   * Calculates the weighted (by class size) false positive rate.
2752   *
2753   * @return the weighted false positive rate.
2754   */
2755  public double weightedFalsePositiveRate() {
2756    double[] classCounts = new double[m_NumClasses];
2757    double classCountSum = 0;
2758
2759    for (int i = 0; i < m_NumClasses; i++) {
2760      for (int j = 0; j < m_NumClasses; j++) {
2761        classCounts[i] += m_ConfusionMatrix[i][j];
2762      }
2763      classCountSum += classCounts[i];
2764    }
2765
2766    double falsePosTotal = 0;
2767    for(int i = 0; i < m_NumClasses; i++) {
2768      double temp = falsePositiveRate(i);
2769      falsePosTotal += (temp * classCounts[i]);
2770    }
2771
2772    return falsePosTotal / classCountSum;
2773  }
2774
2775
2776
2777  /**
2778   * Calculate number of false negatives with respect to a particular class.
2779   * This is defined as<p/>
2780   * <pre>
2781   * incorrectly classified positives
2782   * </pre>
2783   *
2784   * @param classIndex the index of the class to consider as "positive"
2785   * @return the false positive rate
2786   */
2787  public double numFalseNegatives(int classIndex) {
2788
2789    double incorrect = 0;
2790    for (int i = 0; i < m_NumClasses; i++) {
2791      if (i == classIndex) {
2792        for (int j = 0; j < m_NumClasses; j++) {
2793          if (j != classIndex) {
2794            incorrect += m_ConfusionMatrix[i][j];
2795          }
2796        }
2797      }
2798    }
2799    return incorrect;
2800  }
2801
2802  /**
2803   * Calculate the false negative rate with respect to a particular class.
2804   * This is defined as<p/>
2805   * <pre>
2806   * incorrectly classified positives
2807   * --------------------------------
2808   *        total positives
2809   * </pre>
2810   *
2811   * @param classIndex the index of the class to consider as "positive"
2812   * @return the false positive rate
2813   */
2814  public double falseNegativeRate(int classIndex) {
2815
2816    double incorrect = 0, total = 0;
2817    for (int i = 0; i < m_NumClasses; i++) {
2818      if (i == classIndex) {
2819        for (int j = 0; j < m_NumClasses; j++) {
2820          if (j != classIndex) {
2821            incorrect += m_ConfusionMatrix[i][j];
2822          }
2823          total += m_ConfusionMatrix[i][j];
2824        }
2825      }
2826    }
2827    if (total == 0) {
2828      return 0;
2829    }
2830    return incorrect / total;
2831  }
2832
2833  /**
2834   * Calculates the weighted (by class size) false negative rate.
2835   *
2836   * @return the weighted false negative rate.
2837   */
2838  public double weightedFalseNegativeRate() {
2839    double[] classCounts = new double[m_NumClasses];
2840    double classCountSum = 0;
2841
2842    for (int i = 0; i < m_NumClasses; i++) {
2843      for (int j = 0; j < m_NumClasses; j++) {
2844        classCounts[i] += m_ConfusionMatrix[i][j];
2845      }
2846      classCountSum += classCounts[i];
2847    }
2848
2849    double falseNegTotal = 0;
2850    for(int i = 0; i < m_NumClasses; i++) {
2851      double temp = falseNegativeRate(i);
2852      falseNegTotal += (temp * classCounts[i]);
2853    }
2854
2855    return falseNegTotal / classCountSum;
2856  }
2857
2858  /**
2859   * Calculate the recall with respect to a particular class.
2860   * This is defined as<p/>
2861   * <pre>
2862   * correctly classified positives
2863   * ------------------------------
2864   *       total positives
2865   * </pre><p/>
2866   * (Which is also the same as the truePositiveRate.)
2867   *
2868   * @param classIndex the index of the class to consider as "positive"
2869   * @return the recall
2870   */
2871  public double recall(int classIndex) {
2872
2873    return truePositiveRate(classIndex);
2874  }
2875
2876  /**
2877   * Calculates the weighted (by class size) recall.
2878   *
2879   * @return the weighted recall.
2880   */
2881  public double weightedRecall() {
2882    return weightedTruePositiveRate();
2883  }
2884
2885  /**
2886   * Calculate the precision with respect to a particular class.
2887   * This is defined as<p/>
2888   * <pre>
2889   * correctly classified positives
2890   * ------------------------------
2891   *  total predicted as positive
2892   * </pre>
2893   *
2894   * @param classIndex the index of the class to consider as "positive"
2895   * @return the precision
2896   */
2897  public double precision(int classIndex) {
2898
2899    double correct = 0, total = 0;
2900    for (int i = 0; i < m_NumClasses; i++) {
2901      if (i == classIndex) {
2902        correct += m_ConfusionMatrix[i][classIndex];
2903      }
2904      total += m_ConfusionMatrix[i][classIndex];
2905    }
2906    if (total == 0) {
2907      return 0;
2908    }
2909    return correct / total;
2910  }
2911
2912  /**
2913   * Calculates the weighted (by class size) false precision.
2914   *
2915   * @return the weighted precision.
2916   */
2917  public double weightedPrecision() {
2918    double[] classCounts = new double[m_NumClasses];
2919    double classCountSum = 0;
2920
2921    for (int i = 0; i < m_NumClasses; i++) {
2922      for (int j = 0; j < m_NumClasses; j++) {
2923        classCounts[i] += m_ConfusionMatrix[i][j];
2924      }
2925      classCountSum += classCounts[i];
2926    }
2927
2928    double precisionTotal = 0;
2929    for(int i = 0; i < m_NumClasses; i++) {
2930      double temp = precision(i);
2931      precisionTotal += (temp * classCounts[i]);
2932    }
2933
2934    return precisionTotal / classCountSum;
2935  }
2936
2937  /**
2938   * Calculate the F-Measure with respect to a particular class.
2939   * This is defined as<p/>
2940   * <pre>
2941   * 2 * recall * precision
2942   * ----------------------
2943   *   recall + precision
2944   * </pre>
2945   *
2946   * @param classIndex the index of the class to consider as "positive"
2947   * @return the F-Measure
2948   */
2949  public double fMeasure(int classIndex) {
2950
2951    double precision = precision(classIndex);
2952    double recall = recall(classIndex);
2953    if ((precision + recall) == 0) {
2954      return 0;
2955    }
2956    return 2 * precision * recall / (precision + recall);
2957  }
2958
2959  /**
2960   * Calculates the macro weighted (by class size) average
2961   * F-Measure.
2962   *
2963   * @return the weighted F-Measure.
2964   */
2965  public double weightedFMeasure() {
2966    double[] classCounts = new double[m_NumClasses];
2967    double classCountSum = 0;
2968
2969    for (int i = 0; i < m_NumClasses; i++) {
2970      for (int j = 0; j < m_NumClasses; j++) {
2971        classCounts[i] += m_ConfusionMatrix[i][j];
2972      }
2973      classCountSum += classCounts[i];
2974    }
2975
2976    double fMeasureTotal = 0;
2977    for(int i = 0; i < m_NumClasses; i++) {
2978      double temp = fMeasure(i);
2979      fMeasureTotal += (temp * classCounts[i]);
2980    }
2981
2982    return fMeasureTotal / classCountSum;
2983  }
2984
2985  /**
2986   * Unweighted macro-averaged F-measure. If some classes not present in the
2987   * test set, they're just skipped (since recall is undefined there anyway) .
2988   *
2989   * @return unweighted macro-averaged F-measure.
2990   * */
2991  public double unweightedMacroFmeasure() {
2992    weka.experiment.Stats rr = new weka.experiment.Stats();
2993    for (int c = 0; c < m_NumClasses; c++) {
2994      // skip if no testing positive cases of this class
2995      if (numTruePositives(c)+numFalseNegatives(c) > 0) {
2996        rr.add(fMeasure(c));
2997      }
2998    }
2999    rr.calculateDerived();
3000    return rr.mean;
3001  }
3002
3003  /**
3004   * Unweighted micro-averaged F-measure. If some classes not present in the
3005   * test set, they have no effect.
3006   *
3007   * Note: if the test set is *single-label*, then this is the same as accuracy.
3008   *
3009   * @return unweighted micro-averaged F-measure.
3010   */
3011  public double unweightedMicroFmeasure() {
3012    double tp = 0;
3013    double fn = 0;
3014    double fp = 0;
3015    for (int c = 0; c < m_NumClasses; c++) {
3016      tp += numTruePositives(c);
3017      fn += numFalseNegatives(c);
3018      fp += numFalsePositives(c);
3019    }
3020    return 2*tp / (2*tp + fn + fp);
3021  }
3022
3023  /**
3024   * Sets the class prior probabilities.
3025   *
3026   * @param train the training instances used to determine the prior probabilities
3027   * @throws Exception if the class attribute of the instances is not set
3028   */
3029  public void setPriors(Instances train) throws Exception {
3030
3031    m_NoPriors = false;
3032
3033    if (!m_ClassIsNominal) {
3034
3035      m_NumTrainClassVals = 0;
3036      m_TrainClassVals = null;
3037      m_TrainClassWeights = null;
3038      m_PriorEstimator = null;
3039
3040      m_MinTarget = Double.MAX_VALUE;
3041      m_MaxTarget = -Double.MAX_VALUE;
3042
3043      for (int i = 0; i < train.numInstances(); i++) {
3044        Instance currentInst = train.instance(i);
3045        if (!currentInst.classIsMissing()) {
3046          addNumericTrainClass(currentInst.classValue(), currentInst.weight());
3047        }
3048      }
3049
3050      m_ClassPriors[0] = m_ClassPriorsSum = 0;
3051      for (int i = 0; i < train.numInstances(); i++) {
3052        if (!train.instance(i).classIsMissing()) {
3053          m_ClassPriors[0] += train.instance(i).classValue() * train.instance(i).weight();
3054          m_ClassPriorsSum += train.instance(i).weight();
3055        }
3056      }
3057
3058    } else {
3059      for (int i = 0; i < m_NumClasses; i++) {
3060        m_ClassPriors[i] = 1;
3061      }
3062      m_ClassPriorsSum = m_NumClasses;
3063      for (int i = 0; i < train.numInstances(); i++) {
3064        if (!train.instance(i).classIsMissing()) {
3065          m_ClassPriors[(int)train.instance(i).classValue()] +=
3066            train.instance(i).weight();
3067          m_ClassPriorsSum += train.instance(i).weight();
3068        }
3069      }
3070      m_MaxTarget = m_NumClasses;
3071      m_MinTarget = 0;
3072    }
3073  }
3074
3075  /**
3076   * Get the current weighted class counts.
3077   *
3078   * @return the weighted class counts
3079   */
3080  public double [] getClassPriors() {
3081    return m_ClassPriors;
3082  }
3083
3084  /**
3085   * Updates the class prior probabilities or the mean respectively (when incrementally
3086   * training).
3087   *
3088   * @param instance the new training instance seen
3089   * @throws Exception if the class of the instance is not set
3090   */
3091  public void updatePriors(Instance instance) throws Exception {
3092    if (!instance.classIsMissing()) {
3093      if (!m_ClassIsNominal) {
3094        addNumericTrainClass(instance.classValue(), instance.weight());
3095        m_ClassPriors[0] += instance.classValue() * instance.weight();
3096        m_ClassPriorsSum += instance.weight();
3097      } else {
3098        m_ClassPriors[(int)instance.classValue()] += instance.weight();
3099        m_ClassPriorsSum += instance.weight();
3100      }
3101    }
3102  }
3103
3104  /**
3105   * disables the use of priors, e.g., in case of de-serialized schemes
3106   * that have no access to the original training set, but are evaluated
3107   * on a set set.
3108   */
3109  public void useNoPriors() {
3110    m_NoPriors = true;
3111  }
3112
3113  /**
3114   * Tests whether the current evaluation object is equal to another
3115   * evaluation object.
3116   *
3117   * @param obj the object to compare against
3118   * @return true if the two objects are equal
3119   */
3120  public boolean equals(Object obj) {
3121
3122    if ((obj == null) || !(obj.getClass().equals(this.getClass()))) {
3123      return false;
3124    }
3125    Evaluation cmp = (Evaluation) obj;
3126    if (m_ClassIsNominal != cmp.m_ClassIsNominal) return false;
3127    if (m_NumClasses != cmp.m_NumClasses) return false;
3128
3129    if (m_Incorrect != cmp.m_Incorrect) return false;
3130    if (m_Correct != cmp.m_Correct) return false;
3131    if (m_Unclassified != cmp.m_Unclassified) return false;
3132    if (m_MissingClass != cmp.m_MissingClass) return false;
3133    if (m_WithClass != cmp.m_WithClass) return false;
3134
3135    if (m_SumErr != cmp.m_SumErr) return false;
3136    if (m_SumAbsErr != cmp.m_SumAbsErr) return false;
3137    if (m_SumSqrErr != cmp.m_SumSqrErr) return false;
3138    if (m_SumClass != cmp.m_SumClass) return false;
3139    if (m_SumSqrClass != cmp.m_SumSqrClass) return false;
3140    if (m_SumPredicted != cmp.m_SumPredicted) return false;
3141    if (m_SumSqrPredicted != cmp.m_SumSqrPredicted) return false;
3142    if (m_SumClassPredicted != cmp.m_SumClassPredicted) return false;
3143
3144    if (m_ClassIsNominal) {
3145      for (int i = 0; i < m_NumClasses; i++) {
3146        for (int j = 0; j < m_NumClasses; j++) {
3147          if (m_ConfusionMatrix[i][j] != cmp.m_ConfusionMatrix[i][j]) {
3148            return false;
3149          }
3150        }
3151      }
3152    }
3153
3154    return true;
3155  }
3156
3157  /**
3158   * Make up the help string giving all the command line options.
3159   *
3160   * @param classifier the classifier to include options for
3161   * @param globalInfo include the global information string
3162   * for the classifier (if available).
3163   * @return a string detailing the valid command line options
3164   */
3165  protected static String makeOptionString(Classifier classifier,
3166                                           boolean globalInfo) {
3167
3168    StringBuffer optionsText = new StringBuffer("");
3169
3170    // General options
3171    optionsText.append("\n\nGeneral options:\n\n");
3172    optionsText.append("-h or -help\n");
3173    optionsText.append("\tOutput help information.\n");
3174    optionsText.append("-synopsis or -info\n");
3175    optionsText.append("\tOutput synopsis for classifier (use in conjunction "
3176        + " with -h)\n");
3177    optionsText.append("-t <name of training file>\n");
3178    optionsText.append("\tSets training file.\n");
3179    optionsText.append("-T <name of test file>\n");
3180    optionsText.append("\tSets test file. If missing, a cross-validation will be performed\n");
3181    optionsText.append("\ton the training data.\n");
3182    optionsText.append("-c <class index>\n");
3183    optionsText.append("\tSets index of class attribute (default: last).\n");
3184    optionsText.append("-x <number of folds>\n");
3185    optionsText.append("\tSets number of folds for cross-validation (default: 10).\n");
3186    optionsText.append("-no-cv\n");
3187    optionsText.append("\tDo not perform any cross validation.\n");
3188    optionsText.append("-split-percentage <percentage>\n");
3189    optionsText.append("\tSets the percentage for the train/test set split, e.g., 66.\n");
3190    optionsText.append("-preserve-order\n");
3191    optionsText.append("\tPreserves the order in the percentage split.\n");
3192    optionsText.append("-s <random number seed>\n");
3193    optionsText.append("\tSets random number seed for cross-validation or percentage split\n");
3194    optionsText.append("\t(default: 1).\n");
3195    optionsText.append("-m <name of file with cost matrix>\n");
3196    optionsText.append("\tSets file with cost matrix.\n");
3197    optionsText.append("-l <name of input file>\n");
3198    optionsText.append("\tSets model input file. In case the filename ends with '.xml',\n");
3199    optionsText.append("\ta PMML file is loaded or, if that fails, options are loaded\n");
3200    optionsText.append("\tfrom the XML file.\n");
3201    optionsText.append("-d <name of output file>\n");
3202    optionsText.append("\tSets model output file. In case the filename ends with '.xml',\n");
3203    optionsText.append("\tonly the options are saved to the XML file, not the model.\n");
3204    optionsText.append("-v\n");
3205    optionsText.append("\tOutputs no statistics for training data.\n");
3206    optionsText.append("-o\n");
3207    optionsText.append("\tOutputs statistics only, not the classifier.\n");
3208    optionsText.append("-i\n");
3209    optionsText.append("\tOutputs detailed information-retrieval");
3210    optionsText.append(" statistics for each class.\n");
3211    optionsText.append("-k\n");
3212    optionsText.append("\tOutputs information-theoretic statistics.\n");
3213    optionsText.append("-classifications \"weka.classifiers.evaluation.output.prediction.AbstractOutput + options\"\n");
3214    optionsText.append("\tUses the specified class for generating the classification output.\n");
3215    optionsText.append("\tE.g.: " + PlainText.class.getName() + "\n");
3216    optionsText.append("-p range\n");
3217    optionsText.append("\tOutputs predictions for test instances (or the train instances if\n");
3218    optionsText.append("\tno test instances provided and -no-cv is used), along with the \n");
3219    optionsText.append("\tattributes in the specified range (and nothing else). \n");
3220    optionsText.append("\tUse '-p 0' if no attributes are desired.\n");
3221    optionsText.append("\tDeprecated: use \"-classifications ...\" instead.\n");
3222    optionsText.append("-distribution\n");
3223    optionsText.append("\tOutputs the distribution instead of only the prediction\n");
3224    optionsText.append("\tin conjunction with the '-p' option (only nominal classes).\n");
3225    optionsText.append("\tDeprecated: use \"-classifications ...\" instead.\n");
3226    optionsText.append("-r\n");
3227    optionsText.append("\tOnly outputs cumulative margin distribution.\n");
3228    if (classifier instanceof Sourcable) {
3229      optionsText.append("-z <class name>\n");
3230      optionsText.append("\tOnly outputs the source representation"
3231          + " of the classifier,\n\tgiving it the supplied"
3232          + " name.\n");
3233    }
3234    if (classifier instanceof Drawable) {
3235      optionsText.append("-g\n");
3236      optionsText.append("\tOnly outputs the graph representation"
3237          + " of the classifier.\n");
3238    }
3239    optionsText.append("-xml filename | xml-string\n");
3240    optionsText.append("\tRetrieves the options from the XML-data instead of the "
3241        + "command line.\n");
3242    optionsText.append("-threshold-file <file>\n");
3243    optionsText.append("\tThe file to save the threshold data to.\n"
3244        + "\tThe format is determined by the extensions, e.g., '.arff' for ARFF \n"
3245        + "\tformat or '.csv' for CSV.\n");
3246    optionsText.append("-threshold-label <label>\n");
3247    optionsText.append("\tThe class label to determine the threshold data for\n"
3248        + "\t(default is the first label)\n");
3249
3250    // Get scheme-specific options
3251    if (classifier instanceof OptionHandler) {
3252      optionsText.append("\nOptions specific to "
3253          + classifier.getClass().getName()
3254          + ":\n\n");
3255      Enumeration enu = ((OptionHandler)classifier).listOptions();
3256      while (enu.hasMoreElements()) {
3257        Option option = (Option) enu.nextElement();
3258        optionsText.append(option.synopsis() + '\n');
3259        optionsText.append(option.description() + "\n");
3260      }
3261    }
3262
3263    // Get global information (if available)
3264    if (globalInfo) {
3265      try {
3266        String gi = getGlobalInfo(classifier);
3267        optionsText.append(gi);
3268      } catch (Exception ex) {
3269        // quietly ignore
3270      }
3271    }
3272    return optionsText.toString();
3273  }
3274
3275  /**
3276   * Return the global info (if it exists) for the supplied classifier.
3277   *
3278   * @param classifier the classifier to get the global info for
3279   * @return the global info (synopsis) for the classifier
3280   * @throws Exception if there is a problem reflecting on the classifier
3281   */
3282  protected static String getGlobalInfo(Classifier classifier) throws Exception {
3283    BeanInfo bi = Introspector.getBeanInfo(classifier.getClass());
3284    MethodDescriptor[] methods;
3285    methods = bi.getMethodDescriptors();
3286    Object[] args = {};
3287    String result = "\nSynopsis for " + classifier.getClass().getName()
3288      + ":\n\n";
3289
3290    for (int i = 0; i < methods.length; i++) {
3291      String name = methods[i].getDisplayName();
3292      Method meth = methods[i].getMethod();
3293      if (name.equals("globalInfo")) {
3294        String globalInfo = (String)(meth.invoke(classifier, args));
3295        result += globalInfo;
3296        break;
3297      }
3298    }
3299
3300    return result;
3301  }
3302
3303  /**
3304   * Method for generating indices for the confusion matrix.
3305   *
3306   * @param num         integer to format
3307   * @param IDChars     the characters to use
3308   * @param IDWidth     the width of the entry
3309   * @return            the formatted integer as a string
3310   */
3311  protected String num2ShortID(int num, char[] IDChars, int IDWidth) {
3312
3313    char ID [] = new char [IDWidth];
3314    int i;
3315
3316    for(i = IDWidth - 1; i >=0; i--) {
3317      ID[i] = IDChars[num % IDChars.length];
3318      num = num / IDChars.length - 1;
3319      if (num < 0) {
3320        break;
3321      }
3322    }
3323    for(i--; i >= 0; i--) {
3324      ID[i] = ' ';
3325    }
3326
3327    return new String(ID);
3328  }
3329
3330  /**
3331   * Convert a single prediction into a probability distribution
3332   * with all zero probabilities except the predicted value which
3333   * has probability 1.0.
3334   *
3335   * @param predictedClass the index of the predicted class
3336   * @return the probability distribution
3337   */
3338  protected double [] makeDistribution(double predictedClass) {
3339
3340    double [] result = new double [m_NumClasses];
3341    if (Utils.isMissingValue(predictedClass)) {
3342      return result;
3343    }
3344    if (m_ClassIsNominal) {
3345      result[(int)predictedClass] = 1.0;
3346    } else {
3347      result[0] = predictedClass;
3348    }
3349    return result;
3350  }
3351
3352  /**
3353   * Updates all the statistics about a classifiers performance for
3354   * the current test instance.
3355   *
3356   * @param predictedDistribution the probabilities assigned to
3357   * each class
3358   * @param instance the instance to be classified
3359   * @throws Exception if the class of the instance is not
3360   * set
3361   */
3362  protected void updateStatsForClassifier(double [] predictedDistribution,
3363      Instance instance)
3364  throws Exception {
3365
3366    int actualClass = (int)instance.classValue();
3367
3368    if (!instance.classIsMissing()) {
3369      updateMargins(predictedDistribution, actualClass, instance.weight());
3370
3371      // Determine the predicted class (doesn't detect multiple
3372      // classifications)
3373      int predictedClass = -1;
3374      double bestProb = 0.0;
3375      for(int i = 0; i < m_NumClasses; i++) {
3376        if (predictedDistribution[i] > bestProb) {
3377          predictedClass = i;
3378          bestProb = predictedDistribution[i];
3379        }
3380      }
3381
3382      m_WithClass += instance.weight();
3383
3384      // Determine misclassification cost
3385      if (m_CostMatrix != null) {
3386        if (predictedClass < 0) {
3387          // For missing predictions, we assume the worst possible cost.
3388          // This is pretty harsh.
3389          // Perhaps we could take the negative of the cost of a correct
3390          // prediction (-m_CostMatrix.getElement(actualClass,actualClass)),
3391          // although often this will be zero
3392          m_TotalCost += instance.weight() * m_CostMatrix.getMaxCost(actualClass, instance);
3393        } else {
3394          m_TotalCost += instance.weight() * m_CostMatrix.getElement(actualClass, predictedClass,
3395              instance);
3396        }
3397      }
3398
3399      // Update counts when no class was predicted
3400      if (predictedClass < 0) {
3401        m_Unclassified += instance.weight();
3402        return;
3403      }
3404
3405      double predictedProb = Math.max(MIN_SF_PROB, predictedDistribution[actualClass]);
3406      double priorProb = Math.max(MIN_SF_PROB, m_ClassPriors[actualClass] / m_ClassPriorsSum);
3407      if (predictedProb >= priorProb) {
3408        m_SumKBInfo += (Utils.log2(predictedProb) - Utils.log2(priorProb)) * instance.weight();
3409      } else {
3410        m_SumKBInfo -= (Utils.log2(1.0-predictedProb) - Utils.log2(1.0-priorProb))
3411          * instance.weight();
3412      }
3413
3414      m_SumSchemeEntropy -= Utils.log2(predictedProb) * instance.weight();
3415      m_SumPriorEntropy -= Utils.log2(priorProb) * instance.weight();
3416
3417      updateNumericScores(predictedDistribution,
3418          makeDistribution(instance.classValue()),
3419          instance.weight());
3420
3421      // Update coverage stats
3422      int[] indices = Utils.sort(predictedDistribution);
3423      double sum = 0, sizeOfRegions = 0;
3424      for (int i = predictedDistribution.length - 1; i >= 0; i--) {
3425        if (sum >= m_ConfLevel) {
3426          break;
3427        }
3428        sum += predictedDistribution[indices[i]];
3429        sizeOfRegions++;
3430        if (actualClass == indices[i]) {
3431          m_TotalCoverage += instance.weight();
3432        }
3433      }
3434      m_TotalSizeOfRegions += sizeOfRegions / (m_MaxTarget - m_MinTarget);
3435
3436      // Update other stats
3437      m_ConfusionMatrix[actualClass][predictedClass] += instance.weight();
3438      if (predictedClass != actualClass) {
3439        m_Incorrect += instance.weight();
3440      } else {
3441        m_Correct += instance.weight();
3442      }
3443    } else {
3444      m_MissingClass += instance.weight();
3445    }
3446  }
3447
3448  /**
3449   * Updates stats for interval estimator based on current test instance.
3450   *
3451   * @param classifier the interval estimator
3452   * @param classMissing the instance for which the intervals are computed, without a class value
3453   * @param classValue the class value of this instance
3454   * @throws Exception if intervals could not be computed successfully
3455   */
3456  protected void updateStatsForIntervalEstimator(IntervalEstimator  classifier, Instance classMissing,
3457                                                 double classValue) throws Exception {
3458
3459    double[][] preds = classifier.predictIntervals(classMissing, m_ConfLevel);
3460    if (m_Predictions != null)
3461      ((NumericPrediction) m_Predictions.lastElement()).setPredictionIntervals(preds);
3462    for (int i = 0; i < preds.length; i++) {
3463      m_TotalSizeOfRegions += (preds[i][1] - preds[i][0]) / (m_MaxTarget - m_MinTarget);
3464    }
3465    for (int i = 0; i < preds.length; i++) {
3466      if ((preds[i][1] >= classValue) && (preds[i][0] <= classValue)) {
3467        m_TotalCoverage += classMissing.weight();
3468        break;
3469      }
3470    }
3471  }
3472
3473  /**
3474   * Updates stats for conditional density estimator based on current test instance.
3475   *
3476   * @param classifier the conditional density estimator
3477   * @param classMissing the instance for which density is to be computed, without a class value
3478   * @param classValue the class value of this instance
3479   * @throws Exception if density could not be computed successfully
3480   */
3481  protected void updateStatsForConditionalDensityEstimator(ConditionalDensityEstimator classifier,
3482                                                           Instance classMissing,
3483                                                           double classValue) throws Exception {
3484
3485    if (m_PriorEstimator == null) {
3486      setNumericPriorsFromBuffer();
3487    }
3488    m_SumSchemeEntropy -= classifier.logDensity(classMissing, classValue) * classMissing.weight() /
3489      Utils.log2;
3490    m_SumPriorEntropy -= m_PriorEstimator.logDensity(classValue) * classMissing.weight() /
3491      Utils.log2;
3492  }
3493
3494  /**
3495   * Updates all the statistics about a predictors performance for
3496   * the current test instance.
3497   *
3498   * @param predictedValue the numeric value the classifier predicts
3499   * @param instance the instance to be classified
3500   * @throws Exception if the class of the instance is not set
3501   */
3502  protected void updateStatsForPredictor(double predictedValue, Instance instance)
3503    throws Exception {
3504
3505    if (!instance.classIsMissing()){
3506
3507      // Update stats
3508      m_WithClass += instance.weight();
3509      if (Utils.isMissingValue(predictedValue)) {
3510        m_Unclassified += instance.weight();
3511        return;
3512      }
3513      m_SumClass += instance.weight() * instance.classValue();
3514      m_SumSqrClass += instance.weight() * instance.classValue() * instance.classValue();
3515      m_SumClassPredicted += instance.weight() * instance.classValue() * predictedValue;
3516      m_SumPredicted += instance.weight() * predictedValue;
3517      m_SumSqrPredicted += instance.weight() * predictedValue * predictedValue;
3518
3519      updateNumericScores(makeDistribution(predictedValue),
3520          makeDistribution(instance.classValue()),
3521          instance.weight());
3522
3523    } else
3524      m_MissingClass += instance.weight();
3525  }
3526
3527  /**
3528   * Update the cumulative record of classification margins.
3529   *
3530   * @param predictedDistribution the probability distribution predicted for
3531   * the current instance
3532   * @param actualClass the index of the actual instance class
3533   * @param weight the weight assigned to the instance
3534   */
3535  protected void updateMargins(double [] predictedDistribution,
3536      int actualClass, double weight) {
3537
3538    double probActual = predictedDistribution[actualClass];
3539    double probNext = 0;
3540
3541    for(int i = 0; i < m_NumClasses; i++)
3542      if ((i != actualClass) &&
3543          (predictedDistribution[i] > probNext))
3544        probNext = predictedDistribution[i];
3545
3546    double margin = probActual - probNext;
3547    int bin = (int)((margin + 1.0) / 2.0 * k_MarginResolution);
3548    m_MarginCounts[bin] += weight;
3549  }
3550
3551  /**
3552   * Update the numeric accuracy measures. For numeric classes, the
3553   * accuracy is between the actual and predicted class values. For
3554   * nominal classes, the accuracy is between the actual and
3555   * predicted class probabilities.
3556   *
3557   * @param predicted the predicted values
3558   * @param actual the actual value
3559   * @param weight the weight associated with this prediction
3560   */
3561  protected void updateNumericScores(double [] predicted,
3562      double [] actual, double weight) {
3563
3564    double diff;
3565    double sumErr = 0, sumAbsErr = 0, sumSqrErr = 0;
3566    double sumPriorAbsErr = 0, sumPriorSqrErr = 0;
3567    for(int i = 0; i < m_NumClasses; i++) {
3568      diff = predicted[i] - actual[i];
3569      sumErr += diff;
3570      sumAbsErr += Math.abs(diff);
3571      sumSqrErr += diff * diff;
3572      diff = (m_ClassPriors[i] / m_ClassPriorsSum) - actual[i];
3573      sumPriorAbsErr += Math.abs(diff);
3574      sumPriorSqrErr += diff * diff;
3575    }
3576    m_SumErr += weight * sumErr / m_NumClasses;
3577    m_SumAbsErr += weight * sumAbsErr / m_NumClasses;
3578    m_SumSqrErr += weight * sumSqrErr / m_NumClasses;
3579    m_SumPriorAbsErr += weight * sumPriorAbsErr / m_NumClasses;
3580    m_SumPriorSqrErr += weight * sumPriorSqrErr / m_NumClasses;
3581  }
3582
3583  /**
3584   * Adds a numeric (non-missing) training class value and weight to
3585   * the buffer of stored values. Also updates minimum and maximum target value.
3586   *
3587   * @param classValue the class value
3588   * @param weight the instance weight
3589   */
3590  protected void addNumericTrainClass(double classValue, double weight) {
3591
3592    // Update minimum and maximum target value
3593    if (classValue > m_MaxTarget) {
3594      m_MaxTarget = classValue;
3595    }
3596    if (classValue < m_MinTarget) {
3597      m_MinTarget = classValue;
3598    }
3599
3600    // Update buffer
3601    if (m_TrainClassVals == null) {
3602      m_TrainClassVals = new double [100];
3603      m_TrainClassWeights = new double [100];
3604    }
3605    if (m_NumTrainClassVals == m_TrainClassVals.length) {
3606      double [] temp = new double [m_TrainClassVals.length * 2];
3607      System.arraycopy(m_TrainClassVals, 0,
3608          temp, 0, m_TrainClassVals.length);
3609      m_TrainClassVals = temp;
3610
3611      temp = new double [m_TrainClassWeights.length * 2];
3612      System.arraycopy(m_TrainClassWeights, 0,
3613          temp, 0, m_TrainClassWeights.length);
3614      m_TrainClassWeights = temp;
3615    }
3616    m_TrainClassVals[m_NumTrainClassVals] = classValue;
3617    m_TrainClassWeights[m_NumTrainClassVals] = weight;
3618    m_NumTrainClassVals++;
3619  }
3620
3621  /**
3622   * Sets up the priors for numeric class attributes from the
3623   * training class values that have been seen so far.
3624   */
3625  protected void setNumericPriorsFromBuffer() {
3626
3627    m_PriorEstimator = new UnivariateKernelEstimator();
3628    for (int i = 0; i < m_NumTrainClassVals; i++) {
3629      m_PriorEstimator.addValue(m_TrainClassVals[i], m_TrainClassWeights[i]);
3630    }
3631  }
3632
3633  /**
3634   * Returns the revision string.
3635   *
3636   * @return            the revision
3637   */
3638  public String getRevision() {
3639    return RevisionUtils.extract("$Revision: 6041 $");
3640  }
3641}
Note: See TracBrowser for help on using the repository browser.