source: src/main/java/weka/classifiers/pmml/consumer/RuleSetModel.java @ 28

Last change on this file since 28 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 27.6 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    RuleSetModel.java
19 *    Copyright (C) 2009 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.pmml.consumer;
24
25import java.io.Serializable;
26import java.util.ArrayList;
27
28import org.w3c.dom.Element;
29import org.w3c.dom.Node;
30import org.w3c.dom.NodeList;
31
32import weka.classifiers.pmml.consumer.TreeModel.MiningFunction;
33import weka.core.Attribute;
34import weka.core.Instance;
35import weka.core.Instances;
36import weka.core.RevisionUtils;
37import weka.core.Utils;
38import weka.core.pmml.MiningSchema;
39
40/**
41 * Class implementing import of PMML RuleSetModel. Can be used as a Weka
42 * classifier for prediction only (buildClassifier() raises an Exception).
43 *
44 * @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
45 * @version $Revision: 5987 $
46 */
47public class RuleSetModel extends PMMLClassifier {
48 
49  /** For serialization */
50  private static final long serialVersionUID = 1993161168811020547L;
51
52  /**
53   * Abstract inner base class for Rules
54   */
55  static abstract class Rule implements Serializable {
56   
57    /** For serialization */
58    private static final long serialVersionUID = 6236231263477446102L;
59   
60    /** The predicate for this rule */
61    protected TreeModel.Predicate m_predicate;
62   
63    public Rule(Element ruleE, MiningSchema miningSchema) throws Exception {
64      // Set up the predicate
65      m_predicate = TreeModel.Predicate.getPredicate(ruleE, miningSchema);
66    }
67   
68    /**
69     * Collect the rule(s) that fire for the supplied incoming instance
70     *
71     * @param input a vector of independent and derived independent variables
72     * @param ruleCollection the array list to add any firing rules into
73     */
74    public abstract void fires(double[] input, ArrayList<SimpleRule> ruleCollection);
75   
76    /**
77     * Get a textual description of this Rule
78     *
79     * @param prefix prefix string (typically some number of spaces) to prepend
80     * @param indent the number of additional spaces to add to the prefix
81     * @return a description of this Rule as a String
82     */
83    public abstract String toString(String prefix, int indent);
84   
85  }
86 
87  /**
88   * Inner class for representing simple rules
89   */
90  static class SimpleRule extends Rule {
91   
92    /** For serialization */
93    private static final long serialVersionUID = -2612893679476049682L;
94
95    /** The ID for the rule (optional) */
96    protected String m_ID;
97   
98    /** The predicted value when the rule fires (required) */
99    protected String m_scoreString;
100   
101    /**
102     * The predicted value as a number (regression) or index (classification)
103     * when the rule fires (required)
104     */
105    protected double m_score = Utils.missingValue();
106   
107    /** The number of training/test instances on which the rule fired (optional) */
108    protected double m_recordCount = Utils.missingValue();
109   
110    /**
111     * The number of training/test instances on which the rule fired and the
112     * prediction was correct (optional)
113     */
114    protected double m_nbCorrect = Utils.missingValue();
115   
116    /** The confidence of the rule (optional) */
117    protected double m_confidence = Utils.missingValue();
118   
119    /** The score distributions for this rule (if any) */
120    protected ArrayList<TreeModel.ScoreDistribution> m_scoreDistributions = 
121      new ArrayList<TreeModel.ScoreDistribution>();
122   
123    /**
124     *  The relative importance of the rule. May or may not be equal to the
125     * confidence (optional).
126     */
127    protected double m_weight = Utils.missingValue();
128   
129    public String toString(String prefix, int indent) {
130      StringBuffer temp = new StringBuffer();
131     
132      for (int i = 0; i < indent; i++) {
133        prefix += " ";
134      }
135     
136      temp.append(prefix + "Simple rule: " + m_predicate + "\n");
137      temp.append(prefix + " => " + m_scoreString + "\n");
138      if (!Utils.isMissingValue(m_recordCount)) {
139        temp.append(prefix + " recordCount: " + m_recordCount + "\n");
140      }
141      if (!Utils.isMissingValue(m_nbCorrect)) {
142        temp.append(prefix + "   nbCorrect: " + m_nbCorrect + "\n");
143      }
144      if (!Utils.isMissingValue(m_confidence)) {
145        temp.append(prefix + "  confidence: " + m_confidence + "\n");
146      }
147      if (!Utils.isMissingValue(m_weight)) {
148        temp.append(prefix + "      weight: " + m_weight + "\n");
149      }
150     
151      return temp.toString();
152    }
153   
154    public String toString() {
155      return toString("", 0);
156    }
157       
158    /**
159     * Constructor for a simple rule
160     *
161     * @param ruleE the XML element holding the simple rule
162     * @param miningSchema the mining schema to use
163     * @throws Exception if something goes wrong
164     */
165    public SimpleRule(Element ruleE, MiningSchema miningSchema) throws Exception {
166      super(ruleE, miningSchema);
167     
168      String id = ruleE.getAttribute("id");
169      if (id != null && id.length() > 0) {
170        m_ID = id;
171      }
172     
173      m_scoreString = ruleE.getAttribute("score");
174      Attribute classAtt = miningSchema.getFieldsAsInstances().classAttribute(); 
175      if (classAtt.isNumeric()) {
176        m_score = Double.parseDouble(m_scoreString);
177      } else {
178        if (classAtt.indexOfValue(m_scoreString) < 0) {
179          throw new Exception("[SimpleRule] class value " + m_scoreString + 
180              "does not exist in class attribute " + classAtt.name());
181        }
182        m_score = classAtt.indexOfValue(m_scoreString);
183      }
184     
185      String recordCount = ruleE.getAttribute("recordCount");
186      if (recordCount != null && recordCount.length() > 0) {
187        m_recordCount = Double.parseDouble(recordCount);
188      }
189     
190      String nbCorrect = ruleE.getAttribute("nbCorrect");
191      if (nbCorrect != null && nbCorrect.length() > 0) {
192        m_nbCorrect = Double.parseDouble(nbCorrect);
193      }
194     
195      String confidence = ruleE.getAttribute("confidence");
196      if (confidence != null && confidence.length() > 0) {
197        m_confidence = Double.parseDouble(confidence);
198      }
199     
200      String weight = ruleE.getAttribute("weight");
201      if (weight != null && weight.length() > 0) {
202        m_weight = Double.parseDouble(weight);
203      }
204     
205      // get the ScoreDistributions (if any)
206      if (miningSchema.getFieldsAsInstances().classAttribute().isNominal()) {
207        // see if we have any ScoreDistribution entries
208        NodeList scoreChildren = ruleE.getChildNodes();
209               
210        for (int i = 0; i < scoreChildren.getLength(); i++) {
211          Node child = scoreChildren.item(i);
212          if (child.getNodeType() == Node.ELEMENT_NODE) {
213            String tagName = ((Element)child).getTagName();
214            if (tagName.equals("ScoreDistribution")) {
215              TreeModel.ScoreDistribution newDist = 
216                new TreeModel.ScoreDistribution((Element)child, 
217                  miningSchema, m_recordCount);
218              m_scoreDistributions.add(newDist);
219            }
220          }
221        }
222       
223        // check that we have as many score distribution elements as there
224        // are class labels in the data
225        if (m_scoreDistributions.size() > 0 && 
226            m_scoreDistributions.size() != 
227              miningSchema.getFieldsAsInstances().classAttribute().numValues()) {
228          throw new Exception("[SimpleRule] Number of score distribution elements is "
229              + " different than the number of class labels!");
230        }
231       
232        //backfit the confidence values (if necessary)
233        if (Utils.isMissingValue(m_recordCount)) {
234          double baseCount = 0;
235          for (TreeModel.ScoreDistribution s : m_scoreDistributions) {
236            baseCount += s.getRecordCount();
237          }
238         
239          for (TreeModel.ScoreDistribution s : m_scoreDistributions) {
240            s.deriveConfidenceValue(baseCount);
241          }
242        }
243      }
244    }
245   
246    /**
247     * Collect the rule(s) that fire for the supplied incoming instance
248     *
249     * @param input a vector of independent and derived independent variables
250     * @param ruleCollection the array list to add any firing rules into
251     */
252    public void fires(double[] input, ArrayList<SimpleRule> ruleCollection) {
253      if (m_predicate.evaluate(input) == TreeModel.Predicate.Eval.TRUE) {
254        ruleCollection.add(this);
255      }     
256    }
257   
258    /**
259     * Score the incoming instance
260     *
261     * @param instance a vector containing the incoming independent and
262     * derived independent variables
263     * @param classAtt the class attribute
264     * @param rsm the rule selection method (ignored by simple rules)
265     * @return a probability distribution over the class labels or
266     * the predicted value (in element zero of the array if the class is numeric)
267     * @throws Exception if something goes wrong
268     */
269    public double[] score(double[] instance, Attribute classAtt) 
270      throws Exception {
271     
272      double[] preds;
273      if (classAtt.isNumeric()) {
274        preds = new double[1];
275        preds[0] = m_score;
276      } else {
277        preds = new double[classAtt.numValues()];
278        if (m_scoreDistributions.size() > 0) {
279          for (TreeModel.ScoreDistribution s : m_scoreDistributions) {
280            preds[s.getClassLabelIndex()] = s.getConfidence();
281          }
282        } else if (!Utils.isMissingValue(m_confidence)) {
283          preds[classAtt.indexOfValue(m_scoreString)] = m_confidence;
284        } else {
285          preds[classAtt.indexOfValue(m_scoreString)] = 1.0;
286        }
287      }     
288     
289      return preds;
290    }
291   
292    /**
293     * Get the weight of the rule
294     *
295     * @return the weight of the rule
296     */
297    public double getWeight() {
298      return m_weight;
299    }
300   
301    /**
302     * Get the ID of the rule
303     *
304     * @return the ID of the rule
305     */
306    public String getID() {
307      return m_ID;
308    }
309   
310    /**
311     * Get the predicted value of this rule (either a number
312     * for regression problems or an index of a class label for
313     * classification problems)
314     *
315     * @return the predicted value of this rule
316     */
317    public double getScore() {
318      return m_score;
319    }
320  }
321 
322  /**
323   * Inner class representing a compound rule
324   */
325  static class CompoundRule extends Rule {
326   
327    /** For serialization */
328    private static final long serialVersionUID = -2853658811459970718L;
329   
330    /** The child rules of this compound rule */
331    ArrayList<Rule> m_childRules = new ArrayList<Rule>();
332   
333    public String toString(String prefix, int indent) {
334      StringBuffer temp = new StringBuffer();
335
336      for (int i = 0; i < indent; i++) {
337        prefix += " ";
338      }
339     
340      temp.append(prefix + "Compound rule: " + m_predicate + "\n");
341     
342      for (Rule r : m_childRules) {
343        temp.append(r.toString(prefix, indent + 1));
344      }
345
346      return temp.toString();
347    }
348   
349    public String toString() {
350      return toString("", 0);
351    }
352   
353    /**
354     * Constructor.
355     *
356     * @param ruleE XML node holding the rule
357     * @param miningSchema the mining schema to use
358     * @throws Exception if something goes wrong
359     */
360    public CompoundRule(Element ruleE, MiningSchema miningSchema) throws Exception {
361     
362      // get the Predicate
363      super(ruleE, miningSchema);
364     
365      // get the nested rules
366      NodeList ruleChildren = ruleE.getChildNodes();
367      for (int i = 0; i < ruleChildren.getLength(); i++) {
368        Node child = ruleChildren.item(i);
369        if (child.getNodeType() == Node.ELEMENT_NODE) {
370          String tagName = ((Element)child).getTagName();
371          if (tagName.equals("SimpleRule")) {
372            Rule childRule = new SimpleRule(((Element)child), miningSchema);
373            m_childRules.add(childRule);
374          } else if (tagName.equals("CompoundRule")) {
375            Rule childRule = new CompoundRule(((Element)child), miningSchema);
376            m_childRules.add(childRule);
377          }
378        }
379      }
380    }
381   
382    /**
383     * Collect the rule(s) that fire for the supplied incoming instance
384     *
385     * @param input a vector of independent and derived independent variables
386     * @param ruleCollection the array list to add any firing rules into
387     */
388    public void fires(double[] input, ArrayList<SimpleRule> ruleCollection) {
389     
390      // evaluate our predicate first
391      if (m_predicate.evaluate(input) == TreeModel.Predicate.Eval.TRUE) {
392        // now check the child rules
393        for (Rule r : m_childRules) {
394          r.fires(input, ruleCollection);
395        }
396      }     
397    }
398  }
399 
400  /**
401   * Inner class representing a set of rules
402   */
403  static class RuleSet implements Serializable {
404   
405    /** For serialization */
406    private static final long serialVersionUID = -8718126887943074376L;
407
408    enum RuleSelectionMethod {
409      WEIGHTEDSUM("weightedSum"),
410      WEIGHTEDMAX("weightedMax"),
411      FIRSTHIT("firstHit");
412     
413      private final String m_stringVal;
414     
415      RuleSelectionMethod(String name) {
416        m_stringVal = name;
417      }
418     
419      public String toString() {
420        return m_stringVal;
421      }
422    }
423   
424    /**
425     * The number of training/test cases to which the ruleset was
426     * applied to generate support and confidence measures for individual
427     * rules (optional)
428     */
429    private double m_recordCount = Utils.missingValue();
430   
431    /**
432     * The number of training/test cases for which the default
433     * score is correct (optional)
434     */
435    private double m_nbCorrect = Utils.missingValue();
436   
437    /**
438     * The default value to predict when no rule in the
439     * ruleset fires (as a String; optional)
440     * */
441    private String m_defaultScore;
442   
443    /**
444     * The default value to predict (either a real value or an
445     * index)
446     * */
447    private double m_defaultPrediction = Utils.missingValue();
448   
449    /**
450     * The default distribution to predict when no rule in the
451     * ruleset fires (nominal class only, optional)
452     */
453    private ArrayList<TreeModel.ScoreDistribution> m_scoreDistributions =
454      new ArrayList<TreeModel.ScoreDistribution>();
455   
456    /**
457     * The default confidence value to return along with a score
458     * when no rules in the set fire (optional)
459     */
460    private double m_defaultConfidence = Utils.missingValue();
461   
462    /** The active rule selection method */
463    private RuleSelectionMethod m_currentMethod;
464   
465    /** The selection of rule selection methods allowed */
466    private ArrayList<RuleSelectionMethod> m_availableRuleSelectionMethods = 
467      new ArrayList<RuleSelectionMethod>();
468   
469    /** The rules contained in the rule set */
470    private ArrayList<Rule> m_rules = new ArrayList<Rule>();
471   
472    /* (non-Javadoc)
473     * @see java.lang.Object#toString()
474     */
475    public String toString() {
476      StringBuffer temp = new StringBuffer();
477     
478      temp.append("Rule selection method: " + m_currentMethod + "\n");
479      if (m_defaultScore != null) {
480        temp.append("Default prediction: " + m_defaultScore + "\n");
481       
482        if (!Utils.isMissingValue(m_recordCount)) {
483          temp.append("       recordCount: " + m_recordCount + "\n");
484        }
485        if (!Utils.isMissingValue(m_nbCorrect)) {
486          temp.append("         nbCorrect: " + m_nbCorrect + "\n");
487        }
488        if (!Utils.isMissingValue(m_defaultConfidence)) {
489          temp.append(" defaultConfidence: " + m_defaultConfidence + "\n");
490        }
491       
492        temp.append("\n");
493      }
494     
495      for (Rule r : m_rules) {
496        temp.append(r + "\n");
497      }
498     
499      return temp.toString();
500    }
501   
502    /**
503     * Constructor for a RuleSet.
504     *
505     * @param ruleSetNode the XML node holding the RuleSet
506     * @param miningSchema the mining schema to use
507     * @throws Exception if something goes wrong
508     */
509    public RuleSet(Element ruleSetNode, MiningSchema miningSchema) 
510      throws Exception {
511     
512      String recordCount = ruleSetNode.getAttribute("recordCount");
513      if (recordCount != null && recordCount.length() > 0) {
514        m_recordCount = Double.parseDouble(recordCount);
515      }
516     
517      String nbCorrect = ruleSetNode.getAttribute("nbCorrect");
518      if (nbCorrect != null & nbCorrect.length() > 0) {
519        m_nbCorrect = Double.parseDouble(nbCorrect);
520      }
521     
522      String defaultScore = ruleSetNode.getAttribute("defaultScore");
523      if (defaultScore != null && defaultScore.length() > 0) {
524        m_defaultScore = defaultScore;
525       
526        Attribute classAtt = miningSchema.getFieldsAsInstances().classAttribute();
527        if (classAtt == null) {
528          throw new Exception("[RuleSet] class attribute not set!");
529        }
530       
531        if (classAtt.isNumeric()) {
532          m_defaultPrediction = Double.parseDouble(defaultScore);
533        } else {
534          if (classAtt.indexOfValue(defaultScore) < 0) {
535            throw new Exception("[RuleSet] class value " + defaultScore + 
536                " not found!");
537          }
538          m_defaultPrediction = classAtt.indexOfValue(defaultScore);
539        }
540      }
541     
542      String defaultConfidence = ruleSetNode.getAttribute("defaultConfidence");
543      if (defaultConfidence != null && defaultConfidence.length() > 0) {
544        m_defaultConfidence = Double.parseDouble(defaultConfidence);
545      }
546     
547      // get the rule selection methods
548      NodeList selectionNL = ruleSetNode.getElementsByTagName("RuleSelectionMethod");
549      for (int i = 0; i < selectionNL.getLength(); i++) {
550        Node selectN = selectionNL.item(i);
551        if (selectN.getNodeType() == Node.ELEMENT_NODE) {
552          Element sN = (Element)selectN;
553          String criterion = sN.getAttribute("criterion");
554          for (RuleSelectionMethod m : RuleSelectionMethod.values()) {
555            if (m.toString().equals(criterion)) {
556              m_availableRuleSelectionMethods.add(m);
557              if (i == 0) {
558                // set the default (first specified one)
559                m_currentMethod = m;
560              }
561            }
562          }
563        }
564      }
565     
566      if (miningSchema.getFieldsAsInstances().classAttribute().isNominal()) {
567        // see if we have any ScoreDistribution entries
568        NodeList scoreChildren = ruleSetNode.getChildNodes();
569        for (int i = 0; i < scoreChildren.getLength(); i++) {
570          Node child = scoreChildren.item(i);
571          if (child.getNodeType() == Node.ELEMENT_NODE) {
572            String tagName = ((Element)child).getTagName();
573            if (tagName.equals("ScoreDistribution")) {
574              TreeModel.ScoreDistribution newDist = 
575                new TreeModel.ScoreDistribution((Element)child, 
576                  miningSchema, m_recordCount);
577              m_scoreDistributions.add(newDist);
578            }
579          }
580        }
581       
582        //backfit the confidence values (if necessary)
583        if (Utils.isMissingValue(m_recordCount)) {
584          double baseCount = 0;
585          for (TreeModel.ScoreDistribution s : m_scoreDistributions) {
586            baseCount += s.getRecordCount();
587          }
588         
589          for (TreeModel.ScoreDistribution s : m_scoreDistributions) {
590            s.deriveConfidenceValue(baseCount);
591          }
592        }
593      }
594     
595      // Get the rules in this rule set
596      NodeList ruleChildren = ruleSetNode.getChildNodes();
597      for (int i = 0; i < ruleChildren.getLength(); i++) {
598        Node child = ruleChildren.item(i);
599        if (child.getNodeType() == Node.ELEMENT_NODE) {
600          String tagName = ((Element)child).getTagName();
601          if (tagName.equals("SimpleRule")) {
602            Rule tempRule = new SimpleRule(((Element)child), miningSchema);
603            m_rules.add(tempRule);
604          } else if (tagName.equals("CompoundRule")) {
605            Rule tempRule = new CompoundRule(((Element)child), miningSchema);
606            m_rules.add(tempRule);
607          }
608        }
609      }
610    }
611   
612    /**
613     * Score an incoming instance by collecting all rules that fire.
614     *
615     * @param instance a vector of incoming attribte and derived field values
616     * @param classAtt the class attribute
617     * @return a predicted probability distribution
618     * @throws Exception is something goes wrong
619     */
620    protected double[] score(double[] instance, Attribute classAtt)
621      throws Exception {
622     
623      double[] preds = null;
624      if (classAtt.isNumeric()) {
625        preds = new double[1];
626      } else {
627        preds = new double[classAtt.numValues()];
628      }
629     
630      // holds the rules that fire for this test case
631      ArrayList<SimpleRule> firingRules = new ArrayList<SimpleRule>();
632     
633      for (Rule r : m_rules) {
634        r.fires(instance, firingRules);
635      }
636     
637      if (firingRules.size() > 0) {
638        if (m_currentMethod == RuleSelectionMethod.FIRSTHIT) {
639          preds = firingRules.get(0).score(instance, classAtt);
640        } else if (m_currentMethod == RuleSelectionMethod.WEIGHTEDMAX) {
641          double wMax = Double.NEGATIVE_INFINITY;
642          SimpleRule best = null;
643          for (SimpleRule s : firingRules) {
644            if (Utils.isMissingValue(s.getWeight())) {
645              throw new Exception("[RuleSet] Scoring criterion is WEIGHTEDMAX, but " +
646                        "rule " + s.getID() + " does not have a weight defined!");
647            }
648            if (s.getWeight() > wMax) {
649              wMax = s.getWeight();
650              best = s;
651            }
652          }
653          if (best == null) {
654            throw new Exception("[RuleSet] Unable to determine the best rule under " +
655                        "the WEIGHTEDMAX criterion!");
656          }
657          preds = best.score(instance, classAtt);         
658        } else if (m_currentMethod == RuleSelectionMethod.WEIGHTEDSUM) {
659          double sumOfWeights = 0;
660          for (SimpleRule s : firingRules) {
661            if (Utils.isMissingValue(s.getWeight())) {
662              throw new Exception("[RuleSet] Scoring criterion is WEIGHTEDSUM, but " +
663                        "rule " + s.getID() + " does not have a weight defined!");
664            }           
665            if (classAtt.isNumeric()) {
666              sumOfWeights += s.getWeight();
667              preds[0] += (s.getScore() * s.getWeight());
668            } else {
669              preds[(int)s.getScore()] += s.getWeight();
670            }
671          }
672          if (classAtt.isNumeric()) {
673            if (sumOfWeights == 0) {
674              throw new Exception("[RuleSet] Sum of weights is zero!");
675            }
676            preds[0] /= sumOfWeights;
677          } else {
678            // array gets normalized in the distributionForInstance() method
679          }
680        }
681      } else {
682        // default prediction
683        if (classAtt.isNumeric()) {
684          preds[0] = m_defaultPrediction;
685        } else {
686          if (m_scoreDistributions.size() > 0) {
687            for (TreeModel.ScoreDistribution s : m_scoreDistributions) {
688              preds[s.getClassLabelIndex()] = s.getConfidence();
689            }
690          } else if (!Utils.isMissingValue(m_defaultConfidence)) {
691            preds[(int)m_defaultPrediction] = m_defaultConfidence;
692          } else {
693            preds[(int)m_defaultPrediction] = 1.0;
694          }
695        }
696      }
697     
698      return preds;
699    }
700  }
701 
702  /** The mining function */
703  protected MiningFunction m_functionType = MiningFunction.CLASSIFICATION;
704 
705  /** The model name (if defined) */
706  protected String m_modelName;
707 
708  /** The algorithm name (if defined) */
709  protected String m_algorithmName;
710 
711  /** The set of rules */
712  protected RuleSet m_ruleSet;
713
714  /**
715   * Constructor for a RuleSetModel
716   *
717   * @param model the XML element encapsulating the RuleSetModel
718   * @param dataDictionary the data dictionary to use
719   * @param miningSchema the mining schema to use
720   * @throws Exception if something goes wrong
721   */
722  public RuleSetModel(Element model, Instances dataDictionary,
723      MiningSchema miningSchema) throws Exception {
724   
725    super(dataDictionary, miningSchema);
726   
727    if (!getPMMLVersion().equals("3.2")) {
728      // TODO: might have to throw an exception and only support 3.2
729    }
730   
731    String fn = model.getAttribute("functionName");
732    if (fn.equals("regression")) {
733      m_functionType = MiningFunction.REGRESSION;
734    }
735   
736    String modelName = model.getAttribute("modelName");
737    if (modelName != null && modelName.length() > 0) {
738      m_modelName = modelName;
739    }
740   
741    String algoName = model.getAttribute("algorithmName");
742    if (algoName != null && algoName.length() > 0) {
743      m_algorithmName = algoName;
744    }   
745   
746    NodeList ruleset = model.getElementsByTagName("RuleSet");
747    if (ruleset.getLength() == 1) {
748      Node ruleSetNode = ruleset.item(0);
749      if (ruleSetNode.getNodeType() == Node.ELEMENT_NODE) {
750        m_ruleSet = new RuleSet((Element)ruleSetNode, miningSchema);
751      }
752    } else {
753      throw new Exception ("[RuleSetModel] Should only have a single RuleSet!");
754    }
755  }
756 
757  /**                                                                                                             
758   * Classifies the given test instance. The instance has to belong to a                                         
759   * dataset when it's being classified.                                                         
760   *                                                                                                             
761   * @param inst the instance to be classified                                                               
762   * @return the predicted most likely class for the instance or                                                 
763   * Utils.missingValue() if no prediction is made                                                             
764   * @exception Exception if an error occurred during the prediction                                             
765   */
766  public double[] distributionForInstance(Instance inst) throws Exception {
767    if (!m_initialized) {
768      mapToMiningSchema(inst.dataset());
769    }
770    double[] preds = null;
771   
772    if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
773      preds = new double[1];
774    } else {
775      preds = new double[m_miningSchema.getFieldsAsInstances().classAttribute().numValues()];
776    }
777   
778    double[] incoming = m_fieldsMap.instanceToSchema(inst, m_miningSchema);
779   
780    preds = m_ruleSet.score(incoming, 
781        m_miningSchema.getFieldsAsInstances().classAttribute());
782   
783    if (m_miningSchema.getFieldsAsInstances().classAttribute().isNominal()) {
784      Utils.normalize(preds);
785    }
786   
787    return preds;
788  }
789 
790  /**
791   * Return a textual description of this model.
792   *
793   * @return a textual description of this model
794   */
795  public String toString() {
796    StringBuffer temp = new StringBuffer();
797   
798    temp.append("PMML version " + getPMMLVersion());
799    if (!getCreatorApplication().equals("?")) {
800      temp.append("\nApplication: " + getCreatorApplication());
801    }
802    temp.append("\nPMML Model: RuleSetModel");
803    temp.append("\n\n");
804    temp.append(m_miningSchema);
805   
806    if (m_algorithmName != null) {
807      temp.append("\nAlgorithm: " + m_algorithmName + "\n");
808    }
809   
810    temp.append(m_ruleSet);
811 
812    return temp.toString();
813  }
814
815  /**
816   * Get the revision string for this class
817   *
818   * @return the revision string
819   */
820  public String getRevision() {
821    return RevisionUtils.extract("$Revision: 5987 $");
822  }
823 
824}
Note: See TracBrowser for help on using the repository browser.