source: src/main/java/weka/gui/boundaryvisualizer/RemoteBoundaryVisualizerSubTask.java @ 17

Last change on this file since 17 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 10.6 KB
RevLine 
[4]1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *   RemoteBoundaryVisualizerSubTask.java
19 *   Copyright (C) 2003 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.gui.boundaryvisualizer;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.core.Instance;
28import weka.core.DenseInstance;
29import weka.core.Instances;
30import weka.core.Utils;
31import weka.experiment.Task;
32import weka.experiment.TaskStatusInfo;
33
34import java.util.Random;
35
36/**
37 * Class that encapsulates a sub task for distributed boundary
38 * visualization. Produces probability distributions for each pixel
39 * in one row of the visualization.
40 *
41 * @author <a href="mailto:mhall@cs.waikato.ac.nz">Mark Hall</a>
42 * @version $Revision: 5987 $
43 * @since 1.0
44 * @see Task
45 */
46public class RemoteBoundaryVisualizerSubTask implements Task {
47
48  // status information for this sub task
49  private TaskStatusInfo m_status = new TaskStatusInfo();
50
51  // the result of this sub task
52  private RemoteResult m_result;
53
54  // which row are we doing
55  private int m_rowNumber;
56
57  // width and height of the visualization
58  private int m_panelHeight;
59  private int m_panelWidth;
60
61  // the classifier to use
62  private Classifier m_classifier;
63
64  // the kernel density estimator
65  private DataGenerator m_dataGenerator;
66
67  // the training data
68  private Instances m_trainingData;
69
70  // attributes for visualizing on (fixed dimensions)
71  private int m_xAttribute;
72  private int m_yAttribute;
73
74  // pixel width and height in terms of attribute values
75  private double m_pixHeight;
76  private double m_pixWidth;
77
78  // min, max of these attributes
79  private double m_minX;
80  private double m_minY;
81  private double m_maxX;
82  private double m_maxY;
83
84  // number of samples to take from each region in the fixed dimensions
85  private int m_numOfSamplesPerRegion = 2;
86
87  // number of samples per kernel = base ^ (# non-fixed dimensions)
88  private int m_numOfSamplesPerGenerator;
89  private double m_samplesBase = 2.0;
90
91  // A random number generator
92  private Random m_random;
93
94  private double [] m_weightingAttsValues;
95  private boolean [] m_attsToWeightOn;
96  private double [] m_vals;
97  private double [] m_dist;
98  private Instance m_predInst;
99 
100  /**
101   * Set the row number for this sub task
102   *
103   * @param rn the row number
104   */
105  public void setRowNumber(int rn) {
106    m_rowNumber = rn;
107  }
108
109  /**
110   * Set the width of the visualization
111   *
112   * @param pw the width
113   */
114  public void setPanelWidth(int pw) {
115    m_panelWidth = pw;
116  }
117
118  /**
119   * Set the height of the visualization
120   *
121   * @param ph the height
122   */
123  public void setPanelHeight(int ph) {
124    m_panelHeight = ph;
125  }
126
127  /**
128   * Set the height of a pixel
129   *
130   * @param ph the height of a pixel
131   */
132  public void setPixHeight(double ph) {
133    m_pixHeight = ph;
134  }
135
136  /**
137   * Set the width of a pixel
138   *
139   * @param pw the width of a pixel
140   */
141  public void setPixWidth(double pw) {
142    m_pixWidth = pw;
143  }
144
145  /**
146   * Set the classifier to use
147   *
148   * @param dc the classifier
149   */
150  public void setClassifier(Classifier dc) {
151    m_classifier = dc;
152  }
153
154  /**
155   * Set the density estimator to use
156   *
157   * @param dg the density estimator
158   */
159  public void setDataGenerator(DataGenerator dg) {
160    m_dataGenerator = dg;
161  }
162
163  /**
164   * Set the training data
165   *
166   * @param i the training data
167   */
168  public void setInstances(Instances i) {
169    m_trainingData = i;
170  }
171
172  /**
173   * Set the minimum and maximum values of the x axis fixed dimension
174   *
175   * @param minx a <code>double</code> value
176   * @param maxx a <code>double</code> value
177   */
178  public void setMinMaxX(double minx, double maxx) {
179    m_minX = minx; m_maxX = maxx;
180  }
181
182  /**
183   * Set the minimum and maximum values of the y axis fixed dimension
184   *
185   * @param miny a <code>double</code> value
186   * @param maxy a <code>double</code> value
187   */
188  public void setMinMaxY(double miny, double maxy) {
189    m_minY = miny; m_maxY = maxy;
190  }
191
192  /**
193   * Set the x axis fixed dimension
194   *
195   * @param xatt an <code>int</code> value
196   */
197  public void setXAttribute(int xatt) {
198    m_xAttribute = xatt;
199  }
200
201  /**
202   * Set the y axis fixed dimension
203   *
204   * @param yatt an <code>int</code> value
205   */
206  public void setYAttribute(int yatt) {
207    m_yAttribute = yatt;
208  }
209
210  /**
211   * Set the number of points to uniformly sample from a region (fixed
212   * dimensions).
213   *
214   * @param num an <code>int</code> value
215   */
216  public void setNumSamplesPerRegion(int num) {
217    m_numOfSamplesPerRegion = num;
218  }
219
220  /**
221   * Set the base for computing the number of samples to obtain from each
222   * generator. number of samples = base ^ (# non fixed dimensions)
223   *
224   * @param ksb a <code>double</code> value
225   */
226  public void setGeneratorSamplesBase(double ksb) {
227    m_samplesBase = ksb;
228  }
229
230  /**
231   * Perform the sub task
232   */
233  public void execute() {
234
235    m_random = new Random(m_rowNumber * 11);
236    m_dataGenerator.setSeed(m_rowNumber * 11);
237    m_result = new RemoteResult(m_rowNumber, m_panelWidth);
238    m_status.setTaskResult(m_result);
239    m_status.setExecutionStatus(TaskStatusInfo.PROCESSING);
240
241    try {
242      m_numOfSamplesPerGenerator = 
243        (int)Math.pow(m_samplesBase, m_trainingData.numAttributes()-3);
244      if (m_trainingData == null) {
245        throw new Exception("No training data set (BoundaryPanel)");
246      }
247      if (m_classifier == null) {
248        throw new Exception("No classifier set (BoundaryPanel)");
249      }
250      if (m_dataGenerator == null) {
251        throw new Exception("No data generator set (BoundaryPanel)");
252      }
253      if (m_trainingData.attribute(m_xAttribute).isNominal() || 
254        m_trainingData.attribute(m_yAttribute).isNominal()) {
255        throw new Exception("Visualization dimensions must be numeric "
256                            +"(RemoteBoundaryVisualizerSubTask)");
257      }
258     
259      m_attsToWeightOn = new boolean[m_trainingData.numAttributes()];
260      m_attsToWeightOn[m_xAttribute] = true;
261      m_attsToWeightOn[m_yAttribute] = true;
262     
263      // generate samples
264      m_weightingAttsValues = new double [m_attsToWeightOn.length];
265      m_vals = new double[m_trainingData.numAttributes()];
266      m_predInst = new DenseInstance(1.0, m_vals);
267      m_predInst.setDataset(m_trainingData);
268
269      System.err.println("Executing row number "+m_rowNumber);
270      for (int j = 0; j < m_panelWidth; j++) {
271        double [] preds = calculateRegionProbs(j, m_rowNumber);
272        m_result.setLocationProbs(j, preds);
273        m_result.
274          setPercentCompleted((int)(100 * ((double)j / (double)m_panelWidth)));
275      }
276    } catch (Exception ex) {
277      m_status.setExecutionStatus(TaskStatusInfo.FAILED);
278      m_status.setStatusMessage("Row "+m_rowNumber+" failed.");
279      System.err.print(ex);
280      return;
281    }
282
283    // finished
284    m_status.setExecutionStatus(TaskStatusInfo.FINISHED);
285    m_status.setStatusMessage("Row "+m_rowNumber+" completed successfully.");
286  }
287
288
289  private double [] calculateRegionProbs(int j, int i) throws Exception {
290    double [] sumOfProbsForRegion = 
291      new double [m_trainingData.classAttribute().numValues()];
292
293    for (int u = 0; u < m_numOfSamplesPerRegion; u++) {
294     
295      double [] sumOfProbsForLocation = 
296        new double [m_trainingData.classAttribute().numValues()];
297     
298      m_weightingAttsValues[m_xAttribute] = getRandomX(j);
299      m_weightingAttsValues[m_yAttribute] = getRandomY(m_panelHeight-i-1);
300     
301      m_dataGenerator.setWeightingValues(m_weightingAttsValues);
302     
303      double [] weights = m_dataGenerator.getWeights();
304      double sumOfWeights = Utils.sum(weights);
305      int [] indices = Utils.sort(weights);
306     
307      // Prune 1% of weight mass
308      int [] newIndices = new int[indices.length];
309      double sumSoFar = 0; 
310      double criticalMass = 0.99 * sumOfWeights;
311      int index = weights.length - 1; int counter = 0;
312      for (int z = weights.length - 1; z >= 0; z--) {
313        newIndices[index--] = indices[z];
314        sumSoFar += weights[indices[z]];
315        counter++;
316        if (sumSoFar > criticalMass) {
317          break;
318        }
319      }
320      indices = new int[counter];
321      System.arraycopy(newIndices, index + 1, indices, 0, counter);
322     
323      for (int z = 0; z < m_numOfSamplesPerGenerator; z++) {
324       
325        m_dataGenerator.setWeightingValues(m_weightingAttsValues);
326        double [][] values = m_dataGenerator.generateInstances(indices);
327       
328        for (int q = 0; q < values.length; q++) {
329          if (values[q] != null) {
330            System.arraycopy(values[q], 0, m_vals, 0, m_vals.length);
331            m_vals[m_xAttribute] = m_weightingAttsValues[m_xAttribute];
332            m_vals[m_yAttribute] = m_weightingAttsValues[m_yAttribute];
333           
334            // classify the instance
335            m_dist = m_classifier.distributionForInstance(m_predInst);
336
337            for (int k = 0; k < sumOfProbsForLocation.length; k++) {
338              sumOfProbsForLocation[k] += (m_dist[k] * weights[q]); 
339            }
340          }
341        }
342      }
343     
344      for (int k = 0; k < sumOfProbsForRegion.length; k++) {
345        sumOfProbsForRegion[k] += (sumOfProbsForLocation[k] * sumOfWeights); 
346      }
347    }
348   
349    // average
350    Utils.normalize(sumOfProbsForRegion);
351
352    // cache
353    double [] tempDist = new double[sumOfProbsForRegion.length];
354    System.arraycopy(sumOfProbsForRegion, 0, tempDist, 
355                     0, sumOfProbsForRegion.length);
356               
357    return tempDist;
358  }
359
360  /**
361   * Return a random x attribute value contained within
362   * the pix'th horizontal pixel
363   *
364   * @param pix the horizontal pixel number
365   * @return a value in attribute space
366   */
367  private double getRandomX(int pix) {
368
369    double minPix =  m_minX + (pix * m_pixWidth);
370
371    return minPix + m_random.nextDouble() * m_pixWidth;
372  }
373
374  /**
375   * Return a random y attribute value contained within
376   * the pix'th vertical pixel
377   *
378   * @param pix the vertical pixel number
379   * @return a value in attribute space
380   */
381  private double getRandomY(int pix) {
382   
383    double minPix = m_minY + (pix * m_pixHeight);
384   
385    return minPix +  m_random.nextDouble() * m_pixHeight;
386  }
387 
388  /**
389   * Return status information for this sub task
390   *
391   * @return a <code>TaskStatusInfo</code> value
392   */
393  public TaskStatusInfo getTaskStatus() {
394    return m_status;
395  }
396}
Note: See TracBrowser for help on using the repository browser.