source: src/main/java/weka/classifiers/evaluation/EvaluationUtils.java @ 21

Last change on this file since 21 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 4.9 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    EvaluationUtils.java
19 *    Copyright (C) 2002 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.evaluation;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.core.FastVector;
28import weka.core.Instance;
29import weka.core.Instances;
30import weka.core.RevisionHandler;
31import weka.core.RevisionUtils;
32
33import java.util.Random;
34
35/**
36 * Contains utility functions for generating lists of predictions in
37 * various manners.
38 *
39 * @author Len Trigg (len@reeltwo.com)
40 * @version $Revision: 5928 $
41 */
42public class EvaluationUtils
43  implements RevisionHandler {
44
45  /** Seed used to randomize data in cross-validation */
46  private int m_Seed = 1;
47
48  /** Sets the seed for randomization during cross-validation */
49  public void setSeed(int seed) { m_Seed = seed; }
50
51  /** Gets the seed for randomization during cross-validation */
52  public int getSeed() { return m_Seed; }
53 
54  /**
55   * Generate a bunch of predictions ready for processing, by performing a
56   * cross-validation on the supplied dataset.
57   *
58   * @param classifier the Classifier to evaluate
59   * @param data the dataset
60   * @param numFolds the number of folds in the cross-validation.
61   * @exception Exception if an error occurs
62   */
63  public FastVector getCVPredictions(Classifier classifier, 
64                                     Instances data, 
65                                     int numFolds) 
66    throws Exception {
67
68    FastVector predictions = new FastVector();
69    Instances runInstances = new Instances(data);
70    Random random = new Random(m_Seed);
71    runInstances.randomize(random);
72    if (runInstances.classAttribute().isNominal() && (numFolds > 1)) {
73      runInstances.stratify(numFolds);
74    }
75    int inst = 0;
76    for (int fold = 0; fold < numFolds; fold++) {
77      Instances train = runInstances.trainCV(numFolds, fold, random);
78      Instances test = runInstances.testCV(numFolds, fold);
79      FastVector foldPred = getTrainTestPredictions(classifier, train, test);
80      predictions.appendElements(foldPred);
81    } 
82    return predictions;
83  }
84
85  /**
86   * Generate a bunch of predictions ready for processing, by performing a
87   * evaluation on a test set after training on the given training set.
88   *
89   * @param classifier the Classifier to evaluate
90   * @param train the training dataset
91   * @param test the test dataset
92   * @exception Exception if an error occurs
93   */
94  public FastVector getTrainTestPredictions(Classifier classifier, 
95                                            Instances train, Instances test) 
96    throws Exception {
97   
98    classifier.buildClassifier(train);
99    return getTestPredictions(classifier, test);
100  }
101
102  /**
103   * Generate a bunch of predictions ready for processing, by performing a
104   * evaluation on a test set assuming the classifier is already trained.
105   *
106   * @param classifier the pre-trained Classifier to evaluate
107   * @param test the test dataset
108   * @exception Exception if an error occurs
109   */
110  public FastVector getTestPredictions(Classifier classifier, 
111                                       Instances test) 
112    throws Exception {
113   
114    FastVector predictions = new FastVector();
115    for (int i = 0; i < test.numInstances(); i++) {
116      if (!test.instance(i).classIsMissing()) {
117        predictions.addElement(getPrediction(classifier, test.instance(i)));
118      }
119    }
120    return predictions;
121  }
122
123 
124  /**
125   * Generate a single prediction for a test instance given the pre-trained
126   * classifier.
127   *
128   * @param classifier the pre-trained Classifier to evaluate
129   * @param test the test instance
130   * @exception Exception if an error occurs
131   */
132  public Prediction getPrediction(Classifier classifier,
133                                  Instance test)
134    throws Exception {
135   
136    double actual = test.classValue();
137    double [] dist = classifier.distributionForInstance(test);
138    if (test.classAttribute().isNominal()) {
139      return new NominalPrediction(actual, dist, test.weight());
140    } else {
141      return new NumericPrediction(actual, dist[0], test.weight());
142    }
143  }
144 
145  /**
146   * Returns the revision string.
147   *
148   * @return            the revision
149   */
150  public String getRevision() {
151    return RevisionUtils.extract("$Revision: 5928 $");
152  }
153}
154
Note: See TracBrowser for help on using the repository browser.