/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * CostBenefitAnalysis.java * Copyright (C) 2009 University of Waikato, Hamilton, New Zealand * */ package weka.gui.beans; import java.awt.BorderLayout; import java.awt.Color; import java.awt.Dimension; import java.awt.FlowLayout; import java.awt.GridLayout; import java.awt.Graphics; import java.awt.event.ActionEvent; import java.awt.event.ActionListener; import java.awt.event.FocusEvent; import java.awt.event.FocusListener; import java.beans.EventSetDescriptor; import java.beans.PropertyVetoException; import java.beans.VetoableChangeListener; import java.beans.beancontext.BeanContext; import java.beans.beancontext.BeanContextChild; import java.beans.beancontext.BeanContextChildSupport; import java.io.Serializable; import java.util.Enumeration; import java.util.Vector; import javax.swing.BorderFactory; import javax.swing.ButtonGroup; import javax.swing.JButton; import javax.swing.JFrame; import javax.swing.JLabel; import javax.swing.JPanel; import javax.swing.JRadioButton; import javax.swing.JSlider; import javax.swing.JTextField; import javax.swing.SwingConstants; import javax.swing.event.ChangeEvent; import javax.swing.event.ChangeListener; import weka.classifiers.evaluation.ThresholdCurve; import weka.core.Attribute; import weka.core.FastVector; import weka.core.Instance; import weka.core.DenseInstance; import weka.core.Instances; import weka.core.Utils; import weka.gui.Logger; import weka.gui.visualize.VisualizePanel; import weka.gui.visualize.Plot2D; import weka.gui.visualize.PlotData2D; /** * Bean that aids in analyzing cost/benefit tradeoffs. * * @author Mark Hall (mhall{[at]}pentaho{[dot]}com) * @version $Revision: 6137 $ */ public class CostBenefitAnalysis extends JPanel implements BeanCommon, ThresholdDataListener, Visible, UserRequestAcceptor, Serializable, BeanContextChild { /** For serialization */ private static final long serialVersionUID = 8647471654613320469L; protected BeanVisual m_visual; protected transient JFrame m_popupFrame; protected boolean m_framePoppedUp = false; private transient AnalysisPanel m_analysisPanel; /** * True if this bean's appearance is the design mode appearance */ protected boolean m_design; /** * BeanContex that this bean might be contained within */ protected transient BeanContext m_beanContext = null; /** * BeanContextChild support */ protected BeanContextChildSupport m_bcSupport = new BeanContextChildSupport(this); /** * The object sending us data (we allow only one connection at any one time) */ protected Object m_listenee; /** * Inner class for displaying the plots and all control widgets. * * @author Mark Hall (mhall{[at]}pentaho{[dot]}com) */ protected static class AnalysisPanel extends JPanel { /** For serialization */ private static final long serialVersionUID = 5364871945448769003L; /** Displays the performance graphs(s) */ protected VisualizePanel m_performancePanel = new VisualizePanel(); /** Displays the cost/benefit (profit/loss) graph */ protected VisualizePanel m_costBenefitPanel = new VisualizePanel(); /** * The class attribute from the data that was used to generate * the threshold curve */ protected Attribute m_classAttribute; /** Data for the threshold curve */ protected PlotData2D m_masterPlot; /** Data for the cost/benefit curve */ protected PlotData2D m_costBenefit; /** The size of the points being plotted */ protected int[] m_shapeSizes; /** The index of the previous plotted point that was highlighted */ protected int m_previousShapeIndex = -1; /** The slider for adjusting the threshold */ protected JSlider m_thresholdSlider = new JSlider(0,100,0); protected JRadioButton m_percPop = new JRadioButton("% of Population"); protected JRadioButton m_percOfTarget = new JRadioButton("% of Target (recall)"); protected JRadioButton m_threshold = new JRadioButton("Score Threshold"); protected JLabel m_percPopLab = new JLabel(); protected JLabel m_percOfTargetLab = new JLabel(); protected JLabel m_thresholdLab = new JLabel(); // Confusion matrix stuff protected JLabel m_conf_predictedA = new JLabel("Predicted (a)", SwingConstants.RIGHT); protected JLabel m_conf_predictedB = new JLabel("Predicted (b)", SwingConstants.RIGHT); protected JLabel m_conf_actualA = new JLabel(" Actual (a):"); protected JLabel m_conf_actualB = new JLabel(" Actual (b):"); protected ConfusionCell m_conf_aa = new ConfusionCell(); protected ConfusionCell m_conf_ab = new ConfusionCell(); protected ConfusionCell m_conf_ba = new ConfusionCell(); protected ConfusionCell m_conf_bb = new ConfusionCell(); // Cost matrix stuff protected JLabel m_cost_predictedA = new JLabel("Predicted (a)", SwingConstants.RIGHT); protected JLabel m_cost_predictedB = new JLabel("Predicted (b)", SwingConstants.RIGHT); protected JLabel m_cost_actualA = new JLabel(" Actual (a)"); protected JLabel m_cost_actualB = new JLabel(" Actual (b)"); protected JTextField m_cost_aa = new JTextField("0.0", 5); protected JTextField m_cost_ab = new JTextField("1.0", 5); protected JTextField m_cost_ba = new JTextField("1.0", 5); protected JTextField m_cost_bb = new JTextField("0.0" ,5); protected JButton m_maximizeCB = new JButton("Maximize Cost/Benefit"); protected JButton m_minimizeCB = new JButton("Minimize Cost/Benefit"); protected JRadioButton m_costR = new JRadioButton("Cost"); protected JRadioButton m_benefitR = new JRadioButton("Benefit"); protected JLabel m_costBenefitL = new JLabel("Cost: ", SwingConstants.RIGHT); protected JLabel m_costBenefitV = new JLabel("0"); protected JLabel m_randomV = new JLabel("0"); protected JLabel m_gainV = new JLabel("0"); protected int m_originalPopSize; /** Population text field */ protected JTextField m_totalPopField = new JTextField(6); protected int m_totalPopPrevious; /** Classification accuracy */ protected JLabel m_classificationAccV = new JLabel("-"); // Only update curve & stats if values in cost matrix have changed protected double m_tpPrevious; protected double m_fpPrevious; protected double m_tnPrevious; protected double m_fnPrevious; /** * Inner class for handling a single cell in the confusion matrix. * Displays the value, value as a percentage of total population and * graphical depiction of percentage. * * @author Mark Hall (mhall{[at]}pentaho{[dot]}com) */ protected static class ConfusionCell extends JPanel { /** For serialization */ private static final long serialVersionUID = 6148640235434494767L; private JLabel m_conf_cell = new JLabel("-", SwingConstants.RIGHT); JLabel m_conf_perc = new JLabel("-", SwingConstants.RIGHT); private JPanel m_percentageP; protected double m_percentage = 0; public ConfusionCell() { setLayout(new BorderLayout()); setBorder(BorderFactory.createEtchedBorder()); add(m_conf_cell, BorderLayout.NORTH); m_percentageP = new JPanel() { public void paintComponent(Graphics gx) { super.paintComponent(gx); if (m_percentage > 0) { gx.setColor(Color.BLUE); int height = this.getHeight(); double width = this.getWidth(); int barWidth = (int)(m_percentage * width); gx.fillRect(0, 0, barWidth, height); } } }; Dimension d = new Dimension(30,5); m_percentageP.setMinimumSize(d); m_percentageP.setPreferredSize(d); JPanel percHolder = new JPanel(); percHolder.setLayout(new BorderLayout()); percHolder.add(m_percentageP, BorderLayout.CENTER); percHolder.add(m_conf_perc, BorderLayout.EAST); add(percHolder, BorderLayout.SOUTH); } /** * Set the value of a cell. * * @param cellValue the value of the cell * @param max the max (for setting value as a percentage) * @param scaleFactor scale the value by this amount * @param precision precision for the percentage value */ public void setCellValue(double cellValue, double max, double scaleFactor, int precision) { if (!Utils.isMissingValue(cellValue)) { m_percentage = cellValue / max; } else { m_percentage = 0; } m_conf_cell.setText(Utils.doubleToString((cellValue * scaleFactor), 0)); m_conf_perc.setText(Utils.doubleToString(m_percentage * 100.0, precision) + "%"); // refresh the percentage bar m_percentageP.repaint(); } } public AnalysisPanel() { setLayout(new BorderLayout()); m_performancePanel.setShowAttBars(false); m_performancePanel.setShowClassPanel(false); m_costBenefitPanel.setShowAttBars(false); m_costBenefitPanel.setShowClassPanel(false); Dimension size = new Dimension(500, 400); m_performancePanel.setPreferredSize(size); m_performancePanel.setMinimumSize(size); size = new Dimension(500, 400); m_costBenefitPanel.setMinimumSize(size); m_costBenefitPanel.setPreferredSize(size); m_thresholdSlider.addChangeListener(new ChangeListener() { public void stateChanged(ChangeEvent e) { updateInfoForSliderValue((double)m_thresholdSlider.getValue() / 100.0); } }); JPanel plotHolder = new JPanel(); plotHolder.setLayout(new GridLayout(1,2)); plotHolder.add(m_performancePanel); plotHolder.add(m_costBenefitPanel); add(plotHolder, BorderLayout.CENTER); JPanel lowerPanel = new JPanel(); lowerPanel.setLayout(new BorderLayout()); ButtonGroup bGroup = new ButtonGroup(); bGroup.add(m_percPop); bGroup.add(m_percOfTarget); bGroup.add(m_threshold); ButtonGroup bGroup2 = new ButtonGroup(); bGroup2.add(m_costR); bGroup2.add(m_benefitR); ActionListener rl = new ActionListener() { public void actionPerformed(ActionEvent e) { if (m_costR.isSelected()) { m_costBenefitL.setText("Cost: "); } else { m_costBenefitL.setText("Benefit: "); } double gain = Double.parseDouble(m_gainV.getText()); gain = -gain; m_gainV.setText(Utils.doubleToString(gain, 2)); } }; m_costR.addActionListener(rl); m_benefitR.addActionListener(rl); m_costR.setSelected(true); m_percPop.setSelected(true); JPanel threshPanel = new JPanel(); threshPanel.setLayout(new BorderLayout()); JPanel radioHolder = new JPanel(); radioHolder.setLayout(new FlowLayout()); radioHolder.add(m_percPop); radioHolder.add(m_percOfTarget); radioHolder.add(m_threshold); threshPanel.add(radioHolder, BorderLayout.NORTH); threshPanel.add(m_thresholdSlider, BorderLayout.SOUTH); JPanel threshInfoPanel = new JPanel(); threshInfoPanel.setLayout(new GridLayout(3,2)); threshInfoPanel.add(new JLabel("% of Population: ", SwingConstants.RIGHT)); threshInfoPanel.add(m_percPopLab); threshInfoPanel.add(new JLabel("% of Target: ", SwingConstants.RIGHT)); threshInfoPanel.add(m_percOfTargetLab); threshInfoPanel.add(new JLabel("Score Threshold: ", SwingConstants.RIGHT)); threshInfoPanel.add(m_thresholdLab); JPanel threshHolder = new JPanel(); threshHolder.setBorder(BorderFactory.createTitledBorder("Threshold")); threshHolder.setLayout(new BorderLayout()); threshHolder.add(threshPanel, BorderLayout.CENTER); threshHolder.add(threshInfoPanel, BorderLayout.EAST); lowerPanel.add(threshHolder, BorderLayout.NORTH); // holder for the two matrixes JPanel matrixHolder = new JPanel(); matrixHolder.setLayout(new GridLayout(1,2)); // confusion matrix JPanel confusionPanel = new JPanel(); confusionPanel.setLayout(new GridLayout(3,3)); confusionPanel.add(m_conf_predictedA); confusionPanel.add(m_conf_predictedB); confusionPanel.add(new JLabel()); // dummy confusionPanel.add(m_conf_aa); confusionPanel.add(m_conf_ab); confusionPanel.add(m_conf_actualA); confusionPanel.add(m_conf_ba); confusionPanel.add(m_conf_bb); confusionPanel.add(m_conf_actualB); JPanel tempHolderCA = new JPanel(); tempHolderCA.setLayout(new BorderLayout()); tempHolderCA.setBorder(BorderFactory.createTitledBorder("Confusion Matrix")); tempHolderCA.add(confusionPanel, BorderLayout.CENTER); JPanel accHolder = new JPanel(); accHolder.setLayout(new FlowLayout(FlowLayout.LEFT)); accHolder.add(new JLabel("Classification Accuracy: ")); accHolder.add(m_classificationAccV); tempHolderCA.add(accHolder, BorderLayout.SOUTH); matrixHolder.add(tempHolderCA); // cost matrix JPanel costPanel = new JPanel(); costPanel.setBorder(BorderFactory.createTitledBorder("Cost Matrix")); costPanel.setLayout(new BorderLayout()); JPanel cmHolder = new JPanel(); cmHolder.setLayout(new GridLayout(3, 3)); cmHolder.add(m_cost_predictedA); cmHolder.add(m_cost_predictedB); cmHolder.add(new JLabel()); // dummy cmHolder.add(m_cost_aa); cmHolder.add(m_cost_ab); cmHolder.add(m_cost_actualA); cmHolder.add(m_cost_ba); cmHolder.add(m_cost_bb); cmHolder.add(m_cost_actualB); costPanel.add(cmHolder, BorderLayout.CENTER); FocusListener fl = new FocusListener() { public void focusGained(FocusEvent e) { } public void focusLost(FocusEvent e) { if (constructCostBenefitData()) { try { m_costBenefitPanel.setMasterPlot(m_costBenefit); m_costBenefitPanel.validate(); m_costBenefitPanel.repaint(); } catch (Exception ex) { ex.printStackTrace(); } updateCostBenefit(); } } }; ActionListener al = new ActionListener() { public void actionPerformed(ActionEvent e) { if (constructCostBenefitData()) { try { m_costBenefitPanel.setMasterPlot(m_costBenefit); m_costBenefitPanel.validate(); m_costBenefitPanel.repaint(); } catch (Exception ex) { ex.printStackTrace(); } updateCostBenefit(); } } }; m_cost_aa.addFocusListener(fl); m_cost_aa.addActionListener(al); m_cost_ab.addFocusListener(fl); m_cost_ab.addActionListener(al); m_cost_ba.addFocusListener(fl); m_cost_ba.addActionListener(al); m_cost_bb.addFocusListener(fl); m_cost_bb.addActionListener(al); m_totalPopField.addFocusListener(fl); m_totalPopField.addActionListener(al); JPanel cbHolder = new JPanel(); cbHolder.setLayout(new BorderLayout()); JPanel tempP = new JPanel(); tempP.setLayout(new GridLayout(3, 2)); tempP.add(m_costBenefitL); tempP.add(m_costBenefitV); tempP.add(new JLabel("Random: ", SwingConstants.RIGHT)); tempP.add(m_randomV); tempP.add(new JLabel("Gain: ", SwingConstants.RIGHT)); tempP.add(m_gainV); cbHolder.add(tempP, BorderLayout.NORTH); JPanel butHolder = new JPanel(); butHolder.setLayout(new GridLayout(2, 1)); butHolder.add(m_maximizeCB); butHolder.add(m_minimizeCB); m_maximizeCB.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { findMaxMinCB(true); } }); m_minimizeCB.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { findMaxMinCB(false); } }); cbHolder.add(butHolder, BorderLayout.SOUTH); costPanel.add(cbHolder, BorderLayout.EAST); JPanel popCBR = new JPanel(); popCBR.setLayout(new GridLayout(1, 2)); JPanel popHolder = new JPanel(); popHolder.setLayout(new FlowLayout(FlowLayout.LEFT)); popHolder.add(new JLabel("Total Population: ")); popHolder.add(m_totalPopField); JPanel radioHolder2 = new JPanel(); radioHolder2.setLayout(new FlowLayout(FlowLayout.RIGHT)); radioHolder2.add(m_costR); radioHolder2.add(m_benefitR); popCBR.add(popHolder); popCBR.add(radioHolder2); costPanel.add(popCBR, BorderLayout.SOUTH); matrixHolder.add(costPanel); lowerPanel.add(matrixHolder, BorderLayout.SOUTH); // popAccHolder.add(popHolder); //popAccHolder.add(accHolder); /*JPanel lowerPanel2 = new JPanel(); lowerPanel2.setLayout(new BorderLayout()); lowerPanel2.add(lowerPanel, BorderLayout.NORTH); lowerPanel2.add(popAccHolder, BorderLayout.SOUTH); */ add(lowerPanel, BorderLayout.SOUTH); } private void findMaxMinCB(boolean max) { double maxMin = (max) ? Double.NEGATIVE_INFINITY : Double.POSITIVE_INFINITY; Instances cBCurve = m_costBenefit.getPlotInstances(); int maxMinIndex = 0; for (int i = 0; i < cBCurve.numInstances(); i++) { Instance current = cBCurve.instance(i); if (max) { if (current.value(1) > maxMin) { maxMin = current.value(1); maxMinIndex = i; } } else { if (current.value(1) < maxMin) { maxMin = current.value(1); maxMinIndex = i; } } } // set the slider to the correct position int indexOfSampleSize = m_masterPlot.getPlotInstances().attribute(ThresholdCurve.SAMPLE_SIZE_NAME).index(); int indexOfPercOfTarget = m_masterPlot.getPlotInstances().attribute(ThresholdCurve.RECALL_NAME).index(); int indexOfThreshold = m_masterPlot.getPlotInstances().attribute(ThresholdCurve.THRESHOLD_NAME).index(); int indexOfMetric; if (m_percPop.isSelected()) { indexOfMetric = indexOfSampleSize; } else if (m_percOfTarget.isSelected()) { indexOfMetric = indexOfPercOfTarget; } else { indexOfMetric = indexOfThreshold; } double valueOfMetric = m_masterPlot.getPlotInstances().instance(maxMinIndex).value(indexOfMetric); valueOfMetric *= 100.0; // set the approximate location of the slider m_thresholdSlider.setValue((int)valueOfMetric); // make sure the actual values relate to the true min/max rather // than being off due to slider location error. updateInfoGivenIndex(maxMinIndex); } private void updateCostBenefit() { double value = (double)m_thresholdSlider.getValue() / 100.0; Instances plotInstances = m_masterPlot.getPlotInstances(); int indexOfSampleSize = m_masterPlot.getPlotInstances().attribute(ThresholdCurve.SAMPLE_SIZE_NAME).index(); int indexOfPercOfTarget = m_masterPlot.getPlotInstances().attribute(ThresholdCurve.RECALL_NAME).index(); int indexOfThreshold = m_masterPlot.getPlotInstances().attribute(ThresholdCurve.THRESHOLD_NAME).index(); int indexOfMetric; if (m_percPop.isSelected()) { indexOfMetric = indexOfSampleSize; } else if (m_percOfTarget.isSelected()) { indexOfMetric = indexOfPercOfTarget; } else { indexOfMetric = indexOfThreshold; } int index = findIndexForValue(value, plotInstances, indexOfMetric); updateCBRandomGainInfo(index); } private void updateCBRandomGainInfo(int index) { double requestedPopSize = m_originalPopSize; try { requestedPopSize = Double.parseDouble(m_totalPopField.getText()); } catch (NumberFormatException e) {} double scaleFactor = requestedPopSize / m_originalPopSize; double CB = m_costBenefit. getPlotInstances().instance(index).value(1); m_costBenefitV.setText(Utils.doubleToString(CB,2)); double totalRandomCB = 0.0; Instance first = m_masterPlot.getPlotInstances().instance(0); double totalPos = first.value(m_masterPlot.getPlotInstances(). attribute(ThresholdCurve.TRUE_POS_NAME).index()) * scaleFactor; double totalNeg = first.value(m_masterPlot.getPlotInstances(). attribute(ThresholdCurve.FALSE_POS_NAME)) * scaleFactor; double posInSample = (totalPos * (Double.parseDouble(m_percPopLab.getText()) / 100.0)); double negInSample = (totalNeg * (Double.parseDouble(m_percPopLab.getText()) / 100.0)); double posOutSample = totalPos - posInSample; double negOutSample = totalNeg - negInSample; double tpCost = 0.0; try { tpCost = Double.parseDouble(m_cost_aa.getText()); } catch (NumberFormatException n) {} double fpCost = 0.0; try { fpCost = Double.parseDouble(m_cost_ba.getText()); } catch (NumberFormatException n) {} double tnCost = 0.0; try { tnCost = Double.parseDouble(m_cost_bb.getText()); } catch (NumberFormatException n) {} double fnCost = 0.0; try { fnCost = Double.parseDouble(m_cost_ab.getText()); } catch (NumberFormatException n) {} totalRandomCB += posInSample * tpCost; totalRandomCB += negInSample * fpCost; totalRandomCB += posOutSample * fnCost; totalRandomCB += negOutSample * tnCost; m_randomV.setText(Utils.doubleToString(totalRandomCB, 2)); double gain = (m_costR.isSelected()) ? totalRandomCB - CB : CB - totalRandomCB; m_gainV.setText(Utils.doubleToString(gain, 2)); // update classification rate Instance currentInst = m_masterPlot.getPlotInstances().instance(index); double tp = currentInst.value(m_masterPlot.getPlotInstances(). attribute(ThresholdCurve.TRUE_POS_NAME).index()); double tn = currentInst.value(m_masterPlot.getPlotInstances(). attribute(ThresholdCurve.TRUE_NEG_NAME).index()); m_classificationAccV. setText(Utils.doubleToString((tp + tn) / (totalPos + totalNeg) * 100.0, 4) + "%"); } private void updateInfoGivenIndex(int index) { Instances plotInstances = m_masterPlot.getPlotInstances(); int indexOfSampleSize = m_masterPlot.getPlotInstances().attribute(ThresholdCurve.SAMPLE_SIZE_NAME).index(); int indexOfPercOfTarget = m_masterPlot.getPlotInstances().attribute(ThresholdCurve.RECALL_NAME).index(); int indexOfThreshold = m_masterPlot.getPlotInstances().attribute(ThresholdCurve.THRESHOLD_NAME).index(); // update labels m_percPopLab.setText(Utils. doubleToString(100.0 * plotInstances.instance(index).value(indexOfSampleSize), 4)); m_percOfTargetLab.setText(Utils.doubleToString( 100.0 * plotInstances.instance(index).value(indexOfPercOfTarget), 4)); m_thresholdLab.setText(Utils.doubleToString(plotInstances.instance(index).value(indexOfThreshold), 4)); /*if (m_percPop.isSelected()) { m_percPopLab.setText(Utils.doubleToString(100.0 * value, 4)); } else if (m_percOfTarget.isSelected()) { m_percOfTargetLab.setText(Utils.doubleToString(100.0 * value, 4)); } else { m_thresholdLab.setText(Utils.doubleToString(value, 4)); }*/ // Update the highlighted point on the graphs */ if (m_previousShapeIndex >= 0) { m_shapeSizes[m_previousShapeIndex] = 1; } m_shapeSizes[index] = 10; m_previousShapeIndex = index; // Update the confusion matrix // double totalInstances = int tp = plotInstances.attribute(ThresholdCurve.TRUE_POS_NAME).index(); int fp = plotInstances.attribute(ThresholdCurve.FALSE_POS_NAME).index(); int tn = plotInstances.attribute(ThresholdCurve.TRUE_NEG_NAME).index(); int fn = plotInstances.attribute(ThresholdCurve.FALSE_NEG_NAME).index(); Instance temp = plotInstances.instance(index); double totalInstances = temp.value(tp) + temp.value(fp) + temp.value(tn) + temp.value(fn); // get the value out of the total pop field (if possible) double requestedPopSize = totalInstances; try { requestedPopSize = Double.parseDouble(m_totalPopField.getText()); } catch (NumberFormatException e) {} m_conf_aa.setCellValue(temp.value(tp), totalInstances, requestedPopSize / totalInstances, 2); m_conf_ab.setCellValue(temp.value(fn), totalInstances, requestedPopSize / totalInstances, 2); m_conf_ba.setCellValue(temp.value(fp), totalInstances, requestedPopSize / totalInstances, 2); m_conf_bb.setCellValue(temp.value(tn), totalInstances, requestedPopSize / totalInstances, 2); updateCBRandomGainInfo(index); repaint(); } private void updateInfoForSliderValue(double value) { int indexOfSampleSize = m_masterPlot.getPlotInstances().attribute(ThresholdCurve.SAMPLE_SIZE_NAME).index(); int indexOfPercOfTarget = m_masterPlot.getPlotInstances().attribute(ThresholdCurve.RECALL_NAME).index(); int indexOfThreshold = m_masterPlot.getPlotInstances().attribute(ThresholdCurve.THRESHOLD_NAME).index(); int indexOfMetric; if (m_percPop.isSelected()) { indexOfMetric = indexOfSampleSize; } else if (m_percOfTarget.isSelected()) { indexOfMetric = indexOfPercOfTarget; } else { indexOfMetric = indexOfThreshold; } Instances plotInstances = m_masterPlot.getPlotInstances(); int index = findIndexForValue(value, plotInstances, indexOfMetric); updateInfoGivenIndex(index); } private int findIndexForValue(double value, Instances plotInstances, int indexOfMetric) { // binary search // threshold curve is sorted ascending in the threshold (thus // descending for recall and pop size) int index = -1; int lower = 0; int upper = plotInstances.numInstances() - 1; int mid = (upper - lower) / 2; boolean done = false; while (!done) { if (upper - lower <= 1) { // choose the one closest to the value double comp1 = plotInstances.instance(upper).value(indexOfMetric); double comp2 = plotInstances.instance(lower).value(indexOfMetric); if (Math.abs(comp1 - value) < Math.abs(comp2 - value)) { index = upper; } else { index = lower; } break; } double comparisonVal = plotInstances.instance(mid).value(indexOfMetric); if (value > comparisonVal) { if (m_threshold.isSelected()) { lower = mid; mid += (upper - lower) / 2; } else { upper = mid; mid -= (upper - lower) / 2; } } else if (value < comparisonVal) { if (m_threshold.isSelected()) { upper = mid; mid -= (upper - lower) / 2; } else { lower = mid; mid += (upper - lower) / 2; } } else { index = mid; done = true; } } // now check for ties in the appropriate direction if (!m_threshold.isSelected()) { while (index + 1 < plotInstances.numInstances()) { if (plotInstances.instance(index + 1).value(indexOfMetric) == plotInstances.instance(index).value(indexOfMetric)) { index++; } else { break; } } } else { while (index - 1 >= 0) { if (plotInstances.instance(index - 1).value(indexOfMetric) == plotInstances.instance(index).value(indexOfMetric)) { index--; } else { break; } } } return index; } /** * Set the threshold data for the panel to use. * * @param data PlotData2D object encapsulating the threshold data. * @param classAtt the class attribute from the original data used to generate * the threshold data. * @throws Exception if something goes wrong. */ public synchronized void setDataSet(PlotData2D data, Attribute classAtt) throws Exception { // make a copy of the PlotData2D object m_masterPlot = new PlotData2D(data.getPlotInstances()); boolean[] connectPoints = new boolean[m_masterPlot.getPlotInstances().numInstances()]; for (int i = 1; i < connectPoints.length; i++) { connectPoints[i] = true; } m_masterPlot.setConnectPoints(connectPoints); m_masterPlot.m_alwaysDisplayPointsOfThisSize = 10; setClassForConfusionMatrix(classAtt); m_performancePanel.setMasterPlot(m_masterPlot); m_performancePanel.validate(); m_performancePanel.repaint(); m_shapeSizes = new int[m_masterPlot.getPlotInstances().numInstances()]; for (int i = 0; i < m_shapeSizes.length; i++) { m_shapeSizes[i] = 1; } m_masterPlot.setShapeSize(m_shapeSizes); constructCostBenefitData(); m_costBenefitPanel.setMasterPlot(m_costBenefit); m_costBenefitPanel.validate(); m_costBenefitPanel.repaint(); m_totalPopPrevious = 0; m_fpPrevious = 0; m_tpPrevious = 0; m_tnPrevious = 0; m_fnPrevious = 0; m_previousShapeIndex = -1; // set the total population size Instance first = m_masterPlot.getPlotInstances().instance(0); double totalPos = first.value(m_masterPlot.getPlotInstances(). attribute(ThresholdCurve.TRUE_POS_NAME).index()); double totalNeg = first.value(m_masterPlot.getPlotInstances(). attribute(ThresholdCurve.FALSE_POS_NAME)); m_originalPopSize = (int)(totalPos + totalNeg); m_totalPopField.setText("" + m_originalPopSize); m_performancePanel.setYIndex(5); m_performancePanel.setXIndex(10); m_costBenefitPanel.setXIndex(0); m_costBenefitPanel.setYIndex(1); // System.err.println(m_masterPlot.getPlotInstances()); updateInfoForSliderValue((double)m_thresholdSlider.getValue() / 100.0); } private void setClassForConfusionMatrix(Attribute classAtt) { m_classAttribute = classAtt; m_conf_actualA.setText(" Actual (a): " + classAtt.value(0)); m_conf_actualA.setToolTipText(classAtt.value(0)); String negClasses = ""; for (int i = 1; i < classAtt.numValues(); i++) { negClasses += classAtt.value(i); if (i < classAtt.numValues() - 1) { negClasses += ","; } } m_conf_actualB.setText(" Actual (b): " + negClasses); m_conf_actualB.setToolTipText(negClasses); } private boolean constructCostBenefitData() { double tpCost = 0.0; try { tpCost = Double.parseDouble(m_cost_aa.getText()); } catch (NumberFormatException n) {} double fpCost = 0.0; try { fpCost = Double.parseDouble(m_cost_ba.getText()); } catch (NumberFormatException n) {} double tnCost = 0.0; try { tnCost = Double.parseDouble(m_cost_bb.getText()); } catch (NumberFormatException n) {} double fnCost = 0.0; try { fnCost = Double.parseDouble(m_cost_ab.getText()); } catch (NumberFormatException n) {} double requestedPopSize = m_originalPopSize; try { requestedPopSize = Double.parseDouble(m_totalPopField.getText()); } catch (NumberFormatException e) {} double scaleFactor = 1.0; if (m_originalPopSize != 0) { scaleFactor = requestedPopSize / m_originalPopSize; } if (tpCost == m_tpPrevious && fpCost == m_fpPrevious && tnCost == m_tnPrevious && fnCost == m_fnPrevious && requestedPopSize == m_totalPopPrevious) { return false; } // First construct some Instances for the curve FastVector fv = new FastVector(); fv.addElement(new Attribute("Sample Size")); fv.addElement(new Attribute("Cost/Benefit")); Instances costBenefitI = new Instances("Cost/Benefit Curve", fv, 100); // process the performance data to make this curve Instances performanceI = m_masterPlot.getPlotInstances(); for (int i = 0; i < performanceI.numInstances(); i++) { Instance current = performanceI.instance(i); double[] vals = new double[2]; vals[0] = current.value(10); // sample size vals[1] = (current.value(0) * tpCost + current.value(1) * fnCost + current.value(2) * fpCost + current.value(3) * tnCost) * scaleFactor; Instance newInst = new DenseInstance(1.0, vals); costBenefitI.add(newInst); } costBenefitI.compactify(); // now set up the plot data m_costBenefit = new PlotData2D(costBenefitI); m_costBenefit.m_alwaysDisplayPointsOfThisSize = 10; m_costBenefit.setPlotName("Cost/benefit curve"); boolean[] connectPoints = new boolean[costBenefitI.numInstances()]; for (int i = 0; i < connectPoints.length; i++) { connectPoints[i] = true; } try { m_costBenefit.setConnectPoints(connectPoints); m_costBenefit.setShapeSize(m_shapeSizes); } catch (Exception ex) { // ignore } m_tpPrevious = tpCost; m_fpPrevious = fpCost; m_tnPrevious = tnCost; m_fnPrevious = fnCost; return true; } } /** * Constructor. */ public CostBenefitAnalysis() { java.awt.GraphicsEnvironment ge = java.awt.GraphicsEnvironment.getLocalGraphicsEnvironment(); if (!ge.isHeadless()) { appearanceFinal(); } } /** * Global info for this bean * * @return a String value */ public String globalInfo() { return "Visualize performance charts (such as ROC)."; } /** * Accept a threshold data event and set up the visualization. * @param e a threshold data event */ public void acceptDataSet(ThresholdDataEvent e) { try { setCurveData(e.getDataSet(), e.getClassAttribute()); } catch (Exception ex) { System.err.println("[CostBenefitAnalysis] Problem setting up visualization."); ex.printStackTrace(); } } /** * Set the threshold curve data to use. * * @param curveData a PlotData2D object set up with the curve data. * @param origClassAtt the class attribute from the original data used to * generate the curve. * @throws Exception if somthing goes wrong during the setup process. */ public void setCurveData(PlotData2D curveData, Attribute origClassAtt) throws Exception { if (m_analysisPanel == null) { m_analysisPanel = new AnalysisPanel(); } m_analysisPanel.setDataSet(curveData, origClassAtt); } public BeanVisual getVisual() { return m_visual; } public void setVisual(BeanVisual newVisual) { m_visual = newVisual; } public void useDefaultVisual() { m_visual.loadIcons(BeanVisual.ICON_PATH+"DefaultDataVisualizer.gif", BeanVisual.ICON_PATH+"DefaultDataVisualizer_animated.gif"); } public Enumeration enumerateRequests() { Vector newVector = new Vector(0); if (m_analysisPanel != null) { if (m_analysisPanel.m_masterPlot != null) { newVector.addElement("Show analysis"); } } return newVector.elements(); } public void performRequest(String request) { if (request.compareTo("Show analysis") == 0) { try { // popup visualize panel if (!m_framePoppedUp) { m_framePoppedUp = true; final javax.swing.JFrame jf = new javax.swing.JFrame("Cost/Benefit Analysis"); jf.setSize(1000,600); jf.getContentPane().setLayout(new BorderLayout()); jf.getContentPane().add(m_analysisPanel, BorderLayout.CENTER); jf.addWindowListener(new java.awt.event.WindowAdapter() { public void windowClosing(java.awt.event.WindowEvent e) { jf.dispose(); m_framePoppedUp = false; } }); jf.setVisible(true); m_popupFrame = jf; } else { m_popupFrame.toFront(); } } catch (Exception ex) { ex.printStackTrace(); m_framePoppedUp = false; } } else { throw new IllegalArgumentException(request + " not supported (Cost/Benefit Analysis"); } } public void addVetoableChangeListener(String name, VetoableChangeListener vcl) { m_bcSupport.addVetoableChangeListener(name, vcl); } public BeanContext getBeanContext() { return m_beanContext; } public void removeVetoableChangeListener(String name, VetoableChangeListener vcl) { m_bcSupport.removeVetoableChangeListener(name, vcl); } protected void appearanceFinal() { removeAll(); setLayout(new BorderLayout()); setUpFinal(); } protected void setUpFinal() { if (m_analysisPanel == null) { m_analysisPanel = new AnalysisPanel(); } add(m_analysisPanel, BorderLayout.CENTER); } protected void appearanceDesign() { removeAll(); m_visual = new BeanVisual("CostBenefitAnalysis", BeanVisual.ICON_PATH+"ModelPerformanceChart.gif", BeanVisual.ICON_PATH +"ModelPerformanceChart_animated.gif"); setLayout(new BorderLayout()); add(m_visual, BorderLayout.CENTER); } public void setBeanContext(BeanContext bc) throws PropertyVetoException { m_beanContext = bc; m_design = m_beanContext.isDesignTime(); if (m_design) { appearanceDesign(); } else { java.awt.GraphicsEnvironment ge = java.awt.GraphicsEnvironment.getLocalGraphicsEnvironment(); if (!ge.isHeadless()) { appearanceFinal(); } } } /** * Returns true if, at this time, * the object will accept a connection via the named event * * @param eventName the name of the event in question * @return true if the object will accept a connection */ public boolean connectionAllowed(String eventName) { return (m_listenee == null); } /** * Notify this object that it has been registered as a listener with * a source for recieving events described by the named event * This object is responsible for recording this fact. * * @param eventName the event * @param source the source with which this object has been registered as * a listener */ public void connectionNotification(String eventName, Object source) { if (connectionAllowed(eventName)) { m_listenee = source; } } /** * Returns true if, at this time, * the object will accept a connection according to the supplied * EventSetDescriptor * * @param esd the EventSetDescriptor * @return true if the object will accept a connection */ public boolean connectionAllowed(EventSetDescriptor esd) { return connectionAllowed(esd.getName()); } /** * Notify this object that it has been deregistered as a listener with * a source for named event. This object is responsible * for recording this fact. * * @param eventName the event * @param source the source with which this object has been registered as * a listener */ public void disconnectionNotification(String eventName, Object source) { if (m_listenee == source) { m_listenee = null; } } /** * Get the custom (descriptive) name for this bean (if one has been set) * * @return the custom name (or the default name) */ public String getCustomName() { return m_visual.getText(); } /** * Returns true if. at this time, the bean is busy with some * (i.e. perhaps a worker thread is performing some calculation). * * @return true if the bean is busy. */ public boolean isBusy() { return false; } /** * Set a custom (descriptive) name for this bean * * @param name the name to use */ public void setCustomName(String name) { m_visual.setText(name); } /** * Set a logger * * @param logger a weka.gui.Logger value */ public void setLog(Logger logger) { // we don't need to do any logging } /** * Stop any processing that the bean might be doing. */ public void stop() { // nothing to do here } public static void main(String[] args) { try { Instances train = new Instances(new java.io.BufferedReader(new java.io.FileReader(args[0]))); train.setClassIndex(train.numAttributes() - 1); weka.classifiers.evaluation.ThresholdCurve tc = new weka.classifiers.evaluation.ThresholdCurve(); weka.classifiers.evaluation.EvaluationUtils eu = new weka.classifiers.evaluation.EvaluationUtils(); //weka.classifiers.Classifier classifier = new weka.classifiers.functions.Logistic(); weka.classifiers.Classifier classifier = new weka.classifiers.bayes.NaiveBayes(); FastVector predictions = new FastVector(); eu.setSeed(1); predictions.appendElements(eu.getCVPredictions(classifier, train, 10)); Instances result = tc.getCurve(predictions, 0); PlotData2D pd = new PlotData2D(result); pd.m_alwaysDisplayPointsOfThisSize = 10; boolean[] connectPoints = new boolean[result.numInstances()]; for (int i = 1; i < connectPoints.length; i++) { connectPoints[i] = true; } pd.setConnectPoints(connectPoints); final javax.swing.JFrame jf = new javax.swing.JFrame("CostBenefitTest"); jf.setSize(1000,600); //jf.pack(); jf.getContentPane().setLayout(new BorderLayout()); final CostBenefitAnalysis.AnalysisPanel analysisPanel = new CostBenefitAnalysis.AnalysisPanel(); jf.getContentPane().add(analysisPanel, BorderLayout.CENTER); jf.addWindowListener(new java.awt.event.WindowAdapter() { public void windowClosing(java.awt.event.WindowEvent e) { jf.dispose(); System.exit(0); } }); jf.setVisible(true); analysisPanel.setDataSet(pd, train.classAttribute()); } catch (Exception ex) { ex.printStackTrace(); } } }