source: src/main/java/weka/classifiers/bayes/BayesNet.java @ 18

Last change on this file since 18 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 35.2 KB
RevLine 
[4]1/*
2 * This program is free software; you can redistribute it and/or modify
3 * it under the terms of the GNU General Public License as published by
4 * the Free Software Foundation; either version 2 of the License, or
5 * (at your option) any later version.
6 *
7 * This program is distributed in the hope that it will be useful,
8 * but WITHOUT ANY WARRANTY; without even the implied warranty of
9 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 * GNU General Public License for more details.
11 *
12 * You should have received a copy of the GNU General Public License
13 * along with this program; if not, write to the Free Software
14 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 * BayesNet.java
19 * Copyright (C) 2001 University of Waikato, Hamilton, New Zealand
20 *
21 */
22package weka.classifiers.bayes;
23
24import weka.classifiers.Classifier;
25import weka.classifiers.AbstractClassifier;
26import weka.classifiers.bayes.net.ADNode;
27import weka.classifiers.bayes.net.BIFReader;
28import weka.classifiers.bayes.net.ParentSet;
29import weka.classifiers.bayes.net.estimate.BayesNetEstimator;
30import weka.classifiers.bayes.net.estimate.DiscreteEstimatorBayes;
31import weka.classifiers.bayes.net.estimate.SimpleEstimator;
32import weka.classifiers.bayes.net.search.SearchAlgorithm;
33import weka.classifiers.bayes.net.search.local.K2;
34import weka.classifiers.bayes.net.search.local.LocalScoreSearchAlgorithm;
35import weka.classifiers.bayes.net.search.local.Scoreable;
36import weka.core.AdditionalMeasureProducer;
37import weka.core.Attribute;
38import weka.core.Capabilities;
39import weka.core.Drawable;
40import weka.core.Instance;
41import weka.core.Instances;
42import weka.core.Option;
43import weka.core.OptionHandler;
44import weka.core.RevisionUtils;
45import weka.core.Utils;
46import weka.core.WeightedInstancesHandler;
47import weka.core.Capabilities.Capability;
48import weka.estimators.Estimator;
49import weka.filters.Filter;
50import weka.filters.supervised.attribute.Discretize;
51import weka.filters.unsupervised.attribute.ReplaceMissingValues;
52
53import java.util.Enumeration;
54import java.util.Vector;
55
56/**
57 <!-- globalinfo-start -->
58 * Bayes Network learning using various search algorithms and quality measures.<br/>
59 * Base class for a Bayes Network classifier. Provides datastructures (network structure, conditional probability distributions, etc.) and facilities common to Bayes Network learning algorithms like K2 and B.<br/>
60 * <br/>
61 * For more information see:<br/>
62 * <br/>
63 * http://sourceforge.net/projects/weka/files/documentation/WekaManual-3-7-0.pdf/download
64 * <p/>
65 <!-- globalinfo-end -->
66 *
67 <!-- options-start -->
68 * Valid options are: <p/>
69 *
70 * <pre> -D
71 *  Do not use ADTree data structure
72 * </pre>
73 *
74 * <pre> -B &lt;BIF file&gt;
75 *  BIF file to compare with
76 * </pre>
77 *
78 * <pre> -Q weka.classifiers.bayes.net.search.SearchAlgorithm
79 *  Search algorithm
80 * </pre>
81 *
82 * <pre> -E weka.classifiers.bayes.net.estimate.SimpleEstimator
83 *  Estimator algorithm
84 * </pre>
85 *
86 <!-- options-end -->
87 *
88 * @author Remco Bouckaert (rrb@xm.co.nz)
89 * @version $Revision: 5928 $
90 */
91public class BayesNet
92  extends AbstractClassifier
93  implements OptionHandler, WeightedInstancesHandler, Drawable, 
94             AdditionalMeasureProducer {
95
96  /** for serialization */
97  static final long serialVersionUID = 746037443258775954L;
98
99
100  /**
101   * The parent sets.
102   */
103  protected ParentSet[] m_ParentSets;
104
105  /**
106   * The attribute estimators containing CPTs.
107   */
108  public Estimator[][] m_Distributions;
109
110
111  /** filter used to quantize continuous variables, if any **/
112  protected Discretize m_DiscretizeFilter = null;
113
114  /** attribute index of a non-nominal attribute */
115  int m_nNonDiscreteAttribute = -1;
116
117  /** filter used to fill in missing values, if any **/
118  protected ReplaceMissingValues m_MissingValuesFilter = null; 
119
120  /**
121   * The number of classes
122   */
123  protected int m_NumClasses;
124
125  /**
126   * The dataset header for the purposes of printing out a semi-intelligible
127   * model
128   */
129  public Instances m_Instances;
130
131  /**
132   * Datastructure containing ADTree representation of the database.
133   * This may result in more efficient access to the data.
134   */
135  ADNode m_ADTree;
136
137  /**
138   * Bayes network to compare the structure with.
139   */
140  protected BIFReader m_otherBayesNet = null;
141
142  /**
143   * Use the experimental ADTree datastructure for calculating contingency tables
144   */
145  boolean m_bUseADTree = false;
146
147  /**
148   * Search algorithm used for learning the structure of a network.
149   */
150  SearchAlgorithm m_SearchAlgorithm = new K2();
151
152  /**
153   * Search algorithm used for learning the structure of a network.
154   */
155  BayesNetEstimator m_BayesNetEstimator = new SimpleEstimator();
156
157  /**
158   * Returns default capabilities of the classifier.
159   *
160   * @return      the capabilities of this classifier
161   */
162  public Capabilities getCapabilities() {
163    Capabilities result = super.getCapabilities();
164    result.disableAll();
165
166    // attributes
167    result.enable(Capability.NOMINAL_ATTRIBUTES);
168    result.enable(Capability.NUMERIC_ATTRIBUTES);
169    result.enable(Capability.MISSING_VALUES);
170
171    // class
172    result.enable(Capability.NOMINAL_CLASS);
173    result.enable(Capability.MISSING_CLASS_VALUES);
174
175    // instances
176    result.setMinimumNumberInstances(0);
177
178    return result;
179  }
180
181  /**
182   * Generates the classifier.
183   *
184   * @param instances set of instances serving as training data
185   * @throws Exception if the classifier has not been generated
186   * successfully
187   */
188  public void buildClassifier(Instances instances) throws Exception {
189
190    // can classifier handle the data?
191    getCapabilities().testWithFail(instances);
192
193    // remove instances with missing class
194    instances = new Instances(instances);
195    instances.deleteWithMissingClass();
196
197    // ensure we have a data set with discrete variables only and with no missing values
198    instances = normalizeDataSet(instances);
199
200    // Copy the instances
201    m_Instances = new Instances(instances);
202
203    // sanity check: need more than 1 variable in datat set
204    m_NumClasses = instances.numClasses();
205
206    // initialize ADTree
207    if (m_bUseADTree) {
208      m_ADTree = ADNode.makeADTree(instances);
209      //      System.out.println("Oef, done!");
210    }
211
212    // build the network structure
213    initStructure();
214
215    // build the network structure
216    buildStructure();
217
218    // build the set of CPTs
219    estimateCPTs();
220
221    // Save space
222    // m_Instances = new Instances(m_Instances, 0);
223    m_ADTree = null;
224  } // buildClassifier
225
226  /** ensure that all variables are nominal and that there are no missing values
227   * @param instances data set to check and quantize and/or fill in missing values
228   * @return filtered instances
229   * @throws Exception if a filter (Discretize, ReplaceMissingValues) fails
230   */
231  protected Instances normalizeDataSet(Instances instances) throws Exception {
232    m_DiscretizeFilter = null;
233    m_MissingValuesFilter = null;
234
235    boolean bHasNonNominal = false;
236    boolean bHasMissingValues = false;
237
238    Enumeration enu = instances.enumerateAttributes();         
239    while (enu.hasMoreElements()) {
240      Attribute attribute = (Attribute) enu.nextElement();
241      if (attribute.type() != Attribute.NOMINAL) {
242        m_nNonDiscreteAttribute = attribute.index();
243        bHasNonNominal = true;
244        //throw new UnsupportedAttributeTypeException("BayesNet handles nominal variables only. Non-nominal variable in dataset detected.");
245      }
246      Enumeration enum2 = instances.enumerateInstances();
247      while (enum2.hasMoreElements()) {
248        if (((Instance) enum2.nextElement()).isMissing(attribute)) {
249          bHasMissingValues = true;
250          // throw new NoSupportForMissingValuesException("BayesNet: no missing values, please.");
251        }
252      }
253    }
254
255    if (bHasNonNominal) {
256      System.err.println("Warning: discretizing data set");
257      m_DiscretizeFilter = new Discretize();
258      m_DiscretizeFilter.setInputFormat(instances);
259      instances = Filter.useFilter(instances, m_DiscretizeFilter);
260    }
261
262    if (bHasMissingValues) {
263      System.err.println("Warning: filling in missing values in data set");
264      m_MissingValuesFilter = new ReplaceMissingValues();
265      m_MissingValuesFilter.setInputFormat(instances);
266      instances = Filter.useFilter(instances, m_MissingValuesFilter);
267    }
268    return instances;
269  } // normalizeDataSet
270
271  /** ensure that all variables are nominal and that there are no missing values
272   * @param instance instance to check and quantize and/or fill in missing values
273   * @return filtered instance
274   * @throws Exception if a filter (Discretize, ReplaceMissingValues) fails
275   */
276  protected Instance normalizeInstance(Instance instance) throws Exception {
277    if ((m_DiscretizeFilter != null) &&
278        (instance.attribute(m_nNonDiscreteAttribute).type() != Attribute.NOMINAL)) {
279      m_DiscretizeFilter.input(instance);
280      instance = m_DiscretizeFilter.output();
281    }
282    if (m_MissingValuesFilter != null) {
283      m_MissingValuesFilter.input(instance);
284      instance = m_MissingValuesFilter.output();
285    } else {
286      // is there a missing value in this instance?
287      // this can happen when there is no missing value in the training set
288      for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) {
289        if (iAttribute != instance.classIndex() && instance.isMissing(iAttribute)) {
290          System.err.println("Warning: Found missing value in test set, filling in values.");
291          m_MissingValuesFilter = new ReplaceMissingValues();
292          m_MissingValuesFilter.setInputFormat(m_Instances);
293          Filter.useFilter(m_Instances, m_MissingValuesFilter);
294          m_MissingValuesFilter.input(instance);
295          instance = m_MissingValuesFilter.output();
296          iAttribute = m_Instances.numAttributes();
297        }
298      }
299    }
300    return instance;
301  } // normalizeInstance
302
303  /**
304   * Init structure initializes the structure to an empty graph or a Naive Bayes
305   * graph (depending on the -N flag).
306   *
307   * @throws Exception in case of an error
308   */
309  public void initStructure() throws Exception {
310
311    // initialize topological ordering
312    //    m_nOrder = new int[m_Instances.numAttributes()];
313    //    m_nOrder[0] = m_Instances.classIndex();
314
315    int nAttribute = 0;
316
317    for (int iOrder = 1; iOrder < m_Instances.numAttributes(); iOrder++) {
318      if (nAttribute == m_Instances.classIndex()) {
319        nAttribute++;
320      }
321
322      //      m_nOrder[iOrder] = nAttribute++;
323    }
324
325    // reserve memory
326    m_ParentSets = new ParentSet[m_Instances.numAttributes()];
327
328    for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) {
329      m_ParentSets[iAttribute] = new ParentSet(m_Instances.numAttributes());
330    }
331  } // initStructure
332
333  /**
334   * buildStructure determines the network structure/graph of the network.
335   * The default behavior is creating a network where all nodes have the first
336   * node as its parent (i.e., a BayesNet that behaves like a naive Bayes classifier).
337   * This method can be overridden by derived classes to restrict the class
338   * of network structures that are acceptable.
339   *
340   * @throws Exception in case of an error
341   */
342  public void buildStructure() throws Exception {
343    m_SearchAlgorithm.buildStructure(this, m_Instances);
344  } // buildStructure
345
346  /**
347   * estimateCPTs estimates the conditional probability tables for the Bayes
348   * Net using the network structure.
349   *
350   * @throws Exception in case of an error
351   */
352  public void estimateCPTs() throws Exception {
353    m_BayesNetEstimator.estimateCPTs(this);
354  } // estimateCPTs
355
356  /**
357   * initializes the conditional probabilities
358   *
359   * @throws Exception in case of an error
360   */
361  public void initCPTs() throws Exception {
362    m_BayesNetEstimator.initCPTs(this);
363  } // estimateCPTs
364
365  /**
366   * Updates the classifier with the given instance.
367   *
368   * @param instance the new training instance to include in the model
369   * @throws Exception if the instance could not be incorporated in
370   * the model.
371   */
372  public void updateClassifier(Instance instance) throws Exception {
373    instance = normalizeInstance(instance);
374    m_BayesNetEstimator.updateClassifier(this, instance);
375  } // updateClassifier
376
377  /**
378   * Calculates the class membership probabilities for the given test
379   * instance.
380   *
381   * @param instance the instance to be classified
382   * @return predicted class probability distribution
383   * @throws Exception if there is a problem generating the prediction
384   */
385  public double[] distributionForInstance(Instance instance) throws Exception {
386    instance = normalizeInstance(instance);
387    return m_BayesNetEstimator.distributionForInstance(this, instance);
388  } // distributionForInstance
389
390  /**
391   * Calculates the counts for Dirichlet distribution for the
392   * class membership probabilities for the given test instance.
393   *
394   * @param instance the instance to be classified
395   * @return counts for Dirichlet distribution for class probability
396   * @throws Exception if there is a problem generating the prediction
397   */
398  public double[] countsForInstance(Instance instance) throws Exception {
399    double[] fCounts = new double[m_NumClasses];
400
401    for (int iClass = 0; iClass < m_NumClasses; iClass++) {
402      fCounts[iClass] = 0.0;
403    }
404
405    for (int iClass = 0; iClass < m_NumClasses; iClass++) {
406      double fCount = 0;
407
408      for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) {
409        double iCPT = 0;
410
411        for (int iParent = 0; iParent < m_ParentSets[iAttribute].getNrOfParents(); iParent++) {
412          int nParent = m_ParentSets[iAttribute].getParent(iParent);
413
414          if (nParent == m_Instances.classIndex()) {
415            iCPT = iCPT * m_NumClasses + iClass;
416          } else {
417            iCPT = iCPT * m_Instances.attribute(nParent).numValues() + instance.value(nParent);
418          }
419        }
420
421        if (iAttribute == m_Instances.classIndex()) {
422          fCount += ((DiscreteEstimatorBayes) m_Distributions[iAttribute][(int) iCPT]).getCount(iClass);
423        } else {
424          fCount
425          += ((DiscreteEstimatorBayes) m_Distributions[iAttribute][(int) iCPT]).getCount(
426              instance.value(iAttribute));
427        }
428      }
429
430      fCounts[iClass] += fCount;
431    }
432    return fCounts;
433  } // countsForInstance
434
435  /**
436   * Returns an enumeration describing the available options
437   *
438   * @return an enumeration of all the available options
439   */
440  public Enumeration listOptions() {
441    Vector newVector = new Vector(4);
442
443    newVector.addElement(new Option("\tDo not use ADTree data structure\n", "D", 0, "-D"));
444    newVector.addElement(new Option("\tBIF file to compare with\n", "B", 1, "-B <BIF file>"));
445    newVector.addElement(new Option("\tSearch algorithm\n", "Q", 1, "-Q weka.classifiers.bayes.net.search.SearchAlgorithm"));
446    newVector.addElement(new Option("\tEstimator algorithm\n", "E", 1, "-E weka.classifiers.bayes.net.estimate.SimpleEstimator"));
447
448    return newVector.elements();
449  } // listOptions
450
451  /**
452   * Parses a given list of options. <p>
453   *
454     <!-- options-start -->
455   * Valid options are: <p/>
456   *
457   * <pre> -D
458   *  Do not use ADTree data structure
459   * </pre>
460   *
461   * <pre> -B &lt;BIF file&gt;
462   *  BIF file to compare with
463   * </pre>
464   *
465   * <pre> -Q weka.classifiers.bayes.net.search.SearchAlgorithm
466   *  Search algorithm
467   * </pre>
468   *
469   * <pre> -E weka.classifiers.bayes.net.estimate.SimpleEstimator
470   *  Estimator algorithm
471   * </pre>
472   *
473     <!-- options-end -->
474   *
475   * @param options the list of options as an array of strings
476   * @throws Exception if an option is not supported
477   */
478  public void setOptions(String[] options) throws Exception {
479    m_bUseADTree = !(Utils.getFlag('D', options));
480
481    String sBIFFile = Utils.getOption('B', options);
482    if (sBIFFile != null && !sBIFFile.equals("")) {
483      setBIFFile(sBIFFile);
484    }
485
486    String searchAlgorithmName = Utils.getOption('Q', options);
487    if (searchAlgorithmName.length() != 0) {
488      setSearchAlgorithm(
489          (SearchAlgorithm) Utils.forName(
490              SearchAlgorithm.class,
491              searchAlgorithmName,
492              partitionOptions(options)));
493    }
494    else {
495      setSearchAlgorithm(new K2());
496    }
497
498
499    String estimatorName = Utils.getOption('E', options);
500    if (estimatorName.length() != 0) {
501      setEstimator(
502          (BayesNetEstimator) Utils.forName(
503              BayesNetEstimator.class,
504              estimatorName,
505              Utils.partitionOptions(options)));
506    }
507    else {
508      setEstimator(new SimpleEstimator());
509    }
510
511    Utils.checkForRemainingOptions(options);
512  } // setOptions
513
514  /**
515   * Returns the secondary set of options (if any) contained in
516   * the supplied options array. The secondary set is defined to
517   * be any options after the first "--" but before the "-E". These
518   * options are removed from the original options array.
519   *
520   * @param options the input array of options
521   * @return the array of secondary options
522   */
523  public static String [] partitionOptions(String [] options) {
524
525    for (int i = 0; i < options.length; i++) {
526      if (options[i].equals("--")) {
527        // ensure it follows by a -E option
528        int j = i;
529        while ((j < options.length) && !(options[j].equals("-E"))) {
530          j++;
531        }
532        /*      if (j >= options.length) {
533          return new String[0];
534          } */
535        options[i++] = "";
536        String [] result = new String [options.length - i];
537        j = i;
538        while ((j < options.length) && !(options[j].equals("-E"))) {
539          result[j - i] = options[j];
540          options[j] = "";
541          j++;
542        }
543        while(j < options.length) {
544          result[j - i] = "";
545          j++;
546        }               
547        return result;
548      }
549    }
550    return new String [0];
551  }
552
553
554  /**
555   * Gets the current settings of the classifier.
556   *
557   * @return an array of strings suitable for passing to setOptions
558   */
559  public String[] getOptions() {
560    String[] searchOptions = m_SearchAlgorithm.getOptions();
561    String[] estimatorOptions = m_BayesNetEstimator.getOptions();
562    String[] options = new String[11 + searchOptions.length + estimatorOptions.length];
563    int current = 0;
564
565    if (!m_bUseADTree) {
566      options[current++] = "-D";
567    }
568
569    if (m_otherBayesNet != null) {
570      options[current++] = "-B";
571      options[current++] = ((BIFReader) m_otherBayesNet).getFileName();
572    }
573
574    options[current++] = "-Q";
575    options[current++] = "" + getSearchAlgorithm().getClass().getName();
576    options[current++] = "--";
577    for (int iOption = 0; iOption < searchOptions.length; iOption++) {
578      options[current++] = searchOptions[iOption];
579    }
580
581    options[current++] = "-E";
582    options[current++] = "" + getEstimator().getClass().getName();
583    options[current++] = "--";
584    for (int iOption = 0; iOption < estimatorOptions.length; iOption++) {
585      options[current++] = estimatorOptions[iOption];
586    }
587
588    // Fill up rest with empty strings, not nulls!
589    while (current < options.length) {
590      options[current++] = "";
591    }
592
593    return options;
594  } // getOptions
595
596  /**
597   * Set the SearchAlgorithm used in searching for network structures.
598   * @param newSearchAlgorithm the SearchAlgorithm to use.
599   */
600  public void setSearchAlgorithm(SearchAlgorithm newSearchAlgorithm) {
601    m_SearchAlgorithm = newSearchAlgorithm;
602  }
603
604  /**
605   * Get the SearchAlgorithm used as the search algorithm
606   * @return the SearchAlgorithm used as the search algorithm
607   */
608  public SearchAlgorithm getSearchAlgorithm() {
609    return m_SearchAlgorithm;
610  }
611
612  /**
613   * Set the Estimator Algorithm used in calculating the CPTs
614   * @param newBayesNetEstimator the Estimator to use.
615   */
616  public void setEstimator(BayesNetEstimator newBayesNetEstimator) {
617    m_BayesNetEstimator = newBayesNetEstimator;
618  }
619
620  /**
621   * Get the BayesNetEstimator used for calculating the CPTs
622   * @return the BayesNetEstimator used.
623   */
624  public BayesNetEstimator getEstimator() {
625    return m_BayesNetEstimator;
626  }
627
628  /**
629   * Set whether ADTree structure is used or not
630   * @param bUseADTree true if an ADTree structure is used
631   */
632  public void setUseADTree(boolean bUseADTree) {
633    m_bUseADTree = bUseADTree;
634  }
635
636  /**
637   * Method declaration
638   * @return whether ADTree structure is used or not
639   */
640  public boolean getUseADTree() {
641    return m_bUseADTree;
642  }
643
644  /**
645   * Set name of network in BIF file to compare with
646   * @param sBIFFile the name of the BIF file
647   */
648  public void setBIFFile(String sBIFFile) {
649    try {
650      m_otherBayesNet = new BIFReader().processFile(sBIFFile);
651    } catch (Throwable t) {
652      m_otherBayesNet = null;
653    }
654  }
655
656  /**
657   * Get name of network in BIF file to compare with
658   * @return BIF file name
659   */
660  public String getBIFFile() {
661    if (m_otherBayesNet != null) {
662      return m_otherBayesNet.getFileName();
663    }
664    return "";
665  }
666
667
668  /**
669   * Returns a description of the classifier.
670   *
671   * @return a description of the classifier as a string.
672   */
673  public String toString() {
674    StringBuffer text = new StringBuffer();
675
676    text.append("Bayes Network Classifier");
677    text.append("\n" + (m_bUseADTree ? "Using " : "not using ") + "ADTree");
678
679    if (m_Instances == null) {
680      text.append(": No model built yet.");
681    } else {
682
683      // flatten BayesNet down to text
684      text.append("\n#attributes=");
685      text.append(m_Instances.numAttributes());
686      text.append(" #classindex=");
687      text.append(m_Instances.classIndex());
688      text.append("\nNetwork structure (nodes followed by parents)\n");
689
690      for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) {
691        text.append(
692            m_Instances.attribute(iAttribute).name()
693            + "("
694                + m_Instances.attribute(iAttribute).numValues()
695                + "): ");
696
697        for (int iParent = 0; iParent < m_ParentSets[iAttribute].getNrOfParents(); iParent++) {
698          text.append(m_Instances.attribute(m_ParentSets[iAttribute].getParent(iParent)).name() + " ");
699        }
700
701        text.append("\n");
702
703        // Description of distributions tends to be too much detail, so it is commented out here
704        // for (int iParent = 0; iParent < m_ParentSets[iAttribute].GetCardinalityOfParents(); iParent++) {
705        // text.append('(' + m_Distributions[iAttribute][iParent].toString() + ')');
706        // }
707        // text.append("\n");
708      }
709
710      text.append("LogScore Bayes: " + measureBayesScore() + "\n");
711      text.append("LogScore BDeu: " + measureBDeuScore() + "\n");
712      text.append("LogScore MDL: " + measureMDLScore() + "\n");
713      text.append("LogScore ENTROPY: " + measureEntropyScore() + "\n");
714      text.append("LogScore AIC: " + measureAICScore() + "\n");
715
716      if (m_otherBayesNet != null) {
717        text.append(
718            "Missing: "
719            + m_otherBayesNet.missingArcs(this)
720            + " Extra: "
721            + m_otherBayesNet.extraArcs(this)
722            + " Reversed: "
723            + m_otherBayesNet.reversedArcs(this)
724            + "\n");
725        text.append("Divergence: " + m_otherBayesNet.divergence(this) + "\n");
726      }
727    }
728
729    return text.toString();
730  } // toString
731
732
733  /**
734   *  Returns the type of graph this classifier
735   *  represents.
736   *  @return Drawable.TREE
737   */   
738  public int graphType() {
739    return Drawable.BayesNet;
740  }
741
742  /**
743   * Returns a BayesNet graph in XMLBIF ver 0.3 format.
744   * @return String representing this BayesNet in XMLBIF ver  0.3
745   * @throws Exception in case BIF generation fails
746   */
747  public String graph() throws Exception {
748    return toXMLBIF03();
749  }
750
751  public String getBIFHeader() {
752    StringBuffer text = new StringBuffer();
753    text.append("<?xml version=\"1.0\"?>\n");
754    text.append("<!-- DTD for the XMLBIF 0.3 format -->\n");
755    text.append("<!DOCTYPE BIF [\n");
756    text.append("       <!ELEMENT BIF ( NETWORK )*>\n");
757    text.append("             <!ATTLIST BIF VERSION CDATA #REQUIRED>\n");
758    text.append("       <!ELEMENT NETWORK ( NAME, ( PROPERTY | VARIABLE | DEFINITION )* )>\n");
759    text.append("       <!ELEMENT NAME (#PCDATA)>\n");
760    text.append("       <!ELEMENT VARIABLE ( NAME, ( OUTCOME |  PROPERTY )* ) >\n");
761    text.append("             <!ATTLIST VARIABLE TYPE (nature|decision|utility) \"nature\">\n");
762    text.append("       <!ELEMENT OUTCOME (#PCDATA)>\n");
763    text.append("       <!ELEMENT DEFINITION ( FOR | GIVEN | TABLE | PROPERTY )* >\n");
764    text.append("       <!ELEMENT FOR (#PCDATA)>\n");
765    text.append("       <!ELEMENT GIVEN (#PCDATA)>\n");
766    text.append("       <!ELEMENT TABLE (#PCDATA)>\n");
767    text.append("       <!ELEMENT PROPERTY (#PCDATA)>\n");
768    text.append("]>\n");
769    return text.toString();
770  } // getBIFHeader
771
772  /**
773   * Returns a description of the classifier in XML BIF 0.3 format.
774   * See http://www-2.cs.cmu.edu/~fgcozman/Research/InterchangeFormat/
775   * for details on XML BIF.
776   * @return an XML BIF 0.3 description of the classifier as a string.
777   */
778  public String toXMLBIF03() {
779    if (m_Instances == null) {
780      return("<!--No model built yet-->");
781    }
782
783    StringBuffer text = new StringBuffer();
784    text.append(getBIFHeader());
785    text.append("\n");
786    text.append("\n");
787    text.append("<BIF VERSION=\"0.3\">\n");
788    text.append("<NETWORK>\n");
789    text.append("<NAME>" + XMLNormalize(m_Instances.relationName()) + "</NAME>\n");
790    for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) {
791      text.append("<VARIABLE TYPE=\"nature\">\n");
792      text.append("<NAME>" + XMLNormalize(m_Instances.attribute(iAttribute).name()) + "</NAME>\n");
793      for (int iValue = 0; iValue < m_Instances.attribute(iAttribute).numValues(); iValue++) {
794        text.append("<OUTCOME>" + XMLNormalize(m_Instances.attribute(iAttribute).value(iValue)) + "</OUTCOME>\n");
795      }
796      text.append("</VARIABLE>\n");
797    }
798
799    for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) {
800      text.append("<DEFINITION>\n");
801      text.append("<FOR>" + XMLNormalize(m_Instances.attribute(iAttribute).name()) + "</FOR>\n");
802      for (int iParent = 0; iParent < m_ParentSets[iAttribute].getNrOfParents(); iParent++) {
803        text.append("<GIVEN>"
804            + XMLNormalize(m_Instances.attribute(m_ParentSets[iAttribute].getParent(iParent)).name()) +
805        "</GIVEN>\n");
806      }
807      text.append("<TABLE>\n");
808      for (int iParent = 0; iParent < m_ParentSets[iAttribute].getCardinalityOfParents(); iParent++) {
809        for (int iValue = 0; iValue < m_Instances.attribute(iAttribute).numValues(); iValue++) {
810          text.append(m_Distributions[iAttribute][iParent].getProbability(iValue));
811          text.append(' ');
812        }
813        text.append('\n');
814      }
815      text.append("</TABLE>\n");
816      text.append("</DEFINITION>\n");
817    }
818    text.append("</NETWORK>\n");
819    text.append("</BIF>\n");
820    return text.toString();
821  } // toXMLBIF03
822
823
824  /** XMLNormalize converts the five standard XML entities in a string
825   * g.e. the string V&D's is returned as V&amp;D&apos;s
826   * @param sStr string to normalize
827   * @return normalized string
828   */
829  protected String XMLNormalize(String sStr) {
830    StringBuffer sStr2 = new StringBuffer();
831    for (int iStr = 0; iStr < sStr.length(); iStr++) {
832      char c = sStr.charAt(iStr);
833      switch (c) {
834        case '&': sStr2.append("&amp;"); break;
835        case '\'': sStr2.append("&apos;"); break;
836        case '\"': sStr2.append("&quot;"); break;
837        case '<': sStr2.append("&lt;"); break;
838        case '>': sStr2.append("&gt;"); break;
839        default:
840          sStr2.append(c);
841      }
842    }
843    return sStr2.toString();
844  } // XMLNormalize
845
846
847  /**
848   * @return a string to describe the UseADTreeoption.
849   */
850  public String useADTreeTipText() {
851    return "When ADTree (the data structure for increasing speed on counts,"
852    + " not to be confused with the classifier under the same name) is used"
853    + " learning time goes down typically. However, because ADTrees are memory"
854    + " intensive, memory problems may occur. Switching this option off makes"
855    + " the structure learning algorithms slower, and run with less memory."
856    + " By default, ADTrees are used.";
857  }
858
859  /**
860   * @return a string to describe the SearchAlgorithm.
861   */
862  public String searchAlgorithmTipText() {
863    return "Select method used for searching network structures.";
864  }
865
866  /**
867   * This will return a string describing the BayesNetEstimator.
868   * @return The string.
869   */
870  public String estimatorTipText() {
871    return "Select Estimator algorithm for finding the conditional probability tables"
872    + " of the Bayes Network.";
873  }
874
875  /**
876   * @return a string to describe the BIFFile.
877   */
878  public String BIFFileTipText() {
879    return "Set the name of a file in BIF XML format. A Bayes network learned"
880    + " from data can be compared with the Bayes network represented by the BIF file."
881    + " Statistics calculated are o.a. the number of missing and extra arcs.";
882  }
883
884  /**
885   * This will return a string describing the classifier.
886   * @return The string.
887   */
888  public String globalInfo() {
889    return 
890    "Bayes Network learning using various search algorithms and "
891    + "quality measures.\n"
892    + "Base class for a Bayes Network classifier. Provides "
893    + "datastructures (network structure, conditional probability "
894    + "distributions, etc.) and facilities common to Bayes Network "
895    + "learning algorithms like K2 and B.\n\n"
896    + "For more information see:\n\n"
897    + "http://www.cs.waikato.ac.nz/~remco/weka.pdf";
898  }
899
900  /**
901   * Main method for testing this class.
902   *
903   * @param argv the options
904   */
905  public static void main(String[] argv) {
906    runClassifier(new BayesNet(), argv);
907  } // main
908
909  /** get name of the Bayes network
910   * @return name of the Bayes net
911   */
912  public String getName() {
913    return m_Instances.relationName();
914  }
915
916  /** get number of nodes in the Bayes network
917   * @return number of nodes
918   */
919  public int getNrOfNodes() {
920    return m_Instances.numAttributes();
921  }
922
923  /** get name of a node in the Bayes network
924   * @param iNode index of the node
925   * @return name of the specified node
926   */
927  public String getNodeName(int iNode) {
928    return m_Instances.attribute(iNode).name();
929  }
930
931  /** get number of values a node can take
932   * @param iNode index of the node
933   * @return cardinality of the specified node
934   */
935  public int getCardinality(int iNode) {
936    return m_Instances.attribute(iNode).numValues();
937  }
938
939  /** get name of a particular value of a node
940   * @param iNode index of the node
941   * @param iValue index of the value
942   * @return cardinality of the specified node
943   */
944  public String getNodeValue(int iNode, int iValue) {
945    return m_Instances.attribute(iNode).value(iValue);
946  }
947
948  /** get number of parents of a node in the network structure
949   * @param iNode index of the node
950   * @return number of parents of the specified node
951   */
952  public int getNrOfParents(int iNode) {
953    return m_ParentSets[iNode].getNrOfParents();
954  }
955
956  /** get node index of a parent of a node in the network structure
957   * @param iNode index of the node
958   * @param iParent index of the parents, e.g., 0 is the first parent, 1 the second parent, etc.
959   * @return node index of the iParent's parent of the specified node
960   */
961  public int getParent(int iNode, int iParent) {
962    return m_ParentSets[iNode].getParent(iParent);
963  }
964
965  /** Get full set of parent sets.
966   * @return parent sets;
967   */
968  public ParentSet[] getParentSets() { 
969    return m_ParentSets;
970  }
971
972  /** Get full set of estimators.
973   * @return estimators;
974   */
975  public Estimator[][] getDistributions() {
976    return m_Distributions;
977  }
978
979  /** get number of values the collection of parents of a node can take
980   * @param iNode index of the node
981   * @return cardinality of the parent set of the specified node
982   */
983  public int getParentCardinality(int iNode) {
984    return m_ParentSets[iNode].getCardinalityOfParents();
985  }
986
987  /** get particular probability of the conditional probability distribtion
988   * of a node given its parents.
989   * @param iNode index of the node
990   * @param iParent index of the parent set, 0 <= iParent <= getParentCardinality(iNode)
991   * @param iValue index of the value, 0 <= iValue <= getCardinality(iNode)
992   * @return probability
993   */
994  public double getProbability(int iNode, int iParent, int iValue) {
995    return m_Distributions[iNode][iParent].getProbability(iValue);
996  }
997
998  /** get the parent set of a node
999   * @param iNode index of the node
1000   * @return Parent set of the specified node.
1001   */
1002  public ParentSet getParentSet(int iNode) {
1003    return m_ParentSets[iNode];
1004  }
1005
1006  /** get ADTree strucrture containing efficient representation of counts.
1007   * @return ADTree strucrture
1008   */
1009  public ADNode getADTree() { return m_ADTree;}
1010
1011  // implementation of AdditionalMeasureProducer interface
1012  /**
1013   * Returns an enumeration of the measure names. Additional measures
1014   * must follow the naming convention of starting with "measure", eg.
1015   * double measureBlah()
1016   * @return an enumeration of the measure names
1017   */
1018  public Enumeration enumerateMeasures() {
1019    Vector newVector = new Vector(4);
1020    newVector.addElement("measureExtraArcs");
1021    newVector.addElement("measureMissingArcs");
1022    newVector.addElement("measureReversedArcs");
1023    newVector.addElement("measureDivergence");
1024    newVector.addElement("measureBayesScore");
1025    newVector.addElement("measureBDeuScore");
1026    newVector.addElement("measureMDLScore");
1027    newVector.addElement("measureAICScore");
1028    newVector.addElement("measureEntropyScore");
1029    return newVector.elements();
1030  } // enumerateMeasures
1031
1032  public double measureExtraArcs() {
1033    if (m_otherBayesNet != null) {
1034      return m_otherBayesNet.extraArcs(this); 
1035    }
1036    return 0;
1037  } // measureExtraArcs
1038
1039  public double measureMissingArcs() {
1040    if (m_otherBayesNet != null) {
1041      return m_otherBayesNet.missingArcs(this); 
1042    }
1043    return 0;
1044  } // measureMissingArcs
1045
1046  public double measureReversedArcs() {
1047    if (m_otherBayesNet != null) {
1048      return m_otherBayesNet.reversedArcs(this); 
1049    }
1050    return 0;
1051  } // measureReversedArcs
1052
1053  public double measureDivergence() {
1054    if (m_otherBayesNet != null) {
1055      return m_otherBayesNet.divergence(this); 
1056    }
1057    return 0;
1058  } // measureDivergence
1059
1060  public double measureBayesScore() {
1061    LocalScoreSearchAlgorithm s = new LocalScoreSearchAlgorithm(this, m_Instances);
1062    return s.logScore(Scoreable.BAYES);
1063  } // measureBayesScore
1064
1065  public double measureBDeuScore() {
1066    LocalScoreSearchAlgorithm s = new LocalScoreSearchAlgorithm(this, m_Instances);
1067    return s.logScore(Scoreable.BDeu);
1068  } // measureBDeuScore
1069
1070  public double measureMDLScore() {
1071    LocalScoreSearchAlgorithm s = new LocalScoreSearchAlgorithm(this, m_Instances);
1072    return s.logScore(Scoreable.MDL);
1073  } // measureMDLScore
1074
1075  public double measureAICScore() {
1076    LocalScoreSearchAlgorithm s = new LocalScoreSearchAlgorithm(this, m_Instances);
1077    return s.logScore(Scoreable.AIC);
1078  } // measureAICScore
1079
1080  public double measureEntropyScore() {
1081    LocalScoreSearchAlgorithm s = new LocalScoreSearchAlgorithm(this, m_Instances);
1082    return s.logScore(Scoreable.ENTROPY);
1083  } // measureEntropyScore
1084
1085  /**
1086   * Returns the value of the named measure
1087   * @param measureName the name of the measure to query for its value
1088   * @return the value of the named measure
1089   * @throws IllegalArgumentException if the named measure is not supported
1090   */
1091  public double getMeasure(String measureName) {
1092    if (measureName.equals("measureExtraArcs")) {
1093      return measureExtraArcs();
1094    }
1095    if (measureName.equals("measureMissingArcs")) {
1096      return measureMissingArcs();
1097    }
1098    if (measureName.equals("measureReversedArcs")) {
1099      return measureReversedArcs();
1100    }
1101    if (measureName.equals("measureDivergence")) {
1102      return measureDivergence();
1103    }
1104    if (measureName.equals("measureBayesScore")) {
1105      return measureBayesScore();
1106    }
1107    if (measureName.equals("measureBDeuScore")) {
1108      return measureBDeuScore();
1109    }
1110    if (measureName.equals("measureMDLScore")) {
1111      return measureMDLScore();
1112    }
1113    if (measureName.equals("measureAICScore")) {
1114      return measureAICScore();
1115    }
1116    if (measureName.equals("measureEntropyScore")) {
1117      return measureEntropyScore();
1118    }
1119    return 0;
1120  } // getMeasure
1121
1122  /**
1123   * Returns the revision string.
1124   *
1125   * @return            the revision
1126   */
1127  public String getRevision() {
1128    return RevisionUtils.extract("$Revision: 5928 $");
1129  }
1130} // class BayesNet
Note: See TracBrowser for help on using the repository browser.