source: branches/MetisMQI/src/main/java/weka/gui/boundaryvisualizer/KDDataGenerator.java

Last change on this file was 29, checked in by gnappo, 15 years ago

Taggata versione per la demo e aggiunto branch.

File size: 12.3 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *   KDDataGenerator.java
19 *   Copyright (C) 2002 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.gui.boundaryvisualizer;
24
25import weka.core.Attribute;
26import weka.core.Instance;
27import weka.core.Instances;
28import weka.core.Utils;
29
30import java.io.Serializable;
31import java.util.Random;
32
33/**
34 * KDDataGenerator. Class that uses kernels to generate new random
35 * instances based on a supplied set of instances.
36 *
37 * @author <a href="mailto:mhall@cs.waikato.ac.nz">Mark Hall</a>
38 * @version $Revision: 1.7 $
39 * @since 1.0
40 * @see DataGenerator
41 * @see Serializable
42 */
43public class KDDataGenerator
44  implements DataGenerator, Serializable {
45
46  /** for serialization */
47  private static final long serialVersionUID = -958573275606402792L;
48
49  /** the instances to use */
50  private Instances m_instances;
51
52  /** standard deviations of the normal distributions for numeric attributes in
53   * each KD estimator */
54  private double [] m_standardDeviations;
55
56  /** global means or modes to use for missing values */
57  private double [] m_globalMeansOrModes;
58
59  /** minimum standard deviation for numeric attributes */
60  private double m_minStdDev = 1e-5;
61
62  /** Laplace correction for discrete distributions */
63  private double m_laplaceConst = 1.0;
64
65  /** random number seed */
66  private int m_seed = 1;
67
68  /** random number generator */
69  private Random m_random;
70
71  /** which dimensions to use for computing a weight for each generated
72   * instance */
73  private boolean [] m_weightingDimensions;
74 
75  /** the values for the weighting dimensions to use for computing the weight
76   * for the next instance to be generated */
77  private double [] m_weightingValues;
78
79  private static double m_normConst = Math.sqrt(2*Math.PI);
80
81  /** Number of neighbours to use for kernel bandwidth */
82  private int m_kernelBandwidth = 3;
83
84  /** standard deviations for numeric attributes computed from the
85   * m_kernelBandwidth nearest neighbours for each kernel. */
86  private double [][] m_kernelParams;
87
88  /** The minimum values for numeric attributes. */
89  protected double [] m_Min;
90 
91  /** The maximum values for numeric attributes. */
92  protected double [] m_Max;
93
94  /**
95   * Initialize the generator using the supplied instances
96   *
97   * @param inputInstances the instances to use as the basis of the kernels
98   * @throws Exception if an error occurs
99   */
100  public void buildGenerator(Instances inputInstances) throws Exception {
101    m_random = new Random(m_seed);
102   
103    m_instances = inputInstances;
104    m_standardDeviations = new double [m_instances.numAttributes()];
105    m_globalMeansOrModes = new double [m_instances.numAttributes()];
106    if (m_weightingDimensions == null) {
107      m_weightingDimensions = new boolean[m_instances.numAttributes()];
108    }
109    /*    for (int i = 0; i < m_instances.numAttributes(); i++) {
110      if (i != m_instances.classIndex()) {
111        if (m_instances.attribute(i).isNumeric()) {
112          // global standard deviations
113          double var = m_instances.variance(i);
114          if (var == 0) {
115            var = m_minStdDev;
116          } else {
117            var = Math.sqrt(var);
118            //  heuristic to take into account # instances and dimensions
119            double adjust = Math.pow((double) m_instances.numInstances(),
120                                     1.0 / m_instances.numAttributes());
121            //    double adjust = m_instances.numInstances();
122            var /= adjust;
123          }
124          m_standardDeviations[i] = var;
125        } else {
126          m_globalMeansOrModes[i] = m_instances.meanOrMode(i);
127        }
128      }
129      } */
130    for (int i = 0; i < m_instances.numAttributes(); i++) {
131      if (i != m_instances.classIndex()) {
132        m_globalMeansOrModes[i] = m_instances.meanOrMode(i);
133      }
134    }
135
136    m_kernelParams = 
137      new double [m_instances.numInstances()][m_instances.numAttributes()];
138    computeParams();
139  }
140
141  public double [] getWeights() {
142
143    double [] weights = new double[m_instances.numInstances()];
144
145    for (int k = 0; k < m_instances.numInstances(); k++) {
146      double weight = 1;
147      for (int i = 0; i < m_instances.numAttributes(); i++) {
148        if (m_weightingDimensions[i]) {
149          double mean = 0;
150          if (!m_instances.instance(k).isMissing(i)) {
151            mean = m_instances.instance(k).value(i);
152          } else {
153            mean = m_globalMeansOrModes[i];
154          }
155          double wm = 1.0;
156         
157          //        wm = normalDens(m_weightingValues[i], mean, m_standardDeviations[i]);
158          wm = normalDens(m_weightingValues[i], mean, 
159                          m_kernelParams[k][i]);
160         
161          weight *= wm;
162        }
163      }
164      weights[k] = weight;
165    }
166    return weights;
167  }
168
169  /**
170   * Return a cumulative distribution from a discrete distribution
171   *
172   * @param dist the distribution to use
173   * @return the cumulative distribution
174   */
175  private double [] computeCumulativeDistribution(double [] dist) {
176
177    double [] cumDist = new double[dist.length];
178    double sum = 0;
179    for (int i = 0; i < dist.length; i++) {
180      sum += dist[i];
181      cumDist[i] = sum;
182    }
183   
184    return cumDist;
185  }
186
187  /**
188   * Generates a new instance using one kernel estimator. Each successive
189   * call to this method incremets the index of the kernel to use.
190   *
191   * @return the new random instance
192   * @throws Exception if an error occurs
193   */
194  public double [][] generateInstances(int [] indices) throws Exception {
195   
196    double [][] values = new double[m_instances.numInstances()][];
197
198    for (int k = 0; k < indices.length; k++) {
199      values[indices[k]] = new double[m_instances.numAttributes()];
200      for (int i = 0; i < m_instances.numAttributes(); i++) {
201        if ((!m_weightingDimensions[i]) && (i != m_instances.classIndex())) {
202          if (m_instances.attribute(i).isNumeric()) {
203            double mean = 0;
204            double val = m_random.nextGaussian();
205            if (!m_instances.instance(indices[k]).isMissing(i)) {
206              mean = m_instances.instance(indices[k]).value(i);
207            } else {
208              mean = m_globalMeansOrModes[i];
209            }
210           
211            val *= m_kernelParams[indices[k]][i];
212            val += mean;
213
214            values[indices[k]][i] = val;
215          } else {
216            // nominal attribute
217            double [] dist = new double[m_instances.attribute(i).numValues()];
218            for (int j = 0; j < dist.length; j++) {
219              dist[j] = m_laplaceConst;
220            }
221            if (!m_instances.instance(indices[k]).isMissing(i)) {
222              dist[(int)m_instances.instance(indices[k]).value(i)]++;
223            } else {
224              dist[(int)m_globalMeansOrModes[i]]++;
225            }
226            Utils.normalize(dist);
227            double [] cumDist = computeCumulativeDistribution(dist);
228            double randomVal = m_random.nextDouble();
229            int instVal = 0;
230            for (int j = 0; j < cumDist.length; j++) {
231              if (randomVal <= cumDist[j]) {
232                instVal = j;
233                break;
234              }
235            }
236            values[indices[k]][i] = (double)instVal;
237          }
238        }
239      }
240    }
241    return values;
242  }
243
244  /**
245   * Density function of normal distribution.
246   * @param x input value
247   * @param mean mean of distribution
248   * @param stdDev standard deviation of distribution
249   */
250  private double normalDens (double x, double mean, double stdDev) {
251    double diff = x - mean;
252   
253    return  (1/(m_normConst*stdDev))*Math.exp(-(diff*diff/(2*stdDev*stdDev)));
254  }
255
256  /**
257   * Set which dimensions to use when computing a weight for the next
258   * instance to generate
259   *
260   * @param dims an array of booleans indicating which dimensions to use
261   */
262  public void setWeightingDimensions(boolean [] dims) {
263    m_weightingDimensions = dims;
264  }
265
266  /**
267   * Set the values for the weighting dimensions to be used when computing
268   * the weight for the next instance to be generated
269   *
270   * @param vals an array of doubles containing the values of the
271   * weighting dimensions (corresponding to the entries that are set to
272   * true throw setWeightingDimensions)
273   */
274  public void setWeightingValues(double [] vals) {
275    m_weightingValues = vals;
276  }
277
278  /**
279   * Return the number of kernels (there is one per training instance)
280   *
281   * @return the number of kernels
282   */
283  public int getNumGeneratingModels() {
284    if (m_instances != null) {
285      return m_instances.numInstances();
286    }
287    return 0;
288  }
289
290  /**
291   * Set the kernel bandwidth (number of nearest neighbours to cover)
292   *
293   * @param kb an <code>int</code> value
294   */
295  public void setKernelBandwidth(int kb) {
296    m_kernelBandwidth = kb;
297  }
298
299  /**
300   * Get the kernel bandwidth
301   *
302   * @return an <code>int</code> value
303   */
304  public int getKernelBandwidth() {
305    return m_kernelBandwidth;
306  } 
307
308  /**
309   * Initializes a new random number generator using the
310   * supplied seed.
311   *
312   * @param seed an <code>int</code> value
313   */
314  public void setSeed(int seed) {
315    m_seed = seed;
316    m_random = new Random(m_seed);
317  }
318
319  /**
320   * Calculates the distance between two instances
321   *
322   * @param test the first instance
323   * @param train the second instance
324   * @return the distance between the two given instances, between 0 and 1
325   */         
326  private double distance(Instance first, Instance second) { 
327
328    double diff, distance = 0;
329
330    for(int i = 0; i < m_instances.numAttributes(); i++) { 
331      if (i == m_instances.classIndex()) {
332        continue;
333      }
334      double firstVal = m_globalMeansOrModes[i];
335      double secondVal = m_globalMeansOrModes[i];
336
337      switch (m_instances.attribute(i).type()) {
338      case Attribute.NUMERIC:
339        // If attribute is numeric
340        if (!first.isMissing(i)) {
341          firstVal = first.value(i);
342        }
343       
344        if (!second.isMissing(i)) {
345          secondVal = second.value(i);
346        }
347
348        diff = norm(firstVal,i) - norm(secondVal,i);
349
350        break;
351      default:
352        diff = 0;
353        break;
354      }
355      distance += diff * diff;
356    }
357    return Math.sqrt(distance);
358  }
359
360  /**
361   * Normalizes a given value of a numeric attribute.
362   *
363   * @param x the value to be normalized
364   * @param i the attribute's index
365   */
366  private double norm(double x,int i) {
367   
368    if (Double.isNaN(m_Min[i]) || Utils.eq(m_Max[i], m_Min[i])) {
369      return 0;
370    } else {
371      return (x - m_Min[i]) / (m_Max[i] - m_Min[i]);
372    }
373  }
374
375  /**
376   * Updates the minimum and maximum values for all the attributes
377   * based on a new instance.
378   *
379   * @param instance the new instance
380   */
381  private void updateMinMax(Instance instance) { 
382
383    for (int j = 0; j < m_instances.numAttributes(); j++) {
384      if (!instance.isMissing(j)) {
385        if (Double.isNaN(m_Min[j])) {
386          m_Min[j] = instance.value(j);
387          m_Max[j] = instance.value(j);
388        } else if (instance.value(j) < m_Min[j]) {
389          m_Min[j] = instance.value(j);
390        } else if (instance.value(j) > m_Max[j]) {
391          m_Max[j] = instance.value(j);
392        }
393      }
394    }
395  }
396
397  private void computeParams() throws Exception {
398    // Calculate the minimum and maximum values
399    m_Min = new double [m_instances.numAttributes()];
400    m_Max = new double [m_instances.numAttributes()];
401    for (int i = 0; i < m_instances.numAttributes(); i++) {
402      m_Min[i] = m_Max[i] = Double.NaN;
403    }
404    for (int i = 0; i < m_instances.numInstances(); i++) {
405      updateMinMax(m_instances.instance(i));
406    }
407
408    double [] distances = new double[m_instances.numInstances()];
409    for (int i = 0; i < m_instances.numInstances(); i++) {
410      Instance current = m_instances.instance(i);
411      for (int j = 0; j < m_instances.numInstances(); j++) {
412        distances[j] = distance(current, m_instances.instance(j));
413      }
414      int [] sorted = Utils.sort(distances);
415      int k = m_kernelBandwidth;
416      double bandwidth = distances[sorted[k]];
417
418      // Check for bandwidth zero
419      if (bandwidth <= 0) {
420        for (int j = k + 1; j < sorted.length; j++) {
421          if (distances[sorted[j]] > bandwidth) {
422            bandwidth = distances[sorted[j]];
423            break;
424          }
425        }
426        if (bandwidth <= 0) {
427          throw new Exception("All training instances coincide with "
428                              +"test instance!");
429        }
430      }
431      for (int j = 0; j < m_instances.numAttributes(); j++) {
432        if ((m_Max[j] - m_Min[j]) > 0) {
433          m_kernelParams[i][j] = bandwidth * (m_Max[j] - m_Min[j]);
434        }
435      }
436    }
437  }
438}
439
440
Note: See TracBrowser for help on using the repository browser.