source: src/main/java/weka/estimators/KernelEstimator.java @ 6

Last change on this file since 6 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 10.3 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    KernelEstimator.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.estimators;
24
25import weka.core.Capabilities.Capability;
26import weka.core.Capabilities;
27import weka.core.RevisionUtils;
28import weka.core.Utils;
29import weka.core.Statistics;
30
31/**
32 * Simple kernel density estimator. Uses one gaussian kernel per observed
33 * data value.
34 *
35 * @author Len Trigg (trigg@cs.waikato.ac.nz)
36 * @version $Revision: 5490 $
37 */
38public class KernelEstimator extends Estimator implements IncrementalEstimator {
39
40  /** for serialization */
41  private static final long serialVersionUID = 3646923563367683925L;
42
43  /** Vector containing all of the values seen */
44  private double [] m_Values;
45
46  /** Vector containing the associated weights */
47  private double [] m_Weights;
48
49  /** Number of values stored in m_Weights and m_Values so far */
50  private int m_NumValues;
51
52  /** The sum of the weights so far */
53  private double m_SumOfWeights;
54
55  /** The standard deviation */
56  private double m_StandardDev;
57
58  /** The precision of data values */
59  private double m_Precision;
60
61  /** Whether we can optimise the kernel summation */
62  private boolean m_AllWeightsOne;
63
64  /** Maximum percentage error permitted in probability calculations */
65  private static double MAX_ERROR = 0.01;
66
67
68  /**
69   * Execute a binary search to locate the nearest data value
70   *
71   * @param the data value to locate
72   * @return the index of the nearest data value
73   */
74  private int findNearestValue(double key) {
75
76    int low = 0; 
77    int high = m_NumValues;
78    int middle = 0;
79    while (low < high) {
80      middle = (low + high) / 2;
81      double current = m_Values[middle];
82      if (current == key) {
83        return middle;
84      }
85      if (current > key) {
86        high = middle;
87      } else if (current < key) {
88        low = middle + 1;
89      }
90    }
91    return low;
92  }
93
94  /**
95   * Round a data value using the defined precision for this estimator
96   *
97   * @param data the value to round
98   * @return the rounded data value
99   */
100  private double round(double data) {
101
102    return Math.rint(data / m_Precision) * m_Precision;
103  }
104 
105  // ===============
106  // Public methods.
107  // ===============
108 
109  /**
110   * Constructor that takes a precision argument.
111   *
112   * @param precision the  precision to which numeric values are given. For
113   * example, if the precision is stated to be 0.1, the values in the
114   * interval (0.25,0.35] are all treated as 0.3.
115   */
116  public KernelEstimator(double precision) {
117
118    m_Values = new double [50];
119    m_Weights = new double [50];
120    m_NumValues = 0;
121    m_SumOfWeights = 0;
122    m_AllWeightsOne = true;
123    m_Precision = precision;
124    // precision cannot be zero
125    if (m_Precision < Utils.SMALL) m_Precision = Utils.SMALL;
126    //    m_StandardDev = 1e10 * m_Precision; // Set the standard deviation initially very wide
127    m_StandardDev = m_Precision / (2 * 3);
128  }
129
130  /**
131   * Add a new data value to the current estimator.
132   *
133   * @param data the new data value
134   * @param weight the weight assigned to the data value
135   */
136  public void addValue(double data, double weight) {
137   
138    if (weight == 0) {
139      return;
140    }
141    data = round(data);
142    int insertIndex = findNearestValue(data);
143    if ((m_NumValues <= insertIndex) || (m_Values[insertIndex] != data)) {
144      if (m_NumValues < m_Values.length) {
145        int left = m_NumValues - insertIndex; 
146        System.arraycopy(m_Values, insertIndex, 
147            m_Values, insertIndex + 1, left);
148        System.arraycopy(m_Weights, insertIndex, 
149            m_Weights, insertIndex + 1, left);
150       
151        m_Values[insertIndex] = data;
152        m_Weights[insertIndex] = weight;
153        m_NumValues++;
154      } else {
155        double [] newValues = new double [m_Values.length * 2];
156        double [] newWeights = new double [m_Values.length * 2];
157        int left = m_NumValues - insertIndex; 
158        System.arraycopy(m_Values, 0, newValues, 0, insertIndex);
159        System.arraycopy(m_Weights, 0, newWeights, 0, insertIndex);
160        newValues[insertIndex] = data;
161        newWeights[insertIndex] = weight;
162        System.arraycopy(m_Values, insertIndex, 
163            newValues, insertIndex + 1, left);
164        System.arraycopy(m_Weights, insertIndex, 
165            newWeights, insertIndex + 1, left);
166        m_NumValues++;
167        m_Values = newValues;
168        m_Weights = newWeights;
169      }
170      if (weight != 1) {
171        m_AllWeightsOne = false;
172      }
173    } else {
174      m_Weights[insertIndex] += weight;
175      m_AllWeightsOne = false;     
176    }
177    m_SumOfWeights += weight;
178    double range = m_Values[m_NumValues - 1] - m_Values[0];
179    if (range > 0) {
180      m_StandardDev = Math.max(range / Math.sqrt(m_SumOfWeights), 
181          // allow at most 3 sds within one interval
182          m_Precision / (2 * 3));
183    }
184  }
185 
186  /**
187   * Get a probability estimate for a value.
188   *
189   * @param data the value to estimate the probability of
190   * @return the estimated probability of the supplied value
191   */
192  public double getProbability(double data) {
193
194    double delta = 0, sum = 0, currentProb = 0;
195    double zLower = 0, zUpper = 0;
196    if (m_NumValues == 0) {
197      zLower = (data - (m_Precision / 2)) / m_StandardDev;
198      zUpper = (data + (m_Precision / 2)) / m_StandardDev;
199      return (Statistics.normalProbability(zUpper)
200              - Statistics.normalProbability(zLower));
201    }
202    double weightSum = 0;
203    int start = findNearestValue(data);
204    for (int i = start; i < m_NumValues; i++) {
205      delta = m_Values[i] - data;
206      zLower = (delta - (m_Precision / 2)) / m_StandardDev;
207      zUpper = (delta + (m_Precision / 2)) / m_StandardDev;
208      currentProb = (Statistics.normalProbability(zUpper)
209                     - Statistics.normalProbability(zLower));
210      sum += currentProb * m_Weights[i];
211      /*
212      System.out.print("zL" + (i + 1) + ": " + zLower + " ");
213      System.out.print("zU" + (i + 1) + ": " + zUpper + " ");
214      System.out.print("P" + (i + 1) + ": " + currentProb + " ");
215      System.out.println("total: " + (currentProb * m_Weights[i]) + " ");
216      */
217      weightSum += m_Weights[i];
218      if (currentProb * (m_SumOfWeights - weightSum) < sum * MAX_ERROR) {
219        break;
220      }
221    }
222    for (int i = start - 1; i >= 0; i--) {
223      delta = m_Values[i] - data;
224      zLower = (delta - (m_Precision / 2)) / m_StandardDev;
225      zUpper = (delta + (m_Precision / 2)) / m_StandardDev;
226      currentProb = (Statistics.normalProbability(zUpper)
227                     - Statistics.normalProbability(zLower));
228      sum += currentProb * m_Weights[i];
229      weightSum += m_Weights[i];
230      if (currentProb * (m_SumOfWeights - weightSum) < sum * MAX_ERROR) {
231        break;
232      }
233    }
234    return sum / m_SumOfWeights;
235  }
236
237  /** Display a representation of this estimator */
238  public String toString() {
239
240    String result = m_NumValues + " Normal Kernels. \nStandardDev = " 
241      + Utils.doubleToString(m_StandardDev,6,4)
242      + " Precision = " + m_Precision;
243    if (m_NumValues == 0) {
244      result += "  \nMean = 0";
245    } else {
246      result += "  \nMeans =";
247      for (int i = 0; i < m_NumValues; i++) {
248        result += " " + m_Values[i];
249      }
250      if (!m_AllWeightsOne) {
251        result += "\nWeights = ";
252        for (int i = 0; i < m_NumValues; i++) {
253          result += " " + m_Weights[i];
254        }
255      }
256    }
257    return result + "\n";
258  }
259
260  /**
261   * Return the number of kernels in this kernel estimator
262   *
263   * @return the number of kernels
264   */
265  public int getNumKernels() {
266    return m_NumValues;
267  }
268
269  /**
270   * Return the means of the kernels.
271   *
272   * @return the means of the kernels
273   */
274  public double[] getMeans() {
275    return m_Values;
276  }
277
278  /**
279   * Return the weights of the kernels.
280   *
281   * @return the weights of the kernels
282   */
283  public double[] getWeights() {
284    return m_Weights;
285  }
286
287  /**
288   * Return the precision of this kernel estimator.
289   *
290   * @return the precision
291   */
292  public double getPrecision() {
293    return m_Precision;
294  }
295
296  /**
297   * Return the standard deviation of this kernel estimator.
298   *
299   * @return the standard deviation
300   */
301  public double getStdDev() {
302    return m_StandardDev;
303  }
304
305  /**
306   * Returns default capabilities of the classifier.
307   *
308   * @return      the capabilities of this classifier
309   */
310  public Capabilities getCapabilities() {
311    Capabilities result = super.getCapabilities();
312    result.disableAll();
313    // class
314    if (!m_noClass) {
315      result.enable(Capability.NOMINAL_CLASS);
316      result.enable(Capability.MISSING_CLASS_VALUES);
317    } else {
318      result.enable(Capability.NO_CLASS);
319    }
320   
321    // attributes
322    result.enable(Capability.NUMERIC_ATTRIBUTES);
323    return result;
324  }
325 
326  /**
327   * Returns the revision string.
328   *
329   * @return            the revision
330   */
331  public String getRevision() {
332    return RevisionUtils.extract("$Revision: 5490 $");
333  }
334
335  /**
336   * Main method for testing this class.
337   *
338   * @param argv should contain a sequence of numeric values
339   */
340  public static void main(String [] argv) {
341
342    try {
343      if (argv.length < 2) {
344        System.out.println("Please specify a set of instances.");
345        return;
346      }
347      KernelEstimator newEst = new KernelEstimator(0.01);
348      for (int i = 0; i < argv.length - 3; i += 2) {
349        newEst.addValue(Double.valueOf(argv[i]).doubleValue(), 
350                        Double.valueOf(argv[i + 1]).doubleValue());
351      }
352      System.out.println(newEst);
353
354      double start = Double.valueOf(argv[argv.length - 2]).doubleValue();
355      double finish = Double.valueOf(argv[argv.length - 1]).doubleValue();
356      for (double current = start; current < finish; 
357          current += (finish - start) / 50) {
358        System.out.println("Data: " + current + " " 
359                           + newEst.getProbability(current));
360      }
361    } catch (Exception e) {
362      System.out.println(e.getMessage());
363    }
364  }
365}
Note: See TracBrowser for help on using the repository browser.