source: src/main/java/weka/classifiers/pmml/consumer/TreeModel.java @ 17

Last change on this file since 17 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 51.7 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    TreeModel.java
19 *    Copyright (C) 2009 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.pmml.consumer;
24
25import java.io.Serializable;
26import java.util.ArrayList;
27
28import org.w3c.dom.Element;
29import org.w3c.dom.Node;
30import org.w3c.dom.NodeList;
31
32import weka.core.Attribute;
33import weka.core.Drawable;
34import weka.core.Instance;
35import weka.core.Instances;
36import weka.core.RevisionUtils;
37import weka.core.Utils;
38import weka.core.pmml.*;
39
40/**
41 * Class implementing import of PMML TreeModel. Can be used as a Weka
42 * classifier for prediction (buildClassifier() raises and Exception).
43 *
44 * @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
45 * @version $Revision: 5987 $;
46 */
47public class TreeModel extends PMMLClassifier implements Drawable {
48 
49  /**
50   * For serialization
51   */
52  private static final long serialVersionUID = -2065158088298753129L;
53
54  /**
55   * Inner class representing the ScoreDistribution element
56   */
57  static class ScoreDistribution implements Serializable {
58   
59    /**
60     * For serialization
61     */
62    private static final long serialVersionUID = -123506262094299933L;
63
64    /** The class label for this distribution element */
65    private String m_classLabel;
66   
67    /** The index of the class label */
68    private int m_classLabelIndex = -1;
69   
70    /** The count for this label */
71    private double m_recordCount;
72   
73    /** The optional confidence value */
74    private double m_confidence = Utils.missingValue();
75   
76    /**
77     * Construct a ScoreDistribution entry
78     *
79     * @param scoreE the node containing the distribution
80     * @param miningSchema the mining schema
81     * @param baseCount the number of records at the node that owns this
82     * distribution entry
83     * @throws Exception if something goes wrong
84     */
85    protected ScoreDistribution(Element scoreE, MiningSchema miningSchema, double baseCount) 
86      throws Exception {
87      // get the label
88      m_classLabel = scoreE.getAttribute("value");
89      Attribute classAtt = miningSchema.getFieldsAsInstances().classAttribute();
90      if (classAtt == null || classAtt.indexOfValue(m_classLabel) < 0) {
91        throw new Exception("[ScoreDistribution] class attribute not set or class value " +
92            m_classLabel + " not found!");
93      }
94     
95      m_classLabelIndex = classAtt.indexOfValue(m_classLabel);
96     
97      // get the frequency
98      String recordC = scoreE.getAttribute("recordCount");
99      m_recordCount = Double.parseDouble(recordC);
100     
101      // get the optional confidence
102      String confidence = scoreE.getAttribute("confidence");
103      if (confidence != null && confidence.length() > 0) {
104        m_confidence = Double.parseDouble(confidence);       
105      } else if (!Utils.isMissingValue(baseCount) && baseCount > 0) {
106        m_confidence = m_recordCount / baseCount;
107      }
108    }
109   
110    /**
111     * Backfit confidence value (does nothing if the confidence
112     * value is already set).
113     *
114     * @param baseCount the total number of records (supplied either
115     * explicitly from the node that owns this distribution entry
116     * or most likely computed from summing the recordCounts of all
117     * the distribution entries in the distribution that owns this
118     * entry).
119     */
120    void deriveConfidenceValue(double baseCount) {
121      if (Utils.isMissingValue(m_confidence) && 
122          !Utils.isMissingValue(baseCount) && 
123          baseCount > 0) {
124        m_confidence = m_recordCount / baseCount;
125      }
126    }
127   
128    String getClassLabel() {
129      return m_classLabel;
130    }
131   
132    int getClassLabelIndex() {
133      return m_classLabelIndex;
134    }
135   
136    double getRecordCount() {
137      return m_recordCount;
138    }
139   
140    double getConfidence() {
141      return m_confidence;
142    }
143   
144    public String toString() {
145      return m_classLabel + ": " + m_recordCount
146        + " (" + Utils.doubleToString(m_confidence, 2) + ") ";
147    }
148  }
149 
150  /**
151   * Base class for Predicates
152   */
153  static abstract class Predicate implements Serializable {
154   
155    /**
156     * For serialization
157     */
158    private static final long serialVersionUID = 1035344165452733887L;
159
160    enum Eval {
161      TRUE,
162      FALSE,
163      UNKNOWN;
164    }
165   
166    /**
167     * Evaluate this predicate.
168     *
169     * @param input the input vector of attribute and derived field values.
170     *
171     * @return the evaluation status of this predicate.
172     */
173    abstract Eval evaluate(double[] input);
174   
175    protected String toString(int level, boolean cr) {
176      return toString(level);
177    }
178   
179    protected String toString(int level) {
180      StringBuffer text = new StringBuffer();
181      for (int j = 0; j < level; j++) {
182        text.append("|   ");
183      }
184     
185      return text.append(toString()).toString();
186    }
187   
188    static Eval booleanToEval(boolean missing, boolean result) {
189      if (missing) {
190        return Eval.UNKNOWN;
191      } else if (result) {
192        return Eval.TRUE;
193      } else {
194        return Eval.FALSE;
195      }
196    }
197   
198    /**
199     * Factory method to return the appropriate predicate for
200     * a given node in the tree.
201     *
202     * @param nodeE the XML node encapsulating the tree node.
203     * @param miningSchema the mining schema in use
204     * @return a Predicate
205     * @throws Exception of something goes wrong.
206     */
207    static Predicate getPredicate(Element nodeE, 
208        MiningSchema miningSchema) throws Exception {
209     
210      Predicate result = null;
211      NodeList children = nodeE.getChildNodes();
212      for (int i = 0; i < children.getLength(); i++) {
213        Node child = children.item(i);
214        if (child.getNodeType() == Node.ELEMENT_NODE) {
215          String tagName = ((Element)child).getTagName();
216          if (tagName.equals("True")) {
217            result = new True();
218            break;
219          } else if (tagName.equals("False")) {
220            result = new False();
221            break;
222          } else if (tagName.equals("SimplePredicate")) {
223            result = new SimplePredicate((Element)child, miningSchema);
224            break;
225          } else if (tagName.equals("CompoundPredicate")) {
226            result = new CompoundPredicate((Element)child, miningSchema);
227            break;
228          } else if (tagName.equals("SimpleSetPredicate")) {
229           result = new SimpleSetPredicate((Element)child, miningSchema);
230           break;
231          }
232        }
233      }
234     
235      if (result == null) {
236        throw new Exception("[Predicate] unknown or missing predicate type in node");
237      }
238     
239      return result;
240    }
241  }
242 
243  /**
244   * Simple True Predicate
245   */
246  static class True extends Predicate {
247   
248    /**
249     * For serialization
250     */
251    private static final long serialVersionUID = 1817942234610531627L;
252
253    public Predicate.Eval evaluate(double[] input) {
254      return Predicate.Eval.TRUE;
255    }
256   
257    public String toString() {
258      return "True: ";
259    }
260  }
261 
262  /**
263   * Simple False Predicate
264   */
265  static class False extends Predicate {
266   
267    /**
268     * For serialization
269     */
270    private static final long serialVersionUID = -3647261386442860365L;
271
272    public Predicate.Eval evaluate(double[] input) {
273      return Predicate.Eval.FALSE;
274    }
275   
276    public String toString() {
277      return "False: ";
278    }
279  }
280 
281  /**
282   * Class representing the SimplePredicate
283   */
284  static class SimplePredicate extends Predicate {
285   
286    /**
287     * For serialization
288     */
289    private static final long serialVersionUID = -6156684285069327400L;
290
291    enum Operator {
292      EQUAL("equal") {
293        Predicate.Eval evaluate(double[] input, double value, int fieldIndex) {
294          return Predicate.booleanToEval(Utils.isMissingValue(input[fieldIndex]), 
295              weka.core.Utils.eq(input[fieldIndex], value));
296        }
297       
298        String shortName() {
299          return "==";
300        }
301      },
302      NOTEQUAL("notEqual")
303       {
304        Predicate.Eval evaluate(double[] input, double value, int fieldIndex) {
305          return Predicate.booleanToEval(Utils.isMissingValue(input[fieldIndex]), 
306              (input[fieldIndex] != value));
307        }
308       
309        String shortName() {
310          return "!=";
311        }
312      },
313      LESSTHAN("lessThan")  {
314        Predicate.Eval evaluate(double[] input, double value, int fieldIndex) {
315          return Predicate.booleanToEval(Utils.isMissingValue(input[fieldIndex]),
316              (input[fieldIndex] < value));
317        }
318       
319        String shortName() {
320          return "<";
321        }
322      },
323      LESSOREQUAL("lessOrEqual") {
324        Predicate.Eval evaluate(double[] input, double value, int fieldIndex) {
325          return Predicate.booleanToEval(Utils.isMissingValue(input[fieldIndex]),
326              (input[fieldIndex] <= value));
327        }
328       
329        String shortName() {
330          return "<=";
331        }
332      },
333      GREATERTHAN("greaterThan") {
334        Predicate.Eval evaluate(double[] input, double value, int fieldIndex) {
335          return Predicate.booleanToEval(Utils.isMissingValue(input[fieldIndex]),
336              (input[fieldIndex] > value));
337        }
338       
339        String shortName() {
340          return ">";
341        }
342      },
343      GREATEROREQUAL("greaterOrEqual") {
344        Predicate.Eval evaluate(double[] input, double value, int fieldIndex) {
345          return Predicate.booleanToEval(Utils.isMissingValue(input[fieldIndex]),
346              (input[fieldIndex] >= value));
347        }
348       
349        String shortName() {
350          return ">=";
351        }
352      },
353      ISMISSING("isMissing") {
354        Predicate.Eval evaluate(double[] input, double value, int fieldIndex) {
355          return Predicate.booleanToEval(false,
356              Utils.isMissingValue(input[fieldIndex]));
357        }
358       
359        String shortName() {
360          return toString();
361        }
362      },
363      ISNOTMISSING("isNotMissing") {
364        Predicate.Eval evaluate(double[] input, double value, int fieldIndex) {
365          return Predicate.booleanToEval(false, !Utils.isMissingValue(input[fieldIndex]));
366        }
367       
368        String shortName() {
369          return toString();
370        }
371      };
372     
373      abstract Predicate.Eval evaluate(double[] input, double value, int fieldIndex);
374      abstract String shortName();
375     
376      private final String m_stringVal;
377     
378      Operator(String name) {
379        m_stringVal = name;
380      }
381           
382      public String toString() {
383        return m_stringVal;
384      }
385    }
386   
387    /** the field that we are comparing against */
388    int m_fieldIndex = -1;
389   
390    /** the name of the field */
391    String m_fieldName;
392   
393    /** true if the field is nominal */
394    boolean m_isNominal;
395   
396    /** the value as a string (if nominal) */
397    String m_nominalValue;
398   
399    /** the value to compare against (if nominal it holds the index of the value) */
400    double m_value;
401   
402    /** the operator to use */
403    Operator m_operator;
404       
405    public SimplePredicate(Element simpleP, 
406        MiningSchema miningSchema) throws Exception {
407      Instances totalStructure = miningSchema.getFieldsAsInstances();
408     
409      // get the field name and set up the index
410      String fieldS = simpleP.getAttribute("field");
411      Attribute att = totalStructure.attribute(fieldS);
412      if (att == null) {
413        throw new Exception("[SimplePredicate] unable to find field " + fieldS
414            + " in the incoming instance structure!");
415      }
416     
417      // find the index
418      int index = -1;
419      for (int i = 0; i < totalStructure.numAttributes(); i++) {
420        if (totalStructure.attribute(i).name().equals(fieldS)) {
421          index = i;
422          m_fieldName = totalStructure.attribute(i).name();
423          break;
424        }
425      }
426      m_fieldIndex = index;
427      if (att.isNominal()) {
428        m_isNominal = true;
429      }
430     
431      // get the operator
432      String oppS = simpleP.getAttribute("operator");
433      for (Operator o : Operator.values()) {
434        if (o.toString().equals(oppS)) {
435          m_operator = o;
436          break;
437        }
438      }
439     
440      if (m_operator != Operator.ISMISSING && m_operator != Operator.ISNOTMISSING) {
441        String valueS = simpleP.getAttribute("value");
442        if (att.isNumeric()) {
443          m_value = Double.parseDouble(valueS);
444        } else {
445          m_nominalValue = valueS;
446          m_value = att.indexOfValue(valueS);
447          if (m_value < 0) {
448            throw new Exception("[SimplePredicate] can't find value " + valueS + " in nominal " +
449                "attribute " + att.name());
450          }
451        }
452      }
453    }
454   
455    public Predicate.Eval evaluate(double[] input) {
456      return m_operator.evaluate(input, m_value, m_fieldIndex);
457    }
458       
459    public String toString() {
460      StringBuffer temp = new StringBuffer();
461     
462      temp.append(m_fieldName + " " + m_operator.shortName());
463      if (m_operator != Operator.ISMISSING && m_operator != Operator.ISNOTMISSING) {
464        temp.append(" " + ((m_isNominal) ? m_nominalValue : "" + m_value));
465      }
466     
467      return temp.toString();
468    }
469  }
470 
471  /**
472   * Class representing the CompoundPredicate
473   */
474  static class CompoundPredicate extends Predicate {
475   
476    /**
477     * For serialization
478     */
479    private static final long serialVersionUID = -3332091529764559077L;
480
481    enum BooleanOperator {
482      OR("or") {
483        Predicate.Eval evaluate(ArrayList<Predicate> constituents, double[] input) {
484          Predicate.Eval currentStatus = Predicate.Eval.FALSE;
485          for (Predicate p : constituents) {
486            Predicate.Eval temp = p.evaluate(input);
487            if (temp == Predicate.Eval.TRUE) {
488              currentStatus = temp;
489              break;
490            } else if (temp == Predicate.Eval.UNKNOWN) {
491              currentStatus = temp;
492            }           
493          }
494          return currentStatus;
495        }
496      },
497      AND("and") {
498        Predicate.Eval evaluate(ArrayList<Predicate> constituents, double[] input) {
499          Predicate.Eval currentStatus = Predicate.Eval.TRUE;
500          for (Predicate p : constituents) {
501            Predicate.Eval temp = p.evaluate(input);
502            if (temp == Predicate.Eval.FALSE) {
503              currentStatus = temp;
504              break;
505            } else if (temp == Predicate.Eval.UNKNOWN) {
506              currentStatus = temp;
507            }
508          }         
509          return currentStatus;
510        }
511      },
512      XOR("xor") {
513        Predicate.Eval evaluate(ArrayList<Predicate> constituents, double[] input) {
514          Predicate.Eval currentStatus = constituents.get(0).evaluate(input);
515          if (currentStatus != Predicate.Eval.UNKNOWN) {
516            for (int i = 1; i < constituents.size(); i++) {
517              Predicate.Eval temp = constituents.get(i).evaluate(input);
518              if (temp == Predicate.Eval.UNKNOWN) {
519                currentStatus = temp;
520                break;
521              } else {
522                if (currentStatus != temp) {
523                  currentStatus = Predicate.Eval.TRUE;
524                } else {
525                  currentStatus = Predicate.Eval.FALSE;
526                }
527              }
528            }
529          }
530          return currentStatus;
531        }
532      },
533      SURROGATE("surrogate") {
534        Predicate.Eval evaluate(ArrayList<Predicate> constituents, double[] input) {
535          Predicate.Eval currentStatus = constituents.get(0).evaluate(input);
536         
537          int i = 1;
538          while (currentStatus == Predicate.Eval.UNKNOWN) {
539            currentStatus = constituents.get(i).evaluate(input);           
540          }
541         
542          // return false if all our surrogates evaluate to unknown.
543          if (currentStatus == Predicate.Eval.UNKNOWN) {
544            currentStatus = Predicate.Eval.FALSE;
545          }
546         
547          return currentStatus;
548        }
549      };
550     
551      abstract Predicate.Eval evaluate(ArrayList<Predicate> constituents, double[] input);
552     
553      private final String m_stringVal;
554     
555      BooleanOperator(String name) {
556        m_stringVal = name;
557      }
558     
559      public String toString() {
560        return m_stringVal;
561      }
562    }
563   
564    /** the constituent Predicates */
565    ArrayList<Predicate> m_components = new ArrayList<Predicate>();
566   
567    /** the boolean operator */
568    BooleanOperator m_booleanOperator;
569       
570    public CompoundPredicate(Element compoundP, 
571        MiningSchema miningSchema) throws Exception {
572//      Instances totalStructure = miningSchema.getFieldsAsInstances();
573     
574      String booleanOpp = compoundP.getAttribute("booleanOperator");
575      for (BooleanOperator b : BooleanOperator.values()) {
576        if (b.toString().equals(booleanOpp)) {
577          m_booleanOperator = b;
578        }
579      }
580     
581      // now get all the encapsulated operators
582      NodeList children = compoundP.getChildNodes();
583      for (int i = 0; i < children.getLength(); i++) {
584        Node child = children.item(i);
585        if (child.getNodeType() == Node.ELEMENT_NODE) {
586          String tagName = ((Element)child).getTagName();
587          if (tagName.equals("True")) {
588            m_components.add(new True());
589          } else if (tagName.equals("False")) {
590            m_components.add(new False());
591          } else if (tagName.equals("SimplePredicate")) {
592            m_components.add(new SimplePredicate((Element)child, miningSchema));
593          } else if (tagName.equals("CompoundPredicate")) {
594            m_components.add(new CompoundPredicate((Element)child, miningSchema));
595          } else {
596            m_components.add(new SimpleSetPredicate((Element)child, miningSchema));
597          }
598        }
599      }
600    }
601   
602    public Predicate.Eval evaluate(double[] input) {
603      return m_booleanOperator.evaluate(m_components, input);
604    }
605   
606    public String toString() {
607      return toString(0, false);
608    }
609   
610    public String toString(int level, boolean cr) {
611      StringBuffer text = new StringBuffer();
612      for (int j = 0; j < level; j++) {
613        text.append("|   ");
614      }
615     
616      text.append("Compound [" + m_booleanOperator.toString() + "]");
617      if (cr) {
618        text.append("\\n");
619      } else {
620        text.append("\n");
621      }
622      for (int i = 0; i < m_components.size(); i++) {
623        text.append(m_components.get(i).toString(level, cr).replace(":", ""));
624        if (i != m_components.size()-1) {
625          if (cr) {
626            text.append("\\n");
627          } else {
628            text.append("\n");
629          }
630        }
631      }
632     
633      return text.toString();
634    }
635  }
636 
637  /**
638   * Class representing the SimpleSetPredicate
639   */
640  static class SimpleSetPredicate extends Predicate {
641   
642    /**
643     * For serialization
644     */
645    private static final long serialVersionUID = -2711995401345708486L;
646
647    enum BooleanOperator {
648        IS_IN("isIn") {
649          Predicate.Eval evaluate(double[] input, int fieldIndex, 
650              Array set, Attribute nominalLookup) {           
651            if (set.getType() == Array.ArrayType.STRING) {
652              String value = "";
653              if (!Utils.isMissingValue(input[fieldIndex])) {
654                value = nominalLookup.value((int)input[fieldIndex]);
655              }
656              return Predicate.booleanToEval(Utils.isMissingValue(input[fieldIndex]), 
657                  set.contains(value));
658            } else if (set.getType() == Array.ArrayType.NUM ||
659                set.getType() == Array.ArrayType.REAL) {
660              return Predicate.booleanToEval(Utils.isMissingValue(input[fieldIndex]), 
661                set.contains(input[fieldIndex]));
662            }
663            return Predicate.booleanToEval(Utils.isMissingValue(input[fieldIndex]), 
664                set.contains((int)input[fieldIndex]));
665          }
666        },
667        IS_NOT_IN("isNotIn") {
668          Predicate.Eval evaluate(double[] input, int fieldIndex,
669              Array set, Attribute nominalLookup) {
670            Predicate.Eval result = IS_IN.evaluate(input, fieldIndex, set, nominalLookup);
671            if (result == Predicate.Eval.FALSE) {
672              result = Predicate.Eval.TRUE;
673            } else if (result == Predicate.Eval.TRUE) {
674              result = Predicate.Eval.FALSE;
675            }
676           
677            return result;
678          }
679        };
680       
681        abstract Predicate.Eval evaluate(double[] input, int fieldIndex, 
682            Array set, Attribute nominalLookup);
683       
684        private final String m_stringVal;
685       
686        BooleanOperator(String name) {
687          m_stringVal = name;
688        }
689       
690        public String toString() {
691          return m_stringVal;
692        }
693    }
694   
695    /** the field to reference */
696    int m_fieldIndex = -1;
697   
698    /** the name of the field */
699    String m_fieldName;
700   
701    /** is the referenced field nominal? */
702    boolean m_isNominal = false;
703   
704    /** the attribute to lookup nominal values from */
705    Attribute m_nominalLookup;
706   
707    /** the boolean operator */
708    BooleanOperator m_operator = BooleanOperator.IS_IN;
709   
710    /** the array holding the set of values */
711    Array m_set;
712       
713    public SimpleSetPredicate(Element setP, 
714        MiningSchema miningSchema) throws Exception {
715      Instances totalStructure = miningSchema.getFieldsAsInstances();
716     
717      // get the field name and set up the index
718      String fieldS = setP.getAttribute("field");
719      Attribute att = totalStructure.attribute(fieldS);
720      if (att == null) {
721        throw new Exception("[SimplePredicate] unable to find field " + fieldS
722            + " in the incoming instance structure!");
723      }
724     
725      // find the index
726      int index = -1;
727      for (int i = 0; i < totalStructure.numAttributes(); i++) {
728        if (totalStructure.attribute(i).name().equals(fieldS)) {
729          index = i;
730          m_fieldName = totalStructure.attribute(i).name();
731          break;
732        }
733      }
734      m_fieldIndex = index;
735      if (att.isNominal()) {
736        m_isNominal = true;
737        m_nominalLookup = att;
738      }
739 
740      // need to scan the children looking for an array type
741      NodeList children = setP.getChildNodes();
742      for (int i = 0; i < children.getLength(); i++) {
743        Node child = children.item(i);
744        if (child.getNodeType() == Node.ELEMENT_NODE) {
745          if (Array.isArray((Element)child)) {
746            // found the array
747            m_set = Array.create((Element)child);
748            break;
749          }
750        }
751      }
752
753      if (m_set == null) {
754        throw new Exception("[SimpleSetPredictate] couldn't find an " +
755        "array containing the set values!");
756      }
757     
758      // check array type against field type
759      if (m_set.getType() == Array.ArrayType.STRING &&
760          !m_isNominal) {
761        throw new Exception("[SimpleSetPredicate] referenced field " +
762            totalStructure.attribute(m_fieldIndex).name() + 
763            " is numeric but array type is string!");
764      } else if (m_set.getType() != Array.ArrayType.STRING && 
765          m_isNominal) {
766        throw new Exception("[SimpleSetPredicate] referenced field " +
767            totalStructure.attribute(m_fieldIndex).name() +
768            " is nominal but array type is numeric!");
769      }     
770    }
771   
772    public Predicate.Eval evaluate(double[] input) {
773      return m_operator.evaluate(input, m_fieldIndex, m_set, m_nominalLookup);
774    }
775   
776    public String toString() {
777      StringBuffer temp = new StringBuffer();
778     
779      temp.append(m_fieldName + " " + m_operator.toString() + " ");
780      temp.append(m_set.toString());
781     
782      return temp.toString();
783    }
784  }
785 
786  /**
787   * Class for handling a Node in the tree
788   */
789  class TreeNode implements Serializable {
790    // TODO: perhaps implement a class called Statistics that contains Partitions?
791       
792    /**
793     * For serialization
794     */
795    private static final long serialVersionUID = 3011062274167063699L;
796
797    /** ID for this node */
798    private String m_ID = "" + this.hashCode();
799   
800    /** The score as a string */
801    private String m_scoreString;
802   
803    /** The index of this predicted value (if class is nominal) */
804    private int m_scoreIndex = -1;
805   
806    /** The score as a number (if target is numeric) */
807    private double m_scoreNumeric = Utils.missingValue();
808   
809    /** The record count at this node (if defined) */
810    private double m_recordCount = Utils.missingValue();
811   
812    /** The ID of the default child (if applicable) */
813    private String m_defaultChildID;
814   
815    /** Holds the node of the default child (if defined) */
816    private TreeNode m_defaultChild;
817   
818    /** The distribution for labels (classification) */
819    private ArrayList<ScoreDistribution> m_scoreDistributions = 
820      new ArrayList<ScoreDistribution>();
821   
822    /** The predicate for this node */
823    private Predicate m_predicate;
824   
825    /** The children of this node */
826    private ArrayList<TreeNode> m_childNodes = new ArrayList<TreeNode>();
827   
828   
829    protected TreeNode(Element nodeE, MiningSchema miningSchema) throws Exception {
830      Attribute classAtt = miningSchema.getFieldsAsInstances().classAttribute();
831     
832      // get the ID
833      String id = nodeE.getAttribute("id");
834      if (id != null && id.length() > 0) {
835        m_ID = id;
836      }
837     
838      // get the score for this node
839      String scoreS = nodeE.getAttribute("score");
840      if (scoreS != null && scoreS.length() > 0) {
841        m_scoreString = scoreS;
842       
843        // try to parse as a number in case we
844        // are part of a regression tree
845        if (classAtt.isNumeric()) {
846          try {
847            m_scoreNumeric = Double.parseDouble(scoreS);
848          } catch (NumberFormatException ex) {
849            throw new Exception("[TreeNode] class is numeric but unable to parse score " 
850                + m_scoreString + " as a number!");
851          }
852        } else {
853          // store the index of this class value
854          m_scoreIndex = classAtt.indexOfValue(m_scoreString);
855         
856          if (m_scoreIndex < 0) {
857            throw new Exception("[TreeNode] can't find match for predicted value " 
858                + m_scoreString + " in class attribute!");
859          }
860        }
861      }
862     
863      // get the record count if defined
864      String recordC = nodeE.getAttribute("recordCount");
865      if (recordC != null && recordC.length() > 0) {
866        m_recordCount = Double.parseDouble(recordC);
867      }
868     
869      // get the default child (if applicable)
870      String defaultC = nodeE.getAttribute("defaultChild");
871      if (defaultC != null && defaultC.length() > 0) {
872        m_defaultChildID = defaultC;
873      }
874     
875      //TODO: Embedded model (once we support model composition)
876     
877      // Now get the ScoreDistributions (if any and mining function
878      // is classification) at this level
879      if (m_functionType == MiningFunction.CLASSIFICATION) {
880        getScoreDistributions(nodeE, miningSchema);
881      }
882     
883      // Now get the Predicate
884      m_predicate = Predicate.getPredicate(nodeE, miningSchema);
885     
886      // Now get the child Node(s)
887      getChildNodes(nodeE, miningSchema);
888     
889      // If we have a default child specified, find it now
890      if (m_defaultChildID != null) {
891        for (TreeNode t : m_childNodes) {
892          if (t.getID().equals(m_defaultChildID)) {
893            m_defaultChild = t;
894            break;
895          }
896        }
897      }
898    }
899   
900    private void getChildNodes(Element nodeE, MiningSchema miningSchema) throws Exception {
901      NodeList children = nodeE.getChildNodes();
902     
903      for (int i = 0; i < children.getLength(); i++) {
904        Node child = children.item(i);
905        if (child.getNodeType() == Node.ELEMENT_NODE) {
906          String tagName = ((Element)child).getTagName();
907          if (tagName.equals("Node")) {
908            TreeNode tempN = new TreeNode((Element)child, miningSchema);
909            m_childNodes.add(tempN);
910          }
911        }
912      }
913    }
914   
915    private void getScoreDistributions(Element nodeE, 
916        MiningSchema miningSchema) throws Exception {
917     
918      NodeList scoreChildren = nodeE.getChildNodes();
919      for (int i = 0; i < scoreChildren.getLength(); i++) {
920        Node child = scoreChildren.item(i);
921        if (child.getNodeType() == Node.ELEMENT_NODE) {
922          String tagName = ((Element)child).getTagName();
923          if (tagName.equals("ScoreDistribution")) {
924            ScoreDistribution newDist = new ScoreDistribution((Element)child, 
925                miningSchema, m_recordCount);
926            m_scoreDistributions.add(newDist);
927          }
928        }
929      }
930     
931      // backfit the confidence values
932      if (Utils.isMissingValue(m_recordCount)) {
933        double baseCount = 0;
934        for (ScoreDistribution s : m_scoreDistributions) {
935          baseCount += s.getRecordCount();
936        }
937       
938        for (ScoreDistribution s : m_scoreDistributions) {
939          s.deriveConfidenceValue(baseCount);
940        }
941      }
942    }
943       
944    /**
945     * Get the score value as a string.
946     *
947     * @return the score value as a String.
948     */
949    protected String getScore() {
950      return m_scoreString;
951    }
952   
953    /**
954     * Get the score value as a number (regression trees only).
955     *
956     * @return the score as a number
957     */
958    protected double getScoreNumeric() {
959      return m_scoreNumeric;
960    }
961   
962    /**
963     * Get the ID of this node.
964     *
965     * @return the ID of this node.
966     */
967    protected String getID() {
968      return m_ID;
969    }
970   
971    /**
972     * Get the Predicate at this node.
973     *
974     * @return the predicate at this node.
975     */
976    protected Predicate getPredicate() {
977      return m_predicate;
978    }
979   
980    /**
981     * Get the record count at this node.
982     *
983     * @return the record count at this node.
984     */
985    protected double getRecordCount() {
986      return m_recordCount;
987    }
988   
989    protected void dumpGraph(StringBuffer text) throws Exception {
990      text.append("N" + m_ID + " ");
991      if (m_scoreString != null) {
992        text.append("[label=\"score=" + m_scoreString);
993      }
994     
995      if (m_scoreDistributions.size() > 0 && m_childNodes.size() == 0) {
996        text.append("\\n");
997        for (ScoreDistribution s : m_scoreDistributions) {
998          text.append(s + "\\n");
999        }
1000      }
1001     
1002      text.append("\"");
1003     
1004      if (m_childNodes.size() == 0) {
1005        text.append(" shape=box style=filled");
1006       
1007      }
1008     
1009      text.append("]\n");
1010     
1011      for (TreeNode c : m_childNodes) {
1012        text.append("N" + m_ID +"->" + "N" + c.getID());
1013        text.append(" [label=\"" + c.getPredicate().toString(0, true));
1014        text.append("\"]\n");
1015        c.dumpGraph(text);
1016      }
1017    }
1018   
1019    public String toString() {
1020      StringBuffer text = new StringBuffer();
1021     
1022      // print out the root
1023      dumpTree(0, text);
1024
1025      return text.toString();
1026    }
1027   
1028    protected void dumpTree(int level, StringBuffer text) {
1029      if (m_childNodes.size() > 0) {
1030
1031        for (int i = 0; i < m_childNodes.size(); i++) {
1032          text.append("\n");
1033         
1034/*          for (int j = 0; j < level; j++) {
1035            text.append("|   ");
1036          } */
1037         
1038          // output the predicate for this child node
1039          TreeNode child = m_childNodes.get(i);
1040          text.append(child.getPredicate().toString(level, false));
1041         
1042          // process recursively
1043          child.dumpTree(level + 1 , text);         
1044        }
1045      } else {
1046        // leaf
1047        text.append(": ");
1048        if (!Utils.isMissingValue(m_scoreNumeric)) {
1049          text.append(m_scoreNumeric);
1050        } else {
1051          text.append(m_scoreString + " ");
1052          if (m_scoreDistributions.size() > 0) {
1053            text.append("[");
1054            for (ScoreDistribution s : m_scoreDistributions) {
1055              text.append(s);
1056            }
1057            text.append("]");
1058          } else {
1059            text.append(m_scoreString);
1060          }
1061        }
1062      }
1063    }
1064   
1065    /**
1066     * Score an incoming instance. Invokes a missing value handling strategy.
1067     *
1068     * @param instance a vector of incoming attribute and derived field values.
1069     * @param classAtt the class attribute
1070     * @return a predicted probability distribution.
1071     * @throws Exception if something goes wrong.
1072     */
1073    protected double[] score(double[] instance, Attribute classAtt) throws Exception {
1074      double[] preds = null;
1075     
1076      if (classAtt.isNumeric()) {
1077        preds = new double[1];
1078      } else {
1079        preds = new double[classAtt.numValues()];
1080      }
1081     
1082      // leaf?
1083      if (m_childNodes.size() == 0) {
1084        doLeaf(classAtt, preds);
1085      } else {
1086        // process the children
1087        switch (TreeModel.this.m_missingValueStrategy) {
1088        case NONE:
1089          preds = missingValueStrategyNone(instance, classAtt);
1090          break;
1091        case LASTPREDICTION:
1092          preds = missingValueStrategyLastPrediction(instance, classAtt);
1093          break;
1094        case DEFAULTCHILD:
1095          preds = missingValueStrategyDefaultChild(instance, classAtt);
1096          break;
1097        default:
1098          throw new Exception("[TreeModel] not implemented!");
1099        }
1100      }
1101     
1102      return preds;
1103    }
1104   
1105    /**
1106     * Compute the predictions for a leaf.
1107     *
1108     * @param classAtt the class attribute
1109     * @param preds an array to hold the predicted probabilities.
1110     * @throws Exception if something goes wrong.
1111     */
1112    protected void doLeaf(Attribute classAtt, double[] preds) throws Exception {
1113      if (classAtt.isNumeric()) {
1114        preds[0] = m_scoreNumeric;
1115      } else {
1116        if (m_scoreDistributions.size() == 0) {
1117          preds[m_scoreIndex] = 1.0;
1118        } else {
1119          // collect confidences from the score distributions
1120          for (ScoreDistribution s : m_scoreDistributions) {
1121            preds[s.getClassLabelIndex()] = s.getConfidence();
1122          }
1123        }
1124      }
1125    }
1126   
1127    /**
1128     * Evaluate on the basis of the no true child strategy.
1129     *
1130     * @param classAtt the class attribute.
1131     * @param preds an array to hold the predicted probabilities.
1132     * @throws Exception if something goes wrong.
1133     */
1134    protected void doNoTrueChild(Attribute classAtt, double[] preds) 
1135      throws Exception {
1136      if (TreeModel.this.m_noTrueChildStrategy == 
1137        NoTrueChildStrategy.RETURNNULLPREDICTION) {
1138        for (int i = 0; i < classAtt.numValues(); i++) {
1139          preds[i] = Utils.missingValue();
1140        }
1141      } else {
1142        // return the predictions at this node
1143        doLeaf(classAtt, preds);
1144      }
1145    }
1146   
1147    /**
1148     * Compute predictions and optionally invoke the weighted confidence
1149     * missing value handling strategy.
1150     *
1151     * @param instance the incoming vector of attribute and derived field values.
1152     * @param classAtt the class attribute.
1153     * @return the predicted probability distribution.
1154     * @throws Exception if something goes wrong.
1155     */
1156    protected double[] missingValueStrategyWeightedConfidence(double[] instance,
1157        Attribute classAtt) throws Exception {
1158     
1159      if (classAtt.isNumeric()) {
1160        throw new Exception("[TreeNode] missing value strategy weighted confidence, "
1161            + "but class is numeric!");
1162      }
1163     
1164      double[] preds = null;
1165      TreeNode trueNode = null;
1166      boolean strategyInvoked = false;
1167      int nodeCount = 0;
1168     
1169      // look at the evaluation of the child predicates
1170      for (TreeNode c : m_childNodes) {
1171        if (c.getPredicate().evaluate(instance) == Predicate.Eval.TRUE) {
1172          // note the first child to evaluate to true
1173          if (trueNode == null) {
1174            trueNode = c;
1175          }
1176          nodeCount++;
1177        } else if (c.getPredicate().evaluate(instance) == Predicate.Eval.UNKNOWN) {
1178          strategyInvoked = true;
1179          nodeCount++;
1180        }
1181      }
1182     
1183      if (strategyInvoked) {
1184        // we expect to combine nodeCount distributions
1185        double[][] dists = new double[nodeCount][];
1186        double[] weights = new double[nodeCount];
1187       
1188        // collect the distributions and weights
1189        int count = 0;
1190        for (TreeNode c : m_childNodes) {
1191          if (c.getPredicate().evaluate(instance) == Predicate.Eval.TRUE ||
1192              c.getPredicate().evaluate(instance) == Predicate.Eval.UNKNOWN) {
1193           
1194            weights[count] = c.getRecordCount();
1195            if (Utils.isMissingValue(weights[count])) {
1196              throw new Exception("[TreeNode] weighted confidence missing value " +
1197                        "strategy invoked, but no record count defined for node " +
1198                        c.getID());
1199            }           
1200            dists[count++] = c.score(instance, classAtt);
1201          }
1202        }
1203       
1204        // do the combination
1205        preds = new double[classAtt.numValues()];
1206        for (int i = 0; i < classAtt.numValues(); i++) {
1207          for (int j = 0; j < nodeCount; j++) {
1208            preds[i] += ((weights[j] / m_recordCount) * dists[j][i]); 
1209          }
1210        }
1211      } else {
1212        if (trueNode != null) {
1213          preds = trueNode.score(instance, classAtt);
1214        } else {
1215          doNoTrueChild(classAtt, preds);
1216        }
1217      }
1218     
1219      return preds;
1220    }
1221   
1222    protected double[] freqCountsForAggNodesStrategy(double[] instance,
1223        Attribute classAtt) throws Exception {
1224   
1225      double[] counts = new double[classAtt.numValues()];
1226     
1227      if (m_childNodes.size() > 0) {
1228        // collect the counts
1229        for (TreeNode c : m_childNodes) {
1230          if (c.getPredicate().evaluate(instance) == Predicate.Eval.TRUE ||
1231              c.getPredicate().evaluate(instance) == Predicate.Eval.UNKNOWN) {
1232
1233            double[] temp = c.freqCountsForAggNodesStrategy(instance, classAtt);
1234            for (int i = 0; i < classAtt.numValues(); i++) {
1235              counts[i] += temp[i];
1236            }
1237          }
1238        }
1239      } else {
1240        // process the score distributions
1241        if (m_scoreDistributions.size() == 0) {
1242          throw new Exception("[TreeModel] missing value strategy aggregate nodes:" +
1243                        " no score distributions at leaf " + m_ID);
1244        }
1245        for (ScoreDistribution s : m_scoreDistributions) {
1246          counts[s.getClassLabelIndex()] = s.getRecordCount();
1247        }
1248      }
1249           
1250      return counts;
1251    }
1252   
1253    /**
1254     * Compute predictions and optionally invoke the aggregate nodes
1255     * missing value handling strategy.
1256     *
1257     * @param instance the incoming vector of attribute and derived field values.
1258     * @param classAtt the class attribute.
1259     * @return the predicted probability distribution.
1260     * @throws Exception if something goes wrong.
1261     */
1262    protected double[] missingValueStrategyAggregateNodes(double[] instance,
1263        Attribute classAtt) throws Exception {
1264     
1265      if (classAtt.isNumeric()) {
1266        throw new Exception("[TreeNode] missing value strategy aggregate nodes, "
1267            + "but class is numeric!");
1268      }
1269
1270      double[] preds = null;
1271      TreeNode trueNode = null;
1272      boolean strategyInvoked = false;
1273      int nodeCount = 0;
1274     
1275      // look at the evaluation of the child predicates
1276      for (TreeNode c : m_childNodes) {
1277        if (c.getPredicate().evaluate(instance) == Predicate.Eval.TRUE) {
1278          // note the first child to evaluate to true
1279          if (trueNode == null) {
1280            trueNode = c;
1281          }
1282          nodeCount++;
1283        } else if (c.getPredicate().evaluate(instance) == Predicate.Eval.UNKNOWN) {
1284          strategyInvoked = true;
1285          nodeCount++;
1286        }
1287      }
1288     
1289      if (strategyInvoked) {
1290        double[] aggregatedCounts = 
1291          freqCountsForAggNodesStrategy(instance, classAtt);
1292       
1293        // normalize
1294        Utils.normalize(aggregatedCounts);
1295        preds = aggregatedCounts;
1296      } else {
1297        if (trueNode != null) {
1298          preds = trueNode.score(instance, classAtt);
1299        } else {
1300          doNoTrueChild(classAtt, preds);
1301        }
1302      }
1303     
1304      return preds;             
1305    }
1306   
1307    /**
1308     * Compute predictions and optionally invoke the default child
1309     * missing value handling strategy.
1310     *
1311     * @param instance the incoming vector of attribute and derived field values.
1312     * @param classAtt the class attribute.
1313     * @return the predicted probability distribution.
1314     * @throws Exception if something goes wrong.
1315     */
1316    protected double[] missingValueStrategyDefaultChild(double[] instance, 
1317        Attribute classAtt) throws Exception {
1318     
1319      double[] preds = null;
1320      boolean strategyInvoked = false;
1321     
1322      // look for a child whose predicate evaluates to TRUE
1323      for (TreeNode c : m_childNodes) {
1324        if (c.getPredicate().evaluate(instance) == Predicate.Eval.TRUE) {
1325          preds = c.score(instance, classAtt);
1326          break;
1327        } else if (c.getPredicate().evaluate(instance) == Predicate.Eval.UNKNOWN) {
1328          strategyInvoked = true;
1329        }
1330      }
1331     
1332      // no true child found
1333      if (preds == null) {
1334        if (!strategyInvoked) {
1335          doNoTrueChild(classAtt, preds);
1336        } else {
1337          // do the strategy
1338         
1339          // NOTE: we don't actually implement the missing value penalty since
1340          // we always return a full probability distribution.
1341          if (m_defaultChild != null) {
1342            preds = m_defaultChild.score(instance, classAtt);
1343          } else {
1344            throw new Exception("[TreeNode] missing value strategy is defaultChild, but " +
1345                        "no default child has been specified in node " + m_ID);
1346          }
1347        }
1348      }
1349                 
1350      return preds;
1351    }
1352   
1353    /**
1354     * Compute predictions and optionally invoke the last prediction
1355     * missing value handling strategy.
1356     *
1357     * @param instance the incoming vector of attribute and derived field values.
1358     * @param classAtt the class attribute.
1359     * @return the predicted probability distribution.
1360     * @throws Exception if something goes wrong.
1361     */
1362    protected double[] missingValueStrategyLastPrediction(double[] instance, 
1363        Attribute classAtt) throws Exception {
1364     
1365      double[] preds = null;
1366      boolean strategyInvoked = false;
1367     
1368      // look for a child whose predicate evaluates to TRUE
1369      for (TreeNode c : m_childNodes) {
1370        if (c.getPredicate().evaluate(instance) == Predicate.Eval.TRUE) {
1371          preds = c.score(instance, classAtt);
1372          break;
1373        } else if (c.getPredicate().evaluate(instance) == Predicate.Eval.UNKNOWN) {
1374          strategyInvoked = true;
1375        }
1376      }
1377     
1378      // no true child found
1379      if (preds == null) {
1380        preds = new double[classAtt.numValues()];
1381        if (!strategyInvoked) {
1382          // no true child
1383          doNoTrueChild(classAtt, preds);
1384        } else {
1385          // do the strategy
1386          doLeaf(classAtt, preds);
1387        }
1388      }
1389     
1390      return preds;
1391    }
1392   
1393    /**
1394     * Compute predictions and optionally invoke the null prediction
1395     * missing value handling strategy.
1396     *
1397     * @param instance the incoming vector of attribute and derived field values.
1398     * @param classAtt the class attribute.
1399     * @return the predicted probability distribution.
1400     * @throws Exception if something goes wrong.
1401     */
1402    protected double[] missingValueStrategyNullPrediction(double[] instance,
1403        Attribute classAtt) throws Exception {
1404     
1405      double[] preds = null;
1406      boolean strategyInvoked = false;
1407     
1408      // look for a child whose predicate evaluates to TRUE
1409      for (TreeNode c : m_childNodes) {
1410        if (c.getPredicate().evaluate(instance) == Predicate.Eval.TRUE) {
1411          preds = c.score(instance, classAtt);
1412          break;
1413        } else if (c.getPredicate().evaluate(instance) == Predicate.Eval.UNKNOWN) {
1414          strategyInvoked = true;
1415        }
1416      }
1417     
1418      // no true child found
1419      if (preds == null) {
1420        preds = new double[classAtt.numValues()];
1421        if (!strategyInvoked) {
1422          doNoTrueChild(classAtt, preds);
1423        } else {
1424          // do the strategy
1425          for (int i = 0; i < classAtt.numValues(); i++) {
1426            preds[i] = Utils.missingValue();
1427          }
1428        }
1429      }
1430     
1431      return preds;
1432    }
1433   
1434    /**
1435     * Compute predictions and optionally invoke the "none"
1436     * missing value handling strategy (invokes no true child).
1437     *
1438     * @param instance the incoming vector of attribute and derived field values.
1439     * @param classAtt the class attribute.
1440     * @return the predicted probability distribution.
1441     * @throws Exception if something goes wrong.
1442     */
1443    protected double[] missingValueStrategyNone(double[] instance, Attribute classAtt)
1444      throws Exception {
1445     
1446      double[] preds = null;
1447     
1448      // look for a child whose predicate evaluates to TRUE
1449      for (TreeNode c : m_childNodes) {
1450        if (c.getPredicate().evaluate(instance) == Predicate.Eval.TRUE) {
1451          preds = c.score(instance, classAtt);
1452          break;
1453        }
1454      }
1455     
1456      if (preds == null) {
1457        preds = new double[classAtt.numValues()];
1458       
1459        // no true child strategy
1460        doNoTrueChild(classAtt, preds);
1461      }
1462     
1463      return preds;
1464    }
1465  }
1466 
1467  /**
1468   * Enumerated type for the mining function
1469   */
1470  enum MiningFunction {
1471    CLASSIFICATION,
1472    REGRESSION;
1473  }
1474 
1475  enum MissingValueStrategy {
1476    LASTPREDICTION("lastPrediction"),
1477    NULLPREDICTION("nullPrediction"),
1478    DEFAULTCHILD("defaultChild"),
1479    WEIGHTEDCONFIDENCE("weightedConfidence"),
1480    AGGREGATENODES("aggregateNodes"),
1481    NONE("none");
1482   
1483    private final String m_stringVal;
1484   
1485    MissingValueStrategy(String name) {
1486      m_stringVal = name;
1487    }
1488   
1489    public String toString() {
1490      return m_stringVal;
1491    }
1492  }
1493 
1494  enum NoTrueChildStrategy {
1495    RETURNNULLPREDICTION("returnNullPrediction"),
1496    RETURNLASTPREDICTION("returnLastPrediction");
1497   
1498    private final String m_stringVal;
1499   
1500    NoTrueChildStrategy(String name) {
1501      m_stringVal = name;
1502    }
1503   
1504    public String toString() {
1505      return m_stringVal;
1506    }
1507  }
1508 
1509  enum SplitCharacteristic {
1510    BINARYSPLIT("binarySplit"),
1511    MULTISPLIT("multiSplit");
1512 
1513    private final String m_stringVal;
1514   
1515    SplitCharacteristic(String name) {
1516      m_stringVal = name;
1517    }
1518   
1519    public String toString() {
1520      return m_stringVal;
1521    } 
1522  }
1523 
1524  /** The mining function */
1525  protected MiningFunction m_functionType = MiningFunction.CLASSIFICATION;
1526 
1527  /** The missing value strategy */
1528  protected MissingValueStrategy m_missingValueStrategy = MissingValueStrategy.NONE;
1529 
1530  /**
1531   * The missing value penalty (if defined).
1532   * We don't actually make use of this since we always return
1533   * full probability distributions.
1534   */
1535  protected double m_missingValuePenalty = Utils.missingValue();
1536 
1537  /** The no true child strategy to use */
1538  protected NoTrueChildStrategy m_noTrueChildStrategy = NoTrueChildStrategy.RETURNNULLPREDICTION;
1539 
1540  /** The splitting type */
1541  protected SplitCharacteristic m_splitCharacteristic = SplitCharacteristic.MULTISPLIT;
1542 
1543  /** The root of the tree */
1544  protected TreeNode m_root;
1545 
1546  public TreeModel(Element model, Instances dataDictionary, 
1547      MiningSchema miningSchema) throws Exception {
1548   
1549    super(dataDictionary, miningSchema);
1550   
1551    if (!getPMMLVersion().equals("3.2")) {
1552      // TODO: might have to throw an exception and only support 3.2
1553    }
1554   
1555    String fn = model.getAttribute("functionName");
1556    if (fn.equals("regression")) {
1557      m_functionType = MiningFunction.REGRESSION;
1558    }
1559   
1560    // get the missing value strategy (if any)
1561    String missingVS = model.getAttribute("missingValueStrategy");
1562    if (missingVS != null && missingVS.length() > 0) {
1563      for (MissingValueStrategy m : MissingValueStrategy.values()) {
1564        if (m.toString().equals(missingVS)) {
1565          m_missingValueStrategy = m;
1566          break;
1567        }
1568      }
1569    }
1570
1571    // get the missing value penalty (if any)
1572    String missingP = model.getAttribute("missingValuePenalty");
1573    if (missingP != null && missingP.length() > 0) {
1574      // try to parse as a number
1575      try {
1576        m_missingValuePenalty = Double.parseDouble(missingP);
1577      } catch (NumberFormatException ex) {
1578        System.err.println("[TreeModel] WARNING: " +
1579          "couldn't parse supplied missingValuePenalty as a number");
1580      }
1581    }
1582
1583    String splitC = model.getAttribute("splitCharacteristic");
1584
1585    if (splitC != null && splitC.length() > 0) {
1586      for (SplitCharacteristic s : SplitCharacteristic.values()) {
1587        if (s.toString().equals(splitC)) {
1588          m_splitCharacteristic = s;
1589          break;
1590        }
1591      }
1592    }
1593   
1594    // find the root node of the tree
1595    NodeList children = model.getChildNodes();
1596    for (int i = 0; i < children.getLength(); i++) {
1597      Node child = children.item(i);
1598      if (child.getNodeType() == Node.ELEMENT_NODE) {
1599        String tagName = ((Element)child).getTagName();
1600        if (tagName.equals("Node")) {
1601          m_root = new TreeNode((Element)child, miningSchema);         
1602          break;
1603        }
1604      }
1605    }   
1606  }
1607 
1608  /**                                                                                                             
1609   * Classifies the given test instance. The instance has to belong to a                                         
1610   * dataset when it's being classified.                                                         
1611   *                                                                                                             
1612   * @param inst the instance to be classified                                                               
1613   * @return the predicted most likely class for the instance or                                                 
1614   * Utils.missingValue() if no prediction is made                                                             
1615   * @exception Exception if an error occurred during the prediction                                             
1616   */
1617  public double[] distributionForInstance(Instance inst) throws Exception {
1618    if (!m_initialized) {
1619      mapToMiningSchema(inst.dataset());
1620    }
1621    double[] preds = null;
1622   
1623    if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
1624      preds = new double[1];
1625    } else {
1626      preds = new double[m_miningSchema.getFieldsAsInstances().classAttribute().numValues()];
1627    }
1628   
1629    double[] incoming = m_fieldsMap.instanceToSchema(inst, m_miningSchema);
1630   
1631    preds = m_root.score(incoming, m_miningSchema.getFieldsAsInstances().classAttribute());
1632   
1633   return preds; 
1634  }
1635 
1636  public String toString() {
1637    StringBuffer temp = new StringBuffer();
1638
1639    temp.append("PMML version " + getPMMLVersion());
1640    if (!getCreatorApplication().equals("?")) {
1641      temp.append("\nApplication: " + getCreatorApplication());
1642    }
1643    temp.append("\nPMML Model: TreeModel");
1644    temp.append("\n\n");
1645    temp.append(m_miningSchema);
1646   
1647    temp.append("Split-type: " + m_splitCharacteristic + "\n");
1648    temp.append("No true child strategy: " + m_noTrueChildStrategy + "\n");
1649    temp.append("Missing value strategy: " + m_missingValueStrategy + "\n");
1650   
1651    temp.append(m_root.toString());
1652   
1653    return temp.toString();
1654  }
1655 
1656  public String graph() throws Exception {
1657    StringBuffer text = new StringBuffer();
1658    text.append("digraph PMMTree {\n");
1659   
1660    m_root.dumpGraph(text);
1661   
1662    text.append("}\n");
1663   
1664    return text.toString();
1665  }
1666
1667  public String getRevision() {
1668    return RevisionUtils.extract("$Revision: 5987 $");
1669  }
1670
1671  public int graphType() {
1672    return Drawable.TREE;
1673  }
1674}
Note: See TracBrowser for help on using the repository browser.