source: src/main/java/weka/estimators/KKConditionalEstimator.java @ 23

Last change on this file since 23 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 8.7 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    KKConditionalEstimator.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.estimators;
24
25import java.util.Random;
26
27import weka.core.RevisionUtils;
28import weka.core.Statistics;
29import weka.core.Utils;
30
31/**
32 * Conditional probability estimator for a numeric domain conditional upon
33 * a numeric domain.
34 *
35 * @author Len Trigg (trigg@cs.waikato.ac.nz)
36 * @version $Revision: 1.8 $
37 */
38public class KKConditionalEstimator implements ConditionalEstimator {
39
40  /** Vector containing all of the values seen */
41  private double [] m_Values;
42
43  /** Vector containing all of the conditioning values seen */
44  private double [] m_CondValues;
45
46  /** Vector containing the associated weights */
47  private double [] m_Weights;
48
49  /**
50   * Number of values stored in m_Weights, m_CondValues, and m_Values so far
51   */
52  private int m_NumValues;
53
54  /** The sum of the weights so far */
55  private double m_SumOfWeights;
56
57  /** Current standard dev */
58  private double m_StandardDev;
59
60  /** Whether we can optimise the kernel summation */
61  private boolean m_AllWeightsOne;
62
63  /** The numeric precision */
64  private double m_Precision;
65
66  /**
67   * Execute a binary search to locate the nearest data value
68   *
69   * @param key the data value to locate
70   * @param secondaryKey the data value to locate
71   * @return the index of the nearest data value
72   */
73  private int findNearestPair(double key, double secondaryKey) {
74
75    int low = 0; 
76    int high = m_NumValues;
77    int middle = 0;
78    while (low < high) {
79      middle = (low + high) / 2;
80      double current = m_CondValues[middle];
81      if (current == key) {
82        double secondary = m_Values[middle];
83        if (secondary == secondaryKey) {
84          return middle;
85        }
86        if (secondary > secondaryKey) {
87          high = middle;
88        } else if (secondary < secondaryKey) {
89          low = middle+1;
90        }
91      }
92      if (current > key) {
93        high = middle;
94      } else if (current < key) {
95        low = middle+1;
96      }
97    }
98    return low;
99  }
100
101  /**
102   * Round a data value using the defined precision for this estimator
103   *
104   * @param data the value to round
105   * @return the rounded data value
106   */
107  private double round(double data) {
108
109    return Math.rint(data / m_Precision) * m_Precision;
110  }
111 
112  /**
113   * Constructor
114   *
115   * @param precision the  precision to which numeric values are given. For
116   * example, if the precision is stated to be 0.1, the values in the
117   * interval (0.25,0.35] are all treated as 0.3.
118   */
119  public KKConditionalEstimator(double precision) {
120
121    m_CondValues = new double [50];
122    m_Values = new double [50];
123    m_Weights = new double [50];
124    m_NumValues = 0;
125    m_SumOfWeights = 0;
126    m_StandardDev = 0;
127    m_AllWeightsOne = true;
128    m_Precision = precision;
129  }
130
131  /**
132   * Add a new data value to the current estimator.
133   *
134   * @param data the new data value
135   * @param given the new value that data is conditional upon
136   * @param weight the weight assigned to the data value
137   */
138  public void addValue(double data, double given, double weight) {
139
140    data = round(data);
141    given = round(given);
142    int insertIndex = findNearestPair(given, data);
143    if ((m_NumValues <= insertIndex)
144        || (m_CondValues[insertIndex] != given)
145        || (m_Values[insertIndex] != data)) {
146      if (m_NumValues < m_Values.length) {
147        int left = m_NumValues - insertIndex; 
148        System.arraycopy(m_Values, insertIndex, 
149                         m_Values, insertIndex + 1, left);
150        System.arraycopy(m_CondValues, insertIndex, 
151                         m_CondValues, insertIndex + 1, left);
152        System.arraycopy(m_Weights, insertIndex, 
153                         m_Weights, insertIndex + 1, left);
154        m_Values[insertIndex] = data;
155        m_CondValues[insertIndex] = given;
156        m_Weights[insertIndex] = weight;
157        m_NumValues++;
158      } else {
159        double [] newValues = new double [m_Values.length*2];
160        double [] newCondValues = new double [m_Values.length*2];
161        double [] newWeights = new double [m_Values.length*2];
162        int left = m_NumValues - insertIndex; 
163        System.arraycopy(m_Values, 0, newValues, 0, insertIndex);
164        System.arraycopy(m_CondValues, 0, newCondValues, 0, insertIndex);
165        System.arraycopy(m_Weights, 0, newWeights, 0, insertIndex);
166        newValues[insertIndex] = data;
167        newCondValues[insertIndex] = given;
168        newWeights[insertIndex] = weight;
169        System.arraycopy(m_Values, insertIndex, 
170                         newValues, insertIndex+1, left);
171        System.arraycopy(m_CondValues, insertIndex, 
172                         newCondValues, insertIndex+1, left);
173        System.arraycopy(m_Weights, insertIndex, 
174                         newWeights, insertIndex+1, left);
175        m_NumValues++;
176        m_Values = newValues;
177        m_CondValues = newCondValues;
178        m_Weights = newWeights;
179      }
180      if (weight != 1) {
181        m_AllWeightsOne = false;
182      }
183    } else {
184      m_Weights[insertIndex] += weight;
185      m_AllWeightsOne = false;     
186    }
187    m_SumOfWeights += weight;
188    double range = m_CondValues[m_NumValues-1] - m_CondValues[0];
189    m_StandardDev = Math.max(range / Math.sqrt(m_SumOfWeights), 
190                             // allow at most 3 sds within one interval
191                             m_Precision / (2 * 3));
192  }
193
194  /**
195   * Get a probability estimator for a value
196   *
197   * @param given the new value that data is conditional upon
198   * @return the estimator for the supplied value given the condition
199   */
200  public Estimator getEstimator(double given) {
201
202    Estimator result = new KernelEstimator(m_Precision);
203    if (m_NumValues == 0) {
204      return result;
205    }
206
207    double delta = 0, currentProb = 0;
208    double zLower, zUpper;
209    for(int i = 0; i < m_NumValues; i++) {
210      delta = m_CondValues[i] - given;
211      zLower = (delta - (m_Precision / 2)) / m_StandardDev;
212      zUpper = (delta + (m_Precision / 2)) / m_StandardDev;
213      currentProb = (Statistics.normalProbability(zUpper)
214                     - Statistics.normalProbability(zLower));
215      result.addValue(m_Values[i], currentProb * m_Weights[i]);
216    }
217    return result;
218  }
219
220  /**
221   * Get a probability estimate for a value
222   *
223   * @param data the value to estimate the probability of
224   * @param given the new value that data is conditional upon
225   * @return the estimated probability of the supplied value
226   */
227  public double getProbability(double data, double given) {
228
229    return getEstimator(given).getProbability(data);
230  }
231
232  /**
233   * Display a representation of this estimator
234   */
235  public String toString() {
236
237    String result = "KK Conditional Estimator. " 
238      + m_NumValues + " Normal Kernels:\n"
239      + "StandardDev = " + Utils.doubleToString(m_StandardDev,4,2) 
240      + "  \nMeans =";
241    for(int i = 0; i < m_NumValues; i++) {
242      result += " (" + m_Values[i] + ", " + m_CondValues[i] + ")";
243      if (!m_AllWeightsOne) {
244          result += "w=" + m_Weights[i];
245      }
246    }
247    return result;
248  }
249 
250  /**
251   * Returns the revision string.
252   *
253   * @return            the revision
254   */
255  public String getRevision() {
256    return RevisionUtils.extract("$Revision: 1.8 $");
257  }
258
259  /**
260   * Main method for testing this class. Creates some random points
261   * in the range 0 - 100,
262   * and prints out a distribution conditional on some value
263   *
264   * @param argv should contain: seed conditional_value numpoints
265   */
266  public static void main(String [] argv) {
267
268    try {
269      int seed = 42;
270      if (argv.length > 0) {
271        seed = Integer.parseInt(argv[0]);
272      }
273      KKConditionalEstimator newEst = new KKConditionalEstimator(0.1);
274
275      // Create 100 random points and add them
276      Random r = new Random(seed);
277     
278      int numPoints = 50;
279      if (argv.length > 2) {
280        numPoints = Integer.parseInt(argv[2]);
281      }
282      for(int i = 0; i < numPoints; i++) {
283        int x = Math.abs(r.nextInt()%100);
284        int y = Math.abs(r.nextInt()%100);
285        System.out.println("# " + x + "  " + y);
286        newEst.addValue(x, y, 1);
287      }
288      //    System.out.println(newEst);
289      int cond;
290      if (argv.length > 1) {
291        cond = Integer.parseInt(argv[1]);
292      } else {
293        cond = Math.abs(r.nextInt()%100);
294      }
295      System.out.println("## Conditional = " + cond);
296      Estimator result = newEst.getEstimator(cond);
297      for(int i = 0; i <= 100; i+= 5) {
298        System.out.println(" " + i + "  " + result.getProbability(i));
299      }
300    } catch (Exception e) {
301      System.out.println(e.getMessage());
302    }
303  }
304}
Note: See TracBrowser for help on using the repository browser.