source: src/main/java/weka/gui/explorer/ClassifierErrorsPlotInstances.java @ 7

Last change on this file since 7 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 14.7 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 * ClassifierErrorsPlotInstances.java
19 * Copyright (C) 2009 University of Waikato, Hamilton, New Zealand
20 */
21
22package weka.gui.explorer;
23
24import weka.classifiers.Classifier;
25import weka.classifiers.Evaluation;
26import weka.classifiers.IntervalEstimator;
27import weka.classifiers.evaluation.NumericPrediction;
28import weka.core.Attribute;
29import weka.core.DenseInstance;
30import weka.core.FastVector;
31import weka.core.Instance;
32import weka.core.Instances;
33import weka.core.Utils;
34import weka.gui.visualize.Plot2D;
35import weka.gui.visualize.PlotData2D;
36
37/**
38 * A class for generating plottable visualization errors.
39 * <p/>
40 * Example usage:
41 * <pre>
42 * Instances train = ... // from somewhere
43 * Instances test = ... // from somewhere
44 * Classifier cls = ... // from somewhere
45 * // build classifier
46 * cls.buildClassifier(train);
47 * // evaluate classifier and generate plot instances
48 * ClassifierPlotInstances plotInstances = new ClassifierPlotInstances();
49 * plotInstances.setClassifier(cls);
50 * plotInstances.setInstances(train);
51 * plotInstances.setClassIndex(train.classIndex());
52 * plotInstances.setUp();
53 * Evaluation eval = new Evaluation(train);
54 * for (int i = 0; i &lt; test.numInstances(); i++)
55 *   plotInstances.process(test.instance(i), cls, eval);
56 * // generate visualization
57 * VisualizePanel visPanel = new VisualizePanel();
58 * visPanel.addPlot(plotInstances.getPlotData("plot name"));
59 * visPanel.setColourIndex(plotInstances.getPlotInstances().classIndex()+1);
60 * // clean up
61 * plotInstances.cleanUp();
62 * </pre>
63 *
64 * @author  fracpete (fracpete at waikato dot ac dot nz)
65 * @version $Revision: 6103 $
66 */
67public class ClassifierErrorsPlotInstances
68  extends AbstractPlotInstances {
69
70  /** for serialization. */
71  private static final long serialVersionUID = -3941976365792013279L;
72
73  /** the minimum plot size for numeric errors. */
74  protected int m_MinimumPlotSizeNumeric;
75
76  /** the maximum plot size for numeric errors. */
77  protected int m_MaximumPlotSizeNumeric;
78 
79  /** whether to save the instances for visualization or just evaluate the
80   * instance. */
81  protected boolean m_SaveForVisualization;
82 
83  /** for storing the plot shapes. */
84  protected FastVector m_PlotShapes;
85 
86  /** for storing the plot sizes. */
87  protected FastVector m_PlotSizes;
88 
89  /** the classifier being used. */
90  protected Classifier m_Classifier;
91
92  /** the class index. */
93  protected int m_ClassIndex;
94 
95  /** the Evaluation object to use. */
96  protected Evaluation m_Evaluation;
97 
98  /**
99   * Initializes the members.
100   */
101  protected void initialize() {
102    super.initialize();
103   
104    m_PlotShapes             = new FastVector();
105    m_PlotSizes              = new FastVector();
106    m_Classifier             = null;
107    m_ClassIndex             = -1;
108    m_Evaluation             = null;
109    m_SaveForVisualization   = true;
110    m_MinimumPlotSizeNumeric = ExplorerDefaults.getClassifierErrorsMinimumPlotSizeNumeric();
111    m_MaximumPlotSizeNumeric = ExplorerDefaults.getClassifierErrorsMaximumPlotSizeNumeric();
112  }
113 
114  /**
115   * Sets the classifier used for making the predictions.
116   *
117   * @param value       the classifier to use
118   */
119  public void setClassifier(Classifier value) {
120    m_Classifier = value;
121  }
122 
123  /**
124   * Returns the currently set classifier.
125   *
126   * @return            the classifier in use
127   */
128  public Classifier getClassifier() {
129    return m_Classifier;
130  }
131
132  /**
133   * Sets the 0-based class index.
134   *
135   * @param index       the class index
136   */
137  public void setClassIndex(int index) {
138    m_ClassIndex = index;
139  }
140 
141  /**
142   * Returns the 0-based class index.
143   *
144   * @return            the class index
145   */
146  public int getClassIndex() {
147    return m_ClassIndex;
148  }
149
150  /**
151   * Sets the Evaluation object to use.
152   *
153   * @param value       the evaluation to use
154   */
155  public void setEvaluation(Evaluation value) {
156    m_Evaluation = value;
157  }
158 
159  /**
160   * Returns the Evaluation object in use.
161   *
162   * @return            the evaluation object
163   */
164  public Evaluation getEvaluation() {
165    return m_Evaluation;
166  }
167 
168  /**
169   * Sets whether the instances are saved for visualization or only evaluation
170   * of the prediction is to happen.
171   *
172   * @param value       if true then the instances will be saved
173   */
174  public void setSaveForVisualization(boolean value) {
175    m_SaveForVisualization = value;
176  }
177 
178  /**
179   * Returns whether the instances are saved for visualization for only
180   * evaluation of the prediction is to happen.
181   *
182   * @return            true if the instances are saved
183   */
184  public boolean getSaveForVisualization() {
185    return m_SaveForVisualization;
186  }
187 
188  /**
189   * Checks whether classifier, class index and evaluation are provided.
190   */
191  protected void check() {
192    super.check();
193   
194    if (m_Classifier == null)
195      throw new IllegalStateException("No classifier set!");
196   
197    if (m_ClassIndex == -1)
198      throw new IllegalStateException("No class index set!");
199   
200    if (m_Evaluation == null)
201      throw new IllegalStateException("No evaluation set");
202  }
203 
204  /**
205   * Sets up the structure for the plot instances. Sets m_PlotInstances to null
206   * if instances are not saved for visualization.
207   *
208   * @see #getSaveForVisualization()
209   */
210  protected void determineFormat() {
211    FastVector  hv;
212    Attribute   predictedClass;
213    Attribute   classAt;
214    FastVector  attVals;
215    int         i;
216   
217    if (!m_SaveForVisualization) {
218      m_PlotInstances = null;
219      return;
220    }
221   
222    hv = new FastVector();
223
224    classAt = m_Instances.attribute(m_ClassIndex);
225    if (classAt.isNominal()) {
226      attVals = new FastVector();
227      for (i = 0; i < classAt.numValues(); i++)
228        attVals.addElement(classAt.value(i));
229      predictedClass = new Attribute("predicted" + classAt.name(), attVals);
230    }
231    else {
232      predictedClass = new Attribute("predicted" + classAt.name());
233    }
234
235    for (i = 0; i < m_Instances.numAttributes(); i++) {
236      if (i == m_Instances.classIndex())
237        hv.addElement(predictedClass);
238      hv.addElement(m_Instances.attribute(i).copy());
239    }
240   
241    m_PlotInstances = new Instances(
242        m_Instances.relationName() + "_predicted", hv, m_Instances.numInstances());
243    m_PlotInstances.setClassIndex(m_ClassIndex + 1);
244  }
245 
246  /**
247   * Process a classifier's prediction for an instance and update a
248   * set of plotting instances and additional plotting info. m_PlotShape
249   * for nominal class datasets holds shape types (actual data points have
250   * automatic shape type assignment; classifier error data points have
251   * box shape type). For numeric class datasets, the actual data points
252   * are stored in m_PlotInstances and m_PlotSize stores the error (which is
253   * later converted to shape size values).
254   *
255   * @param toPredict   the actual data point
256   * @param classifier  the classifier
257   * @param eval        the evaluation object to use for evaluating the classifier on
258   *                    the instance to predict
259   * @see               #m_PlotShapes
260   * @see               #m_PlotSizes
261   * @see               #m_PlotInstances
262   */
263  public void process(Instance toPredict, Classifier classifier, Evaluation eval) {
264    double      pred;
265    double[]    values;
266    int         i;
267   
268    try {
269      pred = eval.evaluateModelOnceAndRecordPrediction(classifier, toPredict);
270     
271      if (!m_SaveForVisualization)
272        return;
273
274      if (m_PlotInstances != null) {
275        values = new double[m_PlotInstances.numAttributes()];
276        for (i = 0; i < m_PlotInstances.numAttributes(); i++) {
277          if (i < toPredict.classIndex()) {
278            values[i] = toPredict.value(i);
279          }
280          else if (i == toPredict.classIndex()) {
281            values[i]   = pred;
282            values[i+1] = toPredict.value(i);
283            i++;
284          }
285          else {
286            values[i] = toPredict.value(i-1);
287          }
288        }
289
290        m_PlotInstances.add(new DenseInstance(1.0, values));
291       
292        if (toPredict.classAttribute().isNominal()) {
293          if (toPredict.isMissing(toPredict.classIndex()) || Utils.isMissingValue(pred)) {
294            m_PlotShapes.addElement(new Integer(Plot2D.MISSING_SHAPE));
295          }
296          else if (pred != toPredict.classValue()) {
297            // set to default error point shape
298            m_PlotShapes.addElement(new Integer(Plot2D.ERROR_SHAPE));
299          }
300          else {
301            // otherwise set to constant (automatically assigned) point shape
302            m_PlotShapes.addElement(new Integer(Plot2D.CONST_AUTOMATIC_SHAPE));
303          }
304          m_PlotSizes.addElement(new Integer(Plot2D.DEFAULT_SHAPE_SIZE));
305        }
306        else {
307          // store the error (to be converted to a point size later)
308          Double errd = null;
309          if (!toPredict.isMissing(toPredict.classIndex()) && !Utils.isMissingValue(pred)) {
310            errd = new Double(pred - toPredict.classValue());
311            m_PlotShapes.addElement(new Integer(Plot2D.CONST_AUTOMATIC_SHAPE));
312          }
313          else {
314            // missing shape if actual class not present or prediction is missing
315            m_PlotShapes.addElement(new Integer(Plot2D.MISSING_SHAPE));
316          }
317          m_PlotSizes.addElement(errd);
318        }
319      }
320    }
321    catch (Exception ex) {
322      ex.printStackTrace();
323    }
324  }
325
326  /**
327   * Scales numeric class predictions into shape sizes for plotting
328   * in the visualize panel.
329   */
330  protected void scaleNumericPredictions() {
331    double      maxErr;
332    double      minErr;
333    double      err;
334    int         i;
335    Double      errd;
336    double      temp;
337   
338    maxErr = Double.NEGATIVE_INFINITY;
339    minErr = Double.POSITIVE_INFINITY;
340
341    // find min/max errors
342    for (i = 0; i < m_PlotSizes.size(); i++) {
343      errd = (Double) m_PlotSizes.elementAt(i);
344      if (errd != null) {
345        err = Math.abs(errd.doubleValue());
346        if (err < minErr)
347          minErr = err;
348        if (err > maxErr)
349          maxErr = err;
350      }
351    }
352   
353    // scale errors
354    for (i = 0; i < m_PlotSizes.size(); i++) {
355      errd = (Double) m_PlotSizes.elementAt(i);
356      if (errd != null) {
357        err = Math.abs(errd.doubleValue());
358        if (maxErr - minErr > 0) {
359          temp = (((err - minErr) / (maxErr - minErr)) * (m_MaximumPlotSizeNumeric - m_MinimumPlotSizeNumeric + 1));
360          m_PlotSizes.setElementAt(new Integer((int) temp) + m_MinimumPlotSizeNumeric, i);
361        }
362        else {
363          m_PlotSizes.setElementAt(new Integer(m_MinimumPlotSizeNumeric), i);
364        }
365      }
366      else {
367        m_PlotSizes.setElementAt(new Integer(m_MinimumPlotSizeNumeric), i);
368      }
369    }
370  }
371 
372  /**
373   * Adds the prediction intervals as additional attributes at the end.
374   * Since classifiers can returns varying number of intervals per instance,
375   * the dataset is filled with missing values for non-existing intervals.
376   */
377  protected void addPredictionIntervals() {
378    int         maxNum;
379    int         num;
380    int         i;
381    int         n;
382    FastVector  preds;
383    FastVector  atts;
384    Instances   data;
385    Instance    inst;
386    Instance    newInst;
387    double[]    values;
388    double[][]  predInt;
389   
390    // determine the maximum number of intervals
391    maxNum = 0;
392    preds  = m_Evaluation.predictions();
393    for (i = 0; i < preds.size(); i++) {
394      num = ((NumericPrediction) preds.elementAt(i)).predictionIntervals().length;
395      if (num > maxNum)
396        maxNum = num;
397    }
398   
399    // create new header
400    atts = new FastVector();
401    for (i = 0; i < m_PlotInstances.numAttributes(); i++)
402      atts.addElement(m_PlotInstances.attribute(i));
403    for (i = 0; i < maxNum; i++) {
404      atts.addElement(new Attribute("predictionInterval_" + (i+1) + "-lowerBoundary"));
405      atts.addElement(new Attribute("predictionInterval_" + (i+1) + "-upperBoundary"));
406      atts.addElement(new Attribute("predictionInterval_" + (i+1) + "-width"));
407    }
408    data = new Instances(m_PlotInstances.relationName(), atts, m_PlotInstances.numInstances());
409    data.setClassIndex(m_PlotInstances.classIndex());
410   
411    // update data
412    for (i = 0; i < m_PlotInstances.numInstances(); i++) {
413      inst = m_PlotInstances.instance(i);
414      // copy old values
415      values = new double[data.numAttributes()];
416      System.arraycopy(inst.toDoubleArray(), 0, values, 0, inst.numAttributes());
417      // add interval data
418      predInt = ((NumericPrediction) preds.elementAt(i)).predictionIntervals();
419      for (n = 0; n < maxNum; n++) {
420        if (n < predInt.length){
421          values[m_PlotInstances.numAttributes() + n*3 + 0] = predInt[n][0];
422          values[m_PlotInstances.numAttributes() + n*3 + 1] = predInt[n][1];
423          values[m_PlotInstances.numAttributes() + n*3 + 2] = predInt[n][1] - predInt[n][0];
424        }
425        else {
426          values[m_PlotInstances.numAttributes() + n*3 + 0] = Utils.missingValue();
427          values[m_PlotInstances.numAttributes() + n*3 + 1] = Utils.missingValue();
428          values[m_PlotInstances.numAttributes() + n*3 + 2] = Utils.missingValue();
429        }
430      }
431      // create new Instance
432      newInst = new DenseInstance(inst.weight(), values);
433      data.add(newInst);
434    }
435   
436    m_PlotInstances = data;
437  }
438 
439  /**
440   * Performs optional post-processing.
441   *
442   * @see #scaleNumericPredictions()
443   * @see #addPredictionIntervals()
444   */
445  protected void finishUp() {
446    super.finishUp();
447   
448    if (!m_SaveForVisualization)
449      return;
450   
451    if (m_Instances.attribute(m_ClassIndex).isNumeric())
452      scaleNumericPredictions();
453    if (m_Classifier instanceof IntervalEstimator)
454      addPredictionIntervals();
455  }
456 
457  /**
458   * Assembles and returns the plot. The relation name of the dataset gets
459   * added automatically.
460   *
461   * @param name        the name of the plot
462   * @return            the plot or null if plot instances weren't saved for visualization
463   * @throws Exception  if plot generation fails
464   */
465  protected PlotData2D createPlotData(String name) throws Exception {
466    PlotData2D  result;
467   
468    if (!m_SaveForVisualization)
469      return null;
470   
471    result = new PlotData2D(m_PlotInstances);
472    result.setShapeSize(m_PlotSizes);
473    result.setShapeType(m_PlotShapes);
474    result.setPlotName(name + " (" + m_Instances.relationName() + ")");
475    result.addInstanceNumberAttribute();
476
477    return result;
478  }
479 
480  /**
481   * For freeing up memory. Plot data cannot be generated after this call!
482   */
483  public void cleanUp() {
484    super.cleanUp();
485   
486    m_Classifier = null;
487    m_PlotShapes = null;
488    m_PlotSizes  = null;
489    m_Evaluation = null;
490  }
491}
Note: See TracBrowser for help on using the repository browser.