source: tags/MetisMQIDemo/src/main/java/weka/classifiers/bayes/net/estimate/BMAEstimator.java

Last change on this file was 29, checked in by gnappo, 15 years ago

Taggata versione per la demo e aggiunto branch.

File size: 11.3 KB
Line 
1/*
2 * This program is free software; you can redistribute it and/or modify
3 * it under the terms of the GNU General Public License as published by
4 * the Free Software Foundation; either version 2 of the License, or
5 * (at your option) any later version.
6 *
7 * This program is distributed in the hope that it will be useful,
8 * but WITHOUT ANY WARRANTY; without even the implied warranty of
9 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 * GNU General Public License for more details.
11 *
12 * You should have received a copy of the GNU General Public License
13 * along with this program; if not, write to the Free Software
14 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 * BayesNet.java
19 * Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
20 *
21 */
22 
23package weka.classifiers.bayes.net.estimate;
24
25import weka.classifiers.bayes.BayesNet;
26import weka.classifiers.bayes.net.search.local.K2;
27import weka.core.Instance;
28import weka.core.Instances;
29import weka.core.Option;
30import weka.core.RevisionUtils;
31import weka.core.Statistics;
32import weka.core.Utils;
33import weka.estimators.Estimator;
34
35import java.util.Enumeration;
36import java.util.Vector;
37
38/**
39 <!-- globalinfo-start -->
40 * BMAEstimator estimates conditional probability tables of a Bayes network using Bayes Model Averaging (BMA).
41 * <p/>
42 <!-- globalinfo-end -->
43 *
44 <!-- options-start -->
45 * Valid options are: <p/>
46 *
47 * <pre> -k2
48 *  Whether to use K2 prior.
49 * </pre>
50 *
51 * <pre> -A &lt;alpha&gt;
52 *  Initial count (alpha)
53 * </pre>
54 *
55 <!-- options-end -->
56 *
57 * @author Remco Bouckaert (rrb@xm.co.nz)
58 * @version $Revision: 1.8 $
59 */
60public class BMAEstimator 
61    extends SimpleEstimator {
62
63    /** for serialization */
64    static final long serialVersionUID = -1846028304233257309L;
65 
66    /** whether to use K2 prior */
67    protected boolean m_bUseK2Prior = false;
68   
69    /**
70     * Returns a string describing this object
71     * @return a description of the classifier suitable for
72     * displaying in the explorer/experimenter gui
73     */
74    public String globalInfo() {
75      return 
76          "BMAEstimator estimates conditional probability tables of a Bayes "
77        + "network using Bayes Model Averaging (BMA).";
78    }
79
80    /**
81     * estimateCPTs estimates the conditional probability tables for the Bayes
82     * Net using the network structure.
83     *
84     * @param bayesNet the bayes net to use
85     * @throws Exception if an error occurs
86     */
87    public void estimateCPTs(BayesNet bayesNet) throws Exception {
88        initCPTs(bayesNet);
89
90        Instances instances = bayesNet.m_Instances;
91        // sanity check to see if nodes have not more than one parent
92        for (int iAttribute = 0; iAttribute < instances.numAttributes(); iAttribute++) {
93            if (bayesNet.getParentSet(iAttribute).getNrOfParents() > 1) {
94                throw new Exception("Cannot handle networks with nodes with more than 1 parent (yet).");
95            }
96        }
97
98        BayesNet EmptyNet = new BayesNet();
99        K2 oSearchAlgorithm = new K2();
100        oSearchAlgorithm.setInitAsNaiveBayes(false);
101        oSearchAlgorithm.setMaxNrOfParents(0);
102        EmptyNet.setSearchAlgorithm(oSearchAlgorithm);
103        EmptyNet.buildClassifier(instances);
104
105        BayesNet NBNet = new BayesNet();
106        oSearchAlgorithm.setInitAsNaiveBayes(true);
107        oSearchAlgorithm.setMaxNrOfParents(1);
108        NBNet.setSearchAlgorithm(oSearchAlgorithm);
109        NBNet.buildClassifier(instances);
110
111        // estimate CPTs
112        for (int iAttribute = 0; iAttribute < instances.numAttributes(); iAttribute++) {
113            if (iAttribute != instances.classIndex()) {
114                  double w1 = 0.0, w2 = 0.0;
115                  int nAttValues = instances.attribute(iAttribute).numValues();
116                  if (m_bUseK2Prior == true) {
117                      // use Cooper and Herskovitz's metric
118                      for (int iAttValue = 0; iAttValue < nAttValues; iAttValue++) {
119                        w1 += Statistics.lnGamma(1 + ((DiscreteEstimatorBayes)EmptyNet.m_Distributions[iAttribute][0]).getCount(iAttValue))
120                              - Statistics.lnGamma(1);
121                      }
122                      w1 += Statistics.lnGamma(nAttValues) - Statistics.lnGamma(nAttValues + instances.numInstances());
123
124                      for (int iParent = 0; iParent < bayesNet.getParentSet(iAttribute).getCardinalityOfParents(); iParent++) {
125                        int nTotal = 0;
126                          for (int iAttValue = 0; iAttValue < nAttValues; iAttValue++) {
127                            double nCount = ((DiscreteEstimatorBayes)NBNet.m_Distributions[iAttribute][iParent]).getCount(iAttValue);
128                            w2 += Statistics.lnGamma(1 + nCount)
129                                  - Statistics.lnGamma(1);
130                            nTotal += nCount;
131                          }
132                        w2 += Statistics.lnGamma(nAttValues) - Statistics.lnGamma(nAttValues + nTotal);
133                      }
134                  } else {
135                      // use BDe metric
136                      for (int iAttValue = 0; iAttValue < nAttValues; iAttValue++) {
137                        w1 += Statistics.lnGamma(1.0/nAttValues + ((DiscreteEstimatorBayes)EmptyNet.m_Distributions[iAttribute][0]).getCount(iAttValue))
138                              - Statistics.lnGamma(1.0/nAttValues);
139                      }
140                      w1 += Statistics.lnGamma(1) - Statistics.lnGamma(1 + instances.numInstances());
141
142                      int nParentValues = bayesNet.getParentSet(iAttribute).getCardinalityOfParents();
143                      for (int iParent = 0; iParent < nParentValues; iParent++) {
144                        int nTotal = 0;
145                          for (int iAttValue = 0; iAttValue < nAttValues; iAttValue++) {
146                            double nCount = ((DiscreteEstimatorBayes)NBNet.m_Distributions[iAttribute][iParent]).getCount(iAttValue);
147                            w2 += Statistics.lnGamma(1.0/(nAttValues * nParentValues) + nCount)
148                                  - Statistics.lnGamma(1.0/(nAttValues * nParentValues));
149                            nTotal += nCount;
150                          }
151                        w2 += Statistics.lnGamma(1) - Statistics.lnGamma(1 + nTotal);
152                      }
153                  }
154               
155//    System.out.println(w1 + " " + w2 + " " + (w2 - w1));
156                  if (w1 < w2) {
157                    w2 = w2 - w1;
158                    w1 = 0;
159                    w1 = 1 / (1 + Math.exp(w2));
160                    w2 = Math.exp(w2) / (1 + Math.exp(w2));
161                  } else {
162                    w1 = w1 - w2;
163                    w2 = 0;
164                    w2 = 1 / (1 + Math.exp(w1));
165                    w1 = Math.exp(w1) / (1 + Math.exp(w1));
166                  }
167               
168                  for (int iParent = 0; iParent < bayesNet.getParentSet(iAttribute).getCardinalityOfParents(); iParent++) {
169                      bayesNet.m_Distributions[iAttribute][iParent] = 
170                      new DiscreteEstimatorFullBayes(
171                        instances.attribute(iAttribute).numValues(), 
172                        w1, w2,
173                        (DiscreteEstimatorBayes) EmptyNet.m_Distributions[iAttribute][0],
174                        (DiscreteEstimatorBayes) NBNet.m_Distributions[iAttribute][iParent],
175                        m_fAlpha
176                       );
177                  } 
178            }
179        }
180        int iAttribute = instances.classIndex();
181        bayesNet.m_Distributions[iAttribute][0] = EmptyNet.m_Distributions[iAttribute][0];
182    } // estimateCPTs
183
184    /**
185     * Updates the classifier with the given instance.
186     *
187     * @param bayesNet the bayes net to use
188     * @param instance the new training instance to include in the model
189     * @throws Exception if the instance could not be incorporated in
190     * the model.
191     */
192    public void updateClassifier(BayesNet bayesNet, Instance instance) throws Exception {
193        throw new Exception("updateClassifier does not apply to BMA estimator");
194    } // updateClassifier
195
196    /**
197     * initCPTs reserves space for CPTs and set all counts to zero
198     *
199     * @param bayesNet the bayes net to use
200     * @throws Exception if something goes wrong
201     */
202    public void initCPTs(BayesNet bayesNet) throws Exception {
203        // Reserve space for CPTs
204        int nMaxParentCardinality = 1;
205
206        for (int iAttribute = 0; iAttribute < bayesNet.m_Instances.numAttributes(); iAttribute++) {
207            if (bayesNet.getParentSet(iAttribute).getCardinalityOfParents() > nMaxParentCardinality) {
208                nMaxParentCardinality = bayesNet.getParentSet(iAttribute).getCardinalityOfParents();
209            }
210        }
211
212        // Reserve plenty of memory
213        bayesNet.m_Distributions = new Estimator[bayesNet.m_Instances.numAttributes()][nMaxParentCardinality];
214    } // initCPTs
215
216
217    /**
218     * Returns whether K2 prior is used
219     *
220     * @return true if K2 prior is used
221     */
222    public boolean isUseK2Prior() {
223        return m_bUseK2Prior;
224    }
225
226    /**
227     * Sets the UseK2Prior.
228     *
229     * @param bUseK2Prior The bUseK2Prior to set
230     */
231    public void setUseK2Prior(boolean bUseK2Prior) {
232        m_bUseK2Prior = bUseK2Prior;
233    }
234
235    /**
236     * Returns an enumeration describing the available options
237     *
238     * @return an enumeration of all the available options
239     */
240    public Enumeration listOptions() {
241        Vector newVector = new Vector(1);
242
243        newVector.addElement(new Option(
244            "\tWhether to use K2 prior.\n", 
245            "k2", 0, "-k2"));
246
247        Enumeration enu = super.listOptions();
248        while (enu.hasMoreElements()) {
249                newVector.addElement(enu.nextElement());
250        }
251
252        return newVector.elements();
253    } // listOptions
254
255    /**
256     * Parses a given list of options. <p/>
257     *
258     <!-- options-start -->
259     * Valid options are: <p/>
260     *
261     * <pre> -k2
262     *  Whether to use K2 prior.
263     * </pre>
264     *
265     * <pre> -A &lt;alpha&gt;
266     *  Initial count (alpha)
267     * </pre>
268     *
269     <!-- options-end -->
270     *
271     * @param options the list of options as an array of strings
272     * @throws Exception if an option is not supported
273     */
274    public void setOptions(String[] options) throws Exception {
275        setUseK2Prior(Utils.getFlag("k2", options));
276
277        super.setOptions(options);
278    } // setOptions
279
280    /**
281     * Gets the current settings of the classifier.
282     *
283     * @return an array of strings suitable for passing to setOptions
284     */
285    public String[] getOptions() {
286        String[] superOptions = super.getOptions();
287        String[] options = new String[1 + superOptions.length];
288        int current = 0;
289
290        if (isUseK2Prior())
291          options[current++] = "-k2";
292
293        // insert options from parent class
294        for (int iOption = 0; iOption < superOptions.length; iOption++) {
295                options[current++] = superOptions[iOption];
296        }
297
298        // Fill up rest with empty strings, not nulls!
299        while (current < options.length) {
300                options[current++] = "";
301        }
302
303        return options;
304    } // getOptions
305   
306    /**
307     * Returns the revision string.
308     *
309     * @return          the revision
310     */
311    public String getRevision() {
312      return RevisionUtils.extract("$Revision: 1.8 $");
313    }
314} // class BMAEstimator
Note: See TracBrowser for help on using the repository browser.