source: src/main/java/weka/estimators/DDConditionalEstimator.java @ 20

Last change on this file since 20 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 4.4 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    DDConditionalEstimator.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.estimators;
24
25import weka.core.RevisionUtils;
26
27 
28/**
29 * Conditional probability estimator for a discrete domain conditional upon
30 * a discrete domain.
31 *
32 * @author Len Trigg (trigg@cs.waikato.ac.nz)
33 * @version $Revision: 1.8 $
34 */
35public class DDConditionalEstimator implements ConditionalEstimator {
36
37  /** Hold the sub-estimators */
38  private DiscreteEstimator [] m_Estimators;
39
40  /**
41   * Constructor
42   *
43   * @param numSymbols the number of possible symbols (remember to include 0)
44   * @param numCondSymbols the number of conditioning symbols
45   * @param laplace if true, sub-estimators will use laplace
46   */
47  public DDConditionalEstimator(int numSymbols, int numCondSymbols,
48                                boolean laplace) {
49   
50    m_Estimators = new DiscreteEstimator [numCondSymbols];
51    for(int i = 0; i < numCondSymbols; i++) {
52      m_Estimators[i] = new DiscreteEstimator(numSymbols, laplace);
53    }
54  }
55
56  /**
57   * Add a new data value to the current estimator.
58   *
59   * @param data the new data value
60   * @param given the new value that data is conditional upon
61   * @param weight the weight assigned to the data value
62   */
63  public void addValue(double data, double given, double weight) {
64   
65    m_Estimators[(int)given].addValue(data, weight);
66  }
67
68  /**
69   * Get a probability estimator for a value
70   *
71   * @param given the new value that data is conditional upon
72   * @return the estimator for the supplied value given the condition
73   */
74  public Estimator getEstimator(double given) {
75   
76    return m_Estimators[(int)given];
77  }
78
79  /**
80   * Get a probability estimate for a value
81   *
82   * @param data the value to estimate the probability of
83   * @param given the new value that data is conditional upon
84   * @return the estimated probability of the supplied value
85   */
86  public double getProbability(double data, double given) {
87   
88    return getEstimator(given).getProbability(data);
89  }
90
91  /** Display a representation of this estimator */
92  public String toString() {
93   
94    String result = "DD Conditional Estimator. " 
95      + m_Estimators.length + " sub-estimators:\n";
96    for(int i = 0; i < m_Estimators.length; i++) {
97      result += "Sub-estimator " + i + ": " + m_Estimators[i];
98    }
99    return result;
100  }
101 
102  /**
103   * Returns the revision string.
104   *
105   * @return            the revision
106   */
107  public String getRevision() {
108    return RevisionUtils.extract("$Revision: 1.8 $");
109  }
110
111  /**
112   * Main method for testing this class.
113   *
114   * @param argv should contain a sequence of pairs of integers which
115   * will be treated as symbolic.
116   */
117  public static void main(String [] argv) {
118   
119    try {
120      if (argv.length == 0) {
121        System.out.println("Please specify a set of instances.");
122        return;
123      }
124      int currentA = Integer.parseInt(argv[0]);
125      int maxA = currentA;
126      int currentB = Integer.parseInt(argv[1]);
127      int maxB = currentB;
128      for(int i = 2; i < argv.length - 1; i += 2) {
129        currentA = Integer.parseInt(argv[i]);
130        currentB = Integer.parseInt(argv[i + 1]);
131        if (currentA > maxA) {
132          maxA = currentA;
133        }
134        if (currentB > maxB) {
135          maxB = currentB;
136        }
137      }
138      DDConditionalEstimator newEst = new DDConditionalEstimator(maxA + 1,
139                                                                 maxB + 1,
140                                                                 true);
141      for(int i = 0; i < argv.length - 1; i += 2) {
142        currentA = Integer.parseInt(argv[i]);
143        currentB = Integer.parseInt(argv[i + 1]);
144        System.out.println(newEst);
145        System.out.println("Prediction for " + currentA + '|' + currentB
146                           + " = "
147                           + newEst.getProbability(currentA, currentB));
148        newEst.addValue(currentA, currentB, 1);
149      }
150    } catch (Exception e) {
151      System.out.println(e.getMessage());
152    }
153  }
154}
Note: See TracBrowser for help on using the repository browser.