source: src/main/java/weka/estimators/DiscreteEstimator.java @ 6

Last change on this file since 6 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 5.9 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    DiscreteEstimator.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.estimators;
24
25import weka.core.Capabilities.Capability;
26import weka.core.Capabilities;
27import weka.core.RevisionUtils;
28import weka.core.Utils;
29
30/**
31 * Simple symbolic probability estimator based on symbol counts.
32 *
33 * @author Len Trigg (trigg@cs.waikato.ac.nz)
34 * @version $Revision: 5490 $
35 */
36public class DiscreteEstimator extends Estimator implements IncrementalEstimator {
37 
38  /** for serialization */
39  private static final long serialVersionUID = -5526486742612434779L;
40
41  /** Hold the counts */
42  private double [] m_Counts;
43 
44  /** Hold the sum of counts */
45  private double m_SumOfCounts;
46 
47 
48  /**
49   * Constructor
50   *
51   * @param numSymbols the number of possible symbols (remember to include 0)
52   * @param laplace if true, counts will be initialised to 1
53   */
54  public DiscreteEstimator(int numSymbols, boolean laplace) {
55   
56    m_Counts = new double [numSymbols];
57    m_SumOfCounts = 0;
58    if (laplace) {
59      for(int i = 0; i < numSymbols; i++) {
60        m_Counts[i] = 1;
61      }
62      m_SumOfCounts = (double)numSymbols;
63    }
64  }
65 
66  /**
67   * Constructor
68   *
69   * @param nSymbols the number of possible symbols (remember to include 0)
70   * @param fPrior value with which counts will be initialised
71   */
72  public DiscreteEstimator(int nSymbols, double fPrior) {   
73   
74    m_Counts = new double [nSymbols];
75    for(int iSymbol = 0; iSymbol < nSymbols; iSymbol++) {
76      m_Counts[iSymbol] = fPrior;
77    }
78    m_SumOfCounts = fPrior * (double) nSymbols;
79  }
80 
81  /**
82   * Add a new data value to the current estimator.
83   *
84   * @param data the new data value
85   * @param weight the weight assigned to the data value
86   */
87  public void addValue(double data, double weight) {
88   
89    m_Counts[(int)data] += weight;
90    m_SumOfCounts += weight;
91  }
92 
93  /**
94   * Get a probability estimate for a value
95   *
96   * @param data the value to estimate the probability of
97   * @return the estimated probability of the supplied value
98   */
99  public double getProbability(double data) {
100   
101    if (m_SumOfCounts == 0) {
102      return 0;
103    }
104    return (double)m_Counts[(int)data] / m_SumOfCounts;
105  }
106 
107  /**
108   * Gets the number of symbols this estimator operates with
109   *
110   * @return the number of estimator symbols
111   */
112  public int getNumSymbols() {
113   
114    return (m_Counts == null) ? 0 : m_Counts.length;
115  }
116 
117 
118  /**
119   * Get the count for a value
120   *
121   * @param data the value to get the count of
122   * @return the count of the supplied value
123   */
124  public double getCount(double data) {
125   
126    if (m_SumOfCounts == 0) {
127      return 0;
128    }
129    return m_Counts[(int)data];
130  }
131 
132 
133  /**
134   * Get the sum of all the counts
135   *
136   * @return the total sum of counts
137   */
138  public double getSumOfCounts() {
139   
140    return m_SumOfCounts;
141  }
142 
143 
144  /**
145   * Display a representation of this estimator
146   */
147  public String toString() {
148   
149    StringBuffer result = new StringBuffer("Discrete Estimator. Counts = ");
150    if (m_SumOfCounts > 1) {
151      for(int i = 0; i < m_Counts.length; i++) {
152        result.append(" ").append(Utils.doubleToString(m_Counts[i], 2));
153      }
154      result.append("  (Total = " ).append(Utils.doubleToString(m_SumOfCounts, 2));
155      result.append(")\n"); 
156    } else {
157      for(int i = 0; i < m_Counts.length; i++) {
158        result.append(" ").append(m_Counts[i]);
159      }
160      result.append("  (Total = ").append(m_SumOfCounts).append(")\n"); 
161    }
162    return result.toString();
163  }
164 
165  /**
166   * Returns default capabilities of the classifier.
167   *
168   * @return      the capabilities of this classifier
169   */
170  public Capabilities getCapabilities() {
171    Capabilities result = super.getCapabilities();
172    result.disableAll();
173   
174    // class
175    if (!m_noClass) {
176      result.enable(Capability.NOMINAL_CLASS);
177      result.enable(Capability.MISSING_CLASS_VALUES);
178    } else {
179      result.enable(Capability.NO_CLASS);
180    }
181   
182    // attributes
183    result.enable(Capability.NUMERIC_ATTRIBUTES);
184    return result;
185  }
186 
187  /**
188   * Returns the revision string.
189   *
190   * @return            the revision
191   */
192  public String getRevision() {
193    return RevisionUtils.extract("$Revision: 5490 $");
194  }
195 
196  /**
197   * Main method for testing this class.
198   *
199   * @param argv should contain a sequence of integers which
200   * will be treated as symbolic.
201   */
202  public static void main(String [] argv) {
203   
204    try {
205      if (argv.length == 0) {
206        System.out.println("Please specify a set of instances.");
207        return;
208      }
209      int current = Integer.parseInt(argv[0]);
210      int max = current;
211      for(int i = 1; i < argv.length; i++) {
212        current = Integer.parseInt(argv[i]);
213        if (current > max) {
214          max = current;
215        }
216      }
217      DiscreteEstimator newEst = new DiscreteEstimator(max + 1, true);
218      for(int i = 0; i < argv.length; i++) {
219        current = Integer.parseInt(argv[i]);
220        System.out.println(newEst);
221        System.out.println("Prediction for " + current
222            + " = " + newEst.getProbability(current));
223        newEst.addValue(current, 1);
224      }
225    } catch (Exception e) {
226      System.out.println(e.getMessage());
227    }
228  }
229}
Note: See TracBrowser for help on using the repository browser.