[4] | 1 | /* |
---|
| 2 | * This program is free software; you can redistribute it and/or modify |
---|
| 3 | * it under the terms of the GNU General Public License as published by |
---|
| 4 | * the Free Software Foundation; either version 2 of the License, or |
---|
| 5 | * (at your option) any later version. |
---|
| 6 | * |
---|
| 7 | * This program is distributed in the hope that it will be useful, |
---|
| 8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
| 9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
---|
| 10 | * GNU General Public License for more details. |
---|
| 11 | * |
---|
| 12 | * You should have received a copy of the GNU General Public License |
---|
| 13 | * along with this program; if not, write to the Free Software |
---|
| 14 | * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. |
---|
| 15 | */ |
---|
| 16 | |
---|
| 17 | /* |
---|
| 18 | * BayesNet.java |
---|
| 19 | * Copyright (C) 2001 University of Waikato, Hamilton, New Zealand |
---|
| 20 | * |
---|
| 21 | */ |
---|
| 22 | package weka.classifiers.bayes; |
---|
| 23 | |
---|
| 24 | import weka.classifiers.Classifier; |
---|
| 25 | import weka.classifiers.AbstractClassifier; |
---|
| 26 | import weka.classifiers.bayes.net.ADNode; |
---|
| 27 | import weka.classifiers.bayes.net.BIFReader; |
---|
| 28 | import weka.classifiers.bayes.net.ParentSet; |
---|
| 29 | import weka.classifiers.bayes.net.estimate.BayesNetEstimator; |
---|
| 30 | import weka.classifiers.bayes.net.estimate.DiscreteEstimatorBayes; |
---|
| 31 | import weka.classifiers.bayes.net.estimate.SimpleEstimator; |
---|
| 32 | import weka.classifiers.bayes.net.search.SearchAlgorithm; |
---|
| 33 | import weka.classifiers.bayes.net.search.local.K2; |
---|
| 34 | import weka.classifiers.bayes.net.search.local.LocalScoreSearchAlgorithm; |
---|
| 35 | import weka.classifiers.bayes.net.search.local.Scoreable; |
---|
| 36 | import weka.core.AdditionalMeasureProducer; |
---|
| 37 | import weka.core.Attribute; |
---|
| 38 | import weka.core.Capabilities; |
---|
| 39 | import weka.core.Drawable; |
---|
| 40 | import weka.core.Instance; |
---|
| 41 | import weka.core.Instances; |
---|
| 42 | import weka.core.Option; |
---|
| 43 | import weka.core.OptionHandler; |
---|
| 44 | import weka.core.RevisionUtils; |
---|
| 45 | import weka.core.Utils; |
---|
| 46 | import weka.core.WeightedInstancesHandler; |
---|
| 47 | import weka.core.Capabilities.Capability; |
---|
| 48 | import weka.estimators.Estimator; |
---|
| 49 | import weka.filters.Filter; |
---|
| 50 | import weka.filters.supervised.attribute.Discretize; |
---|
| 51 | import weka.filters.unsupervised.attribute.ReplaceMissingValues; |
---|
| 52 | |
---|
| 53 | import java.util.Enumeration; |
---|
| 54 | import java.util.Vector; |
---|
| 55 | |
---|
| 56 | /** |
---|
| 57 | <!-- globalinfo-start --> |
---|
| 58 | * Bayes Network learning using various search algorithms and quality measures.<br/> |
---|
| 59 | * Base class for a Bayes Network classifier. Provides datastructures (network structure, conditional probability distributions, etc.) and facilities common to Bayes Network learning algorithms like K2 and B.<br/> |
---|
| 60 | * <br/> |
---|
| 61 | * For more information see:<br/> |
---|
| 62 | * <br/> |
---|
| 63 | * http://sourceforge.net/projects/weka/files/documentation/WekaManual-3-7-0.pdf/download |
---|
| 64 | * <p/> |
---|
| 65 | <!-- globalinfo-end --> |
---|
| 66 | * |
---|
| 67 | <!-- options-start --> |
---|
| 68 | * Valid options are: <p/> |
---|
| 69 | * |
---|
| 70 | * <pre> -D |
---|
| 71 | * Do not use ADTree data structure |
---|
| 72 | * </pre> |
---|
| 73 | * |
---|
| 74 | * <pre> -B <BIF file> |
---|
| 75 | * BIF file to compare with |
---|
| 76 | * </pre> |
---|
| 77 | * |
---|
| 78 | * <pre> -Q weka.classifiers.bayes.net.search.SearchAlgorithm |
---|
| 79 | * Search algorithm |
---|
| 80 | * </pre> |
---|
| 81 | * |
---|
| 82 | * <pre> -E weka.classifiers.bayes.net.estimate.SimpleEstimator |
---|
| 83 | * Estimator algorithm |
---|
| 84 | * </pre> |
---|
| 85 | * |
---|
| 86 | <!-- options-end --> |
---|
| 87 | * |
---|
| 88 | * @author Remco Bouckaert (rrb@xm.co.nz) |
---|
| 89 | * @version $Revision: 5928 $ |
---|
| 90 | */ |
---|
| 91 | public class BayesNet |
---|
| 92 | extends AbstractClassifier |
---|
| 93 | implements OptionHandler, WeightedInstancesHandler, Drawable, |
---|
| 94 | AdditionalMeasureProducer { |
---|
| 95 | |
---|
| 96 | /** for serialization */ |
---|
| 97 | static final long serialVersionUID = 746037443258775954L; |
---|
| 98 | |
---|
| 99 | |
---|
| 100 | /** |
---|
| 101 | * The parent sets. |
---|
| 102 | */ |
---|
| 103 | protected ParentSet[] m_ParentSets; |
---|
| 104 | |
---|
| 105 | /** |
---|
| 106 | * The attribute estimators containing CPTs. |
---|
| 107 | */ |
---|
| 108 | public Estimator[][] m_Distributions; |
---|
| 109 | |
---|
| 110 | |
---|
| 111 | /** filter used to quantize continuous variables, if any **/ |
---|
| 112 | protected Discretize m_DiscretizeFilter = null; |
---|
| 113 | |
---|
| 114 | /** attribute index of a non-nominal attribute */ |
---|
| 115 | int m_nNonDiscreteAttribute = -1; |
---|
| 116 | |
---|
| 117 | /** filter used to fill in missing values, if any **/ |
---|
| 118 | protected ReplaceMissingValues m_MissingValuesFilter = null; |
---|
| 119 | |
---|
| 120 | /** |
---|
| 121 | * The number of classes |
---|
| 122 | */ |
---|
| 123 | protected int m_NumClasses; |
---|
| 124 | |
---|
| 125 | /** |
---|
| 126 | * The dataset header for the purposes of printing out a semi-intelligible |
---|
| 127 | * model |
---|
| 128 | */ |
---|
| 129 | public Instances m_Instances; |
---|
| 130 | |
---|
| 131 | /** |
---|
| 132 | * Datastructure containing ADTree representation of the database. |
---|
| 133 | * This may result in more efficient access to the data. |
---|
| 134 | */ |
---|
| 135 | ADNode m_ADTree; |
---|
| 136 | |
---|
| 137 | /** |
---|
| 138 | * Bayes network to compare the structure with. |
---|
| 139 | */ |
---|
| 140 | protected BIFReader m_otherBayesNet = null; |
---|
| 141 | |
---|
| 142 | /** |
---|
| 143 | * Use the experimental ADTree datastructure for calculating contingency tables |
---|
| 144 | */ |
---|
| 145 | boolean m_bUseADTree = false; |
---|
| 146 | |
---|
| 147 | /** |
---|
| 148 | * Search algorithm used for learning the structure of a network. |
---|
| 149 | */ |
---|
| 150 | SearchAlgorithm m_SearchAlgorithm = new K2(); |
---|
| 151 | |
---|
| 152 | /** |
---|
| 153 | * Search algorithm used for learning the structure of a network. |
---|
| 154 | */ |
---|
| 155 | BayesNetEstimator m_BayesNetEstimator = new SimpleEstimator(); |
---|
| 156 | |
---|
| 157 | /** |
---|
| 158 | * Returns default capabilities of the classifier. |
---|
| 159 | * |
---|
| 160 | * @return the capabilities of this classifier |
---|
| 161 | */ |
---|
| 162 | public Capabilities getCapabilities() { |
---|
| 163 | Capabilities result = super.getCapabilities(); |
---|
| 164 | result.disableAll(); |
---|
| 165 | |
---|
| 166 | // attributes |
---|
| 167 | result.enable(Capability.NOMINAL_ATTRIBUTES); |
---|
| 168 | result.enable(Capability.NUMERIC_ATTRIBUTES); |
---|
| 169 | result.enable(Capability.MISSING_VALUES); |
---|
| 170 | |
---|
| 171 | // class |
---|
| 172 | result.enable(Capability.NOMINAL_CLASS); |
---|
| 173 | result.enable(Capability.MISSING_CLASS_VALUES); |
---|
| 174 | |
---|
| 175 | // instances |
---|
| 176 | result.setMinimumNumberInstances(0); |
---|
| 177 | |
---|
| 178 | return result; |
---|
| 179 | } |
---|
| 180 | |
---|
| 181 | /** |
---|
| 182 | * Generates the classifier. |
---|
| 183 | * |
---|
| 184 | * @param instances set of instances serving as training data |
---|
| 185 | * @throws Exception if the classifier has not been generated |
---|
| 186 | * successfully |
---|
| 187 | */ |
---|
| 188 | public void buildClassifier(Instances instances) throws Exception { |
---|
| 189 | |
---|
| 190 | // can classifier handle the data? |
---|
| 191 | getCapabilities().testWithFail(instances); |
---|
| 192 | |
---|
| 193 | // remove instances with missing class |
---|
| 194 | instances = new Instances(instances); |
---|
| 195 | instances.deleteWithMissingClass(); |
---|
| 196 | |
---|
| 197 | // ensure we have a data set with discrete variables only and with no missing values |
---|
| 198 | instances = normalizeDataSet(instances); |
---|
| 199 | |
---|
| 200 | // Copy the instances |
---|
| 201 | m_Instances = new Instances(instances); |
---|
| 202 | |
---|
| 203 | // sanity check: need more than 1 variable in datat set |
---|
| 204 | m_NumClasses = instances.numClasses(); |
---|
| 205 | |
---|
| 206 | // initialize ADTree |
---|
| 207 | if (m_bUseADTree) { |
---|
| 208 | m_ADTree = ADNode.makeADTree(instances); |
---|
| 209 | // System.out.println("Oef, done!"); |
---|
| 210 | } |
---|
| 211 | |
---|
| 212 | // build the network structure |
---|
| 213 | initStructure(); |
---|
| 214 | |
---|
| 215 | // build the network structure |
---|
| 216 | buildStructure(); |
---|
| 217 | |
---|
| 218 | // build the set of CPTs |
---|
| 219 | estimateCPTs(); |
---|
| 220 | |
---|
| 221 | // Save space |
---|
| 222 | // m_Instances = new Instances(m_Instances, 0); |
---|
| 223 | m_ADTree = null; |
---|
| 224 | } // buildClassifier |
---|
| 225 | |
---|
| 226 | /** ensure that all variables are nominal and that there are no missing values |
---|
| 227 | * @param instances data set to check and quantize and/or fill in missing values |
---|
| 228 | * @return filtered instances |
---|
| 229 | * @throws Exception if a filter (Discretize, ReplaceMissingValues) fails |
---|
| 230 | */ |
---|
| 231 | protected Instances normalizeDataSet(Instances instances) throws Exception { |
---|
| 232 | m_DiscretizeFilter = null; |
---|
| 233 | m_MissingValuesFilter = null; |
---|
| 234 | |
---|
| 235 | boolean bHasNonNominal = false; |
---|
| 236 | boolean bHasMissingValues = false; |
---|
| 237 | |
---|
| 238 | Enumeration enu = instances.enumerateAttributes(); |
---|
| 239 | while (enu.hasMoreElements()) { |
---|
| 240 | Attribute attribute = (Attribute) enu.nextElement(); |
---|
| 241 | if (attribute.type() != Attribute.NOMINAL) { |
---|
| 242 | m_nNonDiscreteAttribute = attribute.index(); |
---|
| 243 | bHasNonNominal = true; |
---|
| 244 | //throw new UnsupportedAttributeTypeException("BayesNet handles nominal variables only. Non-nominal variable in dataset detected."); |
---|
| 245 | } |
---|
| 246 | Enumeration enum2 = instances.enumerateInstances(); |
---|
| 247 | while (enum2.hasMoreElements()) { |
---|
| 248 | if (((Instance) enum2.nextElement()).isMissing(attribute)) { |
---|
| 249 | bHasMissingValues = true; |
---|
| 250 | // throw new NoSupportForMissingValuesException("BayesNet: no missing values, please."); |
---|
| 251 | } |
---|
| 252 | } |
---|
| 253 | } |
---|
| 254 | |
---|
| 255 | if (bHasNonNominal) { |
---|
| 256 | System.err.println("Warning: discretizing data set"); |
---|
| 257 | m_DiscretizeFilter = new Discretize(); |
---|
| 258 | m_DiscretizeFilter.setInputFormat(instances); |
---|
| 259 | instances = Filter.useFilter(instances, m_DiscretizeFilter); |
---|
| 260 | } |
---|
| 261 | |
---|
| 262 | if (bHasMissingValues) { |
---|
| 263 | System.err.println("Warning: filling in missing values in data set"); |
---|
| 264 | m_MissingValuesFilter = new ReplaceMissingValues(); |
---|
| 265 | m_MissingValuesFilter.setInputFormat(instances); |
---|
| 266 | instances = Filter.useFilter(instances, m_MissingValuesFilter); |
---|
| 267 | } |
---|
| 268 | return instances; |
---|
| 269 | } // normalizeDataSet |
---|
| 270 | |
---|
| 271 | /** ensure that all variables are nominal and that there are no missing values |
---|
| 272 | * @param instance instance to check and quantize and/or fill in missing values |
---|
| 273 | * @return filtered instance |
---|
| 274 | * @throws Exception if a filter (Discretize, ReplaceMissingValues) fails |
---|
| 275 | */ |
---|
| 276 | protected Instance normalizeInstance(Instance instance) throws Exception { |
---|
| 277 | if ((m_DiscretizeFilter != null) && |
---|
| 278 | (instance.attribute(m_nNonDiscreteAttribute).type() != Attribute.NOMINAL)) { |
---|
| 279 | m_DiscretizeFilter.input(instance); |
---|
| 280 | instance = m_DiscretizeFilter.output(); |
---|
| 281 | } |
---|
| 282 | if (m_MissingValuesFilter != null) { |
---|
| 283 | m_MissingValuesFilter.input(instance); |
---|
| 284 | instance = m_MissingValuesFilter.output(); |
---|
| 285 | } else { |
---|
| 286 | // is there a missing value in this instance? |
---|
| 287 | // this can happen when there is no missing value in the training set |
---|
| 288 | for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) { |
---|
| 289 | if (iAttribute != instance.classIndex() && instance.isMissing(iAttribute)) { |
---|
| 290 | System.err.println("Warning: Found missing value in test set, filling in values."); |
---|
| 291 | m_MissingValuesFilter = new ReplaceMissingValues(); |
---|
| 292 | m_MissingValuesFilter.setInputFormat(m_Instances); |
---|
| 293 | Filter.useFilter(m_Instances, m_MissingValuesFilter); |
---|
| 294 | m_MissingValuesFilter.input(instance); |
---|
| 295 | instance = m_MissingValuesFilter.output(); |
---|
| 296 | iAttribute = m_Instances.numAttributes(); |
---|
| 297 | } |
---|
| 298 | } |
---|
| 299 | } |
---|
| 300 | return instance; |
---|
| 301 | } // normalizeInstance |
---|
| 302 | |
---|
| 303 | /** |
---|
| 304 | * Init structure initializes the structure to an empty graph or a Naive Bayes |
---|
| 305 | * graph (depending on the -N flag). |
---|
| 306 | * |
---|
| 307 | * @throws Exception in case of an error |
---|
| 308 | */ |
---|
| 309 | public void initStructure() throws Exception { |
---|
| 310 | |
---|
| 311 | // initialize topological ordering |
---|
| 312 | // m_nOrder = new int[m_Instances.numAttributes()]; |
---|
| 313 | // m_nOrder[0] = m_Instances.classIndex(); |
---|
| 314 | |
---|
| 315 | int nAttribute = 0; |
---|
| 316 | |
---|
| 317 | for (int iOrder = 1; iOrder < m_Instances.numAttributes(); iOrder++) { |
---|
| 318 | if (nAttribute == m_Instances.classIndex()) { |
---|
| 319 | nAttribute++; |
---|
| 320 | } |
---|
| 321 | |
---|
| 322 | // m_nOrder[iOrder] = nAttribute++; |
---|
| 323 | } |
---|
| 324 | |
---|
| 325 | // reserve memory |
---|
| 326 | m_ParentSets = new ParentSet[m_Instances.numAttributes()]; |
---|
| 327 | |
---|
| 328 | for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) { |
---|
| 329 | m_ParentSets[iAttribute] = new ParentSet(m_Instances.numAttributes()); |
---|
| 330 | } |
---|
| 331 | } // initStructure |
---|
| 332 | |
---|
| 333 | /** |
---|
| 334 | * buildStructure determines the network structure/graph of the network. |
---|
| 335 | * The default behavior is creating a network where all nodes have the first |
---|
| 336 | * node as its parent (i.e., a BayesNet that behaves like a naive Bayes classifier). |
---|
| 337 | * This method can be overridden by derived classes to restrict the class |
---|
| 338 | * of network structures that are acceptable. |
---|
| 339 | * |
---|
| 340 | * @throws Exception in case of an error |
---|
| 341 | */ |
---|
| 342 | public void buildStructure() throws Exception { |
---|
| 343 | m_SearchAlgorithm.buildStructure(this, m_Instances); |
---|
| 344 | } // buildStructure |
---|
| 345 | |
---|
| 346 | /** |
---|
| 347 | * estimateCPTs estimates the conditional probability tables for the Bayes |
---|
| 348 | * Net using the network structure. |
---|
| 349 | * |
---|
| 350 | * @throws Exception in case of an error |
---|
| 351 | */ |
---|
| 352 | public void estimateCPTs() throws Exception { |
---|
| 353 | m_BayesNetEstimator.estimateCPTs(this); |
---|
| 354 | } // estimateCPTs |
---|
| 355 | |
---|
| 356 | /** |
---|
| 357 | * initializes the conditional probabilities |
---|
| 358 | * |
---|
| 359 | * @throws Exception in case of an error |
---|
| 360 | */ |
---|
| 361 | public void initCPTs() throws Exception { |
---|
| 362 | m_BayesNetEstimator.initCPTs(this); |
---|
| 363 | } // estimateCPTs |
---|
| 364 | |
---|
| 365 | /** |
---|
| 366 | * Updates the classifier with the given instance. |
---|
| 367 | * |
---|
| 368 | * @param instance the new training instance to include in the model |
---|
| 369 | * @throws Exception if the instance could not be incorporated in |
---|
| 370 | * the model. |
---|
| 371 | */ |
---|
| 372 | public void updateClassifier(Instance instance) throws Exception { |
---|
| 373 | instance = normalizeInstance(instance); |
---|
| 374 | m_BayesNetEstimator.updateClassifier(this, instance); |
---|
| 375 | } // updateClassifier |
---|
| 376 | |
---|
| 377 | /** |
---|
| 378 | * Calculates the class membership probabilities for the given test |
---|
| 379 | * instance. |
---|
| 380 | * |
---|
| 381 | * @param instance the instance to be classified |
---|
| 382 | * @return predicted class probability distribution |
---|
| 383 | * @throws Exception if there is a problem generating the prediction |
---|
| 384 | */ |
---|
| 385 | public double[] distributionForInstance(Instance instance) throws Exception { |
---|
| 386 | instance = normalizeInstance(instance); |
---|
| 387 | return m_BayesNetEstimator.distributionForInstance(this, instance); |
---|
| 388 | } // distributionForInstance |
---|
| 389 | |
---|
| 390 | /** |
---|
| 391 | * Calculates the counts for Dirichlet distribution for the |
---|
| 392 | * class membership probabilities for the given test instance. |
---|
| 393 | * |
---|
| 394 | * @param instance the instance to be classified |
---|
| 395 | * @return counts for Dirichlet distribution for class probability |
---|
| 396 | * @throws Exception if there is a problem generating the prediction |
---|
| 397 | */ |
---|
| 398 | public double[] countsForInstance(Instance instance) throws Exception { |
---|
| 399 | double[] fCounts = new double[m_NumClasses]; |
---|
| 400 | |
---|
| 401 | for (int iClass = 0; iClass < m_NumClasses; iClass++) { |
---|
| 402 | fCounts[iClass] = 0.0; |
---|
| 403 | } |
---|
| 404 | |
---|
| 405 | for (int iClass = 0; iClass < m_NumClasses; iClass++) { |
---|
| 406 | double fCount = 0; |
---|
| 407 | |
---|
| 408 | for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) { |
---|
| 409 | double iCPT = 0; |
---|
| 410 | |
---|
| 411 | for (int iParent = 0; iParent < m_ParentSets[iAttribute].getNrOfParents(); iParent++) { |
---|
| 412 | int nParent = m_ParentSets[iAttribute].getParent(iParent); |
---|
| 413 | |
---|
| 414 | if (nParent == m_Instances.classIndex()) { |
---|
| 415 | iCPT = iCPT * m_NumClasses + iClass; |
---|
| 416 | } else { |
---|
| 417 | iCPT = iCPT * m_Instances.attribute(nParent).numValues() + instance.value(nParent); |
---|
| 418 | } |
---|
| 419 | } |
---|
| 420 | |
---|
| 421 | if (iAttribute == m_Instances.classIndex()) { |
---|
| 422 | fCount += ((DiscreteEstimatorBayes) m_Distributions[iAttribute][(int) iCPT]).getCount(iClass); |
---|
| 423 | } else { |
---|
| 424 | fCount |
---|
| 425 | += ((DiscreteEstimatorBayes) m_Distributions[iAttribute][(int) iCPT]).getCount( |
---|
| 426 | instance.value(iAttribute)); |
---|
| 427 | } |
---|
| 428 | } |
---|
| 429 | |
---|
| 430 | fCounts[iClass] += fCount; |
---|
| 431 | } |
---|
| 432 | return fCounts; |
---|
| 433 | } // countsForInstance |
---|
| 434 | |
---|
| 435 | /** |
---|
| 436 | * Returns an enumeration describing the available options |
---|
| 437 | * |
---|
| 438 | * @return an enumeration of all the available options |
---|
| 439 | */ |
---|
| 440 | public Enumeration listOptions() { |
---|
| 441 | Vector newVector = new Vector(4); |
---|
| 442 | |
---|
| 443 | newVector.addElement(new Option("\tDo not use ADTree data structure\n", "D", 0, "-D")); |
---|
| 444 | newVector.addElement(new Option("\tBIF file to compare with\n", "B", 1, "-B <BIF file>")); |
---|
| 445 | newVector.addElement(new Option("\tSearch algorithm\n", "Q", 1, "-Q weka.classifiers.bayes.net.search.SearchAlgorithm")); |
---|
| 446 | newVector.addElement(new Option("\tEstimator algorithm\n", "E", 1, "-E weka.classifiers.bayes.net.estimate.SimpleEstimator")); |
---|
| 447 | |
---|
| 448 | return newVector.elements(); |
---|
| 449 | } // listOptions |
---|
| 450 | |
---|
| 451 | /** |
---|
| 452 | * Parses a given list of options. <p> |
---|
| 453 | * |
---|
| 454 | <!-- options-start --> |
---|
| 455 | * Valid options are: <p/> |
---|
| 456 | * |
---|
| 457 | * <pre> -D |
---|
| 458 | * Do not use ADTree data structure |
---|
| 459 | * </pre> |
---|
| 460 | * |
---|
| 461 | * <pre> -B <BIF file> |
---|
| 462 | * BIF file to compare with |
---|
| 463 | * </pre> |
---|
| 464 | * |
---|
| 465 | * <pre> -Q weka.classifiers.bayes.net.search.SearchAlgorithm |
---|
| 466 | * Search algorithm |
---|
| 467 | * </pre> |
---|
| 468 | * |
---|
| 469 | * <pre> -E weka.classifiers.bayes.net.estimate.SimpleEstimator |
---|
| 470 | * Estimator algorithm |
---|
| 471 | * </pre> |
---|
| 472 | * |
---|
| 473 | <!-- options-end --> |
---|
| 474 | * |
---|
| 475 | * @param options the list of options as an array of strings |
---|
| 476 | * @throws Exception if an option is not supported |
---|
| 477 | */ |
---|
| 478 | public void setOptions(String[] options) throws Exception { |
---|
| 479 | m_bUseADTree = !(Utils.getFlag('D', options)); |
---|
| 480 | |
---|
| 481 | String sBIFFile = Utils.getOption('B', options); |
---|
| 482 | if (sBIFFile != null && !sBIFFile.equals("")) { |
---|
| 483 | setBIFFile(sBIFFile); |
---|
| 484 | } |
---|
| 485 | |
---|
| 486 | String searchAlgorithmName = Utils.getOption('Q', options); |
---|
| 487 | if (searchAlgorithmName.length() != 0) { |
---|
| 488 | setSearchAlgorithm( |
---|
| 489 | (SearchAlgorithm) Utils.forName( |
---|
| 490 | SearchAlgorithm.class, |
---|
| 491 | searchAlgorithmName, |
---|
| 492 | partitionOptions(options))); |
---|
| 493 | } |
---|
| 494 | else { |
---|
| 495 | setSearchAlgorithm(new K2()); |
---|
| 496 | } |
---|
| 497 | |
---|
| 498 | |
---|
| 499 | String estimatorName = Utils.getOption('E', options); |
---|
| 500 | if (estimatorName.length() != 0) { |
---|
| 501 | setEstimator( |
---|
| 502 | (BayesNetEstimator) Utils.forName( |
---|
| 503 | BayesNetEstimator.class, |
---|
| 504 | estimatorName, |
---|
| 505 | Utils.partitionOptions(options))); |
---|
| 506 | } |
---|
| 507 | else { |
---|
| 508 | setEstimator(new SimpleEstimator()); |
---|
| 509 | } |
---|
| 510 | |
---|
| 511 | Utils.checkForRemainingOptions(options); |
---|
| 512 | } // setOptions |
---|
| 513 | |
---|
| 514 | /** |
---|
| 515 | * Returns the secondary set of options (if any) contained in |
---|
| 516 | * the supplied options array. The secondary set is defined to |
---|
| 517 | * be any options after the first "--" but before the "-E". These |
---|
| 518 | * options are removed from the original options array. |
---|
| 519 | * |
---|
| 520 | * @param options the input array of options |
---|
| 521 | * @return the array of secondary options |
---|
| 522 | */ |
---|
| 523 | public static String [] partitionOptions(String [] options) { |
---|
| 524 | |
---|
| 525 | for (int i = 0; i < options.length; i++) { |
---|
| 526 | if (options[i].equals("--")) { |
---|
| 527 | // ensure it follows by a -E option |
---|
| 528 | int j = i; |
---|
| 529 | while ((j < options.length) && !(options[j].equals("-E"))) { |
---|
| 530 | j++; |
---|
| 531 | } |
---|
| 532 | /* if (j >= options.length) { |
---|
| 533 | return new String[0]; |
---|
| 534 | } */ |
---|
| 535 | options[i++] = ""; |
---|
| 536 | String [] result = new String [options.length - i]; |
---|
| 537 | j = i; |
---|
| 538 | while ((j < options.length) && !(options[j].equals("-E"))) { |
---|
| 539 | result[j - i] = options[j]; |
---|
| 540 | options[j] = ""; |
---|
| 541 | j++; |
---|
| 542 | } |
---|
| 543 | while(j < options.length) { |
---|
| 544 | result[j - i] = ""; |
---|
| 545 | j++; |
---|
| 546 | } |
---|
| 547 | return result; |
---|
| 548 | } |
---|
| 549 | } |
---|
| 550 | return new String [0]; |
---|
| 551 | } |
---|
| 552 | |
---|
| 553 | |
---|
| 554 | /** |
---|
| 555 | * Gets the current settings of the classifier. |
---|
| 556 | * |
---|
| 557 | * @return an array of strings suitable for passing to setOptions |
---|
| 558 | */ |
---|
| 559 | public String[] getOptions() { |
---|
| 560 | String[] searchOptions = m_SearchAlgorithm.getOptions(); |
---|
| 561 | String[] estimatorOptions = m_BayesNetEstimator.getOptions(); |
---|
| 562 | String[] options = new String[11 + searchOptions.length + estimatorOptions.length]; |
---|
| 563 | int current = 0; |
---|
| 564 | |
---|
| 565 | if (!m_bUseADTree) { |
---|
| 566 | options[current++] = "-D"; |
---|
| 567 | } |
---|
| 568 | |
---|
| 569 | if (m_otherBayesNet != null) { |
---|
| 570 | options[current++] = "-B"; |
---|
| 571 | options[current++] = ((BIFReader) m_otherBayesNet).getFileName(); |
---|
| 572 | } |
---|
| 573 | |
---|
| 574 | options[current++] = "-Q"; |
---|
| 575 | options[current++] = "" + getSearchAlgorithm().getClass().getName(); |
---|
| 576 | options[current++] = "--"; |
---|
| 577 | for (int iOption = 0; iOption < searchOptions.length; iOption++) { |
---|
| 578 | options[current++] = searchOptions[iOption]; |
---|
| 579 | } |
---|
| 580 | |
---|
| 581 | options[current++] = "-E"; |
---|
| 582 | options[current++] = "" + getEstimator().getClass().getName(); |
---|
| 583 | options[current++] = "--"; |
---|
| 584 | for (int iOption = 0; iOption < estimatorOptions.length; iOption++) { |
---|
| 585 | options[current++] = estimatorOptions[iOption]; |
---|
| 586 | } |
---|
| 587 | |
---|
| 588 | // Fill up rest with empty strings, not nulls! |
---|
| 589 | while (current < options.length) { |
---|
| 590 | options[current++] = ""; |
---|
| 591 | } |
---|
| 592 | |
---|
| 593 | return options; |
---|
| 594 | } // getOptions |
---|
| 595 | |
---|
| 596 | /** |
---|
| 597 | * Set the SearchAlgorithm used in searching for network structures. |
---|
| 598 | * @param newSearchAlgorithm the SearchAlgorithm to use. |
---|
| 599 | */ |
---|
| 600 | public void setSearchAlgorithm(SearchAlgorithm newSearchAlgorithm) { |
---|
| 601 | m_SearchAlgorithm = newSearchAlgorithm; |
---|
| 602 | } |
---|
| 603 | |
---|
| 604 | /** |
---|
| 605 | * Get the SearchAlgorithm used as the search algorithm |
---|
| 606 | * @return the SearchAlgorithm used as the search algorithm |
---|
| 607 | */ |
---|
| 608 | public SearchAlgorithm getSearchAlgorithm() { |
---|
| 609 | return m_SearchAlgorithm; |
---|
| 610 | } |
---|
| 611 | |
---|
| 612 | /** |
---|
| 613 | * Set the Estimator Algorithm used in calculating the CPTs |
---|
| 614 | * @param newBayesNetEstimator the Estimator to use. |
---|
| 615 | */ |
---|
| 616 | public void setEstimator(BayesNetEstimator newBayesNetEstimator) { |
---|
| 617 | m_BayesNetEstimator = newBayesNetEstimator; |
---|
| 618 | } |
---|
| 619 | |
---|
| 620 | /** |
---|
| 621 | * Get the BayesNetEstimator used for calculating the CPTs |
---|
| 622 | * @return the BayesNetEstimator used. |
---|
| 623 | */ |
---|
| 624 | public BayesNetEstimator getEstimator() { |
---|
| 625 | return m_BayesNetEstimator; |
---|
| 626 | } |
---|
| 627 | |
---|
| 628 | /** |
---|
| 629 | * Set whether ADTree structure is used or not |
---|
| 630 | * @param bUseADTree true if an ADTree structure is used |
---|
| 631 | */ |
---|
| 632 | public void setUseADTree(boolean bUseADTree) { |
---|
| 633 | m_bUseADTree = bUseADTree; |
---|
| 634 | } |
---|
| 635 | |
---|
| 636 | /** |
---|
| 637 | * Method declaration |
---|
| 638 | * @return whether ADTree structure is used or not |
---|
| 639 | */ |
---|
| 640 | public boolean getUseADTree() { |
---|
| 641 | return m_bUseADTree; |
---|
| 642 | } |
---|
| 643 | |
---|
| 644 | /** |
---|
| 645 | * Set name of network in BIF file to compare with |
---|
| 646 | * @param sBIFFile the name of the BIF file |
---|
| 647 | */ |
---|
| 648 | public void setBIFFile(String sBIFFile) { |
---|
| 649 | try { |
---|
| 650 | m_otherBayesNet = new BIFReader().processFile(sBIFFile); |
---|
| 651 | } catch (Throwable t) { |
---|
| 652 | m_otherBayesNet = null; |
---|
| 653 | } |
---|
| 654 | } |
---|
| 655 | |
---|
| 656 | /** |
---|
| 657 | * Get name of network in BIF file to compare with |
---|
| 658 | * @return BIF file name |
---|
| 659 | */ |
---|
| 660 | public String getBIFFile() { |
---|
| 661 | if (m_otherBayesNet != null) { |
---|
| 662 | return m_otherBayesNet.getFileName(); |
---|
| 663 | } |
---|
| 664 | return ""; |
---|
| 665 | } |
---|
| 666 | |
---|
| 667 | |
---|
| 668 | /** |
---|
| 669 | * Returns a description of the classifier. |
---|
| 670 | * |
---|
| 671 | * @return a description of the classifier as a string. |
---|
| 672 | */ |
---|
| 673 | public String toString() { |
---|
| 674 | StringBuffer text = new StringBuffer(); |
---|
| 675 | |
---|
| 676 | text.append("Bayes Network Classifier"); |
---|
| 677 | text.append("\n" + (m_bUseADTree ? "Using " : "not using ") + "ADTree"); |
---|
| 678 | |
---|
| 679 | if (m_Instances == null) { |
---|
| 680 | text.append(": No model built yet."); |
---|
| 681 | } else { |
---|
| 682 | |
---|
| 683 | // flatten BayesNet down to text |
---|
| 684 | text.append("\n#attributes="); |
---|
| 685 | text.append(m_Instances.numAttributes()); |
---|
| 686 | text.append(" #classindex="); |
---|
| 687 | text.append(m_Instances.classIndex()); |
---|
| 688 | text.append("\nNetwork structure (nodes followed by parents)\n"); |
---|
| 689 | |
---|
| 690 | for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) { |
---|
| 691 | text.append( |
---|
| 692 | m_Instances.attribute(iAttribute).name() |
---|
| 693 | + "(" |
---|
| 694 | + m_Instances.attribute(iAttribute).numValues() |
---|
| 695 | + "): "); |
---|
| 696 | |
---|
| 697 | for (int iParent = 0; iParent < m_ParentSets[iAttribute].getNrOfParents(); iParent++) { |
---|
| 698 | text.append(m_Instances.attribute(m_ParentSets[iAttribute].getParent(iParent)).name() + " "); |
---|
| 699 | } |
---|
| 700 | |
---|
| 701 | text.append("\n"); |
---|
| 702 | |
---|
| 703 | // Description of distributions tends to be too much detail, so it is commented out here |
---|
| 704 | // for (int iParent = 0; iParent < m_ParentSets[iAttribute].GetCardinalityOfParents(); iParent++) { |
---|
| 705 | // text.append('(' + m_Distributions[iAttribute][iParent].toString() + ')'); |
---|
| 706 | // } |
---|
| 707 | // text.append("\n"); |
---|
| 708 | } |
---|
| 709 | |
---|
| 710 | text.append("LogScore Bayes: " + measureBayesScore() + "\n"); |
---|
| 711 | text.append("LogScore BDeu: " + measureBDeuScore() + "\n"); |
---|
| 712 | text.append("LogScore MDL: " + measureMDLScore() + "\n"); |
---|
| 713 | text.append("LogScore ENTROPY: " + measureEntropyScore() + "\n"); |
---|
| 714 | text.append("LogScore AIC: " + measureAICScore() + "\n"); |
---|
| 715 | |
---|
| 716 | if (m_otherBayesNet != null) { |
---|
| 717 | text.append( |
---|
| 718 | "Missing: " |
---|
| 719 | + m_otherBayesNet.missingArcs(this) |
---|
| 720 | + " Extra: " |
---|
| 721 | + m_otherBayesNet.extraArcs(this) |
---|
| 722 | + " Reversed: " |
---|
| 723 | + m_otherBayesNet.reversedArcs(this) |
---|
| 724 | + "\n"); |
---|
| 725 | text.append("Divergence: " + m_otherBayesNet.divergence(this) + "\n"); |
---|
| 726 | } |
---|
| 727 | } |
---|
| 728 | |
---|
| 729 | return text.toString(); |
---|
| 730 | } // toString |
---|
| 731 | |
---|
| 732 | |
---|
| 733 | /** |
---|
| 734 | * Returns the type of graph this classifier |
---|
| 735 | * represents. |
---|
| 736 | * @return Drawable.TREE |
---|
| 737 | */ |
---|
| 738 | public int graphType() { |
---|
| 739 | return Drawable.BayesNet; |
---|
| 740 | } |
---|
| 741 | |
---|
| 742 | /** |
---|
| 743 | * Returns a BayesNet graph in XMLBIF ver 0.3 format. |
---|
| 744 | * @return String representing this BayesNet in XMLBIF ver 0.3 |
---|
| 745 | * @throws Exception in case BIF generation fails |
---|
| 746 | */ |
---|
| 747 | public String graph() throws Exception { |
---|
| 748 | return toXMLBIF03(); |
---|
| 749 | } |
---|
| 750 | |
---|
| 751 | public String getBIFHeader() { |
---|
| 752 | StringBuffer text = new StringBuffer(); |
---|
| 753 | text.append("<?xml version=\"1.0\"?>\n"); |
---|
| 754 | text.append("<!-- DTD for the XMLBIF 0.3 format -->\n"); |
---|
| 755 | text.append("<!DOCTYPE BIF [\n"); |
---|
| 756 | text.append(" <!ELEMENT BIF ( NETWORK )*>\n"); |
---|
| 757 | text.append(" <!ATTLIST BIF VERSION CDATA #REQUIRED>\n"); |
---|
| 758 | text.append(" <!ELEMENT NETWORK ( NAME, ( PROPERTY | VARIABLE | DEFINITION )* )>\n"); |
---|
| 759 | text.append(" <!ELEMENT NAME (#PCDATA)>\n"); |
---|
| 760 | text.append(" <!ELEMENT VARIABLE ( NAME, ( OUTCOME | PROPERTY )* ) >\n"); |
---|
| 761 | text.append(" <!ATTLIST VARIABLE TYPE (nature|decision|utility) \"nature\">\n"); |
---|
| 762 | text.append(" <!ELEMENT OUTCOME (#PCDATA)>\n"); |
---|
| 763 | text.append(" <!ELEMENT DEFINITION ( FOR | GIVEN | TABLE | PROPERTY )* >\n"); |
---|
| 764 | text.append(" <!ELEMENT FOR (#PCDATA)>\n"); |
---|
| 765 | text.append(" <!ELEMENT GIVEN (#PCDATA)>\n"); |
---|
| 766 | text.append(" <!ELEMENT TABLE (#PCDATA)>\n"); |
---|
| 767 | text.append(" <!ELEMENT PROPERTY (#PCDATA)>\n"); |
---|
| 768 | text.append("]>\n"); |
---|
| 769 | return text.toString(); |
---|
| 770 | } // getBIFHeader |
---|
| 771 | |
---|
| 772 | /** |
---|
| 773 | * Returns a description of the classifier in XML BIF 0.3 format. |
---|
| 774 | * See http://www-2.cs.cmu.edu/~fgcozman/Research/InterchangeFormat/ |
---|
| 775 | * for details on XML BIF. |
---|
| 776 | * @return an XML BIF 0.3 description of the classifier as a string. |
---|
| 777 | */ |
---|
| 778 | public String toXMLBIF03() { |
---|
| 779 | if (m_Instances == null) { |
---|
| 780 | return("<!--No model built yet-->"); |
---|
| 781 | } |
---|
| 782 | |
---|
| 783 | StringBuffer text = new StringBuffer(); |
---|
| 784 | text.append(getBIFHeader()); |
---|
| 785 | text.append("\n"); |
---|
| 786 | text.append("\n"); |
---|
| 787 | text.append("<BIF VERSION=\"0.3\">\n"); |
---|
| 788 | text.append("<NETWORK>\n"); |
---|
| 789 | text.append("<NAME>" + XMLNormalize(m_Instances.relationName()) + "</NAME>\n"); |
---|
| 790 | for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) { |
---|
| 791 | text.append("<VARIABLE TYPE=\"nature\">\n"); |
---|
| 792 | text.append("<NAME>" + XMLNormalize(m_Instances.attribute(iAttribute).name()) + "</NAME>\n"); |
---|
| 793 | for (int iValue = 0; iValue < m_Instances.attribute(iAttribute).numValues(); iValue++) { |
---|
| 794 | text.append("<OUTCOME>" + XMLNormalize(m_Instances.attribute(iAttribute).value(iValue)) + "</OUTCOME>\n"); |
---|
| 795 | } |
---|
| 796 | text.append("</VARIABLE>\n"); |
---|
| 797 | } |
---|
| 798 | |
---|
| 799 | for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) { |
---|
| 800 | text.append("<DEFINITION>\n"); |
---|
| 801 | text.append("<FOR>" + XMLNormalize(m_Instances.attribute(iAttribute).name()) + "</FOR>\n"); |
---|
| 802 | for (int iParent = 0; iParent < m_ParentSets[iAttribute].getNrOfParents(); iParent++) { |
---|
| 803 | text.append("<GIVEN>" |
---|
| 804 | + XMLNormalize(m_Instances.attribute(m_ParentSets[iAttribute].getParent(iParent)).name()) + |
---|
| 805 | "</GIVEN>\n"); |
---|
| 806 | } |
---|
| 807 | text.append("<TABLE>\n"); |
---|
| 808 | for (int iParent = 0; iParent < m_ParentSets[iAttribute].getCardinalityOfParents(); iParent++) { |
---|
| 809 | for (int iValue = 0; iValue < m_Instances.attribute(iAttribute).numValues(); iValue++) { |
---|
| 810 | text.append(m_Distributions[iAttribute][iParent].getProbability(iValue)); |
---|
| 811 | text.append(' '); |
---|
| 812 | } |
---|
| 813 | text.append('\n'); |
---|
| 814 | } |
---|
| 815 | text.append("</TABLE>\n"); |
---|
| 816 | text.append("</DEFINITION>\n"); |
---|
| 817 | } |
---|
| 818 | text.append("</NETWORK>\n"); |
---|
| 819 | text.append("</BIF>\n"); |
---|
| 820 | return text.toString(); |
---|
| 821 | } // toXMLBIF03 |
---|
| 822 | |
---|
| 823 | |
---|
| 824 | /** XMLNormalize converts the five standard XML entities in a string |
---|
| 825 | * g.e. the string V&D's is returned as V&D's |
---|
| 826 | * @param sStr string to normalize |
---|
| 827 | * @return normalized string |
---|
| 828 | */ |
---|
| 829 | protected String XMLNormalize(String sStr) { |
---|
| 830 | StringBuffer sStr2 = new StringBuffer(); |
---|
| 831 | for (int iStr = 0; iStr < sStr.length(); iStr++) { |
---|
| 832 | char c = sStr.charAt(iStr); |
---|
| 833 | switch (c) { |
---|
| 834 | case '&': sStr2.append("&"); break; |
---|
| 835 | case '\'': sStr2.append("'"); break; |
---|
| 836 | case '\"': sStr2.append("""); break; |
---|
| 837 | case '<': sStr2.append("<"); break; |
---|
| 838 | case '>': sStr2.append(">"); break; |
---|
| 839 | default: |
---|
| 840 | sStr2.append(c); |
---|
| 841 | } |
---|
| 842 | } |
---|
| 843 | return sStr2.toString(); |
---|
| 844 | } // XMLNormalize |
---|
| 845 | |
---|
| 846 | |
---|
| 847 | /** |
---|
| 848 | * @return a string to describe the UseADTreeoption. |
---|
| 849 | */ |
---|
| 850 | public String useADTreeTipText() { |
---|
| 851 | return "When ADTree (the data structure for increasing speed on counts," |
---|
| 852 | + " not to be confused with the classifier under the same name) is used" |
---|
| 853 | + " learning time goes down typically. However, because ADTrees are memory" |
---|
| 854 | + " intensive, memory problems may occur. Switching this option off makes" |
---|
| 855 | + " the structure learning algorithms slower, and run with less memory." |
---|
| 856 | + " By default, ADTrees are used."; |
---|
| 857 | } |
---|
| 858 | |
---|
| 859 | /** |
---|
| 860 | * @return a string to describe the SearchAlgorithm. |
---|
| 861 | */ |
---|
| 862 | public String searchAlgorithmTipText() { |
---|
| 863 | return "Select method used for searching network structures."; |
---|
| 864 | } |
---|
| 865 | |
---|
| 866 | /** |
---|
| 867 | * This will return a string describing the BayesNetEstimator. |
---|
| 868 | * @return The string. |
---|
| 869 | */ |
---|
| 870 | public String estimatorTipText() { |
---|
| 871 | return "Select Estimator algorithm for finding the conditional probability tables" |
---|
| 872 | + " of the Bayes Network."; |
---|
| 873 | } |
---|
| 874 | |
---|
| 875 | /** |
---|
| 876 | * @return a string to describe the BIFFile. |
---|
| 877 | */ |
---|
| 878 | public String BIFFileTipText() { |
---|
| 879 | return "Set the name of a file in BIF XML format. A Bayes network learned" |
---|
| 880 | + " from data can be compared with the Bayes network represented by the BIF file." |
---|
| 881 | + " Statistics calculated are o.a. the number of missing and extra arcs."; |
---|
| 882 | } |
---|
| 883 | |
---|
| 884 | /** |
---|
| 885 | * This will return a string describing the classifier. |
---|
| 886 | * @return The string. |
---|
| 887 | */ |
---|
| 888 | public String globalInfo() { |
---|
| 889 | return |
---|
| 890 | "Bayes Network learning using various search algorithms and " |
---|
| 891 | + "quality measures.\n" |
---|
| 892 | + "Base class for a Bayes Network classifier. Provides " |
---|
| 893 | + "datastructures (network structure, conditional probability " |
---|
| 894 | + "distributions, etc.) and facilities common to Bayes Network " |
---|
| 895 | + "learning algorithms like K2 and B.\n\n" |
---|
| 896 | + "For more information see:\n\n" |
---|
| 897 | + "http://www.cs.waikato.ac.nz/~remco/weka.pdf"; |
---|
| 898 | } |
---|
| 899 | |
---|
| 900 | /** |
---|
| 901 | * Main method for testing this class. |
---|
| 902 | * |
---|
| 903 | * @param argv the options |
---|
| 904 | */ |
---|
| 905 | public static void main(String[] argv) { |
---|
| 906 | runClassifier(new BayesNet(), argv); |
---|
| 907 | } // main |
---|
| 908 | |
---|
| 909 | /** get name of the Bayes network |
---|
| 910 | * @return name of the Bayes net |
---|
| 911 | */ |
---|
| 912 | public String getName() { |
---|
| 913 | return m_Instances.relationName(); |
---|
| 914 | } |
---|
| 915 | |
---|
| 916 | /** get number of nodes in the Bayes network |
---|
| 917 | * @return number of nodes |
---|
| 918 | */ |
---|
| 919 | public int getNrOfNodes() { |
---|
| 920 | return m_Instances.numAttributes(); |
---|
| 921 | } |
---|
| 922 | |
---|
| 923 | /** get name of a node in the Bayes network |
---|
| 924 | * @param iNode index of the node |
---|
| 925 | * @return name of the specified node |
---|
| 926 | */ |
---|
| 927 | public String getNodeName(int iNode) { |
---|
| 928 | return m_Instances.attribute(iNode).name(); |
---|
| 929 | } |
---|
| 930 | |
---|
| 931 | /** get number of values a node can take |
---|
| 932 | * @param iNode index of the node |
---|
| 933 | * @return cardinality of the specified node |
---|
| 934 | */ |
---|
| 935 | public int getCardinality(int iNode) { |
---|
| 936 | return m_Instances.attribute(iNode).numValues(); |
---|
| 937 | } |
---|
| 938 | |
---|
| 939 | /** get name of a particular value of a node |
---|
| 940 | * @param iNode index of the node |
---|
| 941 | * @param iValue index of the value |
---|
| 942 | * @return cardinality of the specified node |
---|
| 943 | */ |
---|
| 944 | public String getNodeValue(int iNode, int iValue) { |
---|
| 945 | return m_Instances.attribute(iNode).value(iValue); |
---|
| 946 | } |
---|
| 947 | |
---|
| 948 | /** get number of parents of a node in the network structure |
---|
| 949 | * @param iNode index of the node |
---|
| 950 | * @return number of parents of the specified node |
---|
| 951 | */ |
---|
| 952 | public int getNrOfParents(int iNode) { |
---|
| 953 | return m_ParentSets[iNode].getNrOfParents(); |
---|
| 954 | } |
---|
| 955 | |
---|
| 956 | /** get node index of a parent of a node in the network structure |
---|
| 957 | * @param iNode index of the node |
---|
| 958 | * @param iParent index of the parents, e.g., 0 is the first parent, 1 the second parent, etc. |
---|
| 959 | * @return node index of the iParent's parent of the specified node |
---|
| 960 | */ |
---|
| 961 | public int getParent(int iNode, int iParent) { |
---|
| 962 | return m_ParentSets[iNode].getParent(iParent); |
---|
| 963 | } |
---|
| 964 | |
---|
| 965 | /** Get full set of parent sets. |
---|
| 966 | * @return parent sets; |
---|
| 967 | */ |
---|
| 968 | public ParentSet[] getParentSets() { |
---|
| 969 | return m_ParentSets; |
---|
| 970 | } |
---|
| 971 | |
---|
| 972 | /** Get full set of estimators. |
---|
| 973 | * @return estimators; |
---|
| 974 | */ |
---|
| 975 | public Estimator[][] getDistributions() { |
---|
| 976 | return m_Distributions; |
---|
| 977 | } |
---|
| 978 | |
---|
| 979 | /** get number of values the collection of parents of a node can take |
---|
| 980 | * @param iNode index of the node |
---|
| 981 | * @return cardinality of the parent set of the specified node |
---|
| 982 | */ |
---|
| 983 | public int getParentCardinality(int iNode) { |
---|
| 984 | return m_ParentSets[iNode].getCardinalityOfParents(); |
---|
| 985 | } |
---|
| 986 | |
---|
| 987 | /** get particular probability of the conditional probability distribtion |
---|
| 988 | * of a node given its parents. |
---|
| 989 | * @param iNode index of the node |
---|
| 990 | * @param iParent index of the parent set, 0 <= iParent <= getParentCardinality(iNode) |
---|
| 991 | * @param iValue index of the value, 0 <= iValue <= getCardinality(iNode) |
---|
| 992 | * @return probability |
---|
| 993 | */ |
---|
| 994 | public double getProbability(int iNode, int iParent, int iValue) { |
---|
| 995 | return m_Distributions[iNode][iParent].getProbability(iValue); |
---|
| 996 | } |
---|
| 997 | |
---|
| 998 | /** get the parent set of a node |
---|
| 999 | * @param iNode index of the node |
---|
| 1000 | * @return Parent set of the specified node. |
---|
| 1001 | */ |
---|
| 1002 | public ParentSet getParentSet(int iNode) { |
---|
| 1003 | return m_ParentSets[iNode]; |
---|
| 1004 | } |
---|
| 1005 | |
---|
| 1006 | /** get ADTree strucrture containing efficient representation of counts. |
---|
| 1007 | * @return ADTree strucrture |
---|
| 1008 | */ |
---|
| 1009 | public ADNode getADTree() { return m_ADTree;} |
---|
| 1010 | |
---|
| 1011 | // implementation of AdditionalMeasureProducer interface |
---|
| 1012 | /** |
---|
| 1013 | * Returns an enumeration of the measure names. Additional measures |
---|
| 1014 | * must follow the naming convention of starting with "measure", eg. |
---|
| 1015 | * double measureBlah() |
---|
| 1016 | * @return an enumeration of the measure names |
---|
| 1017 | */ |
---|
| 1018 | public Enumeration enumerateMeasures() { |
---|
| 1019 | Vector newVector = new Vector(4); |
---|
| 1020 | newVector.addElement("measureExtraArcs"); |
---|
| 1021 | newVector.addElement("measureMissingArcs"); |
---|
| 1022 | newVector.addElement("measureReversedArcs"); |
---|
| 1023 | newVector.addElement("measureDivergence"); |
---|
| 1024 | newVector.addElement("measureBayesScore"); |
---|
| 1025 | newVector.addElement("measureBDeuScore"); |
---|
| 1026 | newVector.addElement("measureMDLScore"); |
---|
| 1027 | newVector.addElement("measureAICScore"); |
---|
| 1028 | newVector.addElement("measureEntropyScore"); |
---|
| 1029 | return newVector.elements(); |
---|
| 1030 | } // enumerateMeasures |
---|
| 1031 | |
---|
| 1032 | public double measureExtraArcs() { |
---|
| 1033 | if (m_otherBayesNet != null) { |
---|
| 1034 | return m_otherBayesNet.extraArcs(this); |
---|
| 1035 | } |
---|
| 1036 | return 0; |
---|
| 1037 | } // measureExtraArcs |
---|
| 1038 | |
---|
| 1039 | public double measureMissingArcs() { |
---|
| 1040 | if (m_otherBayesNet != null) { |
---|
| 1041 | return m_otherBayesNet.missingArcs(this); |
---|
| 1042 | } |
---|
| 1043 | return 0; |
---|
| 1044 | } // measureMissingArcs |
---|
| 1045 | |
---|
| 1046 | public double measureReversedArcs() { |
---|
| 1047 | if (m_otherBayesNet != null) { |
---|
| 1048 | return m_otherBayesNet.reversedArcs(this); |
---|
| 1049 | } |
---|
| 1050 | return 0; |
---|
| 1051 | } // measureReversedArcs |
---|
| 1052 | |
---|
| 1053 | public double measureDivergence() { |
---|
| 1054 | if (m_otherBayesNet != null) { |
---|
| 1055 | return m_otherBayesNet.divergence(this); |
---|
| 1056 | } |
---|
| 1057 | return 0; |
---|
| 1058 | } // measureDivergence |
---|
| 1059 | |
---|
| 1060 | public double measureBayesScore() { |
---|
| 1061 | LocalScoreSearchAlgorithm s = new LocalScoreSearchAlgorithm(this, m_Instances); |
---|
| 1062 | return s.logScore(Scoreable.BAYES); |
---|
| 1063 | } // measureBayesScore |
---|
| 1064 | |
---|
| 1065 | public double measureBDeuScore() { |
---|
| 1066 | LocalScoreSearchAlgorithm s = new LocalScoreSearchAlgorithm(this, m_Instances); |
---|
| 1067 | return s.logScore(Scoreable.BDeu); |
---|
| 1068 | } // measureBDeuScore |
---|
| 1069 | |
---|
| 1070 | public double measureMDLScore() { |
---|
| 1071 | LocalScoreSearchAlgorithm s = new LocalScoreSearchAlgorithm(this, m_Instances); |
---|
| 1072 | return s.logScore(Scoreable.MDL); |
---|
| 1073 | } // measureMDLScore |
---|
| 1074 | |
---|
| 1075 | public double measureAICScore() { |
---|
| 1076 | LocalScoreSearchAlgorithm s = new LocalScoreSearchAlgorithm(this, m_Instances); |
---|
| 1077 | return s.logScore(Scoreable.AIC); |
---|
| 1078 | } // measureAICScore |
---|
| 1079 | |
---|
| 1080 | public double measureEntropyScore() { |
---|
| 1081 | LocalScoreSearchAlgorithm s = new LocalScoreSearchAlgorithm(this, m_Instances); |
---|
| 1082 | return s.logScore(Scoreable.ENTROPY); |
---|
| 1083 | } // measureEntropyScore |
---|
| 1084 | |
---|
| 1085 | /** |
---|
| 1086 | * Returns the value of the named measure |
---|
| 1087 | * @param measureName the name of the measure to query for its value |
---|
| 1088 | * @return the value of the named measure |
---|
| 1089 | * @throws IllegalArgumentException if the named measure is not supported |
---|
| 1090 | */ |
---|
| 1091 | public double getMeasure(String measureName) { |
---|
| 1092 | if (measureName.equals("measureExtraArcs")) { |
---|
| 1093 | return measureExtraArcs(); |
---|
| 1094 | } |
---|
| 1095 | if (measureName.equals("measureMissingArcs")) { |
---|
| 1096 | return measureMissingArcs(); |
---|
| 1097 | } |
---|
| 1098 | if (measureName.equals("measureReversedArcs")) { |
---|
| 1099 | return measureReversedArcs(); |
---|
| 1100 | } |
---|
| 1101 | if (measureName.equals("measureDivergence")) { |
---|
| 1102 | return measureDivergence(); |
---|
| 1103 | } |
---|
| 1104 | if (measureName.equals("measureBayesScore")) { |
---|
| 1105 | return measureBayesScore(); |
---|
| 1106 | } |
---|
| 1107 | if (measureName.equals("measureBDeuScore")) { |
---|
| 1108 | return measureBDeuScore(); |
---|
| 1109 | } |
---|
| 1110 | if (measureName.equals("measureMDLScore")) { |
---|
| 1111 | return measureMDLScore(); |
---|
| 1112 | } |
---|
| 1113 | if (measureName.equals("measureAICScore")) { |
---|
| 1114 | return measureAICScore(); |
---|
| 1115 | } |
---|
| 1116 | if (measureName.equals("measureEntropyScore")) { |
---|
| 1117 | return measureEntropyScore(); |
---|
| 1118 | } |
---|
| 1119 | return 0; |
---|
| 1120 | } // getMeasure |
---|
| 1121 | |
---|
| 1122 | /** |
---|
| 1123 | * Returns the revision string. |
---|
| 1124 | * |
---|
| 1125 | * @return the revision |
---|
| 1126 | */ |
---|
| 1127 | public String getRevision() { |
---|
| 1128 | return RevisionUtils.extract("$Revision: 5928 $"); |
---|
| 1129 | } |
---|
| 1130 | } // class BayesNet |
---|