source: src/main/java/weka/estimators/DNConditionalEstimator.java @ 12

Last change on this file since 12 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 4.9 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    DNConditionalEstimator.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.estimators;
24
25import weka.core.RevisionUtils;
26
27/**
28 * Conditional probability estimator for a discrete domain conditional upon
29 * a numeric domain.
30 *
31 * @author Len Trigg (trigg@cs.waikato.ac.nz)
32 * @version $Revision: 1.8 $
33 */
34public class DNConditionalEstimator implements ConditionalEstimator {
35
36  /** Hold the sub-estimators */
37  private NormalEstimator [] m_Estimators;
38
39  /** Hold the weights for each of the sub-estimators */
40  private DiscreteEstimator m_Weights;
41
42  /**
43   * Constructor
44   *
45   * @param numSymbols the number of symbols
46   * @param precision the  precision to which numeric values are given. For
47   * example, if the precision is stated to be 0.1, the values in the
48   * interval (0.25,0.35] are all treated as 0.3.
49   */
50  public DNConditionalEstimator(int numSymbols, double precision) {
51
52    m_Estimators = new NormalEstimator [numSymbols];
53    for(int i = 0; i < numSymbols; i++) {
54        m_Estimators[i] = new NormalEstimator(precision);
55    }
56    m_Weights = new DiscreteEstimator(numSymbols, true);
57  }
58
59  /**
60   * Add a new data value to the current estimator.
61   *
62   * @param data the new data value
63   * @param given the new value that data is conditional upon
64   * @param weight the weight assigned to the data value
65   */
66  public void addValue(double data, double given, double weight) {
67
68    m_Estimators[(int)data].addValue(given, weight);
69    m_Weights.addValue((int)data, weight);
70  }
71
72  /**
73   * Get a probability estimator for a value
74   *
75   * @param given the new value that data is conditional upon
76   * @return the estimator for the supplied value given the condition
77   */
78  public Estimator getEstimator(double given) {
79
80    Estimator result = new DiscreteEstimator(m_Estimators.length,false);
81    for(int i = 0; i < m_Estimators.length; i++) {
82      result.addValue(i,m_Weights.getProbability(i)
83                      *m_Estimators[i].getProbability(given));
84    }
85    return result;
86  }
87
88  /**
89   * Get a probability estimate for a value
90   *
91   * @param data the value to estimate the probability of
92   * @param given the new value that data is conditional upon
93   * @return the estimated probability of the supplied value
94   */
95  public double getProbability(double data, double given) {
96
97    return getEstimator(given).getProbability(data);
98  }
99
100  /** Display a representation of this estimator */
101  public String toString() {
102
103    String result = "DN Conditional Estimator. " 
104      + m_Estimators.length + " sub-estimators:\n";
105    for(int i = 0; i < m_Estimators.length; i++) {
106      result += "Sub-estimator " + i + ": " + m_Estimators[i];
107    }
108    result += "Weights of each estimator given by " + m_Weights;
109    return result;
110  }
111 
112  /**
113   * Returns the revision string.
114   *
115   * @return            the revision
116   */
117  public String getRevision() {
118    return RevisionUtils.extract("$Revision: 1.8 $");
119  }
120
121  /**
122   * Main method for testing this class.
123   *
124   * @param argv should contain a sequence of pairs of integers which
125   * will be treated as pairs of symbolic, numeric.
126   */
127  public static void main(String [] argv) {
128   
129    try {
130      if (argv.length == 0) {
131        System.out.println("Please specify a set of instances.");
132        return;
133      }
134      int currentA = Integer.parseInt(argv[0]);
135      int maxA = currentA;
136      int currentB = Integer.parseInt(argv[1]);
137      int maxB = currentB;
138      for(int i = 2; i < argv.length - 1; i += 2) {
139        currentA = Integer.parseInt(argv[i]);
140        currentB = Integer.parseInt(argv[i + 1]);
141        if (currentA > maxA) {
142          maxA = currentA;
143        }
144        if (currentB > maxB) {
145          maxB = currentB;
146        }
147      }
148      DNConditionalEstimator newEst = new DNConditionalEstimator(maxA + 1,
149                                                                 1);
150      for(int i = 0; i < argv.length - 1; i += 2) {
151        currentA = Integer.parseInt(argv[i]);
152        currentB = Integer.parseInt(argv[i + 1]);
153        System.out.println(newEst);
154        System.out.println("Prediction for " + currentA + '|' + currentB
155                           + " = "
156                           + newEst.getProbability(currentA, currentB));
157        newEst.addValue(currentA, currentB, 1);
158      }
159    } catch (Exception e) {
160      System.out.println(e.getMessage());
161    }
162  }
163}
Note: See TracBrowser for help on using the repository browser.