source: src/main/java/weka/gui/beans/CostBenefitAnalysis.java @ 11

Last change on this file since 11 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 43.0 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    CostBenefitAnalysis.java
19 *    Copyright (C) 2009 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.gui.beans;
24
25import java.awt.BorderLayout;
26import java.awt.Color;
27import java.awt.Dimension;
28import java.awt.FlowLayout;
29import java.awt.GridLayout;
30import java.awt.Graphics;
31import java.awt.event.ActionEvent;
32import java.awt.event.ActionListener;
33import java.awt.event.FocusEvent;
34import java.awt.event.FocusListener;
35import java.beans.EventSetDescriptor;
36import java.beans.PropertyVetoException;
37import java.beans.VetoableChangeListener;
38import java.beans.beancontext.BeanContext;
39import java.beans.beancontext.BeanContextChild;
40import java.beans.beancontext.BeanContextChildSupport;
41import java.io.Serializable;
42import java.util.Enumeration;
43import java.util.Vector;
44
45import javax.swing.BorderFactory;
46import javax.swing.ButtonGroup;
47import javax.swing.JButton;
48import javax.swing.JFrame;
49import javax.swing.JLabel;
50import javax.swing.JPanel;
51import javax.swing.JRadioButton;
52import javax.swing.JSlider;
53import javax.swing.JTextField;
54import javax.swing.SwingConstants;
55import javax.swing.event.ChangeEvent;
56import javax.swing.event.ChangeListener;
57
58import weka.classifiers.evaluation.ThresholdCurve;
59import weka.core.Attribute;
60import weka.core.FastVector;
61import weka.core.Instance;
62import weka.core.DenseInstance;
63import weka.core.Instances;
64import weka.core.Utils;
65import weka.gui.Logger;
66import weka.gui.visualize.VisualizePanel;
67import weka.gui.visualize.Plot2D;
68import weka.gui.visualize.PlotData2D;
69
70
71/**
72 * Bean that aids in analyzing cost/benefit tradeoffs.
73 *
74 * @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
75 * @version $Revision: 6137 $
76 */
77public class CostBenefitAnalysis extends JPanel
78  implements BeanCommon, ThresholdDataListener, Visible, UserRequestAcceptor,
79  Serializable, BeanContextChild {
80 
81  /** For serialization */
82  private static final long serialVersionUID = 8647471654613320469L;
83
84  protected BeanVisual m_visual;
85 
86  protected transient JFrame m_popupFrame;
87
88  protected boolean m_framePoppedUp = false;
89 
90  private transient AnalysisPanel m_analysisPanel;
91 
92  /**
93   * True if this bean's appearance is the design mode appearance
94   */
95  protected boolean m_design;
96
97  /**
98   * BeanContex that this bean might be contained within
99   */
100  protected transient BeanContext m_beanContext = null;
101 
102  /**
103   * BeanContextChild support
104   */
105  protected BeanContextChildSupport m_bcSupport = 
106    new BeanContextChildSupport(this);
107 
108  /**
109   * The object sending us data (we allow only one connection at any one time)
110   */
111  protected Object m_listenee;
112 
113  /**
114   * Inner class for displaying the plots and all control widgets.
115   *
116   * @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
117   */
118  protected static class AnalysisPanel extends JPanel {
119   
120    /** For serialization */
121    private static final long serialVersionUID = 5364871945448769003L;
122
123    /** Displays the performance graphs(s) */
124    protected VisualizePanel m_performancePanel = new VisualizePanel();
125   
126    /** Displays the cost/benefit (profit/loss) graph */
127    protected VisualizePanel m_costBenefitPanel = new VisualizePanel();
128   
129    /**
130     * The class attribute from the data that was used to generate
131     * the threshold curve
132     */
133    protected Attribute m_classAttribute;
134   
135    /** Data for the threshold curve */
136    protected PlotData2D m_masterPlot;
137   
138    /** Data for the cost/benefit curve */
139    protected PlotData2D m_costBenefit;
140   
141    /** The size of the points being plotted */
142    protected int[] m_shapeSizes;
143   
144    /** The index of the previous plotted point that was highlighted */
145    protected int m_previousShapeIndex = -1;
146       
147    /** The slider for adjusting the threshold */
148    protected JSlider m_thresholdSlider = new JSlider(0,100,0);
149   
150    protected JRadioButton m_percPop = new JRadioButton("% of Population");
151    protected JRadioButton m_percOfTarget = new JRadioButton("% of Target (recall)");
152    protected JRadioButton m_threshold = new JRadioButton("Score Threshold");
153   
154    protected JLabel m_percPopLab = new JLabel();
155    protected JLabel m_percOfTargetLab = new JLabel();
156    protected JLabel m_thresholdLab = new JLabel();
157   
158    // Confusion matrix stuff
159    protected JLabel m_conf_predictedA = new JLabel("Predicted (a)", SwingConstants.RIGHT);
160    protected JLabel m_conf_predictedB = new JLabel("Predicted (b)", SwingConstants.RIGHT);
161    protected JLabel m_conf_actualA = new JLabel(" Actual (a):");
162    protected JLabel m_conf_actualB = new JLabel(" Actual (b):");
163    protected ConfusionCell m_conf_aa = new ConfusionCell();
164    protected ConfusionCell m_conf_ab = new ConfusionCell();
165    protected ConfusionCell m_conf_ba = new ConfusionCell();
166    protected ConfusionCell m_conf_bb = new ConfusionCell();
167   
168    // Cost matrix stuff
169    protected JLabel m_cost_predictedA = new JLabel("Predicted (a)", SwingConstants.RIGHT);
170    protected JLabel m_cost_predictedB = new JLabel("Predicted (b)", SwingConstants.RIGHT);
171    protected JLabel m_cost_actualA = new JLabel(" Actual (a)");
172    protected JLabel m_cost_actualB = new JLabel(" Actual (b)");
173    protected JTextField m_cost_aa = new JTextField("0.0", 5);
174    protected JTextField m_cost_ab = new JTextField("1.0", 5);
175    protected JTextField m_cost_ba = new JTextField("1.0", 5);
176    protected JTextField m_cost_bb = new JTextField("0.0" ,5);
177    protected JButton m_maximizeCB = new JButton("Maximize Cost/Benefit");
178    protected JButton m_minimizeCB = new JButton("Minimize Cost/Benefit");
179    protected JRadioButton m_costR = new JRadioButton("Cost");
180    protected JRadioButton m_benefitR = new JRadioButton("Benefit");
181    protected JLabel m_costBenefitL = new JLabel("Cost: ", SwingConstants.RIGHT);
182    protected JLabel m_costBenefitV = new JLabel("0");
183    protected JLabel m_randomV = new JLabel("0");
184    protected JLabel m_gainV = new JLabel("0");
185   
186    protected int m_originalPopSize;
187   
188    /** Population text field */
189    protected JTextField m_totalPopField = new JTextField(6);
190    protected int m_totalPopPrevious;
191   
192    /** Classification accuracy */
193    protected JLabel m_classificationAccV = new JLabel("-");
194   
195    // Only update curve & stats if values in cost matrix have changed
196    protected double m_tpPrevious;
197    protected double m_fpPrevious;
198    protected double m_tnPrevious;
199    protected double m_fnPrevious;
200   
201    /**
202     * Inner class for handling a single cell in the confusion matrix.
203     * Displays the value, value as a percentage of total population and
204     * graphical depiction of percentage.
205     *
206     * @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
207     */
208    protected static class ConfusionCell extends JPanel {
209
210      /** For serialization */
211      private static final long serialVersionUID = 6148640235434494767L;
212     
213      private JLabel m_conf_cell = new JLabel("-", SwingConstants.RIGHT);
214      JLabel m_conf_perc = new JLabel("-", SwingConstants.RIGHT);
215     
216      private JPanel m_percentageP;
217     
218      protected double m_percentage = 0;
219     
220      public ConfusionCell() {
221        setLayout(new BorderLayout());
222        setBorder(BorderFactory.createEtchedBorder());
223       
224        add(m_conf_cell, BorderLayout.NORTH);
225       
226        m_percentageP = new JPanel() {
227          public void paintComponent(Graphics gx) {
228            super.paintComponent(gx);
229           
230            if (m_percentage > 0) {
231              gx.setColor(Color.BLUE);
232              int height = this.getHeight();
233              double width = this.getWidth();
234              int barWidth = (int)(m_percentage * width); 
235              gx.fillRect(0, 0, barWidth, height);
236            }
237          }
238        };
239       
240        Dimension d = new Dimension(30,5);
241        m_percentageP.setMinimumSize(d);
242        m_percentageP.setPreferredSize(d);
243        JPanel percHolder = new JPanel();
244        percHolder.setLayout(new BorderLayout());
245        percHolder.add(m_percentageP, BorderLayout.CENTER);
246        percHolder.add(m_conf_perc, BorderLayout.EAST);
247       
248        add(percHolder, BorderLayout.SOUTH);
249      }
250     
251      /**
252       * Set the value of a cell.
253       *
254       * @param cellValue the value of the cell
255       * @param max the max (for setting value as a percentage)
256       * @param scaleFactor scale the value by this amount
257       * @param precision precision for the percentage value
258       */
259      public void setCellValue(double cellValue, double max, double scaleFactor, int precision) {
260        if (!Utils.isMissingValue(cellValue)) {
261          m_percentage = cellValue / max;
262        } else {
263          m_percentage = 0;
264        }
265       
266        m_conf_cell.setText(Utils.doubleToString((cellValue * scaleFactor), 0));
267        m_conf_perc.setText(Utils.doubleToString(m_percentage * 100.0, precision) + "%");
268       
269        // refresh the percentage bar
270        m_percentageP.repaint();
271      }
272    }
273   
274    public AnalysisPanel() {
275      setLayout(new BorderLayout());
276      m_performancePanel.setShowAttBars(false);
277      m_performancePanel.setShowClassPanel(false);
278      m_costBenefitPanel.setShowAttBars(false);
279      m_costBenefitPanel.setShowClassPanel(false);
280     
281      Dimension size = new Dimension(500, 400);
282      m_performancePanel.setPreferredSize(size);
283      m_performancePanel.setMinimumSize(size);
284     
285      size = new Dimension(500, 400);
286      m_costBenefitPanel.setMinimumSize(size);
287      m_costBenefitPanel.setPreferredSize(size);
288     
289      m_thresholdSlider.addChangeListener(new ChangeListener() {
290        public void stateChanged(ChangeEvent e) {
291          updateInfoForSliderValue((double)m_thresholdSlider.getValue() / 100.0);
292        }
293      });
294     
295      JPanel plotHolder = new JPanel();
296      plotHolder.setLayout(new GridLayout(1,2));     
297      plotHolder.add(m_performancePanel);
298      plotHolder.add(m_costBenefitPanel);
299      add(plotHolder, BorderLayout.CENTER);
300     
301      JPanel lowerPanel = new JPanel();
302      lowerPanel.setLayout(new BorderLayout());
303     
304      ButtonGroup bGroup = new ButtonGroup();
305      bGroup.add(m_percPop);
306      bGroup.add(m_percOfTarget);
307      bGroup.add(m_threshold);
308     
309      ButtonGroup bGroup2 = new ButtonGroup();
310      bGroup2.add(m_costR);
311      bGroup2.add(m_benefitR);
312      ActionListener rl = new ActionListener() {
313        public void actionPerformed(ActionEvent e) {
314          if (m_costR.isSelected()) {
315            m_costBenefitL.setText("Cost: ");
316          } else {
317            m_costBenefitL.setText("Benefit: ");
318          }
319
320          double gain = Double.parseDouble(m_gainV.getText());
321          gain = -gain;
322          m_gainV.setText(Utils.doubleToString(gain, 2));
323        }
324      };
325      m_costR.addActionListener(rl);
326      m_benefitR.addActionListener(rl);
327      m_costR.setSelected(true);
328     
329      m_percPop.setSelected(true);
330      JPanel threshPanel = new JPanel();
331      threshPanel.setLayout(new BorderLayout());
332      JPanel radioHolder = new JPanel();
333      radioHolder.setLayout(new FlowLayout());
334      radioHolder.add(m_percPop);
335      radioHolder.add(m_percOfTarget);
336      radioHolder.add(m_threshold);
337      threshPanel.add(radioHolder, BorderLayout.NORTH);
338      threshPanel.add(m_thresholdSlider, BorderLayout.SOUTH);
339     
340      JPanel threshInfoPanel = new JPanel();
341      threshInfoPanel.setLayout(new GridLayout(3,2));
342      threshInfoPanel.add(new JLabel("% of Population: ", SwingConstants.RIGHT));
343      threshInfoPanel.add(m_percPopLab);
344      threshInfoPanel.add(new JLabel("% of Target: ", SwingConstants.RIGHT));
345      threshInfoPanel.add(m_percOfTargetLab);
346      threshInfoPanel.add(new JLabel("Score Threshold: ", SwingConstants.RIGHT));
347      threshInfoPanel.add(m_thresholdLab);
348     
349      JPanel threshHolder = new JPanel();
350      threshHolder.setBorder(BorderFactory.createTitledBorder("Threshold"));
351      threshHolder.setLayout(new BorderLayout());
352      threshHolder.add(threshPanel, BorderLayout.CENTER);
353      threshHolder.add(threshInfoPanel, BorderLayout.EAST);
354     
355      lowerPanel.add(threshHolder, BorderLayout.NORTH);
356     
357      // holder for the two matrixes
358      JPanel matrixHolder = new JPanel();
359      matrixHolder.setLayout(new GridLayout(1,2));
360     
361      // confusion matrix
362      JPanel confusionPanel = new JPanel();
363      confusionPanel.setLayout(new GridLayout(3,3));
364      confusionPanel.add(m_conf_predictedA);
365      confusionPanel.add(m_conf_predictedB);
366      confusionPanel.add(new JLabel()); // dummy
367      confusionPanel.add(m_conf_aa);
368      confusionPanel.add(m_conf_ab);
369      confusionPanel.add(m_conf_actualA);
370      confusionPanel.add(m_conf_ba);
371      confusionPanel.add(m_conf_bb);
372      confusionPanel.add(m_conf_actualB);
373      JPanel tempHolderCA = new JPanel();
374      tempHolderCA.setLayout(new BorderLayout());
375      tempHolderCA.setBorder(BorderFactory.createTitledBorder("Confusion Matrix"));
376      tempHolderCA.add(confusionPanel, BorderLayout.CENTER);
377     
378      JPanel accHolder = new JPanel();
379      accHolder.setLayout(new FlowLayout(FlowLayout.LEFT));
380      accHolder.add(new JLabel("Classification Accuracy: "));
381      accHolder.add(m_classificationAccV);
382      tempHolderCA.add(accHolder, BorderLayout.SOUTH);
383     
384      matrixHolder.add(tempHolderCA);
385     
386      // cost matrix
387      JPanel costPanel = new JPanel();
388      costPanel.setBorder(BorderFactory.createTitledBorder("Cost Matrix"));
389      costPanel.setLayout(new BorderLayout());
390     
391      JPanel cmHolder = new JPanel();
392      cmHolder.setLayout(new GridLayout(3, 3));
393      cmHolder.add(m_cost_predictedA);     
394      cmHolder.add(m_cost_predictedB);
395      cmHolder.add(new JLabel()); // dummy
396      cmHolder.add(m_cost_aa);
397      cmHolder.add(m_cost_ab);
398      cmHolder.add(m_cost_actualA);
399      cmHolder.add(m_cost_ba);
400      cmHolder.add(m_cost_bb);
401      cmHolder.add(m_cost_actualB);
402      costPanel.add(cmHolder, BorderLayout.CENTER);
403     
404      FocusListener fl = new FocusListener() {
405        public void focusGained(FocusEvent e) {
406         
407        }
408       
409        public void focusLost(FocusEvent e) {
410          if (constructCostBenefitData()) {
411            try {
412              m_costBenefitPanel.setMasterPlot(m_costBenefit);
413              m_costBenefitPanel.validate(); m_costBenefitPanel.repaint();
414            } catch (Exception ex) {
415              ex.printStackTrace();
416            }
417            updateCostBenefit();
418          }
419        }
420      };
421     
422      ActionListener al = new ActionListener() {
423        public void actionPerformed(ActionEvent e) {
424          if (constructCostBenefitData()) {
425            try {
426              m_costBenefitPanel.setMasterPlot(m_costBenefit);
427              m_costBenefitPanel.validate(); m_costBenefitPanel.repaint();
428            } catch (Exception ex) {
429              ex.printStackTrace();
430            }
431            updateCostBenefit();
432          }
433        }
434      };
435           
436      m_cost_aa.addFocusListener(fl);
437      m_cost_aa.addActionListener(al);
438      m_cost_ab.addFocusListener(fl);
439      m_cost_ab.addActionListener(al);
440      m_cost_ba.addFocusListener(fl);
441      m_cost_ba.addActionListener(al);
442      m_cost_bb.addFocusListener(fl);
443      m_cost_bb.addActionListener(al);
444     
445      m_totalPopField.addFocusListener(fl);
446      m_totalPopField.addActionListener(al);
447     
448      JPanel cbHolder = new JPanel();
449      cbHolder.setLayout(new BorderLayout());
450      JPanel tempP = new JPanel();
451      tempP.setLayout(new GridLayout(3, 2));
452      tempP.add(m_costBenefitL);
453      tempP.add(m_costBenefitV);
454      tempP.add(new JLabel("Random: ", SwingConstants.RIGHT));
455      tempP.add(m_randomV);
456      tempP.add(new JLabel("Gain: ", SwingConstants.RIGHT));
457      tempP.add(m_gainV);
458      cbHolder.add(tempP, BorderLayout.NORTH);
459      JPanel butHolder = new JPanel();
460      butHolder.setLayout(new GridLayout(2, 1));
461      butHolder.add(m_maximizeCB);
462      butHolder.add(m_minimizeCB);
463      m_maximizeCB.addActionListener(new ActionListener() {
464        public void actionPerformed(ActionEvent e) {
465          findMaxMinCB(true);
466        }
467      });
468     
469      m_minimizeCB.addActionListener(new ActionListener() {
470        public void actionPerformed(ActionEvent e) {
471          findMaxMinCB(false);
472        }
473      });
474     
475      cbHolder.add(butHolder, BorderLayout.SOUTH);
476      costPanel.add(cbHolder, BorderLayout.EAST);
477     
478      JPanel popCBR = new JPanel();
479      popCBR.setLayout(new GridLayout(1, 2));
480      JPanel popHolder = new JPanel();
481      popHolder.setLayout(new FlowLayout(FlowLayout.LEFT));
482      popHolder.add(new JLabel("Total Population: "));
483      popHolder.add(m_totalPopField);
484     
485      JPanel radioHolder2 = new JPanel();
486      radioHolder2.setLayout(new FlowLayout(FlowLayout.RIGHT));
487      radioHolder2.add(m_costR);
488      radioHolder2.add(m_benefitR);
489      popCBR.add(popHolder);
490      popCBR.add(radioHolder2);
491     
492      costPanel.add(popCBR, BorderLayout.SOUTH);
493     
494      matrixHolder.add(costPanel);
495     
496     
497      lowerPanel.add(matrixHolder, BorderLayout.SOUTH);
498     
499
500
501//      popAccHolder.add(popHolder);
502     
503      //popAccHolder.add(accHolder);
504     
505      /*JPanel lowerPanel2 = new JPanel();
506      lowerPanel2.setLayout(new BorderLayout());
507      lowerPanel2.add(lowerPanel, BorderLayout.NORTH);
508      lowerPanel2.add(popAccHolder, BorderLayout.SOUTH); */
509     
510      add(lowerPanel, BorderLayout.SOUTH);
511     
512    }
513   
514    private void findMaxMinCB(boolean max) {
515      double maxMin = (max) 
516      ? Double.NEGATIVE_INFINITY 
517          : Double.POSITIVE_INFINITY;
518     
519      Instances cBCurve = m_costBenefit.getPlotInstances();
520      int maxMinIndex = 0;
521     
522      for (int i = 0; i < cBCurve.numInstances(); i++) {
523        Instance current = cBCurve.instance(i);
524        if (max) {
525          if (current.value(1) > maxMin) {
526            maxMin = current.value(1);
527            maxMinIndex = i;
528          }
529        } else {
530          if (current.value(1) < maxMin) {
531            maxMin = current.value(1);
532            maxMinIndex = i;
533          }
534        }
535      }
536     
537     
538      // set the slider to the correct position
539      int indexOfSampleSize = 
540        m_masterPlot.getPlotInstances().attribute(ThresholdCurve.SAMPLE_SIZE_NAME).index();
541      int indexOfPercOfTarget = 
542        m_masterPlot.getPlotInstances().attribute(ThresholdCurve.RECALL_NAME).index();
543      int indexOfThreshold =
544        m_masterPlot.getPlotInstances().attribute(ThresholdCurve.THRESHOLD_NAME).index();
545      int indexOfMetric;
546     
547      if (m_percPop.isSelected()) {
548        indexOfMetric = indexOfSampleSize;           
549      } else if (m_percOfTarget.isSelected()) {
550        indexOfMetric = indexOfPercOfTarget;
551      } else {
552        indexOfMetric = indexOfThreshold;
553      }
554     
555      double valueOfMetric = m_masterPlot.getPlotInstances().instance(maxMinIndex).value(indexOfMetric);
556      valueOfMetric *= 100.0;
557     
558      // set the approximate location of the slider
559      m_thresholdSlider.setValue((int)valueOfMetric);
560     
561      // make sure the actual values relate to the true min/max rather
562      // than being off due to slider location error.
563      updateInfoGivenIndex(maxMinIndex);
564    }
565   
566    private void updateCostBenefit() {
567      double value = (double)m_thresholdSlider.getValue() / 100.0;
568      Instances plotInstances = m_masterPlot.getPlotInstances();
569      int indexOfSampleSize = 
570        m_masterPlot.getPlotInstances().attribute(ThresholdCurve.SAMPLE_SIZE_NAME).index();
571      int indexOfPercOfTarget = 
572        m_masterPlot.getPlotInstances().attribute(ThresholdCurve.RECALL_NAME).index();
573      int indexOfThreshold =
574        m_masterPlot.getPlotInstances().attribute(ThresholdCurve.THRESHOLD_NAME).index();
575      int indexOfMetric;
576     
577      if (m_percPop.isSelected()) {
578        indexOfMetric = indexOfSampleSize;           
579      } else if (m_percOfTarget.isSelected()) {
580        indexOfMetric = indexOfPercOfTarget;
581      } else {
582        indexOfMetric = indexOfThreshold;
583      }
584     
585      int index = findIndexForValue(value, plotInstances, indexOfMetric);
586      updateCBRandomGainInfo(index);
587    }
588   
589    private void updateCBRandomGainInfo(int index) {
590      double requestedPopSize = m_originalPopSize;
591      try {
592        requestedPopSize = Double.parseDouble(m_totalPopField.getText());
593      } catch (NumberFormatException e) {}
594      double scaleFactor = requestedPopSize / m_originalPopSize;
595     
596      double CB = m_costBenefit.
597        getPlotInstances().instance(index).value(1);
598      m_costBenefitV.setText(Utils.doubleToString(CB,2));
599     
600      double totalRandomCB = 0.0;
601      Instance first = m_masterPlot.getPlotInstances().instance(0);
602      double totalPos = first.value(m_masterPlot.getPlotInstances().
603          attribute(ThresholdCurve.TRUE_POS_NAME).index()) * scaleFactor;
604      double totalNeg = first.value(m_masterPlot.getPlotInstances().
605          attribute(ThresholdCurve.FALSE_POS_NAME)) * scaleFactor;
606
607      double posInSample = (totalPos * (Double.parseDouble(m_percPopLab.getText()) / 100.0));
608      double negInSample = (totalNeg * (Double.parseDouble(m_percPopLab.getText()) / 100.0));
609      double posOutSample = totalPos - posInSample;
610      double negOutSample = totalNeg - negInSample;
611     
612      double tpCost = 0.0;
613      try {
614        tpCost = Double.parseDouble(m_cost_aa.getText());
615      } catch (NumberFormatException n) {}
616      double fpCost = 0.0;
617      try {
618        fpCost = Double.parseDouble(m_cost_ba.getText());
619      } catch (NumberFormatException n) {}
620      double tnCost = 0.0;
621      try {
622        tnCost = Double.parseDouble(m_cost_bb.getText());
623      } catch (NumberFormatException n) {}
624      double fnCost = 0.0;
625      try {
626        fnCost = Double.parseDouble(m_cost_ab.getText());
627      } catch (NumberFormatException n) {}
628           
629      totalRandomCB += posInSample * tpCost;
630      totalRandomCB += negInSample * fpCost;
631      totalRandomCB += posOutSample * fnCost;
632      totalRandomCB += negOutSample * tnCost;
633     
634      m_randomV.setText(Utils.doubleToString(totalRandomCB, 2));
635      double gain = (m_costR.isSelected()) 
636      ? totalRandomCB - CB
637          : CB - totalRandomCB;
638      m_gainV.setText(Utils.doubleToString(gain, 2));
639     
640      // update classification rate
641      Instance currentInst = m_masterPlot.getPlotInstances().instance(index);
642      double tp = currentInst.value(m_masterPlot.getPlotInstances().
643          attribute(ThresholdCurve.TRUE_POS_NAME).index());
644      double tn = currentInst.value(m_masterPlot.getPlotInstances().
645          attribute(ThresholdCurve.TRUE_NEG_NAME).index());
646      m_classificationAccV.
647        setText(Utils.doubleToString((tp + tn) / (totalPos + totalNeg) * 100.0, 4) + "%");     
648    }
649   
650    private void updateInfoGivenIndex(int index) {
651      Instances plotInstances = m_masterPlot.getPlotInstances();
652      int indexOfSampleSize = 
653        m_masterPlot.getPlotInstances().attribute(ThresholdCurve.SAMPLE_SIZE_NAME).index();
654      int indexOfPercOfTarget = 
655        m_masterPlot.getPlotInstances().attribute(ThresholdCurve.RECALL_NAME).index();
656      int indexOfThreshold =
657        m_masterPlot.getPlotInstances().attribute(ThresholdCurve.THRESHOLD_NAME).index();
658     
659      // update labels
660      m_percPopLab.setText(Utils.
661          doubleToString(100.0 * plotInstances.instance(index).value(indexOfSampleSize), 4));
662      m_percOfTargetLab.setText(Utils.doubleToString(
663          100.0 * plotInstances.instance(index).value(indexOfPercOfTarget), 4));
664      m_thresholdLab.setText(Utils.doubleToString(plotInstances.instance(index).value(indexOfThreshold), 4));
665      /*if (m_percPop.isSelected()) {
666        m_percPopLab.setText(Utils.doubleToString(100.0 * value, 4));
667      } else if (m_percOfTarget.isSelected()) {
668        m_percOfTargetLab.setText(Utils.doubleToString(100.0 * value, 4));
669      } else {
670        m_thresholdLab.setText(Utils.doubleToString(value, 4));
671      }*/
672     
673      // Update the highlighted point on the graphs */
674      if (m_previousShapeIndex >= 0) {
675        m_shapeSizes[m_previousShapeIndex] = 1;
676      }
677     
678      m_shapeSizes[index] = 10;
679      m_previousShapeIndex = index;
680     
681      // Update the confusion matrix
682//      double totalInstances =
683      int tp = plotInstances.attribute(ThresholdCurve.TRUE_POS_NAME).index();
684      int fp = plotInstances.attribute(ThresholdCurve.FALSE_POS_NAME).index();
685      int tn = plotInstances.attribute(ThresholdCurve.TRUE_NEG_NAME).index();
686      int fn = plotInstances.attribute(ThresholdCurve.FALSE_NEG_NAME).index();
687      Instance temp = plotInstances.instance(index);
688      double totalInstances = temp.value(tp) + temp.value(fp) + temp.value(tn) + temp.value(fn);
689      // get the value out of the total pop field (if possible)
690      double requestedPopSize = totalInstances;
691      try {
692        requestedPopSize = Double.parseDouble(m_totalPopField.getText());
693      } catch (NumberFormatException e) {}
694     
695      m_conf_aa.setCellValue(temp.value(tp), totalInstances, 
696          requestedPopSize / totalInstances, 2);
697      m_conf_ab.setCellValue(temp.value(fn), totalInstances, 
698          requestedPopSize / totalInstances, 2);
699      m_conf_ba.setCellValue(temp.value(fp), totalInstances, 
700          requestedPopSize / totalInstances, 2);
701      m_conf_bb.setCellValue(temp.value(tn), totalInstances, 
702            requestedPopSize / totalInstances, 2);
703     
704      updateCBRandomGainInfo(index);
705     
706      repaint();
707    }
708   
709    private void updateInfoForSliderValue(double value) {
710      int indexOfSampleSize = 
711        m_masterPlot.getPlotInstances().attribute(ThresholdCurve.SAMPLE_SIZE_NAME).index();
712      int indexOfPercOfTarget = 
713        m_masterPlot.getPlotInstances().attribute(ThresholdCurve.RECALL_NAME).index();
714      int indexOfThreshold =
715        m_masterPlot.getPlotInstances().attribute(ThresholdCurve.THRESHOLD_NAME).index();
716      int indexOfMetric;
717     
718      if (m_percPop.isSelected()) {
719        indexOfMetric = indexOfSampleSize;           
720      } else if (m_percOfTarget.isSelected()) {
721        indexOfMetric = indexOfPercOfTarget;
722      } else {
723        indexOfMetric = indexOfThreshold;
724      }
725     
726      Instances plotInstances = m_masterPlot.getPlotInstances();
727      int index = findIndexForValue(value, plotInstances, indexOfMetric);
728      updateInfoGivenIndex(index);
729    }
730   
731    private int findIndexForValue(double value, Instances plotInstances, int indexOfMetric) {
732      // binary search
733      // threshold curve is sorted ascending in the threshold (thus
734      // descending for recall and pop size)
735      int index = -1;
736      int lower = 0;
737      int upper = plotInstances.numInstances() - 1;
738      int mid = (upper - lower) / 2;
739      boolean done = false;
740      while (!done) {
741        if (upper - lower <= 1) {
742         
743          // choose the one closest to the value
744          double comp1 = plotInstances.instance(upper).value(indexOfMetric);
745          double comp2 = plotInstances.instance(lower).value(indexOfMetric);
746          if (Math.abs(comp1 - value) < Math.abs(comp2 - value)) {
747            index = upper;
748          } else {
749            index = lower;
750          }
751         
752          break;
753        }
754        double comparisonVal = plotInstances.instance(mid).value(indexOfMetric);
755        if (value > comparisonVal) {
756          if (m_threshold.isSelected()) {
757            lower = mid;
758            mid += (upper - lower) / 2;
759          } else {
760            upper = mid;
761            mid -= (upper - lower) / 2;
762          }
763        } else if (value < comparisonVal) {
764          if (m_threshold.isSelected()) {
765            upper = mid;
766            mid -= (upper - lower) / 2;
767          } else {
768            lower = mid;
769            mid += (upper - lower) / 2;
770          }
771        } else {
772          index = mid;
773          done = true;
774        }
775      }
776     
777      // now check for ties in the appropriate direction
778      if (!m_threshold.isSelected()) {
779        while (index + 1 < plotInstances.numInstances()) {
780          if (plotInstances.instance(index + 1).value(indexOfMetric) == 
781            plotInstances.instance(index).value(indexOfMetric)) {
782            index++;
783          } else {
784            break;
785          }
786        }
787      } else {
788        while (index - 1 >= 0) {
789          if (plotInstances.instance(index - 1).value(indexOfMetric) == 
790            plotInstances.instance(index).value(indexOfMetric)) {
791            index--;
792          } else {
793            break;
794          } 
795        }
796      }
797      return index;
798    }
799   
800    /**
801     * Set the threshold data for the panel to use.
802     *
803     * @param data PlotData2D object encapsulating the threshold data.
804     * @param classAtt the class attribute from the original data used to generate
805     * the threshold data.
806     * @throws Exception if something goes wrong.
807     */
808    public synchronized void setDataSet(PlotData2D data, Attribute classAtt) throws Exception {     
809      // make a copy of the PlotData2D object
810      m_masterPlot = new PlotData2D(data.getPlotInstances());
811      boolean[] connectPoints = new boolean[m_masterPlot.getPlotInstances().numInstances()];
812      for (int i = 1; i < connectPoints.length; i++) {
813        connectPoints[i] = true;
814      }
815      m_masterPlot.setConnectPoints(connectPoints);
816
817      m_masterPlot.m_alwaysDisplayPointsOfThisSize = 10;
818      setClassForConfusionMatrix(classAtt);
819      m_performancePanel.setMasterPlot(m_masterPlot);
820      m_performancePanel.validate(); m_performancePanel.repaint();
821
822      m_shapeSizes = new int[m_masterPlot.getPlotInstances().numInstances()];
823      for (int i = 0; i < m_shapeSizes.length; i++) {
824        m_shapeSizes[i] = 1;
825      }
826      m_masterPlot.setShapeSize(m_shapeSizes);
827      constructCostBenefitData();
828      m_costBenefitPanel.setMasterPlot(m_costBenefit);
829      m_costBenefitPanel.validate(); m_costBenefitPanel.repaint();
830
831      m_totalPopPrevious = 0;
832      m_fpPrevious = 0;
833      m_tpPrevious = 0;
834      m_tnPrevious = 0;
835      m_fnPrevious = 0;
836      m_previousShapeIndex = -1;
837
838      // set the total population size
839      Instance first = m_masterPlot.getPlotInstances().instance(0);
840      double totalPos = first.value(m_masterPlot.getPlotInstances().
841          attribute(ThresholdCurve.TRUE_POS_NAME).index());
842      double totalNeg = first.value(m_masterPlot.getPlotInstances().
843          attribute(ThresholdCurve.FALSE_POS_NAME));
844      m_originalPopSize = (int)(totalPos + totalNeg);
845      m_totalPopField.setText("" + m_originalPopSize);
846
847      m_performancePanel.setYIndex(5);
848      m_performancePanel.setXIndex(10);
849      m_costBenefitPanel.setXIndex(0);
850      m_costBenefitPanel.setYIndex(1);
851      //      System.err.println(m_masterPlot.getPlotInstances());
852      updateInfoForSliderValue((double)m_thresholdSlider.getValue() / 100.0);
853    }
854   
855    private void setClassForConfusionMatrix(Attribute classAtt) {
856      m_classAttribute = classAtt;
857      m_conf_actualA.setText(" Actual (a): " + classAtt.value(0));
858      m_conf_actualA.setToolTipText(classAtt.value(0));
859      String negClasses = "";
860      for (int i = 1; i < classAtt.numValues(); i++) {
861        negClasses += classAtt.value(i);
862        if (i < classAtt.numValues() - 1) {
863          negClasses += ",";
864        }
865      }
866      m_conf_actualB.setText(" Actual (b): " + negClasses);
867      m_conf_actualB.setToolTipText(negClasses);
868    }
869   
870    private boolean constructCostBenefitData() {
871      double tpCost = 0.0;
872      try {
873        tpCost = Double.parseDouble(m_cost_aa.getText());
874      } catch (NumberFormatException n) {}
875      double fpCost = 0.0;
876      try {
877        fpCost = Double.parseDouble(m_cost_ba.getText());
878      } catch (NumberFormatException n) {}
879      double tnCost = 0.0;
880      try {
881        tnCost = Double.parseDouble(m_cost_bb.getText());
882      } catch (NumberFormatException n) {}
883      double fnCost = 0.0;
884      try {
885        fnCost = Double.parseDouble(m_cost_ab.getText());
886      } catch (NumberFormatException n) {}
887     
888      double requestedPopSize = m_originalPopSize;
889      try {
890        requestedPopSize = Double.parseDouble(m_totalPopField.getText());
891      } catch (NumberFormatException e) {}
892     
893      double scaleFactor = 1.0;
894      if (m_originalPopSize != 0) {
895        scaleFactor = requestedPopSize / m_originalPopSize;
896      }
897     
898      if (tpCost == m_tpPrevious && fpCost == m_fpPrevious &&
899          tnCost == m_tnPrevious && fnCost == m_fnPrevious &&
900          requestedPopSize == m_totalPopPrevious) {
901        return false;
902      }
903     
904      // First construct some Instances for the curve
905      FastVector fv = new FastVector();
906      fv.addElement(new Attribute("Sample Size"));
907      fv.addElement(new Attribute("Cost/Benefit"));
908      Instances costBenefitI = new Instances("Cost/Benefit Curve", fv, 100);
909     
910      // process the performance data to make this curve
911      Instances performanceI = m_masterPlot.getPlotInstances();
912     
913      for (int i = 0; i < performanceI.numInstances(); i++) {
914        Instance current = performanceI.instance(i);
915       
916        double[] vals = new double[2];
917        vals[0] = current.value(10); // sample size
918        vals[1] = (current.value(0) * tpCost
919            + current.value(1) * fnCost
920            + current.value(2) * fpCost
921            + current.value(3) * tnCost) * scaleFactor;
922        Instance newInst = new DenseInstance(1.0, vals);
923        costBenefitI.add(newInst);
924      }
925     
926      costBenefitI.compactify();
927     
928      // now set up the plot data
929      m_costBenefit = new PlotData2D(costBenefitI);
930      m_costBenefit.m_alwaysDisplayPointsOfThisSize = 10;
931      m_costBenefit.setPlotName("Cost/benefit curve");
932      boolean[] connectPoints = new boolean[costBenefitI.numInstances()];
933     
934      for (int i = 0; i < connectPoints.length; i++) {
935        connectPoints[i] = true;
936      }
937      try {
938        m_costBenefit.setConnectPoints(connectPoints);
939        m_costBenefit.setShapeSize(m_shapeSizes);
940      } catch (Exception ex) {
941        // ignore
942      }
943     
944      m_tpPrevious = tpCost;
945      m_fpPrevious = fpCost;
946      m_tnPrevious = tnCost;
947      m_fnPrevious = fnCost;
948     
949      return true;
950    }
951  }
952 
953  /**
954   * Constructor.
955   */
956  public CostBenefitAnalysis() {
957    java.awt.GraphicsEnvironment ge = 
958      java.awt.GraphicsEnvironment.getLocalGraphicsEnvironment();
959    if (!ge.isHeadless()) {
960      appearanceFinal();
961    }
962  }
963 
964  /**
965   * Global info for this bean
966   *
967   * @return a <code>String</code> value
968   */
969  public String globalInfo() {
970    return "Visualize performance charts (such as ROC).";
971  }
972
973  /**
974   * Accept a threshold data event and set up the visualization.
975   * @param e a threshold data event
976   */
977  public void acceptDataSet(ThresholdDataEvent e) {
978    try {
979      setCurveData(e.getDataSet(), e.getClassAttribute());
980    } catch (Exception ex) {
981      System.err.println("[CostBenefitAnalysis] Problem setting up visualization.");
982      ex.printStackTrace();
983    }
984  }
985 
986  /**
987   * Set the threshold curve data to use.
988   *
989   * @param curveData a PlotData2D object set up with the curve data.
990   * @param origClassAtt the class attribute from the original data used to
991   * generate the curve.
992   * @throws Exception if somthing goes wrong during the setup process.
993   */
994  public void setCurveData(PlotData2D curveData, Attribute origClassAtt) 
995    throws Exception {
996    if (m_analysisPanel == null) {
997      m_analysisPanel = new AnalysisPanel();
998    }
999    m_analysisPanel.setDataSet(curveData, origClassAtt);
1000  }
1001
1002  public BeanVisual getVisual() {
1003    return m_visual;
1004  }
1005
1006  public void setVisual(BeanVisual newVisual) {
1007    m_visual = newVisual;
1008  }
1009
1010  public void useDefaultVisual() {
1011    m_visual.loadIcons(BeanVisual.ICON_PATH+"DefaultDataVisualizer.gif",
1012        BeanVisual.ICON_PATH+"DefaultDataVisualizer_animated.gif");
1013  }
1014
1015  public Enumeration enumerateRequests() {
1016    Vector newVector = new Vector(0);
1017    if (m_analysisPanel != null) {
1018      if (m_analysisPanel.m_masterPlot != null) {
1019        newVector.addElement("Show analysis");
1020      }
1021    }
1022    return newVector.elements();
1023  }
1024
1025  public void performRequest(String request) {
1026    if (request.compareTo("Show analysis") == 0) {
1027      try {
1028        // popup visualize panel
1029        if (!m_framePoppedUp) {
1030          m_framePoppedUp = true;
1031
1032          final javax.swing.JFrame jf = 
1033            new javax.swing.JFrame("Cost/Benefit Analysis");
1034          jf.setSize(1000,600);
1035          jf.getContentPane().setLayout(new BorderLayout());
1036          jf.getContentPane().add(m_analysisPanel, BorderLayout.CENTER);
1037          jf.addWindowListener(new java.awt.event.WindowAdapter() {
1038              public void windowClosing(java.awt.event.WindowEvent e) {
1039                jf.dispose();
1040                m_framePoppedUp = false;
1041              }
1042            });
1043          jf.setVisible(true);
1044          m_popupFrame = jf;
1045        } else {
1046          m_popupFrame.toFront();
1047        }
1048      } catch (Exception ex) {
1049        ex.printStackTrace();
1050        m_framePoppedUp = false;
1051      }
1052    } else {
1053      throw new IllegalArgumentException(request
1054          + " not supported (Cost/Benefit Analysis");
1055    }
1056  }
1057
1058  public void addVetoableChangeListener(String name, VetoableChangeListener vcl) {
1059    m_bcSupport.addVetoableChangeListener(name, vcl);
1060  }
1061
1062  public BeanContext getBeanContext() {
1063    return m_beanContext;
1064  }
1065
1066  public void removeVetoableChangeListener(String name,
1067      VetoableChangeListener vcl) {
1068    m_bcSupport.removeVetoableChangeListener(name, vcl);
1069  }
1070 
1071  protected void appearanceFinal() {
1072    removeAll();
1073    setLayout(new BorderLayout());
1074    setUpFinal();
1075  }
1076 
1077  protected void setUpFinal() {
1078    if (m_analysisPanel == null) {
1079      m_analysisPanel = new AnalysisPanel();
1080    }
1081    add(m_analysisPanel, BorderLayout.CENTER);
1082  }
1083 
1084  protected void appearanceDesign() {
1085    removeAll();
1086    m_visual = new BeanVisual("CostBenefitAnalysis", 
1087                              BeanVisual.ICON_PATH+"ModelPerformanceChart.gif",
1088                              BeanVisual.ICON_PATH
1089                              +"ModelPerformanceChart_animated.gif");
1090    setLayout(new BorderLayout());
1091    add(m_visual, BorderLayout.CENTER);
1092  }
1093
1094  public void setBeanContext(BeanContext bc) throws PropertyVetoException {
1095    m_beanContext = bc;
1096    m_design = m_beanContext.isDesignTime();
1097    if (m_design) {
1098      appearanceDesign();
1099    } else {
1100      java.awt.GraphicsEnvironment ge = 
1101        java.awt.GraphicsEnvironment.getLocalGraphicsEnvironment(); 
1102      if (!ge.isHeadless()) {
1103        appearanceFinal();
1104      }
1105    }
1106  }
1107 
1108  /**
1109   * Returns true if, at this time,
1110   * the object will accept a connection via the named event
1111   *
1112   * @param eventName the name of the event in question
1113   * @return true if the object will accept a connection
1114   */
1115  public boolean connectionAllowed(String eventName) {
1116    return (m_listenee == null);
1117  }
1118
1119  /**
1120   * Notify this object that it has been registered as a listener with
1121   * a source for recieving events described by the named event
1122   * This object is responsible for recording this fact.
1123   *
1124   * @param eventName the event
1125   * @param source the source with which this object has been registered as
1126   * a listener
1127   */
1128  public void connectionNotification(String eventName, Object source) {
1129    if (connectionAllowed(eventName)) {
1130      m_listenee = source;
1131    }
1132  }
1133 
1134  /**
1135   * Returns true if, at this time,
1136   * the object will accept a connection according to the supplied
1137   * EventSetDescriptor
1138   *
1139   * @param esd the EventSetDescriptor
1140   * @return true if the object will accept a connection
1141   */
1142  public boolean connectionAllowed(EventSetDescriptor esd) {
1143    return connectionAllowed(esd.getName());
1144  }
1145
1146  /**
1147   * Notify this object that it has been deregistered as a listener with
1148   * a source for named event. This object is responsible
1149   * for recording this fact.
1150   *
1151   * @param eventName the event
1152   * @param source the source with which this object has been registered as
1153   * a listener
1154   */
1155  public void disconnectionNotification(String eventName, Object source) {
1156    if (m_listenee == source) {
1157      m_listenee = null;
1158    }
1159   
1160  }
1161
1162  /**
1163   * Get the custom (descriptive) name for this bean (if one has been set)
1164   *
1165   * @return the custom name (or the default name)
1166   */
1167  public String getCustomName() {
1168    return m_visual.getText();
1169  }
1170
1171  /**
1172   * Returns true if. at this time, the bean is busy with some
1173   * (i.e. perhaps a worker thread is performing some calculation).
1174   *
1175   * @return true if the bean is busy.
1176   */
1177  public boolean isBusy() {
1178    return false;
1179  }
1180
1181  /**
1182   * Set a custom (descriptive) name for this bean
1183   *
1184   * @param name the name to use
1185   */
1186  public void setCustomName(String name) {
1187    m_visual.setText(name);
1188  }
1189
1190  /**
1191   * Set a logger
1192   *
1193   * @param logger a <code>weka.gui.Logger</code> value
1194   */
1195  public void setLog(Logger logger) {
1196    // we don't need to do any logging   
1197  }
1198
1199  /**
1200   * Stop any processing that the bean might be doing.
1201   */
1202  public void stop() {
1203    // nothing to do here
1204  }
1205   
1206  public static void main(String[] args) {
1207    try {
1208      Instances train = new Instances(new java.io.BufferedReader(new java.io.FileReader(args[0])));
1209      train.setClassIndex(train.numAttributes() - 1);
1210      weka.classifiers.evaluation.ThresholdCurve tc = 
1211        new weka.classifiers.evaluation.ThresholdCurve();
1212      weka.classifiers.evaluation.EvaluationUtils eu = 
1213        new weka.classifiers.evaluation.EvaluationUtils();
1214      //weka.classifiers.Classifier classifier = new weka.classifiers.functions.Logistic();
1215      weka.classifiers.Classifier classifier = new weka.classifiers.bayes.NaiveBayes();
1216      FastVector predictions = new FastVector();
1217      eu.setSeed(1);
1218      predictions.appendElements(eu.getCVPredictions(classifier, train, 10));
1219      Instances result = tc.getCurve(predictions, 0);
1220      PlotData2D pd = new PlotData2D(result);
1221      pd.m_alwaysDisplayPointsOfThisSize = 10;
1222
1223      boolean[] connectPoints = new boolean[result.numInstances()];
1224      for (int i = 1; i < connectPoints.length; i++) {
1225        connectPoints[i] = true;
1226      }
1227      pd.setConnectPoints(connectPoints);
1228      final javax.swing.JFrame jf = 
1229        new javax.swing.JFrame("CostBenefitTest");
1230      jf.setSize(1000,600);
1231      //jf.pack();
1232      jf.getContentPane().setLayout(new BorderLayout());
1233      final CostBenefitAnalysis.AnalysisPanel analysisPanel = 
1234        new CostBenefitAnalysis.AnalysisPanel();
1235     
1236      jf.getContentPane().add(analysisPanel, BorderLayout.CENTER);
1237      jf.addWindowListener(new java.awt.event.WindowAdapter() {
1238        public void windowClosing(java.awt.event.WindowEvent e) {
1239          jf.dispose();
1240          System.exit(0);
1241        }
1242      });
1243     
1244      jf.setVisible(true);
1245     
1246      analysisPanel.setDataSet(pd, train.classAttribute());
1247     
1248    } catch (Exception ex) {
1249      ex.printStackTrace();
1250    }
1251 
1252  }
1253}
Note: See TracBrowser for help on using the repository browser.