source: src/main/java/weka/estimators/UnivariateEqualFrequencyHistogramEstimator.java @ 6

Last change on this file since 6 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 19.6 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    UnivariateEqualFrequencyEstimator.java
19 *    Copyright (C) 2009 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.estimators;
24
25import java.util.Random;
26import java.util.Collection;
27import java.util.Set;
28import java.util.Map;
29import java.util.Iterator;
30import java.util.TreeMap;
31import java.util.ArrayList;
32import java.util.Arrays;
33
34import weka.core.Statistics;
35import weka.core.Utils;
36
37/**
38 * Simple histogram density estimator. Uses equal-frequency histograms
39 * based on the specified number of bins (default: 10).
40 *
41 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
42 * @version $Revision: 5680 $
43 */
44public class UnivariateEqualFrequencyHistogramEstimator implements UnivariateDensityEstimator,
45                                                     UnivariateIntervalEstimator {
46
47  /** The collection used to store the weighted values. */
48  protected TreeMap<Double, Double> m_TM = new TreeMap<Double, Double>();
49
50  /** The interval boundaries. */
51  protected double[] m_Boundaries = null;
52
53  /** The weight of each interval. */
54  protected double[] m_Weights = null;
55
56  /** The weighted sum of values */
57  protected double m_WeightedSum = 0;
58
59  /** The weighted sum of squared values */
60  protected double m_WeightedSumSquared = 0;
61
62  /** The total sum of weights. */
63  protected double m_SumOfWeights = 0;
64
65  /** The number of bins to use. */
66  protected int m_NumBins = 10;
67
68  /** The current bandwidth (only computed when needed) */
69  protected double m_Width = Double.MAX_VALUE;
70
71  /** The exponent to use in computation of bandwidth (default: -0.25) */
72  protected double m_Exponent = -0.25;
73
74  /** The minimum allowed value of the kernel width (default: 1.0E-6) */
75  protected double m_MinWidth = 1.0E-6;
76
77  /** Constant for Gaussian density. */
78  public static final double CONST = - 0.5 * Math.log(2 * Math.PI);
79
80  /** The number of intervals used to approximate prediction interval. */
81  protected int m_NumIntervals = 1000;
82
83  /** Whether boundaries are updated or only weights. */
84  protected boolean m_UpdateWeightsOnly = false;
85
86  /**
87   * Gets the number of bins
88   *
89   * @return the number of bins.
90   */
91  public int getNumBins() {
92
93    return m_NumBins;
94  }
95
96  /**
97   * Sets the number of bins
98   *
99   * @param numBins the number of bins
100   */
101  public void setNumBins(int numBins) {
102
103    m_NumBins = numBins;
104  }
105
106  /**
107   * Triggers construction of estimator based on current data
108   * and then initializes the statistics.
109   */
110  public void initializeStatistics() {
111
112    updateBoundariesAndOrWeights();
113
114    m_TM = new TreeMap<Double, Double>();
115    m_WeightedSum = 0;
116    m_WeightedSumSquared = 0;
117    m_SumOfWeights = 0;
118    m_Weights = null;
119  }   
120
121  /**
122   * Sets whether only weights should be udpated.
123   */
124  public void setUpdateWeightsOnly(boolean flag) {
125
126    m_UpdateWeightsOnly = flag;
127  }
128
129  /**
130   * Gets whether only weights should be udpated.*
131   */
132  public boolean getUpdateWeightsOnly() {
133
134    return m_UpdateWeightsOnly;
135  }
136
137  /**
138   * Adds a value to the density estimator.
139   *
140   * @param value the value to add
141   * @param weight the weight of the value
142   */
143  public void addValue(double value, double weight) {
144
145    // Add data point to collection
146    m_WeightedSum += value * weight;
147    m_WeightedSumSquared += value * value * weight;
148    m_SumOfWeights += weight;
149    if (m_TM.get(value) == null) {
150      m_TM.put(value, weight);
151    } else {
152      m_TM.put(value, m_TM.get(value) + weight);
153    }
154
155    // Make sure estimator is updated
156    if (!getUpdateWeightsOnly()) {
157      m_Boundaries = null;
158    }
159    m_Weights = null;
160  }
161
162  /**
163   * Updates the boundaries if necessary.
164   */
165  protected void updateBoundariesAndOrWeights() {
166
167    // Do we need to update?
168    if (m_Weights != null) {
169      return;
170    }
171
172    // Update widths for cases that are out of bounds,
173    // using same code as in kernel estimator
174
175    // First, compute variance for scaling
176    double mean = m_WeightedSum / m_SumOfWeights;
177    double variance = m_WeightedSumSquared / m_SumOfWeights - mean * mean;
178    if (variance < 0) {
179      variance = 0;
180    }
181   
182    // Compute kernel bandwidth
183    m_Width = Math.sqrt(variance) * Math.pow(m_SumOfWeights, m_Exponent);
184   
185    if (m_Width <= m_MinWidth) {
186      m_Width = m_MinWidth;
187    }
188   
189    // Do we need to update weights only
190    if (getUpdateWeightsOnly()) {
191      updateWeightsOnly();
192    } else {
193      updateBoundariesAndWeights();
194    }
195  }
196   
197  /**
198   * Updates the weights only.
199   */
200  protected void updateWeightsOnly() throws IllegalArgumentException {
201
202    // Get values and keys from tree map
203    Iterator<Map.Entry<Double,Double>> itr = m_TM.entrySet().iterator();
204    int j = 1;
205    m_Weights = new double[m_Boundaries.length - 1];
206    while(itr.hasNext()) {
207      Map.Entry<Double,Double> entry = itr.next();
208      double value = entry.getKey();
209      double weight = entry.getValue();
210      if ((value < m_Boundaries[0]) || (value > m_Boundaries[m_Boundaries.length - 1])) {
211        throw new IllegalArgumentException("Out-of-range value during weight update");
212      }
213      while (value > m_Boundaries[j]) {
214        j++;
215      }
216      m_Weights[j - 1] += weight;
217    }
218  }
219
220  /**
221   * Updates the boundaries and weights.
222   */
223  protected void updateBoundariesAndWeights() {
224
225    // Get values and keys from tree map
226    double[] values = new double[m_TM.size()];
227    double[] weights = new double[m_TM.size()];
228    Iterator<Map.Entry<Double,Double>> itr = m_TM.entrySet().iterator();
229    int j = 0;
230    while(itr.hasNext()) {
231      Map.Entry<Double,Double> entry = itr.next();
232      values[j] = entry.getKey();
233      weights[j] = entry.getValue();
234      j++;
235    }
236
237    double freq = m_SumOfWeights / m_NumBins;
238    double[] cutPoints = new double[m_NumBins - 1];
239    double[] binWeights = new double[m_NumBins];
240    double sumOfWeights = m_SumOfWeights;
241
242    // Compute break points
243    double weightSumSoFar = 0, lastWeightSum = 0;
244    int cpindex = 0, lastIndex = -1;
245    for (int i = 0; i < values.length - 1; i++) {
246
247      // Update weight statistics
248      weightSumSoFar += weights[i];
249      sumOfWeights -= weights[i];
250
251      // Have we passed the ideal size?
252      if (weightSumSoFar >= freq) {
253
254        // Is this break point worse than the last one?
255        if (((freq - lastWeightSum) < (weightSumSoFar - freq)) && (lastIndex != -1)) {
256          cutPoints[cpindex] = (values[lastIndex] + values[lastIndex + 1]) / 2;
257          weightSumSoFar -= lastWeightSum;
258          binWeights[cpindex] = lastWeightSum;
259          lastWeightSum = weightSumSoFar;
260          lastIndex = i;
261        } else {
262          cutPoints[cpindex] = (values[i] + values[i + 1]) / 2;
263          binWeights[cpindex] = weightSumSoFar;
264          weightSumSoFar = 0;
265          lastWeightSum = 0;
266          lastIndex = -1;
267        }
268        cpindex++;
269        freq = (sumOfWeights + weightSumSoFar) / ((cutPoints.length + 1) - cpindex);
270      } else {
271        lastIndex = i;
272        lastWeightSum = weightSumSoFar;
273      }
274    }
275
276    // Check whether there was another possibility for a cut point
277    if ((cpindex < cutPoints.length) && (lastIndex != -1)) {
278      cutPoints[cpindex] = (values[lastIndex] + values[lastIndex + 1]) / 2;     
279      binWeights[cpindex] = lastWeightSum;
280      cpindex++;
281      binWeights[cpindex] = weightSumSoFar - lastWeightSum;
282    } else {
283      binWeights[cpindex] = weightSumSoFar;
284    }
285
286    // Did we find any cutpoints?
287    if (cpindex == 0) {
288      m_Boundaries = null;
289      m_Weights = null;
290    } else {
291
292      // Need to add weight of last data point to right-most bin
293      binWeights[cpindex] += weights[values.length - 1];
294
295      // Copy over boundaries and weights
296      m_Boundaries = new double[cpindex + 2];
297      m_Boundaries[0] = m_TM.firstKey();
298      m_Boundaries[cpindex + 1] = m_TM.lastKey();
299      System.arraycopy(cutPoints, 0, m_Boundaries, 1, cpindex);
300      m_Weights = new double[cpindex + 1];
301      System.arraycopy(binWeights, 0, m_Weights, 0, cpindex + 1);
302    }
303  }
304   
305
306  /**
307   * Returns the interval for the given confidence value.
308   *
309   * @param conf the confidence value in the interval [0, 1]
310   * @return the interval
311   */
312  public double[][] predictIntervals(double conf) {
313
314    // Update the bandwidth
315    updateBoundariesAndOrWeights();
316
317    // Compute minimum and maximum value, and delta
318    double val = Statistics.normalInverse(1.0 - (1.0 - conf) / 2);
319    double min = m_TM.firstKey() - val * m_Width;
320    double max = m_TM.lastKey() + val * m_Width;
321    double delta = (max - min) / m_NumIntervals;
322
323    // Create array with estimated probabilities
324    double[] probabilities = new double[m_NumIntervals];
325    double leftVal = Math.exp(logDensity(min));
326    for (int i = 0; i < m_NumIntervals; i++) {
327      double rightVal = Math.exp(logDensity(min + (i + 1) * delta));
328      probabilities[i] = 0.5 * (leftVal + rightVal) * delta;
329      leftVal = rightVal;
330    }
331
332    // Sort array based on area of bin estimates
333    int[] sortedIndices = Utils.sort(probabilities);
334
335    // Mark the intervals to use
336    double sum = 0;
337    boolean[] toUse = new boolean[probabilities.length];
338    int k = 0;
339    while ((sum < conf) && (k < toUse.length)){
340      toUse[sortedIndices[toUse.length - (k + 1)]] = true;
341      sum += probabilities[sortedIndices[toUse.length - (k + 1)]];
342      k++;
343    }
344
345    // Don't need probabilities anymore
346    probabilities = null;
347
348    // Create final list of intervals
349    ArrayList<double[]> intervals = new ArrayList<double[]>();
350
351    // The current interval
352    double[] interval = null;
353   
354    // Iterate through kernels
355    boolean haveStartedInterval = false;
356    for (int i = 0; i < m_NumIntervals; i++) {
357
358      // Should the current bin be used?
359      if (toUse[i]) {
360
361        // Do we need to create a new interval?
362        if (haveStartedInterval == false) {
363          haveStartedInterval = true;
364          interval = new double[2];
365          interval[0] = min + i * delta;
366        }
367
368        // Regardless, we should update the upper boundary
369        interval[1] = min + (i + 1) * delta;
370      } else {
371
372        // We need to finalize and store the last interval
373        // if necessary.
374        if (haveStartedInterval) {
375          haveStartedInterval = false;
376          intervals.add(interval);
377        }
378      }
379    }
380
381    // Add last interval if there is one
382    if (haveStartedInterval) {
383      intervals.add(interval);
384    }
385
386    return intervals.toArray(new double[0][0]);
387  }
388
389  /**
390   * Returns the natural logarithm of the density estimate at the given
391   * point.
392   *
393   * @param value the value at which to evaluate
394   * @return the natural logarithm of the density estimate at the given
395   * value
396   */
397  public double logDensity(double value) {
398
399    // Update boundaries if necessary
400    updateBoundariesAndOrWeights();
401
402    if (m_Boundaries == null) {
403      return Math.log(Double.MIN_VALUE);
404    }
405
406    // Find the bin
407    int index = Arrays.binarySearch(m_Boundaries, value);
408
409    // Is the value outside?
410    if ((index == -1) || (index == -m_Boundaries.length - 1)) {
411
412      // Use normal density outside
413      double val = 0;
414      if (index == -1) { // Smaller than minimum
415        val = m_TM.firstKey() - value;
416      } else {
417        val = value - m_TM.lastKey();
418      }
419      return (CONST - Math.log(m_Width) - 0.5 * (val * val / (m_Width * m_Width))) -
420        Math.log(m_SumOfWeights + 2); 
421    }
422   
423    // Is value exactly equal to right-most boundary?
424    if (index == m_Boundaries.length - 1) {
425      index--;
426    } else {
427
428      // Need to reverse index if necessary
429      if (index < 0) {
430        index = -index - 2;
431      }
432    }
433   
434    // Figure out of width
435    double width = m_Boundaries[index + 1] - m_Boundaries[index];
436
437    // Density compontent from smeared-out data point
438    double densSmearedOut = 1.0 / ((m_SumOfWeights + 2) * (m_Boundaries[m_Boundaries.length - 1] -
439                                                           m_Boundaries[0]));
440
441    // Return log of density
442    if (m_Weights[index] <= 0) {
443
444      /*      System.out.println(value);
445      System.out.println(this);
446      System.exit(1);*/
447      // Just use one smeared-out data point
448      return Math.log(densSmearedOut);
449    } else {
450      return Math.log(densSmearedOut + m_Weights[index] / ((m_SumOfWeights + 2) * width));
451    }
452  }
453
454  /**
455   * Returns textual description of this estimator.
456   */
457  public String toString() {
458
459    StringBuffer text = new StringBuffer();
460
461    text.append("EqualFrequencyHistogram estimator\n\n" +
462                "Bandwidth for out of range cases " + m_Width + 
463                ", total weight " + m_SumOfWeights);
464
465    if (m_Boundaries != null) {
466      text.append("\nLeft boundary\tRight boundary\tWeight\n");
467      for (int i = 0; i < m_Boundaries.length - 1; i++) {
468        text.append(m_Boundaries[i] + "\t" + m_Boundaries[i + 1] + "\t" + m_Weights[i] + "\t" +
469                    Math.exp(logDensity((m_Boundaries[i + 1] + m_Boundaries[i]) / 2)) + "\n");
470      }
471    }
472
473    return text.toString();
474  }
475
476  /**
477   * Main method, used for testing this class.
478   */
479  public static void main(String[] args) {
480
481    // Get random number generator initialized by system
482    Random r = new Random();
483
484    // Create density estimator
485    UnivariateEqualFrequencyHistogramEstimator e = new UnivariateEqualFrequencyHistogramEstimator();
486
487    // Output the density estimator
488    System.out.println(e);
489   
490    // Monte Carlo integration
491    double sum = 0;
492    for (int i = 0; i < 1000; i++) {
493      sum += Math.exp(e.logDensity(r.nextDouble() * 10.0 - 5.0));
494    }
495    System.out.println("Approximate integral: " + 10.0 * sum / 1000);
496   
497    // Add Gaussian values into it
498    for (int i = 0; i < 1000; i++) {
499      e.addValue(0.1 * r.nextGaussian() - 3, 1);
500      e.addValue(r.nextGaussian() * 0.25, 3);
501    }
502
503    // Monte Carlo integration
504    sum = 0;
505    int points = 10000000;
506    for (int i = 0; i < points; i++) {
507      double value = r.nextDouble() * 20.0 - 10.0;
508      sum += Math.exp(e.logDensity(value));
509    }
510
511    // Output the density estimator
512    System.out.println(e);
513
514    System.out.println("Approximate integral: " + 20.0 * sum / points);
515
516    // Check interval estimates
517    double[][] Intervals = e.predictIntervals(0.9);
518   
519    System.out.println("Printing histogram intervals ---------------------");
520   
521    for (int k = 0; k < Intervals.length; k++) {
522      System.out.println("Left: " + Intervals[k][0] + "\t Right: " + Intervals[k][1]);
523    }
524   
525    System.out.println("Finished histogram printing intervals ---------------------");
526
527    double Covered = 0;
528    for (int i = 0; i < 1000; i++) {
529      double val = -1;
530      if (r.nextDouble() < 0.25) {
531        val = 0.1 * r.nextGaussian() - 3.0;
532      } else {
533        val = r.nextGaussian() * 0.25;
534      }
535      for (int k = 0; k < Intervals.length; k++) {
536        if (val >= Intervals[k][0] && val <= Intervals[k][1]) {
537          Covered++;
538          break;
539        }
540      }
541    }
542    System.out.println("Coverage at 0.9 level for histogram intervals: " + Covered / 1000);
543
544    for (int j = 1; j < 5; j++) {
545      double numTrain = Math.pow(10, j);
546      System.out.println("Number of training cases: " +
547                         numTrain); 
548
549      // Compare performance to normal estimator on normally distributed data
550      UnivariateEqualFrequencyHistogramEstimator eHistogram = new UnivariateEqualFrequencyHistogramEstimator();
551      UnivariateNormalEstimator eNormal = new UnivariateNormalEstimator();
552     
553      // Add training cases
554      for (int i = 0; i < numTrain; i++) {
555        double val = r.nextGaussian() * 1.5 + 0.5;
556        /*        if (j == 4) {
557          System.err.println(val);
558          }*/
559        eHistogram.addValue(val, 1);
560        eNormal.addValue(val, 1);
561      }
562
563      // Monte Carlo integration
564      sum = 0;
565      points = 10000000;
566      for (int i = 0; i < points; i++) {
567        double value = r.nextDouble() * 20.0 - 10.0;
568        sum += Math.exp(eHistogram.logDensity(value));
569      }
570      System.out.println(eHistogram);
571      System.out.println("Approximate integral for histogram estimator: " + 20.0 * sum / points);
572
573      // Evaluate estimators
574      double loglikelihoodHistogram = 0, loglikelihoodNormal = 0;
575      for (int i = 0; i < 1000; i++) {
576        double val = r.nextGaussian() * 1.5 + 0.5;
577        loglikelihoodHistogram += eHistogram.logDensity(val);
578        loglikelihoodNormal += eNormal.logDensity(val);
579      }
580      System.out.println("Loglikelihood for histogram estimator: " +
581                         loglikelihoodHistogram / 1000);
582      System.out.println("Loglikelihood for normal estimator: " +
583                         loglikelihoodNormal / 1000);
584
585      // Check interval estimates
586      double[][] histogramIntervals = eHistogram.predictIntervals(0.95);
587      double[][] normalIntervals = eNormal.predictIntervals(0.95);
588
589      System.out.println("Printing histogram intervals ---------------------");
590     
591      for (int k = 0; k < histogramIntervals.length; k++) {
592        System.out.println("Left: " + histogramIntervals[k][0] + "\t Right: " + histogramIntervals[k][1]);
593      }
594
595      System.out.println("Finished histogram printing intervals ---------------------");
596
597      System.out.println("Printing normal intervals ---------------------");
598     
599      for (int k = 0; k < normalIntervals.length; k++) {
600        System.out.println("Left: " + normalIntervals[k][0] + "\t Right: " + normalIntervals[k][1]);
601      }
602
603      System.out.println("Finished normal printing intervals ---------------------");
604 
605      double histogramCovered = 0;
606      double normalCovered = 0;
607      for (int i = 0; i < 1000; i++) {
608        double val = r.nextGaussian() * 1.5 + 0.5;
609        for (int k = 0; k < histogramIntervals.length; k++) {
610          if (val >= histogramIntervals[k][0] && val <= histogramIntervals[k][1]) {
611            histogramCovered++;
612            break;
613          }
614        }
615        for (int k = 0; k < normalIntervals.length; k++) {
616          if (val >= normalIntervals[k][0] && val <= normalIntervals[k][1]) {
617            normalCovered++;
618            break;
619          }
620        }
621      }
622      System.out.println("Coverage at 0.95 level for histogram intervals: " + histogramCovered / 1000);
623      System.out.println("Coverage at 0.95 level for normal intervals: " + normalCovered / 1000);
624     
625      histogramIntervals = eHistogram.predictIntervals(0.8);
626      normalIntervals = eNormal.predictIntervals(0.8);
627      histogramCovered = 0;
628      normalCovered = 0;
629      for (int i = 0; i < 1000; i++) {
630        double val = r.nextGaussian() * 1.5 + 0.5;
631        for (int k = 0; k < histogramIntervals.length; k++) {
632          if (val >= histogramIntervals[k][0] && val <= histogramIntervals[k][1]) {
633            histogramCovered++;
634            break;
635          }
636        }
637        for (int k = 0; k < normalIntervals.length; k++) {
638          if (val >= normalIntervals[k][0] && val <= normalIntervals[k][1]) {
639            normalCovered++;
640            break;
641          }
642        }
643      }
644      System.out.println("Coverage at 0.8 level for histogram intervals: " + histogramCovered / 1000);
645      System.out.println("Coverage at 0.8 level for normal intervals: " + normalCovered / 1000);
646    }
647  }
648}
Note: See TracBrowser for help on using the repository browser.