[29] | 1 | /* |
---|
| 2 | * This program is free software; you can redistribute it and/or modify |
---|
| 3 | * it under the terms of the GNU General Public License as published by |
---|
| 4 | * the Free Software Foundation; either version 2 of the License, or |
---|
| 5 | * (at your option) any later version. |
---|
| 6 | * |
---|
| 7 | * This program is distributed in the hope that it will be useful, |
---|
| 8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
| 9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
---|
| 10 | * GNU General Public License for more details. |
---|
| 11 | * |
---|
| 12 | * You should have received a copy of the GNU General Public License |
---|
| 13 | * along with this program; if not, write to the Free Software |
---|
| 14 | * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. |
---|
| 15 | */ |
---|
| 16 | |
---|
| 17 | /* |
---|
| 18 | * ThresholdVisualizePanel.java |
---|
| 19 | * Copyright (C) 2003 University of Waikato, Hamilton, New Zealand |
---|
| 20 | * |
---|
| 21 | */ |
---|
| 22 | |
---|
| 23 | package weka.gui.visualize; |
---|
| 24 | |
---|
| 25 | import weka.classifiers.Classifier; |
---|
| 26 | import weka.classifiers.AbstractClassifier; |
---|
| 27 | import weka.classifiers.AbstractClassifier; |
---|
| 28 | import weka.classifiers.evaluation.EvaluationUtils; |
---|
| 29 | import weka.classifiers.evaluation.ThresholdCurve; |
---|
| 30 | import weka.core.FastVector; |
---|
| 31 | import weka.core.Instances; |
---|
| 32 | import weka.core.SingleIndex; |
---|
| 33 | import weka.core.Utils; |
---|
| 34 | |
---|
| 35 | import java.awt.BorderLayout; |
---|
| 36 | import java.awt.event.ActionEvent; |
---|
| 37 | import java.awt.event.ActionListener; |
---|
| 38 | import java.awt.event.WindowAdapter; |
---|
| 39 | import java.awt.event.WindowEvent; |
---|
| 40 | import java.io.BufferedReader; |
---|
| 41 | import java.io.FileReader; |
---|
| 42 | |
---|
| 43 | import javax.swing.BorderFactory; |
---|
| 44 | import javax.swing.JFrame; |
---|
| 45 | import javax.swing.border.TitledBorder; |
---|
| 46 | |
---|
| 47 | /** |
---|
| 48 | * This panel is a VisualizePanel, with the added ablility to display the |
---|
| 49 | * area under the ROC curve if an ROC curve is chosen. |
---|
| 50 | * |
---|
| 51 | * @author Dale Fletcher (dale@cs.waikato.ac.nz) |
---|
| 52 | * @author FracPete (fracpete at waikato dot ac dot nz) |
---|
| 53 | * @version $Revision: 5928 $ |
---|
| 54 | */ |
---|
| 55 | public class ThresholdVisualizePanel |
---|
| 56 | extends VisualizePanel { |
---|
| 57 | |
---|
| 58 | /** for serialization */ |
---|
| 59 | private static final long serialVersionUID = 3070002211779443890L; |
---|
| 60 | |
---|
| 61 | /** The string to add to the Plot Border. */ |
---|
| 62 | private String m_ROCString=""; |
---|
| 63 | |
---|
| 64 | /** Original border text */ |
---|
| 65 | private String m_savePanelBorderText; |
---|
| 66 | |
---|
| 67 | /** |
---|
| 68 | * default constructor |
---|
| 69 | */ |
---|
| 70 | public ThresholdVisualizePanel() { |
---|
| 71 | super(); |
---|
| 72 | |
---|
| 73 | // Save the current border text |
---|
| 74 | TitledBorder tb=(TitledBorder) m_plotSurround.getBorder(); |
---|
| 75 | m_savePanelBorderText = tb.getTitle(); |
---|
| 76 | } |
---|
| 77 | |
---|
| 78 | /** |
---|
| 79 | * Set the string with ROC area |
---|
| 80 | * @param str ROC area string to add to border |
---|
| 81 | */ |
---|
| 82 | public void setROCString(String str) { |
---|
| 83 | m_ROCString=str; |
---|
| 84 | } |
---|
| 85 | |
---|
| 86 | /** |
---|
| 87 | * This extracts the ROC area string |
---|
| 88 | * @return ROC area string |
---|
| 89 | */ |
---|
| 90 | public String getROCString() { |
---|
| 91 | return m_ROCString; |
---|
| 92 | } |
---|
| 93 | |
---|
| 94 | /** |
---|
| 95 | * This overloads VisualizePanel's setUpComboBoxes to add |
---|
| 96 | * ActionListeners to watch for when the X/Y Axis comboboxes |
---|
| 97 | * are changed. |
---|
| 98 | * @param inst a set of instances with data for plotting |
---|
| 99 | */ |
---|
| 100 | public void setUpComboBoxes(Instances inst) { |
---|
| 101 | super.setUpComboBoxes(inst); |
---|
| 102 | |
---|
| 103 | m_XCombo.addActionListener(new ActionListener() { |
---|
| 104 | public void actionPerformed(ActionEvent e) { |
---|
| 105 | setBorderText(); |
---|
| 106 | } |
---|
| 107 | }); |
---|
| 108 | m_YCombo.addActionListener(new ActionListener() { |
---|
| 109 | public void actionPerformed(ActionEvent e) { |
---|
| 110 | setBorderText(); |
---|
| 111 | } |
---|
| 112 | }); |
---|
| 113 | |
---|
| 114 | // Just in case the default is ROC |
---|
| 115 | setBorderText(); |
---|
| 116 | } |
---|
| 117 | |
---|
| 118 | /** |
---|
| 119 | * This checks the current selected X/Y Axis comboBoxes to see if |
---|
| 120 | * an ROC graph is selected. If so, add the ROC area string to the |
---|
| 121 | * plot border, otherwise display the original border text. |
---|
| 122 | */ |
---|
| 123 | private void setBorderText() { |
---|
| 124 | |
---|
| 125 | String xs = m_XCombo.getSelectedItem().toString(); |
---|
| 126 | String ys = m_YCombo.getSelectedItem().toString(); |
---|
| 127 | |
---|
| 128 | if (xs.equals("X: False Positive Rate (Num)") && ys.equals("Y: True Positive Rate (Num)")) { |
---|
| 129 | m_plotSurround.setBorder((BorderFactory.createTitledBorder(m_savePanelBorderText+" "+m_ROCString))); |
---|
| 130 | } else |
---|
| 131 | m_plotSurround.setBorder((BorderFactory.createTitledBorder(m_savePanelBorderText))); |
---|
| 132 | } |
---|
| 133 | |
---|
| 134 | /** |
---|
| 135 | * displays the previously saved instances |
---|
| 136 | * |
---|
| 137 | * @param insts the instances to display |
---|
| 138 | * @throws Exception if display is not possible |
---|
| 139 | */ |
---|
| 140 | protected void openVisibleInstances(Instances insts) throws Exception { |
---|
| 141 | super.openVisibleInstances(insts); |
---|
| 142 | |
---|
| 143 | setROCString( |
---|
| 144 | "(Area under ROC = " |
---|
| 145 | + Utils.doubleToString(ThresholdCurve.getROCArea(insts), 4) + ")"); |
---|
| 146 | |
---|
| 147 | setBorderText(); |
---|
| 148 | } |
---|
| 149 | |
---|
| 150 | /** |
---|
| 151 | * Starts the ThresholdVisualizationPanel with parameters from the command line. <p/> |
---|
| 152 | * |
---|
| 153 | * Valid options are: <p/> |
---|
| 154 | * -h <br/> |
---|
| 155 | * lists all the commandline parameters <p/> |
---|
| 156 | * |
---|
| 157 | * -t file <br/> |
---|
| 158 | * Dataset to process with given classifier. <p/> |
---|
| 159 | * |
---|
| 160 | * -W classname <br/> |
---|
| 161 | * Full classname of classifier to run.<br/> |
---|
| 162 | * Options after '--' are passed to the classifier. <br/> |
---|
| 163 | * (default weka.classifiers.functions.Logistic) <p/> |
---|
| 164 | * |
---|
| 165 | * -r number <br/> |
---|
| 166 | * The number of runs to perform (default 2). <p/> |
---|
| 167 | * |
---|
| 168 | * -x number <br/> |
---|
| 169 | * The number of Cross-validation folds (default 10). <p/> |
---|
| 170 | * |
---|
| 171 | * -l file <br/> |
---|
| 172 | * Previously saved threshold curve ARFF file. <p/> |
---|
| 173 | * |
---|
| 174 | * @param args optional commandline parameters |
---|
| 175 | */ |
---|
| 176 | public static void main(String [] args) { |
---|
| 177 | Instances inst; |
---|
| 178 | Classifier classifier; |
---|
| 179 | int runs; |
---|
| 180 | int folds; |
---|
| 181 | String tmpStr; |
---|
| 182 | boolean compute; |
---|
| 183 | Instances result; |
---|
| 184 | String[] options; |
---|
| 185 | SingleIndex classIndex; |
---|
| 186 | SingleIndex valueIndex; |
---|
| 187 | int seed; |
---|
| 188 | |
---|
| 189 | inst = null; |
---|
| 190 | classifier = null; |
---|
| 191 | runs = 2; |
---|
| 192 | folds = 10; |
---|
| 193 | compute = true; |
---|
| 194 | result = null; |
---|
| 195 | classIndex = null; |
---|
| 196 | valueIndex = null; |
---|
| 197 | seed = 1; |
---|
| 198 | |
---|
| 199 | try { |
---|
| 200 | // help? |
---|
| 201 | if (Utils.getFlag('h', args)) { |
---|
| 202 | System.out.println("\nOptions for " + ThresholdVisualizePanel.class.getName() + ":\n"); |
---|
| 203 | System.out.println("-h\n\tThis help."); |
---|
| 204 | System.out.println("-t <file>\n\tDataset to process with given classifier."); |
---|
| 205 | System.out.println("-c <num>\n\tThe class index. first and last are valid, too (default: last)."); |
---|
| 206 | System.out.println("-C <num>\n\tThe index of the class value to get the the curve for (default: first)."); |
---|
| 207 | System.out.println("-W <classname>\n\tFull classname of classifier to run.\n\tOptions after '--' are passed to the classifier.\n\t(default: weka.classifiers.functions.Logistic)"); |
---|
| 208 | System.out.println("-r <number>\n\tThe number of runs to perform (default: 1)."); |
---|
| 209 | System.out.println("-x <number>\n\tThe number of Cross-validation folds (default: 10)."); |
---|
| 210 | System.out.println("-S <number>\n\tThe seed value for randomizing the data (default: 1)."); |
---|
| 211 | System.out.println("-l <file>\n\tPreviously saved threshold curve ARFF file."); |
---|
| 212 | return; |
---|
| 213 | } |
---|
| 214 | |
---|
| 215 | // regular options |
---|
| 216 | tmpStr = Utils.getOption('l', args); |
---|
| 217 | if (tmpStr.length() != 0) { |
---|
| 218 | result = new Instances(new BufferedReader(new FileReader(tmpStr))); |
---|
| 219 | compute = false; |
---|
| 220 | } |
---|
| 221 | |
---|
| 222 | if (compute) { |
---|
| 223 | tmpStr = Utils.getOption('r', args); |
---|
| 224 | if (tmpStr.length() != 0) |
---|
| 225 | runs = Integer.parseInt(tmpStr); |
---|
| 226 | else |
---|
| 227 | runs = 1; |
---|
| 228 | |
---|
| 229 | tmpStr = Utils.getOption('x', args); |
---|
| 230 | if (tmpStr.length() != 0) |
---|
| 231 | folds = Integer.parseInt(tmpStr); |
---|
| 232 | else |
---|
| 233 | folds = 10; |
---|
| 234 | |
---|
| 235 | tmpStr = Utils.getOption('S', args); |
---|
| 236 | if (tmpStr.length() != 0) |
---|
| 237 | seed = Integer.parseInt(tmpStr); |
---|
| 238 | else |
---|
| 239 | seed = 1; |
---|
| 240 | |
---|
| 241 | tmpStr = Utils.getOption('t', args); |
---|
| 242 | if (tmpStr.length() != 0) { |
---|
| 243 | inst = new Instances(new BufferedReader(new FileReader(tmpStr))); |
---|
| 244 | inst.setClassIndex(inst.numAttributes() - 1); |
---|
| 245 | } |
---|
| 246 | |
---|
| 247 | tmpStr = Utils.getOption('W', args); |
---|
| 248 | if (tmpStr.length() != 0) { |
---|
| 249 | options = Utils.partitionOptions(args); |
---|
| 250 | } |
---|
| 251 | else { |
---|
| 252 | tmpStr = weka.classifiers.functions.Logistic.class.getName(); |
---|
| 253 | options = new String[0]; |
---|
| 254 | } |
---|
| 255 | classifier = AbstractClassifier.forName(tmpStr, options); |
---|
| 256 | |
---|
| 257 | tmpStr = Utils.getOption('c', args); |
---|
| 258 | if (tmpStr.length() != 0) |
---|
| 259 | classIndex = new SingleIndex(tmpStr); |
---|
| 260 | else |
---|
| 261 | classIndex = new SingleIndex("last"); |
---|
| 262 | |
---|
| 263 | tmpStr = Utils.getOption('C', args); |
---|
| 264 | if (tmpStr.length() != 0) |
---|
| 265 | valueIndex = new SingleIndex(tmpStr); |
---|
| 266 | else |
---|
| 267 | valueIndex = new SingleIndex("first"); |
---|
| 268 | } |
---|
| 269 | |
---|
| 270 | // compute if necessary |
---|
| 271 | if (compute) { |
---|
| 272 | if (classIndex != null) { |
---|
| 273 | classIndex.setUpper(inst.numAttributes() - 1); |
---|
| 274 | inst.setClassIndex(classIndex.getIndex()); |
---|
| 275 | } |
---|
| 276 | else { |
---|
| 277 | inst.setClassIndex(inst.numAttributes() - 1); |
---|
| 278 | } |
---|
| 279 | |
---|
| 280 | if (valueIndex != null) { |
---|
| 281 | valueIndex.setUpper(inst.classAttribute().numValues() - 1); |
---|
| 282 | } |
---|
| 283 | |
---|
| 284 | ThresholdCurve tc = new ThresholdCurve(); |
---|
| 285 | EvaluationUtils eu = new EvaluationUtils(); |
---|
| 286 | FastVector predictions = new FastVector(); |
---|
| 287 | for (int i = 0; i < runs; i++) { |
---|
| 288 | eu.setSeed(seed + i); |
---|
| 289 | predictions.appendElements(eu.getCVPredictions(classifier, inst, folds)); |
---|
| 290 | } |
---|
| 291 | |
---|
| 292 | if (valueIndex != null) |
---|
| 293 | result = tc.getCurve(predictions, valueIndex.getIndex()); |
---|
| 294 | else |
---|
| 295 | result = tc.getCurve(predictions); |
---|
| 296 | } |
---|
| 297 | |
---|
| 298 | // setup GUI |
---|
| 299 | ThresholdVisualizePanel vmc = new ThresholdVisualizePanel(); |
---|
| 300 | vmc.setROCString("(Area under ROC = " + |
---|
| 301 | Utils.doubleToString(ThresholdCurve.getROCArea(result), 4) + ")"); |
---|
| 302 | if (compute) |
---|
| 303 | vmc.setName( |
---|
| 304 | result.relationName() |
---|
| 305 | + ". (Class value " + inst.classAttribute().value(valueIndex.getIndex()) + ")"); |
---|
| 306 | else |
---|
| 307 | vmc.setName( |
---|
| 308 | result.relationName() |
---|
| 309 | + " (display only)"); |
---|
| 310 | PlotData2D tempd = new PlotData2D(result); |
---|
| 311 | tempd.setPlotName(result.relationName()); |
---|
| 312 | tempd.addInstanceNumberAttribute(); |
---|
| 313 | vmc.addPlot(tempd); |
---|
| 314 | |
---|
| 315 | String plotName = vmc.getName(); |
---|
| 316 | final JFrame jf = new JFrame("Weka Classifier Visualize: "+plotName); |
---|
| 317 | jf.setSize(500,400); |
---|
| 318 | jf.getContentPane().setLayout(new BorderLayout()); |
---|
| 319 | |
---|
| 320 | jf.getContentPane().add(vmc, BorderLayout.CENTER); |
---|
| 321 | jf.addWindowListener(new WindowAdapter() { |
---|
| 322 | public void windowClosing(WindowEvent e) { |
---|
| 323 | jf.dispose(); |
---|
| 324 | } |
---|
| 325 | }); |
---|
| 326 | |
---|
| 327 | jf.setVisible(true); |
---|
| 328 | } |
---|
| 329 | catch (Exception e) { |
---|
| 330 | e.printStackTrace(); |
---|
| 331 | } |
---|
| 332 | } |
---|
| 333 | } |
---|
| 334 | |
---|
| 335 | |
---|
| 336 | |
---|
| 337 | |
---|
| 338 | |
---|
| 339 | |
---|
| 340 | |
---|
| 341 | |
---|