source: src/main/java/weka/estimators/NDConditionalEstimator.java @ 10

Last change on this file since 10 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 4.5 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    NDConditionalEstimator.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.estimators;
24
25import weka.core.RevisionUtils;
26
27/**
28 * Conditional probability estimator for a numeric domain conditional upon
29 * a discrete domain (utilises separate normal estimators for each discrete
30 * conditioning value).
31 *
32 * @author Len Trigg (trigg@cs.waikato.ac.nz)
33 * @version $Revision: 1.7 $
34 */
35public class NDConditionalEstimator implements ConditionalEstimator {
36
37  /** Hold the sub-estimators */
38  private NormalEstimator [] m_Estimators;
39
40  /**
41   * Constructor
42   *
43   * @param numCondSymbols the number of conditioning symbols
44   * @param precision the  precision to which numeric values are given. For
45   * example, if the precision is stated to be 0.1, the values in the
46   * interval (0.25,0.35] are all treated as 0.3.
47   */
48  public NDConditionalEstimator(int numCondSymbols, double precision) {
49
50    m_Estimators = new NormalEstimator [numCondSymbols];
51    for(int i = 0; i < numCondSymbols; i++) {
52      m_Estimators[i] = new NormalEstimator(precision);
53    }
54  }
55
56  /**
57   * Add a new data value to the current estimator.
58   *
59   * @param data the new data value
60   * @param given the new value that data is conditional upon
61   * @param weight the weight assigned to the data value
62   */
63  public void addValue(double data, double given, double weight) {
64
65    m_Estimators[(int)given].addValue(data, weight);
66  }
67
68  /**
69   * Get a probability estimator for a value
70   *
71   * @param given the new value that data is conditional upon
72   * @return the estimator for the supplied value given the condition
73   */
74  public Estimator getEstimator(double given) {
75
76    return m_Estimators[(int)given];
77  }
78
79  /**
80   * Get a probability estimate for a value
81   *
82   * @param data the value to estimate the probability of
83   * @param given the new value that data is conditional upon
84   * @return the estimated probability of the supplied value
85   */
86  public double getProbability(double data, double given) {
87
88    return getEstimator(given).getProbability(data);
89  }
90
91  /**
92   * Display a representation of this estimator
93   */
94  public String toString() {
95
96    String result = "ND Conditional Estimator. " 
97      + m_Estimators.length + " sub-estimators:\n";
98    for(int i = 0; i < m_Estimators.length; i++) {
99      result += "Sub-estimator " + i + ": " + m_Estimators[i];
100    }
101    return result;
102  }
103 
104  /**
105   * Returns the revision string.
106   *
107   * @return            the revision
108   */
109  public String getRevision() {
110    return RevisionUtils.extract("$Revision: 1.7 $");
111  }
112
113  /**
114   * Main method for testing this class.
115   *
116   * @param argv should contain a sequence of pairs of integers which
117   * will be treated as numeric, symbolic.
118   */
119  public static void main(String [] argv) {
120   
121    try {
122      if (argv.length == 0) {
123        System.out.println("Please specify a set of instances.");
124        return;
125      }
126      int currentA = Integer.parseInt(argv[0]);
127      int maxA = currentA;
128      int currentB = Integer.parseInt(argv[1]);
129      int maxB = currentB;
130      for(int i = 2; i < argv.length - 1; i += 2) {
131        currentA = Integer.parseInt(argv[i]);
132        currentB = Integer.parseInt(argv[i + 1]);
133        if (currentA > maxA) {
134          maxA = currentA;
135        }
136        if (currentB > maxB) {
137          maxB = currentB;
138        }
139      }
140      NDConditionalEstimator newEst = new NDConditionalEstimator(maxB + 1,
141                                                                 1);
142      for(int i = 0; i < argv.length - 1; i += 2) {
143        currentA = Integer.parseInt(argv[i]);
144        currentB = Integer.parseInt(argv[i + 1]);
145        System.out.println(newEst);
146        System.out.println("Prediction for " + currentA + '|' + currentB
147                           + " = "
148                           + newEst.getProbability(currentA, currentB));
149        newEst.addValue(currentA, currentB, 1);
150      }
151    } catch (Exception e) {
152      System.out.println(e.getMessage());
153    }
154  }
155}
Note: See TracBrowser for help on using the repository browser.