source: tags/MetisMQIDemo/src/main/java/weka/classifiers/bayes/net/estimate/SimpleEstimator.java

Last change on this file was 29, checked in by gnappo, 15 years ago

Taggata versione per la demo e aggiunto branch.

File size: 7.8 KB
Line 
1/*
2 * This program is free software; you can redistribute it and/or modify
3 * it under the terms of the GNU General Public License as published by
4 * the Free Software Foundation; either version 2 of the License, or
5 * (at your option) any later version.
6 *
7 * This program is distributed in the hope that it will be useful,
8 * but WITHOUT ANY WARRANTY; without even the implied warranty of
9 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 * GNU General Public License for more details.
11 *
12 * You should have received a copy of the GNU General Public License
13 * along with this program; if not, write to the Free Software
14 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 * BayesNet.java
19 * Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
20 *
21 */
22 
23package weka.classifiers.bayes.net.estimate;
24
25import weka.classifiers.bayes.BayesNet;
26import weka.core.Instance;
27import weka.core.Instances;
28import weka.core.RevisionUtils;
29import weka.core.Utils;
30import weka.estimators.Estimator;
31
32import java.util.Enumeration;
33
34/**
35 <!-- globalinfo-start -->
36 * SimpleEstimator is used for estimating the conditional probability tables of a Bayes network once the structure has been learned. Estimates probabilities directly from data.
37 * <p/>
38 <!-- globalinfo-end -->
39 *
40 <!-- options-start -->
41 * Valid options are: <p/>
42 *
43 * <pre> -A &lt;alpha&gt;
44 *  Initial count (alpha)
45 * </pre>
46 *
47 <!-- options-end -->
48 *
49 * @author Remco Bouckaert (rrb@xm.co.nz)
50 * @version $Revision: 1.6 $
51 */
52public class SimpleEstimator 
53    extends BayesNetEstimator {
54
55    /** for serialization */
56    static final long serialVersionUID = 5874941612331806172L;
57   
58    /**
59     * Returns a string describing this object
60     * @return a description of the classifier suitable for
61     * displaying in the explorer/experimenter gui
62     */
63    public String globalInfo() {
64      return 
65          "SimpleEstimator is used for estimating the conditional probability "
66        + "tables of a Bayes network once the structure has been learned. "
67        + "Estimates probabilities directly from data.";
68    }
69 
70    /**
71     * estimateCPTs estimates the conditional probability tables for the Bayes
72     * Net using the network structure.
73     *
74     * @param bayesNet the bayes net to use
75     * @throws Exception if something goes wrong
76     */
77    public void estimateCPTs(BayesNet bayesNet) throws Exception {
78            initCPTs(bayesNet);
79
80            // Compute counts
81            Enumeration enumInsts = bayesNet.m_Instances.enumerateInstances();
82            while (enumInsts.hasMoreElements()) {
83                Instance instance = (Instance) enumInsts.nextElement();
84
85                updateClassifier(bayesNet, instance);
86            }
87    } // estimateCPTs
88
89    /**
90     * Updates the classifier with the given instance.
91     *
92     * @param bayesNet the bayes net to use
93     * @param instance the new training instance to include in the model
94     * @throws Exception if the instance could not be incorporated in
95     * the model.
96     */
97    public void updateClassifier(BayesNet bayesNet, Instance instance) throws Exception {
98        for (int iAttribute = 0; iAttribute < bayesNet.m_Instances.numAttributes(); iAttribute++) {
99            double iCPT = 0;
100
101            for (int iParent = 0; iParent < bayesNet.getParentSet(iAttribute).getNrOfParents(); iParent++) {
102                int nParent = bayesNet.getParentSet(iAttribute).getParent(iParent);
103
104                iCPT = iCPT * bayesNet.m_Instances.attribute(nParent).numValues() + instance.value(nParent);
105            }
106
107            bayesNet.m_Distributions[iAttribute][(int) iCPT].addValue(instance.value(iAttribute), instance.weight());
108        }
109    } // updateClassifier
110
111
112    /**
113     * initCPTs reserves space for CPTs and set all counts to zero
114     *
115     * @param bayesNet the bayes net to use
116     * @throws Exception if something goes wrong
117     */
118    public void initCPTs(BayesNet bayesNet) throws Exception {
119        Instances instances = bayesNet.m_Instances;
120       
121        // Reserve space for CPTs
122        int nMaxParentCardinality = 1;
123        for (int iAttribute = 0; iAttribute < instances.numAttributes(); iAttribute++) {
124            if (bayesNet.getParentSet(iAttribute).getCardinalityOfParents() > nMaxParentCardinality) {
125                nMaxParentCardinality = bayesNet.getParentSet(iAttribute).getCardinalityOfParents();
126            }
127        }
128       
129        // Reserve plenty of memory
130        bayesNet.m_Distributions = new Estimator[instances.numAttributes()][nMaxParentCardinality];
131       
132        // estimate CPTs
133        for (int iAttribute = 0; iAttribute < instances.numAttributes(); iAttribute++) {
134            for (int iParent = 0; iParent < bayesNet.getParentSet(iAttribute).getCardinalityOfParents(); iParent++) {
135                bayesNet.m_Distributions[iAttribute][iParent] =
136                    new DiscreteEstimatorBayes(instances.attribute(iAttribute).numValues(), m_fAlpha);
137            }
138        }
139    } // initCPTs
140
141    /**
142     * Calculates the class membership probabilities for the given test
143     * instance.
144     *
145     * @param bayesNet the bayes net to use
146     * @param instance the instance to be classified
147     * @return predicted class probability distribution
148     * @throws Exception if there is a problem generating the prediction
149     */
150    public double[] distributionForInstance(BayesNet bayesNet, Instance instance) throws Exception {
151        Instances instances = bayesNet.m_Instances;
152        int nNumClasses = instances.numClasses();
153        double[] fProbs = new double[nNumClasses];
154
155        for (int iClass = 0; iClass < nNumClasses; iClass++) {
156            fProbs[iClass] = 1.0;
157        }
158
159        for (int iClass = 0; iClass < nNumClasses; iClass++) {
160            double logfP = 0;
161
162            for (int iAttribute = 0; iAttribute < instances.numAttributes(); iAttribute++) {
163                double iCPT = 0;
164
165                for (int iParent = 0; iParent < bayesNet.getParentSet(iAttribute).getNrOfParents(); iParent++) {
166                    int nParent = bayesNet.getParentSet(iAttribute).getParent(iParent);
167
168                    if (nParent == instances.classIndex()) {
169                        iCPT = iCPT * nNumClasses + iClass;
170                    } else {
171                        iCPT = iCPT * instances.attribute(nParent).numValues() + instance.value(nParent);
172                    }
173                }
174
175                if (iAttribute == instances.classIndex()) {
176                    //    fP *=
177                    //      m_Distributions[iAttribute][(int) iCPT].getProbability(iClass);
178                    logfP += Math.log(bayesNet.m_Distributions[iAttribute][(int) iCPT].getProbability(iClass));
179                } else {
180                    //    fP *=
181                    //      m_Distributions[iAttribute][(int) iCPT]
182                    //        .getProbability(instance.value(iAttribute));
183                    logfP
184                        += Math.log(bayesNet.m_Distributions[iAttribute][(int) iCPT].getProbability(instance.value(iAttribute)));
185                }
186            }
187
188            //      fProbs[iClass] *= fP;
189            fProbs[iClass] += logfP;
190        }
191
192        // Find maximum
193        double fMax = fProbs[0];
194        for (int iClass = 0; iClass < nNumClasses; iClass++) {
195            if (fProbs[iClass] > fMax) {
196                fMax = fProbs[iClass];
197            }
198        }
199        // transform from log-space to normal-space
200        for (int iClass = 0; iClass < nNumClasses; iClass++) {
201            fProbs[iClass] = Math.exp(fProbs[iClass] - fMax);
202        }
203
204        // Display probabilities
205        Utils.normalize(fProbs);
206
207        return fProbs;
208    } // distributionForInstance
209   
210    /**
211     * Returns the revision string.
212     *
213     * @return          the revision
214     */
215    public String getRevision() {
216      return RevisionUtils.extract("$Revision: 1.6 $");
217    }
218
219} // SimpleEstimator
Note: See TracBrowser for help on using the repository browser.