1 | /* |
---|
2 | * This program is free software; you can redistribute it and/or modify |
---|
3 | * it under the terms of the GNU General Public License as published by |
---|
4 | * the Free Software Foundation; either version 2 of the License, or |
---|
5 | * (at your option) any later version. |
---|
6 | * |
---|
7 | * This program is distributed in the hope that it will be useful, |
---|
8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
---|
10 | * GNU General Public License for more details. |
---|
11 | * |
---|
12 | * You should have received a copy of the GNU General Public License |
---|
13 | * along with this program; if not, write to the Free Software |
---|
14 | * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. |
---|
15 | */ |
---|
16 | |
---|
17 | /* |
---|
18 | * BayesNet.java |
---|
19 | * Copyright (C) 2001 University of Waikato, Hamilton, New Zealand |
---|
20 | * |
---|
21 | */ |
---|
22 | package weka.classifiers.bayes; |
---|
23 | |
---|
24 | import weka.classifiers.Classifier; |
---|
25 | import weka.classifiers.AbstractClassifier; |
---|
26 | import weka.classifiers.bayes.net.ADNode; |
---|
27 | import weka.classifiers.bayes.net.BIFReader; |
---|
28 | import weka.classifiers.bayes.net.ParentSet; |
---|
29 | import weka.classifiers.bayes.net.estimate.BayesNetEstimator; |
---|
30 | import weka.classifiers.bayes.net.estimate.DiscreteEstimatorBayes; |
---|
31 | import weka.classifiers.bayes.net.estimate.SimpleEstimator; |
---|
32 | import weka.classifiers.bayes.net.search.SearchAlgorithm; |
---|
33 | import weka.classifiers.bayes.net.search.local.K2; |
---|
34 | import weka.classifiers.bayes.net.search.local.LocalScoreSearchAlgorithm; |
---|
35 | import weka.classifiers.bayes.net.search.local.Scoreable; |
---|
36 | import weka.core.AdditionalMeasureProducer; |
---|
37 | import weka.core.Attribute; |
---|
38 | import weka.core.Capabilities; |
---|
39 | import weka.core.Drawable; |
---|
40 | import weka.core.Instance; |
---|
41 | import weka.core.Instances; |
---|
42 | import weka.core.Option; |
---|
43 | import weka.core.OptionHandler; |
---|
44 | import weka.core.RevisionUtils; |
---|
45 | import weka.core.Utils; |
---|
46 | import weka.core.WeightedInstancesHandler; |
---|
47 | import weka.core.Capabilities.Capability; |
---|
48 | import weka.estimators.Estimator; |
---|
49 | import weka.filters.Filter; |
---|
50 | import weka.filters.supervised.attribute.Discretize; |
---|
51 | import weka.filters.unsupervised.attribute.ReplaceMissingValues; |
---|
52 | |
---|
53 | import java.util.Enumeration; |
---|
54 | import java.util.Vector; |
---|
55 | |
---|
56 | /** |
---|
57 | <!-- globalinfo-start --> |
---|
58 | * Bayes Network learning using various search algorithms and quality measures.<br/> |
---|
59 | * Base class for a Bayes Network classifier. Provides datastructures (network structure, conditional probability distributions, etc.) and facilities common to Bayes Network learning algorithms like K2 and B.<br/> |
---|
60 | * <br/> |
---|
61 | * For more information see:<br/> |
---|
62 | * <br/> |
---|
63 | * http://sourceforge.net/projects/weka/files/documentation/WekaManual-3-7-0.pdf/download |
---|
64 | * <p/> |
---|
65 | <!-- globalinfo-end --> |
---|
66 | * |
---|
67 | <!-- options-start --> |
---|
68 | * Valid options are: <p/> |
---|
69 | * |
---|
70 | * <pre> -D |
---|
71 | * Do not use ADTree data structure |
---|
72 | * </pre> |
---|
73 | * |
---|
74 | * <pre> -B <BIF file> |
---|
75 | * BIF file to compare with |
---|
76 | * </pre> |
---|
77 | * |
---|
78 | * <pre> -Q weka.classifiers.bayes.net.search.SearchAlgorithm |
---|
79 | * Search algorithm |
---|
80 | * </pre> |
---|
81 | * |
---|
82 | * <pre> -E weka.classifiers.bayes.net.estimate.SimpleEstimator |
---|
83 | * Estimator algorithm |
---|
84 | * </pre> |
---|
85 | * |
---|
86 | <!-- options-end --> |
---|
87 | * |
---|
88 | * @author Remco Bouckaert (rrb@xm.co.nz) |
---|
89 | * @version $Revision: 5928 $ |
---|
90 | */ |
---|
91 | public class BayesNet |
---|
92 | extends AbstractClassifier |
---|
93 | implements OptionHandler, WeightedInstancesHandler, Drawable, |
---|
94 | AdditionalMeasureProducer { |
---|
95 | |
---|
96 | /** for serialization */ |
---|
97 | static final long serialVersionUID = 746037443258775954L; |
---|
98 | |
---|
99 | |
---|
100 | /** |
---|
101 | * The parent sets. |
---|
102 | */ |
---|
103 | protected ParentSet[] m_ParentSets; |
---|
104 | |
---|
105 | /** |
---|
106 | * The attribute estimators containing CPTs. |
---|
107 | */ |
---|
108 | public Estimator[][] m_Distributions; |
---|
109 | |
---|
110 | |
---|
111 | /** filter used to quantize continuous variables, if any **/ |
---|
112 | protected Discretize m_DiscretizeFilter = null; |
---|
113 | |
---|
114 | /** attribute index of a non-nominal attribute */ |
---|
115 | int m_nNonDiscreteAttribute = -1; |
---|
116 | |
---|
117 | /** filter used to fill in missing values, if any **/ |
---|
118 | protected ReplaceMissingValues m_MissingValuesFilter = null; |
---|
119 | |
---|
120 | /** |
---|
121 | * The number of classes |
---|
122 | */ |
---|
123 | protected int m_NumClasses; |
---|
124 | |
---|
125 | /** |
---|
126 | * The dataset header for the purposes of printing out a semi-intelligible |
---|
127 | * model |
---|
128 | */ |
---|
129 | public Instances m_Instances; |
---|
130 | |
---|
131 | /** |
---|
132 | * Datastructure containing ADTree representation of the database. |
---|
133 | * This may result in more efficient access to the data. |
---|
134 | */ |
---|
135 | ADNode m_ADTree; |
---|
136 | |
---|
137 | /** |
---|
138 | * Bayes network to compare the structure with. |
---|
139 | */ |
---|
140 | protected BIFReader m_otherBayesNet = null; |
---|
141 | |
---|
142 | /** |
---|
143 | * Use the experimental ADTree datastructure for calculating contingency tables |
---|
144 | */ |
---|
145 | boolean m_bUseADTree = false; |
---|
146 | |
---|
147 | /** |
---|
148 | * Search algorithm used for learning the structure of a network. |
---|
149 | */ |
---|
150 | SearchAlgorithm m_SearchAlgorithm = new K2(); |
---|
151 | |
---|
152 | /** |
---|
153 | * Search algorithm used for learning the structure of a network. |
---|
154 | */ |
---|
155 | BayesNetEstimator m_BayesNetEstimator = new SimpleEstimator(); |
---|
156 | |
---|
157 | /** |
---|
158 | * Returns default capabilities of the classifier. |
---|
159 | * |
---|
160 | * @return the capabilities of this classifier |
---|
161 | */ |
---|
162 | public Capabilities getCapabilities() { |
---|
163 | Capabilities result = super.getCapabilities(); |
---|
164 | result.disableAll(); |
---|
165 | |
---|
166 | // attributes |
---|
167 | result.enable(Capability.NOMINAL_ATTRIBUTES); |
---|
168 | result.enable(Capability.NUMERIC_ATTRIBUTES); |
---|
169 | result.enable(Capability.MISSING_VALUES); |
---|
170 | |
---|
171 | // class |
---|
172 | result.enable(Capability.NOMINAL_CLASS); |
---|
173 | result.enable(Capability.MISSING_CLASS_VALUES); |
---|
174 | |
---|
175 | // instances |
---|
176 | result.setMinimumNumberInstances(0); |
---|
177 | |
---|
178 | return result; |
---|
179 | } |
---|
180 | |
---|
181 | /** |
---|
182 | * Generates the classifier. |
---|
183 | * |
---|
184 | * @param instances set of instances serving as training data |
---|
185 | * @throws Exception if the classifier has not been generated |
---|
186 | * successfully |
---|
187 | */ |
---|
188 | public void buildClassifier(Instances instances) throws Exception { |
---|
189 | |
---|
190 | // can classifier handle the data? |
---|
191 | getCapabilities().testWithFail(instances); |
---|
192 | |
---|
193 | // remove instances with missing class |
---|
194 | instances = new Instances(instances); |
---|
195 | instances.deleteWithMissingClass(); |
---|
196 | |
---|
197 | // ensure we have a data set with discrete variables only and with no missing values |
---|
198 | instances = normalizeDataSet(instances); |
---|
199 | |
---|
200 | // Copy the instances |
---|
201 | m_Instances = new Instances(instances); |
---|
202 | |
---|
203 | // sanity check: need more than 1 variable in datat set |
---|
204 | m_NumClasses = instances.numClasses(); |
---|
205 | |
---|
206 | // initialize ADTree |
---|
207 | if (m_bUseADTree) { |
---|
208 | m_ADTree = ADNode.makeADTree(instances); |
---|
209 | // System.out.println("Oef, done!"); |
---|
210 | } |
---|
211 | |
---|
212 | // build the network structure |
---|
213 | initStructure(); |
---|
214 | |
---|
215 | // build the network structure |
---|
216 | buildStructure(); |
---|
217 | |
---|
218 | // build the set of CPTs |
---|
219 | estimateCPTs(); |
---|
220 | |
---|
221 | // Save space |
---|
222 | // m_Instances = new Instances(m_Instances, 0); |
---|
223 | m_ADTree = null; |
---|
224 | } // buildClassifier |
---|
225 | |
---|
226 | /** ensure that all variables are nominal and that there are no missing values |
---|
227 | * @param instances data set to check and quantize and/or fill in missing values |
---|
228 | * @return filtered instances |
---|
229 | * @throws Exception if a filter (Discretize, ReplaceMissingValues) fails |
---|
230 | */ |
---|
231 | protected Instances normalizeDataSet(Instances instances) throws Exception { |
---|
232 | m_DiscretizeFilter = null; |
---|
233 | m_MissingValuesFilter = null; |
---|
234 | |
---|
235 | boolean bHasNonNominal = false; |
---|
236 | boolean bHasMissingValues = false; |
---|
237 | |
---|
238 | Enumeration enu = instances.enumerateAttributes(); |
---|
239 | while (enu.hasMoreElements()) { |
---|
240 | Attribute attribute = (Attribute) enu.nextElement(); |
---|
241 | if (attribute.type() != Attribute.NOMINAL) { |
---|
242 | m_nNonDiscreteAttribute = attribute.index(); |
---|
243 | bHasNonNominal = true; |
---|
244 | //throw new UnsupportedAttributeTypeException("BayesNet handles nominal variables only. Non-nominal variable in dataset detected."); |
---|
245 | } |
---|
246 | Enumeration enum2 = instances.enumerateInstances(); |
---|
247 | while (enum2.hasMoreElements()) { |
---|
248 | if (((Instance) enum2.nextElement()).isMissing(attribute)) { |
---|
249 | bHasMissingValues = true; |
---|
250 | // throw new NoSupportForMissingValuesException("BayesNet: no missing values, please."); |
---|
251 | } |
---|
252 | } |
---|
253 | } |
---|
254 | |
---|
255 | if (bHasNonNominal) { |
---|
256 | System.err.println("Warning: discretizing data set"); |
---|
257 | m_DiscretizeFilter = new Discretize(); |
---|
258 | m_DiscretizeFilter.setInputFormat(instances); |
---|
259 | instances = Filter.useFilter(instances, m_DiscretizeFilter); |
---|
260 | } |
---|
261 | |
---|
262 | if (bHasMissingValues) { |
---|
263 | System.err.println("Warning: filling in missing values in data set"); |
---|
264 | m_MissingValuesFilter = new ReplaceMissingValues(); |
---|
265 | m_MissingValuesFilter.setInputFormat(instances); |
---|
266 | instances = Filter.useFilter(instances, m_MissingValuesFilter); |
---|
267 | } |
---|
268 | return instances; |
---|
269 | } // normalizeDataSet |
---|
270 | |
---|
271 | /** ensure that all variables are nominal and that there are no missing values |
---|
272 | * @param instance instance to check and quantize and/or fill in missing values |
---|
273 | * @return filtered instance |
---|
274 | * @throws Exception if a filter (Discretize, ReplaceMissingValues) fails |
---|
275 | */ |
---|
276 | protected Instance normalizeInstance(Instance instance) throws Exception { |
---|
277 | if ((m_DiscretizeFilter != null) && |
---|
278 | (instance.attribute(m_nNonDiscreteAttribute).type() != Attribute.NOMINAL)) { |
---|
279 | m_DiscretizeFilter.input(instance); |
---|
280 | instance = m_DiscretizeFilter.output(); |
---|
281 | } |
---|
282 | if (m_MissingValuesFilter != null) { |
---|
283 | m_MissingValuesFilter.input(instance); |
---|
284 | instance = m_MissingValuesFilter.output(); |
---|
285 | } else { |
---|
286 | // is there a missing value in this instance? |
---|
287 | // this can happen when there is no missing value in the training set |
---|
288 | for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) { |
---|
289 | if (iAttribute != instance.classIndex() && instance.isMissing(iAttribute)) { |
---|
290 | System.err.println("Warning: Found missing value in test set, filling in values."); |
---|
291 | m_MissingValuesFilter = new ReplaceMissingValues(); |
---|
292 | m_MissingValuesFilter.setInputFormat(m_Instances); |
---|
293 | Filter.useFilter(m_Instances, m_MissingValuesFilter); |
---|
294 | m_MissingValuesFilter.input(instance); |
---|
295 | instance = m_MissingValuesFilter.output(); |
---|
296 | iAttribute = m_Instances.numAttributes(); |
---|
297 | } |
---|
298 | } |
---|
299 | } |
---|
300 | return instance; |
---|
301 | } // normalizeInstance |
---|
302 | |
---|
303 | /** |
---|
304 | * Init structure initializes the structure to an empty graph or a Naive Bayes |
---|
305 | * graph (depending on the -N flag). |
---|
306 | * |
---|
307 | * @throws Exception in case of an error |
---|
308 | */ |
---|
309 | public void initStructure() throws Exception { |
---|
310 | |
---|
311 | // initialize topological ordering |
---|
312 | // m_nOrder = new int[m_Instances.numAttributes()]; |
---|
313 | // m_nOrder[0] = m_Instances.classIndex(); |
---|
314 | |
---|
315 | int nAttribute = 0; |
---|
316 | |
---|
317 | for (int iOrder = 1; iOrder < m_Instances.numAttributes(); iOrder++) { |
---|
318 | if (nAttribute == m_Instances.classIndex()) { |
---|
319 | nAttribute++; |
---|
320 | } |
---|
321 | |
---|
322 | // m_nOrder[iOrder] = nAttribute++; |
---|
323 | } |
---|
324 | |
---|
325 | // reserve memory |
---|
326 | m_ParentSets = new ParentSet[m_Instances.numAttributes()]; |
---|
327 | |
---|
328 | for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) { |
---|
329 | m_ParentSets[iAttribute] = new ParentSet(m_Instances.numAttributes()); |
---|
330 | } |
---|
331 | } // initStructure |
---|
332 | |
---|
333 | /** |
---|
334 | * buildStructure determines the network structure/graph of the network. |
---|
335 | * The default behavior is creating a network where all nodes have the first |
---|
336 | * node as its parent (i.e., a BayesNet that behaves like a naive Bayes classifier). |
---|
337 | * This method can be overridden by derived classes to restrict the class |
---|
338 | * of network structures that are acceptable. |
---|
339 | * |
---|
340 | * @throws Exception in case of an error |
---|
341 | */ |
---|
342 | public void buildStructure() throws Exception { |
---|
343 | m_SearchAlgorithm.buildStructure(this, m_Instances); |
---|
344 | } // buildStructure |
---|
345 | |
---|
346 | /** |
---|
347 | * estimateCPTs estimates the conditional probability tables for the Bayes |
---|
348 | * Net using the network structure. |
---|
349 | * |
---|
350 | * @throws Exception in case of an error |
---|
351 | */ |
---|
352 | public void estimateCPTs() throws Exception { |
---|
353 | m_BayesNetEstimator.estimateCPTs(this); |
---|
354 | } // estimateCPTs |
---|
355 | |
---|
356 | /** |
---|
357 | * initializes the conditional probabilities |
---|
358 | * |
---|
359 | * @throws Exception in case of an error |
---|
360 | */ |
---|
361 | public void initCPTs() throws Exception { |
---|
362 | m_BayesNetEstimator.initCPTs(this); |
---|
363 | } // estimateCPTs |
---|
364 | |
---|
365 | /** |
---|
366 | * Updates the classifier with the given instance. |
---|
367 | * |
---|
368 | * @param instance the new training instance to include in the model |
---|
369 | * @throws Exception if the instance could not be incorporated in |
---|
370 | * the model. |
---|
371 | */ |
---|
372 | public void updateClassifier(Instance instance) throws Exception { |
---|
373 | instance = normalizeInstance(instance); |
---|
374 | m_BayesNetEstimator.updateClassifier(this, instance); |
---|
375 | } // updateClassifier |
---|
376 | |
---|
377 | /** |
---|
378 | * Calculates the class membership probabilities for the given test |
---|
379 | * instance. |
---|
380 | * |
---|
381 | * @param instance the instance to be classified |
---|
382 | * @return predicted class probability distribution |
---|
383 | * @throws Exception if there is a problem generating the prediction |
---|
384 | */ |
---|
385 | public double[] distributionForInstance(Instance instance) throws Exception { |
---|
386 | instance = normalizeInstance(instance); |
---|
387 | return m_BayesNetEstimator.distributionForInstance(this, instance); |
---|
388 | } // distributionForInstance |
---|
389 | |
---|
390 | /** |
---|
391 | * Calculates the counts for Dirichlet distribution for the |
---|
392 | * class membership probabilities for the given test instance. |
---|
393 | * |
---|
394 | * @param instance the instance to be classified |
---|
395 | * @return counts for Dirichlet distribution for class probability |
---|
396 | * @throws Exception if there is a problem generating the prediction |
---|
397 | */ |
---|
398 | public double[] countsForInstance(Instance instance) throws Exception { |
---|
399 | double[] fCounts = new double[m_NumClasses]; |
---|
400 | |
---|
401 | for (int iClass = 0; iClass < m_NumClasses; iClass++) { |
---|
402 | fCounts[iClass] = 0.0; |
---|
403 | } |
---|
404 | |
---|
405 | for (int iClass = 0; iClass < m_NumClasses; iClass++) { |
---|
406 | double fCount = 0; |
---|
407 | |
---|
408 | for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) { |
---|
409 | double iCPT = 0; |
---|
410 | |
---|
411 | for (int iParent = 0; iParent < m_ParentSets[iAttribute].getNrOfParents(); iParent++) { |
---|
412 | int nParent = m_ParentSets[iAttribute].getParent(iParent); |
---|
413 | |
---|
414 | if (nParent == m_Instances.classIndex()) { |
---|
415 | iCPT = iCPT * m_NumClasses + iClass; |
---|
416 | } else { |
---|
417 | iCPT = iCPT * m_Instances.attribute(nParent).numValues() + instance.value(nParent); |
---|
418 | } |
---|
419 | } |
---|
420 | |
---|
421 | if (iAttribute == m_Instances.classIndex()) { |
---|
422 | fCount += ((DiscreteEstimatorBayes) m_Distributions[iAttribute][(int) iCPT]).getCount(iClass); |
---|
423 | } else { |
---|
424 | fCount |
---|
425 | += ((DiscreteEstimatorBayes) m_Distributions[iAttribute][(int) iCPT]).getCount( |
---|
426 | instance.value(iAttribute)); |
---|
427 | } |
---|
428 | } |
---|
429 | |
---|
430 | fCounts[iClass] += fCount; |
---|
431 | } |
---|
432 | return fCounts; |
---|
433 | } // countsForInstance |
---|
434 | |
---|
435 | /** |
---|
436 | * Returns an enumeration describing the available options |
---|
437 | * |
---|
438 | * @return an enumeration of all the available options |
---|
439 | */ |
---|
440 | public Enumeration listOptions() { |
---|
441 | Vector newVector = new Vector(4); |
---|
442 | |
---|
443 | newVector.addElement(new Option("\tDo not use ADTree data structure\n", "D", 0, "-D")); |
---|
444 | newVector.addElement(new Option("\tBIF file to compare with\n", "B", 1, "-B <BIF file>")); |
---|
445 | newVector.addElement(new Option("\tSearch algorithm\n", "Q", 1, "-Q weka.classifiers.bayes.net.search.SearchAlgorithm")); |
---|
446 | newVector.addElement(new Option("\tEstimator algorithm\n", "E", 1, "-E weka.classifiers.bayes.net.estimate.SimpleEstimator")); |
---|
447 | |
---|
448 | return newVector.elements(); |
---|
449 | } // listOptions |
---|
450 | |
---|
451 | /** |
---|
452 | * Parses a given list of options. <p> |
---|
453 | * |
---|
454 | <!-- options-start --> |
---|
455 | * Valid options are: <p/> |
---|
456 | * |
---|
457 | * <pre> -D |
---|
458 | * Do not use ADTree data structure |
---|
459 | * </pre> |
---|
460 | * |
---|
461 | * <pre> -B <BIF file> |
---|
462 | * BIF file to compare with |
---|
463 | * </pre> |
---|
464 | * |
---|
465 | * <pre> -Q weka.classifiers.bayes.net.search.SearchAlgorithm |
---|
466 | * Search algorithm |
---|
467 | * </pre> |
---|
468 | * |
---|
469 | * <pre> -E weka.classifiers.bayes.net.estimate.SimpleEstimator |
---|
470 | * Estimator algorithm |
---|
471 | * </pre> |
---|
472 | * |
---|
473 | <!-- options-end --> |
---|
474 | * |
---|
475 | * @param options the list of options as an array of strings |
---|
476 | * @throws Exception if an option is not supported |
---|
477 | */ |
---|
478 | public void setOptions(String[] options) throws Exception { |
---|
479 | m_bUseADTree = !(Utils.getFlag('D', options)); |
---|
480 | |
---|
481 | String sBIFFile = Utils.getOption('B', options); |
---|
482 | if (sBIFFile != null && !sBIFFile.equals("")) { |
---|
483 | setBIFFile(sBIFFile); |
---|
484 | } |
---|
485 | |
---|
486 | String searchAlgorithmName = Utils.getOption('Q', options); |
---|
487 | if (searchAlgorithmName.length() != 0) { |
---|
488 | setSearchAlgorithm( |
---|
489 | (SearchAlgorithm) Utils.forName( |
---|
490 | SearchAlgorithm.class, |
---|
491 | searchAlgorithmName, |
---|
492 | partitionOptions(options))); |
---|
493 | } |
---|
494 | else { |
---|
495 | setSearchAlgorithm(new K2()); |
---|
496 | } |
---|
497 | |
---|
498 | |
---|
499 | String estimatorName = Utils.getOption('E', options); |
---|
500 | if (estimatorName.length() != 0) { |
---|
501 | setEstimator( |
---|
502 | (BayesNetEstimator) Utils.forName( |
---|
503 | BayesNetEstimator.class, |
---|
504 | estimatorName, |
---|
505 | Utils.partitionOptions(options))); |
---|
506 | } |
---|
507 | else { |
---|
508 | setEstimator(new SimpleEstimator()); |
---|
509 | } |
---|
510 | |
---|
511 | Utils.checkForRemainingOptions(options); |
---|
512 | } // setOptions |
---|
513 | |
---|
514 | /** |
---|
515 | * Returns the secondary set of options (if any) contained in |
---|
516 | * the supplied options array. The secondary set is defined to |
---|
517 | * be any options after the first "--" but before the "-E". These |
---|
518 | * options are removed from the original options array. |
---|
519 | * |
---|
520 | * @param options the input array of options |
---|
521 | * @return the array of secondary options |
---|
522 | */ |
---|
523 | public static String [] partitionOptions(String [] options) { |
---|
524 | |
---|
525 | for (int i = 0; i < options.length; i++) { |
---|
526 | if (options[i].equals("--")) { |
---|
527 | // ensure it follows by a -E option |
---|
528 | int j = i; |
---|
529 | while ((j < options.length) && !(options[j].equals("-E"))) { |
---|
530 | j++; |
---|
531 | } |
---|
532 | /* if (j >= options.length) { |
---|
533 | return new String[0]; |
---|
534 | } */ |
---|
535 | options[i++] = ""; |
---|
536 | String [] result = new String [options.length - i]; |
---|
537 | j = i; |
---|
538 | while ((j < options.length) && !(options[j].equals("-E"))) { |
---|
539 | result[j - i] = options[j]; |
---|
540 | options[j] = ""; |
---|
541 | j++; |
---|
542 | } |
---|
543 | while(j < options.length) { |
---|
544 | result[j - i] = ""; |
---|
545 | j++; |
---|
546 | } |
---|
547 | return result; |
---|
548 | } |
---|
549 | } |
---|
550 | return new String [0]; |
---|
551 | } |
---|
552 | |
---|
553 | |
---|
554 | /** |
---|
555 | * Gets the current settings of the classifier. |
---|
556 | * |
---|
557 | * @return an array of strings suitable for passing to setOptions |
---|
558 | */ |
---|
559 | public String[] getOptions() { |
---|
560 | String[] searchOptions = m_SearchAlgorithm.getOptions(); |
---|
561 | String[] estimatorOptions = m_BayesNetEstimator.getOptions(); |
---|
562 | String[] options = new String[11 + searchOptions.length + estimatorOptions.length]; |
---|
563 | int current = 0; |
---|
564 | |
---|
565 | if (!m_bUseADTree) { |
---|
566 | options[current++] = "-D"; |
---|
567 | } |
---|
568 | |
---|
569 | if (m_otherBayesNet != null) { |
---|
570 | options[current++] = "-B"; |
---|
571 | options[current++] = ((BIFReader) m_otherBayesNet).getFileName(); |
---|
572 | } |
---|
573 | |
---|
574 | options[current++] = "-Q"; |
---|
575 | options[current++] = "" + getSearchAlgorithm().getClass().getName(); |
---|
576 | options[current++] = "--"; |
---|
577 | for (int iOption = 0; iOption < searchOptions.length; iOption++) { |
---|
578 | options[current++] = searchOptions[iOption]; |
---|
579 | } |
---|
580 | |
---|
581 | options[current++] = "-E"; |
---|
582 | options[current++] = "" + getEstimator().getClass().getName(); |
---|
583 | options[current++] = "--"; |
---|
584 | for (int iOption = 0; iOption < estimatorOptions.length; iOption++) { |
---|
585 | options[current++] = estimatorOptions[iOption]; |
---|
586 | } |
---|
587 | |
---|
588 | // Fill up rest with empty strings, not nulls! |
---|
589 | while (current < options.length) { |
---|
590 | options[current++] = ""; |
---|
591 | } |
---|
592 | |
---|
593 | return options; |
---|
594 | } // getOptions |
---|
595 | |
---|
596 | /** |
---|
597 | * Set the SearchAlgorithm used in searching for network structures. |
---|
598 | * @param newSearchAlgorithm the SearchAlgorithm to use. |
---|
599 | */ |
---|
600 | public void setSearchAlgorithm(SearchAlgorithm newSearchAlgorithm) { |
---|
601 | m_SearchAlgorithm = newSearchAlgorithm; |
---|
602 | } |
---|
603 | |
---|
604 | /** |
---|
605 | * Get the SearchAlgorithm used as the search algorithm |
---|
606 | * @return the SearchAlgorithm used as the search algorithm |
---|
607 | */ |
---|
608 | public SearchAlgorithm getSearchAlgorithm() { |
---|
609 | return m_SearchAlgorithm; |
---|
610 | } |
---|
611 | |
---|
612 | /** |
---|
613 | * Set the Estimator Algorithm used in calculating the CPTs |
---|
614 | * @param newBayesNetEstimator the Estimator to use. |
---|
615 | */ |
---|
616 | public void setEstimator(BayesNetEstimator newBayesNetEstimator) { |
---|
617 | m_BayesNetEstimator = newBayesNetEstimator; |
---|
618 | } |
---|
619 | |
---|
620 | /** |
---|
621 | * Get the BayesNetEstimator used for calculating the CPTs |
---|
622 | * @return the BayesNetEstimator used. |
---|
623 | */ |
---|
624 | public BayesNetEstimator getEstimator() { |
---|
625 | return m_BayesNetEstimator; |
---|
626 | } |
---|
627 | |
---|
628 | /** |
---|
629 | * Set whether ADTree structure is used or not |
---|
630 | * @param bUseADTree true if an ADTree structure is used |
---|
631 | */ |
---|
632 | public void setUseADTree(boolean bUseADTree) { |
---|
633 | m_bUseADTree = bUseADTree; |
---|
634 | } |
---|
635 | |
---|
636 | /** |
---|
637 | * Method declaration |
---|
638 | * @return whether ADTree structure is used or not |
---|
639 | */ |
---|
640 | public boolean getUseADTree() { |
---|
641 | return m_bUseADTree; |
---|
642 | } |
---|
643 | |
---|
644 | /** |
---|
645 | * Set name of network in BIF file to compare with |
---|
646 | * @param sBIFFile the name of the BIF file |
---|
647 | */ |
---|
648 | public void setBIFFile(String sBIFFile) { |
---|
649 | try { |
---|
650 | m_otherBayesNet = new BIFReader().processFile(sBIFFile); |
---|
651 | } catch (Throwable t) { |
---|
652 | m_otherBayesNet = null; |
---|
653 | } |
---|
654 | } |
---|
655 | |
---|
656 | /** |
---|
657 | * Get name of network in BIF file to compare with |
---|
658 | * @return BIF file name |
---|
659 | */ |
---|
660 | public String getBIFFile() { |
---|
661 | if (m_otherBayesNet != null) { |
---|
662 | return m_otherBayesNet.getFileName(); |
---|
663 | } |
---|
664 | return ""; |
---|
665 | } |
---|
666 | |
---|
667 | |
---|
668 | /** |
---|
669 | * Returns a description of the classifier. |
---|
670 | * |
---|
671 | * @return a description of the classifier as a string. |
---|
672 | */ |
---|
673 | public String toString() { |
---|
674 | StringBuffer text = new StringBuffer(); |
---|
675 | |
---|
676 | text.append("Bayes Network Classifier"); |
---|
677 | text.append("\n" + (m_bUseADTree ? "Using " : "not using ") + "ADTree"); |
---|
678 | |
---|
679 | if (m_Instances == null) { |
---|
680 | text.append(": No model built yet."); |
---|
681 | } else { |
---|
682 | |
---|
683 | // flatten BayesNet down to text |
---|
684 | text.append("\n#attributes="); |
---|
685 | text.append(m_Instances.numAttributes()); |
---|
686 | text.append(" #classindex="); |
---|
687 | text.append(m_Instances.classIndex()); |
---|
688 | text.append("\nNetwork structure (nodes followed by parents)\n"); |
---|
689 | |
---|
690 | for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) { |
---|
691 | text.append( |
---|
692 | m_Instances.attribute(iAttribute).name() |
---|
693 | + "(" |
---|
694 | + m_Instances.attribute(iAttribute).numValues() |
---|
695 | + "): "); |
---|
696 | |
---|
697 | for (int iParent = 0; iParent < m_ParentSets[iAttribute].getNrOfParents(); iParent++) { |
---|
698 | text.append(m_Instances.attribute(m_ParentSets[iAttribute].getParent(iParent)).name() + " "); |
---|
699 | } |
---|
700 | |
---|
701 | text.append("\n"); |
---|
702 | |
---|
703 | // Description of distributions tends to be too much detail, so it is commented out here |
---|
704 | // for (int iParent = 0; iParent < m_ParentSets[iAttribute].GetCardinalityOfParents(); iParent++) { |
---|
705 | // text.append('(' + m_Distributions[iAttribute][iParent].toString() + ')'); |
---|
706 | // } |
---|
707 | // text.append("\n"); |
---|
708 | } |
---|
709 | |
---|
710 | text.append("LogScore Bayes: " + measureBayesScore() + "\n"); |
---|
711 | text.append("LogScore BDeu: " + measureBDeuScore() + "\n"); |
---|
712 | text.append("LogScore MDL: " + measureMDLScore() + "\n"); |
---|
713 | text.append("LogScore ENTROPY: " + measureEntropyScore() + "\n"); |
---|
714 | text.append("LogScore AIC: " + measureAICScore() + "\n"); |
---|
715 | |
---|
716 | if (m_otherBayesNet != null) { |
---|
717 | text.append( |
---|
718 | "Missing: " |
---|
719 | + m_otherBayesNet.missingArcs(this) |
---|
720 | + " Extra: " |
---|
721 | + m_otherBayesNet.extraArcs(this) |
---|
722 | + " Reversed: " |
---|
723 | + m_otherBayesNet.reversedArcs(this) |
---|
724 | + "\n"); |
---|
725 | text.append("Divergence: " + m_otherBayesNet.divergence(this) + "\n"); |
---|
726 | } |
---|
727 | } |
---|
728 | |
---|
729 | return text.toString(); |
---|
730 | } // toString |
---|
731 | |
---|
732 | |
---|
733 | /** |
---|
734 | * Returns the type of graph this classifier |
---|
735 | * represents. |
---|
736 | * @return Drawable.TREE |
---|
737 | */ |
---|
738 | public int graphType() { |
---|
739 | return Drawable.BayesNet; |
---|
740 | } |
---|
741 | |
---|
742 | /** |
---|
743 | * Returns a BayesNet graph in XMLBIF ver 0.3 format. |
---|
744 | * @return String representing this BayesNet in XMLBIF ver 0.3 |
---|
745 | * @throws Exception in case BIF generation fails |
---|
746 | */ |
---|
747 | public String graph() throws Exception { |
---|
748 | return toXMLBIF03(); |
---|
749 | } |
---|
750 | |
---|
751 | public String getBIFHeader() { |
---|
752 | StringBuffer text = new StringBuffer(); |
---|
753 | text.append("<?xml version=\"1.0\"?>\n"); |
---|
754 | text.append("<!-- DTD for the XMLBIF 0.3 format -->\n"); |
---|
755 | text.append("<!DOCTYPE BIF [\n"); |
---|
756 | text.append(" <!ELEMENT BIF ( NETWORK )*>\n"); |
---|
757 | text.append(" <!ATTLIST BIF VERSION CDATA #REQUIRED>\n"); |
---|
758 | text.append(" <!ELEMENT NETWORK ( NAME, ( PROPERTY | VARIABLE | DEFINITION )* )>\n"); |
---|
759 | text.append(" <!ELEMENT NAME (#PCDATA)>\n"); |
---|
760 | text.append(" <!ELEMENT VARIABLE ( NAME, ( OUTCOME | PROPERTY )* ) >\n"); |
---|
761 | text.append(" <!ATTLIST VARIABLE TYPE (nature|decision|utility) \"nature\">\n"); |
---|
762 | text.append(" <!ELEMENT OUTCOME (#PCDATA)>\n"); |
---|
763 | text.append(" <!ELEMENT DEFINITION ( FOR | GIVEN | TABLE | PROPERTY )* >\n"); |
---|
764 | text.append(" <!ELEMENT FOR (#PCDATA)>\n"); |
---|
765 | text.append(" <!ELEMENT GIVEN (#PCDATA)>\n"); |
---|
766 | text.append(" <!ELEMENT TABLE (#PCDATA)>\n"); |
---|
767 | text.append(" <!ELEMENT PROPERTY (#PCDATA)>\n"); |
---|
768 | text.append("]>\n"); |
---|
769 | return text.toString(); |
---|
770 | } // getBIFHeader |
---|
771 | |
---|
772 | /** |
---|
773 | * Returns a description of the classifier in XML BIF 0.3 format. |
---|
774 | * See http://www-2.cs.cmu.edu/~fgcozman/Research/InterchangeFormat/ |
---|
775 | * for details on XML BIF. |
---|
776 | * @return an XML BIF 0.3 description of the classifier as a string. |
---|
777 | */ |
---|
778 | public String toXMLBIF03() { |
---|
779 | if (m_Instances == null) { |
---|
780 | return("<!--No model built yet-->"); |
---|
781 | } |
---|
782 | |
---|
783 | StringBuffer text = new StringBuffer(); |
---|
784 | text.append(getBIFHeader()); |
---|
785 | text.append("\n"); |
---|
786 | text.append("\n"); |
---|
787 | text.append("<BIF VERSION=\"0.3\">\n"); |
---|
788 | text.append("<NETWORK>\n"); |
---|
789 | text.append("<NAME>" + XMLNormalize(m_Instances.relationName()) + "</NAME>\n"); |
---|
790 | for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) { |
---|
791 | text.append("<VARIABLE TYPE=\"nature\">\n"); |
---|
792 | text.append("<NAME>" + XMLNormalize(m_Instances.attribute(iAttribute).name()) + "</NAME>\n"); |
---|
793 | for (int iValue = 0; iValue < m_Instances.attribute(iAttribute).numValues(); iValue++) { |
---|
794 | text.append("<OUTCOME>" + XMLNormalize(m_Instances.attribute(iAttribute).value(iValue)) + "</OUTCOME>\n"); |
---|
795 | } |
---|
796 | text.append("</VARIABLE>\n"); |
---|
797 | } |
---|
798 | |
---|
799 | for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) { |
---|
800 | text.append("<DEFINITION>\n"); |
---|
801 | text.append("<FOR>" + XMLNormalize(m_Instances.attribute(iAttribute).name()) + "</FOR>\n"); |
---|
802 | for (int iParent = 0; iParent < m_ParentSets[iAttribute].getNrOfParents(); iParent++) { |
---|
803 | text.append("<GIVEN>" |
---|
804 | + XMLNormalize(m_Instances.attribute(m_ParentSets[iAttribute].getParent(iParent)).name()) + |
---|
805 | "</GIVEN>\n"); |
---|
806 | } |
---|
807 | text.append("<TABLE>\n"); |
---|
808 | for (int iParent = 0; iParent < m_ParentSets[iAttribute].getCardinalityOfParents(); iParent++) { |
---|
809 | for (int iValue = 0; iValue < m_Instances.attribute(iAttribute).numValues(); iValue++) { |
---|
810 | text.append(m_Distributions[iAttribute][iParent].getProbability(iValue)); |
---|
811 | text.append(' '); |
---|
812 | } |
---|
813 | text.append('\n'); |
---|
814 | } |
---|
815 | text.append("</TABLE>\n"); |
---|
816 | text.append("</DEFINITION>\n"); |
---|
817 | } |
---|
818 | text.append("</NETWORK>\n"); |
---|
819 | text.append("</BIF>\n"); |
---|
820 | return text.toString(); |
---|
821 | } // toXMLBIF03 |
---|
822 | |
---|
823 | |
---|
824 | /** XMLNormalize converts the five standard XML entities in a string |
---|
825 | * g.e. the string V&D's is returned as V&D's |
---|
826 | * @param sStr string to normalize |
---|
827 | * @return normalized string |
---|
828 | */ |
---|
829 | protected String XMLNormalize(String sStr) { |
---|
830 | StringBuffer sStr2 = new StringBuffer(); |
---|
831 | for (int iStr = 0; iStr < sStr.length(); iStr++) { |
---|
832 | char c = sStr.charAt(iStr); |
---|
833 | switch (c) { |
---|
834 | case '&': sStr2.append("&"); break; |
---|
835 | case '\'': sStr2.append("'"); break; |
---|
836 | case '\"': sStr2.append("""); break; |
---|
837 | case '<': sStr2.append("<"); break; |
---|
838 | case '>': sStr2.append(">"); break; |
---|
839 | default: |
---|
840 | sStr2.append(c); |
---|
841 | } |
---|
842 | } |
---|
843 | return sStr2.toString(); |
---|
844 | } // XMLNormalize |
---|
845 | |
---|
846 | |
---|
847 | /** |
---|
848 | * @return a string to describe the UseADTreeoption. |
---|
849 | */ |
---|
850 | public String useADTreeTipText() { |
---|
851 | return "When ADTree (the data structure for increasing speed on counts," |
---|
852 | + " not to be confused with the classifier under the same name) is used" |
---|
853 | + " learning time goes down typically. However, because ADTrees are memory" |
---|
854 | + " intensive, memory problems may occur. Switching this option off makes" |
---|
855 | + " the structure learning algorithms slower, and run with less memory." |
---|
856 | + " By default, ADTrees are used."; |
---|
857 | } |
---|
858 | |
---|
859 | /** |
---|
860 | * @return a string to describe the SearchAlgorithm. |
---|
861 | */ |
---|
862 | public String searchAlgorithmTipText() { |
---|
863 | return "Select method used for searching network structures."; |
---|
864 | } |
---|
865 | |
---|
866 | /** |
---|
867 | * This will return a string describing the BayesNetEstimator. |
---|
868 | * @return The string. |
---|
869 | */ |
---|
870 | public String estimatorTipText() { |
---|
871 | return "Select Estimator algorithm for finding the conditional probability tables" |
---|
872 | + " of the Bayes Network."; |
---|
873 | } |
---|
874 | |
---|
875 | /** |
---|
876 | * @return a string to describe the BIFFile. |
---|
877 | */ |
---|
878 | public String BIFFileTipText() { |
---|
879 | return "Set the name of a file in BIF XML format. A Bayes network learned" |
---|
880 | + " from data can be compared with the Bayes network represented by the BIF file." |
---|
881 | + " Statistics calculated are o.a. the number of missing and extra arcs."; |
---|
882 | } |
---|
883 | |
---|
884 | /** |
---|
885 | * This will return a string describing the classifier. |
---|
886 | * @return The string. |
---|
887 | */ |
---|
888 | public String globalInfo() { |
---|
889 | return |
---|
890 | "Bayes Network learning using various search algorithms and " |
---|
891 | + "quality measures.\n" |
---|
892 | + "Base class for a Bayes Network classifier. Provides " |
---|
893 | + "datastructures (network structure, conditional probability " |
---|
894 | + "distributions, etc.) and facilities common to Bayes Network " |
---|
895 | + "learning algorithms like K2 and B.\n\n" |
---|
896 | + "For more information see:\n\n" |
---|
897 | + "http://www.cs.waikato.ac.nz/~remco/weka.pdf"; |
---|
898 | } |
---|
899 | |
---|
900 | /** |
---|
901 | * Main method for testing this class. |
---|
902 | * |
---|
903 | * @param argv the options |
---|
904 | */ |
---|
905 | public static void main(String[] argv) { |
---|
906 | runClassifier(new BayesNet(), argv); |
---|
907 | } // main |
---|
908 | |
---|
909 | /** get name of the Bayes network |
---|
910 | * @return name of the Bayes net |
---|
911 | */ |
---|
912 | public String getName() { |
---|
913 | return m_Instances.relationName(); |
---|
914 | } |
---|
915 | |
---|
916 | /** get number of nodes in the Bayes network |
---|
917 | * @return number of nodes |
---|
918 | */ |
---|
919 | public int getNrOfNodes() { |
---|
920 | return m_Instances.numAttributes(); |
---|
921 | } |
---|
922 | |
---|
923 | /** get name of a node in the Bayes network |
---|
924 | * @param iNode index of the node |
---|
925 | * @return name of the specified node |
---|
926 | */ |
---|
927 | public String getNodeName(int iNode) { |
---|
928 | return m_Instances.attribute(iNode).name(); |
---|
929 | } |
---|
930 | |
---|
931 | /** get number of values a node can take |
---|
932 | * @param iNode index of the node |
---|
933 | * @return cardinality of the specified node |
---|
934 | */ |
---|
935 | public int getCardinality(int iNode) { |
---|
936 | return m_Instances.attribute(iNode).numValues(); |
---|
937 | } |
---|
938 | |
---|
939 | /** get name of a particular value of a node |
---|
940 | * @param iNode index of the node |
---|
941 | * @param iValue index of the value |
---|
942 | * @return cardinality of the specified node |
---|
943 | */ |
---|
944 | public String getNodeValue(int iNode, int iValue) { |
---|
945 | return m_Instances.attribute(iNode).value(iValue); |
---|
946 | } |
---|
947 | |
---|
948 | /** get number of parents of a node in the network structure |
---|
949 | * @param iNode index of the node |
---|
950 | * @return number of parents of the specified node |
---|
951 | */ |
---|
952 | public int getNrOfParents(int iNode) { |
---|
953 | return m_ParentSets[iNode].getNrOfParents(); |
---|
954 | } |
---|
955 | |
---|
956 | /** get node index of a parent of a node in the network structure |
---|
957 | * @param iNode index of the node |
---|
958 | * @param iParent index of the parents, e.g., 0 is the first parent, 1 the second parent, etc. |
---|
959 | * @return node index of the iParent's parent of the specified node |
---|
960 | */ |
---|
961 | public int getParent(int iNode, int iParent) { |
---|
962 | return m_ParentSets[iNode].getParent(iParent); |
---|
963 | } |
---|
964 | |
---|
965 | /** Get full set of parent sets. |
---|
966 | * @return parent sets; |
---|
967 | */ |
---|
968 | public ParentSet[] getParentSets() { |
---|
969 | return m_ParentSets; |
---|
970 | } |
---|
971 | |
---|
972 | /** Get full set of estimators. |
---|
973 | * @return estimators; |
---|
974 | */ |
---|
975 | public Estimator[][] getDistributions() { |
---|
976 | return m_Distributions; |
---|
977 | } |
---|
978 | |
---|
979 | /** get number of values the collection of parents of a node can take |
---|
980 | * @param iNode index of the node |
---|
981 | * @return cardinality of the parent set of the specified node |
---|
982 | */ |
---|
983 | public int getParentCardinality(int iNode) { |
---|
984 | return m_ParentSets[iNode].getCardinalityOfParents(); |
---|
985 | } |
---|
986 | |
---|
987 | /** get particular probability of the conditional probability distribtion |
---|
988 | * of a node given its parents. |
---|
989 | * @param iNode index of the node |
---|
990 | * @param iParent index of the parent set, 0 <= iParent <= getParentCardinality(iNode) |
---|
991 | * @param iValue index of the value, 0 <= iValue <= getCardinality(iNode) |
---|
992 | * @return probability |
---|
993 | */ |
---|
994 | public double getProbability(int iNode, int iParent, int iValue) { |
---|
995 | return m_Distributions[iNode][iParent].getProbability(iValue); |
---|
996 | } |
---|
997 | |
---|
998 | /** get the parent set of a node |
---|
999 | * @param iNode index of the node |
---|
1000 | * @return Parent set of the specified node. |
---|
1001 | */ |
---|
1002 | public ParentSet getParentSet(int iNode) { |
---|
1003 | return m_ParentSets[iNode]; |
---|
1004 | } |
---|
1005 | |
---|
1006 | /** get ADTree strucrture containing efficient representation of counts. |
---|
1007 | * @return ADTree strucrture |
---|
1008 | */ |
---|
1009 | public ADNode getADTree() { return m_ADTree;} |
---|
1010 | |
---|
1011 | // implementation of AdditionalMeasureProducer interface |
---|
1012 | /** |
---|
1013 | * Returns an enumeration of the measure names. Additional measures |
---|
1014 | * must follow the naming convention of starting with "measure", eg. |
---|
1015 | * double measureBlah() |
---|
1016 | * @return an enumeration of the measure names |
---|
1017 | */ |
---|
1018 | public Enumeration enumerateMeasures() { |
---|
1019 | Vector newVector = new Vector(4); |
---|
1020 | newVector.addElement("measureExtraArcs"); |
---|
1021 | newVector.addElement("measureMissingArcs"); |
---|
1022 | newVector.addElement("measureReversedArcs"); |
---|
1023 | newVector.addElement("measureDivergence"); |
---|
1024 | newVector.addElement("measureBayesScore"); |
---|
1025 | newVector.addElement("measureBDeuScore"); |
---|
1026 | newVector.addElement("measureMDLScore"); |
---|
1027 | newVector.addElement("measureAICScore"); |
---|
1028 | newVector.addElement("measureEntropyScore"); |
---|
1029 | return newVector.elements(); |
---|
1030 | } // enumerateMeasures |
---|
1031 | |
---|
1032 | public double measureExtraArcs() { |
---|
1033 | if (m_otherBayesNet != null) { |
---|
1034 | return m_otherBayesNet.extraArcs(this); |
---|
1035 | } |
---|
1036 | return 0; |
---|
1037 | } // measureExtraArcs |
---|
1038 | |
---|
1039 | public double measureMissingArcs() { |
---|
1040 | if (m_otherBayesNet != null) { |
---|
1041 | return m_otherBayesNet.missingArcs(this); |
---|
1042 | } |
---|
1043 | return 0; |
---|
1044 | } // measureMissingArcs |
---|
1045 | |
---|
1046 | public double measureReversedArcs() { |
---|
1047 | if (m_otherBayesNet != null) { |
---|
1048 | return m_otherBayesNet.reversedArcs(this); |
---|
1049 | } |
---|
1050 | return 0; |
---|
1051 | } // measureReversedArcs |
---|
1052 | |
---|
1053 | public double measureDivergence() { |
---|
1054 | if (m_otherBayesNet != null) { |
---|
1055 | return m_otherBayesNet.divergence(this); |
---|
1056 | } |
---|
1057 | return 0; |
---|
1058 | } // measureDivergence |
---|
1059 | |
---|
1060 | public double measureBayesScore() { |
---|
1061 | LocalScoreSearchAlgorithm s = new LocalScoreSearchAlgorithm(this, m_Instances); |
---|
1062 | return s.logScore(Scoreable.BAYES); |
---|
1063 | } // measureBayesScore |
---|
1064 | |
---|
1065 | public double measureBDeuScore() { |
---|
1066 | LocalScoreSearchAlgorithm s = new LocalScoreSearchAlgorithm(this, m_Instances); |
---|
1067 | return s.logScore(Scoreable.BDeu); |
---|
1068 | } // measureBDeuScore |
---|
1069 | |
---|
1070 | public double measureMDLScore() { |
---|
1071 | LocalScoreSearchAlgorithm s = new LocalScoreSearchAlgorithm(this, m_Instances); |
---|
1072 | return s.logScore(Scoreable.MDL); |
---|
1073 | } // measureMDLScore |
---|
1074 | |
---|
1075 | public double measureAICScore() { |
---|
1076 | LocalScoreSearchAlgorithm s = new LocalScoreSearchAlgorithm(this, m_Instances); |
---|
1077 | return s.logScore(Scoreable.AIC); |
---|
1078 | } // measureAICScore |
---|
1079 | |
---|
1080 | public double measureEntropyScore() { |
---|
1081 | LocalScoreSearchAlgorithm s = new LocalScoreSearchAlgorithm(this, m_Instances); |
---|
1082 | return s.logScore(Scoreable.ENTROPY); |
---|
1083 | } // measureEntropyScore |
---|
1084 | |
---|
1085 | /** |
---|
1086 | * Returns the value of the named measure |
---|
1087 | * @param measureName the name of the measure to query for its value |
---|
1088 | * @return the value of the named measure |
---|
1089 | * @throws IllegalArgumentException if the named measure is not supported |
---|
1090 | */ |
---|
1091 | public double getMeasure(String measureName) { |
---|
1092 | if (measureName.equals("measureExtraArcs")) { |
---|
1093 | return measureExtraArcs(); |
---|
1094 | } |
---|
1095 | if (measureName.equals("measureMissingArcs")) { |
---|
1096 | return measureMissingArcs(); |
---|
1097 | } |
---|
1098 | if (measureName.equals("measureReversedArcs")) { |
---|
1099 | return measureReversedArcs(); |
---|
1100 | } |
---|
1101 | if (measureName.equals("measureDivergence")) { |
---|
1102 | return measureDivergence(); |
---|
1103 | } |
---|
1104 | if (measureName.equals("measureBayesScore")) { |
---|
1105 | return measureBayesScore(); |
---|
1106 | } |
---|
1107 | if (measureName.equals("measureBDeuScore")) { |
---|
1108 | return measureBDeuScore(); |
---|
1109 | } |
---|
1110 | if (measureName.equals("measureMDLScore")) { |
---|
1111 | return measureMDLScore(); |
---|
1112 | } |
---|
1113 | if (measureName.equals("measureAICScore")) { |
---|
1114 | return measureAICScore(); |
---|
1115 | } |
---|
1116 | if (measureName.equals("measureEntropyScore")) { |
---|
1117 | return measureEntropyScore(); |
---|
1118 | } |
---|
1119 | return 0; |
---|
1120 | } // getMeasure |
---|
1121 | |
---|
1122 | /** |
---|
1123 | * Returns the revision string. |
---|
1124 | * |
---|
1125 | * @return the revision |
---|
1126 | */ |
---|
1127 | public String getRevision() { |
---|
1128 | return RevisionUtils.extract("$Revision: 5928 $"); |
---|
1129 | } |
---|
1130 | } // class BayesNet |
---|