source: src/main/java/weka/estimators/NNConditionalEstimator.java @ 22

Last change on this file since 22 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 8.2 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    NNConditionalEstimator.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.estimators;
24
25import java.util.Random;
26import java.util.Vector;
27
28import weka.core.matrix.Matrix;
29import weka.core.RevisionUtils;
30import weka.core.Utils;
31
32/**
33 * Conditional probability estimator for a numeric domain conditional upon
34 * a numeric domain (using Mahalanobis distance).
35 *
36 * @author Len Trigg (trigg@cs.waikato.ac.nz)
37 * @version $Revision: 1.8 $
38 */
39public class NNConditionalEstimator implements ConditionalEstimator {
40
41  /** Vector containing all of the values seen */
42  private Vector m_Values = new Vector();
43
44  /** Vector containing all of the conditioning values seen */
45  private Vector m_CondValues = new Vector();
46
47  /** Vector containing the associated weights */
48  private Vector m_Weights = new Vector();
49
50  /** The sum of the weights so far */
51  private double m_SumOfWeights;
52
53  /** Current Conditional mean */
54  private double m_CondMean;
55
56  /** Current Values mean */
57  private double m_ValueMean;
58
59  /** Current covariance matrix */
60  private Matrix m_Covariance;
61
62  /** Whether we can optimise the kernel summation */
63  private boolean m_AllWeightsOne = true;
64
65  /** 2 * PI */
66  private static double TWO_PI = 2 * Math.PI;
67 
68  // ===============
69  // Private methods
70  // ===============
71
72  /**
73   * Execute a binary search to locate the nearest data value
74   *
75   * @param key the data value to locate
76   * @param secondaryKey the data value to locate
77   * @return the index of the nearest data value
78   */
79  private int findNearestPair(double key, double secondaryKey) {
80   
81    int low = 0; 
82    int high = m_CondValues.size();
83    int middle = 0;
84    while (low < high) {
85      middle = (low + high) / 2;
86      double current = ((Double)m_CondValues.elementAt(middle)).doubleValue();
87      if (current == key) {
88        double secondary = ((Double)m_Values.elementAt(middle)).doubleValue();
89        if (secondary == secondaryKey) {
90          return middle;
91        }
92        if (secondary > secondaryKey) {
93          high = middle;
94        } else if (secondary < secondaryKey) {
95          low = middle + 1;
96        }
97      }
98      if (current > key) {
99        high = middle;
100      } else if (current < key) {
101        low = middle + 1;
102      }
103    }
104    return low;
105  }
106
107  /** Calculate covariance and value means */
108  private void calculateCovariance() {
109   
110    double sumValues = 0, sumConds = 0;
111    for(int i = 0; i < m_Values.size(); i++) {
112      sumValues += ((Double)m_Values.elementAt(i)).doubleValue()
113        * ((Double)m_Weights.elementAt(i)).doubleValue();
114      sumConds += ((Double)m_CondValues.elementAt(i)).doubleValue()
115        * ((Double)m_Weights.elementAt(i)).doubleValue();
116    }
117    m_ValueMean = sumValues / m_SumOfWeights;
118    m_CondMean = sumConds / m_SumOfWeights;
119    double c00 = 0, c01 = 0, c10 = 0, c11 = 0;
120    for(int i = 0; i < m_Values.size(); i++) {
121      double x = ((Double)m_Values.elementAt(i)).doubleValue();
122      double y = ((Double)m_CondValues.elementAt(i)).doubleValue();
123      double weight = ((Double)m_Weights.elementAt(i)).doubleValue();
124      c00 += (x - m_ValueMean) * (x - m_ValueMean) * weight;
125      c01 += (x - m_ValueMean) * (y - m_CondMean) * weight;
126      c11 += (y - m_CondMean) * (y - m_CondMean) * weight;
127    }
128    c00 /= (m_SumOfWeights - 1.0);
129    c01 /= (m_SumOfWeights - 1.0);
130    c10 = c01;
131    c11 /= (m_SumOfWeights - 1.0);
132    m_Covariance = new Matrix(2, 2);
133    m_Covariance.set(0, 0, c00);
134    m_Covariance.set(0, 1, c01);
135    m_Covariance.set(1, 0, c10);
136    m_Covariance.set(1, 1, c11);
137  }
138
139  /**
140   * Returns value for normal kernel
141   *
142   * @param x the argument to the kernel function
143   * @param variance the variance
144   * @return the value for a normal kernel
145   */
146  private double normalKernel(double x, double variance) {
147   
148    return Math.exp(-x * x / (2 * variance)) / Math.sqrt(variance * TWO_PI);
149  }
150 
151  /**
152   * Add a new data value to the current estimator.
153   *
154   * @param data the new data value
155   * @param given the new value that data is conditional upon
156   * @param weight the weight assigned to the data value
157   */
158  public void addValue(double data, double given, double weight) {
159   
160    int insertIndex = findNearestPair(given, data);
161    if ((m_Values.size() <= insertIndex)
162        || (((Double)m_CondValues.elementAt(insertIndex)).doubleValue()
163            != given)
164        || (((Double)m_Values.elementAt(insertIndex)).doubleValue()
165            != data)) {
166      m_CondValues.insertElementAt(new Double(given), insertIndex);
167      m_Values.insertElementAt(new Double(data), insertIndex);
168      m_Weights.insertElementAt(new Double(weight), insertIndex);
169      if (weight != 1) {
170        m_AllWeightsOne = false;
171      }
172    } else {
173      double newWeight = ((Double)m_Weights.elementAt(insertIndex))
174        .doubleValue();
175      newWeight += weight;
176      m_Weights.setElementAt(new Double(newWeight), insertIndex);
177      m_AllWeightsOne = false;     
178    }
179    m_SumOfWeights += weight;
180    // Invalidate any previously calculated covariance matrix
181    m_Covariance = null;
182  }
183
184  /**
185   * Get a probability estimator for a value
186   *
187   * @param given the new value that data is conditional upon
188   * @return the estimator for the supplied value given the condition
189   */
190  public Estimator getEstimator(double given) {
191   
192    if (m_Covariance == null) {
193      calculateCovariance();
194    }
195    Estimator result = new MahalanobisEstimator(m_Covariance,
196                                                given - m_CondMean,
197                                                m_ValueMean);
198    return result;
199  }
200
201  /**
202   * Get a probability estimate for a value
203   *
204   * @param data the value to estimate the probability of
205   * @param given the new value that data is conditional upon
206   * @return the estimated probability of the supplied value
207   */
208  public double getProbability(double data, double given) {
209   
210    return getEstimator(given).getProbability(data);
211  }
212
213  /** Display a representation of this estimator */
214  public String toString() {
215   
216    if (m_Covariance == null) {
217      calculateCovariance();
218    }
219    String result = "NN Conditional Estimator. "
220      + m_CondValues.size() 
221      + " data points.  Mean = " + Utils.doubleToString(m_ValueMean, 4, 2)
222      + "  Conditional mean = " + Utils.doubleToString(m_CondMean, 4, 2);
223    result += "  Covariance Matrix: \n" + m_Covariance;
224    return result;
225  }
226 
227  /**
228   * Returns the revision string.
229   *
230   * @return            the revision
231   */
232  public String getRevision() {
233    return RevisionUtils.extract("$Revision: 1.8 $");
234  }
235
236  /**
237   * Main method for testing this class.
238   *
239   * @param argv should contain a sequence of numeric values
240   */
241  public static void main(String [] argv) {
242   
243    try {
244      int seed = 42;
245      if (argv.length > 0) {
246        seed = Integer.parseInt(argv[0]);
247      }
248      NNConditionalEstimator newEst = new NNConditionalEstimator();
249
250      // Create 100 random points and add them
251      Random r = new Random(seed);
252     
253      int numPoints = 50;
254      if (argv.length > 2) {
255        numPoints = Integer.parseInt(argv[2]);
256      }
257      for(int i = 0; i < numPoints; i++) {
258        int x = Math.abs(r.nextInt() % 100);
259        int y = Math.abs(r.nextInt() % 100);
260        System.out.println("# " + x + "  " + y);
261        newEst.addValue(x, y, 1);
262      }
263      //    System.out.println(newEst);
264      int cond;
265      if (argv.length > 1) {
266        cond = Integer.parseInt(argv[1]);
267      }
268      else cond = Math.abs(r.nextInt() % 100);
269      System.out.println("## Conditional = " + cond);
270      Estimator result = newEst.getEstimator(cond);
271      for(int i = 0; i <= 100; i+= 5) {
272        System.out.println(" " + i + "  " + result.getProbability(i));
273      }
274    } catch (Exception e) {
275      System.out.println(e.getMessage());
276    }
277  }
278}
Note: See TracBrowser for help on using the repository browser.