source: src/main/java/weka/gui/boundaryvisualizer/BoundaryPanel.java @ 15

Last change on this file since 15 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 36.3 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *   BoundaryPanel.java
19 *   Copyright (C) 2002 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.gui.boundaryvisualizer;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.core.FastVector;
28import weka.core.Instance;
29import weka.core.DenseInstance;
30import weka.core.Instances;
31import weka.core.Utils;
32import weka.gui.visualize.JPEGWriter;
33
34import java.awt.BorderLayout;
35import java.awt.Color;
36import java.awt.Dimension;
37import java.awt.Graphics;
38import java.awt.Graphics2D;
39import java.awt.Image;
40import java.awt.RenderingHints;
41import java.awt.event.ActionEvent;
42import java.awt.event.ActionListener;
43import java.awt.event.MouseEvent;
44import java.awt.event.MouseListener;
45import java.awt.image.BufferedImage;
46import java.io.File;
47import java.io.FileInputStream;
48import java.io.ObjectInputStream;
49import java.util.Iterator;
50import java.util.Locale;
51import java.util.Random;
52import java.util.Vector;
53
54import javax.imageio.IIOImage;
55import javax.imageio.ImageIO;
56import javax.imageio.ImageWriteParam;
57import javax.imageio.ImageWriter;
58import javax.imageio.plugins.jpeg.JPEGImageWriteParam;
59import javax.imageio.stream.ImageOutputStream;
60import javax.swing.JOptionPane;
61import javax.swing.JPanel;
62import javax.swing.ToolTipManager;
63
64/**
65 * BoundaryPanel. A class to handle the plotting operations
66 * associated with generating a 2D picture of a classifier's decision
67 * boundaries.
68 *
69 * @author <a href="mailto:mhall@cs.waikato.ac.nz">Mark Hall</a>
70 * @version $Revision: 5987 $
71 * @since 1.0
72 * @see JPanel
73 */
74public class BoundaryPanel
75  extends JPanel {
76
77  /** for serialization */
78  private static final long serialVersionUID = -8499445518744770458L;
79 
80  /** default colours for classes */
81  public static final Color [] DEFAULT_COLORS = {
82    Color.red,
83    Color.green,
84    Color.blue,
85    new Color(0, 255, 255), // cyan
86    new Color(255, 0, 255), // pink
87    new Color(255, 255, 0), // yellow
88    new Color(255, 255, 255), //white
89    new Color(0, 0, 0)};
90   
91  /** The distance we can click away from a point in the GUI and still remove it. */
92  public static final double REMOVE_POINT_RADIUS = 7.0;
93
94  protected FastVector m_Colors = new FastVector();
95
96  /** training data */
97  protected Instances m_trainingData;
98
99  /** distribution classifier to use */
100  protected Classifier m_classifier;
101
102  /** data generator to use */
103  protected DataGenerator m_dataGenerator;
104
105  /** index of the class attribute */
106  private int m_classIndex = -1;
107
108  // attributes for visualizing on
109  protected int m_xAttribute;
110  protected int m_yAttribute;
111
112  // min, max and ranges of these attributes
113  protected double m_minX;
114  protected double m_minY;
115  protected double m_maxX;
116  protected double m_maxY;
117  private double m_rangeX;
118  private double m_rangeY;
119
120  // pixel width and height in terms of attribute values
121  protected double m_pixHeight;
122  protected double m_pixWidth;
123
124  /** used for offscreen drawing */
125  protected Image m_osi = null;
126
127  // width and height of the display area
128  protected int m_panelWidth;
129  protected int m_panelHeight;
130
131  // number of samples to take from each region in the fixed dimensions
132  protected int m_numOfSamplesPerRegion = 2;
133
134  // number of samples per kernel = base ^ (# non-fixed dimensions)
135  protected int m_numOfSamplesPerGenerator;
136  protected double m_samplesBase = 2.0;
137
138  /** listeners to be notified when plot is complete */
139  private Vector m_listeners = new Vector();
140
141  /**
142   * small inner class for rendering the bitmap on to
143   */
144  private class PlotPanel
145    extends JPanel {
146
147    /** for serialization */
148    private static final long serialVersionUID = 743629498352235060L;
149   
150    public PlotPanel() {
151      this.setToolTipText("");
152    }
153   
154    public void paintComponent(Graphics g) {
155      super.paintComponent(g);
156      if (m_osi != null) {
157        g.drawImage(m_osi,0,0,this);
158      }
159    }
160   
161    public String getToolTipText(MouseEvent event) {
162      if (m_probabilityCache == null) {
163        return null;
164      }
165     
166      if (m_probabilityCache[event.getY()][event.getX()] == null) {
167        return null;
168      }
169     
170      String pVec = "(X: "
171        +Utils.doubleToString(convertFromPanelX((double)event.getX()), 2)
172        +" Y: "
173        +Utils.doubleToString(convertFromPanelY((double)event.getY()), 2)+") ";
174      // construct a string holding the probability vector
175      for (int i = 0; i < m_trainingData.classAttribute().numValues(); i++) {
176        pVec += 
177          Utils.
178          doubleToString(m_probabilityCache[event.getY()][event.getX()][i],
179                         3)+" ";
180      }
181      return pVec;
182    }
183  }
184
185  /** the actual plotting area */
186  private PlotPanel m_plotPanel = new PlotPanel();
187
188  /** thread for running the plotting operation in */
189  private Thread m_plotThread = null;
190
191  /** Stop the plotting thread */
192  protected boolean m_stopPlotting = false;
193
194  /** Stop any replotting threads */
195  protected boolean m_stopReplotting = false;
196
197  // Used by replotting threads to pause and resume the main plot thread
198  private Double m_dummy = new Double(1.0);
199  private boolean m_pausePlotting = false;
200  /** what size of tile is currently being plotted */
201  private int m_size = 1;
202  /** is the main plot thread performing the initial coarse tiling */
203  private boolean m_initialTiling;
204
205  /** A random number generator  */
206  private Random m_random = null;
207
208  /** cache of probabilities for fast replotting */
209  protected double [][][] m_probabilityCache;
210
211  /** plot the training data */
212  protected boolean m_plotTrainingData = true;
213
214  /**
215   * Creates a new <code>BoundaryPanel</code> instance.
216   *
217   * @param panelWidth the width in pixels of the panel
218   * @param panelHeight the height in pixels of the panel
219   */
220  public BoundaryPanel(int panelWidth, int panelHeight) {
221    ToolTipManager.sharedInstance().setDismissDelay(Integer.MAX_VALUE);
222    m_panelWidth = panelWidth;
223    m_panelHeight = panelHeight;
224    setLayout(new BorderLayout());
225    m_plotPanel.setMinimumSize(new Dimension(m_panelWidth, m_panelHeight));
226    m_plotPanel.setPreferredSize(new Dimension(m_panelWidth, m_panelHeight));
227    m_plotPanel.setMaximumSize(new Dimension(m_panelWidth, m_panelHeight));
228    add(m_plotPanel, BorderLayout.CENTER);
229    setPreferredSize(m_plotPanel.getPreferredSize());
230    setMaximumSize(m_plotPanel.getMaximumSize());
231    setMinimumSize(m_plotPanel.getMinimumSize());
232
233    m_random = new Random(1);
234    for (int i = 0; i < DEFAULT_COLORS.length; i++) {
235      m_Colors.addElement(new Color(DEFAULT_COLORS[i].getRed(),
236                                    DEFAULT_COLORS[i].getGreen(),
237                                    DEFAULT_COLORS[i].getBlue()));
238    }
239    m_probabilityCache = new double[m_panelHeight][m_panelWidth][];
240   
241  }
242
243  /**
244   * Set the number of points to uniformly sample from a region (fixed
245   * dimensions).
246   *
247   * @param num an <code>int</code> value
248   */
249  public void setNumSamplesPerRegion(int num) {
250    m_numOfSamplesPerRegion = num;
251  }
252
253  /**
254   * Get the number of points to sample from a region (fixed dimensions).
255   *
256   * @return an <code>int</code> value
257   */
258  public int getNumSamplesPerRegion() {
259    return m_numOfSamplesPerRegion;
260  }
261
262  /**
263   * Set the base for computing the number of samples to obtain from each
264   * generator. number of samples = base ^ (# non fixed dimensions)
265   *
266   * @param ksb a <code>double</code> value
267   */
268  public void setGeneratorSamplesBase(double ksb) {
269    m_samplesBase = ksb;
270  }
271
272  /**
273   * Get the base used for computing the number of samples to obtain from
274   * each generator
275   *
276   * @return a <code>double</code> value
277   */
278  public double getGeneratorSamplesBase() {
279    return m_samplesBase;
280  }
281
282  /**
283   * Set up the off screen bitmap for rendering to
284   */
285  protected void initialize() {
286    int iwidth = m_plotPanel.getWidth();
287    int iheight = m_plotPanel.getHeight();
288    //    System.err.println(iwidth+" "+iheight);
289    m_osi = m_plotPanel.createImage(iwidth, iheight);
290    Graphics m = m_osi.getGraphics();
291    m.fillRect(0,0,iwidth,iheight);
292  }
293
294  /**
295   * Stop the plotting thread
296   */
297  public void stopPlotting() {
298    m_stopPlotting = true;
299    try {
300        m_plotThread.join(100);
301    } catch (Exception e){};
302  }
303
304  /** Set up the bounds of our graphic based by finding the smallest reasonable
305      area in the instance space to surround our data points.
306  */
307  public void computeMinMaxAtts() {
308    m_minX = Double.MAX_VALUE;
309    m_minY = Double.MAX_VALUE;
310    m_maxX = Double.MIN_VALUE;
311    m_maxY = Double.MIN_VALUE;
312   
313    boolean allPointsLessThanOne = true;
314   
315    if (m_trainingData.numInstances() == 0) {
316      m_minX = m_minY = 0.0;
317      m_maxX = m_maxY = 1.0;
318    }
319    else
320    {
321        for (int i = 0; i < m_trainingData.numInstances(); i++) {
322                Instance inst = m_trainingData.instance(i);
323                double x = inst.value(m_xAttribute);
324                double y = inst.value(m_yAttribute);
325                if (!Utils.isMissingValue(x) && !Utils.isMissingValue(y)) {
326                        if (x < m_minX) {
327                        m_minX = x;
328                        }
329                        if (x > m_maxX) {
330                        m_maxX = x;
331                        }
332               
333                        if (y < m_minY) {
334                        m_minY = y;
335                        }
336                        if (y > m_maxY) {
337                        m_maxY = y;
338                        }
339                        if (x > 1.0 || y > 1.0)
340                                allPointsLessThanOne = false;
341                }
342        }
343    }
344   
345    if (m_minX == m_maxX)
346        m_minX = 0;
347    if (m_minY == m_maxY)
348        m_minY = 0;
349    if (m_minX == Double.MAX_VALUE)
350        m_minX = 0;
351    if (m_minY == Double.MAX_VALUE)
352        m_minY = 0;
353    if (m_maxX == Double.MIN_VALUE)
354        m_maxX = 1;
355    if (m_maxY == Double.MIN_VALUE)
356        m_maxY = 1;
357    if (allPointsLessThanOne) {
358        m_minX = m_minY = 0.0;
359        m_maxX = m_maxY = 1.0;
360    }
361   
362   
363   
364    m_rangeX = (m_maxX - m_minX);
365    m_rangeY = (m_maxY - m_minY);
366   
367    m_pixWidth = m_rangeX / (double)m_panelWidth;
368    m_pixHeight = m_rangeY / (double) m_panelHeight;
369  }
370
371  /**
372   * Return a random x attribute value contained within
373   * the pix'th horizontal pixel
374   *
375   * @param pix the horizontal pixel number
376   * @return a value in attribute space
377   */
378  private double getRandomX(int pix) {
379
380    double minPix =  m_minX + (pix * m_pixWidth);
381
382    return minPix + m_random.nextDouble() * m_pixWidth;
383  }
384
385  /**
386   * Return a random y attribute value contained within
387   * the pix'th vertical pixel
388   *
389   * @param pix the vertical pixel number
390   * @return a value in attribute space
391   */
392  private double getRandomY(int pix) {
393   
394    double minPix = m_minY + (pix * m_pixHeight);
395   
396    return minPix +  m_random.nextDouble() * m_pixHeight;
397  }
398 
399  /**
400   * Start the plotting thread
401   *
402   * @exception Exception if an error occurs
403   */
404  public void start() throws Exception {
405    m_numOfSamplesPerGenerator = 
406      (int)Math.pow(m_samplesBase, m_trainingData.numAttributes()-3);
407
408    m_stopReplotting = true;
409    if (m_trainingData == null) {
410      throw new Exception("No training data set (BoundaryPanel)");
411    }
412    if (m_classifier == null) {
413      throw new Exception("No classifier set (BoundaryPanel)");
414    }
415    if (m_dataGenerator == null) {
416      throw new Exception("No data generator set (BoundaryPanel)");
417    }
418    if (m_trainingData.attribute(m_xAttribute).isNominal() || 
419        m_trainingData.attribute(m_yAttribute).isNominal()) {
420      throw new Exception("Visualization dimensions must be numeric "
421                          +"(BoundaryPanel)");
422    }
423   
424    computeMinMaxAtts();
425   
426    startPlotThread();
427    /*if (m_plotThread == null) {
428      m_plotThread = new PlotThread();
429      m_plotThread.setPriority(Thread.MIN_PRIORITY);
430      m_plotThread.start();
431    }*/
432  }
433 
434  // Thread for main plotting operation
435  protected class PlotThread extends Thread {
436    double [] m_weightingAttsValues;
437    boolean [] m_attsToWeightOn;
438    double [] m_vals;
439    double [] m_dist;
440    Instance m_predInst;
441    public void run() {
442
443      m_stopPlotting = false;
444      try {
445        initialize();
446        repaint();
447       
448        // train the classifier
449        m_probabilityCache = new double[m_panelHeight][m_panelWidth][];
450        m_classifier.buildClassifier(m_trainingData);
451       
452        // build DataGenerator
453        m_attsToWeightOn = new boolean[m_trainingData.numAttributes()];
454        m_attsToWeightOn[m_xAttribute] = true;
455        m_attsToWeightOn[m_yAttribute] = true;
456             
457        m_dataGenerator.setWeightingDimensions(m_attsToWeightOn);
458             
459        m_dataGenerator.buildGenerator(m_trainingData);
460
461        // generate samples
462        m_weightingAttsValues = new double [m_attsToWeightOn.length];
463        m_vals = new double[m_trainingData.numAttributes()];
464        m_predInst = new DenseInstance(1.0, m_vals);
465        m_predInst.setDataset(m_trainingData);
466
467       
468        m_size = 1 << 4;  // Current sample region size
469       
470        m_initialTiling = true;
471        // Display the initial coarse image tiling.
472      abortInitial:
473        for (int i = 0; i <= m_panelHeight; i += m_size) {   
474          for (int j = 0; j <= m_panelWidth; j += m_size) {   
475            if (m_stopPlotting) {
476              break abortInitial;
477            }
478            if (m_pausePlotting) {
479              synchronized (m_dummy) {
480                try {
481                  m_dummy.wait();
482                } catch (InterruptedException ex) {
483                  m_pausePlotting = false;
484                }
485              }
486            }
487            plotPoint(j, i, m_size, m_size, 
488                      calculateRegionProbs(j, i), (j == 0));
489          }
490        }
491        if (!m_stopPlotting) {
492          m_initialTiling = false;
493        }
494       
495        // Sampling and gridding loop
496        int size2 = m_size / 2;
497        abortPlot: 
498        while (m_size > 1) { // Subdivide down to the pixel level
499          for (int i = 0; i <= m_panelHeight; i += m_size) {
500            for (int j = 0; j <= m_panelWidth; j += m_size) {
501              if (m_stopPlotting) {
502                break abortPlot;
503              }
504              if (m_pausePlotting) {
505                synchronized (m_dummy) {
506                  try {
507                    m_dummy.wait();
508                  } catch (InterruptedException ex) {
509                    m_pausePlotting = false;
510                  }
511                }
512              }
513              boolean update = (j == 0 && i % 2 == 0);
514              // Draw the three new subpixel regions
515              plotPoint(j, i + size2, size2, size2, 
516                        calculateRegionProbs(j, i + size2), update);
517              plotPoint(j + size2, i + size2, size2, size2, 
518                        calculateRegionProbs(j + size2, i + size2), update);
519              plotPoint(j + size2, i, size2, size2, 
520                        calculateRegionProbs(j + size2, i), update);
521            }
522          }
523          // The new region edge length is half the old edge length
524          m_size = size2;
525          size2 = size2 / 2;
526        }
527        update();
528       
529
530        /*
531        // Old method without sampling.
532        abortPlot:
533        for (int i = 0; i < m_panelHeight; i++) {
534          for (int j = 0; j < m_panelWidth; j++) {
535            if (m_stopPlotting) {
536              break abortPlot;
537            }
538            plotPoint(j, i, calculateRegionProbs(j, i), (j == 0));
539          }
540        }
541        */
542
543
544        if (m_plotTrainingData) {
545          plotTrainingData();
546        }
547             
548      } catch (Exception ex) {
549        ex.printStackTrace();
550        JOptionPane.showMessageDialog(null,"Error while plotting: \"" + ex.getMessage() + "\"");
551      } finally {
552        m_plotThread = null;
553        // notify any listeners that we are finished
554        Vector l;
555        ActionEvent e = new ActionEvent(this, 0, "");
556        synchronized(this) {
557          l = (Vector)m_listeners.clone();
558        }
559        for (int i = 0; i < l.size(); i++) {
560          ActionListener al = (ActionListener)l.elementAt(i);
561          al.actionPerformed(e);
562        }
563      }
564    }
565   
566    private double [] calculateRegionProbs(int j, int i) throws Exception {
567      double [] sumOfProbsForRegion = 
568        new double [m_trainingData.classAttribute().numValues()];
569
570      for (int u = 0; u < m_numOfSamplesPerRegion; u++) {
571     
572        double [] sumOfProbsForLocation = 
573          new double [m_trainingData.classAttribute().numValues()];
574     
575        m_weightingAttsValues[m_xAttribute] = getRandomX(j);
576        m_weightingAttsValues[m_yAttribute] = getRandomY(m_panelHeight-i-1);
577     
578        m_dataGenerator.setWeightingValues(m_weightingAttsValues);
579     
580        double [] weights = m_dataGenerator.getWeights();
581        double sumOfWeights = Utils.sum(weights);
582        int [] indices = Utils.sort(weights);
583     
584        // Prune 1% of weight mass
585        int [] newIndices = new int[indices.length];
586        double sumSoFar = 0; 
587        double criticalMass = 0.99 * sumOfWeights;
588        int index = weights.length - 1; int counter = 0;
589        for (int z = weights.length - 1; z >= 0; z--) {
590          newIndices[index--] = indices[z];
591          sumSoFar += weights[indices[z]];
592          counter++;
593          if (sumSoFar > criticalMass) {
594            break;
595          }
596        }
597        indices = new int[counter];
598        System.arraycopy(newIndices, index + 1, indices, 0, counter);
599     
600        for (int z = 0; z < m_numOfSamplesPerGenerator; z++) {
601       
602          m_dataGenerator.setWeightingValues(m_weightingAttsValues);
603          double [][] values = m_dataGenerator.generateInstances(indices);
604       
605          for (int q = 0; q < values.length; q++) {
606            if (values[q] != null) {
607              System.arraycopy(values[q], 0, m_vals, 0, m_vals.length);
608              m_vals[m_xAttribute] = m_weightingAttsValues[m_xAttribute];
609              m_vals[m_yAttribute] = m_weightingAttsValues[m_yAttribute];
610           
611              // classify the instance
612              m_dist = m_classifier.distributionForInstance(m_predInst);
613              for (int k = 0; k < sumOfProbsForLocation.length; k++) {
614                sumOfProbsForLocation[k] += (m_dist[k] * weights[q]); 
615              }
616            }
617          }
618        }
619     
620        for (int k = 0; k < sumOfProbsForRegion.length; k++) {
621          sumOfProbsForRegion[k] += (sumOfProbsForLocation[k] * 
622                                     sumOfWeights); 
623        }
624      }
625   
626      // average
627      Utils.normalize(sumOfProbsForRegion);
628
629      // cache
630      if ((i < m_panelHeight) && (j < m_panelWidth)) {
631        m_probabilityCache[i][j] = new double[sumOfProbsForRegion.length];
632        System.arraycopy(sumOfProbsForRegion, 0, m_probabilityCache[i][j], 
633                         0, sumOfProbsForRegion.length);
634      }
635               
636      return sumOfProbsForRegion;
637    }
638  }
639
640  /** Render the training points on-screen.
641  */
642  public void plotTrainingData() {
643   
644    Graphics2D osg = (Graphics2D)m_osi.getGraphics();
645    Graphics g = m_plotPanel.getGraphics();
646    osg.setRenderingHint(RenderingHints.KEY_ANTIALIASING,
647                         RenderingHints.VALUE_ANTIALIAS_ON);
648    double xval = 0; double yval = 0;
649   
650    for (int i = 0; i < m_trainingData.numInstances(); i++) {
651      if (!m_trainingData.instance(i).isMissing(m_xAttribute) &&
652          !m_trainingData.instance(i).isMissing(m_yAttribute)) {
653         
654        if (m_trainingData.instance(i).isMissing(m_classIndex)) //jimmy.
655                continue; //don't plot if class is missing. TODO could we plot it differently instead?
656       
657        xval = m_trainingData.instance(i).value(m_xAttribute);
658        yval = m_trainingData.instance(i).value(m_yAttribute);
659       
660        int panelX = convertToPanelX(xval);
661        int panelY = convertToPanelY(yval);
662        Color ColorToPlotWith = 
663          ((Color)m_Colors.elementAt((int)m_trainingData.instance(i).
664                                     value(m_classIndex) % m_Colors.size()));
665       
666        if (ColorToPlotWith.equals(Color.white)) {
667          osg.setColor(Color.black);
668        } else {
669          osg.setColor(Color.white);
670        }
671        osg.fillOval(panelX-3, panelY-3, 7, 7);
672        osg.setColor(ColorToPlotWith);
673        osg.fillOval(panelX-2, panelY-2, 5, 5);
674      }
675    }
676    g.drawImage(m_osi,0,0,m_plotPanel);
677  }
678 
679  /** Convert an X coordinate from the instance space to the panel space.
680  */
681  private int convertToPanelX(double xval) {
682    double temp = (xval - m_minX) / m_rangeX;
683    temp = temp * (double) m_panelWidth;
684
685    return (int)temp;
686  }
687
688  /** Convert a Y coordinate from the instance space to the panel space.
689  */
690  private int convertToPanelY(double yval) {
691    double temp = (yval - m_minY) / m_rangeY;
692    temp = temp * (double) m_panelHeight;
693    temp = m_panelHeight - temp;
694   
695    return (int)temp;
696  }
697 
698  /** Convert an X coordinate from the panel space to the instance space.
699  */
700  private double convertFromPanelX(double pX) {
701    pX /= (double) m_panelWidth;
702    pX *= m_rangeX;
703    return pX + m_minX;
704  }
705
706  /** Convert a Y coordinate from the panel space to the instance space.
707  */
708  private double convertFromPanelY(double pY) {
709    pY  = m_panelHeight - pY;
710    pY /= (double) m_panelHeight;
711    pY *= m_rangeY;
712   
713    return pY + m_minY;
714  }
715
716
717  /** Plot a point in our visualization on-screen.
718  */
719  protected  void plotPoint(int x, int y, double [] probs, boolean update) {
720    plotPoint(x, y, 1, 1, probs, update);
721  }
722 
723  /** Plot a point in our visualization on-screen.
724  */
725  private void plotPoint(int x, int y, int width, int height, 
726                         double [] probs, boolean update) {
727
728    // draw a progress line
729    Graphics osg = m_osi.getGraphics();
730    if (update) {
731      osg.setXORMode(Color.white);
732      osg.drawLine(0, y, m_panelWidth-1, y);
733      update();
734      osg.drawLine(0, y, m_panelWidth-1, y);
735    }
736
737    // plot the point
738    osg.setPaintMode();
739    float [] colVal = new float[3];
740   
741    float [] tempCols = new float[3];
742    for (int k = 0; k < probs.length; k++) {
743      Color curr = (Color)m_Colors.elementAt(k % m_Colors.size());
744
745      curr.getRGBColorComponents(tempCols);
746      for (int z = 0 ; z < 3; z++) {
747        colVal[z] += probs[k] * tempCols[z];
748      }
749    }
750
751    for (int z = 0; z < 3; z++) {
752      if (colVal[z] < 0) {
753        colVal[z] = 0;
754      } else if (colVal[z] > 1) {
755        colVal[z] = 1;
756      }
757    }
758   
759    osg.setColor(new Color(colVal[0], 
760                           colVal[1], 
761                           colVal[2]));
762    osg.fillRect(x, y, width, height);
763  }
764 
765  /** Update the rendered image.
766  */
767  private void update() {
768    Graphics g = m_plotPanel.getGraphics();
769    g.drawImage(m_osi, 0, 0, m_plotPanel);
770  }
771
772  /**
773   * Set the training data to use
774   *
775   * @param trainingData the training data
776   * @exception Exception if an error occurs
777   */
778  public void setTrainingData(Instances trainingData) throws Exception {
779
780    m_trainingData = trainingData;
781    if (m_trainingData.classIndex() < 0) {
782      throw new Exception("No class attribute set (BoundaryPanel)");
783    }
784    m_classIndex = m_trainingData.classIndex();
785  }
786 
787  /** Adds a training instance to the visualization dataset.
788  */
789  public void addTrainingInstance(Instance instance) {
790       
791        if (m_trainingData == null) {
792                //TODO
793                System.err.println("Trying to add to a null training set (BoundaryPanel)");
794        }
795       
796        m_trainingData.add(instance);
797  }
798
799  /**
800   * Register a listener to be notified when plotting completes
801   *
802   * @param newListener the listener to add
803   */
804  public void addActionListener(ActionListener newListener) {
805    m_listeners.add(newListener);
806  }
807 
808  /**
809   * Remove a listener
810   *
811   * @param removeListener the listener to remove
812   */
813  public void removeActionListener(ActionListener removeListener) {
814    m_listeners.removeElement(removeListener);
815  }
816 
817  /**
818   * Set the classifier to use.
819   *
820   * @param classifier the classifier to use
821   */
822  public void setClassifier(Classifier classifier) {
823    m_classifier = classifier;
824  }
825 
826  /**
827   * Set the data generator to use for generating new instances
828   *
829   * @param dataGenerator the data generator to use
830   */
831  public void setDataGenerator(DataGenerator dataGenerator) {
832    m_dataGenerator = dataGenerator;
833  }
834 
835  /**
836   * Set the x attribute index
837   *
838   * @param xatt index of the attribute to use on the x axis
839   * @exception Exception if an error occurs
840   */
841  public void setXAttribute(int xatt) throws Exception {
842    if (m_trainingData == null) {
843      throw new Exception("No training data set (BoundaryPanel)");
844    }
845    if (xatt < 0 || 
846        xatt > m_trainingData.numAttributes()) {
847      throw new Exception("X attribute out of range (BoundaryPanel)");
848    }
849    if (m_trainingData.attribute(xatt).isNominal()) {
850      throw new Exception("Visualization dimensions must be numeric "
851                          +"(BoundaryPanel)");
852    }
853    /*if (m_trainingData.numDistinctValues(xatt) < 2) {
854      throw new Exception("Too few distinct values for X attribute "
855                          +"(BoundaryPanel)");
856    }*/ //removed by jimmy. TESTING!
857    m_xAttribute = xatt;
858  }
859
860  /**
861   * Set the y attribute index
862   *
863   * @param yatt index of the attribute to use on the y axis
864   * @exception Exception if an error occurs
865   */
866  public void setYAttribute(int yatt) throws Exception {
867    if (m_trainingData == null) {
868      throw new Exception("No training data set (BoundaryPanel)");
869    }
870    if (yatt < 0 || 
871        yatt > m_trainingData.numAttributes()) {
872      throw new Exception("X attribute out of range (BoundaryPanel)");
873    }
874    if (m_trainingData.attribute(yatt).isNominal()) {
875      throw new Exception("Visualization dimensions must be numeric "
876                          +"(BoundaryPanel)");
877    }
878    /*if (m_trainingData.numDistinctValues(yatt) < 2) {
879      throw new Exception("Too few distinct values for Y attribute "
880                          +"(BoundaryPanel)");
881    }*/ //removed by jimmy. TESTING!
882    m_yAttribute = yatt;
883  }
884 
885  /**
886   * Set a vector of Color objects for the classes
887   *
888   * @param colors a <code>FastVector</code> value
889   */
890  public void setColors(FastVector colors) {
891    synchronized (m_Colors) {
892      m_Colors = colors;
893    }
894    //replot(); //commented by jimmy
895    update(); //added by jimmy
896  }
897
898  /**
899   * Set whether to superimpose the training data
900   * plot
901   *
902   * @param pg a <code>boolean</code> value
903   */
904  public void setPlotTrainingData(boolean pg) {
905    m_plotTrainingData = pg;
906  }
907
908  /**
909   * Returns true if training data is to be superimposed
910   *
911   * @return a <code>boolean</code> value
912   */
913  public boolean getPlotTrainingData() {
914    return m_plotTrainingData;
915  }
916 
917  /**
918   * Get the current vector of Color objects used for the classes
919   *
920   * @return a <code>FastVector</code> value
921   */
922  public FastVector getColors() {
923    return m_Colors;
924  }
925 
926  /**
927   * Quickly replot the display using cached probability estimates
928   */
929  public void replot() {
930    if (m_probabilityCache[0][0] == null) {
931      return;
932    }
933    m_stopReplotting = true;
934    m_pausePlotting = true;
935    // wait 300 ms to give any other replot threads a chance to halt
936    try {
937      Thread.sleep(300);
938    } catch (Exception ex) {}
939
940    final Thread replotThread = new Thread() {
941        public void run() {
942          m_stopReplotting = false;
943          int size2 = m_size / 2;
944          finishedReplot: for (int i = 0; i < m_panelHeight; i += m_size) {
945            for (int j = 0; j < m_panelWidth; j += m_size) {
946              if (m_probabilityCache[i][j] == null || m_stopReplotting) {
947                break finishedReplot;
948              }
949
950              boolean update = (j == 0 && i % 2 == 0);
951              if (i < m_panelHeight && j < m_panelWidth) {
952                // Draw the three new subpixel regions or single course tiling
953                if (m_initialTiling || m_size == 1) {
954                  if (m_probabilityCache[i][j] == null) {
955                    break finishedReplot;
956                  }
957                  plotPoint(j, i, m_size, m_size, 
958                            m_probabilityCache[i][j], update);
959                } else {
960                  if (m_probabilityCache[i+size2][j] == null) {
961                    break finishedReplot;
962                  }
963                  plotPoint(j, i + size2, size2, size2, 
964                            m_probabilityCache[i + size2][j], update);
965                  if (m_probabilityCache[i+size2][j+size2] == null) {
966                    break finishedReplot;
967                  }
968                  plotPoint(j + size2, i + size2, size2, size2, 
969                            m_probabilityCache[i + size2][j + size2], update);
970                  if (m_probabilityCache[i][j+size2] == null) {
971                    break finishedReplot;
972                  }
973                  plotPoint(j + size2, i, size2, size2, 
974                            m_probabilityCache[i + size2][j], update);
975                }
976              }
977            }
978          }
979          update();
980          if (m_plotTrainingData) {
981            plotTrainingData();
982          }
983          m_pausePlotting = false;
984          if (!m_stopPlotting) {
985            synchronized (m_dummy) {
986              m_dummy.notifyAll();
987            }
988          }
989        }
990      };
991   
992    replotThread.start();     
993  }
994
995  protected void saveImage(String fileName) {
996    BufferedImage       bi;
997    Graphics2D          gr2;
998    ImageWriter         writer;
999    Iterator            iter;
1000    ImageOutputStream   ios;
1001    ImageWriteParam     param;
1002
1003    try {
1004      // render image
1005      bi  = new BufferedImage(m_panelWidth, m_panelHeight, BufferedImage.TYPE_INT_RGB);
1006      gr2 = bi.createGraphics();
1007      gr2.drawImage(m_osi, 0, 0, m_panelWidth, m_panelHeight, null);
1008
1009      // get jpeg writer
1010      writer = null;
1011      iter   = ImageIO.getImageWritersByFormatName("jpg");
1012      if (iter.hasNext())
1013        writer = (ImageWriter) iter.next();
1014      else
1015        throw new Exception("No JPEG writer available!");
1016
1017      // prepare output file
1018      ios = ImageIO.createImageOutputStream(new File(fileName));
1019      writer.setOutput(ios);
1020
1021      // set the quality
1022      param = new JPEGImageWriteParam(Locale.getDefault());
1023      param.setCompressionMode(ImageWriteParam.MODE_EXPLICIT) ;
1024      param.setCompressionQuality(1.0f);
1025
1026      // write the image
1027      writer.write(null, new IIOImage(bi, null, null), param);
1028
1029      // cleanup
1030      ios.flush();
1031      writer.dispose();
1032      ios.close();   
1033    }
1034    catch (Exception e) {
1035      e.printStackTrace();
1036    }
1037  }
1038 
1039  /** Adds a training instance to our dataset, based on the coordinates of the mouse on the panel.
1040      This method sets the x and y attributes and the class (as defined by classAttIndex), and sets
1041      all other values as Missing.
1042   *  @param mouseX the x coordinate of the mouse, in pixels.
1043   *  @param mouseY the y coordinate of the mouse, in pixels.
1044   *  @param classAttIndex the index of the attribute that is currently selected as the class attribute.
1045   *  @param classValue the value to set the class to in our new point.
1046   */
1047  public void addTrainingInstanceFromMouseLocation(int mouseX, int mouseY, int classAttIndex, double classValue) {
1048        //convert to coordinates in the training instance space.
1049        double x = convertFromPanelX(mouseX);
1050        double y = convertFromPanelY(mouseY);
1051       
1052        //build the training instance
1053        Instance newInstance = new DenseInstance(m_trainingData.numAttributes());
1054        for (int i = 0; i < newInstance.numAttributes(); i++) {
1055                if (i == classAttIndex) {
1056                        newInstance.setValue(i,classValue);
1057                }
1058                else if (i == m_xAttribute)
1059                        newInstance.setValue(i,x);
1060                else if (i == m_yAttribute)
1061                        newInstance.setValue(i,y);
1062                else newInstance.setMissing(i);
1063        }
1064       
1065        //add it to our data set.
1066        addTrainingInstance(newInstance);
1067  }
1068 
1069  /** Deletes all training instances from our dataset.
1070  */
1071  public void removeAllInstances() {
1072        if (m_trainingData != null)
1073        {
1074                m_trainingData.delete();
1075                try { initialize();} catch (Exception e) {};
1076        }
1077       
1078  }
1079 
1080  /** Removes a single training instance from our dataset, if there is one that is close enough
1081      to the specified mouse location.
1082  */
1083  public void removeTrainingInstanceFromMouseLocation(int mouseX, int mouseY) {
1084       
1085        //convert to coordinates in the training instance space.
1086        double x = convertFromPanelX(mouseX);
1087        double y = convertFromPanelY(mouseY);
1088       
1089        int bestIndex = -1;
1090        double bestDistanceBetween = Integer.MAX_VALUE;
1091       
1092        //find the closest point.
1093        for (int i = 0; i < m_trainingData.numInstances(); i++) {
1094                Instance current = m_trainingData.instance(i);
1095                double distanceBetween = (current.value(m_xAttribute) - x) * (current.value(m_xAttribute) - x) + (current.value(m_yAttribute) - y) * (current.value(m_yAttribute) - y); // won't bother to sqrt, just used square values.
1096               
1097                if (distanceBetween < bestDistanceBetween)
1098                {
1099                        bestIndex = i;
1100                        bestDistanceBetween = distanceBetween;
1101                }
1102        }
1103        if (bestIndex == -1)
1104                return;
1105        Instance best = m_trainingData.instance(bestIndex);
1106        double panelDistance = (convertToPanelX(best.value(m_xAttribute)) - mouseX) * (convertToPanelX(best.value(m_xAttribute)) - mouseX)
1107                + (convertToPanelY(best.value(m_yAttribute)) - mouseY) * (convertToPanelY(best.value(m_yAttribute)) - mouseY);
1108        if (panelDistance < REMOVE_POINT_RADIUS * REMOVE_POINT_RADIUS) {//the best point is close enough. (using squared distances)
1109                m_trainingData.delete(bestIndex);
1110        }
1111  }
1112 
1113  /** Starts the plotting thread.  Will also create it if necessary.
1114  */
1115  public void startPlotThread() {
1116        if (m_plotThread == null) { //jimmy
1117                m_plotThread = new PlotThread();
1118                m_plotThread.setPriority(Thread.MIN_PRIORITY);
1119                m_plotThread.start();
1120        }
1121  }
1122 
1123  /** Adds a mouse listener.
1124  */
1125  public void addMouseListener(MouseListener l) {
1126        m_plotPanel.addMouseListener(l);
1127  }
1128 
1129  /** Gets the minimum x-coordinate bound, in training-instance units (not mouse coordinates).
1130  */
1131  public double getMinXBound() {
1132        return m_minX;
1133  }
1134 
1135  /** Gets the minimum y-coordinate bound, in training-instance units (not mouse coordinates).
1136  */
1137  public double getMinYBound() {
1138        return m_minY;
1139  }
1140 
1141  /** Gets the maximum x-coordinate bound, in training-instance units (not mouse coordinates).
1142  */
1143  public double getMaxXBound() {
1144        return m_maxX;
1145  }
1146 
1147  /** Gets the maximum x-coordinate bound, in training-instance units (not mouse coordinates).
1148  */
1149  public double getMaxYBound() {
1150        return m_maxY;
1151  }
1152
1153  /**
1154   * Main method for testing this class
1155   *
1156   * @param args a <code>String[]</code> value
1157   */
1158  public static void main (String [] args) {
1159    try {
1160      if (args.length < 8) {
1161        System.err.println("Usage : BoundaryPanel <dataset> "
1162                           +"<class col> <xAtt> <yAtt> "
1163                           +"<base> <# loc/pixel> <kernel bandwidth> "
1164                           +"<display width> "
1165                           +"<display height> <classifier "
1166                           +"[classifier options]>");
1167        System.exit(1);
1168      }
1169      final javax.swing.JFrame jf = 
1170        new javax.swing.JFrame("Weka classification boundary visualizer");
1171      jf.getContentPane().setLayout(new BorderLayout());
1172
1173      System.err.println("Loading instances from : "+args[0]);
1174      java.io.Reader r = new java.io.BufferedReader(
1175                         new java.io.FileReader(args[0]));
1176      final Instances i = new Instances(r);
1177      i.setClassIndex(Integer.parseInt(args[1]));
1178
1179      //      bv.setClassifier(new Logistic());
1180      final int xatt = Integer.parseInt(args[2]);
1181      final int yatt = Integer.parseInt(args[3]);
1182      int base = Integer.parseInt(args[4]);
1183      int loc = Integer.parseInt(args[5]);
1184
1185      int bandWidth = Integer.parseInt(args[6]);
1186      int panelWidth = Integer.parseInt(args[7]);
1187      int panelHeight = Integer.parseInt(args[8]);
1188
1189      final String classifierName = args[9];
1190      final BoundaryPanel bv = new BoundaryPanel(panelWidth,panelHeight);
1191      bv.addActionListener(new ActionListener() {
1192          public void actionPerformed(ActionEvent e) {
1193            String classifierNameNew = 
1194              classifierName.substring(classifierName.lastIndexOf('.')+1, 
1195                                       classifierName.length());
1196            bv.saveImage(classifierNameNew+"_"+i.relationName()
1197                         +"_X"+xatt+"_Y"+yatt+".jpg");
1198          }
1199        });
1200
1201      jf.getContentPane().add(bv, BorderLayout.CENTER);
1202      jf.setSize(bv.getMinimumSize());
1203      //      jf.setSize(200,200);
1204      jf.addWindowListener(new java.awt.event.WindowAdapter() {
1205          public void windowClosing(java.awt.event.WindowEvent e) {
1206            jf.dispose();
1207            System.exit(0);
1208          }
1209        });
1210
1211      jf.pack();
1212      jf.setVisible(true);
1213      //      bv.initialize();
1214      bv.repaint();
1215     
1216
1217      String [] argsR = null;
1218      if (args.length > 10) {
1219        argsR = new String [args.length-10];
1220        for (int j = 10; j < args.length; j++) {
1221          argsR[j-10] = args[j];
1222        }
1223      }
1224      Classifier c = AbstractClassifier.forName(args[9], argsR);
1225      KDDataGenerator dataGen = new KDDataGenerator();
1226      dataGen.setKernelBandwidth(bandWidth);
1227      bv.setDataGenerator(dataGen);
1228      bv.setNumSamplesPerRegion(loc);
1229      bv.setGeneratorSamplesBase(base);
1230      bv.setClassifier(c);
1231      bv.setTrainingData(i);
1232      bv.setXAttribute(xatt);
1233      bv.setYAttribute(yatt);
1234
1235      try {
1236        // try and load a color map if one exists
1237        FileInputStream fis = new FileInputStream("colors.ser");
1238        ObjectInputStream ois = new ObjectInputStream(fis);
1239        FastVector colors = (FastVector)ois.readObject();
1240        bv.setColors(colors);   
1241      } catch (Exception ex) {
1242        System.err.println("No color map file");
1243      }
1244      bv.start();
1245    } catch (Exception ex) {
1246      ex.printStackTrace();
1247    }
1248  }
1249}
1250
Note: See TracBrowser for help on using the repository browser.