source: branches/MetisMQI/src/main/java/weka/classifiers/bayes/net/estimate/MultiNomialBMAEstimator.java

Last change on this file was 29, checked in by gnappo, 15 years ago

Taggata versione per la demo e aggiunto branch.

File size: 14.5 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17package weka.classifiers.bayes.net.estimate;
18
19import weka.classifiers.bayes.BayesNet;
20import weka.classifiers.bayes.net.search.local.K2;
21import weka.core.Attribute;
22import weka.core.FastVector;
23import weka.core.Instance;
24import weka.core.DenseInstance;
25import weka.core.Instances;
26import weka.core.Option;
27import weka.core.RevisionUtils;
28import weka.core.Statistics;
29import weka.core.Utils;
30import weka.estimators.Estimator;
31
32import java.util.Enumeration;
33import java.util.Vector;
34
35/**
36 <!-- globalinfo-start -->
37 * Multinomial BMA Estimator.
38 * <p/>
39 <!-- globalinfo-end -->
40 *
41 <!-- options-start -->
42 * Valid options are: <p/>
43 *
44 * <pre> -k2
45 *  Whether to use K2 prior.
46 * </pre>
47 *
48 * <pre> -A &lt;alpha&gt;
49 *  Initial count (alpha)
50 * </pre>
51 *
52 <!-- options-end -->
53 *
54 * @version $Revision: 5987 $
55 * @author Remco Bouckaert (rrb@xm.co.nz)
56 */
57public class MultiNomialBMAEstimator 
58    extends BayesNetEstimator {
59
60    /** for serialization */
61    static final long serialVersionUID = 8330705772601586313L;
62   
63    /** whether to use K2 prior */
64    protected boolean m_bUseK2Prior = true;
65   
66    /**
67     * Returns a string describing this object
68     * @return a description of the classifier suitable for
69     * displaying in the explorer/experimenter gui
70     */
71    public String globalInfo() {
72      return 
73          "Multinomial BMA Estimator.";
74    }
75
76    /**
77     * estimateCPTs estimates the conditional probability tables for the Bayes
78     * Net using the network structure.
79     *
80     * @param bayesNet the bayes net to use
81     * @throws Exception if number of parents doesn't fit (more than 1)
82     */
83    public void estimateCPTs(BayesNet bayesNet) throws Exception {
84        initCPTs(bayesNet);
85       
86        // sanity check to see if nodes have not more than one parent
87        for (int iAttribute = 0; iAttribute < bayesNet.m_Instances.numAttributes(); iAttribute++) {
88            if (bayesNet.getParentSet(iAttribute).getNrOfParents() > 1) {
89                throw new Exception("Cannot handle networks with nodes with more than 1 parent (yet).");
90            }
91        }
92
93                // filter data to binary
94        Instances instances = new Instances(bayesNet.m_Instances);
95        while (instances.numInstances() > 0) {
96            instances.delete(0);
97        }
98        for (int iAttribute = instances.numAttributes() - 1; iAttribute >= 0; iAttribute--) {
99            if (iAttribute != instances.classIndex()) {
100                FastVector values = new FastVector();
101                values.addElement("0");
102                values.addElement("1");
103                Attribute a = new Attribute(instances.attribute(iAttribute).name(), (FastVector) values);
104                instances.deleteAttributeAt(iAttribute);
105                instances.insertAttributeAt(a,iAttribute);
106            }
107        }
108       
109        for (int iInstance = 0; iInstance < bayesNet.m_Instances.numInstances(); iInstance++) {
110            Instance instanceOrig = bayesNet.m_Instances.instance(iInstance);
111            Instance instance = new DenseInstance(instances.numAttributes());
112            for (int iAttribute = 0; iAttribute < instances.numAttributes(); iAttribute++) {
113                if (iAttribute != instances.classIndex()) {
114                    if (instanceOrig.value(iAttribute) > 0) {
115                        instance.setValue(iAttribute, 1);
116                    }
117                } else {
118                    instance.setValue(iAttribute, instanceOrig.value(iAttribute));
119                }
120            }
121        }
122        // ok, now all data is binary, except the class attribute
123        // now learn the empty and tree network
124
125        BayesNet EmptyNet = new BayesNet();
126        K2 oSearchAlgorithm = new K2();
127        oSearchAlgorithm.setInitAsNaiveBayes(false);
128        oSearchAlgorithm.setMaxNrOfParents(0);
129        EmptyNet.setSearchAlgorithm(oSearchAlgorithm);
130        EmptyNet.buildClassifier(instances);
131
132        BayesNet NBNet = new BayesNet();
133        oSearchAlgorithm.setInitAsNaiveBayes(true);
134        oSearchAlgorithm.setMaxNrOfParents(1);
135        NBNet.setSearchAlgorithm(oSearchAlgorithm);
136        NBNet.buildClassifier(instances);
137
138        // estimate CPTs
139        for (int iAttribute = 0; iAttribute < instances.numAttributes(); iAttribute++) {
140            if (iAttribute != instances.classIndex()) {
141                  double w1 = 0.0, w2 = 0.0;
142                  int nAttValues = instances.attribute(iAttribute).numValues();
143                  if (m_bUseK2Prior == true) {
144                      // use Cooper and Herskovitz's metric
145                      for (int iAttValue = 0; iAttValue < nAttValues; iAttValue++) {
146                        w1 += Statistics.lnGamma(1 + ((DiscreteEstimatorBayes)EmptyNet.m_Distributions[iAttribute][0]).getCount(iAttValue))
147                              - Statistics.lnGamma(1);
148                      }
149                      w1 += Statistics.lnGamma(nAttValues) - Statistics.lnGamma(nAttValues + instances.numInstances());
150
151                      for (int iParent = 0; iParent < bayesNet.getParentSet(iAttribute).getCardinalityOfParents(); iParent++) {
152                        int nTotal = 0;
153                          for (int iAttValue = 0; iAttValue < nAttValues; iAttValue++) {
154                            double nCount = ((DiscreteEstimatorBayes)NBNet.m_Distributions[iAttribute][iParent]).getCount(iAttValue);
155                            w2 += Statistics.lnGamma(1 + nCount)
156                                  - Statistics.lnGamma(1);
157                            nTotal += nCount;
158                          }
159                        w2 += Statistics.lnGamma(nAttValues) - Statistics.lnGamma(nAttValues + nTotal);
160                      }
161                  } else {
162                      // use BDe metric
163                      for (int iAttValue = 0; iAttValue < nAttValues; iAttValue++) {
164                        w1 += Statistics.lnGamma(1.0/nAttValues + ((DiscreteEstimatorBayes)EmptyNet.m_Distributions[iAttribute][0]).getCount(iAttValue))
165                              - Statistics.lnGamma(1.0/nAttValues);
166                      }
167                      w1 += Statistics.lnGamma(1) - Statistics.lnGamma(1 + instances.numInstances());
168
169                                          int nParentValues = bayesNet.getParentSet(iAttribute).getCardinalityOfParents();
170                      for (int iParent = 0; iParent < nParentValues; iParent++) {
171                        int nTotal = 0;
172                          for (int iAttValue = 0; iAttValue < nAttValues; iAttValue++) {
173                            double nCount = ((DiscreteEstimatorBayes)NBNet.m_Distributions[iAttribute][iParent]).getCount(iAttValue);
174                            w2 += Statistics.lnGamma(1.0/(nAttValues * nParentValues) + nCount)
175                                  - Statistics.lnGamma(1.0/(nAttValues * nParentValues));
176                            nTotal += nCount;
177                          }
178                        w2 += Statistics.lnGamma(1) - Statistics.lnGamma(1 + nTotal);
179                      }
180                  }
181               
182//    System.out.println(w1 + " " + w2 + " " + (w2 - w1));
183                  // normalize weigths
184                  if (w1 < w2) {
185                    w2 = w2 - w1;
186                    w1 = 0;
187                    w1 = 1 / (1 + Math.exp(w2));
188                    w2 = Math.exp(w2) / (1 + Math.exp(w2));
189                  } else {
190                    w1 = w1 - w2;
191                    w2 = 0;
192                    w2 = 1 / (1 + Math.exp(w1));
193                    w1 = Math.exp(w1) / (1 + Math.exp(w1));
194                  }
195               
196                  for (int iParent = 0; iParent < bayesNet.getParentSet(iAttribute).getCardinalityOfParents(); iParent++) {
197                      bayesNet.m_Distributions[iAttribute][iParent] = 
198                      new DiscreteEstimatorFullBayes(
199                        instances.attribute(iAttribute).numValues(), 
200                        w1, w2,
201                        (DiscreteEstimatorBayes) EmptyNet.m_Distributions[iAttribute][0],
202                        (DiscreteEstimatorBayes) NBNet.m_Distributions[iAttribute][iParent],
203                        m_fAlpha
204                       );
205                  } 
206            }
207        }
208        int iAttribute = instances.classIndex();
209        bayesNet.m_Distributions[iAttribute][0] = EmptyNet.m_Distributions[iAttribute][0];
210    } // estimateCPTs
211
212    /**
213     * Updates the classifier with the given instance.
214     *
215     * @param bayesNet the bayes net to use
216     * @param instance the new training instance to include in the model
217     * @throws Exception if the instance could not be incorporated in
218     * the model.
219     */
220    public void updateClassifier(BayesNet bayesNet, Instance instance) throws Exception {
221        throw new Exception("updateClassifier does not apply to BMA estimator");
222    } // updateClassifier
223
224    /**
225     * initCPTs reserves space for CPTs and set all counts to zero
226     *
227     * @param bayesNet the bayes net to use
228     * @throws Exception doesn't apply
229     */
230    public void initCPTs(BayesNet bayesNet) throws Exception {
231        // Reserve sufficient memory
232        bayesNet.m_Distributions = new Estimator[bayesNet.m_Instances.numAttributes()][2];
233    } // initCPTs
234
235
236    /**
237     * @return boolean
238     */
239    public boolean isUseK2Prior() {
240        return m_bUseK2Prior;
241    }
242
243    /**
244     * Sets the UseK2Prior.
245     *
246     * @param bUseK2Prior The bUseK2Prior to set
247     */
248    public void setUseK2Prior(boolean bUseK2Prior) {
249        m_bUseK2Prior = bUseK2Prior;
250    }
251
252    /**
253     * Calculates the class membership probabilities for the given test
254     * instance.
255     *
256     * @param bayesNet the bayes net to use
257     * @param instance the instance to be classified
258     * @return predicted class probability distribution
259     * @throws Exception if there is a problem generating the prediction
260     */
261    public double[] distributionForInstance(BayesNet bayesNet, Instance instance) throws Exception {
262        Instances instances = bayesNet.m_Instances;
263        int nNumClasses = instances.numClasses();
264        double[] fProbs = new double[nNumClasses];
265
266        for (int iClass = 0; iClass < nNumClasses; iClass++) {
267            fProbs[iClass] = 1.0;
268        }
269
270        for (int iClass = 0; iClass < nNumClasses; iClass++) {
271            double logfP = 0;
272
273            for (int iAttribute = 0; iAttribute < instances.numAttributes(); iAttribute++) {
274                double iCPT = 0;
275
276                for (int iParent = 0; iParent < bayesNet.getParentSet(iAttribute).getNrOfParents(); iParent++) {
277                    int nParent = bayesNet.getParentSet(iAttribute).getParent(iParent);
278
279                    if (nParent == instances.classIndex()) {
280                        iCPT = iCPT * nNumClasses + iClass;
281                    } else {
282                        iCPT = iCPT * instances.attribute(nParent).numValues() + instance.value(nParent);
283                    }
284                }
285
286                if (iAttribute == instances.classIndex()) {
287                    logfP += Math.log(bayesNet.m_Distributions[iAttribute][(int) iCPT].getProbability(iClass));
288                } else {
289                    logfP += instance.value(iAttribute) * Math.log(
290                      bayesNet.m_Distributions[iAttribute][(int) iCPT].getProbability(instance.value(1)));
291                }
292            }
293
294            fProbs[iClass] += logfP;
295        }
296
297        // Find maximum
298        double fMax = fProbs[0];
299        for (int iClass = 0; iClass < nNumClasses; iClass++) {
300            if (fProbs[iClass] > fMax) {
301                fMax = fProbs[iClass];
302            }
303        }
304        // transform from log-space to normal-space
305        for (int iClass = 0; iClass < nNumClasses; iClass++) {
306            fProbs[iClass] = Math.exp(fProbs[iClass] - fMax);
307        }
308
309        // Display probabilities
310        Utils.normalize(fProbs);
311
312        return fProbs;
313    } // distributionForInstance
314
315    /**
316     * Returns an enumeration describing the available options
317     *
318     * @return an enumeration of all the available options
319     */
320    public Enumeration listOptions() {
321        Vector newVector = new Vector(1);
322
323        newVector.addElement(new Option(
324            "\tWhether to use K2 prior.\n", 
325            "k2", 0, "-k2"));
326
327        Enumeration enu = super.listOptions();
328        while (enu.hasMoreElements()) {
329                newVector.addElement(enu.nextElement());
330        }
331
332        return newVector.elements();
333    } // listOptions
334
335    /**
336     * Parses a given list of options. <p/>
337     *
338     <!-- options-start -->
339     * Valid options are: <p/>
340     *
341     * <pre> -k2
342     *  Whether to use K2 prior.
343     * </pre>
344     *
345     * <pre> -A &lt;alpha&gt;
346     *  Initial count (alpha)
347     * </pre>
348     *
349     <!-- options-end -->
350     *
351     * @param options the list of options as an array of strings
352     * @throws Exception if an option is not supported
353     */
354    public void setOptions(String[] options) throws Exception {
355        setUseK2Prior(Utils.getFlag("k2", options));
356
357        super.setOptions(options);
358    } // setOptions
359
360    /**
361     * Gets the current settings of the classifier.
362     *
363     * @return an array of strings suitable for passing to setOptions
364     */
365    public String[] getOptions() {
366        String[] superOptions = super.getOptions();
367        String[] options = new String[1 + superOptions.length];
368        int current = 0;
369
370        if (isUseK2Prior())
371          options[current++] = "-k2";
372
373        // insert options from parent class
374        for (int iOption = 0; iOption < superOptions.length; iOption++) {
375                options[current++] = superOptions[iOption];
376        }
377
378        // Fill up rest with empty strings, not nulls!
379        while (current < options.length) {
380                options[current++] = "";
381        }
382
383        return options;
384    } // getOptions
385   
386    /**
387     * Returns the revision string.
388     *
389     * @return          the revision
390     */
391    public String getRevision() {
392      return RevisionUtils.extract("$Revision: 5987 $");
393    }
394} // class MultiNomialBMAEstimator
Note: See TracBrowser for help on using the repository browser.