1 | /* |
---|
2 | * This program is free software; you can redistribute it and/or modify |
---|
3 | * it under the terms of the GNU General Public License as published by |
---|
4 | * the Free Software Foundation; either version 2 of the License, or |
---|
5 | * (at your option) any later version. |
---|
6 | * |
---|
7 | * This program is distributed in the hope that it will be useful, |
---|
8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
---|
10 | * GNU General Public License for more details. |
---|
11 | * |
---|
12 | * You should have received a copy of the GNU General Public License |
---|
13 | * along with this program; if not, write to the Free Software |
---|
14 | * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. |
---|
15 | */ |
---|
16 | |
---|
17 | /* |
---|
18 | * TreeModel.java |
---|
19 | * Copyright (C) 2009 University of Waikato, Hamilton, New Zealand |
---|
20 | * |
---|
21 | */ |
---|
22 | |
---|
23 | package weka.classifiers.pmml.consumer; |
---|
24 | |
---|
25 | import java.io.Serializable; |
---|
26 | import java.util.ArrayList; |
---|
27 | |
---|
28 | import org.w3c.dom.Element; |
---|
29 | import org.w3c.dom.Node; |
---|
30 | import org.w3c.dom.NodeList; |
---|
31 | |
---|
32 | import weka.core.Attribute; |
---|
33 | import weka.core.Drawable; |
---|
34 | import weka.core.Instance; |
---|
35 | import weka.core.Instances; |
---|
36 | import weka.core.RevisionUtils; |
---|
37 | import weka.core.Utils; |
---|
38 | import weka.core.pmml.*; |
---|
39 | |
---|
40 | /** |
---|
41 | * Class implementing import of PMML TreeModel. Can be used as a Weka |
---|
42 | * classifier for prediction (buildClassifier() raises and Exception). |
---|
43 | * |
---|
44 | * @author Mark Hall (mhall{[at]}pentaho{[dot]}com) |
---|
45 | * @version $Revision: 5987 $; |
---|
46 | */ |
---|
47 | public class TreeModel extends PMMLClassifier implements Drawable { |
---|
48 | |
---|
49 | /** |
---|
50 | * For serialization |
---|
51 | */ |
---|
52 | private static final long serialVersionUID = -2065158088298753129L; |
---|
53 | |
---|
54 | /** |
---|
55 | * Inner class representing the ScoreDistribution element |
---|
56 | */ |
---|
57 | static class ScoreDistribution implements Serializable { |
---|
58 | |
---|
59 | /** |
---|
60 | * For serialization |
---|
61 | */ |
---|
62 | private static final long serialVersionUID = -123506262094299933L; |
---|
63 | |
---|
64 | /** The class label for this distribution element */ |
---|
65 | private String m_classLabel; |
---|
66 | |
---|
67 | /** The index of the class label */ |
---|
68 | private int m_classLabelIndex = -1; |
---|
69 | |
---|
70 | /** The count for this label */ |
---|
71 | private double m_recordCount; |
---|
72 | |
---|
73 | /** The optional confidence value */ |
---|
74 | private double m_confidence = Utils.missingValue(); |
---|
75 | |
---|
76 | /** |
---|
77 | * Construct a ScoreDistribution entry |
---|
78 | * |
---|
79 | * @param scoreE the node containing the distribution |
---|
80 | * @param miningSchema the mining schema |
---|
81 | * @param baseCount the number of records at the node that owns this |
---|
82 | * distribution entry |
---|
83 | * @throws Exception if something goes wrong |
---|
84 | */ |
---|
85 | protected ScoreDistribution(Element scoreE, MiningSchema miningSchema, double baseCount) |
---|
86 | throws Exception { |
---|
87 | // get the label |
---|
88 | m_classLabel = scoreE.getAttribute("value"); |
---|
89 | Attribute classAtt = miningSchema.getFieldsAsInstances().classAttribute(); |
---|
90 | if (classAtt == null || classAtt.indexOfValue(m_classLabel) < 0) { |
---|
91 | throw new Exception("[ScoreDistribution] class attribute not set or class value " + |
---|
92 | m_classLabel + " not found!"); |
---|
93 | } |
---|
94 | |
---|
95 | m_classLabelIndex = classAtt.indexOfValue(m_classLabel); |
---|
96 | |
---|
97 | // get the frequency |
---|
98 | String recordC = scoreE.getAttribute("recordCount"); |
---|
99 | m_recordCount = Double.parseDouble(recordC); |
---|
100 | |
---|
101 | // get the optional confidence |
---|
102 | String confidence = scoreE.getAttribute("confidence"); |
---|
103 | if (confidence != null && confidence.length() > 0) { |
---|
104 | m_confidence = Double.parseDouble(confidence); |
---|
105 | } else if (!Utils.isMissingValue(baseCount) && baseCount > 0) { |
---|
106 | m_confidence = m_recordCount / baseCount; |
---|
107 | } |
---|
108 | } |
---|
109 | |
---|
110 | /** |
---|
111 | * Backfit confidence value (does nothing if the confidence |
---|
112 | * value is already set). |
---|
113 | * |
---|
114 | * @param baseCount the total number of records (supplied either |
---|
115 | * explicitly from the node that owns this distribution entry |
---|
116 | * or most likely computed from summing the recordCounts of all |
---|
117 | * the distribution entries in the distribution that owns this |
---|
118 | * entry). |
---|
119 | */ |
---|
120 | void deriveConfidenceValue(double baseCount) { |
---|
121 | if (Utils.isMissingValue(m_confidence) && |
---|
122 | !Utils.isMissingValue(baseCount) && |
---|
123 | baseCount > 0) { |
---|
124 | m_confidence = m_recordCount / baseCount; |
---|
125 | } |
---|
126 | } |
---|
127 | |
---|
128 | String getClassLabel() { |
---|
129 | return m_classLabel; |
---|
130 | } |
---|
131 | |
---|
132 | int getClassLabelIndex() { |
---|
133 | return m_classLabelIndex; |
---|
134 | } |
---|
135 | |
---|
136 | double getRecordCount() { |
---|
137 | return m_recordCount; |
---|
138 | } |
---|
139 | |
---|
140 | double getConfidence() { |
---|
141 | return m_confidence; |
---|
142 | } |
---|
143 | |
---|
144 | public String toString() { |
---|
145 | return m_classLabel + ": " + m_recordCount |
---|
146 | + " (" + Utils.doubleToString(m_confidence, 2) + ") "; |
---|
147 | } |
---|
148 | } |
---|
149 | |
---|
150 | /** |
---|
151 | * Base class for Predicates |
---|
152 | */ |
---|
153 | static abstract class Predicate implements Serializable { |
---|
154 | |
---|
155 | /** |
---|
156 | * For serialization |
---|
157 | */ |
---|
158 | private static final long serialVersionUID = 1035344165452733887L; |
---|
159 | |
---|
160 | enum Eval { |
---|
161 | TRUE, |
---|
162 | FALSE, |
---|
163 | UNKNOWN; |
---|
164 | } |
---|
165 | |
---|
166 | /** |
---|
167 | * Evaluate this predicate. |
---|
168 | * |
---|
169 | * @param input the input vector of attribute and derived field values. |
---|
170 | * |
---|
171 | * @return the evaluation status of this predicate. |
---|
172 | */ |
---|
173 | abstract Eval evaluate(double[] input); |
---|
174 | |
---|
175 | protected String toString(int level, boolean cr) { |
---|
176 | return toString(level); |
---|
177 | } |
---|
178 | |
---|
179 | protected String toString(int level) { |
---|
180 | StringBuffer text = new StringBuffer(); |
---|
181 | for (int j = 0; j < level; j++) { |
---|
182 | text.append("| "); |
---|
183 | } |
---|
184 | |
---|
185 | return text.append(toString()).toString(); |
---|
186 | } |
---|
187 | |
---|
188 | static Eval booleanToEval(boolean missing, boolean result) { |
---|
189 | if (missing) { |
---|
190 | return Eval.UNKNOWN; |
---|
191 | } else if (result) { |
---|
192 | return Eval.TRUE; |
---|
193 | } else { |
---|
194 | return Eval.FALSE; |
---|
195 | } |
---|
196 | } |
---|
197 | |
---|
198 | /** |
---|
199 | * Factory method to return the appropriate predicate for |
---|
200 | * a given node in the tree. |
---|
201 | * |
---|
202 | * @param nodeE the XML node encapsulating the tree node. |
---|
203 | * @param miningSchema the mining schema in use |
---|
204 | * @return a Predicate |
---|
205 | * @throws Exception of something goes wrong. |
---|
206 | */ |
---|
207 | static Predicate getPredicate(Element nodeE, |
---|
208 | MiningSchema miningSchema) throws Exception { |
---|
209 | |
---|
210 | Predicate result = null; |
---|
211 | NodeList children = nodeE.getChildNodes(); |
---|
212 | for (int i = 0; i < children.getLength(); i++) { |
---|
213 | Node child = children.item(i); |
---|
214 | if (child.getNodeType() == Node.ELEMENT_NODE) { |
---|
215 | String tagName = ((Element)child).getTagName(); |
---|
216 | if (tagName.equals("True")) { |
---|
217 | result = new True(); |
---|
218 | break; |
---|
219 | } else if (tagName.equals("False")) { |
---|
220 | result = new False(); |
---|
221 | break; |
---|
222 | } else if (tagName.equals("SimplePredicate")) { |
---|
223 | result = new SimplePredicate((Element)child, miningSchema); |
---|
224 | break; |
---|
225 | } else if (tagName.equals("CompoundPredicate")) { |
---|
226 | result = new CompoundPredicate((Element)child, miningSchema); |
---|
227 | break; |
---|
228 | } else if (tagName.equals("SimpleSetPredicate")) { |
---|
229 | result = new SimpleSetPredicate((Element)child, miningSchema); |
---|
230 | break; |
---|
231 | } |
---|
232 | } |
---|
233 | } |
---|
234 | |
---|
235 | if (result == null) { |
---|
236 | throw new Exception("[Predicate] unknown or missing predicate type in node"); |
---|
237 | } |
---|
238 | |
---|
239 | return result; |
---|
240 | } |
---|
241 | } |
---|
242 | |
---|
243 | /** |
---|
244 | * Simple True Predicate |
---|
245 | */ |
---|
246 | static class True extends Predicate { |
---|
247 | |
---|
248 | /** |
---|
249 | * For serialization |
---|
250 | */ |
---|
251 | private static final long serialVersionUID = 1817942234610531627L; |
---|
252 | |
---|
253 | public Predicate.Eval evaluate(double[] input) { |
---|
254 | return Predicate.Eval.TRUE; |
---|
255 | } |
---|
256 | |
---|
257 | public String toString() { |
---|
258 | return "True: "; |
---|
259 | } |
---|
260 | } |
---|
261 | |
---|
262 | /** |
---|
263 | * Simple False Predicate |
---|
264 | */ |
---|
265 | static class False extends Predicate { |
---|
266 | |
---|
267 | /** |
---|
268 | * For serialization |
---|
269 | */ |
---|
270 | private static final long serialVersionUID = -3647261386442860365L; |
---|
271 | |
---|
272 | public Predicate.Eval evaluate(double[] input) { |
---|
273 | return Predicate.Eval.FALSE; |
---|
274 | } |
---|
275 | |
---|
276 | public String toString() { |
---|
277 | return "False: "; |
---|
278 | } |
---|
279 | } |
---|
280 | |
---|
281 | /** |
---|
282 | * Class representing the SimplePredicate |
---|
283 | */ |
---|
284 | static class SimplePredicate extends Predicate { |
---|
285 | |
---|
286 | /** |
---|
287 | * For serialization |
---|
288 | */ |
---|
289 | private static final long serialVersionUID = -6156684285069327400L; |
---|
290 | |
---|
291 | enum Operator { |
---|
292 | EQUAL("equal") { |
---|
293 | Predicate.Eval evaluate(double[] input, double value, int fieldIndex) { |
---|
294 | return Predicate.booleanToEval(Utils.isMissingValue(input[fieldIndex]), |
---|
295 | weka.core.Utils.eq(input[fieldIndex], value)); |
---|
296 | } |
---|
297 | |
---|
298 | String shortName() { |
---|
299 | return "=="; |
---|
300 | } |
---|
301 | }, |
---|
302 | NOTEQUAL("notEqual") |
---|
303 | { |
---|
304 | Predicate.Eval evaluate(double[] input, double value, int fieldIndex) { |
---|
305 | return Predicate.booleanToEval(Utils.isMissingValue(input[fieldIndex]), |
---|
306 | (input[fieldIndex] != value)); |
---|
307 | } |
---|
308 | |
---|
309 | String shortName() { |
---|
310 | return "!="; |
---|
311 | } |
---|
312 | }, |
---|
313 | LESSTHAN("lessThan") { |
---|
314 | Predicate.Eval evaluate(double[] input, double value, int fieldIndex) { |
---|
315 | return Predicate.booleanToEval(Utils.isMissingValue(input[fieldIndex]), |
---|
316 | (input[fieldIndex] < value)); |
---|
317 | } |
---|
318 | |
---|
319 | String shortName() { |
---|
320 | return "<"; |
---|
321 | } |
---|
322 | }, |
---|
323 | LESSOREQUAL("lessOrEqual") { |
---|
324 | Predicate.Eval evaluate(double[] input, double value, int fieldIndex) { |
---|
325 | return Predicate.booleanToEval(Utils.isMissingValue(input[fieldIndex]), |
---|
326 | (input[fieldIndex] <= value)); |
---|
327 | } |
---|
328 | |
---|
329 | String shortName() { |
---|
330 | return "<="; |
---|
331 | } |
---|
332 | }, |
---|
333 | GREATERTHAN("greaterThan") { |
---|
334 | Predicate.Eval evaluate(double[] input, double value, int fieldIndex) { |
---|
335 | return Predicate.booleanToEval(Utils.isMissingValue(input[fieldIndex]), |
---|
336 | (input[fieldIndex] > value)); |
---|
337 | } |
---|
338 | |
---|
339 | String shortName() { |
---|
340 | return ">"; |
---|
341 | } |
---|
342 | }, |
---|
343 | GREATEROREQUAL("greaterOrEqual") { |
---|
344 | Predicate.Eval evaluate(double[] input, double value, int fieldIndex) { |
---|
345 | return Predicate.booleanToEval(Utils.isMissingValue(input[fieldIndex]), |
---|
346 | (input[fieldIndex] >= value)); |
---|
347 | } |
---|
348 | |
---|
349 | String shortName() { |
---|
350 | return ">="; |
---|
351 | } |
---|
352 | }, |
---|
353 | ISMISSING("isMissing") { |
---|
354 | Predicate.Eval evaluate(double[] input, double value, int fieldIndex) { |
---|
355 | return Predicate.booleanToEval(false, |
---|
356 | Utils.isMissingValue(input[fieldIndex])); |
---|
357 | } |
---|
358 | |
---|
359 | String shortName() { |
---|
360 | return toString(); |
---|
361 | } |
---|
362 | }, |
---|
363 | ISNOTMISSING("isNotMissing") { |
---|
364 | Predicate.Eval evaluate(double[] input, double value, int fieldIndex) { |
---|
365 | return Predicate.booleanToEval(false, !Utils.isMissingValue(input[fieldIndex])); |
---|
366 | } |
---|
367 | |
---|
368 | String shortName() { |
---|
369 | return toString(); |
---|
370 | } |
---|
371 | }; |
---|
372 | |
---|
373 | abstract Predicate.Eval evaluate(double[] input, double value, int fieldIndex); |
---|
374 | abstract String shortName(); |
---|
375 | |
---|
376 | private final String m_stringVal; |
---|
377 | |
---|
378 | Operator(String name) { |
---|
379 | m_stringVal = name; |
---|
380 | } |
---|
381 | |
---|
382 | public String toString() { |
---|
383 | return m_stringVal; |
---|
384 | } |
---|
385 | } |
---|
386 | |
---|
387 | /** the field that we are comparing against */ |
---|
388 | int m_fieldIndex = -1; |
---|
389 | |
---|
390 | /** the name of the field */ |
---|
391 | String m_fieldName; |
---|
392 | |
---|
393 | /** true if the field is nominal */ |
---|
394 | boolean m_isNominal; |
---|
395 | |
---|
396 | /** the value as a string (if nominal) */ |
---|
397 | String m_nominalValue; |
---|
398 | |
---|
399 | /** the value to compare against (if nominal it holds the index of the value) */ |
---|
400 | double m_value; |
---|
401 | |
---|
402 | /** the operator to use */ |
---|
403 | Operator m_operator; |
---|
404 | |
---|
405 | public SimplePredicate(Element simpleP, |
---|
406 | MiningSchema miningSchema) throws Exception { |
---|
407 | Instances totalStructure = miningSchema.getFieldsAsInstances(); |
---|
408 | |
---|
409 | // get the field name and set up the index |
---|
410 | String fieldS = simpleP.getAttribute("field"); |
---|
411 | Attribute att = totalStructure.attribute(fieldS); |
---|
412 | if (att == null) { |
---|
413 | throw new Exception("[SimplePredicate] unable to find field " + fieldS |
---|
414 | + " in the incoming instance structure!"); |
---|
415 | } |
---|
416 | |
---|
417 | // find the index |
---|
418 | int index = -1; |
---|
419 | for (int i = 0; i < totalStructure.numAttributes(); i++) { |
---|
420 | if (totalStructure.attribute(i).name().equals(fieldS)) { |
---|
421 | index = i; |
---|
422 | m_fieldName = totalStructure.attribute(i).name(); |
---|
423 | break; |
---|
424 | } |
---|
425 | } |
---|
426 | m_fieldIndex = index; |
---|
427 | if (att.isNominal()) { |
---|
428 | m_isNominal = true; |
---|
429 | } |
---|
430 | |
---|
431 | // get the operator |
---|
432 | String oppS = simpleP.getAttribute("operator"); |
---|
433 | for (Operator o : Operator.values()) { |
---|
434 | if (o.toString().equals(oppS)) { |
---|
435 | m_operator = o; |
---|
436 | break; |
---|
437 | } |
---|
438 | } |
---|
439 | |
---|
440 | if (m_operator != Operator.ISMISSING && m_operator != Operator.ISNOTMISSING) { |
---|
441 | String valueS = simpleP.getAttribute("value"); |
---|
442 | if (att.isNumeric()) { |
---|
443 | m_value = Double.parseDouble(valueS); |
---|
444 | } else { |
---|
445 | m_nominalValue = valueS; |
---|
446 | m_value = att.indexOfValue(valueS); |
---|
447 | if (m_value < 0) { |
---|
448 | throw new Exception("[SimplePredicate] can't find value " + valueS + " in nominal " + |
---|
449 | "attribute " + att.name()); |
---|
450 | } |
---|
451 | } |
---|
452 | } |
---|
453 | } |
---|
454 | |
---|
455 | public Predicate.Eval evaluate(double[] input) { |
---|
456 | return m_operator.evaluate(input, m_value, m_fieldIndex); |
---|
457 | } |
---|
458 | |
---|
459 | public String toString() { |
---|
460 | StringBuffer temp = new StringBuffer(); |
---|
461 | |
---|
462 | temp.append(m_fieldName + " " + m_operator.shortName()); |
---|
463 | if (m_operator != Operator.ISMISSING && m_operator != Operator.ISNOTMISSING) { |
---|
464 | temp.append(" " + ((m_isNominal) ? m_nominalValue : "" + m_value)); |
---|
465 | } |
---|
466 | |
---|
467 | return temp.toString(); |
---|
468 | } |
---|
469 | } |
---|
470 | |
---|
471 | /** |
---|
472 | * Class representing the CompoundPredicate |
---|
473 | */ |
---|
474 | static class CompoundPredicate extends Predicate { |
---|
475 | |
---|
476 | /** |
---|
477 | * For serialization |
---|
478 | */ |
---|
479 | private static final long serialVersionUID = -3332091529764559077L; |
---|
480 | |
---|
481 | enum BooleanOperator { |
---|
482 | OR("or") { |
---|
483 | Predicate.Eval evaluate(ArrayList<Predicate> constituents, double[] input) { |
---|
484 | Predicate.Eval currentStatus = Predicate.Eval.FALSE; |
---|
485 | for (Predicate p : constituents) { |
---|
486 | Predicate.Eval temp = p.evaluate(input); |
---|
487 | if (temp == Predicate.Eval.TRUE) { |
---|
488 | currentStatus = temp; |
---|
489 | break; |
---|
490 | } else if (temp == Predicate.Eval.UNKNOWN) { |
---|
491 | currentStatus = temp; |
---|
492 | } |
---|
493 | } |
---|
494 | return currentStatus; |
---|
495 | } |
---|
496 | }, |
---|
497 | AND("and") { |
---|
498 | Predicate.Eval evaluate(ArrayList<Predicate> constituents, double[] input) { |
---|
499 | Predicate.Eval currentStatus = Predicate.Eval.TRUE; |
---|
500 | for (Predicate p : constituents) { |
---|
501 | Predicate.Eval temp = p.evaluate(input); |
---|
502 | if (temp == Predicate.Eval.FALSE) { |
---|
503 | currentStatus = temp; |
---|
504 | break; |
---|
505 | } else if (temp == Predicate.Eval.UNKNOWN) { |
---|
506 | currentStatus = temp; |
---|
507 | } |
---|
508 | } |
---|
509 | return currentStatus; |
---|
510 | } |
---|
511 | }, |
---|
512 | XOR("xor") { |
---|
513 | Predicate.Eval evaluate(ArrayList<Predicate> constituents, double[] input) { |
---|
514 | Predicate.Eval currentStatus = constituents.get(0).evaluate(input); |
---|
515 | if (currentStatus != Predicate.Eval.UNKNOWN) { |
---|
516 | for (int i = 1; i < constituents.size(); i++) { |
---|
517 | Predicate.Eval temp = constituents.get(i).evaluate(input); |
---|
518 | if (temp == Predicate.Eval.UNKNOWN) { |
---|
519 | currentStatus = temp; |
---|
520 | break; |
---|
521 | } else { |
---|
522 | if (currentStatus != temp) { |
---|
523 | currentStatus = Predicate.Eval.TRUE; |
---|
524 | } else { |
---|
525 | currentStatus = Predicate.Eval.FALSE; |
---|
526 | } |
---|
527 | } |
---|
528 | } |
---|
529 | } |
---|
530 | return currentStatus; |
---|
531 | } |
---|
532 | }, |
---|
533 | SURROGATE("surrogate") { |
---|
534 | Predicate.Eval evaluate(ArrayList<Predicate> constituents, double[] input) { |
---|
535 | Predicate.Eval currentStatus = constituents.get(0).evaluate(input); |
---|
536 | |
---|
537 | int i = 1; |
---|
538 | while (currentStatus == Predicate.Eval.UNKNOWN) { |
---|
539 | currentStatus = constituents.get(i).evaluate(input); |
---|
540 | } |
---|
541 | |
---|
542 | // return false if all our surrogates evaluate to unknown. |
---|
543 | if (currentStatus == Predicate.Eval.UNKNOWN) { |
---|
544 | currentStatus = Predicate.Eval.FALSE; |
---|
545 | } |
---|
546 | |
---|
547 | return currentStatus; |
---|
548 | } |
---|
549 | }; |
---|
550 | |
---|
551 | abstract Predicate.Eval evaluate(ArrayList<Predicate> constituents, double[] input); |
---|
552 | |
---|
553 | private final String m_stringVal; |
---|
554 | |
---|
555 | BooleanOperator(String name) { |
---|
556 | m_stringVal = name; |
---|
557 | } |
---|
558 | |
---|
559 | public String toString() { |
---|
560 | return m_stringVal; |
---|
561 | } |
---|
562 | } |
---|
563 | |
---|
564 | /** the constituent Predicates */ |
---|
565 | ArrayList<Predicate> m_components = new ArrayList<Predicate>(); |
---|
566 | |
---|
567 | /** the boolean operator */ |
---|
568 | BooleanOperator m_booleanOperator; |
---|
569 | |
---|
570 | public CompoundPredicate(Element compoundP, |
---|
571 | MiningSchema miningSchema) throws Exception { |
---|
572 | // Instances totalStructure = miningSchema.getFieldsAsInstances(); |
---|
573 | |
---|
574 | String booleanOpp = compoundP.getAttribute("booleanOperator"); |
---|
575 | for (BooleanOperator b : BooleanOperator.values()) { |
---|
576 | if (b.toString().equals(booleanOpp)) { |
---|
577 | m_booleanOperator = b; |
---|
578 | } |
---|
579 | } |
---|
580 | |
---|
581 | // now get all the encapsulated operators |
---|
582 | NodeList children = compoundP.getChildNodes(); |
---|
583 | for (int i = 0; i < children.getLength(); i++) { |
---|
584 | Node child = children.item(i); |
---|
585 | if (child.getNodeType() == Node.ELEMENT_NODE) { |
---|
586 | String tagName = ((Element)child).getTagName(); |
---|
587 | if (tagName.equals("True")) { |
---|
588 | m_components.add(new True()); |
---|
589 | } else if (tagName.equals("False")) { |
---|
590 | m_components.add(new False()); |
---|
591 | } else if (tagName.equals("SimplePredicate")) { |
---|
592 | m_components.add(new SimplePredicate((Element)child, miningSchema)); |
---|
593 | } else if (tagName.equals("CompoundPredicate")) { |
---|
594 | m_components.add(new CompoundPredicate((Element)child, miningSchema)); |
---|
595 | } else { |
---|
596 | m_components.add(new SimpleSetPredicate((Element)child, miningSchema)); |
---|
597 | } |
---|
598 | } |
---|
599 | } |
---|
600 | } |
---|
601 | |
---|
602 | public Predicate.Eval evaluate(double[] input) { |
---|
603 | return m_booleanOperator.evaluate(m_components, input); |
---|
604 | } |
---|
605 | |
---|
606 | public String toString() { |
---|
607 | return toString(0, false); |
---|
608 | } |
---|
609 | |
---|
610 | public String toString(int level, boolean cr) { |
---|
611 | StringBuffer text = new StringBuffer(); |
---|
612 | for (int j = 0; j < level; j++) { |
---|
613 | text.append("| "); |
---|
614 | } |
---|
615 | |
---|
616 | text.append("Compound [" + m_booleanOperator.toString() + "]"); |
---|
617 | if (cr) { |
---|
618 | text.append("\\n"); |
---|
619 | } else { |
---|
620 | text.append("\n"); |
---|
621 | } |
---|
622 | for (int i = 0; i < m_components.size(); i++) { |
---|
623 | text.append(m_components.get(i).toString(level, cr).replace(":", "")); |
---|
624 | if (i != m_components.size()-1) { |
---|
625 | if (cr) { |
---|
626 | text.append("\\n"); |
---|
627 | } else { |
---|
628 | text.append("\n"); |
---|
629 | } |
---|
630 | } |
---|
631 | } |
---|
632 | |
---|
633 | return text.toString(); |
---|
634 | } |
---|
635 | } |
---|
636 | |
---|
637 | /** |
---|
638 | * Class representing the SimpleSetPredicate |
---|
639 | */ |
---|
640 | static class SimpleSetPredicate extends Predicate { |
---|
641 | |
---|
642 | /** |
---|
643 | * For serialization |
---|
644 | */ |
---|
645 | private static final long serialVersionUID = -2711995401345708486L; |
---|
646 | |
---|
647 | enum BooleanOperator { |
---|
648 | IS_IN("isIn") { |
---|
649 | Predicate.Eval evaluate(double[] input, int fieldIndex, |
---|
650 | Array set, Attribute nominalLookup) { |
---|
651 | if (set.getType() == Array.ArrayType.STRING) { |
---|
652 | String value = ""; |
---|
653 | if (!Utils.isMissingValue(input[fieldIndex])) { |
---|
654 | value = nominalLookup.value((int)input[fieldIndex]); |
---|
655 | } |
---|
656 | return Predicate.booleanToEval(Utils.isMissingValue(input[fieldIndex]), |
---|
657 | set.contains(value)); |
---|
658 | } else if (set.getType() == Array.ArrayType.NUM || |
---|
659 | set.getType() == Array.ArrayType.REAL) { |
---|
660 | return Predicate.booleanToEval(Utils.isMissingValue(input[fieldIndex]), |
---|
661 | set.contains(input[fieldIndex])); |
---|
662 | } |
---|
663 | return Predicate.booleanToEval(Utils.isMissingValue(input[fieldIndex]), |
---|
664 | set.contains((int)input[fieldIndex])); |
---|
665 | } |
---|
666 | }, |
---|
667 | IS_NOT_IN("isNotIn") { |
---|
668 | Predicate.Eval evaluate(double[] input, int fieldIndex, |
---|
669 | Array set, Attribute nominalLookup) { |
---|
670 | Predicate.Eval result = IS_IN.evaluate(input, fieldIndex, set, nominalLookup); |
---|
671 | if (result == Predicate.Eval.FALSE) { |
---|
672 | result = Predicate.Eval.TRUE; |
---|
673 | } else if (result == Predicate.Eval.TRUE) { |
---|
674 | result = Predicate.Eval.FALSE; |
---|
675 | } |
---|
676 | |
---|
677 | return result; |
---|
678 | } |
---|
679 | }; |
---|
680 | |
---|
681 | abstract Predicate.Eval evaluate(double[] input, int fieldIndex, |
---|
682 | Array set, Attribute nominalLookup); |
---|
683 | |
---|
684 | private final String m_stringVal; |
---|
685 | |
---|
686 | BooleanOperator(String name) { |
---|
687 | m_stringVal = name; |
---|
688 | } |
---|
689 | |
---|
690 | public String toString() { |
---|
691 | return m_stringVal; |
---|
692 | } |
---|
693 | } |
---|
694 | |
---|
695 | /** the field to reference */ |
---|
696 | int m_fieldIndex = -1; |
---|
697 | |
---|
698 | /** the name of the field */ |
---|
699 | String m_fieldName; |
---|
700 | |
---|
701 | /** is the referenced field nominal? */ |
---|
702 | boolean m_isNominal = false; |
---|
703 | |
---|
704 | /** the attribute to lookup nominal values from */ |
---|
705 | Attribute m_nominalLookup; |
---|
706 | |
---|
707 | /** the boolean operator */ |
---|
708 | BooleanOperator m_operator = BooleanOperator.IS_IN; |
---|
709 | |
---|
710 | /** the array holding the set of values */ |
---|
711 | Array m_set; |
---|
712 | |
---|
713 | public SimpleSetPredicate(Element setP, |
---|
714 | MiningSchema miningSchema) throws Exception { |
---|
715 | Instances totalStructure = miningSchema.getFieldsAsInstances(); |
---|
716 | |
---|
717 | // get the field name and set up the index |
---|
718 | String fieldS = setP.getAttribute("field"); |
---|
719 | Attribute att = totalStructure.attribute(fieldS); |
---|
720 | if (att == null) { |
---|
721 | throw new Exception("[SimplePredicate] unable to find field " + fieldS |
---|
722 | + " in the incoming instance structure!"); |
---|
723 | } |
---|
724 | |
---|
725 | // find the index |
---|
726 | int index = -1; |
---|
727 | for (int i = 0; i < totalStructure.numAttributes(); i++) { |
---|
728 | if (totalStructure.attribute(i).name().equals(fieldS)) { |
---|
729 | index = i; |
---|
730 | m_fieldName = totalStructure.attribute(i).name(); |
---|
731 | break; |
---|
732 | } |
---|
733 | } |
---|
734 | m_fieldIndex = index; |
---|
735 | if (att.isNominal()) { |
---|
736 | m_isNominal = true; |
---|
737 | m_nominalLookup = att; |
---|
738 | } |
---|
739 | |
---|
740 | // need to scan the children looking for an array type |
---|
741 | NodeList children = setP.getChildNodes(); |
---|
742 | for (int i = 0; i < children.getLength(); i++) { |
---|
743 | Node child = children.item(i); |
---|
744 | if (child.getNodeType() == Node.ELEMENT_NODE) { |
---|
745 | if (Array.isArray((Element)child)) { |
---|
746 | // found the array |
---|
747 | m_set = Array.create((Element)child); |
---|
748 | break; |
---|
749 | } |
---|
750 | } |
---|
751 | } |
---|
752 | |
---|
753 | if (m_set == null) { |
---|
754 | throw new Exception("[SimpleSetPredictate] couldn't find an " + |
---|
755 | "array containing the set values!"); |
---|
756 | } |
---|
757 | |
---|
758 | // check array type against field type |
---|
759 | if (m_set.getType() == Array.ArrayType.STRING && |
---|
760 | !m_isNominal) { |
---|
761 | throw new Exception("[SimpleSetPredicate] referenced field " + |
---|
762 | totalStructure.attribute(m_fieldIndex).name() + |
---|
763 | " is numeric but array type is string!"); |
---|
764 | } else if (m_set.getType() != Array.ArrayType.STRING && |
---|
765 | m_isNominal) { |
---|
766 | throw new Exception("[SimpleSetPredicate] referenced field " + |
---|
767 | totalStructure.attribute(m_fieldIndex).name() + |
---|
768 | " is nominal but array type is numeric!"); |
---|
769 | } |
---|
770 | } |
---|
771 | |
---|
772 | public Predicate.Eval evaluate(double[] input) { |
---|
773 | return m_operator.evaluate(input, m_fieldIndex, m_set, m_nominalLookup); |
---|
774 | } |
---|
775 | |
---|
776 | public String toString() { |
---|
777 | StringBuffer temp = new StringBuffer(); |
---|
778 | |
---|
779 | temp.append(m_fieldName + " " + m_operator.toString() + " "); |
---|
780 | temp.append(m_set.toString()); |
---|
781 | |
---|
782 | return temp.toString(); |
---|
783 | } |
---|
784 | } |
---|
785 | |
---|
786 | /** |
---|
787 | * Class for handling a Node in the tree |
---|
788 | */ |
---|
789 | class TreeNode implements Serializable { |
---|
790 | // TODO: perhaps implement a class called Statistics that contains Partitions? |
---|
791 | |
---|
792 | /** |
---|
793 | * For serialization |
---|
794 | */ |
---|
795 | private static final long serialVersionUID = 3011062274167063699L; |
---|
796 | |
---|
797 | /** ID for this node */ |
---|
798 | private String m_ID = "" + this.hashCode(); |
---|
799 | |
---|
800 | /** The score as a string */ |
---|
801 | private String m_scoreString; |
---|
802 | |
---|
803 | /** The index of this predicted value (if class is nominal) */ |
---|
804 | private int m_scoreIndex = -1; |
---|
805 | |
---|
806 | /** The score as a number (if target is numeric) */ |
---|
807 | private double m_scoreNumeric = Utils.missingValue(); |
---|
808 | |
---|
809 | /** The record count at this node (if defined) */ |
---|
810 | private double m_recordCount = Utils.missingValue(); |
---|
811 | |
---|
812 | /** The ID of the default child (if applicable) */ |
---|
813 | private String m_defaultChildID; |
---|
814 | |
---|
815 | /** Holds the node of the default child (if defined) */ |
---|
816 | private TreeNode m_defaultChild; |
---|
817 | |
---|
818 | /** The distribution for labels (classification) */ |
---|
819 | private ArrayList<ScoreDistribution> m_scoreDistributions = |
---|
820 | new ArrayList<ScoreDistribution>(); |
---|
821 | |
---|
822 | /** The predicate for this node */ |
---|
823 | private Predicate m_predicate; |
---|
824 | |
---|
825 | /** The children of this node */ |
---|
826 | private ArrayList<TreeNode> m_childNodes = new ArrayList<TreeNode>(); |
---|
827 | |
---|
828 | |
---|
829 | protected TreeNode(Element nodeE, MiningSchema miningSchema) throws Exception { |
---|
830 | Attribute classAtt = miningSchema.getFieldsAsInstances().classAttribute(); |
---|
831 | |
---|
832 | // get the ID |
---|
833 | String id = nodeE.getAttribute("id"); |
---|
834 | if (id != null && id.length() > 0) { |
---|
835 | m_ID = id; |
---|
836 | } |
---|
837 | |
---|
838 | // get the score for this node |
---|
839 | String scoreS = nodeE.getAttribute("score"); |
---|
840 | if (scoreS != null && scoreS.length() > 0) { |
---|
841 | m_scoreString = scoreS; |
---|
842 | |
---|
843 | // try to parse as a number in case we |
---|
844 | // are part of a regression tree |
---|
845 | if (classAtt.isNumeric()) { |
---|
846 | try { |
---|
847 | m_scoreNumeric = Double.parseDouble(scoreS); |
---|
848 | } catch (NumberFormatException ex) { |
---|
849 | throw new Exception("[TreeNode] class is numeric but unable to parse score " |
---|
850 | + m_scoreString + " as a number!"); |
---|
851 | } |
---|
852 | } else { |
---|
853 | // store the index of this class value |
---|
854 | m_scoreIndex = classAtt.indexOfValue(m_scoreString); |
---|
855 | |
---|
856 | if (m_scoreIndex < 0) { |
---|
857 | throw new Exception("[TreeNode] can't find match for predicted value " |
---|
858 | + m_scoreString + " in class attribute!"); |
---|
859 | } |
---|
860 | } |
---|
861 | } |
---|
862 | |
---|
863 | // get the record count if defined |
---|
864 | String recordC = nodeE.getAttribute("recordCount"); |
---|
865 | if (recordC != null && recordC.length() > 0) { |
---|
866 | m_recordCount = Double.parseDouble(recordC); |
---|
867 | } |
---|
868 | |
---|
869 | // get the default child (if applicable) |
---|
870 | String defaultC = nodeE.getAttribute("defaultChild"); |
---|
871 | if (defaultC != null && defaultC.length() > 0) { |
---|
872 | m_defaultChildID = defaultC; |
---|
873 | } |
---|
874 | |
---|
875 | //TODO: Embedded model (once we support model composition) |
---|
876 | |
---|
877 | // Now get the ScoreDistributions (if any and mining function |
---|
878 | // is classification) at this level |
---|
879 | if (m_functionType == MiningFunction.CLASSIFICATION) { |
---|
880 | getScoreDistributions(nodeE, miningSchema); |
---|
881 | } |
---|
882 | |
---|
883 | // Now get the Predicate |
---|
884 | m_predicate = Predicate.getPredicate(nodeE, miningSchema); |
---|
885 | |
---|
886 | // Now get the child Node(s) |
---|
887 | getChildNodes(nodeE, miningSchema); |
---|
888 | |
---|
889 | // If we have a default child specified, find it now |
---|
890 | if (m_defaultChildID != null) { |
---|
891 | for (TreeNode t : m_childNodes) { |
---|
892 | if (t.getID().equals(m_defaultChildID)) { |
---|
893 | m_defaultChild = t; |
---|
894 | break; |
---|
895 | } |
---|
896 | } |
---|
897 | } |
---|
898 | } |
---|
899 | |
---|
900 | private void getChildNodes(Element nodeE, MiningSchema miningSchema) throws Exception { |
---|
901 | NodeList children = nodeE.getChildNodes(); |
---|
902 | |
---|
903 | for (int i = 0; i < children.getLength(); i++) { |
---|
904 | Node child = children.item(i); |
---|
905 | if (child.getNodeType() == Node.ELEMENT_NODE) { |
---|
906 | String tagName = ((Element)child).getTagName(); |
---|
907 | if (tagName.equals("Node")) { |
---|
908 | TreeNode tempN = new TreeNode((Element)child, miningSchema); |
---|
909 | m_childNodes.add(tempN); |
---|
910 | } |
---|
911 | } |
---|
912 | } |
---|
913 | } |
---|
914 | |
---|
915 | private void getScoreDistributions(Element nodeE, |
---|
916 | MiningSchema miningSchema) throws Exception { |
---|
917 | |
---|
918 | NodeList scoreChildren = nodeE.getChildNodes(); |
---|
919 | for (int i = 0; i < scoreChildren.getLength(); i++) { |
---|
920 | Node child = scoreChildren.item(i); |
---|
921 | if (child.getNodeType() == Node.ELEMENT_NODE) { |
---|
922 | String tagName = ((Element)child).getTagName(); |
---|
923 | if (tagName.equals("ScoreDistribution")) { |
---|
924 | ScoreDistribution newDist = new ScoreDistribution((Element)child, |
---|
925 | miningSchema, m_recordCount); |
---|
926 | m_scoreDistributions.add(newDist); |
---|
927 | } |
---|
928 | } |
---|
929 | } |
---|
930 | |
---|
931 | // backfit the confidence values |
---|
932 | if (Utils.isMissingValue(m_recordCount)) { |
---|
933 | double baseCount = 0; |
---|
934 | for (ScoreDistribution s : m_scoreDistributions) { |
---|
935 | baseCount += s.getRecordCount(); |
---|
936 | } |
---|
937 | |
---|
938 | for (ScoreDistribution s : m_scoreDistributions) { |
---|
939 | s.deriveConfidenceValue(baseCount); |
---|
940 | } |
---|
941 | } |
---|
942 | } |
---|
943 | |
---|
944 | /** |
---|
945 | * Get the score value as a string. |
---|
946 | * |
---|
947 | * @return the score value as a String. |
---|
948 | */ |
---|
949 | protected String getScore() { |
---|
950 | return m_scoreString; |
---|
951 | } |
---|
952 | |
---|
953 | /** |
---|
954 | * Get the score value as a number (regression trees only). |
---|
955 | * |
---|
956 | * @return the score as a number |
---|
957 | */ |
---|
958 | protected double getScoreNumeric() { |
---|
959 | return m_scoreNumeric; |
---|
960 | } |
---|
961 | |
---|
962 | /** |
---|
963 | * Get the ID of this node. |
---|
964 | * |
---|
965 | * @return the ID of this node. |
---|
966 | */ |
---|
967 | protected String getID() { |
---|
968 | return m_ID; |
---|
969 | } |
---|
970 | |
---|
971 | /** |
---|
972 | * Get the Predicate at this node. |
---|
973 | * |
---|
974 | * @return the predicate at this node. |
---|
975 | */ |
---|
976 | protected Predicate getPredicate() { |
---|
977 | return m_predicate; |
---|
978 | } |
---|
979 | |
---|
980 | /** |
---|
981 | * Get the record count at this node. |
---|
982 | * |
---|
983 | * @return the record count at this node. |
---|
984 | */ |
---|
985 | protected double getRecordCount() { |
---|
986 | return m_recordCount; |
---|
987 | } |
---|
988 | |
---|
989 | protected void dumpGraph(StringBuffer text) throws Exception { |
---|
990 | text.append("N" + m_ID + " "); |
---|
991 | if (m_scoreString != null) { |
---|
992 | text.append("[label=\"score=" + m_scoreString); |
---|
993 | } |
---|
994 | |
---|
995 | if (m_scoreDistributions.size() > 0 && m_childNodes.size() == 0) { |
---|
996 | text.append("\\n"); |
---|
997 | for (ScoreDistribution s : m_scoreDistributions) { |
---|
998 | text.append(s + "\\n"); |
---|
999 | } |
---|
1000 | } |
---|
1001 | |
---|
1002 | text.append("\""); |
---|
1003 | |
---|
1004 | if (m_childNodes.size() == 0) { |
---|
1005 | text.append(" shape=box style=filled"); |
---|
1006 | |
---|
1007 | } |
---|
1008 | |
---|
1009 | text.append("]\n"); |
---|
1010 | |
---|
1011 | for (TreeNode c : m_childNodes) { |
---|
1012 | text.append("N" + m_ID +"->" + "N" + c.getID()); |
---|
1013 | text.append(" [label=\"" + c.getPredicate().toString(0, true)); |
---|
1014 | text.append("\"]\n"); |
---|
1015 | c.dumpGraph(text); |
---|
1016 | } |
---|
1017 | } |
---|
1018 | |
---|
1019 | public String toString() { |
---|
1020 | StringBuffer text = new StringBuffer(); |
---|
1021 | |
---|
1022 | // print out the root |
---|
1023 | dumpTree(0, text); |
---|
1024 | |
---|
1025 | return text.toString(); |
---|
1026 | } |
---|
1027 | |
---|
1028 | protected void dumpTree(int level, StringBuffer text) { |
---|
1029 | if (m_childNodes.size() > 0) { |
---|
1030 | |
---|
1031 | for (int i = 0; i < m_childNodes.size(); i++) { |
---|
1032 | text.append("\n"); |
---|
1033 | |
---|
1034 | /* for (int j = 0; j < level; j++) { |
---|
1035 | text.append("| "); |
---|
1036 | } */ |
---|
1037 | |
---|
1038 | // output the predicate for this child node |
---|
1039 | TreeNode child = m_childNodes.get(i); |
---|
1040 | text.append(child.getPredicate().toString(level, false)); |
---|
1041 | |
---|
1042 | // process recursively |
---|
1043 | child.dumpTree(level + 1 , text); |
---|
1044 | } |
---|
1045 | } else { |
---|
1046 | // leaf |
---|
1047 | text.append(": "); |
---|
1048 | if (!Utils.isMissingValue(m_scoreNumeric)) { |
---|
1049 | text.append(m_scoreNumeric); |
---|
1050 | } else { |
---|
1051 | text.append(m_scoreString + " "); |
---|
1052 | if (m_scoreDistributions.size() > 0) { |
---|
1053 | text.append("["); |
---|
1054 | for (ScoreDistribution s : m_scoreDistributions) { |
---|
1055 | text.append(s); |
---|
1056 | } |
---|
1057 | text.append("]"); |
---|
1058 | } else { |
---|
1059 | text.append(m_scoreString); |
---|
1060 | } |
---|
1061 | } |
---|
1062 | } |
---|
1063 | } |
---|
1064 | |
---|
1065 | /** |
---|
1066 | * Score an incoming instance. Invokes a missing value handling strategy. |
---|
1067 | * |
---|
1068 | * @param instance a vector of incoming attribute and derived field values. |
---|
1069 | * @param classAtt the class attribute |
---|
1070 | * @return a predicted probability distribution. |
---|
1071 | * @throws Exception if something goes wrong. |
---|
1072 | */ |
---|
1073 | protected double[] score(double[] instance, Attribute classAtt) throws Exception { |
---|
1074 | double[] preds = null; |
---|
1075 | |
---|
1076 | if (classAtt.isNumeric()) { |
---|
1077 | preds = new double[1]; |
---|
1078 | } else { |
---|
1079 | preds = new double[classAtt.numValues()]; |
---|
1080 | } |
---|
1081 | |
---|
1082 | // leaf? |
---|
1083 | if (m_childNodes.size() == 0) { |
---|
1084 | doLeaf(classAtt, preds); |
---|
1085 | } else { |
---|
1086 | // process the children |
---|
1087 | switch (TreeModel.this.m_missingValueStrategy) { |
---|
1088 | case NONE: |
---|
1089 | preds = missingValueStrategyNone(instance, classAtt); |
---|
1090 | break; |
---|
1091 | case LASTPREDICTION: |
---|
1092 | preds = missingValueStrategyLastPrediction(instance, classAtt); |
---|
1093 | break; |
---|
1094 | case DEFAULTCHILD: |
---|
1095 | preds = missingValueStrategyDefaultChild(instance, classAtt); |
---|
1096 | break; |
---|
1097 | default: |
---|
1098 | throw new Exception("[TreeModel] not implemented!"); |
---|
1099 | } |
---|
1100 | } |
---|
1101 | |
---|
1102 | return preds; |
---|
1103 | } |
---|
1104 | |
---|
1105 | /** |
---|
1106 | * Compute the predictions for a leaf. |
---|
1107 | * |
---|
1108 | * @param classAtt the class attribute |
---|
1109 | * @param preds an array to hold the predicted probabilities. |
---|
1110 | * @throws Exception if something goes wrong. |
---|
1111 | */ |
---|
1112 | protected void doLeaf(Attribute classAtt, double[] preds) throws Exception { |
---|
1113 | if (classAtt.isNumeric()) { |
---|
1114 | preds[0] = m_scoreNumeric; |
---|
1115 | } else { |
---|
1116 | if (m_scoreDistributions.size() == 0) { |
---|
1117 | preds[m_scoreIndex] = 1.0; |
---|
1118 | } else { |
---|
1119 | // collect confidences from the score distributions |
---|
1120 | for (ScoreDistribution s : m_scoreDistributions) { |
---|
1121 | preds[s.getClassLabelIndex()] = s.getConfidence(); |
---|
1122 | } |
---|
1123 | } |
---|
1124 | } |
---|
1125 | } |
---|
1126 | |
---|
1127 | /** |
---|
1128 | * Evaluate on the basis of the no true child strategy. |
---|
1129 | * |
---|
1130 | * @param classAtt the class attribute. |
---|
1131 | * @param preds an array to hold the predicted probabilities. |
---|
1132 | * @throws Exception if something goes wrong. |
---|
1133 | */ |
---|
1134 | protected void doNoTrueChild(Attribute classAtt, double[] preds) |
---|
1135 | throws Exception { |
---|
1136 | if (TreeModel.this.m_noTrueChildStrategy == |
---|
1137 | NoTrueChildStrategy.RETURNNULLPREDICTION) { |
---|
1138 | for (int i = 0; i < classAtt.numValues(); i++) { |
---|
1139 | preds[i] = Utils.missingValue(); |
---|
1140 | } |
---|
1141 | } else { |
---|
1142 | // return the predictions at this node |
---|
1143 | doLeaf(classAtt, preds); |
---|
1144 | } |
---|
1145 | } |
---|
1146 | |
---|
1147 | /** |
---|
1148 | * Compute predictions and optionally invoke the weighted confidence |
---|
1149 | * missing value handling strategy. |
---|
1150 | * |
---|
1151 | * @param instance the incoming vector of attribute and derived field values. |
---|
1152 | * @param classAtt the class attribute. |
---|
1153 | * @return the predicted probability distribution. |
---|
1154 | * @throws Exception if something goes wrong. |
---|
1155 | */ |
---|
1156 | protected double[] missingValueStrategyWeightedConfidence(double[] instance, |
---|
1157 | Attribute classAtt) throws Exception { |
---|
1158 | |
---|
1159 | if (classAtt.isNumeric()) { |
---|
1160 | throw new Exception("[TreeNode] missing value strategy weighted confidence, " |
---|
1161 | + "but class is numeric!"); |
---|
1162 | } |
---|
1163 | |
---|
1164 | double[] preds = null; |
---|
1165 | TreeNode trueNode = null; |
---|
1166 | boolean strategyInvoked = false; |
---|
1167 | int nodeCount = 0; |
---|
1168 | |
---|
1169 | // look at the evaluation of the child predicates |
---|
1170 | for (TreeNode c : m_childNodes) { |
---|
1171 | if (c.getPredicate().evaluate(instance) == Predicate.Eval.TRUE) { |
---|
1172 | // note the first child to evaluate to true |
---|
1173 | if (trueNode == null) { |
---|
1174 | trueNode = c; |
---|
1175 | } |
---|
1176 | nodeCount++; |
---|
1177 | } else if (c.getPredicate().evaluate(instance) == Predicate.Eval.UNKNOWN) { |
---|
1178 | strategyInvoked = true; |
---|
1179 | nodeCount++; |
---|
1180 | } |
---|
1181 | } |
---|
1182 | |
---|
1183 | if (strategyInvoked) { |
---|
1184 | // we expect to combine nodeCount distributions |
---|
1185 | double[][] dists = new double[nodeCount][]; |
---|
1186 | double[] weights = new double[nodeCount]; |
---|
1187 | |
---|
1188 | // collect the distributions and weights |
---|
1189 | int count = 0; |
---|
1190 | for (TreeNode c : m_childNodes) { |
---|
1191 | if (c.getPredicate().evaluate(instance) == Predicate.Eval.TRUE || |
---|
1192 | c.getPredicate().evaluate(instance) == Predicate.Eval.UNKNOWN) { |
---|
1193 | |
---|
1194 | weights[count] = c.getRecordCount(); |
---|
1195 | if (Utils.isMissingValue(weights[count])) { |
---|
1196 | throw new Exception("[TreeNode] weighted confidence missing value " + |
---|
1197 | "strategy invoked, but no record count defined for node " + |
---|
1198 | c.getID()); |
---|
1199 | } |
---|
1200 | dists[count++] = c.score(instance, classAtt); |
---|
1201 | } |
---|
1202 | } |
---|
1203 | |
---|
1204 | // do the combination |
---|
1205 | preds = new double[classAtt.numValues()]; |
---|
1206 | for (int i = 0; i < classAtt.numValues(); i++) { |
---|
1207 | for (int j = 0; j < nodeCount; j++) { |
---|
1208 | preds[i] += ((weights[j] / m_recordCount) * dists[j][i]); |
---|
1209 | } |
---|
1210 | } |
---|
1211 | } else { |
---|
1212 | if (trueNode != null) { |
---|
1213 | preds = trueNode.score(instance, classAtt); |
---|
1214 | } else { |
---|
1215 | doNoTrueChild(classAtt, preds); |
---|
1216 | } |
---|
1217 | } |
---|
1218 | |
---|
1219 | return preds; |
---|
1220 | } |
---|
1221 | |
---|
1222 | protected double[] freqCountsForAggNodesStrategy(double[] instance, |
---|
1223 | Attribute classAtt) throws Exception { |
---|
1224 | |
---|
1225 | double[] counts = new double[classAtt.numValues()]; |
---|
1226 | |
---|
1227 | if (m_childNodes.size() > 0) { |
---|
1228 | // collect the counts |
---|
1229 | for (TreeNode c : m_childNodes) { |
---|
1230 | if (c.getPredicate().evaluate(instance) == Predicate.Eval.TRUE || |
---|
1231 | c.getPredicate().evaluate(instance) == Predicate.Eval.UNKNOWN) { |
---|
1232 | |
---|
1233 | double[] temp = c.freqCountsForAggNodesStrategy(instance, classAtt); |
---|
1234 | for (int i = 0; i < classAtt.numValues(); i++) { |
---|
1235 | counts[i] += temp[i]; |
---|
1236 | } |
---|
1237 | } |
---|
1238 | } |
---|
1239 | } else { |
---|
1240 | // process the score distributions |
---|
1241 | if (m_scoreDistributions.size() == 0) { |
---|
1242 | throw new Exception("[TreeModel] missing value strategy aggregate nodes:" + |
---|
1243 | " no score distributions at leaf " + m_ID); |
---|
1244 | } |
---|
1245 | for (ScoreDistribution s : m_scoreDistributions) { |
---|
1246 | counts[s.getClassLabelIndex()] = s.getRecordCount(); |
---|
1247 | } |
---|
1248 | } |
---|
1249 | |
---|
1250 | return counts; |
---|
1251 | } |
---|
1252 | |
---|
1253 | /** |
---|
1254 | * Compute predictions and optionally invoke the aggregate nodes |
---|
1255 | * missing value handling strategy. |
---|
1256 | * |
---|
1257 | * @param instance the incoming vector of attribute and derived field values. |
---|
1258 | * @param classAtt the class attribute. |
---|
1259 | * @return the predicted probability distribution. |
---|
1260 | * @throws Exception if something goes wrong. |
---|
1261 | */ |
---|
1262 | protected double[] missingValueStrategyAggregateNodes(double[] instance, |
---|
1263 | Attribute classAtt) throws Exception { |
---|
1264 | |
---|
1265 | if (classAtt.isNumeric()) { |
---|
1266 | throw new Exception("[TreeNode] missing value strategy aggregate nodes, " |
---|
1267 | + "but class is numeric!"); |
---|
1268 | } |
---|
1269 | |
---|
1270 | double[] preds = null; |
---|
1271 | TreeNode trueNode = null; |
---|
1272 | boolean strategyInvoked = false; |
---|
1273 | int nodeCount = 0; |
---|
1274 | |
---|
1275 | // look at the evaluation of the child predicates |
---|
1276 | for (TreeNode c : m_childNodes) { |
---|
1277 | if (c.getPredicate().evaluate(instance) == Predicate.Eval.TRUE) { |
---|
1278 | // note the first child to evaluate to true |
---|
1279 | if (trueNode == null) { |
---|
1280 | trueNode = c; |
---|
1281 | } |
---|
1282 | nodeCount++; |
---|
1283 | } else if (c.getPredicate().evaluate(instance) == Predicate.Eval.UNKNOWN) { |
---|
1284 | strategyInvoked = true; |
---|
1285 | nodeCount++; |
---|
1286 | } |
---|
1287 | } |
---|
1288 | |
---|
1289 | if (strategyInvoked) { |
---|
1290 | double[] aggregatedCounts = |
---|
1291 | freqCountsForAggNodesStrategy(instance, classAtt); |
---|
1292 | |
---|
1293 | // normalize |
---|
1294 | Utils.normalize(aggregatedCounts); |
---|
1295 | preds = aggregatedCounts; |
---|
1296 | } else { |
---|
1297 | if (trueNode != null) { |
---|
1298 | preds = trueNode.score(instance, classAtt); |
---|
1299 | } else { |
---|
1300 | doNoTrueChild(classAtt, preds); |
---|
1301 | } |
---|
1302 | } |
---|
1303 | |
---|
1304 | return preds; |
---|
1305 | } |
---|
1306 | |
---|
1307 | /** |
---|
1308 | * Compute predictions and optionally invoke the default child |
---|
1309 | * missing value handling strategy. |
---|
1310 | * |
---|
1311 | * @param instance the incoming vector of attribute and derived field values. |
---|
1312 | * @param classAtt the class attribute. |
---|
1313 | * @return the predicted probability distribution. |
---|
1314 | * @throws Exception if something goes wrong. |
---|
1315 | */ |
---|
1316 | protected double[] missingValueStrategyDefaultChild(double[] instance, |
---|
1317 | Attribute classAtt) throws Exception { |
---|
1318 | |
---|
1319 | double[] preds = null; |
---|
1320 | boolean strategyInvoked = false; |
---|
1321 | |
---|
1322 | // look for a child whose predicate evaluates to TRUE |
---|
1323 | for (TreeNode c : m_childNodes) { |
---|
1324 | if (c.getPredicate().evaluate(instance) == Predicate.Eval.TRUE) { |
---|
1325 | preds = c.score(instance, classAtt); |
---|
1326 | break; |
---|
1327 | } else if (c.getPredicate().evaluate(instance) == Predicate.Eval.UNKNOWN) { |
---|
1328 | strategyInvoked = true; |
---|
1329 | } |
---|
1330 | } |
---|
1331 | |
---|
1332 | // no true child found |
---|
1333 | if (preds == null) { |
---|
1334 | if (!strategyInvoked) { |
---|
1335 | doNoTrueChild(classAtt, preds); |
---|
1336 | } else { |
---|
1337 | // do the strategy |
---|
1338 | |
---|
1339 | // NOTE: we don't actually implement the missing value penalty since |
---|
1340 | // we always return a full probability distribution. |
---|
1341 | if (m_defaultChild != null) { |
---|
1342 | preds = m_defaultChild.score(instance, classAtt); |
---|
1343 | } else { |
---|
1344 | throw new Exception("[TreeNode] missing value strategy is defaultChild, but " + |
---|
1345 | "no default child has been specified in node " + m_ID); |
---|
1346 | } |
---|
1347 | } |
---|
1348 | } |
---|
1349 | |
---|
1350 | return preds; |
---|
1351 | } |
---|
1352 | |
---|
1353 | /** |
---|
1354 | * Compute predictions and optionally invoke the last prediction |
---|
1355 | * missing value handling strategy. |
---|
1356 | * |
---|
1357 | * @param instance the incoming vector of attribute and derived field values. |
---|
1358 | * @param classAtt the class attribute. |
---|
1359 | * @return the predicted probability distribution. |
---|
1360 | * @throws Exception if something goes wrong. |
---|
1361 | */ |
---|
1362 | protected double[] missingValueStrategyLastPrediction(double[] instance, |
---|
1363 | Attribute classAtt) throws Exception { |
---|
1364 | |
---|
1365 | double[] preds = null; |
---|
1366 | boolean strategyInvoked = false; |
---|
1367 | |
---|
1368 | // look for a child whose predicate evaluates to TRUE |
---|
1369 | for (TreeNode c : m_childNodes) { |
---|
1370 | if (c.getPredicate().evaluate(instance) == Predicate.Eval.TRUE) { |
---|
1371 | preds = c.score(instance, classAtt); |
---|
1372 | break; |
---|
1373 | } else if (c.getPredicate().evaluate(instance) == Predicate.Eval.UNKNOWN) { |
---|
1374 | strategyInvoked = true; |
---|
1375 | } |
---|
1376 | } |
---|
1377 | |
---|
1378 | // no true child found |
---|
1379 | if (preds == null) { |
---|
1380 | preds = new double[classAtt.numValues()]; |
---|
1381 | if (!strategyInvoked) { |
---|
1382 | // no true child |
---|
1383 | doNoTrueChild(classAtt, preds); |
---|
1384 | } else { |
---|
1385 | // do the strategy |
---|
1386 | doLeaf(classAtt, preds); |
---|
1387 | } |
---|
1388 | } |
---|
1389 | |
---|
1390 | return preds; |
---|
1391 | } |
---|
1392 | |
---|
1393 | /** |
---|
1394 | * Compute predictions and optionally invoke the null prediction |
---|
1395 | * missing value handling strategy. |
---|
1396 | * |
---|
1397 | * @param instance the incoming vector of attribute and derived field values. |
---|
1398 | * @param classAtt the class attribute. |
---|
1399 | * @return the predicted probability distribution. |
---|
1400 | * @throws Exception if something goes wrong. |
---|
1401 | */ |
---|
1402 | protected double[] missingValueStrategyNullPrediction(double[] instance, |
---|
1403 | Attribute classAtt) throws Exception { |
---|
1404 | |
---|
1405 | double[] preds = null; |
---|
1406 | boolean strategyInvoked = false; |
---|
1407 | |
---|
1408 | // look for a child whose predicate evaluates to TRUE |
---|
1409 | for (TreeNode c : m_childNodes) { |
---|
1410 | if (c.getPredicate().evaluate(instance) == Predicate.Eval.TRUE) { |
---|
1411 | preds = c.score(instance, classAtt); |
---|
1412 | break; |
---|
1413 | } else if (c.getPredicate().evaluate(instance) == Predicate.Eval.UNKNOWN) { |
---|
1414 | strategyInvoked = true; |
---|
1415 | } |
---|
1416 | } |
---|
1417 | |
---|
1418 | // no true child found |
---|
1419 | if (preds == null) { |
---|
1420 | preds = new double[classAtt.numValues()]; |
---|
1421 | if (!strategyInvoked) { |
---|
1422 | doNoTrueChild(classAtt, preds); |
---|
1423 | } else { |
---|
1424 | // do the strategy |
---|
1425 | for (int i = 0; i < classAtt.numValues(); i++) { |
---|
1426 | preds[i] = Utils.missingValue(); |
---|
1427 | } |
---|
1428 | } |
---|
1429 | } |
---|
1430 | |
---|
1431 | return preds; |
---|
1432 | } |
---|
1433 | |
---|
1434 | /** |
---|
1435 | * Compute predictions and optionally invoke the "none" |
---|
1436 | * missing value handling strategy (invokes no true child). |
---|
1437 | * |
---|
1438 | * @param instance the incoming vector of attribute and derived field values. |
---|
1439 | * @param classAtt the class attribute. |
---|
1440 | * @return the predicted probability distribution. |
---|
1441 | * @throws Exception if something goes wrong. |
---|
1442 | */ |
---|
1443 | protected double[] missingValueStrategyNone(double[] instance, Attribute classAtt) |
---|
1444 | throws Exception { |
---|
1445 | |
---|
1446 | double[] preds = null; |
---|
1447 | |
---|
1448 | // look for a child whose predicate evaluates to TRUE |
---|
1449 | for (TreeNode c : m_childNodes) { |
---|
1450 | if (c.getPredicate().evaluate(instance) == Predicate.Eval.TRUE) { |
---|
1451 | preds = c.score(instance, classAtt); |
---|
1452 | break; |
---|
1453 | } |
---|
1454 | } |
---|
1455 | |
---|
1456 | if (preds == null) { |
---|
1457 | preds = new double[classAtt.numValues()]; |
---|
1458 | |
---|
1459 | // no true child strategy |
---|
1460 | doNoTrueChild(classAtt, preds); |
---|
1461 | } |
---|
1462 | |
---|
1463 | return preds; |
---|
1464 | } |
---|
1465 | } |
---|
1466 | |
---|
1467 | /** |
---|
1468 | * Enumerated type for the mining function |
---|
1469 | */ |
---|
1470 | enum MiningFunction { |
---|
1471 | CLASSIFICATION, |
---|
1472 | REGRESSION; |
---|
1473 | } |
---|
1474 | |
---|
1475 | enum MissingValueStrategy { |
---|
1476 | LASTPREDICTION("lastPrediction"), |
---|
1477 | NULLPREDICTION("nullPrediction"), |
---|
1478 | DEFAULTCHILD("defaultChild"), |
---|
1479 | WEIGHTEDCONFIDENCE("weightedConfidence"), |
---|
1480 | AGGREGATENODES("aggregateNodes"), |
---|
1481 | NONE("none"); |
---|
1482 | |
---|
1483 | private final String m_stringVal; |
---|
1484 | |
---|
1485 | MissingValueStrategy(String name) { |
---|
1486 | m_stringVal = name; |
---|
1487 | } |
---|
1488 | |
---|
1489 | public String toString() { |
---|
1490 | return m_stringVal; |
---|
1491 | } |
---|
1492 | } |
---|
1493 | |
---|
1494 | enum NoTrueChildStrategy { |
---|
1495 | RETURNNULLPREDICTION("returnNullPrediction"), |
---|
1496 | RETURNLASTPREDICTION("returnLastPrediction"); |
---|
1497 | |
---|
1498 | private final String m_stringVal; |
---|
1499 | |
---|
1500 | NoTrueChildStrategy(String name) { |
---|
1501 | m_stringVal = name; |
---|
1502 | } |
---|
1503 | |
---|
1504 | public String toString() { |
---|
1505 | return m_stringVal; |
---|
1506 | } |
---|
1507 | } |
---|
1508 | |
---|
1509 | enum SplitCharacteristic { |
---|
1510 | BINARYSPLIT("binarySplit"), |
---|
1511 | MULTISPLIT("multiSplit"); |
---|
1512 | |
---|
1513 | private final String m_stringVal; |
---|
1514 | |
---|
1515 | SplitCharacteristic(String name) { |
---|
1516 | m_stringVal = name; |
---|
1517 | } |
---|
1518 | |
---|
1519 | public String toString() { |
---|
1520 | return m_stringVal; |
---|
1521 | } |
---|
1522 | } |
---|
1523 | |
---|
1524 | /** The mining function */ |
---|
1525 | protected MiningFunction m_functionType = MiningFunction.CLASSIFICATION; |
---|
1526 | |
---|
1527 | /** The missing value strategy */ |
---|
1528 | protected MissingValueStrategy m_missingValueStrategy = MissingValueStrategy.NONE; |
---|
1529 | |
---|
1530 | /** |
---|
1531 | * The missing value penalty (if defined). |
---|
1532 | * We don't actually make use of this since we always return |
---|
1533 | * full probability distributions. |
---|
1534 | */ |
---|
1535 | protected double m_missingValuePenalty = Utils.missingValue(); |
---|
1536 | |
---|
1537 | /** The no true child strategy to use */ |
---|
1538 | protected NoTrueChildStrategy m_noTrueChildStrategy = NoTrueChildStrategy.RETURNNULLPREDICTION; |
---|
1539 | |
---|
1540 | /** The splitting type */ |
---|
1541 | protected SplitCharacteristic m_splitCharacteristic = SplitCharacteristic.MULTISPLIT; |
---|
1542 | |
---|
1543 | /** The root of the tree */ |
---|
1544 | protected TreeNode m_root; |
---|
1545 | |
---|
1546 | public TreeModel(Element model, Instances dataDictionary, |
---|
1547 | MiningSchema miningSchema) throws Exception { |
---|
1548 | |
---|
1549 | super(dataDictionary, miningSchema); |
---|
1550 | |
---|
1551 | if (!getPMMLVersion().equals("3.2")) { |
---|
1552 | // TODO: might have to throw an exception and only support 3.2 |
---|
1553 | } |
---|
1554 | |
---|
1555 | String fn = model.getAttribute("functionName"); |
---|
1556 | if (fn.equals("regression")) { |
---|
1557 | m_functionType = MiningFunction.REGRESSION; |
---|
1558 | } |
---|
1559 | |
---|
1560 | // get the missing value strategy (if any) |
---|
1561 | String missingVS = model.getAttribute("missingValueStrategy"); |
---|
1562 | if (missingVS != null && missingVS.length() > 0) { |
---|
1563 | for (MissingValueStrategy m : MissingValueStrategy.values()) { |
---|
1564 | if (m.toString().equals(missingVS)) { |
---|
1565 | m_missingValueStrategy = m; |
---|
1566 | break; |
---|
1567 | } |
---|
1568 | } |
---|
1569 | } |
---|
1570 | |
---|
1571 | // get the missing value penalty (if any) |
---|
1572 | String missingP = model.getAttribute("missingValuePenalty"); |
---|
1573 | if (missingP != null && missingP.length() > 0) { |
---|
1574 | // try to parse as a number |
---|
1575 | try { |
---|
1576 | m_missingValuePenalty = Double.parseDouble(missingP); |
---|
1577 | } catch (NumberFormatException ex) { |
---|
1578 | System.err.println("[TreeModel] WARNING: " + |
---|
1579 | "couldn't parse supplied missingValuePenalty as a number"); |
---|
1580 | } |
---|
1581 | } |
---|
1582 | |
---|
1583 | String splitC = model.getAttribute("splitCharacteristic"); |
---|
1584 | |
---|
1585 | if (splitC != null && splitC.length() > 0) { |
---|
1586 | for (SplitCharacteristic s : SplitCharacteristic.values()) { |
---|
1587 | if (s.toString().equals(splitC)) { |
---|
1588 | m_splitCharacteristic = s; |
---|
1589 | break; |
---|
1590 | } |
---|
1591 | } |
---|
1592 | } |
---|
1593 | |
---|
1594 | // find the root node of the tree |
---|
1595 | NodeList children = model.getChildNodes(); |
---|
1596 | for (int i = 0; i < children.getLength(); i++) { |
---|
1597 | Node child = children.item(i); |
---|
1598 | if (child.getNodeType() == Node.ELEMENT_NODE) { |
---|
1599 | String tagName = ((Element)child).getTagName(); |
---|
1600 | if (tagName.equals("Node")) { |
---|
1601 | m_root = new TreeNode((Element)child, miningSchema); |
---|
1602 | break; |
---|
1603 | } |
---|
1604 | } |
---|
1605 | } |
---|
1606 | } |
---|
1607 | |
---|
1608 | /** |
---|
1609 | * Classifies the given test instance. The instance has to belong to a |
---|
1610 | * dataset when it's being classified. |
---|
1611 | * |
---|
1612 | * @param inst the instance to be classified |
---|
1613 | * @return the predicted most likely class for the instance or |
---|
1614 | * Utils.missingValue() if no prediction is made |
---|
1615 | * @exception Exception if an error occurred during the prediction |
---|
1616 | */ |
---|
1617 | public double[] distributionForInstance(Instance inst) throws Exception { |
---|
1618 | if (!m_initialized) { |
---|
1619 | mapToMiningSchema(inst.dataset()); |
---|
1620 | } |
---|
1621 | double[] preds = null; |
---|
1622 | |
---|
1623 | if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) { |
---|
1624 | preds = new double[1]; |
---|
1625 | } else { |
---|
1626 | preds = new double[m_miningSchema.getFieldsAsInstances().classAttribute().numValues()]; |
---|
1627 | } |
---|
1628 | |
---|
1629 | double[] incoming = m_fieldsMap.instanceToSchema(inst, m_miningSchema); |
---|
1630 | |
---|
1631 | preds = m_root.score(incoming, m_miningSchema.getFieldsAsInstances().classAttribute()); |
---|
1632 | |
---|
1633 | return preds; |
---|
1634 | } |
---|
1635 | |
---|
1636 | public String toString() { |
---|
1637 | StringBuffer temp = new StringBuffer(); |
---|
1638 | |
---|
1639 | temp.append("PMML version " + getPMMLVersion()); |
---|
1640 | if (!getCreatorApplication().equals("?")) { |
---|
1641 | temp.append("\nApplication: " + getCreatorApplication()); |
---|
1642 | } |
---|
1643 | temp.append("\nPMML Model: TreeModel"); |
---|
1644 | temp.append("\n\n"); |
---|
1645 | temp.append(m_miningSchema); |
---|
1646 | |
---|
1647 | temp.append("Split-type: " + m_splitCharacteristic + "\n"); |
---|
1648 | temp.append("No true child strategy: " + m_noTrueChildStrategy + "\n"); |
---|
1649 | temp.append("Missing value strategy: " + m_missingValueStrategy + "\n"); |
---|
1650 | |
---|
1651 | temp.append(m_root.toString()); |
---|
1652 | |
---|
1653 | return temp.toString(); |
---|
1654 | } |
---|
1655 | |
---|
1656 | public String graph() throws Exception { |
---|
1657 | StringBuffer text = new StringBuffer(); |
---|
1658 | text.append("digraph PMMTree {\n"); |
---|
1659 | |
---|
1660 | m_root.dumpGraph(text); |
---|
1661 | |
---|
1662 | text.append("}\n"); |
---|
1663 | |
---|
1664 | return text.toString(); |
---|
1665 | } |
---|
1666 | |
---|
1667 | public String getRevision() { |
---|
1668 | return RevisionUtils.extract("$Revision: 5987 $"); |
---|
1669 | } |
---|
1670 | |
---|
1671 | public int graphType() { |
---|
1672 | return Drawable.TREE; |
---|
1673 | } |
---|
1674 | } |
---|