source: tags/MetisMQIDemo/src/main/java/weka/classifiers/bayes/net/estimate/DiscreteEstimatorBayes.java

Last change on this file was 29, checked in by gnappo, 15 years ago

Taggata versione per la demo e aggiunto branch.

File size: 6.5 KB
Line 
1
2/*
3 * This program is free software; you can redistribute it and/or modify
4 * it under the terms of the GNU General Public License as published by
5 * the Free Software Foundation; either version 2 of the License, or
6 * (at your option) any later version.
7 *
8 * This program is distributed in the hope that it will be useful,
9 * but WITHOUT ANY WARRANTY; without even the implied warranty of
10 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 * GNU General Public License for more details.
12 *
13 * You should have received a copy of the GNU General Public License
14 * along with this program; if not, write to the Free Software
15 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
16 */
17
18/*
19 * DiscreteEstimatorBayes.java
20 * Adapted from DiscreteEstimator.java
21 *
22 */
23package weka.classifiers.bayes.net.estimate;
24
25import weka.classifiers.bayes.net.search.local.Scoreable;
26import weka.core.RevisionUtils;
27import weka.core.Statistics;
28import weka.core.Utils;
29import weka.estimators.DiscreteEstimator;
30import weka.estimators.Estimator;
31
32/**
33 * Symbolic probability estimator based on symbol counts and a prior.
34 *
35 * @author Remco Bouckaert (rrb@xm.co.nz)
36 * @version $Revision: 1.7 $
37 */
38public class DiscreteEstimatorBayes extends Estimator
39  implements Scoreable {
40
41  /** for serialization */
42  static final long serialVersionUID = 4215400230843212684L;
43 
44  /**
45   * Hold the counts
46   */
47  protected double[] m_Counts;
48
49  /**
50   * Hold the sum of counts
51   */
52  protected double   m_SumOfCounts;
53
54  /**
55   * Holds number of symbols in distribution
56   */
57  protected int      m_nSymbols = 0;
58
59  /**
60   * Holds the prior probability
61   */
62  protected double   m_fPrior = 0.0;
63
64  /**
65   * Constructor
66   *
67   * @param nSymbols the number of possible symbols (remember to include 0)
68   * @param fPrior
69   */
70  public DiscreteEstimatorBayes(int nSymbols, double fPrior) {
71    m_fPrior = fPrior;
72    m_nSymbols = nSymbols;
73    m_Counts = new double[m_nSymbols];
74
75    for (int iSymbol = 0; iSymbol < m_nSymbols; iSymbol++) {
76      m_Counts[iSymbol] = m_fPrior;
77    } 
78
79    m_SumOfCounts = m_fPrior * (double) m_nSymbols;
80  }    // DiscreteEstimatorBayes
81
82  /**
83   * Add a new data value to the current estimator.
84   *
85   * @param data the new data value
86   * @param weight the weight assigned to the data value
87   */
88  public void addValue(double data, double weight) {
89    m_Counts[(int) data] += weight;
90    m_SumOfCounts += weight;
91  } 
92
93  /**
94   * Get a probability estimate for a value
95   *
96   * @param data the value to estimate the probability of
97   * @return the estimated probability of the supplied value
98   */
99  public double getProbability(double data) {
100    if (m_SumOfCounts == 0) {
101
102      // this can only happen if numSymbols = 0 in constructor
103      return 0;
104    } 
105
106    return (double) m_Counts[(int) data] / m_SumOfCounts;
107  } 
108
109  /**
110   * Get a counts for a value
111   *
112   * @param data the value to get the counts for
113   * @return the count of the supplied value
114   */
115  public double getCount(double data) {
116    if (m_SumOfCounts == 0) {
117      // this can only happen if numSymbols = 0 in constructor
118      return 0;
119    } 
120
121    return m_Counts[(int) data];
122  } 
123 
124  /**
125   * Gets the number of symbols this estimator operates with
126   *
127   * @return the number of estimator symbols
128   */
129  public int getNumSymbols() {
130    return (m_Counts == null) ? 0 : m_Counts.length;
131  } 
132
133  /**
134   * Gets the log score contribution of this distribution
135   * @param nType score type
136   * @return the score
137   */
138  public double logScore(int nType, int nCardinality) {
139            double fScore = 0.0;
140
141            switch (nType) {
142
143            case (Scoreable.BAYES): {
144              for (int iSymbol = 0; iSymbol < m_nSymbols; iSymbol++) {
145                fScore += Statistics.lnGamma(m_Counts[iSymbol]);
146              } 
147
148              fScore -= Statistics.lnGamma(m_SumOfCounts);
149              if (m_fPrior != 0.0) {
150                      fScore -= m_nSymbols * Statistics.lnGamma(m_fPrior);
151                  fScore += Statistics.lnGamma(m_nSymbols * m_fPrior);
152              }
153            } 
154
155              break;
156                  case (Scoreable.BDeu): {
157                  for (int iSymbol = 0; iSymbol < m_nSymbols; iSymbol++) {
158                        fScore += Statistics.lnGamma(m_Counts[iSymbol]);
159                  } 
160
161                  fScore -= Statistics.lnGamma(m_SumOfCounts);
162                  //fScore -= m_nSymbols * Statistics.lnGamma(1.0);
163                  //fScore += Statistics.lnGamma(m_nSymbols * 1.0);
164              fScore -= m_nSymbols * Statistics.lnGamma(1.0/(m_nSymbols * nCardinality));
165              fScore += Statistics.lnGamma(1.0/nCardinality);
166                } 
167                  break;
168
169            case (Scoreable.MDL):
170
171            case (Scoreable.AIC):
172
173            case (Scoreable.ENTROPY): {
174              for (int iSymbol = 0; iSymbol < m_nSymbols; iSymbol++) {
175                double fP = getProbability(iSymbol);
176
177                fScore += m_Counts[iSymbol] * Math.log(fP);
178              } 
179            } 
180
181              break;
182
183            default: {}
184            }
185
186            return fScore;
187          } 
188
189  /**
190   * Display a representation of this estimator
191   *
192   * @return a string representation of the estimator
193   */
194  public String toString() {
195    String result = "Discrete Estimator. Counts = ";
196
197    if (m_SumOfCounts > 1) {
198      for (int i = 0; i < m_Counts.length; i++) {
199        result += " " + Utils.doubleToString(m_Counts[i], 2);
200      } 
201
202      result += "  (Total = " + Utils.doubleToString(m_SumOfCounts, 2) 
203                + ")\n";
204    } else {
205      for (int i = 0; i < m_Counts.length; i++) {
206        result += " " + m_Counts[i];
207      } 
208
209      result += "  (Total = " + m_SumOfCounts + ")\n";
210    } 
211
212    return result;
213  } 
214 
215  /**
216   * Returns the revision string.
217   *
218   * @return            the revision
219   */
220  public String getRevision() {
221    return RevisionUtils.extract("$Revision: 1.7 $");
222  }
223 
224  /**
225   * Main method for testing this class.
226   *
227   * @param argv should contain a sequence of integers which
228   * will be treated as symbolic.
229   */
230  public static void main(String[] argv) {
231    try {
232      if (argv.length == 0) {
233        System.out.println("Please specify a set of instances.");
234
235        return;
236      } 
237
238      int current = Integer.parseInt(argv[0]);
239      int max = current;
240
241      for (int i = 1; i < argv.length; i++) {
242        current = Integer.parseInt(argv[i]);
243
244        if (current > max) {
245          max = current;
246        } 
247      } 
248
249      DiscreteEstimator newEst = new DiscreteEstimator(max + 1, true);
250
251      for (int i = 0; i < argv.length; i++) {
252        current = Integer.parseInt(argv[i]);
253
254        System.out.println(newEst);
255        System.out.println("Prediction for " + current + " = " 
256                           + newEst.getProbability(current));
257        newEst.addValue(current, 1);
258      } 
259    } catch (Exception e) {
260      System.out.println(e.getMessage());
261    } 
262  }    // main
263 
264}      // class DiscreteEstimatorBayes
Note: See TracBrowser for help on using the repository browser.