source: src/main/java/weka/gui/visualize/ThresholdVisualizePanel.java @ 24

Last change on this file since 24 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 9.7 KB
RevLine 
[4]1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    ThresholdVisualizePanel.java
19 *    Copyright (C) 2003 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.gui.visualize;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.classifiers.AbstractClassifier;
28import weka.classifiers.evaluation.EvaluationUtils;
29import weka.classifiers.evaluation.ThresholdCurve;
30import weka.core.FastVector;
31import weka.core.Instances;
32import weka.core.SingleIndex;
33import weka.core.Utils;
34
35import java.awt.BorderLayout;
36import java.awt.event.ActionEvent;
37import java.awt.event.ActionListener;
38import java.awt.event.WindowAdapter;
39import java.awt.event.WindowEvent;
40import java.io.BufferedReader;
41import java.io.FileReader;
42
43import javax.swing.BorderFactory;
44import javax.swing.JFrame;
45import javax.swing.border.TitledBorder;
46
47/**
48 * This panel is a VisualizePanel, with the added ablility to display the
49 * area under the ROC curve if an ROC curve is chosen.
50 *
51 * @author Dale Fletcher (dale@cs.waikato.ac.nz)
52 * @author FracPete (fracpete at waikato dot ac dot nz)
53 * @version $Revision: 5928 $
54 */
55public class ThresholdVisualizePanel 
56  extends VisualizePanel {
57
58  /** for serialization */
59  private static final long serialVersionUID = 3070002211779443890L;
60
61  /** The string to add to the Plot Border. */
62  private String m_ROCString="";
63 
64  /** Original border text */
65  private String m_savePanelBorderText;
66
67  /**
68   * default constructor
69   */
70  public ThresholdVisualizePanel() {
71    super();
72
73    // Save the current border text
74    TitledBorder tb=(TitledBorder) m_plotSurround.getBorder();
75    m_savePanelBorderText = tb.getTitle();
76  }
77 
78  /**
79   * Set the string with ROC area
80   * @param str ROC area string to add to border
81   */ 
82  public void setROCString(String str) {
83    m_ROCString=str;
84  }
85
86  /**
87   * This extracts the ROC area string
88   * @return ROC area string
89   */
90  public String getROCString() {
91    return m_ROCString;
92  }
93
94  /**
95   * This overloads VisualizePanel's setUpComboBoxes to add
96   * ActionListeners to watch for when the X/Y Axis comboboxes
97   * are changed.
98   * @param inst a set of instances with data for plotting
99   */
100  public void setUpComboBoxes(Instances inst) {
101    super.setUpComboBoxes(inst);
102
103    m_XCombo.addActionListener(new ActionListener() {
104        public void actionPerformed(ActionEvent e) {
105          setBorderText();
106        }
107    });
108    m_YCombo.addActionListener(new ActionListener() {
109        public void actionPerformed(ActionEvent e) {
110          setBorderText();
111        }
112    });
113
114    // Just in case the default is ROC
115    setBorderText();
116  }
117
118  /**
119   * This checks the current selected X/Y Axis comboBoxes to see if
120   * an ROC graph is selected. If so, add the ROC area string to the
121   * plot border, otherwise display the original border text.
122   */
123  private void setBorderText() {
124
125    String xs = m_XCombo.getSelectedItem().toString();
126    String ys = m_YCombo.getSelectedItem().toString();
127
128    if (xs.equals("X: False Positive Rate (Num)") && ys.equals("Y: True Positive Rate (Num)"))   {
129        m_plotSurround.setBorder((BorderFactory.createTitledBorder(m_savePanelBorderText+" "+m_ROCString)));
130    } else
131        m_plotSurround.setBorder((BorderFactory.createTitledBorder(m_savePanelBorderText))); 
132  }
133
134  /**
135   * displays the previously saved instances
136   *
137   * @param insts       the instances to display
138   * @throws Exception  if display is not possible
139   */
140  protected void openVisibleInstances(Instances insts) throws Exception {
141    super.openVisibleInstances(insts);
142
143    setROCString(
144        "(Area under ROC = " 
145        + Utils.doubleToString(ThresholdCurve.getROCArea(insts), 4) + ")");
146   
147    setBorderText();
148  }
149 
150  /**
151   * Starts the ThresholdVisualizationPanel with parameters from the command line. <p/>
152   *
153   * Valid options are: <p/>
154   *  -h <br/>
155   *  lists all the commandline parameters <p/>
156   * 
157   *  -t file <br/>
158   *  Dataset to process with given classifier. <p/>
159   * 
160   *  -W classname <br/>
161   *  Full classname of classifier to run.<br/>
162   *  Options after '--' are passed to the classifier. <br/>
163   *  (default weka.classifiers.functions.Logistic) <p/>
164   * 
165   *  -r number <br/>
166   *  The number of runs to perform (default 2). <p/>
167   * 
168   *  -x number <br/>
169   *  The number of Cross-validation folds (default 10). <p/>
170   * 
171   *  -l file <br/>
172   *  Previously saved threshold curve ARFF file. <p/>
173   *
174   * @param args optional commandline parameters
175   */
176  public static void main(String [] args) {
177    Instances           inst;
178    Classifier          classifier;
179    int                 runs;
180    int                 folds;
181    String              tmpStr;
182    boolean             compute;
183    Instances           result;
184    String[]            options;
185    SingleIndex         classIndex;
186    SingleIndex         valueIndex;
187    int                 seed;
188   
189    inst       = null;
190    classifier = null;
191    runs       = 2;
192    folds      = 10;
193    compute    = true;
194    result     = null;
195    classIndex = null;
196    valueIndex = null;
197    seed       = 1;
198   
199    try {
200      // help?
201      if (Utils.getFlag('h', args)) {
202        System.out.println("\nOptions for " + ThresholdVisualizePanel.class.getName() + ":\n");
203        System.out.println("-h\n\tThis help.");
204        System.out.println("-t <file>\n\tDataset to process with given classifier.");
205        System.out.println("-c <num>\n\tThe class index. first and last are valid, too (default: last).");
206        System.out.println("-C <num>\n\tThe index of the class value to get the the curve for (default: first).");
207        System.out.println("-W <classname>\n\tFull classname of classifier to run.\n\tOptions after '--' are passed to the classifier.\n\t(default: weka.classifiers.functions.Logistic)");
208        System.out.println("-r <number>\n\tThe number of runs to perform (default: 1).");
209        System.out.println("-x <number>\n\tThe number of Cross-validation folds (default: 10).");
210        System.out.println("-S <number>\n\tThe seed value for randomizing the data (default: 1).");
211        System.out.println("-l <file>\n\tPreviously saved threshold curve ARFF file.");
212        return;
213      }
214     
215      // regular options
216      tmpStr = Utils.getOption('l', args);
217      if (tmpStr.length() != 0) {
218        result = new Instances(new BufferedReader(new FileReader(tmpStr)));
219        compute = false;
220      }
221     
222      if (compute) {
223        tmpStr = Utils.getOption('r', args);
224        if (tmpStr.length() != 0)
225          runs = Integer.parseInt(tmpStr);
226        else
227          runs = 1;
228       
229        tmpStr = Utils.getOption('x', args);
230        if (tmpStr.length() != 0)
231          folds = Integer.parseInt(tmpStr);
232        else
233          folds = 10;
234       
235        tmpStr = Utils.getOption('S', args);
236        if (tmpStr.length() != 0)
237          seed = Integer.parseInt(tmpStr);
238        else
239          seed = 1;
240       
241        tmpStr = Utils.getOption('t', args);
242        if (tmpStr.length() != 0) {
243          inst = new Instances(new BufferedReader(new FileReader(tmpStr)));
244          inst.setClassIndex(inst.numAttributes() - 1);
245        }
246       
247        tmpStr = Utils.getOption('W', args);
248        if (tmpStr.length() != 0) {
249          options = Utils.partitionOptions(args);
250        }
251        else {
252          tmpStr = weka.classifiers.functions.Logistic.class.getName();
253          options = new String[0];
254        }
255        classifier = AbstractClassifier.forName(tmpStr, options);
256       
257        tmpStr = Utils.getOption('c', args);
258        if (tmpStr.length() != 0)
259          classIndex = new SingleIndex(tmpStr);
260        else
261          classIndex = new SingleIndex("last");
262       
263        tmpStr = Utils.getOption('C', args);
264        if (tmpStr.length() != 0)
265          valueIndex = new SingleIndex(tmpStr);
266        else
267          valueIndex = new SingleIndex("first");
268      }
269     
270      // compute if necessary
271      if (compute) {
272        if (classIndex != null) {
273          classIndex.setUpper(inst.numAttributes() - 1);
274          inst.setClassIndex(classIndex.getIndex());
275        }
276        else {
277          inst.setClassIndex(inst.numAttributes() - 1);
278        }
279       
280        if (valueIndex != null) {
281          valueIndex.setUpper(inst.classAttribute().numValues() - 1);
282        }
283       
284        ThresholdCurve tc = new ThresholdCurve();
285        EvaluationUtils eu = new EvaluationUtils();
286        FastVector predictions = new FastVector();
287        for (int i = 0; i < runs; i++) {
288          eu.setSeed(seed + i);
289          predictions.appendElements(eu.getCVPredictions(classifier, inst, folds));
290        }
291       
292        if (valueIndex != null)
293          result = tc.getCurve(predictions, valueIndex.getIndex());
294        else
295          result = tc.getCurve(predictions);
296      }
297     
298      // setup GUI
299      ThresholdVisualizePanel vmc = new ThresholdVisualizePanel();
300      vmc.setROCString("(Area under ROC = " + 
301          Utils.doubleToString(ThresholdCurve.getROCArea(result), 4) + ")");
302      if (compute)     
303        vmc.setName(
304            result.relationName() 
305            + ". (Class value " + inst.classAttribute().value(valueIndex.getIndex()) + ")");
306      else
307        vmc.setName(
308            result.relationName()
309            + " (display only)");
310      PlotData2D tempd = new PlotData2D(result);
311      tempd.setPlotName(result.relationName());
312      tempd.addInstanceNumberAttribute();
313      vmc.addPlot(tempd);
314     
315      String plotName = vmc.getName(); 
316      final JFrame jf = new JFrame("Weka Classifier Visualize: "+plotName);
317      jf.setSize(500,400);
318      jf.getContentPane().setLayout(new BorderLayout());
319     
320      jf.getContentPane().add(vmc, BorderLayout.CENTER);
321      jf.addWindowListener(new WindowAdapter() {
322        public void windowClosing(WindowEvent e) {
323          jf.dispose();
324        }
325      });
326     
327      jf.setVisible(true);
328    }
329    catch (Exception e) {
330      e.printStackTrace();
331    }
332  }
333}
334
335 
336
337
338
339
340
341
Note: See TracBrowser for help on using the repository browser.