/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* BayesNet.java
* Copyright (C) 2001 University of Waikato, Hamilton, New Zealand
*
*/
package weka.classifiers.bayes;
import weka.classifiers.Classifier;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.bayes.net.ADNode;
import weka.classifiers.bayes.net.BIFReader;
import weka.classifiers.bayes.net.ParentSet;
import weka.classifiers.bayes.net.estimate.BayesNetEstimator;
import weka.classifiers.bayes.net.estimate.DiscreteEstimatorBayes;
import weka.classifiers.bayes.net.estimate.SimpleEstimator;
import weka.classifiers.bayes.net.search.SearchAlgorithm;
import weka.classifiers.bayes.net.search.local.K2;
import weka.classifiers.bayes.net.search.local.LocalScoreSearchAlgorithm;
import weka.classifiers.bayes.net.search.local.Scoreable;
import weka.core.AdditionalMeasureProducer;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Drawable;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.core.Capabilities.Capability;
import weka.estimators.Estimator;
import weka.filters.Filter;
import weka.filters.supervised.attribute.Discretize;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;
import java.util.Enumeration;
import java.util.Vector;
/**
* Bayes Network learning using various search algorithms and quality measures.
* Base class for a Bayes Network classifier. Provides datastructures (network structure, conditional probability distributions, etc.) and facilities common to Bayes Network learning algorithms like K2 and B.
*
* For more information see:
*
* http://sourceforge.net/projects/weka/files/documentation/WekaManual-3-7-0.pdf/download
*
-D * Do not use ADTree data structure ** *
-B <BIF file> * BIF file to compare with ** *
-Q weka.classifiers.bayes.net.search.SearchAlgorithm * Search algorithm ** *
-E weka.classifiers.bayes.net.estimate.SimpleEstimator * Estimator algorithm ** * * @author Remco Bouckaert (rrb@xm.co.nz) * @version $Revision: 5928 $ */ public class BayesNet extends AbstractClassifier implements OptionHandler, WeightedInstancesHandler, Drawable, AdditionalMeasureProducer { /** for serialization */ static final long serialVersionUID = 746037443258775954L; /** * The parent sets. */ protected ParentSet[] m_ParentSets; /** * The attribute estimators containing CPTs. */ public Estimator[][] m_Distributions; /** filter used to quantize continuous variables, if any **/ protected Discretize m_DiscretizeFilter = null; /** attribute index of a non-nominal attribute */ int m_nNonDiscreteAttribute = -1; /** filter used to fill in missing values, if any **/ protected ReplaceMissingValues m_MissingValuesFilter = null; /** * The number of classes */ protected int m_NumClasses; /** * The dataset header for the purposes of printing out a semi-intelligible * model */ public Instances m_Instances; /** * Datastructure containing ADTree representation of the database. * This may result in more efficient access to the data. */ ADNode m_ADTree; /** * Bayes network to compare the structure with. */ protected BIFReader m_otherBayesNet = null; /** * Use the experimental ADTree datastructure for calculating contingency tables */ boolean m_bUseADTree = false; /** * Search algorithm used for learning the structure of a network. */ SearchAlgorithm m_SearchAlgorithm = new K2(); /** * Search algorithm used for learning the structure of a network. */ BayesNetEstimator m_BayesNetEstimator = new SimpleEstimator(); /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); result.disableAll(); // attributes result.enable(Capability.NOMINAL_ATTRIBUTES); result.enable(Capability.NUMERIC_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); // class result.enable(Capability.NOMINAL_CLASS); result.enable(Capability.MISSING_CLASS_VALUES); // instances result.setMinimumNumberInstances(0); return result; } /** * Generates the classifier. * * @param instances set of instances serving as training data * @throws Exception if the classifier has not been generated * successfully */ public void buildClassifier(Instances instances) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(instances); // remove instances with missing class instances = new Instances(instances); instances.deleteWithMissingClass(); // ensure we have a data set with discrete variables only and with no missing values instances = normalizeDataSet(instances); // Copy the instances m_Instances = new Instances(instances); // sanity check: need more than 1 variable in datat set m_NumClasses = instances.numClasses(); // initialize ADTree if (m_bUseADTree) { m_ADTree = ADNode.makeADTree(instances); // System.out.println("Oef, done!"); } // build the network structure initStructure(); // build the network structure buildStructure(); // build the set of CPTs estimateCPTs(); // Save space // m_Instances = new Instances(m_Instances, 0); m_ADTree = null; } // buildClassifier /** ensure that all variables are nominal and that there are no missing values * @param instances data set to check and quantize and/or fill in missing values * @return filtered instances * @throws Exception if a filter (Discretize, ReplaceMissingValues) fails */ protected Instances normalizeDataSet(Instances instances) throws Exception { m_DiscretizeFilter = null; m_MissingValuesFilter = null; boolean bHasNonNominal = false; boolean bHasMissingValues = false; Enumeration enu = instances.enumerateAttributes(); while (enu.hasMoreElements()) { Attribute attribute = (Attribute) enu.nextElement(); if (attribute.type() != Attribute.NOMINAL) { m_nNonDiscreteAttribute = attribute.index(); bHasNonNominal = true; //throw new UnsupportedAttributeTypeException("BayesNet handles nominal variables only. Non-nominal variable in dataset detected."); } Enumeration enum2 = instances.enumerateInstances(); while (enum2.hasMoreElements()) { if (((Instance) enum2.nextElement()).isMissing(attribute)) { bHasMissingValues = true; // throw new NoSupportForMissingValuesException("BayesNet: no missing values, please."); } } } if (bHasNonNominal) { System.err.println("Warning: discretizing data set"); m_DiscretizeFilter = new Discretize(); m_DiscretizeFilter.setInputFormat(instances); instances = Filter.useFilter(instances, m_DiscretizeFilter); } if (bHasMissingValues) { System.err.println("Warning: filling in missing values in data set"); m_MissingValuesFilter = new ReplaceMissingValues(); m_MissingValuesFilter.setInputFormat(instances); instances = Filter.useFilter(instances, m_MissingValuesFilter); } return instances; } // normalizeDataSet /** ensure that all variables are nominal and that there are no missing values * @param instance instance to check and quantize and/or fill in missing values * @return filtered instance * @throws Exception if a filter (Discretize, ReplaceMissingValues) fails */ protected Instance normalizeInstance(Instance instance) throws Exception { if ((m_DiscretizeFilter != null) && (instance.attribute(m_nNonDiscreteAttribute).type() != Attribute.NOMINAL)) { m_DiscretizeFilter.input(instance); instance = m_DiscretizeFilter.output(); } if (m_MissingValuesFilter != null) { m_MissingValuesFilter.input(instance); instance = m_MissingValuesFilter.output(); } else { // is there a missing value in this instance? // this can happen when there is no missing value in the training set for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) { if (iAttribute != instance.classIndex() && instance.isMissing(iAttribute)) { System.err.println("Warning: Found missing value in test set, filling in values."); m_MissingValuesFilter = new ReplaceMissingValues(); m_MissingValuesFilter.setInputFormat(m_Instances); Filter.useFilter(m_Instances, m_MissingValuesFilter); m_MissingValuesFilter.input(instance); instance = m_MissingValuesFilter.output(); iAttribute = m_Instances.numAttributes(); } } } return instance; } // normalizeInstance /** * Init structure initializes the structure to an empty graph or a Naive Bayes * graph (depending on the -N flag). * * @throws Exception in case of an error */ public void initStructure() throws Exception { // initialize topological ordering // m_nOrder = new int[m_Instances.numAttributes()]; // m_nOrder[0] = m_Instances.classIndex(); int nAttribute = 0; for (int iOrder = 1; iOrder < m_Instances.numAttributes(); iOrder++) { if (nAttribute == m_Instances.classIndex()) { nAttribute++; } // m_nOrder[iOrder] = nAttribute++; } // reserve memory m_ParentSets = new ParentSet[m_Instances.numAttributes()]; for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) { m_ParentSets[iAttribute] = new ParentSet(m_Instances.numAttributes()); } } // initStructure /** * buildStructure determines the network structure/graph of the network. * The default behavior is creating a network where all nodes have the first * node as its parent (i.e., a BayesNet that behaves like a naive Bayes classifier). * This method can be overridden by derived classes to restrict the class * of network structures that are acceptable. * * @throws Exception in case of an error */ public void buildStructure() throws Exception { m_SearchAlgorithm.buildStructure(this, m_Instances); } // buildStructure /** * estimateCPTs estimates the conditional probability tables for the Bayes * Net using the network structure. * * @throws Exception in case of an error */ public void estimateCPTs() throws Exception { m_BayesNetEstimator.estimateCPTs(this); } // estimateCPTs /** * initializes the conditional probabilities * * @throws Exception in case of an error */ public void initCPTs() throws Exception { m_BayesNetEstimator.initCPTs(this); } // estimateCPTs /** * Updates the classifier with the given instance. * * @param instance the new training instance to include in the model * @throws Exception if the instance could not be incorporated in * the model. */ public void updateClassifier(Instance instance) throws Exception { instance = normalizeInstance(instance); m_BayesNetEstimator.updateClassifier(this, instance); } // updateClassifier /** * Calculates the class membership probabilities for the given test * instance. * * @param instance the instance to be classified * @return predicted class probability distribution * @throws Exception if there is a problem generating the prediction */ public double[] distributionForInstance(Instance instance) throws Exception { instance = normalizeInstance(instance); return m_BayesNetEstimator.distributionForInstance(this, instance); } // distributionForInstance /** * Calculates the counts for Dirichlet distribution for the * class membership probabilities for the given test instance. * * @param instance the instance to be classified * @return counts for Dirichlet distribution for class probability * @throws Exception if there is a problem generating the prediction */ public double[] countsForInstance(Instance instance) throws Exception { double[] fCounts = new double[m_NumClasses]; for (int iClass = 0; iClass < m_NumClasses; iClass++) { fCounts[iClass] = 0.0; } for (int iClass = 0; iClass < m_NumClasses; iClass++) { double fCount = 0; for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) { double iCPT = 0; for (int iParent = 0; iParent < m_ParentSets[iAttribute].getNrOfParents(); iParent++) { int nParent = m_ParentSets[iAttribute].getParent(iParent); if (nParent == m_Instances.classIndex()) { iCPT = iCPT * m_NumClasses + iClass; } else { iCPT = iCPT * m_Instances.attribute(nParent).numValues() + instance.value(nParent); } } if (iAttribute == m_Instances.classIndex()) { fCount += ((DiscreteEstimatorBayes) m_Distributions[iAttribute][(int) iCPT]).getCount(iClass); } else { fCount += ((DiscreteEstimatorBayes) m_Distributions[iAttribute][(int) iCPT]).getCount( instance.value(iAttribute)); } } fCounts[iClass] += fCount; } return fCounts; } // countsForInstance /** * Returns an enumeration describing the available options * * @return an enumeration of all the available options */ public Enumeration listOptions() { Vector newVector = new Vector(4); newVector.addElement(new Option("\tDo not use ADTree data structure\n", "D", 0, "-D")); newVector.addElement(new Option("\tBIF file to compare with\n", "B", 1, "-B
* * Valid options are:
* *-D * Do not use ADTree data structure ** *
-B <BIF file> * BIF file to compare with ** *
-Q weka.classifiers.bayes.net.search.SearchAlgorithm * Search algorithm ** *
-E weka.classifiers.bayes.net.estimate.SimpleEstimator * Estimator algorithm ** * * @param options the list of options as an array of strings * @throws Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { m_bUseADTree = !(Utils.getFlag('D', options)); String sBIFFile = Utils.getOption('B', options); if (sBIFFile != null && !sBIFFile.equals("")) { setBIFFile(sBIFFile); } String searchAlgorithmName = Utils.getOption('Q', options); if (searchAlgorithmName.length() != 0) { setSearchAlgorithm( (SearchAlgorithm) Utils.forName( SearchAlgorithm.class, searchAlgorithmName, partitionOptions(options))); } else { setSearchAlgorithm(new K2()); } String estimatorName = Utils.getOption('E', options); if (estimatorName.length() != 0) { setEstimator( (BayesNetEstimator) Utils.forName( BayesNetEstimator.class, estimatorName, Utils.partitionOptions(options))); } else { setEstimator(new SimpleEstimator()); } Utils.checkForRemainingOptions(options); } // setOptions /** * Returns the secondary set of options (if any) contained in * the supplied options array. The secondary set is defined to * be any options after the first "--" but before the "-E". These * options are removed from the original options array. * * @param options the input array of options * @return the array of secondary options */ public static String [] partitionOptions(String [] options) { for (int i = 0; i < options.length; i++) { if (options[i].equals("--")) { // ensure it follows by a -E option int j = i; while ((j < options.length) && !(options[j].equals("-E"))) { j++; } /* if (j >= options.length) { return new String[0]; } */ options[i++] = ""; String [] result = new String [options.length - i]; j = i; while ((j < options.length) && !(options[j].equals("-E"))) { result[j - i] = options[j]; options[j] = ""; j++; } while(j < options.length) { result[j - i] = ""; j++; } return result; } } return new String [0]; } /** * Gets the current settings of the classifier. * * @return an array of strings suitable for passing to setOptions */ public String[] getOptions() { String[] searchOptions = m_SearchAlgorithm.getOptions(); String[] estimatorOptions = m_BayesNetEstimator.getOptions(); String[] options = new String[11 + searchOptions.length + estimatorOptions.length]; int current = 0; if (!m_bUseADTree) { options[current++] = "-D"; } if (m_otherBayesNet != null) { options[current++] = "-B"; options[current++] = ((BIFReader) m_otherBayesNet).getFileName(); } options[current++] = "-Q"; options[current++] = "" + getSearchAlgorithm().getClass().getName(); options[current++] = "--"; for (int iOption = 0; iOption < searchOptions.length; iOption++) { options[current++] = searchOptions[iOption]; } options[current++] = "-E"; options[current++] = "" + getEstimator().getClass().getName(); options[current++] = "--"; for (int iOption = 0; iOption < estimatorOptions.length; iOption++) { options[current++] = estimatorOptions[iOption]; } // Fill up rest with empty strings, not nulls! while (current < options.length) { options[current++] = ""; } return options; } // getOptions /** * Set the SearchAlgorithm used in searching for network structures. * @param newSearchAlgorithm the SearchAlgorithm to use. */ public void setSearchAlgorithm(SearchAlgorithm newSearchAlgorithm) { m_SearchAlgorithm = newSearchAlgorithm; } /** * Get the SearchAlgorithm used as the search algorithm * @return the SearchAlgorithm used as the search algorithm */ public SearchAlgorithm getSearchAlgorithm() { return m_SearchAlgorithm; } /** * Set the Estimator Algorithm used in calculating the CPTs * @param newBayesNetEstimator the Estimator to use. */ public void setEstimator(BayesNetEstimator newBayesNetEstimator) { m_BayesNetEstimator = newBayesNetEstimator; } /** * Get the BayesNetEstimator used for calculating the CPTs * @return the BayesNetEstimator used. */ public BayesNetEstimator getEstimator() { return m_BayesNetEstimator; } /** * Set whether ADTree structure is used or not * @param bUseADTree true if an ADTree structure is used */ public void setUseADTree(boolean bUseADTree) { m_bUseADTree = bUseADTree; } /** * Method declaration * @return whether ADTree structure is used or not */ public boolean getUseADTree() { return m_bUseADTree; } /** * Set name of network in BIF file to compare with * @param sBIFFile the name of the BIF file */ public void setBIFFile(String sBIFFile) { try { m_otherBayesNet = new BIFReader().processFile(sBIFFile); } catch (Throwable t) { m_otherBayesNet = null; } } /** * Get name of network in BIF file to compare with * @return BIF file name */ public String getBIFFile() { if (m_otherBayesNet != null) { return m_otherBayesNet.getFileName(); } return ""; } /** * Returns a description of the classifier. * * @return a description of the classifier as a string. */ public String toString() { StringBuffer text = new StringBuffer(); text.append("Bayes Network Classifier"); text.append("\n" + (m_bUseADTree ? "Using " : "not using ") + "ADTree"); if (m_Instances == null) { text.append(": No model built yet."); } else { // flatten BayesNet down to text text.append("\n#attributes="); text.append(m_Instances.numAttributes()); text.append(" #classindex="); text.append(m_Instances.classIndex()); text.append("\nNetwork structure (nodes followed by parents)\n"); for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) { text.append( m_Instances.attribute(iAttribute).name() + "(" + m_Instances.attribute(iAttribute).numValues() + "): "); for (int iParent = 0; iParent < m_ParentSets[iAttribute].getNrOfParents(); iParent++) { text.append(m_Instances.attribute(m_ParentSets[iAttribute].getParent(iParent)).name() + " "); } text.append("\n"); // Description of distributions tends to be too much detail, so it is commented out here // for (int iParent = 0; iParent < m_ParentSets[iAttribute].GetCardinalityOfParents(); iParent++) { // text.append('(' + m_Distributions[iAttribute][iParent].toString() + ')'); // } // text.append("\n"); } text.append("LogScore Bayes: " + measureBayesScore() + "\n"); text.append("LogScore BDeu: " + measureBDeuScore() + "\n"); text.append("LogScore MDL: " + measureMDLScore() + "\n"); text.append("LogScore ENTROPY: " + measureEntropyScore() + "\n"); text.append("LogScore AIC: " + measureAICScore() + "\n"); if (m_otherBayesNet != null) { text.append( "Missing: " + m_otherBayesNet.missingArcs(this) + " Extra: " + m_otherBayesNet.extraArcs(this) + " Reversed: " + m_otherBayesNet.reversedArcs(this) + "\n"); text.append("Divergence: " + m_otherBayesNet.divergence(this) + "\n"); } } return text.toString(); } // toString /** * Returns the type of graph this classifier * represents. * @return Drawable.TREE */ public int graphType() { return Drawable.BayesNet; } /** * Returns a BayesNet graph in XMLBIF ver 0.3 format. * @return String representing this BayesNet in XMLBIF ver 0.3 * @throws Exception in case BIF generation fails */ public String graph() throws Exception { return toXMLBIF03(); } public String getBIFHeader() { StringBuffer text = new StringBuffer(); text.append("\n"); text.append("\n"); text.append("\n"); text.append(" \n"); text.append(" \n"); text.append(" \n"); text.append(" \n"); text.append(" \n"); text.append(" \n"); text.append(" \n"); text.append(" \n"); text.append(" \n"); text.append(" \n"); text.append(" \n"); text.append("]>\n"); return text.toString(); } // getBIFHeader /** * Returns a description of the classifier in XML BIF 0.3 format. * See http://www-2.cs.cmu.edu/~fgcozman/Research/InterchangeFormat/ * for details on XML BIF. * @return an XML BIF 0.3 description of the classifier as a string. */ public String toXMLBIF03() { if (m_Instances == null) { return(""); } StringBuffer text = new StringBuffer(); text.append(getBIFHeader()); text.append("\n"); text.append("\n"); text.append("