/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * BoundaryPanel.java * Copyright (C) 2002 University of Waikato, Hamilton, New Zealand * */ package weka.gui.boundaryvisualizer; import weka.classifiers.Classifier; import weka.classifiers.AbstractClassifier; import weka.core.FastVector; import weka.core.Instance; import weka.core.DenseInstance; import weka.core.Instances; import weka.core.Utils; import weka.gui.visualize.JPEGWriter; import java.awt.BorderLayout; import java.awt.Color; import java.awt.Dimension; import java.awt.Graphics; import java.awt.Graphics2D; import java.awt.Image; import java.awt.RenderingHints; import java.awt.event.ActionEvent; import java.awt.event.ActionListener; import java.awt.event.MouseEvent; import java.awt.event.MouseListener; import java.awt.image.BufferedImage; import java.io.File; import java.io.FileInputStream; import java.io.ObjectInputStream; import java.util.Iterator; import java.util.Locale; import java.util.Random; import java.util.Vector; import javax.imageio.IIOImage; import javax.imageio.ImageIO; import javax.imageio.ImageWriteParam; import javax.imageio.ImageWriter; import javax.imageio.plugins.jpeg.JPEGImageWriteParam; import javax.imageio.stream.ImageOutputStream; import javax.swing.JOptionPane; import javax.swing.JPanel; import javax.swing.ToolTipManager; /** * BoundaryPanel. A class to handle the plotting operations * associated with generating a 2D picture of a classifier's decision * boundaries. * * @author Mark Hall * @version $Revision: 5987 $ * @since 1.0 * @see JPanel */ public class BoundaryPanel extends JPanel { /** for serialization */ private static final long serialVersionUID = -8499445518744770458L; /** default colours for classes */ public static final Color [] DEFAULT_COLORS = { Color.red, Color.green, Color.blue, new Color(0, 255, 255), // cyan new Color(255, 0, 255), // pink new Color(255, 255, 0), // yellow new Color(255, 255, 255), //white new Color(0, 0, 0)}; /** The distance we can click away from a point in the GUI and still remove it. */ public static final double REMOVE_POINT_RADIUS = 7.0; protected FastVector m_Colors = new FastVector(); /** training data */ protected Instances m_trainingData; /** distribution classifier to use */ protected Classifier m_classifier; /** data generator to use */ protected DataGenerator m_dataGenerator; /** index of the class attribute */ private int m_classIndex = -1; // attributes for visualizing on protected int m_xAttribute; protected int m_yAttribute; // min, max and ranges of these attributes protected double m_minX; protected double m_minY; protected double m_maxX; protected double m_maxY; private double m_rangeX; private double m_rangeY; // pixel width and height in terms of attribute values protected double m_pixHeight; protected double m_pixWidth; /** used for offscreen drawing */ protected Image m_osi = null; // width and height of the display area protected int m_panelWidth; protected int m_panelHeight; // number of samples to take from each region in the fixed dimensions protected int m_numOfSamplesPerRegion = 2; // number of samples per kernel = base ^ (# non-fixed dimensions) protected int m_numOfSamplesPerGenerator; protected double m_samplesBase = 2.0; /** listeners to be notified when plot is complete */ private Vector m_listeners = new Vector(); /** * small inner class for rendering the bitmap on to */ private class PlotPanel extends JPanel { /** for serialization */ private static final long serialVersionUID = 743629498352235060L; public PlotPanel() { this.setToolTipText(""); } public void paintComponent(Graphics g) { super.paintComponent(g); if (m_osi != null) { g.drawImage(m_osi,0,0,this); } } public String getToolTipText(MouseEvent event) { if (m_probabilityCache == null) { return null; } if (m_probabilityCache[event.getY()][event.getX()] == null) { return null; } String pVec = "(X: " +Utils.doubleToString(convertFromPanelX((double)event.getX()), 2) +" Y: " +Utils.doubleToString(convertFromPanelY((double)event.getY()), 2)+") "; // construct a string holding the probability vector for (int i = 0; i < m_trainingData.classAttribute().numValues(); i++) { pVec += Utils. doubleToString(m_probabilityCache[event.getY()][event.getX()][i], 3)+" "; } return pVec; } } /** the actual plotting area */ private PlotPanel m_plotPanel = new PlotPanel(); /** thread for running the plotting operation in */ private Thread m_plotThread = null; /** Stop the plotting thread */ protected boolean m_stopPlotting = false; /** Stop any replotting threads */ protected boolean m_stopReplotting = false; // Used by replotting threads to pause and resume the main plot thread private Double m_dummy = new Double(1.0); private boolean m_pausePlotting = false; /** what size of tile is currently being plotted */ private int m_size = 1; /** is the main plot thread performing the initial coarse tiling */ private boolean m_initialTiling; /** A random number generator */ private Random m_random = null; /** cache of probabilities for fast replotting */ protected double [][][] m_probabilityCache; /** plot the training data */ protected boolean m_plotTrainingData = true; /** * Creates a new BoundaryPanel instance. * * @param panelWidth the width in pixels of the panel * @param panelHeight the height in pixels of the panel */ public BoundaryPanel(int panelWidth, int panelHeight) { ToolTipManager.sharedInstance().setDismissDelay(Integer.MAX_VALUE); m_panelWidth = panelWidth; m_panelHeight = panelHeight; setLayout(new BorderLayout()); m_plotPanel.setMinimumSize(new Dimension(m_panelWidth, m_panelHeight)); m_plotPanel.setPreferredSize(new Dimension(m_panelWidth, m_panelHeight)); m_plotPanel.setMaximumSize(new Dimension(m_panelWidth, m_panelHeight)); add(m_plotPanel, BorderLayout.CENTER); setPreferredSize(m_plotPanel.getPreferredSize()); setMaximumSize(m_plotPanel.getMaximumSize()); setMinimumSize(m_plotPanel.getMinimumSize()); m_random = new Random(1); for (int i = 0; i < DEFAULT_COLORS.length; i++) { m_Colors.addElement(new Color(DEFAULT_COLORS[i].getRed(), DEFAULT_COLORS[i].getGreen(), DEFAULT_COLORS[i].getBlue())); } m_probabilityCache = new double[m_panelHeight][m_panelWidth][]; } /** * Set the number of points to uniformly sample from a region (fixed * dimensions). * * @param num an int value */ public void setNumSamplesPerRegion(int num) { m_numOfSamplesPerRegion = num; } /** * Get the number of points to sample from a region (fixed dimensions). * * @return an int value */ public int getNumSamplesPerRegion() { return m_numOfSamplesPerRegion; } /** * Set the base for computing the number of samples to obtain from each * generator. number of samples = base ^ (# non fixed dimensions) * * @param ksb a double value */ public void setGeneratorSamplesBase(double ksb) { m_samplesBase = ksb; } /** * Get the base used for computing the number of samples to obtain from * each generator * * @return a double value */ public double getGeneratorSamplesBase() { return m_samplesBase; } /** * Set up the off screen bitmap for rendering to */ protected void initialize() { int iwidth = m_plotPanel.getWidth(); int iheight = m_plotPanel.getHeight(); // System.err.println(iwidth+" "+iheight); m_osi = m_plotPanel.createImage(iwidth, iheight); Graphics m = m_osi.getGraphics(); m.fillRect(0,0,iwidth,iheight); } /** * Stop the plotting thread */ public void stopPlotting() { m_stopPlotting = true; try { m_plotThread.join(100); } catch (Exception e){}; } /** Set up the bounds of our graphic based by finding the smallest reasonable area in the instance space to surround our data points. */ public void computeMinMaxAtts() { m_minX = Double.MAX_VALUE; m_minY = Double.MAX_VALUE; m_maxX = Double.MIN_VALUE; m_maxY = Double.MIN_VALUE; boolean allPointsLessThanOne = true; if (m_trainingData.numInstances() == 0) { m_minX = m_minY = 0.0; m_maxX = m_maxY = 1.0; } else { for (int i = 0; i < m_trainingData.numInstances(); i++) { Instance inst = m_trainingData.instance(i); double x = inst.value(m_xAttribute); double y = inst.value(m_yAttribute); if (!Utils.isMissingValue(x) && !Utils.isMissingValue(y)) { if (x < m_minX) { m_minX = x; } if (x > m_maxX) { m_maxX = x; } if (y < m_minY) { m_minY = y; } if (y > m_maxY) { m_maxY = y; } if (x > 1.0 || y > 1.0) allPointsLessThanOne = false; } } } if (m_minX == m_maxX) m_minX = 0; if (m_minY == m_maxY) m_minY = 0; if (m_minX == Double.MAX_VALUE) m_minX = 0; if (m_minY == Double.MAX_VALUE) m_minY = 0; if (m_maxX == Double.MIN_VALUE) m_maxX = 1; if (m_maxY == Double.MIN_VALUE) m_maxY = 1; if (allPointsLessThanOne) { m_minX = m_minY = 0.0; m_maxX = m_maxY = 1.0; } m_rangeX = (m_maxX - m_minX); m_rangeY = (m_maxY - m_minY); m_pixWidth = m_rangeX / (double)m_panelWidth; m_pixHeight = m_rangeY / (double) m_panelHeight; } /** * Return a random x attribute value contained within * the pix'th horizontal pixel * * @param pix the horizontal pixel number * @return a value in attribute space */ private double getRandomX(int pix) { double minPix = m_minX + (pix * m_pixWidth); return minPix + m_random.nextDouble() * m_pixWidth; } /** * Return a random y attribute value contained within * the pix'th vertical pixel * * @param pix the vertical pixel number * @return a value in attribute space */ private double getRandomY(int pix) { double minPix = m_minY + (pix * m_pixHeight); return minPix + m_random.nextDouble() * m_pixHeight; } /** * Start the plotting thread * * @exception Exception if an error occurs */ public void start() throws Exception { m_numOfSamplesPerGenerator = (int)Math.pow(m_samplesBase, m_trainingData.numAttributes()-3); m_stopReplotting = true; if (m_trainingData == null) { throw new Exception("No training data set (BoundaryPanel)"); } if (m_classifier == null) { throw new Exception("No classifier set (BoundaryPanel)"); } if (m_dataGenerator == null) { throw new Exception("No data generator set (BoundaryPanel)"); } if (m_trainingData.attribute(m_xAttribute).isNominal() || m_trainingData.attribute(m_yAttribute).isNominal()) { throw new Exception("Visualization dimensions must be numeric " +"(BoundaryPanel)"); } computeMinMaxAtts(); startPlotThread(); /*if (m_plotThread == null) { m_plotThread = new PlotThread(); m_plotThread.setPriority(Thread.MIN_PRIORITY); m_plotThread.start(); }*/ } // Thread for main plotting operation protected class PlotThread extends Thread { double [] m_weightingAttsValues; boolean [] m_attsToWeightOn; double [] m_vals; double [] m_dist; Instance m_predInst; public void run() { m_stopPlotting = false; try { initialize(); repaint(); // train the classifier m_probabilityCache = new double[m_panelHeight][m_panelWidth][]; m_classifier.buildClassifier(m_trainingData); // build DataGenerator m_attsToWeightOn = new boolean[m_trainingData.numAttributes()]; m_attsToWeightOn[m_xAttribute] = true; m_attsToWeightOn[m_yAttribute] = true; m_dataGenerator.setWeightingDimensions(m_attsToWeightOn); m_dataGenerator.buildGenerator(m_trainingData); // generate samples m_weightingAttsValues = new double [m_attsToWeightOn.length]; m_vals = new double[m_trainingData.numAttributes()]; m_predInst = new DenseInstance(1.0, m_vals); m_predInst.setDataset(m_trainingData); m_size = 1 << 4; // Current sample region size m_initialTiling = true; // Display the initial coarse image tiling. abortInitial: for (int i = 0; i <= m_panelHeight; i += m_size) { for (int j = 0; j <= m_panelWidth; j += m_size) { if (m_stopPlotting) { break abortInitial; } if (m_pausePlotting) { synchronized (m_dummy) { try { m_dummy.wait(); } catch (InterruptedException ex) { m_pausePlotting = false; } } } plotPoint(j, i, m_size, m_size, calculateRegionProbs(j, i), (j == 0)); } } if (!m_stopPlotting) { m_initialTiling = false; } // Sampling and gridding loop int size2 = m_size / 2; abortPlot: while (m_size > 1) { // Subdivide down to the pixel level for (int i = 0; i <= m_panelHeight; i += m_size) { for (int j = 0; j <= m_panelWidth; j += m_size) { if (m_stopPlotting) { break abortPlot; } if (m_pausePlotting) { synchronized (m_dummy) { try { m_dummy.wait(); } catch (InterruptedException ex) { m_pausePlotting = false; } } } boolean update = (j == 0 && i % 2 == 0); // Draw the three new subpixel regions plotPoint(j, i + size2, size2, size2, calculateRegionProbs(j, i + size2), update); plotPoint(j + size2, i + size2, size2, size2, calculateRegionProbs(j + size2, i + size2), update); plotPoint(j + size2, i, size2, size2, calculateRegionProbs(j + size2, i), update); } } // The new region edge length is half the old edge length m_size = size2; size2 = size2 / 2; } update(); /* // Old method without sampling. abortPlot: for (int i = 0; i < m_panelHeight; i++) { for (int j = 0; j < m_panelWidth; j++) { if (m_stopPlotting) { break abortPlot; } plotPoint(j, i, calculateRegionProbs(j, i), (j == 0)); } } */ if (m_plotTrainingData) { plotTrainingData(); } } catch (Exception ex) { ex.printStackTrace(); JOptionPane.showMessageDialog(null,"Error while plotting: \"" + ex.getMessage() + "\""); } finally { m_plotThread = null; // notify any listeners that we are finished Vector l; ActionEvent e = new ActionEvent(this, 0, ""); synchronized(this) { l = (Vector)m_listeners.clone(); } for (int i = 0; i < l.size(); i++) { ActionListener al = (ActionListener)l.elementAt(i); al.actionPerformed(e); } } } private double [] calculateRegionProbs(int j, int i) throws Exception { double [] sumOfProbsForRegion = new double [m_trainingData.classAttribute().numValues()]; for (int u = 0; u < m_numOfSamplesPerRegion; u++) { double [] sumOfProbsForLocation = new double [m_trainingData.classAttribute().numValues()]; m_weightingAttsValues[m_xAttribute] = getRandomX(j); m_weightingAttsValues[m_yAttribute] = getRandomY(m_panelHeight-i-1); m_dataGenerator.setWeightingValues(m_weightingAttsValues); double [] weights = m_dataGenerator.getWeights(); double sumOfWeights = Utils.sum(weights); int [] indices = Utils.sort(weights); // Prune 1% of weight mass int [] newIndices = new int[indices.length]; double sumSoFar = 0; double criticalMass = 0.99 * sumOfWeights; int index = weights.length - 1; int counter = 0; for (int z = weights.length - 1; z >= 0; z--) { newIndices[index--] = indices[z]; sumSoFar += weights[indices[z]]; counter++; if (sumSoFar > criticalMass) { break; } } indices = new int[counter]; System.arraycopy(newIndices, index + 1, indices, 0, counter); for (int z = 0; z < m_numOfSamplesPerGenerator; z++) { m_dataGenerator.setWeightingValues(m_weightingAttsValues); double [][] values = m_dataGenerator.generateInstances(indices); for (int q = 0; q < values.length; q++) { if (values[q] != null) { System.arraycopy(values[q], 0, m_vals, 0, m_vals.length); m_vals[m_xAttribute] = m_weightingAttsValues[m_xAttribute]; m_vals[m_yAttribute] = m_weightingAttsValues[m_yAttribute]; // classify the instance m_dist = m_classifier.distributionForInstance(m_predInst); for (int k = 0; k < sumOfProbsForLocation.length; k++) { sumOfProbsForLocation[k] += (m_dist[k] * weights[q]); } } } } for (int k = 0; k < sumOfProbsForRegion.length; k++) { sumOfProbsForRegion[k] += (sumOfProbsForLocation[k] * sumOfWeights); } } // average Utils.normalize(sumOfProbsForRegion); // cache if ((i < m_panelHeight) && (j < m_panelWidth)) { m_probabilityCache[i][j] = new double[sumOfProbsForRegion.length]; System.arraycopy(sumOfProbsForRegion, 0, m_probabilityCache[i][j], 0, sumOfProbsForRegion.length); } return sumOfProbsForRegion; } } /** Render the training points on-screen. */ public void plotTrainingData() { Graphics2D osg = (Graphics2D)m_osi.getGraphics(); Graphics g = m_plotPanel.getGraphics(); osg.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON); double xval = 0; double yval = 0; for (int i = 0; i < m_trainingData.numInstances(); i++) { if (!m_trainingData.instance(i).isMissing(m_xAttribute) && !m_trainingData.instance(i).isMissing(m_yAttribute)) { if (m_trainingData.instance(i).isMissing(m_classIndex)) //jimmy. continue; //don't plot if class is missing. TODO could we plot it differently instead? xval = m_trainingData.instance(i).value(m_xAttribute); yval = m_trainingData.instance(i).value(m_yAttribute); int panelX = convertToPanelX(xval); int panelY = convertToPanelY(yval); Color ColorToPlotWith = ((Color)m_Colors.elementAt((int)m_trainingData.instance(i). value(m_classIndex) % m_Colors.size())); if (ColorToPlotWith.equals(Color.white)) { osg.setColor(Color.black); } else { osg.setColor(Color.white); } osg.fillOval(panelX-3, panelY-3, 7, 7); osg.setColor(ColorToPlotWith); osg.fillOval(panelX-2, panelY-2, 5, 5); } } g.drawImage(m_osi,0,0,m_plotPanel); } /** Convert an X coordinate from the instance space to the panel space. */ private int convertToPanelX(double xval) { double temp = (xval - m_minX) / m_rangeX; temp = temp * (double) m_panelWidth; return (int)temp; } /** Convert a Y coordinate from the instance space to the panel space. */ private int convertToPanelY(double yval) { double temp = (yval - m_minY) / m_rangeY; temp = temp * (double) m_panelHeight; temp = m_panelHeight - temp; return (int)temp; } /** Convert an X coordinate from the panel space to the instance space. */ private double convertFromPanelX(double pX) { pX /= (double) m_panelWidth; pX *= m_rangeX; return pX + m_minX; } /** Convert a Y coordinate from the panel space to the instance space. */ private double convertFromPanelY(double pY) { pY = m_panelHeight - pY; pY /= (double) m_panelHeight; pY *= m_rangeY; return pY + m_minY; } /** Plot a point in our visualization on-screen. */ protected void plotPoint(int x, int y, double [] probs, boolean update) { plotPoint(x, y, 1, 1, probs, update); } /** Plot a point in our visualization on-screen. */ private void plotPoint(int x, int y, int width, int height, double [] probs, boolean update) { // draw a progress line Graphics osg = m_osi.getGraphics(); if (update) { osg.setXORMode(Color.white); osg.drawLine(0, y, m_panelWidth-1, y); update(); osg.drawLine(0, y, m_panelWidth-1, y); } // plot the point osg.setPaintMode(); float [] colVal = new float[3]; float [] tempCols = new float[3]; for (int k = 0; k < probs.length; k++) { Color curr = (Color)m_Colors.elementAt(k % m_Colors.size()); curr.getRGBColorComponents(tempCols); for (int z = 0 ; z < 3; z++) { colVal[z] += probs[k] * tempCols[z]; } } for (int z = 0; z < 3; z++) { if (colVal[z] < 0) { colVal[z] = 0; } else if (colVal[z] > 1) { colVal[z] = 1; } } osg.setColor(new Color(colVal[0], colVal[1], colVal[2])); osg.fillRect(x, y, width, height); } /** Update the rendered image. */ private void update() { Graphics g = m_plotPanel.getGraphics(); g.drawImage(m_osi, 0, 0, m_plotPanel); } /** * Set the training data to use * * @param trainingData the training data * @exception Exception if an error occurs */ public void setTrainingData(Instances trainingData) throws Exception { m_trainingData = trainingData; if (m_trainingData.classIndex() < 0) { throw new Exception("No class attribute set (BoundaryPanel)"); } m_classIndex = m_trainingData.classIndex(); } /** Adds a training instance to the visualization dataset. */ public void addTrainingInstance(Instance instance) { if (m_trainingData == null) { //TODO System.err.println("Trying to add to a null training set (BoundaryPanel)"); } m_trainingData.add(instance); } /** * Register a listener to be notified when plotting completes * * @param newListener the listener to add */ public void addActionListener(ActionListener newListener) { m_listeners.add(newListener); } /** * Remove a listener * * @param removeListener the listener to remove */ public void removeActionListener(ActionListener removeListener) { m_listeners.removeElement(removeListener); } /** * Set the classifier to use. * * @param classifier the classifier to use */ public void setClassifier(Classifier classifier) { m_classifier = classifier; } /** * Set the data generator to use for generating new instances * * @param dataGenerator the data generator to use */ public void setDataGenerator(DataGenerator dataGenerator) { m_dataGenerator = dataGenerator; } /** * Set the x attribute index * * @param xatt index of the attribute to use on the x axis * @exception Exception if an error occurs */ public void setXAttribute(int xatt) throws Exception { if (m_trainingData == null) { throw new Exception("No training data set (BoundaryPanel)"); } if (xatt < 0 || xatt > m_trainingData.numAttributes()) { throw new Exception("X attribute out of range (BoundaryPanel)"); } if (m_trainingData.attribute(xatt).isNominal()) { throw new Exception("Visualization dimensions must be numeric " +"(BoundaryPanel)"); } /*if (m_trainingData.numDistinctValues(xatt) < 2) { throw new Exception("Too few distinct values for X attribute " +"(BoundaryPanel)"); }*/ //removed by jimmy. TESTING! m_xAttribute = xatt; } /** * Set the y attribute index * * @param yatt index of the attribute to use on the y axis * @exception Exception if an error occurs */ public void setYAttribute(int yatt) throws Exception { if (m_trainingData == null) { throw new Exception("No training data set (BoundaryPanel)"); } if (yatt < 0 || yatt > m_trainingData.numAttributes()) { throw new Exception("X attribute out of range (BoundaryPanel)"); } if (m_trainingData.attribute(yatt).isNominal()) { throw new Exception("Visualization dimensions must be numeric " +"(BoundaryPanel)"); } /*if (m_trainingData.numDistinctValues(yatt) < 2) { throw new Exception("Too few distinct values for Y attribute " +"(BoundaryPanel)"); }*/ //removed by jimmy. TESTING! m_yAttribute = yatt; } /** * Set a vector of Color objects for the classes * * @param colors a FastVector value */ public void setColors(FastVector colors) { synchronized (m_Colors) { m_Colors = colors; } //replot(); //commented by jimmy update(); //added by jimmy } /** * Set whether to superimpose the training data * plot * * @param pg a boolean value */ public void setPlotTrainingData(boolean pg) { m_plotTrainingData = pg; } /** * Returns true if training data is to be superimposed * * @return a boolean value */ public boolean getPlotTrainingData() { return m_plotTrainingData; } /** * Get the current vector of Color objects used for the classes * * @return a FastVector value */ public FastVector getColors() { return m_Colors; } /** * Quickly replot the display using cached probability estimates */ public void replot() { if (m_probabilityCache[0][0] == null) { return; } m_stopReplotting = true; m_pausePlotting = true; // wait 300 ms to give any other replot threads a chance to halt try { Thread.sleep(300); } catch (Exception ex) {} final Thread replotThread = new Thread() { public void run() { m_stopReplotting = false; int size2 = m_size / 2; finishedReplot: for (int i = 0; i < m_panelHeight; i += m_size) { for (int j = 0; j < m_panelWidth; j += m_size) { if (m_probabilityCache[i][j] == null || m_stopReplotting) { break finishedReplot; } boolean update = (j == 0 && i % 2 == 0); if (i < m_panelHeight && j < m_panelWidth) { // Draw the three new subpixel regions or single course tiling if (m_initialTiling || m_size == 1) { if (m_probabilityCache[i][j] == null) { break finishedReplot; } plotPoint(j, i, m_size, m_size, m_probabilityCache[i][j], update); } else { if (m_probabilityCache[i+size2][j] == null) { break finishedReplot; } plotPoint(j, i + size2, size2, size2, m_probabilityCache[i + size2][j], update); if (m_probabilityCache[i+size2][j+size2] == null) { break finishedReplot; } plotPoint(j + size2, i + size2, size2, size2, m_probabilityCache[i + size2][j + size2], update); if (m_probabilityCache[i][j+size2] == null) { break finishedReplot; } plotPoint(j + size2, i, size2, size2, m_probabilityCache[i + size2][j], update); } } } } update(); if (m_plotTrainingData) { plotTrainingData(); } m_pausePlotting = false; if (!m_stopPlotting) { synchronized (m_dummy) { m_dummy.notifyAll(); } } } }; replotThread.start(); } protected void saveImage(String fileName) { BufferedImage bi; Graphics2D gr2; ImageWriter writer; Iterator iter; ImageOutputStream ios; ImageWriteParam param; try { // render image bi = new BufferedImage(m_panelWidth, m_panelHeight, BufferedImage.TYPE_INT_RGB); gr2 = bi.createGraphics(); gr2.drawImage(m_osi, 0, 0, m_panelWidth, m_panelHeight, null); // get jpeg writer writer = null; iter = ImageIO.getImageWritersByFormatName("jpg"); if (iter.hasNext()) writer = (ImageWriter) iter.next(); else throw new Exception("No JPEG writer available!"); // prepare output file ios = ImageIO.createImageOutputStream(new File(fileName)); writer.setOutput(ios); // set the quality param = new JPEGImageWriteParam(Locale.getDefault()); param.setCompressionMode(ImageWriteParam.MODE_EXPLICIT) ; param.setCompressionQuality(1.0f); // write the image writer.write(null, new IIOImage(bi, null, null), param); // cleanup ios.flush(); writer.dispose(); ios.close(); } catch (Exception e) { e.printStackTrace(); } } /** Adds a training instance to our dataset, based on the coordinates of the mouse on the panel. This method sets the x and y attributes and the class (as defined by classAttIndex), and sets all other values as Missing. * @param mouseX the x coordinate of the mouse, in pixels. * @param mouseY the y coordinate of the mouse, in pixels. * @param classAttIndex the index of the attribute that is currently selected as the class attribute. * @param classValue the value to set the class to in our new point. */ public void addTrainingInstanceFromMouseLocation(int mouseX, int mouseY, int classAttIndex, double classValue) { //convert to coordinates in the training instance space. double x = convertFromPanelX(mouseX); double y = convertFromPanelY(mouseY); //build the training instance Instance newInstance = new DenseInstance(m_trainingData.numAttributes()); for (int i = 0; i < newInstance.numAttributes(); i++) { if (i == classAttIndex) { newInstance.setValue(i,classValue); } else if (i == m_xAttribute) newInstance.setValue(i,x); else if (i == m_yAttribute) newInstance.setValue(i,y); else newInstance.setMissing(i); } //add it to our data set. addTrainingInstance(newInstance); } /** Deletes all training instances from our dataset. */ public void removeAllInstances() { if (m_trainingData != null) { m_trainingData.delete(); try { initialize();} catch (Exception e) {}; } } /** Removes a single training instance from our dataset, if there is one that is close enough to the specified mouse location. */ public void removeTrainingInstanceFromMouseLocation(int mouseX, int mouseY) { //convert to coordinates in the training instance space. double x = convertFromPanelX(mouseX); double y = convertFromPanelY(mouseY); int bestIndex = -1; double bestDistanceBetween = Integer.MAX_VALUE; //find the closest point. for (int i = 0; i < m_trainingData.numInstances(); i++) { Instance current = m_trainingData.instance(i); double distanceBetween = (current.value(m_xAttribute) - x) * (current.value(m_xAttribute) - x) + (current.value(m_yAttribute) - y) * (current.value(m_yAttribute) - y); // won't bother to sqrt, just used square values. if (distanceBetween < bestDistanceBetween) { bestIndex = i; bestDistanceBetween = distanceBetween; } } if (bestIndex == -1) return; Instance best = m_trainingData.instance(bestIndex); double panelDistance = (convertToPanelX(best.value(m_xAttribute)) - mouseX) * (convertToPanelX(best.value(m_xAttribute)) - mouseX) + (convertToPanelY(best.value(m_yAttribute)) - mouseY) * (convertToPanelY(best.value(m_yAttribute)) - mouseY); if (panelDistance < REMOVE_POINT_RADIUS * REMOVE_POINT_RADIUS) {//the best point is close enough. (using squared distances) m_trainingData.delete(bestIndex); } } /** Starts the plotting thread. Will also create it if necessary. */ public void startPlotThread() { if (m_plotThread == null) { //jimmy m_plotThread = new PlotThread(); m_plotThread.setPriority(Thread.MIN_PRIORITY); m_plotThread.start(); } } /** Adds a mouse listener. */ public void addMouseListener(MouseListener l) { m_plotPanel.addMouseListener(l); } /** Gets the minimum x-coordinate bound, in training-instance units (not mouse coordinates). */ public double getMinXBound() { return m_minX; } /** Gets the minimum y-coordinate bound, in training-instance units (not mouse coordinates). */ public double getMinYBound() { return m_minY; } /** Gets the maximum x-coordinate bound, in training-instance units (not mouse coordinates). */ public double getMaxXBound() { return m_maxX; } /** Gets the maximum x-coordinate bound, in training-instance units (not mouse coordinates). */ public double getMaxYBound() { return m_maxY; } /** * Main method for testing this class * * @param args a String[] value */ public static void main (String [] args) { try { if (args.length < 8) { System.err.println("Usage : BoundaryPanel " +" " +" <# loc/pixel> " +" " +" "); System.exit(1); } final javax.swing.JFrame jf = new javax.swing.JFrame("Weka classification boundary visualizer"); jf.getContentPane().setLayout(new BorderLayout()); System.err.println("Loading instances from : "+args[0]); java.io.Reader r = new java.io.BufferedReader( new java.io.FileReader(args[0])); final Instances i = new Instances(r); i.setClassIndex(Integer.parseInt(args[1])); // bv.setClassifier(new Logistic()); final int xatt = Integer.parseInt(args[2]); final int yatt = Integer.parseInt(args[3]); int base = Integer.parseInt(args[4]); int loc = Integer.parseInt(args[5]); int bandWidth = Integer.parseInt(args[6]); int panelWidth = Integer.parseInt(args[7]); int panelHeight = Integer.parseInt(args[8]); final String classifierName = args[9]; final BoundaryPanel bv = new BoundaryPanel(panelWidth,panelHeight); bv.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { String classifierNameNew = classifierName.substring(classifierName.lastIndexOf('.')+1, classifierName.length()); bv.saveImage(classifierNameNew+"_"+i.relationName() +"_X"+xatt+"_Y"+yatt+".jpg"); } }); jf.getContentPane().add(bv, BorderLayout.CENTER); jf.setSize(bv.getMinimumSize()); // jf.setSize(200,200); jf.addWindowListener(new java.awt.event.WindowAdapter() { public void windowClosing(java.awt.event.WindowEvent e) { jf.dispose(); System.exit(0); } }); jf.pack(); jf.setVisible(true); // bv.initialize(); bv.repaint(); String [] argsR = null; if (args.length > 10) { argsR = new String [args.length-10]; for (int j = 10; j < args.length; j++) { argsR[j-10] = args[j]; } } Classifier c = AbstractClassifier.forName(args[9], argsR); KDDataGenerator dataGen = new KDDataGenerator(); dataGen.setKernelBandwidth(bandWidth); bv.setDataGenerator(dataGen); bv.setNumSamplesPerRegion(loc); bv.setGeneratorSamplesBase(base); bv.setClassifier(c); bv.setTrainingData(i); bv.setXAttribute(xatt); bv.setYAttribute(yatt); try { // try and load a color map if one exists FileInputStream fis = new FileInputStream("colors.ser"); ObjectInputStream ois = new ObjectInputStream(fis); FastVector colors = (FastVector)ois.readObject(); bv.setColors(colors); } catch (Exception ex) { System.err.println("No color map file"); } bv.start(); } catch (Exception ex) { ex.printStackTrace(); } } }