1 | /* |
---|
2 | * This program is free software; you can redistribute it and/or modify |
---|
3 | * it under the terms of the GNU General Public License as published by |
---|
4 | * the Free Software Foundation; either version 2 of the License, or |
---|
5 | * (at your option) any later version. |
---|
6 | * |
---|
7 | * This program is distributed in the hope that it will be useful, |
---|
8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
---|
10 | * GNU General Public License for more details. |
---|
11 | * |
---|
12 | * You should have received a copy of the GNU General Public License |
---|
13 | * along with this program; if not, write to the Free Software |
---|
14 | * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. |
---|
15 | */ |
---|
16 | |
---|
17 | /* |
---|
18 | * BoundaryPanel.java |
---|
19 | * Copyright (C) 2002 University of Waikato, Hamilton, New Zealand |
---|
20 | * |
---|
21 | */ |
---|
22 | |
---|
23 | package weka.gui.boundaryvisualizer; |
---|
24 | |
---|
25 | import weka.classifiers.Classifier; |
---|
26 | import weka.classifiers.AbstractClassifier; |
---|
27 | import weka.core.FastVector; |
---|
28 | import weka.core.Instance; |
---|
29 | import weka.core.DenseInstance; |
---|
30 | import weka.core.Instances; |
---|
31 | import weka.core.Utils; |
---|
32 | import weka.gui.visualize.JPEGWriter; |
---|
33 | |
---|
34 | import java.awt.BorderLayout; |
---|
35 | import java.awt.Color; |
---|
36 | import java.awt.Dimension; |
---|
37 | import java.awt.Graphics; |
---|
38 | import java.awt.Graphics2D; |
---|
39 | import java.awt.Image; |
---|
40 | import java.awt.RenderingHints; |
---|
41 | import java.awt.event.ActionEvent; |
---|
42 | import java.awt.event.ActionListener; |
---|
43 | import java.awt.event.MouseEvent; |
---|
44 | import java.awt.event.MouseListener; |
---|
45 | import java.awt.image.BufferedImage; |
---|
46 | import java.io.File; |
---|
47 | import java.io.FileInputStream; |
---|
48 | import java.io.ObjectInputStream; |
---|
49 | import java.util.Iterator; |
---|
50 | import java.util.Locale; |
---|
51 | import java.util.Random; |
---|
52 | import java.util.Vector; |
---|
53 | |
---|
54 | import javax.imageio.IIOImage; |
---|
55 | import javax.imageio.ImageIO; |
---|
56 | import javax.imageio.ImageWriteParam; |
---|
57 | import javax.imageio.ImageWriter; |
---|
58 | import javax.imageio.plugins.jpeg.JPEGImageWriteParam; |
---|
59 | import javax.imageio.stream.ImageOutputStream; |
---|
60 | import javax.swing.JOptionPane; |
---|
61 | import javax.swing.JPanel; |
---|
62 | import javax.swing.ToolTipManager; |
---|
63 | |
---|
64 | /** |
---|
65 | * BoundaryPanel. A class to handle the plotting operations |
---|
66 | * associated with generating a 2D picture of a classifier's decision |
---|
67 | * boundaries. |
---|
68 | * |
---|
69 | * @author <a href="mailto:mhall@cs.waikato.ac.nz">Mark Hall</a> |
---|
70 | * @version $Revision: 5987 $ |
---|
71 | * @since 1.0 |
---|
72 | * @see JPanel |
---|
73 | */ |
---|
74 | public class BoundaryPanel |
---|
75 | extends JPanel { |
---|
76 | |
---|
77 | /** for serialization */ |
---|
78 | private static final long serialVersionUID = -8499445518744770458L; |
---|
79 | |
---|
80 | /** default colours for classes */ |
---|
81 | public static final Color [] DEFAULT_COLORS = { |
---|
82 | Color.red, |
---|
83 | Color.green, |
---|
84 | Color.blue, |
---|
85 | new Color(0, 255, 255), // cyan |
---|
86 | new Color(255, 0, 255), // pink |
---|
87 | new Color(255, 255, 0), // yellow |
---|
88 | new Color(255, 255, 255), //white |
---|
89 | new Color(0, 0, 0)}; |
---|
90 | |
---|
91 | /** The distance we can click away from a point in the GUI and still remove it. */ |
---|
92 | public static final double REMOVE_POINT_RADIUS = 7.0; |
---|
93 | |
---|
94 | protected FastVector m_Colors = new FastVector(); |
---|
95 | |
---|
96 | /** training data */ |
---|
97 | protected Instances m_trainingData; |
---|
98 | |
---|
99 | /** distribution classifier to use */ |
---|
100 | protected Classifier m_classifier; |
---|
101 | |
---|
102 | /** data generator to use */ |
---|
103 | protected DataGenerator m_dataGenerator; |
---|
104 | |
---|
105 | /** index of the class attribute */ |
---|
106 | private int m_classIndex = -1; |
---|
107 | |
---|
108 | // attributes for visualizing on |
---|
109 | protected int m_xAttribute; |
---|
110 | protected int m_yAttribute; |
---|
111 | |
---|
112 | // min, max and ranges of these attributes |
---|
113 | protected double m_minX; |
---|
114 | protected double m_minY; |
---|
115 | protected double m_maxX; |
---|
116 | protected double m_maxY; |
---|
117 | private double m_rangeX; |
---|
118 | private double m_rangeY; |
---|
119 | |
---|
120 | // pixel width and height in terms of attribute values |
---|
121 | protected double m_pixHeight; |
---|
122 | protected double m_pixWidth; |
---|
123 | |
---|
124 | /** used for offscreen drawing */ |
---|
125 | protected Image m_osi = null; |
---|
126 | |
---|
127 | // width and height of the display area |
---|
128 | protected int m_panelWidth; |
---|
129 | protected int m_panelHeight; |
---|
130 | |
---|
131 | // number of samples to take from each region in the fixed dimensions |
---|
132 | protected int m_numOfSamplesPerRegion = 2; |
---|
133 | |
---|
134 | // number of samples per kernel = base ^ (# non-fixed dimensions) |
---|
135 | protected int m_numOfSamplesPerGenerator; |
---|
136 | protected double m_samplesBase = 2.0; |
---|
137 | |
---|
138 | /** listeners to be notified when plot is complete */ |
---|
139 | private Vector m_listeners = new Vector(); |
---|
140 | |
---|
141 | /** |
---|
142 | * small inner class for rendering the bitmap on to |
---|
143 | */ |
---|
144 | private class PlotPanel |
---|
145 | extends JPanel { |
---|
146 | |
---|
147 | /** for serialization */ |
---|
148 | private static final long serialVersionUID = 743629498352235060L; |
---|
149 | |
---|
150 | public PlotPanel() { |
---|
151 | this.setToolTipText(""); |
---|
152 | } |
---|
153 | |
---|
154 | public void paintComponent(Graphics g) { |
---|
155 | super.paintComponent(g); |
---|
156 | if (m_osi != null) { |
---|
157 | g.drawImage(m_osi,0,0,this); |
---|
158 | } |
---|
159 | } |
---|
160 | |
---|
161 | public String getToolTipText(MouseEvent event) { |
---|
162 | if (m_probabilityCache == null) { |
---|
163 | return null; |
---|
164 | } |
---|
165 | |
---|
166 | if (m_probabilityCache[event.getY()][event.getX()] == null) { |
---|
167 | return null; |
---|
168 | } |
---|
169 | |
---|
170 | String pVec = "(X: " |
---|
171 | +Utils.doubleToString(convertFromPanelX((double)event.getX()), 2) |
---|
172 | +" Y: " |
---|
173 | +Utils.doubleToString(convertFromPanelY((double)event.getY()), 2)+") "; |
---|
174 | // construct a string holding the probability vector |
---|
175 | for (int i = 0; i < m_trainingData.classAttribute().numValues(); i++) { |
---|
176 | pVec += |
---|
177 | Utils. |
---|
178 | doubleToString(m_probabilityCache[event.getY()][event.getX()][i], |
---|
179 | 3)+" "; |
---|
180 | } |
---|
181 | return pVec; |
---|
182 | } |
---|
183 | } |
---|
184 | |
---|
185 | /** the actual plotting area */ |
---|
186 | private PlotPanel m_plotPanel = new PlotPanel(); |
---|
187 | |
---|
188 | /** thread for running the plotting operation in */ |
---|
189 | private Thread m_plotThread = null; |
---|
190 | |
---|
191 | /** Stop the plotting thread */ |
---|
192 | protected boolean m_stopPlotting = false; |
---|
193 | |
---|
194 | /** Stop any replotting threads */ |
---|
195 | protected boolean m_stopReplotting = false; |
---|
196 | |
---|
197 | // Used by replotting threads to pause and resume the main plot thread |
---|
198 | private Double m_dummy = new Double(1.0); |
---|
199 | private boolean m_pausePlotting = false; |
---|
200 | /** what size of tile is currently being plotted */ |
---|
201 | private int m_size = 1; |
---|
202 | /** is the main plot thread performing the initial coarse tiling */ |
---|
203 | private boolean m_initialTiling; |
---|
204 | |
---|
205 | /** A random number generator */ |
---|
206 | private Random m_random = null; |
---|
207 | |
---|
208 | /** cache of probabilities for fast replotting */ |
---|
209 | protected double [][][] m_probabilityCache; |
---|
210 | |
---|
211 | /** plot the training data */ |
---|
212 | protected boolean m_plotTrainingData = true; |
---|
213 | |
---|
214 | /** |
---|
215 | * Creates a new <code>BoundaryPanel</code> instance. |
---|
216 | * |
---|
217 | * @param panelWidth the width in pixels of the panel |
---|
218 | * @param panelHeight the height in pixels of the panel |
---|
219 | */ |
---|
220 | public BoundaryPanel(int panelWidth, int panelHeight) { |
---|
221 | ToolTipManager.sharedInstance().setDismissDelay(Integer.MAX_VALUE); |
---|
222 | m_panelWidth = panelWidth; |
---|
223 | m_panelHeight = panelHeight; |
---|
224 | setLayout(new BorderLayout()); |
---|
225 | m_plotPanel.setMinimumSize(new Dimension(m_panelWidth, m_panelHeight)); |
---|
226 | m_plotPanel.setPreferredSize(new Dimension(m_panelWidth, m_panelHeight)); |
---|
227 | m_plotPanel.setMaximumSize(new Dimension(m_panelWidth, m_panelHeight)); |
---|
228 | add(m_plotPanel, BorderLayout.CENTER); |
---|
229 | setPreferredSize(m_plotPanel.getPreferredSize()); |
---|
230 | setMaximumSize(m_plotPanel.getMaximumSize()); |
---|
231 | setMinimumSize(m_plotPanel.getMinimumSize()); |
---|
232 | |
---|
233 | m_random = new Random(1); |
---|
234 | for (int i = 0; i < DEFAULT_COLORS.length; i++) { |
---|
235 | m_Colors.addElement(new Color(DEFAULT_COLORS[i].getRed(), |
---|
236 | DEFAULT_COLORS[i].getGreen(), |
---|
237 | DEFAULT_COLORS[i].getBlue())); |
---|
238 | } |
---|
239 | m_probabilityCache = new double[m_panelHeight][m_panelWidth][]; |
---|
240 | |
---|
241 | } |
---|
242 | |
---|
243 | /** |
---|
244 | * Set the number of points to uniformly sample from a region (fixed |
---|
245 | * dimensions). |
---|
246 | * |
---|
247 | * @param num an <code>int</code> value |
---|
248 | */ |
---|
249 | public void setNumSamplesPerRegion(int num) { |
---|
250 | m_numOfSamplesPerRegion = num; |
---|
251 | } |
---|
252 | |
---|
253 | /** |
---|
254 | * Get the number of points to sample from a region (fixed dimensions). |
---|
255 | * |
---|
256 | * @return an <code>int</code> value |
---|
257 | */ |
---|
258 | public int getNumSamplesPerRegion() { |
---|
259 | return m_numOfSamplesPerRegion; |
---|
260 | } |
---|
261 | |
---|
262 | /** |
---|
263 | * Set the base for computing the number of samples to obtain from each |
---|
264 | * generator. number of samples = base ^ (# non fixed dimensions) |
---|
265 | * |
---|
266 | * @param ksb a <code>double</code> value |
---|
267 | */ |
---|
268 | public void setGeneratorSamplesBase(double ksb) { |
---|
269 | m_samplesBase = ksb; |
---|
270 | } |
---|
271 | |
---|
272 | /** |
---|
273 | * Get the base used for computing the number of samples to obtain from |
---|
274 | * each generator |
---|
275 | * |
---|
276 | * @return a <code>double</code> value |
---|
277 | */ |
---|
278 | public double getGeneratorSamplesBase() { |
---|
279 | return m_samplesBase; |
---|
280 | } |
---|
281 | |
---|
282 | /** |
---|
283 | * Set up the off screen bitmap for rendering to |
---|
284 | */ |
---|
285 | protected void initialize() { |
---|
286 | int iwidth = m_plotPanel.getWidth(); |
---|
287 | int iheight = m_plotPanel.getHeight(); |
---|
288 | // System.err.println(iwidth+" "+iheight); |
---|
289 | m_osi = m_plotPanel.createImage(iwidth, iheight); |
---|
290 | Graphics m = m_osi.getGraphics(); |
---|
291 | m.fillRect(0,0,iwidth,iheight); |
---|
292 | } |
---|
293 | |
---|
294 | /** |
---|
295 | * Stop the plotting thread |
---|
296 | */ |
---|
297 | public void stopPlotting() { |
---|
298 | m_stopPlotting = true; |
---|
299 | try { |
---|
300 | m_plotThread.join(100); |
---|
301 | } catch (Exception e){}; |
---|
302 | } |
---|
303 | |
---|
304 | /** Set up the bounds of our graphic based by finding the smallest reasonable |
---|
305 | area in the instance space to surround our data points. |
---|
306 | */ |
---|
307 | public void computeMinMaxAtts() { |
---|
308 | m_minX = Double.MAX_VALUE; |
---|
309 | m_minY = Double.MAX_VALUE; |
---|
310 | m_maxX = Double.MIN_VALUE; |
---|
311 | m_maxY = Double.MIN_VALUE; |
---|
312 | |
---|
313 | boolean allPointsLessThanOne = true; |
---|
314 | |
---|
315 | if (m_trainingData.numInstances() == 0) { |
---|
316 | m_minX = m_minY = 0.0; |
---|
317 | m_maxX = m_maxY = 1.0; |
---|
318 | } |
---|
319 | else |
---|
320 | { |
---|
321 | for (int i = 0; i < m_trainingData.numInstances(); i++) { |
---|
322 | Instance inst = m_trainingData.instance(i); |
---|
323 | double x = inst.value(m_xAttribute); |
---|
324 | double y = inst.value(m_yAttribute); |
---|
325 | if (!Utils.isMissingValue(x) && !Utils.isMissingValue(y)) { |
---|
326 | if (x < m_minX) { |
---|
327 | m_minX = x; |
---|
328 | } |
---|
329 | if (x > m_maxX) { |
---|
330 | m_maxX = x; |
---|
331 | } |
---|
332 | |
---|
333 | if (y < m_minY) { |
---|
334 | m_minY = y; |
---|
335 | } |
---|
336 | if (y > m_maxY) { |
---|
337 | m_maxY = y; |
---|
338 | } |
---|
339 | if (x > 1.0 || y > 1.0) |
---|
340 | allPointsLessThanOne = false; |
---|
341 | } |
---|
342 | } |
---|
343 | } |
---|
344 | |
---|
345 | if (m_minX == m_maxX) |
---|
346 | m_minX = 0; |
---|
347 | if (m_minY == m_maxY) |
---|
348 | m_minY = 0; |
---|
349 | if (m_minX == Double.MAX_VALUE) |
---|
350 | m_minX = 0; |
---|
351 | if (m_minY == Double.MAX_VALUE) |
---|
352 | m_minY = 0; |
---|
353 | if (m_maxX == Double.MIN_VALUE) |
---|
354 | m_maxX = 1; |
---|
355 | if (m_maxY == Double.MIN_VALUE) |
---|
356 | m_maxY = 1; |
---|
357 | if (allPointsLessThanOne) { |
---|
358 | m_minX = m_minY = 0.0; |
---|
359 | m_maxX = m_maxY = 1.0; |
---|
360 | } |
---|
361 | |
---|
362 | |
---|
363 | |
---|
364 | m_rangeX = (m_maxX - m_minX); |
---|
365 | m_rangeY = (m_maxY - m_minY); |
---|
366 | |
---|
367 | m_pixWidth = m_rangeX / (double)m_panelWidth; |
---|
368 | m_pixHeight = m_rangeY / (double) m_panelHeight; |
---|
369 | } |
---|
370 | |
---|
371 | /** |
---|
372 | * Return a random x attribute value contained within |
---|
373 | * the pix'th horizontal pixel |
---|
374 | * |
---|
375 | * @param pix the horizontal pixel number |
---|
376 | * @return a value in attribute space |
---|
377 | */ |
---|
378 | private double getRandomX(int pix) { |
---|
379 | |
---|
380 | double minPix = m_minX + (pix * m_pixWidth); |
---|
381 | |
---|
382 | return minPix + m_random.nextDouble() * m_pixWidth; |
---|
383 | } |
---|
384 | |
---|
385 | /** |
---|
386 | * Return a random y attribute value contained within |
---|
387 | * the pix'th vertical pixel |
---|
388 | * |
---|
389 | * @param pix the vertical pixel number |
---|
390 | * @return a value in attribute space |
---|
391 | */ |
---|
392 | private double getRandomY(int pix) { |
---|
393 | |
---|
394 | double minPix = m_minY + (pix * m_pixHeight); |
---|
395 | |
---|
396 | return minPix + m_random.nextDouble() * m_pixHeight; |
---|
397 | } |
---|
398 | |
---|
399 | /** |
---|
400 | * Start the plotting thread |
---|
401 | * |
---|
402 | * @exception Exception if an error occurs |
---|
403 | */ |
---|
404 | public void start() throws Exception { |
---|
405 | m_numOfSamplesPerGenerator = |
---|
406 | (int)Math.pow(m_samplesBase, m_trainingData.numAttributes()-3); |
---|
407 | |
---|
408 | m_stopReplotting = true; |
---|
409 | if (m_trainingData == null) { |
---|
410 | throw new Exception("No training data set (BoundaryPanel)"); |
---|
411 | } |
---|
412 | if (m_classifier == null) { |
---|
413 | throw new Exception("No classifier set (BoundaryPanel)"); |
---|
414 | } |
---|
415 | if (m_dataGenerator == null) { |
---|
416 | throw new Exception("No data generator set (BoundaryPanel)"); |
---|
417 | } |
---|
418 | if (m_trainingData.attribute(m_xAttribute).isNominal() || |
---|
419 | m_trainingData.attribute(m_yAttribute).isNominal()) { |
---|
420 | throw new Exception("Visualization dimensions must be numeric " |
---|
421 | +"(BoundaryPanel)"); |
---|
422 | } |
---|
423 | |
---|
424 | computeMinMaxAtts(); |
---|
425 | |
---|
426 | startPlotThread(); |
---|
427 | /*if (m_plotThread == null) { |
---|
428 | m_plotThread = new PlotThread(); |
---|
429 | m_plotThread.setPriority(Thread.MIN_PRIORITY); |
---|
430 | m_plotThread.start(); |
---|
431 | }*/ |
---|
432 | } |
---|
433 | |
---|
434 | // Thread for main plotting operation |
---|
435 | protected class PlotThread extends Thread { |
---|
436 | double [] m_weightingAttsValues; |
---|
437 | boolean [] m_attsToWeightOn; |
---|
438 | double [] m_vals; |
---|
439 | double [] m_dist; |
---|
440 | Instance m_predInst; |
---|
441 | public void run() { |
---|
442 | |
---|
443 | m_stopPlotting = false; |
---|
444 | try { |
---|
445 | initialize(); |
---|
446 | repaint(); |
---|
447 | |
---|
448 | // train the classifier |
---|
449 | m_probabilityCache = new double[m_panelHeight][m_panelWidth][]; |
---|
450 | m_classifier.buildClassifier(m_trainingData); |
---|
451 | |
---|
452 | // build DataGenerator |
---|
453 | m_attsToWeightOn = new boolean[m_trainingData.numAttributes()]; |
---|
454 | m_attsToWeightOn[m_xAttribute] = true; |
---|
455 | m_attsToWeightOn[m_yAttribute] = true; |
---|
456 | |
---|
457 | m_dataGenerator.setWeightingDimensions(m_attsToWeightOn); |
---|
458 | |
---|
459 | m_dataGenerator.buildGenerator(m_trainingData); |
---|
460 | |
---|
461 | // generate samples |
---|
462 | m_weightingAttsValues = new double [m_attsToWeightOn.length]; |
---|
463 | m_vals = new double[m_trainingData.numAttributes()]; |
---|
464 | m_predInst = new DenseInstance(1.0, m_vals); |
---|
465 | m_predInst.setDataset(m_trainingData); |
---|
466 | |
---|
467 | |
---|
468 | m_size = 1 << 4; // Current sample region size |
---|
469 | |
---|
470 | m_initialTiling = true; |
---|
471 | // Display the initial coarse image tiling. |
---|
472 | abortInitial: |
---|
473 | for (int i = 0; i <= m_panelHeight; i += m_size) { |
---|
474 | for (int j = 0; j <= m_panelWidth; j += m_size) { |
---|
475 | if (m_stopPlotting) { |
---|
476 | break abortInitial; |
---|
477 | } |
---|
478 | if (m_pausePlotting) { |
---|
479 | synchronized (m_dummy) { |
---|
480 | try { |
---|
481 | m_dummy.wait(); |
---|
482 | } catch (InterruptedException ex) { |
---|
483 | m_pausePlotting = false; |
---|
484 | } |
---|
485 | } |
---|
486 | } |
---|
487 | plotPoint(j, i, m_size, m_size, |
---|
488 | calculateRegionProbs(j, i), (j == 0)); |
---|
489 | } |
---|
490 | } |
---|
491 | if (!m_stopPlotting) { |
---|
492 | m_initialTiling = false; |
---|
493 | } |
---|
494 | |
---|
495 | // Sampling and gridding loop |
---|
496 | int size2 = m_size / 2; |
---|
497 | abortPlot: |
---|
498 | while (m_size > 1) { // Subdivide down to the pixel level |
---|
499 | for (int i = 0; i <= m_panelHeight; i += m_size) { |
---|
500 | for (int j = 0; j <= m_panelWidth; j += m_size) { |
---|
501 | if (m_stopPlotting) { |
---|
502 | break abortPlot; |
---|
503 | } |
---|
504 | if (m_pausePlotting) { |
---|
505 | synchronized (m_dummy) { |
---|
506 | try { |
---|
507 | m_dummy.wait(); |
---|
508 | } catch (InterruptedException ex) { |
---|
509 | m_pausePlotting = false; |
---|
510 | } |
---|
511 | } |
---|
512 | } |
---|
513 | boolean update = (j == 0 && i % 2 == 0); |
---|
514 | // Draw the three new subpixel regions |
---|
515 | plotPoint(j, i + size2, size2, size2, |
---|
516 | calculateRegionProbs(j, i + size2), update); |
---|
517 | plotPoint(j + size2, i + size2, size2, size2, |
---|
518 | calculateRegionProbs(j + size2, i + size2), update); |
---|
519 | plotPoint(j + size2, i, size2, size2, |
---|
520 | calculateRegionProbs(j + size2, i), update); |
---|
521 | } |
---|
522 | } |
---|
523 | // The new region edge length is half the old edge length |
---|
524 | m_size = size2; |
---|
525 | size2 = size2 / 2; |
---|
526 | } |
---|
527 | update(); |
---|
528 | |
---|
529 | |
---|
530 | /* |
---|
531 | // Old method without sampling. |
---|
532 | abortPlot: |
---|
533 | for (int i = 0; i < m_panelHeight; i++) { |
---|
534 | for (int j = 0; j < m_panelWidth; j++) { |
---|
535 | if (m_stopPlotting) { |
---|
536 | break abortPlot; |
---|
537 | } |
---|
538 | plotPoint(j, i, calculateRegionProbs(j, i), (j == 0)); |
---|
539 | } |
---|
540 | } |
---|
541 | */ |
---|
542 | |
---|
543 | |
---|
544 | if (m_plotTrainingData) { |
---|
545 | plotTrainingData(); |
---|
546 | } |
---|
547 | |
---|
548 | } catch (Exception ex) { |
---|
549 | ex.printStackTrace(); |
---|
550 | JOptionPane.showMessageDialog(null,"Error while plotting: \"" + ex.getMessage() + "\""); |
---|
551 | } finally { |
---|
552 | m_plotThread = null; |
---|
553 | // notify any listeners that we are finished |
---|
554 | Vector l; |
---|
555 | ActionEvent e = new ActionEvent(this, 0, ""); |
---|
556 | synchronized(this) { |
---|
557 | l = (Vector)m_listeners.clone(); |
---|
558 | } |
---|
559 | for (int i = 0; i < l.size(); i++) { |
---|
560 | ActionListener al = (ActionListener)l.elementAt(i); |
---|
561 | al.actionPerformed(e); |
---|
562 | } |
---|
563 | } |
---|
564 | } |
---|
565 | |
---|
566 | private double [] calculateRegionProbs(int j, int i) throws Exception { |
---|
567 | double [] sumOfProbsForRegion = |
---|
568 | new double [m_trainingData.classAttribute().numValues()]; |
---|
569 | |
---|
570 | for (int u = 0; u < m_numOfSamplesPerRegion; u++) { |
---|
571 | |
---|
572 | double [] sumOfProbsForLocation = |
---|
573 | new double [m_trainingData.classAttribute().numValues()]; |
---|
574 | |
---|
575 | m_weightingAttsValues[m_xAttribute] = getRandomX(j); |
---|
576 | m_weightingAttsValues[m_yAttribute] = getRandomY(m_panelHeight-i-1); |
---|
577 | |
---|
578 | m_dataGenerator.setWeightingValues(m_weightingAttsValues); |
---|
579 | |
---|
580 | double [] weights = m_dataGenerator.getWeights(); |
---|
581 | double sumOfWeights = Utils.sum(weights); |
---|
582 | int [] indices = Utils.sort(weights); |
---|
583 | |
---|
584 | // Prune 1% of weight mass |
---|
585 | int [] newIndices = new int[indices.length]; |
---|
586 | double sumSoFar = 0; |
---|
587 | double criticalMass = 0.99 * sumOfWeights; |
---|
588 | int index = weights.length - 1; int counter = 0; |
---|
589 | for (int z = weights.length - 1; z >= 0; z--) { |
---|
590 | newIndices[index--] = indices[z]; |
---|
591 | sumSoFar += weights[indices[z]]; |
---|
592 | counter++; |
---|
593 | if (sumSoFar > criticalMass) { |
---|
594 | break; |
---|
595 | } |
---|
596 | } |
---|
597 | indices = new int[counter]; |
---|
598 | System.arraycopy(newIndices, index + 1, indices, 0, counter); |
---|
599 | |
---|
600 | for (int z = 0; z < m_numOfSamplesPerGenerator; z++) { |
---|
601 | |
---|
602 | m_dataGenerator.setWeightingValues(m_weightingAttsValues); |
---|
603 | double [][] values = m_dataGenerator.generateInstances(indices); |
---|
604 | |
---|
605 | for (int q = 0; q < values.length; q++) { |
---|
606 | if (values[q] != null) { |
---|
607 | System.arraycopy(values[q], 0, m_vals, 0, m_vals.length); |
---|
608 | m_vals[m_xAttribute] = m_weightingAttsValues[m_xAttribute]; |
---|
609 | m_vals[m_yAttribute] = m_weightingAttsValues[m_yAttribute]; |
---|
610 | |
---|
611 | // classify the instance |
---|
612 | m_dist = m_classifier.distributionForInstance(m_predInst); |
---|
613 | for (int k = 0; k < sumOfProbsForLocation.length; k++) { |
---|
614 | sumOfProbsForLocation[k] += (m_dist[k] * weights[q]); |
---|
615 | } |
---|
616 | } |
---|
617 | } |
---|
618 | } |
---|
619 | |
---|
620 | for (int k = 0; k < sumOfProbsForRegion.length; k++) { |
---|
621 | sumOfProbsForRegion[k] += (sumOfProbsForLocation[k] * |
---|
622 | sumOfWeights); |
---|
623 | } |
---|
624 | } |
---|
625 | |
---|
626 | // average |
---|
627 | Utils.normalize(sumOfProbsForRegion); |
---|
628 | |
---|
629 | // cache |
---|
630 | if ((i < m_panelHeight) && (j < m_panelWidth)) { |
---|
631 | m_probabilityCache[i][j] = new double[sumOfProbsForRegion.length]; |
---|
632 | System.arraycopy(sumOfProbsForRegion, 0, m_probabilityCache[i][j], |
---|
633 | 0, sumOfProbsForRegion.length); |
---|
634 | } |
---|
635 | |
---|
636 | return sumOfProbsForRegion; |
---|
637 | } |
---|
638 | } |
---|
639 | |
---|
640 | /** Render the training points on-screen. |
---|
641 | */ |
---|
642 | public void plotTrainingData() { |
---|
643 | |
---|
644 | Graphics2D osg = (Graphics2D)m_osi.getGraphics(); |
---|
645 | Graphics g = m_plotPanel.getGraphics(); |
---|
646 | osg.setRenderingHint(RenderingHints.KEY_ANTIALIASING, |
---|
647 | RenderingHints.VALUE_ANTIALIAS_ON); |
---|
648 | double xval = 0; double yval = 0; |
---|
649 | |
---|
650 | for (int i = 0; i < m_trainingData.numInstances(); i++) { |
---|
651 | if (!m_trainingData.instance(i).isMissing(m_xAttribute) && |
---|
652 | !m_trainingData.instance(i).isMissing(m_yAttribute)) { |
---|
653 | |
---|
654 | if (m_trainingData.instance(i).isMissing(m_classIndex)) //jimmy. |
---|
655 | continue; //don't plot if class is missing. TODO could we plot it differently instead? |
---|
656 | |
---|
657 | xval = m_trainingData.instance(i).value(m_xAttribute); |
---|
658 | yval = m_trainingData.instance(i).value(m_yAttribute); |
---|
659 | |
---|
660 | int panelX = convertToPanelX(xval); |
---|
661 | int panelY = convertToPanelY(yval); |
---|
662 | Color ColorToPlotWith = |
---|
663 | ((Color)m_Colors.elementAt((int)m_trainingData.instance(i). |
---|
664 | value(m_classIndex) % m_Colors.size())); |
---|
665 | |
---|
666 | if (ColorToPlotWith.equals(Color.white)) { |
---|
667 | osg.setColor(Color.black); |
---|
668 | } else { |
---|
669 | osg.setColor(Color.white); |
---|
670 | } |
---|
671 | osg.fillOval(panelX-3, panelY-3, 7, 7); |
---|
672 | osg.setColor(ColorToPlotWith); |
---|
673 | osg.fillOval(panelX-2, panelY-2, 5, 5); |
---|
674 | } |
---|
675 | } |
---|
676 | g.drawImage(m_osi,0,0,m_plotPanel); |
---|
677 | } |
---|
678 | |
---|
679 | /** Convert an X coordinate from the instance space to the panel space. |
---|
680 | */ |
---|
681 | private int convertToPanelX(double xval) { |
---|
682 | double temp = (xval - m_minX) / m_rangeX; |
---|
683 | temp = temp * (double) m_panelWidth; |
---|
684 | |
---|
685 | return (int)temp; |
---|
686 | } |
---|
687 | |
---|
688 | /** Convert a Y coordinate from the instance space to the panel space. |
---|
689 | */ |
---|
690 | private int convertToPanelY(double yval) { |
---|
691 | double temp = (yval - m_minY) / m_rangeY; |
---|
692 | temp = temp * (double) m_panelHeight; |
---|
693 | temp = m_panelHeight - temp; |
---|
694 | |
---|
695 | return (int)temp; |
---|
696 | } |
---|
697 | |
---|
698 | /** Convert an X coordinate from the panel space to the instance space. |
---|
699 | */ |
---|
700 | private double convertFromPanelX(double pX) { |
---|
701 | pX /= (double) m_panelWidth; |
---|
702 | pX *= m_rangeX; |
---|
703 | return pX + m_minX; |
---|
704 | } |
---|
705 | |
---|
706 | /** Convert a Y coordinate from the panel space to the instance space. |
---|
707 | */ |
---|
708 | private double convertFromPanelY(double pY) { |
---|
709 | pY = m_panelHeight - pY; |
---|
710 | pY /= (double) m_panelHeight; |
---|
711 | pY *= m_rangeY; |
---|
712 | |
---|
713 | return pY + m_minY; |
---|
714 | } |
---|
715 | |
---|
716 | |
---|
717 | /** Plot a point in our visualization on-screen. |
---|
718 | */ |
---|
719 | protected void plotPoint(int x, int y, double [] probs, boolean update) { |
---|
720 | plotPoint(x, y, 1, 1, probs, update); |
---|
721 | } |
---|
722 | |
---|
723 | /** Plot a point in our visualization on-screen. |
---|
724 | */ |
---|
725 | private void plotPoint(int x, int y, int width, int height, |
---|
726 | double [] probs, boolean update) { |
---|
727 | |
---|
728 | // draw a progress line |
---|
729 | Graphics osg = m_osi.getGraphics(); |
---|
730 | if (update) { |
---|
731 | osg.setXORMode(Color.white); |
---|
732 | osg.drawLine(0, y, m_panelWidth-1, y); |
---|
733 | update(); |
---|
734 | osg.drawLine(0, y, m_panelWidth-1, y); |
---|
735 | } |
---|
736 | |
---|
737 | // plot the point |
---|
738 | osg.setPaintMode(); |
---|
739 | float [] colVal = new float[3]; |
---|
740 | |
---|
741 | float [] tempCols = new float[3]; |
---|
742 | for (int k = 0; k < probs.length; k++) { |
---|
743 | Color curr = (Color)m_Colors.elementAt(k % m_Colors.size()); |
---|
744 | |
---|
745 | curr.getRGBColorComponents(tempCols); |
---|
746 | for (int z = 0 ; z < 3; z++) { |
---|
747 | colVal[z] += probs[k] * tempCols[z]; |
---|
748 | } |
---|
749 | } |
---|
750 | |
---|
751 | for (int z = 0; z < 3; z++) { |
---|
752 | if (colVal[z] < 0) { |
---|
753 | colVal[z] = 0; |
---|
754 | } else if (colVal[z] > 1) { |
---|
755 | colVal[z] = 1; |
---|
756 | } |
---|
757 | } |
---|
758 | |
---|
759 | osg.setColor(new Color(colVal[0], |
---|
760 | colVal[1], |
---|
761 | colVal[2])); |
---|
762 | osg.fillRect(x, y, width, height); |
---|
763 | } |
---|
764 | |
---|
765 | /** Update the rendered image. |
---|
766 | */ |
---|
767 | private void update() { |
---|
768 | Graphics g = m_plotPanel.getGraphics(); |
---|
769 | g.drawImage(m_osi, 0, 0, m_plotPanel); |
---|
770 | } |
---|
771 | |
---|
772 | /** |
---|
773 | * Set the training data to use |
---|
774 | * |
---|
775 | * @param trainingData the training data |
---|
776 | * @exception Exception if an error occurs |
---|
777 | */ |
---|
778 | public void setTrainingData(Instances trainingData) throws Exception { |
---|
779 | |
---|
780 | m_trainingData = trainingData; |
---|
781 | if (m_trainingData.classIndex() < 0) { |
---|
782 | throw new Exception("No class attribute set (BoundaryPanel)"); |
---|
783 | } |
---|
784 | m_classIndex = m_trainingData.classIndex(); |
---|
785 | } |
---|
786 | |
---|
787 | /** Adds a training instance to the visualization dataset. |
---|
788 | */ |
---|
789 | public void addTrainingInstance(Instance instance) { |
---|
790 | |
---|
791 | if (m_trainingData == null) { |
---|
792 | //TODO |
---|
793 | System.err.println("Trying to add to a null training set (BoundaryPanel)"); |
---|
794 | } |
---|
795 | |
---|
796 | m_trainingData.add(instance); |
---|
797 | } |
---|
798 | |
---|
799 | /** |
---|
800 | * Register a listener to be notified when plotting completes |
---|
801 | * |
---|
802 | * @param newListener the listener to add |
---|
803 | */ |
---|
804 | public void addActionListener(ActionListener newListener) { |
---|
805 | m_listeners.add(newListener); |
---|
806 | } |
---|
807 | |
---|
808 | /** |
---|
809 | * Remove a listener |
---|
810 | * |
---|
811 | * @param removeListener the listener to remove |
---|
812 | */ |
---|
813 | public void removeActionListener(ActionListener removeListener) { |
---|
814 | m_listeners.removeElement(removeListener); |
---|
815 | } |
---|
816 | |
---|
817 | /** |
---|
818 | * Set the classifier to use. |
---|
819 | * |
---|
820 | * @param classifier the classifier to use |
---|
821 | */ |
---|
822 | public void setClassifier(Classifier classifier) { |
---|
823 | m_classifier = classifier; |
---|
824 | } |
---|
825 | |
---|
826 | /** |
---|
827 | * Set the data generator to use for generating new instances |
---|
828 | * |
---|
829 | * @param dataGenerator the data generator to use |
---|
830 | */ |
---|
831 | public void setDataGenerator(DataGenerator dataGenerator) { |
---|
832 | m_dataGenerator = dataGenerator; |
---|
833 | } |
---|
834 | |
---|
835 | /** |
---|
836 | * Set the x attribute index |
---|
837 | * |
---|
838 | * @param xatt index of the attribute to use on the x axis |
---|
839 | * @exception Exception if an error occurs |
---|
840 | */ |
---|
841 | public void setXAttribute(int xatt) throws Exception { |
---|
842 | if (m_trainingData == null) { |
---|
843 | throw new Exception("No training data set (BoundaryPanel)"); |
---|
844 | } |
---|
845 | if (xatt < 0 || |
---|
846 | xatt > m_trainingData.numAttributes()) { |
---|
847 | throw new Exception("X attribute out of range (BoundaryPanel)"); |
---|
848 | } |
---|
849 | if (m_trainingData.attribute(xatt).isNominal()) { |
---|
850 | throw new Exception("Visualization dimensions must be numeric " |
---|
851 | +"(BoundaryPanel)"); |
---|
852 | } |
---|
853 | /*if (m_trainingData.numDistinctValues(xatt) < 2) { |
---|
854 | throw new Exception("Too few distinct values for X attribute " |
---|
855 | +"(BoundaryPanel)"); |
---|
856 | }*/ //removed by jimmy. TESTING! |
---|
857 | m_xAttribute = xatt; |
---|
858 | } |
---|
859 | |
---|
860 | /** |
---|
861 | * Set the y attribute index |
---|
862 | * |
---|
863 | * @param yatt index of the attribute to use on the y axis |
---|
864 | * @exception Exception if an error occurs |
---|
865 | */ |
---|
866 | public void setYAttribute(int yatt) throws Exception { |
---|
867 | if (m_trainingData == null) { |
---|
868 | throw new Exception("No training data set (BoundaryPanel)"); |
---|
869 | } |
---|
870 | if (yatt < 0 || |
---|
871 | yatt > m_trainingData.numAttributes()) { |
---|
872 | throw new Exception("X attribute out of range (BoundaryPanel)"); |
---|
873 | } |
---|
874 | if (m_trainingData.attribute(yatt).isNominal()) { |
---|
875 | throw new Exception("Visualization dimensions must be numeric " |
---|
876 | +"(BoundaryPanel)"); |
---|
877 | } |
---|
878 | /*if (m_trainingData.numDistinctValues(yatt) < 2) { |
---|
879 | throw new Exception("Too few distinct values for Y attribute " |
---|
880 | +"(BoundaryPanel)"); |
---|
881 | }*/ //removed by jimmy. TESTING! |
---|
882 | m_yAttribute = yatt; |
---|
883 | } |
---|
884 | |
---|
885 | /** |
---|
886 | * Set a vector of Color objects for the classes |
---|
887 | * |
---|
888 | * @param colors a <code>FastVector</code> value |
---|
889 | */ |
---|
890 | public void setColors(FastVector colors) { |
---|
891 | synchronized (m_Colors) { |
---|
892 | m_Colors = colors; |
---|
893 | } |
---|
894 | //replot(); //commented by jimmy |
---|
895 | update(); //added by jimmy |
---|
896 | } |
---|
897 | |
---|
898 | /** |
---|
899 | * Set whether to superimpose the training data |
---|
900 | * plot |
---|
901 | * |
---|
902 | * @param pg a <code>boolean</code> value |
---|
903 | */ |
---|
904 | public void setPlotTrainingData(boolean pg) { |
---|
905 | m_plotTrainingData = pg; |
---|
906 | } |
---|
907 | |
---|
908 | /** |
---|
909 | * Returns true if training data is to be superimposed |
---|
910 | * |
---|
911 | * @return a <code>boolean</code> value |
---|
912 | */ |
---|
913 | public boolean getPlotTrainingData() { |
---|
914 | return m_plotTrainingData; |
---|
915 | } |
---|
916 | |
---|
917 | /** |
---|
918 | * Get the current vector of Color objects used for the classes |
---|
919 | * |
---|
920 | * @return a <code>FastVector</code> value |
---|
921 | */ |
---|
922 | public FastVector getColors() { |
---|
923 | return m_Colors; |
---|
924 | } |
---|
925 | |
---|
926 | /** |
---|
927 | * Quickly replot the display using cached probability estimates |
---|
928 | */ |
---|
929 | public void replot() { |
---|
930 | if (m_probabilityCache[0][0] == null) { |
---|
931 | return; |
---|
932 | } |
---|
933 | m_stopReplotting = true; |
---|
934 | m_pausePlotting = true; |
---|
935 | // wait 300 ms to give any other replot threads a chance to halt |
---|
936 | try { |
---|
937 | Thread.sleep(300); |
---|
938 | } catch (Exception ex) {} |
---|
939 | |
---|
940 | final Thread replotThread = new Thread() { |
---|
941 | public void run() { |
---|
942 | m_stopReplotting = false; |
---|
943 | int size2 = m_size / 2; |
---|
944 | finishedReplot: for (int i = 0; i < m_panelHeight; i += m_size) { |
---|
945 | for (int j = 0; j < m_panelWidth; j += m_size) { |
---|
946 | if (m_probabilityCache[i][j] == null || m_stopReplotting) { |
---|
947 | break finishedReplot; |
---|
948 | } |
---|
949 | |
---|
950 | boolean update = (j == 0 && i % 2 == 0); |
---|
951 | if (i < m_panelHeight && j < m_panelWidth) { |
---|
952 | // Draw the three new subpixel regions or single course tiling |
---|
953 | if (m_initialTiling || m_size == 1) { |
---|
954 | if (m_probabilityCache[i][j] == null) { |
---|
955 | break finishedReplot; |
---|
956 | } |
---|
957 | plotPoint(j, i, m_size, m_size, |
---|
958 | m_probabilityCache[i][j], update); |
---|
959 | } else { |
---|
960 | if (m_probabilityCache[i+size2][j] == null) { |
---|
961 | break finishedReplot; |
---|
962 | } |
---|
963 | plotPoint(j, i + size2, size2, size2, |
---|
964 | m_probabilityCache[i + size2][j], update); |
---|
965 | if (m_probabilityCache[i+size2][j+size2] == null) { |
---|
966 | break finishedReplot; |
---|
967 | } |
---|
968 | plotPoint(j + size2, i + size2, size2, size2, |
---|
969 | m_probabilityCache[i + size2][j + size2], update); |
---|
970 | if (m_probabilityCache[i][j+size2] == null) { |
---|
971 | break finishedReplot; |
---|
972 | } |
---|
973 | plotPoint(j + size2, i, size2, size2, |
---|
974 | m_probabilityCache[i + size2][j], update); |
---|
975 | } |
---|
976 | } |
---|
977 | } |
---|
978 | } |
---|
979 | update(); |
---|
980 | if (m_plotTrainingData) { |
---|
981 | plotTrainingData(); |
---|
982 | } |
---|
983 | m_pausePlotting = false; |
---|
984 | if (!m_stopPlotting) { |
---|
985 | synchronized (m_dummy) { |
---|
986 | m_dummy.notifyAll(); |
---|
987 | } |
---|
988 | } |
---|
989 | } |
---|
990 | }; |
---|
991 | |
---|
992 | replotThread.start(); |
---|
993 | } |
---|
994 | |
---|
995 | protected void saveImage(String fileName) { |
---|
996 | BufferedImage bi; |
---|
997 | Graphics2D gr2; |
---|
998 | ImageWriter writer; |
---|
999 | Iterator iter; |
---|
1000 | ImageOutputStream ios; |
---|
1001 | ImageWriteParam param; |
---|
1002 | |
---|
1003 | try { |
---|
1004 | // render image |
---|
1005 | bi = new BufferedImage(m_panelWidth, m_panelHeight, BufferedImage.TYPE_INT_RGB); |
---|
1006 | gr2 = bi.createGraphics(); |
---|
1007 | gr2.drawImage(m_osi, 0, 0, m_panelWidth, m_panelHeight, null); |
---|
1008 | |
---|
1009 | // get jpeg writer |
---|
1010 | writer = null; |
---|
1011 | iter = ImageIO.getImageWritersByFormatName("jpg"); |
---|
1012 | if (iter.hasNext()) |
---|
1013 | writer = (ImageWriter) iter.next(); |
---|
1014 | else |
---|
1015 | throw new Exception("No JPEG writer available!"); |
---|
1016 | |
---|
1017 | // prepare output file |
---|
1018 | ios = ImageIO.createImageOutputStream(new File(fileName)); |
---|
1019 | writer.setOutput(ios); |
---|
1020 | |
---|
1021 | // set the quality |
---|
1022 | param = new JPEGImageWriteParam(Locale.getDefault()); |
---|
1023 | param.setCompressionMode(ImageWriteParam.MODE_EXPLICIT) ; |
---|
1024 | param.setCompressionQuality(1.0f); |
---|
1025 | |
---|
1026 | // write the image |
---|
1027 | writer.write(null, new IIOImage(bi, null, null), param); |
---|
1028 | |
---|
1029 | // cleanup |
---|
1030 | ios.flush(); |
---|
1031 | writer.dispose(); |
---|
1032 | ios.close(); |
---|
1033 | } |
---|
1034 | catch (Exception e) { |
---|
1035 | e.printStackTrace(); |
---|
1036 | } |
---|
1037 | } |
---|
1038 | |
---|
1039 | /** Adds a training instance to our dataset, based on the coordinates of the mouse on the panel. |
---|
1040 | This method sets the x and y attributes and the class (as defined by classAttIndex), and sets |
---|
1041 | all other values as Missing. |
---|
1042 | * @param mouseX the x coordinate of the mouse, in pixels. |
---|
1043 | * @param mouseY the y coordinate of the mouse, in pixels. |
---|
1044 | * @param classAttIndex the index of the attribute that is currently selected as the class attribute. |
---|
1045 | * @param classValue the value to set the class to in our new point. |
---|
1046 | */ |
---|
1047 | public void addTrainingInstanceFromMouseLocation(int mouseX, int mouseY, int classAttIndex, double classValue) { |
---|
1048 | //convert to coordinates in the training instance space. |
---|
1049 | double x = convertFromPanelX(mouseX); |
---|
1050 | double y = convertFromPanelY(mouseY); |
---|
1051 | |
---|
1052 | //build the training instance |
---|
1053 | Instance newInstance = new DenseInstance(m_trainingData.numAttributes()); |
---|
1054 | for (int i = 0; i < newInstance.numAttributes(); i++) { |
---|
1055 | if (i == classAttIndex) { |
---|
1056 | newInstance.setValue(i,classValue); |
---|
1057 | } |
---|
1058 | else if (i == m_xAttribute) |
---|
1059 | newInstance.setValue(i,x); |
---|
1060 | else if (i == m_yAttribute) |
---|
1061 | newInstance.setValue(i,y); |
---|
1062 | else newInstance.setMissing(i); |
---|
1063 | } |
---|
1064 | |
---|
1065 | //add it to our data set. |
---|
1066 | addTrainingInstance(newInstance); |
---|
1067 | } |
---|
1068 | |
---|
1069 | /** Deletes all training instances from our dataset. |
---|
1070 | */ |
---|
1071 | public void removeAllInstances() { |
---|
1072 | if (m_trainingData != null) |
---|
1073 | { |
---|
1074 | m_trainingData.delete(); |
---|
1075 | try { initialize();} catch (Exception e) {}; |
---|
1076 | } |
---|
1077 | |
---|
1078 | } |
---|
1079 | |
---|
1080 | /** Removes a single training instance from our dataset, if there is one that is close enough |
---|
1081 | to the specified mouse location. |
---|
1082 | */ |
---|
1083 | public void removeTrainingInstanceFromMouseLocation(int mouseX, int mouseY) { |
---|
1084 | |
---|
1085 | //convert to coordinates in the training instance space. |
---|
1086 | double x = convertFromPanelX(mouseX); |
---|
1087 | double y = convertFromPanelY(mouseY); |
---|
1088 | |
---|
1089 | int bestIndex = -1; |
---|
1090 | double bestDistanceBetween = Integer.MAX_VALUE; |
---|
1091 | |
---|
1092 | //find the closest point. |
---|
1093 | for (int i = 0; i < m_trainingData.numInstances(); i++) { |
---|
1094 | Instance current = m_trainingData.instance(i); |
---|
1095 | double distanceBetween = (current.value(m_xAttribute) - x) * (current.value(m_xAttribute) - x) + (current.value(m_yAttribute) - y) * (current.value(m_yAttribute) - y); // won't bother to sqrt, just used square values. |
---|
1096 | |
---|
1097 | if (distanceBetween < bestDistanceBetween) |
---|
1098 | { |
---|
1099 | bestIndex = i; |
---|
1100 | bestDistanceBetween = distanceBetween; |
---|
1101 | } |
---|
1102 | } |
---|
1103 | if (bestIndex == -1) |
---|
1104 | return; |
---|
1105 | Instance best = m_trainingData.instance(bestIndex); |
---|
1106 | double panelDistance = (convertToPanelX(best.value(m_xAttribute)) - mouseX) * (convertToPanelX(best.value(m_xAttribute)) - mouseX) |
---|
1107 | + (convertToPanelY(best.value(m_yAttribute)) - mouseY) * (convertToPanelY(best.value(m_yAttribute)) - mouseY); |
---|
1108 | if (panelDistance < REMOVE_POINT_RADIUS * REMOVE_POINT_RADIUS) {//the best point is close enough. (using squared distances) |
---|
1109 | m_trainingData.delete(bestIndex); |
---|
1110 | } |
---|
1111 | } |
---|
1112 | |
---|
1113 | /** Starts the plotting thread. Will also create it if necessary. |
---|
1114 | */ |
---|
1115 | public void startPlotThread() { |
---|
1116 | if (m_plotThread == null) { //jimmy |
---|
1117 | m_plotThread = new PlotThread(); |
---|
1118 | m_plotThread.setPriority(Thread.MIN_PRIORITY); |
---|
1119 | m_plotThread.start(); |
---|
1120 | } |
---|
1121 | } |
---|
1122 | |
---|
1123 | /** Adds a mouse listener. |
---|
1124 | */ |
---|
1125 | public void addMouseListener(MouseListener l) { |
---|
1126 | m_plotPanel.addMouseListener(l); |
---|
1127 | } |
---|
1128 | |
---|
1129 | /** Gets the minimum x-coordinate bound, in training-instance units (not mouse coordinates). |
---|
1130 | */ |
---|
1131 | public double getMinXBound() { |
---|
1132 | return m_minX; |
---|
1133 | } |
---|
1134 | |
---|
1135 | /** Gets the minimum y-coordinate bound, in training-instance units (not mouse coordinates). |
---|
1136 | */ |
---|
1137 | public double getMinYBound() { |
---|
1138 | return m_minY; |
---|
1139 | } |
---|
1140 | |
---|
1141 | /** Gets the maximum x-coordinate bound, in training-instance units (not mouse coordinates). |
---|
1142 | */ |
---|
1143 | public double getMaxXBound() { |
---|
1144 | return m_maxX; |
---|
1145 | } |
---|
1146 | |
---|
1147 | /** Gets the maximum x-coordinate bound, in training-instance units (not mouse coordinates). |
---|
1148 | */ |
---|
1149 | public double getMaxYBound() { |
---|
1150 | return m_maxY; |
---|
1151 | } |
---|
1152 | |
---|
1153 | /** |
---|
1154 | * Main method for testing this class |
---|
1155 | * |
---|
1156 | * @param args a <code>String[]</code> value |
---|
1157 | */ |
---|
1158 | public static void main (String [] args) { |
---|
1159 | try { |
---|
1160 | if (args.length < 8) { |
---|
1161 | System.err.println("Usage : BoundaryPanel <dataset> " |
---|
1162 | +"<class col> <xAtt> <yAtt> " |
---|
1163 | +"<base> <# loc/pixel> <kernel bandwidth> " |
---|
1164 | +"<display width> " |
---|
1165 | +"<display height> <classifier " |
---|
1166 | +"[classifier options]>"); |
---|
1167 | System.exit(1); |
---|
1168 | } |
---|
1169 | final javax.swing.JFrame jf = |
---|
1170 | new javax.swing.JFrame("Weka classification boundary visualizer"); |
---|
1171 | jf.getContentPane().setLayout(new BorderLayout()); |
---|
1172 | |
---|
1173 | System.err.println("Loading instances from : "+args[0]); |
---|
1174 | java.io.Reader r = new java.io.BufferedReader( |
---|
1175 | new java.io.FileReader(args[0])); |
---|
1176 | final Instances i = new Instances(r); |
---|
1177 | i.setClassIndex(Integer.parseInt(args[1])); |
---|
1178 | |
---|
1179 | // bv.setClassifier(new Logistic()); |
---|
1180 | final int xatt = Integer.parseInt(args[2]); |
---|
1181 | final int yatt = Integer.parseInt(args[3]); |
---|
1182 | int base = Integer.parseInt(args[4]); |
---|
1183 | int loc = Integer.parseInt(args[5]); |
---|
1184 | |
---|
1185 | int bandWidth = Integer.parseInt(args[6]); |
---|
1186 | int panelWidth = Integer.parseInt(args[7]); |
---|
1187 | int panelHeight = Integer.parseInt(args[8]); |
---|
1188 | |
---|
1189 | final String classifierName = args[9]; |
---|
1190 | final BoundaryPanel bv = new BoundaryPanel(panelWidth,panelHeight); |
---|
1191 | bv.addActionListener(new ActionListener() { |
---|
1192 | public void actionPerformed(ActionEvent e) { |
---|
1193 | String classifierNameNew = |
---|
1194 | classifierName.substring(classifierName.lastIndexOf('.')+1, |
---|
1195 | classifierName.length()); |
---|
1196 | bv.saveImage(classifierNameNew+"_"+i.relationName() |
---|
1197 | +"_X"+xatt+"_Y"+yatt+".jpg"); |
---|
1198 | } |
---|
1199 | }); |
---|
1200 | |
---|
1201 | jf.getContentPane().add(bv, BorderLayout.CENTER); |
---|
1202 | jf.setSize(bv.getMinimumSize()); |
---|
1203 | // jf.setSize(200,200); |
---|
1204 | jf.addWindowListener(new java.awt.event.WindowAdapter() { |
---|
1205 | public void windowClosing(java.awt.event.WindowEvent e) { |
---|
1206 | jf.dispose(); |
---|
1207 | System.exit(0); |
---|
1208 | } |
---|
1209 | }); |
---|
1210 | |
---|
1211 | jf.pack(); |
---|
1212 | jf.setVisible(true); |
---|
1213 | // bv.initialize(); |
---|
1214 | bv.repaint(); |
---|
1215 | |
---|
1216 | |
---|
1217 | String [] argsR = null; |
---|
1218 | if (args.length > 10) { |
---|
1219 | argsR = new String [args.length-10]; |
---|
1220 | for (int j = 10; j < args.length; j++) { |
---|
1221 | argsR[j-10] = args[j]; |
---|
1222 | } |
---|
1223 | } |
---|
1224 | Classifier c = AbstractClassifier.forName(args[9], argsR); |
---|
1225 | KDDataGenerator dataGen = new KDDataGenerator(); |
---|
1226 | dataGen.setKernelBandwidth(bandWidth); |
---|
1227 | bv.setDataGenerator(dataGen); |
---|
1228 | bv.setNumSamplesPerRegion(loc); |
---|
1229 | bv.setGeneratorSamplesBase(base); |
---|
1230 | bv.setClassifier(c); |
---|
1231 | bv.setTrainingData(i); |
---|
1232 | bv.setXAttribute(xatt); |
---|
1233 | bv.setYAttribute(yatt); |
---|
1234 | |
---|
1235 | try { |
---|
1236 | // try and load a color map if one exists |
---|
1237 | FileInputStream fis = new FileInputStream("colors.ser"); |
---|
1238 | ObjectInputStream ois = new ObjectInputStream(fis); |
---|
1239 | FastVector colors = (FastVector)ois.readObject(); |
---|
1240 | bv.setColors(colors); |
---|
1241 | } catch (Exception ex) { |
---|
1242 | System.err.println("No color map file"); |
---|
1243 | } |
---|
1244 | bv.start(); |
---|
1245 | } catch (Exception ex) { |
---|
1246 | ex.printStackTrace(); |
---|
1247 | } |
---|
1248 | } |
---|
1249 | } |
---|
1250 | |
---|