1 | /* |
---|
2 | * This program is free software; you can redistribute it and/or modify |
---|
3 | * it under the terms of the GNU General Public License as published by |
---|
4 | * the Free Software Foundation; either version 2 of the License, or |
---|
5 | * (at your option) any later version. |
---|
6 | * |
---|
7 | * This program is distributed in the hope that it will be useful, |
---|
8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
---|
10 | * GNU General Public License for more details. |
---|
11 | * |
---|
12 | * You should have received a copy of the GNU General Public License |
---|
13 | * along with this program; if not, write to the Free Software |
---|
14 | * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. |
---|
15 | */ |
---|
16 | |
---|
17 | /* |
---|
18 | * CostBenefitAnalysis.java |
---|
19 | * Copyright (C) 2009 University of Waikato, Hamilton, New Zealand |
---|
20 | * |
---|
21 | */ |
---|
22 | |
---|
23 | package weka.gui.beans; |
---|
24 | |
---|
25 | import java.awt.BorderLayout; |
---|
26 | import java.awt.Color; |
---|
27 | import java.awt.Dimension; |
---|
28 | import java.awt.FlowLayout; |
---|
29 | import java.awt.GridLayout; |
---|
30 | import java.awt.Graphics; |
---|
31 | import java.awt.event.ActionEvent; |
---|
32 | import java.awt.event.ActionListener; |
---|
33 | import java.awt.event.FocusEvent; |
---|
34 | import java.awt.event.FocusListener; |
---|
35 | import java.beans.EventSetDescriptor; |
---|
36 | import java.beans.PropertyVetoException; |
---|
37 | import java.beans.VetoableChangeListener; |
---|
38 | import java.beans.beancontext.BeanContext; |
---|
39 | import java.beans.beancontext.BeanContextChild; |
---|
40 | import java.beans.beancontext.BeanContextChildSupport; |
---|
41 | import java.io.Serializable; |
---|
42 | import java.util.Enumeration; |
---|
43 | import java.util.Vector; |
---|
44 | |
---|
45 | import javax.swing.BorderFactory; |
---|
46 | import javax.swing.ButtonGroup; |
---|
47 | import javax.swing.JButton; |
---|
48 | import javax.swing.JFrame; |
---|
49 | import javax.swing.JLabel; |
---|
50 | import javax.swing.JPanel; |
---|
51 | import javax.swing.JRadioButton; |
---|
52 | import javax.swing.JSlider; |
---|
53 | import javax.swing.JTextField; |
---|
54 | import javax.swing.SwingConstants; |
---|
55 | import javax.swing.event.ChangeEvent; |
---|
56 | import javax.swing.event.ChangeListener; |
---|
57 | |
---|
58 | import weka.classifiers.evaluation.ThresholdCurve; |
---|
59 | import weka.core.Attribute; |
---|
60 | import weka.core.FastVector; |
---|
61 | import weka.core.Instance; |
---|
62 | import weka.core.DenseInstance; |
---|
63 | import weka.core.Instances; |
---|
64 | import weka.core.Utils; |
---|
65 | import weka.gui.Logger; |
---|
66 | import weka.gui.visualize.VisualizePanel; |
---|
67 | import weka.gui.visualize.Plot2D; |
---|
68 | import weka.gui.visualize.PlotData2D; |
---|
69 | |
---|
70 | |
---|
71 | /** |
---|
72 | * Bean that aids in analyzing cost/benefit tradeoffs. |
---|
73 | * |
---|
74 | * @author Mark Hall (mhall{[at]}pentaho{[dot]}com) |
---|
75 | * @version $Revision: 6137 $ |
---|
76 | */ |
---|
77 | public class CostBenefitAnalysis extends JPanel |
---|
78 | implements BeanCommon, ThresholdDataListener, Visible, UserRequestAcceptor, |
---|
79 | Serializable, BeanContextChild { |
---|
80 | |
---|
81 | /** For serialization */ |
---|
82 | private static final long serialVersionUID = 8647471654613320469L; |
---|
83 | |
---|
84 | protected BeanVisual m_visual; |
---|
85 | |
---|
86 | protected transient JFrame m_popupFrame; |
---|
87 | |
---|
88 | protected boolean m_framePoppedUp = false; |
---|
89 | |
---|
90 | private transient AnalysisPanel m_analysisPanel; |
---|
91 | |
---|
92 | /** |
---|
93 | * True if this bean's appearance is the design mode appearance |
---|
94 | */ |
---|
95 | protected boolean m_design; |
---|
96 | |
---|
97 | /** |
---|
98 | * BeanContex that this bean might be contained within |
---|
99 | */ |
---|
100 | protected transient BeanContext m_beanContext = null; |
---|
101 | |
---|
102 | /** |
---|
103 | * BeanContextChild support |
---|
104 | */ |
---|
105 | protected BeanContextChildSupport m_bcSupport = |
---|
106 | new BeanContextChildSupport(this); |
---|
107 | |
---|
108 | /** |
---|
109 | * The object sending us data (we allow only one connection at any one time) |
---|
110 | */ |
---|
111 | protected Object m_listenee; |
---|
112 | |
---|
113 | /** |
---|
114 | * Inner class for displaying the plots and all control widgets. |
---|
115 | * |
---|
116 | * @author Mark Hall (mhall{[at]}pentaho{[dot]}com) |
---|
117 | */ |
---|
118 | protected static class AnalysisPanel extends JPanel { |
---|
119 | |
---|
120 | /** For serialization */ |
---|
121 | private static final long serialVersionUID = 5364871945448769003L; |
---|
122 | |
---|
123 | /** Displays the performance graphs(s) */ |
---|
124 | protected VisualizePanel m_performancePanel = new VisualizePanel(); |
---|
125 | |
---|
126 | /** Displays the cost/benefit (profit/loss) graph */ |
---|
127 | protected VisualizePanel m_costBenefitPanel = new VisualizePanel(); |
---|
128 | |
---|
129 | /** |
---|
130 | * The class attribute from the data that was used to generate |
---|
131 | * the threshold curve |
---|
132 | */ |
---|
133 | protected Attribute m_classAttribute; |
---|
134 | |
---|
135 | /** Data for the threshold curve */ |
---|
136 | protected PlotData2D m_masterPlot; |
---|
137 | |
---|
138 | /** Data for the cost/benefit curve */ |
---|
139 | protected PlotData2D m_costBenefit; |
---|
140 | |
---|
141 | /** The size of the points being plotted */ |
---|
142 | protected int[] m_shapeSizes; |
---|
143 | |
---|
144 | /** The index of the previous plotted point that was highlighted */ |
---|
145 | protected int m_previousShapeIndex = -1; |
---|
146 | |
---|
147 | /** The slider for adjusting the threshold */ |
---|
148 | protected JSlider m_thresholdSlider = new JSlider(0,100,0); |
---|
149 | |
---|
150 | protected JRadioButton m_percPop = new JRadioButton("% of Population"); |
---|
151 | protected JRadioButton m_percOfTarget = new JRadioButton("% of Target (recall)"); |
---|
152 | protected JRadioButton m_threshold = new JRadioButton("Score Threshold"); |
---|
153 | |
---|
154 | protected JLabel m_percPopLab = new JLabel(); |
---|
155 | protected JLabel m_percOfTargetLab = new JLabel(); |
---|
156 | protected JLabel m_thresholdLab = new JLabel(); |
---|
157 | |
---|
158 | // Confusion matrix stuff |
---|
159 | protected JLabel m_conf_predictedA = new JLabel("Predicted (a)", SwingConstants.RIGHT); |
---|
160 | protected JLabel m_conf_predictedB = new JLabel("Predicted (b)", SwingConstants.RIGHT); |
---|
161 | protected JLabel m_conf_actualA = new JLabel(" Actual (a):"); |
---|
162 | protected JLabel m_conf_actualB = new JLabel(" Actual (b):"); |
---|
163 | protected ConfusionCell m_conf_aa = new ConfusionCell(); |
---|
164 | protected ConfusionCell m_conf_ab = new ConfusionCell(); |
---|
165 | protected ConfusionCell m_conf_ba = new ConfusionCell(); |
---|
166 | protected ConfusionCell m_conf_bb = new ConfusionCell(); |
---|
167 | |
---|
168 | // Cost matrix stuff |
---|
169 | protected JLabel m_cost_predictedA = new JLabel("Predicted (a)", SwingConstants.RIGHT); |
---|
170 | protected JLabel m_cost_predictedB = new JLabel("Predicted (b)", SwingConstants.RIGHT); |
---|
171 | protected JLabel m_cost_actualA = new JLabel(" Actual (a)"); |
---|
172 | protected JLabel m_cost_actualB = new JLabel(" Actual (b)"); |
---|
173 | protected JTextField m_cost_aa = new JTextField("0.0", 5); |
---|
174 | protected JTextField m_cost_ab = new JTextField("1.0", 5); |
---|
175 | protected JTextField m_cost_ba = new JTextField("1.0", 5); |
---|
176 | protected JTextField m_cost_bb = new JTextField("0.0" ,5); |
---|
177 | protected JButton m_maximizeCB = new JButton("Maximize Cost/Benefit"); |
---|
178 | protected JButton m_minimizeCB = new JButton("Minimize Cost/Benefit"); |
---|
179 | protected JRadioButton m_costR = new JRadioButton("Cost"); |
---|
180 | protected JRadioButton m_benefitR = new JRadioButton("Benefit"); |
---|
181 | protected JLabel m_costBenefitL = new JLabel("Cost: ", SwingConstants.RIGHT); |
---|
182 | protected JLabel m_costBenefitV = new JLabel("0"); |
---|
183 | protected JLabel m_randomV = new JLabel("0"); |
---|
184 | protected JLabel m_gainV = new JLabel("0"); |
---|
185 | |
---|
186 | protected int m_originalPopSize; |
---|
187 | |
---|
188 | /** Population text field */ |
---|
189 | protected JTextField m_totalPopField = new JTextField(6); |
---|
190 | protected int m_totalPopPrevious; |
---|
191 | |
---|
192 | /** Classification accuracy */ |
---|
193 | protected JLabel m_classificationAccV = new JLabel("-"); |
---|
194 | |
---|
195 | // Only update curve & stats if values in cost matrix have changed |
---|
196 | protected double m_tpPrevious; |
---|
197 | protected double m_fpPrevious; |
---|
198 | protected double m_tnPrevious; |
---|
199 | protected double m_fnPrevious; |
---|
200 | |
---|
201 | /** |
---|
202 | * Inner class for handling a single cell in the confusion matrix. |
---|
203 | * Displays the value, value as a percentage of total population and |
---|
204 | * graphical depiction of percentage. |
---|
205 | * |
---|
206 | * @author Mark Hall (mhall{[at]}pentaho{[dot]}com) |
---|
207 | */ |
---|
208 | protected static class ConfusionCell extends JPanel { |
---|
209 | |
---|
210 | /** For serialization */ |
---|
211 | private static final long serialVersionUID = 6148640235434494767L; |
---|
212 | |
---|
213 | private JLabel m_conf_cell = new JLabel("-", SwingConstants.RIGHT); |
---|
214 | JLabel m_conf_perc = new JLabel("-", SwingConstants.RIGHT); |
---|
215 | |
---|
216 | private JPanel m_percentageP; |
---|
217 | |
---|
218 | protected double m_percentage = 0; |
---|
219 | |
---|
220 | public ConfusionCell() { |
---|
221 | setLayout(new BorderLayout()); |
---|
222 | setBorder(BorderFactory.createEtchedBorder()); |
---|
223 | |
---|
224 | add(m_conf_cell, BorderLayout.NORTH); |
---|
225 | |
---|
226 | m_percentageP = new JPanel() { |
---|
227 | public void paintComponent(Graphics gx) { |
---|
228 | super.paintComponent(gx); |
---|
229 | |
---|
230 | if (m_percentage > 0) { |
---|
231 | gx.setColor(Color.BLUE); |
---|
232 | int height = this.getHeight(); |
---|
233 | double width = this.getWidth(); |
---|
234 | int barWidth = (int)(m_percentage * width); |
---|
235 | gx.fillRect(0, 0, barWidth, height); |
---|
236 | } |
---|
237 | } |
---|
238 | }; |
---|
239 | |
---|
240 | Dimension d = new Dimension(30,5); |
---|
241 | m_percentageP.setMinimumSize(d); |
---|
242 | m_percentageP.setPreferredSize(d); |
---|
243 | JPanel percHolder = new JPanel(); |
---|
244 | percHolder.setLayout(new BorderLayout()); |
---|
245 | percHolder.add(m_percentageP, BorderLayout.CENTER); |
---|
246 | percHolder.add(m_conf_perc, BorderLayout.EAST); |
---|
247 | |
---|
248 | add(percHolder, BorderLayout.SOUTH); |
---|
249 | } |
---|
250 | |
---|
251 | /** |
---|
252 | * Set the value of a cell. |
---|
253 | * |
---|
254 | * @param cellValue the value of the cell |
---|
255 | * @param max the max (for setting value as a percentage) |
---|
256 | * @param scaleFactor scale the value by this amount |
---|
257 | * @param precision precision for the percentage value |
---|
258 | */ |
---|
259 | public void setCellValue(double cellValue, double max, double scaleFactor, int precision) { |
---|
260 | if (!Utils.isMissingValue(cellValue)) { |
---|
261 | m_percentage = cellValue / max; |
---|
262 | } else { |
---|
263 | m_percentage = 0; |
---|
264 | } |
---|
265 | |
---|
266 | m_conf_cell.setText(Utils.doubleToString((cellValue * scaleFactor), 0)); |
---|
267 | m_conf_perc.setText(Utils.doubleToString(m_percentage * 100.0, precision) + "%"); |
---|
268 | |
---|
269 | // refresh the percentage bar |
---|
270 | m_percentageP.repaint(); |
---|
271 | } |
---|
272 | } |
---|
273 | |
---|
274 | public AnalysisPanel() { |
---|
275 | setLayout(new BorderLayout()); |
---|
276 | m_performancePanel.setShowAttBars(false); |
---|
277 | m_performancePanel.setShowClassPanel(false); |
---|
278 | m_costBenefitPanel.setShowAttBars(false); |
---|
279 | m_costBenefitPanel.setShowClassPanel(false); |
---|
280 | |
---|
281 | Dimension size = new Dimension(500, 400); |
---|
282 | m_performancePanel.setPreferredSize(size); |
---|
283 | m_performancePanel.setMinimumSize(size); |
---|
284 | |
---|
285 | size = new Dimension(500, 400); |
---|
286 | m_costBenefitPanel.setMinimumSize(size); |
---|
287 | m_costBenefitPanel.setPreferredSize(size); |
---|
288 | |
---|
289 | m_thresholdSlider.addChangeListener(new ChangeListener() { |
---|
290 | public void stateChanged(ChangeEvent e) { |
---|
291 | updateInfoForSliderValue((double)m_thresholdSlider.getValue() / 100.0); |
---|
292 | } |
---|
293 | }); |
---|
294 | |
---|
295 | JPanel plotHolder = new JPanel(); |
---|
296 | plotHolder.setLayout(new GridLayout(1,2)); |
---|
297 | plotHolder.add(m_performancePanel); |
---|
298 | plotHolder.add(m_costBenefitPanel); |
---|
299 | add(plotHolder, BorderLayout.CENTER); |
---|
300 | |
---|
301 | JPanel lowerPanel = new JPanel(); |
---|
302 | lowerPanel.setLayout(new BorderLayout()); |
---|
303 | |
---|
304 | ButtonGroup bGroup = new ButtonGroup(); |
---|
305 | bGroup.add(m_percPop); |
---|
306 | bGroup.add(m_percOfTarget); |
---|
307 | bGroup.add(m_threshold); |
---|
308 | |
---|
309 | ButtonGroup bGroup2 = new ButtonGroup(); |
---|
310 | bGroup2.add(m_costR); |
---|
311 | bGroup2.add(m_benefitR); |
---|
312 | ActionListener rl = new ActionListener() { |
---|
313 | public void actionPerformed(ActionEvent e) { |
---|
314 | if (m_costR.isSelected()) { |
---|
315 | m_costBenefitL.setText("Cost: "); |
---|
316 | } else { |
---|
317 | m_costBenefitL.setText("Benefit: "); |
---|
318 | } |
---|
319 | |
---|
320 | double gain = Double.parseDouble(m_gainV.getText()); |
---|
321 | gain = -gain; |
---|
322 | m_gainV.setText(Utils.doubleToString(gain, 2)); |
---|
323 | } |
---|
324 | }; |
---|
325 | m_costR.addActionListener(rl); |
---|
326 | m_benefitR.addActionListener(rl); |
---|
327 | m_costR.setSelected(true); |
---|
328 | |
---|
329 | m_percPop.setSelected(true); |
---|
330 | JPanel threshPanel = new JPanel(); |
---|
331 | threshPanel.setLayout(new BorderLayout()); |
---|
332 | JPanel radioHolder = new JPanel(); |
---|
333 | radioHolder.setLayout(new FlowLayout()); |
---|
334 | radioHolder.add(m_percPop); |
---|
335 | radioHolder.add(m_percOfTarget); |
---|
336 | radioHolder.add(m_threshold); |
---|
337 | threshPanel.add(radioHolder, BorderLayout.NORTH); |
---|
338 | threshPanel.add(m_thresholdSlider, BorderLayout.SOUTH); |
---|
339 | |
---|
340 | JPanel threshInfoPanel = new JPanel(); |
---|
341 | threshInfoPanel.setLayout(new GridLayout(3,2)); |
---|
342 | threshInfoPanel.add(new JLabel("% of Population: ", SwingConstants.RIGHT)); |
---|
343 | threshInfoPanel.add(m_percPopLab); |
---|
344 | threshInfoPanel.add(new JLabel("% of Target: ", SwingConstants.RIGHT)); |
---|
345 | threshInfoPanel.add(m_percOfTargetLab); |
---|
346 | threshInfoPanel.add(new JLabel("Score Threshold: ", SwingConstants.RIGHT)); |
---|
347 | threshInfoPanel.add(m_thresholdLab); |
---|
348 | |
---|
349 | JPanel threshHolder = new JPanel(); |
---|
350 | threshHolder.setBorder(BorderFactory.createTitledBorder("Threshold")); |
---|
351 | threshHolder.setLayout(new BorderLayout()); |
---|
352 | threshHolder.add(threshPanel, BorderLayout.CENTER); |
---|
353 | threshHolder.add(threshInfoPanel, BorderLayout.EAST); |
---|
354 | |
---|
355 | lowerPanel.add(threshHolder, BorderLayout.NORTH); |
---|
356 | |
---|
357 | // holder for the two matrixes |
---|
358 | JPanel matrixHolder = new JPanel(); |
---|
359 | matrixHolder.setLayout(new GridLayout(1,2)); |
---|
360 | |
---|
361 | // confusion matrix |
---|
362 | JPanel confusionPanel = new JPanel(); |
---|
363 | confusionPanel.setLayout(new GridLayout(3,3)); |
---|
364 | confusionPanel.add(m_conf_predictedA); |
---|
365 | confusionPanel.add(m_conf_predictedB); |
---|
366 | confusionPanel.add(new JLabel()); // dummy |
---|
367 | confusionPanel.add(m_conf_aa); |
---|
368 | confusionPanel.add(m_conf_ab); |
---|
369 | confusionPanel.add(m_conf_actualA); |
---|
370 | confusionPanel.add(m_conf_ba); |
---|
371 | confusionPanel.add(m_conf_bb); |
---|
372 | confusionPanel.add(m_conf_actualB); |
---|
373 | JPanel tempHolderCA = new JPanel(); |
---|
374 | tempHolderCA.setLayout(new BorderLayout()); |
---|
375 | tempHolderCA.setBorder(BorderFactory.createTitledBorder("Confusion Matrix")); |
---|
376 | tempHolderCA.add(confusionPanel, BorderLayout.CENTER); |
---|
377 | |
---|
378 | JPanel accHolder = new JPanel(); |
---|
379 | accHolder.setLayout(new FlowLayout(FlowLayout.LEFT)); |
---|
380 | accHolder.add(new JLabel("Classification Accuracy: ")); |
---|
381 | accHolder.add(m_classificationAccV); |
---|
382 | tempHolderCA.add(accHolder, BorderLayout.SOUTH); |
---|
383 | |
---|
384 | matrixHolder.add(tempHolderCA); |
---|
385 | |
---|
386 | // cost matrix |
---|
387 | JPanel costPanel = new JPanel(); |
---|
388 | costPanel.setBorder(BorderFactory.createTitledBorder("Cost Matrix")); |
---|
389 | costPanel.setLayout(new BorderLayout()); |
---|
390 | |
---|
391 | JPanel cmHolder = new JPanel(); |
---|
392 | cmHolder.setLayout(new GridLayout(3, 3)); |
---|
393 | cmHolder.add(m_cost_predictedA); |
---|
394 | cmHolder.add(m_cost_predictedB); |
---|
395 | cmHolder.add(new JLabel()); // dummy |
---|
396 | cmHolder.add(m_cost_aa); |
---|
397 | cmHolder.add(m_cost_ab); |
---|
398 | cmHolder.add(m_cost_actualA); |
---|
399 | cmHolder.add(m_cost_ba); |
---|
400 | cmHolder.add(m_cost_bb); |
---|
401 | cmHolder.add(m_cost_actualB); |
---|
402 | costPanel.add(cmHolder, BorderLayout.CENTER); |
---|
403 | |
---|
404 | FocusListener fl = new FocusListener() { |
---|
405 | public void focusGained(FocusEvent e) { |
---|
406 | |
---|
407 | } |
---|
408 | |
---|
409 | public void focusLost(FocusEvent e) { |
---|
410 | if (constructCostBenefitData()) { |
---|
411 | try { |
---|
412 | m_costBenefitPanel.setMasterPlot(m_costBenefit); |
---|
413 | m_costBenefitPanel.validate(); m_costBenefitPanel.repaint(); |
---|
414 | } catch (Exception ex) { |
---|
415 | ex.printStackTrace(); |
---|
416 | } |
---|
417 | updateCostBenefit(); |
---|
418 | } |
---|
419 | } |
---|
420 | }; |
---|
421 | |
---|
422 | ActionListener al = new ActionListener() { |
---|
423 | public void actionPerformed(ActionEvent e) { |
---|
424 | if (constructCostBenefitData()) { |
---|
425 | try { |
---|
426 | m_costBenefitPanel.setMasterPlot(m_costBenefit); |
---|
427 | m_costBenefitPanel.validate(); m_costBenefitPanel.repaint(); |
---|
428 | } catch (Exception ex) { |
---|
429 | ex.printStackTrace(); |
---|
430 | } |
---|
431 | updateCostBenefit(); |
---|
432 | } |
---|
433 | } |
---|
434 | }; |
---|
435 | |
---|
436 | m_cost_aa.addFocusListener(fl); |
---|
437 | m_cost_aa.addActionListener(al); |
---|
438 | m_cost_ab.addFocusListener(fl); |
---|
439 | m_cost_ab.addActionListener(al); |
---|
440 | m_cost_ba.addFocusListener(fl); |
---|
441 | m_cost_ba.addActionListener(al); |
---|
442 | m_cost_bb.addFocusListener(fl); |
---|
443 | m_cost_bb.addActionListener(al); |
---|
444 | |
---|
445 | m_totalPopField.addFocusListener(fl); |
---|
446 | m_totalPopField.addActionListener(al); |
---|
447 | |
---|
448 | JPanel cbHolder = new JPanel(); |
---|
449 | cbHolder.setLayout(new BorderLayout()); |
---|
450 | JPanel tempP = new JPanel(); |
---|
451 | tempP.setLayout(new GridLayout(3, 2)); |
---|
452 | tempP.add(m_costBenefitL); |
---|
453 | tempP.add(m_costBenefitV); |
---|
454 | tempP.add(new JLabel("Random: ", SwingConstants.RIGHT)); |
---|
455 | tempP.add(m_randomV); |
---|
456 | tempP.add(new JLabel("Gain: ", SwingConstants.RIGHT)); |
---|
457 | tempP.add(m_gainV); |
---|
458 | cbHolder.add(tempP, BorderLayout.NORTH); |
---|
459 | JPanel butHolder = new JPanel(); |
---|
460 | butHolder.setLayout(new GridLayout(2, 1)); |
---|
461 | butHolder.add(m_maximizeCB); |
---|
462 | butHolder.add(m_minimizeCB); |
---|
463 | m_maximizeCB.addActionListener(new ActionListener() { |
---|
464 | public void actionPerformed(ActionEvent e) { |
---|
465 | findMaxMinCB(true); |
---|
466 | } |
---|
467 | }); |
---|
468 | |
---|
469 | m_minimizeCB.addActionListener(new ActionListener() { |
---|
470 | public void actionPerformed(ActionEvent e) { |
---|
471 | findMaxMinCB(false); |
---|
472 | } |
---|
473 | }); |
---|
474 | |
---|
475 | cbHolder.add(butHolder, BorderLayout.SOUTH); |
---|
476 | costPanel.add(cbHolder, BorderLayout.EAST); |
---|
477 | |
---|
478 | JPanel popCBR = new JPanel(); |
---|
479 | popCBR.setLayout(new GridLayout(1, 2)); |
---|
480 | JPanel popHolder = new JPanel(); |
---|
481 | popHolder.setLayout(new FlowLayout(FlowLayout.LEFT)); |
---|
482 | popHolder.add(new JLabel("Total Population: ")); |
---|
483 | popHolder.add(m_totalPopField); |
---|
484 | |
---|
485 | JPanel radioHolder2 = new JPanel(); |
---|
486 | radioHolder2.setLayout(new FlowLayout(FlowLayout.RIGHT)); |
---|
487 | radioHolder2.add(m_costR); |
---|
488 | radioHolder2.add(m_benefitR); |
---|
489 | popCBR.add(popHolder); |
---|
490 | popCBR.add(radioHolder2); |
---|
491 | |
---|
492 | costPanel.add(popCBR, BorderLayout.SOUTH); |
---|
493 | |
---|
494 | matrixHolder.add(costPanel); |
---|
495 | |
---|
496 | |
---|
497 | lowerPanel.add(matrixHolder, BorderLayout.SOUTH); |
---|
498 | |
---|
499 | |
---|
500 | |
---|
501 | // popAccHolder.add(popHolder); |
---|
502 | |
---|
503 | //popAccHolder.add(accHolder); |
---|
504 | |
---|
505 | /*JPanel lowerPanel2 = new JPanel(); |
---|
506 | lowerPanel2.setLayout(new BorderLayout()); |
---|
507 | lowerPanel2.add(lowerPanel, BorderLayout.NORTH); |
---|
508 | lowerPanel2.add(popAccHolder, BorderLayout.SOUTH); */ |
---|
509 | |
---|
510 | add(lowerPanel, BorderLayout.SOUTH); |
---|
511 | |
---|
512 | } |
---|
513 | |
---|
514 | private void findMaxMinCB(boolean max) { |
---|
515 | double maxMin = (max) |
---|
516 | ? Double.NEGATIVE_INFINITY |
---|
517 | : Double.POSITIVE_INFINITY; |
---|
518 | |
---|
519 | Instances cBCurve = m_costBenefit.getPlotInstances(); |
---|
520 | int maxMinIndex = 0; |
---|
521 | |
---|
522 | for (int i = 0; i < cBCurve.numInstances(); i++) { |
---|
523 | Instance current = cBCurve.instance(i); |
---|
524 | if (max) { |
---|
525 | if (current.value(1) > maxMin) { |
---|
526 | maxMin = current.value(1); |
---|
527 | maxMinIndex = i; |
---|
528 | } |
---|
529 | } else { |
---|
530 | if (current.value(1) < maxMin) { |
---|
531 | maxMin = current.value(1); |
---|
532 | maxMinIndex = i; |
---|
533 | } |
---|
534 | } |
---|
535 | } |
---|
536 | |
---|
537 | |
---|
538 | // set the slider to the correct position |
---|
539 | int indexOfSampleSize = |
---|
540 | m_masterPlot.getPlotInstances().attribute(ThresholdCurve.SAMPLE_SIZE_NAME).index(); |
---|
541 | int indexOfPercOfTarget = |
---|
542 | m_masterPlot.getPlotInstances().attribute(ThresholdCurve.RECALL_NAME).index(); |
---|
543 | int indexOfThreshold = |
---|
544 | m_masterPlot.getPlotInstances().attribute(ThresholdCurve.THRESHOLD_NAME).index(); |
---|
545 | int indexOfMetric; |
---|
546 | |
---|
547 | if (m_percPop.isSelected()) { |
---|
548 | indexOfMetric = indexOfSampleSize; |
---|
549 | } else if (m_percOfTarget.isSelected()) { |
---|
550 | indexOfMetric = indexOfPercOfTarget; |
---|
551 | } else { |
---|
552 | indexOfMetric = indexOfThreshold; |
---|
553 | } |
---|
554 | |
---|
555 | double valueOfMetric = m_masterPlot.getPlotInstances().instance(maxMinIndex).value(indexOfMetric); |
---|
556 | valueOfMetric *= 100.0; |
---|
557 | |
---|
558 | // set the approximate location of the slider |
---|
559 | m_thresholdSlider.setValue((int)valueOfMetric); |
---|
560 | |
---|
561 | // make sure the actual values relate to the true min/max rather |
---|
562 | // than being off due to slider location error. |
---|
563 | updateInfoGivenIndex(maxMinIndex); |
---|
564 | } |
---|
565 | |
---|
566 | private void updateCostBenefit() { |
---|
567 | double value = (double)m_thresholdSlider.getValue() / 100.0; |
---|
568 | Instances plotInstances = m_masterPlot.getPlotInstances(); |
---|
569 | int indexOfSampleSize = |
---|
570 | m_masterPlot.getPlotInstances().attribute(ThresholdCurve.SAMPLE_SIZE_NAME).index(); |
---|
571 | int indexOfPercOfTarget = |
---|
572 | m_masterPlot.getPlotInstances().attribute(ThresholdCurve.RECALL_NAME).index(); |
---|
573 | int indexOfThreshold = |
---|
574 | m_masterPlot.getPlotInstances().attribute(ThresholdCurve.THRESHOLD_NAME).index(); |
---|
575 | int indexOfMetric; |
---|
576 | |
---|
577 | if (m_percPop.isSelected()) { |
---|
578 | indexOfMetric = indexOfSampleSize; |
---|
579 | } else if (m_percOfTarget.isSelected()) { |
---|
580 | indexOfMetric = indexOfPercOfTarget; |
---|
581 | } else { |
---|
582 | indexOfMetric = indexOfThreshold; |
---|
583 | } |
---|
584 | |
---|
585 | int index = findIndexForValue(value, plotInstances, indexOfMetric); |
---|
586 | updateCBRandomGainInfo(index); |
---|
587 | } |
---|
588 | |
---|
589 | private void updateCBRandomGainInfo(int index) { |
---|
590 | double requestedPopSize = m_originalPopSize; |
---|
591 | try { |
---|
592 | requestedPopSize = Double.parseDouble(m_totalPopField.getText()); |
---|
593 | } catch (NumberFormatException e) {} |
---|
594 | double scaleFactor = requestedPopSize / m_originalPopSize; |
---|
595 | |
---|
596 | double CB = m_costBenefit. |
---|
597 | getPlotInstances().instance(index).value(1); |
---|
598 | m_costBenefitV.setText(Utils.doubleToString(CB,2)); |
---|
599 | |
---|
600 | double totalRandomCB = 0.0; |
---|
601 | Instance first = m_masterPlot.getPlotInstances().instance(0); |
---|
602 | double totalPos = first.value(m_masterPlot.getPlotInstances(). |
---|
603 | attribute(ThresholdCurve.TRUE_POS_NAME).index()) * scaleFactor; |
---|
604 | double totalNeg = first.value(m_masterPlot.getPlotInstances(). |
---|
605 | attribute(ThresholdCurve.FALSE_POS_NAME)) * scaleFactor; |
---|
606 | |
---|
607 | double posInSample = (totalPos * (Double.parseDouble(m_percPopLab.getText()) / 100.0)); |
---|
608 | double negInSample = (totalNeg * (Double.parseDouble(m_percPopLab.getText()) / 100.0)); |
---|
609 | double posOutSample = totalPos - posInSample; |
---|
610 | double negOutSample = totalNeg - negInSample; |
---|
611 | |
---|
612 | double tpCost = 0.0; |
---|
613 | try { |
---|
614 | tpCost = Double.parseDouble(m_cost_aa.getText()); |
---|
615 | } catch (NumberFormatException n) {} |
---|
616 | double fpCost = 0.0; |
---|
617 | try { |
---|
618 | fpCost = Double.parseDouble(m_cost_ba.getText()); |
---|
619 | } catch (NumberFormatException n) {} |
---|
620 | double tnCost = 0.0; |
---|
621 | try { |
---|
622 | tnCost = Double.parseDouble(m_cost_bb.getText()); |
---|
623 | } catch (NumberFormatException n) {} |
---|
624 | double fnCost = 0.0; |
---|
625 | try { |
---|
626 | fnCost = Double.parseDouble(m_cost_ab.getText()); |
---|
627 | } catch (NumberFormatException n) {} |
---|
628 | |
---|
629 | totalRandomCB += posInSample * tpCost; |
---|
630 | totalRandomCB += negInSample * fpCost; |
---|
631 | totalRandomCB += posOutSample * fnCost; |
---|
632 | totalRandomCB += negOutSample * tnCost; |
---|
633 | |
---|
634 | m_randomV.setText(Utils.doubleToString(totalRandomCB, 2)); |
---|
635 | double gain = (m_costR.isSelected()) |
---|
636 | ? totalRandomCB - CB |
---|
637 | : CB - totalRandomCB; |
---|
638 | m_gainV.setText(Utils.doubleToString(gain, 2)); |
---|
639 | |
---|
640 | // update classification rate |
---|
641 | Instance currentInst = m_masterPlot.getPlotInstances().instance(index); |
---|
642 | double tp = currentInst.value(m_masterPlot.getPlotInstances(). |
---|
643 | attribute(ThresholdCurve.TRUE_POS_NAME).index()); |
---|
644 | double tn = currentInst.value(m_masterPlot.getPlotInstances(). |
---|
645 | attribute(ThresholdCurve.TRUE_NEG_NAME).index()); |
---|
646 | m_classificationAccV. |
---|
647 | setText(Utils.doubleToString((tp + tn) / (totalPos + totalNeg) * 100.0, 4) + "%"); |
---|
648 | } |
---|
649 | |
---|
650 | private void updateInfoGivenIndex(int index) { |
---|
651 | Instances plotInstances = m_masterPlot.getPlotInstances(); |
---|
652 | int indexOfSampleSize = |
---|
653 | m_masterPlot.getPlotInstances().attribute(ThresholdCurve.SAMPLE_SIZE_NAME).index(); |
---|
654 | int indexOfPercOfTarget = |
---|
655 | m_masterPlot.getPlotInstances().attribute(ThresholdCurve.RECALL_NAME).index(); |
---|
656 | int indexOfThreshold = |
---|
657 | m_masterPlot.getPlotInstances().attribute(ThresholdCurve.THRESHOLD_NAME).index(); |
---|
658 | |
---|
659 | // update labels |
---|
660 | m_percPopLab.setText(Utils. |
---|
661 | doubleToString(100.0 * plotInstances.instance(index).value(indexOfSampleSize), 4)); |
---|
662 | m_percOfTargetLab.setText(Utils.doubleToString( |
---|
663 | 100.0 * plotInstances.instance(index).value(indexOfPercOfTarget), 4)); |
---|
664 | m_thresholdLab.setText(Utils.doubleToString(plotInstances.instance(index).value(indexOfThreshold), 4)); |
---|
665 | /*if (m_percPop.isSelected()) { |
---|
666 | m_percPopLab.setText(Utils.doubleToString(100.0 * value, 4)); |
---|
667 | } else if (m_percOfTarget.isSelected()) { |
---|
668 | m_percOfTargetLab.setText(Utils.doubleToString(100.0 * value, 4)); |
---|
669 | } else { |
---|
670 | m_thresholdLab.setText(Utils.doubleToString(value, 4)); |
---|
671 | }*/ |
---|
672 | |
---|
673 | // Update the highlighted point on the graphs */ |
---|
674 | if (m_previousShapeIndex >= 0) { |
---|
675 | m_shapeSizes[m_previousShapeIndex] = 1; |
---|
676 | } |
---|
677 | |
---|
678 | m_shapeSizes[index] = 10; |
---|
679 | m_previousShapeIndex = index; |
---|
680 | |
---|
681 | // Update the confusion matrix |
---|
682 | // double totalInstances = |
---|
683 | int tp = plotInstances.attribute(ThresholdCurve.TRUE_POS_NAME).index(); |
---|
684 | int fp = plotInstances.attribute(ThresholdCurve.FALSE_POS_NAME).index(); |
---|
685 | int tn = plotInstances.attribute(ThresholdCurve.TRUE_NEG_NAME).index(); |
---|
686 | int fn = plotInstances.attribute(ThresholdCurve.FALSE_NEG_NAME).index(); |
---|
687 | Instance temp = plotInstances.instance(index); |
---|
688 | double totalInstances = temp.value(tp) + temp.value(fp) + temp.value(tn) + temp.value(fn); |
---|
689 | // get the value out of the total pop field (if possible) |
---|
690 | double requestedPopSize = totalInstances; |
---|
691 | try { |
---|
692 | requestedPopSize = Double.parseDouble(m_totalPopField.getText()); |
---|
693 | } catch (NumberFormatException e) {} |
---|
694 | |
---|
695 | m_conf_aa.setCellValue(temp.value(tp), totalInstances, |
---|
696 | requestedPopSize / totalInstances, 2); |
---|
697 | m_conf_ab.setCellValue(temp.value(fn), totalInstances, |
---|
698 | requestedPopSize / totalInstances, 2); |
---|
699 | m_conf_ba.setCellValue(temp.value(fp), totalInstances, |
---|
700 | requestedPopSize / totalInstances, 2); |
---|
701 | m_conf_bb.setCellValue(temp.value(tn), totalInstances, |
---|
702 | requestedPopSize / totalInstances, 2); |
---|
703 | |
---|
704 | updateCBRandomGainInfo(index); |
---|
705 | |
---|
706 | repaint(); |
---|
707 | } |
---|
708 | |
---|
709 | private void updateInfoForSliderValue(double value) { |
---|
710 | int indexOfSampleSize = |
---|
711 | m_masterPlot.getPlotInstances().attribute(ThresholdCurve.SAMPLE_SIZE_NAME).index(); |
---|
712 | int indexOfPercOfTarget = |
---|
713 | m_masterPlot.getPlotInstances().attribute(ThresholdCurve.RECALL_NAME).index(); |
---|
714 | int indexOfThreshold = |
---|
715 | m_masterPlot.getPlotInstances().attribute(ThresholdCurve.THRESHOLD_NAME).index(); |
---|
716 | int indexOfMetric; |
---|
717 | |
---|
718 | if (m_percPop.isSelected()) { |
---|
719 | indexOfMetric = indexOfSampleSize; |
---|
720 | } else if (m_percOfTarget.isSelected()) { |
---|
721 | indexOfMetric = indexOfPercOfTarget; |
---|
722 | } else { |
---|
723 | indexOfMetric = indexOfThreshold; |
---|
724 | } |
---|
725 | |
---|
726 | Instances plotInstances = m_masterPlot.getPlotInstances(); |
---|
727 | int index = findIndexForValue(value, plotInstances, indexOfMetric); |
---|
728 | updateInfoGivenIndex(index); |
---|
729 | } |
---|
730 | |
---|
731 | private int findIndexForValue(double value, Instances plotInstances, int indexOfMetric) { |
---|
732 | // binary search |
---|
733 | // threshold curve is sorted ascending in the threshold (thus |
---|
734 | // descending for recall and pop size) |
---|
735 | int index = -1; |
---|
736 | int lower = 0; |
---|
737 | int upper = plotInstances.numInstances() - 1; |
---|
738 | int mid = (upper - lower) / 2; |
---|
739 | boolean done = false; |
---|
740 | while (!done) { |
---|
741 | if (upper - lower <= 1) { |
---|
742 | |
---|
743 | // choose the one closest to the value |
---|
744 | double comp1 = plotInstances.instance(upper).value(indexOfMetric); |
---|
745 | double comp2 = plotInstances.instance(lower).value(indexOfMetric); |
---|
746 | if (Math.abs(comp1 - value) < Math.abs(comp2 - value)) { |
---|
747 | index = upper; |
---|
748 | } else { |
---|
749 | index = lower; |
---|
750 | } |
---|
751 | |
---|
752 | break; |
---|
753 | } |
---|
754 | double comparisonVal = plotInstances.instance(mid).value(indexOfMetric); |
---|
755 | if (value > comparisonVal) { |
---|
756 | if (m_threshold.isSelected()) { |
---|
757 | lower = mid; |
---|
758 | mid += (upper - lower) / 2; |
---|
759 | } else { |
---|
760 | upper = mid; |
---|
761 | mid -= (upper - lower) / 2; |
---|
762 | } |
---|
763 | } else if (value < comparisonVal) { |
---|
764 | if (m_threshold.isSelected()) { |
---|
765 | upper = mid; |
---|
766 | mid -= (upper - lower) / 2; |
---|
767 | } else { |
---|
768 | lower = mid; |
---|
769 | mid += (upper - lower) / 2; |
---|
770 | } |
---|
771 | } else { |
---|
772 | index = mid; |
---|
773 | done = true; |
---|
774 | } |
---|
775 | } |
---|
776 | |
---|
777 | // now check for ties in the appropriate direction |
---|
778 | if (!m_threshold.isSelected()) { |
---|
779 | while (index + 1 < plotInstances.numInstances()) { |
---|
780 | if (plotInstances.instance(index + 1).value(indexOfMetric) == |
---|
781 | plotInstances.instance(index).value(indexOfMetric)) { |
---|
782 | index++; |
---|
783 | } else { |
---|
784 | break; |
---|
785 | } |
---|
786 | } |
---|
787 | } else { |
---|
788 | while (index - 1 >= 0) { |
---|
789 | if (plotInstances.instance(index - 1).value(indexOfMetric) == |
---|
790 | plotInstances.instance(index).value(indexOfMetric)) { |
---|
791 | index--; |
---|
792 | } else { |
---|
793 | break; |
---|
794 | } |
---|
795 | } |
---|
796 | } |
---|
797 | return index; |
---|
798 | } |
---|
799 | |
---|
800 | /** |
---|
801 | * Set the threshold data for the panel to use. |
---|
802 | * |
---|
803 | * @param data PlotData2D object encapsulating the threshold data. |
---|
804 | * @param classAtt the class attribute from the original data used to generate |
---|
805 | * the threshold data. |
---|
806 | * @throws Exception if something goes wrong. |
---|
807 | */ |
---|
808 | public synchronized void setDataSet(PlotData2D data, Attribute classAtt) throws Exception { |
---|
809 | // make a copy of the PlotData2D object |
---|
810 | m_masterPlot = new PlotData2D(data.getPlotInstances()); |
---|
811 | boolean[] connectPoints = new boolean[m_masterPlot.getPlotInstances().numInstances()]; |
---|
812 | for (int i = 1; i < connectPoints.length; i++) { |
---|
813 | connectPoints[i] = true; |
---|
814 | } |
---|
815 | m_masterPlot.setConnectPoints(connectPoints); |
---|
816 | |
---|
817 | m_masterPlot.m_alwaysDisplayPointsOfThisSize = 10; |
---|
818 | setClassForConfusionMatrix(classAtt); |
---|
819 | m_performancePanel.setMasterPlot(m_masterPlot); |
---|
820 | m_performancePanel.validate(); m_performancePanel.repaint(); |
---|
821 | |
---|
822 | m_shapeSizes = new int[m_masterPlot.getPlotInstances().numInstances()]; |
---|
823 | for (int i = 0; i < m_shapeSizes.length; i++) { |
---|
824 | m_shapeSizes[i] = 1; |
---|
825 | } |
---|
826 | m_masterPlot.setShapeSize(m_shapeSizes); |
---|
827 | constructCostBenefitData(); |
---|
828 | m_costBenefitPanel.setMasterPlot(m_costBenefit); |
---|
829 | m_costBenefitPanel.validate(); m_costBenefitPanel.repaint(); |
---|
830 | |
---|
831 | m_totalPopPrevious = 0; |
---|
832 | m_fpPrevious = 0; |
---|
833 | m_tpPrevious = 0; |
---|
834 | m_tnPrevious = 0; |
---|
835 | m_fnPrevious = 0; |
---|
836 | m_previousShapeIndex = -1; |
---|
837 | |
---|
838 | // set the total population size |
---|
839 | Instance first = m_masterPlot.getPlotInstances().instance(0); |
---|
840 | double totalPos = first.value(m_masterPlot.getPlotInstances(). |
---|
841 | attribute(ThresholdCurve.TRUE_POS_NAME).index()); |
---|
842 | double totalNeg = first.value(m_masterPlot.getPlotInstances(). |
---|
843 | attribute(ThresholdCurve.FALSE_POS_NAME)); |
---|
844 | m_originalPopSize = (int)(totalPos + totalNeg); |
---|
845 | m_totalPopField.setText("" + m_originalPopSize); |
---|
846 | |
---|
847 | m_performancePanel.setYIndex(5); |
---|
848 | m_performancePanel.setXIndex(10); |
---|
849 | m_costBenefitPanel.setXIndex(0); |
---|
850 | m_costBenefitPanel.setYIndex(1); |
---|
851 | // System.err.println(m_masterPlot.getPlotInstances()); |
---|
852 | updateInfoForSliderValue((double)m_thresholdSlider.getValue() / 100.0); |
---|
853 | } |
---|
854 | |
---|
855 | private void setClassForConfusionMatrix(Attribute classAtt) { |
---|
856 | m_classAttribute = classAtt; |
---|
857 | m_conf_actualA.setText(" Actual (a): " + classAtt.value(0)); |
---|
858 | m_conf_actualA.setToolTipText(classAtt.value(0)); |
---|
859 | String negClasses = ""; |
---|
860 | for (int i = 1; i < classAtt.numValues(); i++) { |
---|
861 | negClasses += classAtt.value(i); |
---|
862 | if (i < classAtt.numValues() - 1) { |
---|
863 | negClasses += ","; |
---|
864 | } |
---|
865 | } |
---|
866 | m_conf_actualB.setText(" Actual (b): " + negClasses); |
---|
867 | m_conf_actualB.setToolTipText(negClasses); |
---|
868 | } |
---|
869 | |
---|
870 | private boolean constructCostBenefitData() { |
---|
871 | double tpCost = 0.0; |
---|
872 | try { |
---|
873 | tpCost = Double.parseDouble(m_cost_aa.getText()); |
---|
874 | } catch (NumberFormatException n) {} |
---|
875 | double fpCost = 0.0; |
---|
876 | try { |
---|
877 | fpCost = Double.parseDouble(m_cost_ba.getText()); |
---|
878 | } catch (NumberFormatException n) {} |
---|
879 | double tnCost = 0.0; |
---|
880 | try { |
---|
881 | tnCost = Double.parseDouble(m_cost_bb.getText()); |
---|
882 | } catch (NumberFormatException n) {} |
---|
883 | double fnCost = 0.0; |
---|
884 | try { |
---|
885 | fnCost = Double.parseDouble(m_cost_ab.getText()); |
---|
886 | } catch (NumberFormatException n) {} |
---|
887 | |
---|
888 | double requestedPopSize = m_originalPopSize; |
---|
889 | try { |
---|
890 | requestedPopSize = Double.parseDouble(m_totalPopField.getText()); |
---|
891 | } catch (NumberFormatException e) {} |
---|
892 | |
---|
893 | double scaleFactor = 1.0; |
---|
894 | if (m_originalPopSize != 0) { |
---|
895 | scaleFactor = requestedPopSize / m_originalPopSize; |
---|
896 | } |
---|
897 | |
---|
898 | if (tpCost == m_tpPrevious && fpCost == m_fpPrevious && |
---|
899 | tnCost == m_tnPrevious && fnCost == m_fnPrevious && |
---|
900 | requestedPopSize == m_totalPopPrevious) { |
---|
901 | return false; |
---|
902 | } |
---|
903 | |
---|
904 | // First construct some Instances for the curve |
---|
905 | FastVector fv = new FastVector(); |
---|
906 | fv.addElement(new Attribute("Sample Size")); |
---|
907 | fv.addElement(new Attribute("Cost/Benefit")); |
---|
908 | Instances costBenefitI = new Instances("Cost/Benefit Curve", fv, 100); |
---|
909 | |
---|
910 | // process the performance data to make this curve |
---|
911 | Instances performanceI = m_masterPlot.getPlotInstances(); |
---|
912 | |
---|
913 | for (int i = 0; i < performanceI.numInstances(); i++) { |
---|
914 | Instance current = performanceI.instance(i); |
---|
915 | |
---|
916 | double[] vals = new double[2]; |
---|
917 | vals[0] = current.value(10); // sample size |
---|
918 | vals[1] = (current.value(0) * tpCost |
---|
919 | + current.value(1) * fnCost |
---|
920 | + current.value(2) * fpCost |
---|
921 | + current.value(3) * tnCost) * scaleFactor; |
---|
922 | Instance newInst = new DenseInstance(1.0, vals); |
---|
923 | costBenefitI.add(newInst); |
---|
924 | } |
---|
925 | |
---|
926 | costBenefitI.compactify(); |
---|
927 | |
---|
928 | // now set up the plot data |
---|
929 | m_costBenefit = new PlotData2D(costBenefitI); |
---|
930 | m_costBenefit.m_alwaysDisplayPointsOfThisSize = 10; |
---|
931 | m_costBenefit.setPlotName("Cost/benefit curve"); |
---|
932 | boolean[] connectPoints = new boolean[costBenefitI.numInstances()]; |
---|
933 | |
---|
934 | for (int i = 0; i < connectPoints.length; i++) { |
---|
935 | connectPoints[i] = true; |
---|
936 | } |
---|
937 | try { |
---|
938 | m_costBenefit.setConnectPoints(connectPoints); |
---|
939 | m_costBenefit.setShapeSize(m_shapeSizes); |
---|
940 | } catch (Exception ex) { |
---|
941 | // ignore |
---|
942 | } |
---|
943 | |
---|
944 | m_tpPrevious = tpCost; |
---|
945 | m_fpPrevious = fpCost; |
---|
946 | m_tnPrevious = tnCost; |
---|
947 | m_fnPrevious = fnCost; |
---|
948 | |
---|
949 | return true; |
---|
950 | } |
---|
951 | } |
---|
952 | |
---|
953 | /** |
---|
954 | * Constructor. |
---|
955 | */ |
---|
956 | public CostBenefitAnalysis() { |
---|
957 | java.awt.GraphicsEnvironment ge = |
---|
958 | java.awt.GraphicsEnvironment.getLocalGraphicsEnvironment(); |
---|
959 | if (!ge.isHeadless()) { |
---|
960 | appearanceFinal(); |
---|
961 | } |
---|
962 | } |
---|
963 | |
---|
964 | /** |
---|
965 | * Global info for this bean |
---|
966 | * |
---|
967 | * @return a <code>String</code> value |
---|
968 | */ |
---|
969 | public String globalInfo() { |
---|
970 | return "Visualize performance charts (such as ROC)."; |
---|
971 | } |
---|
972 | |
---|
973 | /** |
---|
974 | * Accept a threshold data event and set up the visualization. |
---|
975 | * @param e a threshold data event |
---|
976 | */ |
---|
977 | public void acceptDataSet(ThresholdDataEvent e) { |
---|
978 | try { |
---|
979 | setCurveData(e.getDataSet(), e.getClassAttribute()); |
---|
980 | } catch (Exception ex) { |
---|
981 | System.err.println("[CostBenefitAnalysis] Problem setting up visualization."); |
---|
982 | ex.printStackTrace(); |
---|
983 | } |
---|
984 | } |
---|
985 | |
---|
986 | /** |
---|
987 | * Set the threshold curve data to use. |
---|
988 | * |
---|
989 | * @param curveData a PlotData2D object set up with the curve data. |
---|
990 | * @param origClassAtt the class attribute from the original data used to |
---|
991 | * generate the curve. |
---|
992 | * @throws Exception if somthing goes wrong during the setup process. |
---|
993 | */ |
---|
994 | public void setCurveData(PlotData2D curveData, Attribute origClassAtt) |
---|
995 | throws Exception { |
---|
996 | if (m_analysisPanel == null) { |
---|
997 | m_analysisPanel = new AnalysisPanel(); |
---|
998 | } |
---|
999 | m_analysisPanel.setDataSet(curveData, origClassAtt); |
---|
1000 | } |
---|
1001 | |
---|
1002 | public BeanVisual getVisual() { |
---|
1003 | return m_visual; |
---|
1004 | } |
---|
1005 | |
---|
1006 | public void setVisual(BeanVisual newVisual) { |
---|
1007 | m_visual = newVisual; |
---|
1008 | } |
---|
1009 | |
---|
1010 | public void useDefaultVisual() { |
---|
1011 | m_visual.loadIcons(BeanVisual.ICON_PATH+"DefaultDataVisualizer.gif", |
---|
1012 | BeanVisual.ICON_PATH+"DefaultDataVisualizer_animated.gif"); |
---|
1013 | } |
---|
1014 | |
---|
1015 | public Enumeration enumerateRequests() { |
---|
1016 | Vector newVector = new Vector(0); |
---|
1017 | if (m_analysisPanel != null) { |
---|
1018 | if (m_analysisPanel.m_masterPlot != null) { |
---|
1019 | newVector.addElement("Show analysis"); |
---|
1020 | } |
---|
1021 | } |
---|
1022 | return newVector.elements(); |
---|
1023 | } |
---|
1024 | |
---|
1025 | public void performRequest(String request) { |
---|
1026 | if (request.compareTo("Show analysis") == 0) { |
---|
1027 | try { |
---|
1028 | // popup visualize panel |
---|
1029 | if (!m_framePoppedUp) { |
---|
1030 | m_framePoppedUp = true; |
---|
1031 | |
---|
1032 | final javax.swing.JFrame jf = |
---|
1033 | new javax.swing.JFrame("Cost/Benefit Analysis"); |
---|
1034 | jf.setSize(1000,600); |
---|
1035 | jf.getContentPane().setLayout(new BorderLayout()); |
---|
1036 | jf.getContentPane().add(m_analysisPanel, BorderLayout.CENTER); |
---|
1037 | jf.addWindowListener(new java.awt.event.WindowAdapter() { |
---|
1038 | public void windowClosing(java.awt.event.WindowEvent e) { |
---|
1039 | jf.dispose(); |
---|
1040 | m_framePoppedUp = false; |
---|
1041 | } |
---|
1042 | }); |
---|
1043 | jf.setVisible(true); |
---|
1044 | m_popupFrame = jf; |
---|
1045 | } else { |
---|
1046 | m_popupFrame.toFront(); |
---|
1047 | } |
---|
1048 | } catch (Exception ex) { |
---|
1049 | ex.printStackTrace(); |
---|
1050 | m_framePoppedUp = false; |
---|
1051 | } |
---|
1052 | } else { |
---|
1053 | throw new IllegalArgumentException(request |
---|
1054 | + " not supported (Cost/Benefit Analysis"); |
---|
1055 | } |
---|
1056 | } |
---|
1057 | |
---|
1058 | public void addVetoableChangeListener(String name, VetoableChangeListener vcl) { |
---|
1059 | m_bcSupport.addVetoableChangeListener(name, vcl); |
---|
1060 | } |
---|
1061 | |
---|
1062 | public BeanContext getBeanContext() { |
---|
1063 | return m_beanContext; |
---|
1064 | } |
---|
1065 | |
---|
1066 | public void removeVetoableChangeListener(String name, |
---|
1067 | VetoableChangeListener vcl) { |
---|
1068 | m_bcSupport.removeVetoableChangeListener(name, vcl); |
---|
1069 | } |
---|
1070 | |
---|
1071 | protected void appearanceFinal() { |
---|
1072 | removeAll(); |
---|
1073 | setLayout(new BorderLayout()); |
---|
1074 | setUpFinal(); |
---|
1075 | } |
---|
1076 | |
---|
1077 | protected void setUpFinal() { |
---|
1078 | if (m_analysisPanel == null) { |
---|
1079 | m_analysisPanel = new AnalysisPanel(); |
---|
1080 | } |
---|
1081 | add(m_analysisPanel, BorderLayout.CENTER); |
---|
1082 | } |
---|
1083 | |
---|
1084 | protected void appearanceDesign() { |
---|
1085 | removeAll(); |
---|
1086 | m_visual = new BeanVisual("CostBenefitAnalysis", |
---|
1087 | BeanVisual.ICON_PATH+"ModelPerformanceChart.gif", |
---|
1088 | BeanVisual.ICON_PATH |
---|
1089 | +"ModelPerformanceChart_animated.gif"); |
---|
1090 | setLayout(new BorderLayout()); |
---|
1091 | add(m_visual, BorderLayout.CENTER); |
---|
1092 | } |
---|
1093 | |
---|
1094 | public void setBeanContext(BeanContext bc) throws PropertyVetoException { |
---|
1095 | m_beanContext = bc; |
---|
1096 | m_design = m_beanContext.isDesignTime(); |
---|
1097 | if (m_design) { |
---|
1098 | appearanceDesign(); |
---|
1099 | } else { |
---|
1100 | java.awt.GraphicsEnvironment ge = |
---|
1101 | java.awt.GraphicsEnvironment.getLocalGraphicsEnvironment(); |
---|
1102 | if (!ge.isHeadless()) { |
---|
1103 | appearanceFinal(); |
---|
1104 | } |
---|
1105 | } |
---|
1106 | } |
---|
1107 | |
---|
1108 | /** |
---|
1109 | * Returns true if, at this time, |
---|
1110 | * the object will accept a connection via the named event |
---|
1111 | * |
---|
1112 | * @param eventName the name of the event in question |
---|
1113 | * @return true if the object will accept a connection |
---|
1114 | */ |
---|
1115 | public boolean connectionAllowed(String eventName) { |
---|
1116 | return (m_listenee == null); |
---|
1117 | } |
---|
1118 | |
---|
1119 | /** |
---|
1120 | * Notify this object that it has been registered as a listener with |
---|
1121 | * a source for recieving events described by the named event |
---|
1122 | * This object is responsible for recording this fact. |
---|
1123 | * |
---|
1124 | * @param eventName the event |
---|
1125 | * @param source the source with which this object has been registered as |
---|
1126 | * a listener |
---|
1127 | */ |
---|
1128 | public void connectionNotification(String eventName, Object source) { |
---|
1129 | if (connectionAllowed(eventName)) { |
---|
1130 | m_listenee = source; |
---|
1131 | } |
---|
1132 | } |
---|
1133 | |
---|
1134 | /** |
---|
1135 | * Returns true if, at this time, |
---|
1136 | * the object will accept a connection according to the supplied |
---|
1137 | * EventSetDescriptor |
---|
1138 | * |
---|
1139 | * @param esd the EventSetDescriptor |
---|
1140 | * @return true if the object will accept a connection |
---|
1141 | */ |
---|
1142 | public boolean connectionAllowed(EventSetDescriptor esd) { |
---|
1143 | return connectionAllowed(esd.getName()); |
---|
1144 | } |
---|
1145 | |
---|
1146 | /** |
---|
1147 | * Notify this object that it has been deregistered as a listener with |
---|
1148 | * a source for named event. This object is responsible |
---|
1149 | * for recording this fact. |
---|
1150 | * |
---|
1151 | * @param eventName the event |
---|
1152 | * @param source the source with which this object has been registered as |
---|
1153 | * a listener |
---|
1154 | */ |
---|
1155 | public void disconnectionNotification(String eventName, Object source) { |
---|
1156 | if (m_listenee == source) { |
---|
1157 | m_listenee = null; |
---|
1158 | } |
---|
1159 | |
---|
1160 | } |
---|
1161 | |
---|
1162 | /** |
---|
1163 | * Get the custom (descriptive) name for this bean (if one has been set) |
---|
1164 | * |
---|
1165 | * @return the custom name (or the default name) |
---|
1166 | */ |
---|
1167 | public String getCustomName() { |
---|
1168 | return m_visual.getText(); |
---|
1169 | } |
---|
1170 | |
---|
1171 | /** |
---|
1172 | * Returns true if. at this time, the bean is busy with some |
---|
1173 | * (i.e. perhaps a worker thread is performing some calculation). |
---|
1174 | * |
---|
1175 | * @return true if the bean is busy. |
---|
1176 | */ |
---|
1177 | public boolean isBusy() { |
---|
1178 | return false; |
---|
1179 | } |
---|
1180 | |
---|
1181 | /** |
---|
1182 | * Set a custom (descriptive) name for this bean |
---|
1183 | * |
---|
1184 | * @param name the name to use |
---|
1185 | */ |
---|
1186 | public void setCustomName(String name) { |
---|
1187 | m_visual.setText(name); |
---|
1188 | } |
---|
1189 | |
---|
1190 | /** |
---|
1191 | * Set a logger |
---|
1192 | * |
---|
1193 | * @param logger a <code>weka.gui.Logger</code> value |
---|
1194 | */ |
---|
1195 | public void setLog(Logger logger) { |
---|
1196 | // we don't need to do any logging |
---|
1197 | } |
---|
1198 | |
---|
1199 | /** |
---|
1200 | * Stop any processing that the bean might be doing. |
---|
1201 | */ |
---|
1202 | public void stop() { |
---|
1203 | // nothing to do here |
---|
1204 | } |
---|
1205 | |
---|
1206 | public static void main(String[] args) { |
---|
1207 | try { |
---|
1208 | Instances train = new Instances(new java.io.BufferedReader(new java.io.FileReader(args[0]))); |
---|
1209 | train.setClassIndex(train.numAttributes() - 1); |
---|
1210 | weka.classifiers.evaluation.ThresholdCurve tc = |
---|
1211 | new weka.classifiers.evaluation.ThresholdCurve(); |
---|
1212 | weka.classifiers.evaluation.EvaluationUtils eu = |
---|
1213 | new weka.classifiers.evaluation.EvaluationUtils(); |
---|
1214 | //weka.classifiers.Classifier classifier = new weka.classifiers.functions.Logistic(); |
---|
1215 | weka.classifiers.Classifier classifier = new weka.classifiers.bayes.NaiveBayes(); |
---|
1216 | FastVector predictions = new FastVector(); |
---|
1217 | eu.setSeed(1); |
---|
1218 | predictions.appendElements(eu.getCVPredictions(classifier, train, 10)); |
---|
1219 | Instances result = tc.getCurve(predictions, 0); |
---|
1220 | PlotData2D pd = new PlotData2D(result); |
---|
1221 | pd.m_alwaysDisplayPointsOfThisSize = 10; |
---|
1222 | |
---|
1223 | boolean[] connectPoints = new boolean[result.numInstances()]; |
---|
1224 | for (int i = 1; i < connectPoints.length; i++) { |
---|
1225 | connectPoints[i] = true; |
---|
1226 | } |
---|
1227 | pd.setConnectPoints(connectPoints); |
---|
1228 | final javax.swing.JFrame jf = |
---|
1229 | new javax.swing.JFrame("CostBenefitTest"); |
---|
1230 | jf.setSize(1000,600); |
---|
1231 | //jf.pack(); |
---|
1232 | jf.getContentPane().setLayout(new BorderLayout()); |
---|
1233 | final CostBenefitAnalysis.AnalysisPanel analysisPanel = |
---|
1234 | new CostBenefitAnalysis.AnalysisPanel(); |
---|
1235 | |
---|
1236 | jf.getContentPane().add(analysisPanel, BorderLayout.CENTER); |
---|
1237 | jf.addWindowListener(new java.awt.event.WindowAdapter() { |
---|
1238 | public void windowClosing(java.awt.event.WindowEvent e) { |
---|
1239 | jf.dispose(); |
---|
1240 | System.exit(0); |
---|
1241 | } |
---|
1242 | }); |
---|
1243 | |
---|
1244 | jf.setVisible(true); |
---|
1245 | |
---|
1246 | analysisPanel.setDataSet(pd, train.classAttribute()); |
---|
1247 | |
---|
1248 | } catch (Exception ex) { |
---|
1249 | ex.printStackTrace(); |
---|
1250 | } |
---|
1251 | |
---|
1252 | } |
---|
1253 | } |
---|