source: src/main/java/weka/estimators/NormalEstimator.java @ 7

Last change on this file since 7 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 6.3 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    NormalEstimator.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.estimators;
24
25import weka.core.Capabilities.Capability;
26import weka.core.Capabilities;
27import weka.core.RevisionUtils;
28import weka.core.Statistics;
29import weka.core.Utils;
30
31/**
32 * Simple probability estimator that places a single normal distribution
33 * over the observed values.
34 *
35 * @author Len Trigg (trigg@cs.waikato.ac.nz)
36 * @version $Revision: 5490 $
37 */
38public class NormalEstimator
39  extends Estimator
40  implements IncrementalEstimator {
41
42  /** for serialization */
43  private static final long serialVersionUID = 93584379632315841L;
44
45  /** The sum of the weights */
46  private double m_SumOfWeights;
47
48  /** The sum of the values seen */
49  private double m_SumOfValues;
50
51  /** The sum of the values squared */
52  private double m_SumOfValuesSq;
53
54  /** The current mean */
55  private double m_Mean;
56
57  /** The current standard deviation */
58  private double m_StandardDev;
59
60  /** The precision of numeric values ( = minimum std dev permitted) */
61  private double m_Precision;
62
63  /**
64   * Round a data value using the defined precision for this estimator
65   *
66   * @param data the value to round
67   * @return the rounded data value
68   */
69  private double round(double data) {
70
71    return Math.rint(data / m_Precision) * m_Precision;
72  }
73 
74  // ===============
75  // Public methods.
76  // ===============
77 
78  /**
79   * Constructor that takes a precision argument.
80   *
81   * @param precision the precision to which numeric values are given. For
82   * example, if the precision is stated to be 0.1, the values in the
83   * interval (0.25,0.35] are all treated as 0.3.
84   */
85  public NormalEstimator(double precision) {
86
87    m_Precision = precision;
88
89    // Allow at most 3 sd's within one interval
90    m_StandardDev = m_Precision / (2 * 3);
91  }
92
93  /**
94   * Add a new data value to the current estimator.
95   *
96   * @param data the new data value
97   * @param weight the weight assigned to the data value
98   */
99  public void addValue(double data, double weight) {
100
101    if (weight == 0) {
102      return;
103    }
104    data = round(data);
105    m_SumOfWeights += weight;
106    m_SumOfValues += data * weight;
107    m_SumOfValuesSq += data * data * weight;
108
109    if (m_SumOfWeights > 0) {
110      m_Mean = m_SumOfValues / m_SumOfWeights;
111      double stdDev = Math.sqrt(Math.abs(m_SumOfValuesSq
112                                          - m_Mean * m_SumOfValues) 
113                                         / m_SumOfWeights);
114      // If the stdDev ~= 0, we really have no idea of scale yet,
115      // so stick with the default. Otherwise...
116      if (stdDev > 1e-10) {
117        m_StandardDev = Math.max(m_Precision / (2 * 3), 
118                                 // allow at most 3sd's within one interval
119                                 stdDev);
120      }
121    }
122  }
123
124  /**
125   * Get a probability estimate for a value
126   *
127   * @param data the value to estimate the probability of
128   * @return the estimated probability of the supplied value
129   */
130  public double getProbability(double data) {
131
132    data = round(data);
133    double zLower = (data - m_Mean - (m_Precision / 2)) / m_StandardDev;
134    double zUpper = (data - m_Mean + (m_Precision / 2)) / m_StandardDev;
135   
136    double pLower = Statistics.normalProbability(zLower);
137    double pUpper = Statistics.normalProbability(zUpper);
138    return pUpper - pLower;
139  }
140
141  /**
142   * Display a representation of this estimator
143   */
144  public String toString() {
145
146    return "Normal Distribution. Mean = " + Utils.doubleToString(m_Mean, 4)
147      + " StandardDev = " + Utils.doubleToString(m_StandardDev, 4)
148      + " WeightSum = " + Utils.doubleToString(m_SumOfWeights, 4)
149      + " Precision = " + m_Precision + "\n";
150  }
151
152  /**
153   * Returns default capabilities of the classifier.
154   *
155   * @return      the capabilities of this classifier
156   */
157  public Capabilities getCapabilities() {
158    Capabilities result = super.getCapabilities();
159    result.disableAll();
160   
161    // class
162    if (!m_noClass) {
163      result.enable(Capability.NOMINAL_CLASS);
164      result.enable(Capability.MISSING_CLASS_VALUES);
165    } else {
166      result.enable(Capability.NO_CLASS);
167    }
168   
169    // attributes
170    result.enable(Capability.NUMERIC_ATTRIBUTES);
171    return result;
172  }
173
174  /**
175   * Return the value of the mean of this normal estimator.
176   *
177   * @return the mean
178   */
179  public double getMean() {
180    return m_Mean;
181  }
182
183  /**
184   * Return the value of the standard deviation of this normal estimator.
185   *
186   * @return the standard deviation
187   */
188  public double getStdDev() {
189    return m_StandardDev;
190  }
191
192  /**
193   * Return the value of the precision of this normal estimator.
194   *
195   * @return the precision
196   */
197  public double getPrecision() {
198    return m_Precision;
199  }
200
201  /**
202   * Return the sum of the weights for this normal estimator.
203   *
204   * @return the sum of the weights
205   */
206  public double getSumOfWeights() {
207    return m_SumOfWeights;
208  }
209 
210  /**
211   * Returns the revision string.
212   *
213   * @return            the revision
214   */
215  public String getRevision() {
216    return RevisionUtils.extract("$Revision: 5490 $");
217  }
218
219  /**
220   * Main method for testing this class.
221   *
222   * @param argv should contain a sequence of numeric values
223   */
224  public static void main(String [] argv) {
225
226    try {
227
228      if (argv.length == 0) {
229        System.out.println("Please specify a set of instances.");
230        return;
231      }
232      NormalEstimator newEst = new NormalEstimator(0.01);
233      for(int i = 0; i < argv.length; i++) {
234        double current = Double.valueOf(argv[i]).doubleValue();
235        System.out.println(newEst);
236        System.out.println("Prediction for " + current
237                           + " = " + newEst.getProbability(current));
238        newEst.addValue(current, 1);
239      }
240    } catch (Exception e) {
241      System.out.println(e.getMessage());
242    }
243  }
244}
Note: See TracBrowser for help on using the repository browser.