source: src/main/java/weka/classifiers/evaluation/MarginCurve.java @ 24

Last change on this file since 24 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 5.5 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    MarginCurve.java
19 *    Copyright (C) 2002 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.evaluation;
24
25import weka.core.Attribute;
26import weka.core.FastVector;
27import weka.core.Instance;
28import weka.core.DenseInstance;
29import weka.core.Instances;
30import weka.core.RevisionHandler;
31import weka.core.RevisionUtils;
32import weka.core.Utils;
33
34/**
35 * Generates points illustrating the prediction margin. The margin is defined
36 * as the difference between the probability predicted for the actual class and
37 * the highest probability predicted for the other classes. One hypothesis
38 * as to the good performance of boosting algorithms is that they increaes the
39 * margins on the training data and this gives better performance on test data.
40 *
41 * @author Len Trigg (len@reeltwo.com)
42 * @version $Revision: 5987 $
43 */
44public class MarginCurve
45  implements RevisionHandler {
46
47  /**
48   * Calculates the cumulative margin distribution for the set of
49   * predictions, returning the result as a set of Instances. The
50   * structure of these Instances is as follows:<p> <ul>
51   * <li> <b>Margin</b> contains the margin value (which should be plotted
52   * as an x-coordinate)
53   * <li> <b>Current</b> contains the count of instances with the current
54   * margin (plot as y axis)
55   * <li> <b>Cumulative</b> contains the count of instances with margin
56   * less than or equal to the current margin (plot as y axis)
57   * </ul> <p>
58   *
59   * @return datapoints as a set of instances, null if no predictions
60   * have been made. 
61   */
62  public Instances getCurve(FastVector predictions) {
63
64    if (predictions.size() == 0) {
65      return null;
66    }
67
68    Instances insts = makeHeader();
69    double [] margins = getMargins(predictions);
70    int [] sorted = Utils.sort(margins);
71    int binMargin = 0;
72    int totalMargin = 0;
73    insts.add(makeInstance(-1, binMargin, totalMargin));
74    for (int i = 0; i < sorted.length; i++) {
75      double current = margins[sorted[i]];
76      double weight = ((NominalPrediction)predictions.elementAt(sorted[i]))
77        .weight();
78      totalMargin += weight;
79      binMargin += weight;
80      if (true) {
81        insts.add(makeInstance(current, binMargin, totalMargin));
82        binMargin = 0;
83      }
84    }
85    return insts;
86  }
87
88  /**
89   * Pulls all the margin values out of a vector of NominalPredictions.
90   *
91   * @param predictions a FastVector containing NominalPredictions
92   * @return an array of margin values.
93   */
94  private double [] getMargins(FastVector predictions) {
95
96    // sort by predicted probability of the desired class.
97    double [] margins = new double [predictions.size()];
98    for (int i = 0; i < margins.length; i++) {
99      NominalPrediction pred = (NominalPrediction)predictions.elementAt(i);
100      margins[i] = pred.margin();
101    }
102    return margins;
103  }
104
105  /**
106   * Creates an Instances object with the attributes we will be calculating.
107   *
108   * @return the Instances structure.
109   */
110  private Instances makeHeader() {
111
112    FastVector fv = new FastVector();
113    fv.addElement(new Attribute("Margin"));
114    fv.addElement(new Attribute("Current"));
115    fv.addElement(new Attribute("Cumulative"));
116    return new Instances("MarginCurve", fv, 100);
117  }
118 
119  /**
120   * Creates an Instance object with the attributes calculated.
121   *
122   * @param margin the margin for this data point.
123   * @param current the number of instances with this margin.
124   * @param cumulative the number of instances with margin less than or equal
125   * to this margin.
126   * @return the Instance object.
127   */
128  private Instance makeInstance(double margin, int current, int cumulative) {
129
130    int count = 0;
131    double [] vals = new double[3];
132    vals[count++] = margin;
133    vals[count++] = current;
134    vals[count++] = cumulative;
135    return new DenseInstance(1.0, vals);
136  }
137 
138  /**
139   * Returns the revision string.
140   *
141   * @return            the revision
142   */
143  public String getRevision() {
144    return RevisionUtils.extract("$Revision: 5987 $");
145  }
146 
147  /**
148   * Tests the MarginCurve generation from the command line.
149   * The classifier is currently hardcoded. Pipe in an arff file.
150   *
151   * @param args currently ignored
152   */
153  public static void main(String [] args) {
154
155    try {
156      Utils.SMALL = 0;
157      Instances inst = new Instances(new java.io.InputStreamReader(System.in));
158      inst.setClassIndex(inst.numAttributes() - 1);
159      MarginCurve tc = new MarginCurve();
160      EvaluationUtils eu = new EvaluationUtils();
161      weka.classifiers.meta.LogitBoost classifier
162        = new weka.classifiers.meta.LogitBoost();
163      classifier.setNumIterations(20);
164      FastVector predictions
165        = eu.getTrainTestPredictions(classifier, inst, inst);
166      Instances result = tc.getCurve(predictions);
167      System.out.println(result);
168    } catch (Exception ex) {
169      ex.printStackTrace();
170    }
171  }
172}
Note: See TracBrowser for help on using the repository browser.