source: src/main/java/weka/estimators/UnivariateKernelEstimator.java @ 20

Last change on this file since 20 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 14.7 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    UnivariateKernelEstimator.java
19 *    Copyright (C) 2009 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.estimators;
24
25import java.util.Random;
26import java.util.Collection;
27import java.util.Set;
28import java.util.Map;
29import java.util.Iterator;
30import java.util.TreeMap;
31import java.util.ArrayList;
32
33import weka.core.Statistics;
34import weka.core.Utils;
35
36/**
37 * Simple weighted kernel density estimator.
38 *
39 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
40 * @version $Revision: 5680 $
41 */
42public class UnivariateKernelEstimator implements UnivariateDensityEstimator,
43                                                  UnivariateIntervalEstimator {
44
45  /** The collection used to store the weighted values. */
46  protected TreeMap<Double, Double> m_TM = new TreeMap<Double, Double>();
47
48  /** The weighted sum of values */
49  protected double m_WeightedSum = 0;
50
51  /** The weighted sum of squared values */
52  protected double m_WeightedSumSquared = 0;
53
54  /** The weight of the values collected so far */
55  protected double m_SumOfWeights = 0;
56
57  /** The current bandwidth (only computed when needed) */
58  protected double m_Width = Double.MAX_VALUE;
59
60  /** The exponent to use in computation of bandwidth (default: -0.25) */
61  protected double m_Exponent = -0.25;
62
63  /** The minimum allowed value of the kernel width (default: 1.0E-6) */
64  protected double m_MinWidth = 1.0E-6;
65
66  /** Constant for Gaussian density. */
67  public static final double CONST = - 0.5 * Math.log(2 * Math.PI);
68
69  /** Threshold at which further kernels are no longer added to sum. */
70  protected double m_Threshold = 1.0E-6;
71
72  /** The number of intervals used to approximate prediction interval. */
73  protected int m_NumIntervals = 1000;
74
75  /**
76   * Adds a value to the density estimator.
77   *
78   * @param value the value to add
79   * @param weight the weight of the value
80   */
81  public void addValue(double value, double weight) {
82
83    m_WeightedSum += value * weight;
84    m_WeightedSumSquared += value * value * weight;
85    m_SumOfWeights += weight;
86    if (m_TM.get(value) == null) {
87      m_TM.put(value, weight);
88    } else {
89      m_TM.put(value, m_TM.get(value) + weight);
90    }
91  }
92
93  /**
94   * Updates bandwidth: the sample standard deviation is multiplied by
95   * the total weight to the power of the given exponent.
96   *
97   * If the total weight is not greater than zero, the width is set to
98   * Double.MAX_VALUE. If that is not the case, but the width becomes
99   * smaller than m_MinWidth, the width is set to the value of
100   * m_MinWidth.
101   */
102  public void updateWidth() {
103
104    // OK, need to do some work
105    if (m_SumOfWeights > 0) {
106
107      // Compute variance for scaling
108      double mean = m_WeightedSum / m_SumOfWeights;
109      double variance = m_WeightedSumSquared / m_SumOfWeights - mean * mean;
110      if (variance < 0) {
111        variance = 0;
112      }
113
114      // Compute kernel bandwidth
115      m_Width = Math.sqrt(variance) * Math.pow(m_SumOfWeights, m_Exponent);
116
117      if (m_Width <= m_MinWidth) {
118        m_Width = m_MinWidth;
119      }
120    } else {
121      m_Width = Double.MAX_VALUE;
122    }
123  }
124   
125
126  /**
127   * Returns the interval for the given confidence value.
128   *
129   * @param conf the confidence value in the interval [0, 1]
130   * @return the interval
131   */
132  public double[][] predictIntervals(double conf) {
133
134    // Update the bandwidth
135    updateWidth();
136
137    // Compute minimum and maximum value, and delta
138    double val = Statistics.normalInverse(1.0 - (1.0 - conf) / 2);
139    double min = m_TM.firstKey() - val * m_Width;
140    double max = m_TM.lastKey() + val * m_Width;
141    double delta = (max - min) / m_NumIntervals;
142
143    // Create array with estimated probabilities
144    double[] probabilities = new double[m_NumIntervals];
145    double leftVal = Math.exp(logDensity(min));
146    for (int i = 0; i < m_NumIntervals; i++) {
147      double rightVal = Math.exp(logDensity(min + (i + 1) * delta));
148      probabilities[i] = 0.5 * (leftVal + rightVal) * delta;
149      leftVal = rightVal;
150    }
151
152    // Sort array based on area of bin estimates
153    int[] sortedIndices = Utils.sort(probabilities);
154
155    // Mark the intervals to use
156    double sum = 0;
157    boolean[] toUse = new boolean[probabilities.length];
158    int k = 0;
159    while ((sum < conf) && (k < toUse.length)){
160      toUse[sortedIndices[toUse.length - (k + 1)]] = true;
161      sum += probabilities[sortedIndices[toUse.length - (k + 1)]];
162      k++;
163    }
164
165    // Don't need probabilities anymore
166    probabilities = null;
167
168    // Create final list of intervals
169    ArrayList<double[]> intervals = new ArrayList<double[]>();
170
171    // The current interval
172    double[] interval = null;
173   
174    // Iterate through kernels
175    boolean haveStartedInterval = false;
176    for (int i = 0; i < m_NumIntervals; i++) {
177
178      // Should the current bin be used?
179      if (toUse[i]) {
180
181        // Do we need to create a new interval?
182        if (haveStartedInterval == false) {
183          haveStartedInterval = true;
184          interval = new double[2];
185          interval[0] = min + i * delta;
186        }
187
188        // Regardless, we should update the upper boundary
189        interval[1] = min + (i + 1) * delta;
190      } else {
191
192        // We need to finalize and store the last interval
193        // if necessary.
194        if (haveStartedInterval) {
195          haveStartedInterval = false;
196          intervals.add(interval);
197        }
198      }
199    }
200
201    // Add last interval if there is one
202    if (haveStartedInterval) {
203      intervals.add(interval);
204    }
205
206    return intervals.toArray(new double[0][0]);
207  }
208
209  /**
210   * Computes the logarithm of x and y given the logarithms of x and y.
211   *
212   * This is based on Tobias P. Mann's description in "Numerically
213   * Stable Hidden Markov Implementation" (2006).
214   */
215  protected double logOfSum(double logOfX, double logOfY) {
216
217    // Check for cases where log of zero is present
218    if (Double.isNaN(logOfX)) {
219      return logOfY;
220    } 
221    if (Double.isNaN(logOfY)) {
222      return logOfX;
223    }
224
225    // Otherwise return proper result, taken care of overflows
226    if (logOfX > logOfY) {
227      return logOfX + Math.log(1 + Math.exp(logOfY - logOfX));
228    } else {
229      return logOfY + Math.log(1 + Math.exp(logOfX - logOfY));
230    }
231  }
232
233  /**
234   * Compute running sum of density values and weights.
235   */
236  protected void runningSum(Set<Map.Entry<Double,Double>> c, double value, 
237                            double[] sums) {
238
239    // Auxiliary variables
240    double offset = CONST - Math.log(m_Width);
241    double logFactor = Math.log(m_Threshold) - Math.log(1 - m_Threshold);
242    double logSumOfWeights = Math.log(m_SumOfWeights);
243
244    // Iterate through values
245    Iterator<Map.Entry<Double,Double>> itr = c.iterator();
246    while(itr.hasNext()) {
247      Map.Entry<Double,Double> entry = itr.next();
248
249      // Skip entry if weight is zero because it cannot contribute to sum
250      if (entry.getValue() > 0) {
251        double diff = (entry.getKey() - value) / m_Width;
252        double logDensity = offset - 0.5 * diff * diff;
253        double logWeight = Math.log(entry.getValue());
254        sums[0] = logOfSum(sums[0], logWeight + logDensity);
255        sums[1] = logOfSum(sums[1], logWeight);
256
257        // Can we stop assuming worst case?
258        if (logDensity + logSumOfWeights < logOfSum(logFactor + sums[0], logDensity + sums[1])) {
259          break;
260        }
261      }
262    }
263  }
264
265  /**
266   * Returns the natural logarithm of the density estimate at the given
267   * point.
268   *
269   * @param value the value at which to evaluate
270   * @return the natural logarithm of the density estimate at the given
271   * value
272   */
273  public double logDensity(double value) {
274
275    // Update the bandwidth
276    updateWidth();
277
278    // Array used to keep running sums
279    double[] sums = new double[2];
280    sums[0] = Double.NaN;
281    sums[1] = Double.NaN;
282
283    // Examine right-hand size of value
284    runningSum(m_TM.tailMap(value, true).entrySet(), value, sums);
285
286    // Examine left-hand size of value
287    runningSum(m_TM.headMap(value, false).descendingMap().entrySet(), value, sums);
288
289    // Need to normalize
290    return sums[0] - Math.log(m_SumOfWeights);
291  }
292
293  /**
294   * Returns textual description of this estimator.
295   */
296  public String toString() {
297
298    return "Kernel estimator with bandwidth " + m_Width + 
299      " and total weight " + m_SumOfWeights +
300      " based on\n" + m_TM.toString();
301  }
302
303  /**
304   * Main method, used for testing this class.
305   */
306  public static void main(String[] args) {
307
308    // Get random number generator initialized by system
309    Random r = new Random();
310
311    // Create density estimator
312    UnivariateKernelEstimator e = new UnivariateKernelEstimator();
313
314    // Output the density estimator
315    System.out.println(e);
316   
317    // Monte Carlo integration
318    double sum = 0;
319    for (int i = 0; i < 1000; i++) {
320      sum += Math.exp(e.logDensity(r.nextDouble() * 10.0 - 5.0));
321    }
322    System.out.println("Approximate integral: " + 10.0 * sum / 1000);
323   
324    // Add Gaussian values into it
325    for (int i = 0; i < 1000; i++) {
326      e.addValue(0.1 * r.nextGaussian() - 3, 1);
327      e.addValue(r.nextGaussian() * 0.25, 3);
328    }
329
330    // Monte Carlo integration
331    sum = 0;
332    int points = 10000;
333    for (int i = 0; i < points; i++) {
334      double value = r.nextDouble() * 10.0 - 5.0;
335      sum += Math.exp(e.logDensity(value));
336    }
337    System.out.println("Approximate integral: " + 10.0 * sum / points);
338
339    // Check interval estimates
340    double[][] Intervals = e.predictIntervals(0.9);
341   
342    System.out.println("Printing kernel intervals ---------------------");
343   
344    for (int k = 0; k < Intervals.length; k++) {
345      System.out.println("Left: " + Intervals[k][0] + "\t Right: " + Intervals[k][1]);
346    }
347   
348    System.out.println("Finished kernel printing intervals ---------------------");
349
350    double Covered = 0;
351    for (int i = 0; i < 1000; i++) {
352      double val = -1;
353      if (r.nextDouble() < 0.25) {
354        val = 0.1 * r.nextGaussian() - 3.0;
355      } else {
356        val = r.nextGaussian() * 0.25;
357      }
358      for (int k = 0; k < Intervals.length; k++) {
359        if (val >= Intervals[k][0] && val <= Intervals[k][1]) {
360          Covered++;
361          break;
362        }
363      }
364    }
365    System.out.println("Coverage at 0.9 level for kernel intervals: " + Covered / 1000);
366
367    // Compare performance to normal estimator on normally distributed data
368    UnivariateKernelEstimator eKernel = new UnivariateKernelEstimator();
369    UnivariateNormalEstimator eNormal = new UnivariateNormalEstimator();
370
371    for (int j = 1; j < 5; j++) {
372      double numTrain = Math.pow(10, j);
373      System.out.println("Number of training cases: " +
374                         numTrain); 
375
376      // Add training cases
377      for (int i = 0; i < numTrain; i++) {
378        double val = r.nextGaussian() * 1.5 + 0.5;
379        eKernel.addValue(val, 1);
380        eNormal.addValue(val, 1);
381      }
382
383      // Monte Carlo integration
384      sum = 0;
385      points = 10000;
386      for (int i = 0; i < points; i++) {
387        double value = r.nextDouble() * 20.0 - 10.0;
388        sum += Math.exp(eKernel.logDensity(value));
389      }
390      System.out.println("Approximate integral for kernel estimator: " + 20.0 * sum / points);
391
392      // Evaluate estimators
393      double loglikelihoodKernel = 0, loglikelihoodNormal = 0;
394      for (int i = 0; i < 1000; i++) {
395        double val = r.nextGaussian() * 1.5 + 0.5;
396        loglikelihoodKernel += eKernel.logDensity(val);
397        loglikelihoodNormal += eNormal.logDensity(val);
398      }
399      System.out.println("Loglikelihood for kernel estimator: " +
400                         loglikelihoodKernel / 1000);
401      System.out.println("Loglikelihood for normal estimator: " +
402                         loglikelihoodNormal / 1000);
403
404      // Check interval estimates
405      double[][] kernelIntervals = eKernel.predictIntervals(0.95);
406      double[][] normalIntervals = eNormal.predictIntervals(0.95);
407
408      System.out.println("Printing kernel intervals ---------------------");
409     
410      for (int k = 0; k < kernelIntervals.length; k++) {
411        System.out.println("Left: " + kernelIntervals[k][0] + "\t Right: " + kernelIntervals[k][1]);
412      }
413
414      System.out.println("Finished kernel printing intervals ---------------------");
415
416      System.out.println("Printing normal intervals ---------------------");
417     
418      for (int k = 0; k < normalIntervals.length; k++) {
419        System.out.println("Left: " + normalIntervals[k][0] + "\t Right: " + normalIntervals[k][1]);
420      }
421
422      System.out.println("Finished normal printing intervals ---------------------");
423 
424      double kernelCovered = 0;
425      double normalCovered = 0;
426      for (int i = 0; i < 1000; i++) {
427        double val = r.nextGaussian() * 1.5 + 0.5;
428        for (int k = 0; k < kernelIntervals.length; k++) {
429          if (val >= kernelIntervals[k][0] && val <= kernelIntervals[k][1]) {
430            kernelCovered++;
431            break;
432          }
433        }
434        for (int k = 0; k < normalIntervals.length; k++) {
435          if (val >= normalIntervals[k][0] && val <= normalIntervals[k][1]) {
436            normalCovered++;
437            break;
438          }
439        }
440      }
441      System.out.println("Coverage at 0.95 level for kernel intervals: " + kernelCovered / 1000);
442      System.out.println("Coverage at 0.95 level for normal intervals: " + normalCovered / 1000);
443     
444      kernelIntervals = eKernel.predictIntervals(0.8);
445      normalIntervals = eNormal.predictIntervals(0.8);
446      kernelCovered = 0;
447      normalCovered = 0;
448      for (int i = 0; i < 1000; i++) {
449        double val = r.nextGaussian() * 1.5 + 0.5;
450        for (int k = 0; k < kernelIntervals.length; k++) {
451          if (val >= kernelIntervals[k][0] && val <= kernelIntervals[k][1]) {
452            kernelCovered++;
453            break;
454          }
455        }
456        for (int k = 0; k < normalIntervals.length; k++) {
457          if (val >= normalIntervals[k][0] && val <= normalIntervals[k][1]) {
458            normalCovered++;
459            break;
460          }
461        }
462      }
463      System.out.println("Coverage at 0.8 level for kernel intervals: " + kernelCovered / 1000);
464      System.out.println("Coverage at 0.8 level for normal intervals: " + normalCovered / 1000);
465    }
466  }
467}
Note: See TracBrowser for help on using the repository browser.