source: src/main/java/weka/classifiers/trees/m5/RuleNode.java @ 16

Last change on this file since 16 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 27.2 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    RuleNode.java
19 *    Copyright (C) 2000 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.trees.m5;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.classifiers.Evaluation;
28import weka.classifiers.functions.LinearRegression;
29import weka.core.FastVector;
30import weka.core.Instance;
31import weka.core.Instances;
32import weka.core.RevisionUtils;
33import weka.core.Utils;
34import weka.filters.Filter;
35import weka.filters.unsupervised.attribute.Remove;
36
37/**
38 * Constructs a node for use in an m5 tree or rule
39 *
40 * @author Mark Hall (mhall@cs.waikato.ac.nz)
41 * @version $Revision: 5928 $
42 */
43public class RuleNode 
44  extends AbstractClassifier {
45
46  /** for serialization */
47  static final long serialVersionUID = 1979807611124337144L;
48 
49  /**
50   * instances reaching this node
51   */
52  private Instances        m_instances;
53
54  /**
55   * the class index
56   */
57  private int              m_classIndex;
58
59  /**
60   * the number of instances reaching this node
61   */
62  protected int            m_numInstances;
63
64  /**
65   * the number of attributes
66   */
67  private int              m_numAttributes;
68
69  /**
70   * Node is a leaf
71   */
72  private boolean          m_isLeaf;
73
74  /**
75   * attribute this node splits on
76   */
77  private int              m_splitAtt;
78
79  /**
80   * the value of the split attribute
81   */
82  private double           m_splitValue;
83
84  /**
85   * the linear model at this node
86   */
87  private PreConstructedLinearModel m_nodeModel;
88
89  /**
90   * the number of paramters in the chosen model for this node---either
91   * the subtree model or the linear model.
92   * The constant term is counted as a paramter---this is for pruning
93   * purposes
94   */
95  public int               m_numParameters;
96
97  /**
98   * the mean squared error of the model at this node (either linear or
99   * subtree)
100   */
101  private double           m_rootMeanSquaredError;
102
103  /**
104   * left child node
105   */
106  protected RuleNode       m_left;
107
108  /**
109   * right child node
110   */
111  protected RuleNode       m_right;
112
113  /**
114   * the parent of this node
115   */
116  private RuleNode         m_parent;
117
118  /**
119   * a node will not be split if it contains less then m_splitNum instances
120   */
121  private double           m_splitNum = 4;
122
123  /**
124   * a node will not be split if its class standard deviation is less
125   * than 5% of the class standard deviation of all the instances
126   */
127  private double           m_devFraction = 0.05;
128  private double           m_pruningMultiplier = 2;
129
130  /**
131   * the number assigned to the linear model if this node is a leaf.
132   * = 0 if this node is not a leaf
133   */
134  private int              m_leafModelNum;
135
136  /**
137   * a node will not be split if the class deviation of its
138   * instances is less than m_devFraction of the deviation of the
139   * global class
140   */
141  private double           m_globalDeviation;
142
143  /**
144   * the absolute deviation of the global class
145   */
146  private double           m_globalAbsDeviation;
147
148  /**
149   * Indices of the attributes to be used in generating a linear model
150   * at this node
151   */
152  private int [] m_indices;
153   
154  /**
155   * Constant used in original m5 smoothing calculation
156   */
157  private static final double      SMOOTHING_CONSTANT = 15.0;
158
159  /**
160   * Node id.
161   */
162  private int m_id;
163
164  /**
165   * Save the instances at each node (for visualizing in the
166   * Explorer's treevisualizer.
167   */
168  private boolean m_saveInstances = false;
169
170  /**
171   * Make a regression tree instead of a model tree
172   */
173  private boolean m_regressionTree;
174
175  /**
176   * Creates a new <code>RuleNode</code> instance.
177   *
178   * @param globalDev the global standard deviation of the class
179   * @param globalAbsDev the global absolute deviation of the class
180   * @param parent the parent of this node
181   */
182  public RuleNode(double globalDev, double globalAbsDev, RuleNode parent) {
183    m_nodeModel = null;
184    m_right = null;
185    m_left = null;
186    m_parent = parent;
187    m_globalDeviation = globalDev;
188    m_globalAbsDeviation = globalAbsDev;
189  }
190
191   
192  /**
193   * Build this node (find an attribute and split point)
194   *
195   * @param data the instances on which to build this node
196   * @throws Exception if an error occurs
197   */
198  public void buildClassifier(Instances data) throws Exception {
199
200    m_rootMeanSquaredError = Double.MAX_VALUE;
201    //    m_instances = new Instances(data);
202    m_instances = data;
203    m_classIndex = m_instances.classIndex();
204    m_numInstances = m_instances.numInstances();
205    m_numAttributes = m_instances.numAttributes();
206    m_nodeModel = null;
207    m_right = null;
208    m_left = null;
209
210    if ((m_numInstances < m_splitNum) 
211        || (Rule.stdDev(m_classIndex, m_instances) 
212            < (m_globalDeviation * m_devFraction))) {
213      m_isLeaf = true;
214    } else {
215      m_isLeaf = false;
216    } 
217
218    split();
219  } 
220 
221  /**
222   * Classify an instance using this node. Recursively calls classifyInstance
223   * on child nodes.
224   *
225   * @param inst the instance to classify
226   * @return the prediction for this instance
227   * @throws Exception if an error occurs
228   */
229  public double classifyInstance(Instance inst) throws Exception {
230    if (m_isLeaf) {
231      if (m_nodeModel == null) {
232        throw new Exception("Classifier has not been built correctly.");
233      } 
234
235      return m_nodeModel.classifyInstance(inst);
236    }
237
238    if (inst.value(m_splitAtt) <= m_splitValue) {
239      return m_left.classifyInstance(inst);
240    } else {
241      return m_right.classifyInstance(inst);
242    } 
243  } 
244
245  /**
246   * Applies the m5 smoothing procedure to a prediction
247   *
248   * @param n number of instances in selected child of this node
249   * @param pred the prediction so far
250   * @param supportPred the prediction of the linear model at this node
251   * @return the current prediction smoothed with the prediction of the
252   * linear model at this node
253   * @throws Exception if an error occurs
254   */
255  protected static double smoothingOriginal(double n, double pred, 
256                                            double supportPred) 
257    throws Exception {
258    double   smoothed;
259
260    smoothed = 
261      ((n * pred) + (SMOOTHING_CONSTANT * supportPred)) /
262      (n + SMOOTHING_CONSTANT);
263
264    return smoothed;
265  } 
266
267
268  /**
269   * Finds an attribute and split point for this node
270   *
271   * @throws Exception if an error occurs
272   */
273  public void split() throws Exception {
274    int           i;
275    Instances     leftSubset, rightSubset;
276    SplitEvaluate bestSplit, currentSplit;
277    boolean[]     attsBelow;
278
279    if (!m_isLeaf) {
280     
281      bestSplit = new YongSplitInfo(0, m_numInstances - 1, -1);
282      currentSplit = new YongSplitInfo(0, m_numInstances - 1, -1);
283
284      // find the best attribute to split on
285      for (i = 0; i < m_numAttributes; i++) {
286        if (i != m_classIndex) {
287
288          // sort the instances by this attribute
289          m_instances.sort(i);
290          currentSplit.attrSplit(i, m_instances);
291
292          if ((Math.abs(currentSplit.maxImpurity() - 
293                        bestSplit.maxImpurity()) > 1.e-6) 
294              && (currentSplit.maxImpurity() 
295                  > bestSplit.maxImpurity() + 1.e-6)) {
296            bestSplit = currentSplit.copy();
297          } 
298        } 
299      } 
300
301      // cant find a good split or split point?
302      if (bestSplit.splitAttr() < 0 || bestSplit.position() < 1 
303          || bestSplit.position() > m_numInstances - 1) {
304        m_isLeaf = true;
305      } else {
306        m_splitAtt = bestSplit.splitAttr();
307        m_splitValue = bestSplit.splitValue();
308        leftSubset = new Instances(m_instances, m_numInstances);
309        rightSubset = new Instances(m_instances, m_numInstances);
310
311        for (i = 0; i < m_numInstances; i++) {
312          if (m_instances.instance(i).value(m_splitAtt) <= m_splitValue) {
313            leftSubset.add(m_instances.instance(i));
314          } else {
315            rightSubset.add(m_instances.instance(i));
316          } 
317        } 
318
319        leftSubset.compactify();
320        rightSubset.compactify();
321
322        // build left and right nodes
323        m_left = new RuleNode(m_globalDeviation, m_globalAbsDeviation, this);
324        m_left.setMinNumInstances(m_splitNum);
325        m_left.setRegressionTree(m_regressionTree);
326        m_left.setSaveInstances(m_saveInstances);
327        m_left.buildClassifier(leftSubset);
328
329        m_right = new RuleNode(m_globalDeviation, m_globalAbsDeviation, this);
330        m_right.setMinNumInstances(m_splitNum);
331        m_right.setRegressionTree(m_regressionTree);
332        m_right.setSaveInstances(m_saveInstances);
333        m_right.buildClassifier(rightSubset);
334
335        // now find out what attributes are tested in the left and right
336        // subtrees and use them to learn a linear model for this node
337        if (!m_regressionTree) {
338          attsBelow = attsTestedBelow();
339          attsBelow[m_classIndex] = true;
340          int count = 0, j;
341
342          for (j = 0; j < m_numAttributes; j++) {
343            if (attsBelow[j]) {
344              count++;
345            } 
346          } 
347         
348          int[] indices = new int[count];
349
350          count = 0;
351         
352          for (j = 0; j < m_numAttributes; j++) {
353            if (attsBelow[j] && (j != m_classIndex)) {
354              indices[count++] = j;
355            } 
356          } 
357         
358          indices[count] = m_classIndex;
359          m_indices = indices;
360        } else {
361          m_indices = new int [1];
362          m_indices[0] = m_classIndex;
363          m_numParameters = 1;
364        }
365      } 
366    } 
367
368    if (m_isLeaf) {
369      int [] indices = new int [1];
370      indices[0] = m_classIndex;
371      m_indices = indices;
372      m_numParameters = 1;
373     
374      // need to evaluate the model here if want correct stats for unpruned
375      // tree
376    } 
377  } 
378
379  /**
380   * Build a linear model for this node using those attributes
381   * specified in indices.
382   *
383   * @param indices an array of attribute indices to include in the linear
384   * model
385   * @throws Exception if something goes wrong
386   */
387  private void buildLinearModel(int [] indices) throws Exception {
388    // copy the training instances and remove all but the tested
389    // attributes
390    Instances reducedInst = new Instances(m_instances);
391    Remove attributeFilter = new Remove();
392   
393    attributeFilter.setInvertSelection(true);
394    attributeFilter.setAttributeIndicesArray(indices);
395    attributeFilter.setInputFormat(reducedInst);
396
397    reducedInst = Filter.useFilter(reducedInst, attributeFilter);
398   
399    // build a linear regression for the training data using the
400    // tested attributes
401    LinearRegression temp = new LinearRegression();
402    temp.buildClassifier(reducedInst);
403
404    double [] lmCoeffs = temp.coefficients();
405    double [] coeffs = new double [m_instances.numAttributes()];
406
407    for (int i = 0; i < lmCoeffs.length - 1; i++) {
408      if (indices[i] != m_classIndex) {
409        coeffs[indices[i]] = lmCoeffs[i];
410      }
411    }
412    m_nodeModel = new PreConstructedLinearModel(coeffs, lmCoeffs[lmCoeffs.length - 1]);
413    m_nodeModel.buildClassifier(m_instances);
414  }
415
416  /**
417   * Returns an array containing the indexes of attributes used in tests
418   * above this node
419   *
420   * @return an array of attribute indexes
421   */
422  private boolean[] attsTestedAbove() {
423    boolean[] atts = new boolean[m_numAttributes];
424    boolean[] attsAbove = null;
425
426    if (m_parent != null) {
427      attsAbove = m_parent.attsTestedAbove();
428    } 
429
430    if (attsAbove != null) {
431      for (int i = 0; i < m_numAttributes; i++) {
432        atts[i] = attsAbove[i];
433      } 
434    } 
435
436    atts[m_splitAtt] = true;
437    return atts;
438  } 
439
440  /**
441   * Returns an array containing the indexes of attributes used in tests
442   * below this node
443   *
444   * @return an array of attribute indexes
445   */
446  private boolean[] attsTestedBelow() {
447    boolean[] attsBelow = new boolean[m_numAttributes];
448    boolean[] attsBelowLeft = null;
449    boolean[] attsBelowRight = null;
450
451    if (m_right != null) {
452      attsBelowRight = m_right.attsTestedBelow();
453    } 
454
455    if (m_left != null) {
456      attsBelowLeft = m_left.attsTestedBelow();
457    } 
458
459    for (int i = 0; i < m_numAttributes; i++) {
460      if (attsBelowLeft != null) {
461        attsBelow[i] = (attsBelow[i] || attsBelowLeft[i]);
462      } 
463
464      if (attsBelowRight != null) {
465        attsBelow[i] = (attsBelow[i] || attsBelowRight[i]);
466      } 
467    } 
468
469    if (!m_isLeaf) {
470      attsBelow[m_splitAtt] = true;
471    } 
472    return attsBelow;
473  } 
474
475  /**
476   * Sets the leaves' numbers
477   * @param leafCounter the number of leaves counted
478   * @return the number of the total leaves under the node
479   */
480  public int numLeaves(int leafCounter) {
481
482    if (!m_isLeaf) {
483      // node
484      m_leafModelNum = 0;
485
486      if (m_left != null) {
487        leafCounter = m_left.numLeaves(leafCounter);
488      } 
489
490      if (m_right != null) {
491        leafCounter = m_right.numLeaves(leafCounter);
492      } 
493    } else {
494      // leaf
495      leafCounter++;
496      m_leafModelNum = leafCounter;
497    } 
498    return leafCounter;
499  } 
500
501  /**
502   * print the linear model at this node
503   *
504   * @return the linear model
505   */
506  public String toString() {
507    return printNodeLinearModel();
508  } 
509
510  /**
511   * print the linear model at this node
512   *
513   * @return the linear model at this node
514   */
515  public String printNodeLinearModel() {
516    return m_nodeModel.toString();
517  } 
518
519  /**
520   * print all leaf models
521   *
522   * @return the leaf models
523   */
524  public String printLeafModels() {
525    StringBuffer text = new StringBuffer();
526
527    if (m_isLeaf) {
528      text.append("\nLM num: " + m_leafModelNum);
529      text.append(m_nodeModel.toString());
530      text.append("\n");
531    } else {
532      text.append(m_left.printLeafModels());
533      text.append(m_right.printLeafModels());
534    } 
535    return text.toString();
536  } 
537
538  /**
539   * Returns a description of this node (debugging purposes)
540   *
541   * @return a string describing this node
542   */
543  public String nodeToString() {
544    StringBuffer text = new StringBuffer();
545
546    System.out.println("In to string");
547    text.append("Node:\n\tnum inst: " + m_numInstances);
548
549    if (m_isLeaf) {
550      text.append("\n\tleaf");
551    } else {
552      text.append("\tnode");
553    }
554
555    text.append("\n\tSplit att: " + m_instances.attribute(m_splitAtt).name());
556    text.append("\n\tSplit val: " + Utils.doubleToString(m_splitValue, 1, 3));
557    text.append("\n\tLM num: " + m_leafModelNum);
558    text.append("\n\tLinear model\n" + m_nodeModel.toString());
559    text.append("\n\n");
560
561    if (m_left != null) {
562      text.append(m_left.nodeToString());
563    } 
564
565    if (m_right != null) {
566      text.append(m_right.nodeToString());
567    } 
568
569    return text.toString();
570  } 
571
572  /**
573   * Recursively builds a textual description of the tree
574   *
575   * @param level the level of this node
576   * @return string describing the tree
577   */
578  public String treeToString(int level) {
579    int          i;
580    StringBuffer text = new StringBuffer();
581
582    if (!m_isLeaf) {
583      text.append("\n");
584
585      for (i = 1; i <= level; i++) {
586        text.append("|   ");
587      } 
588
589      if (m_instances.attribute(m_splitAtt).name().charAt(0) != '[') {
590        text.append(m_instances.attribute(m_splitAtt).name() + " <= " 
591                    + Utils.doubleToString(m_splitValue, 1, 3) + " : ");
592      } else {
593        text.append(m_instances.attribute(m_splitAtt).name() + " false : ");
594      } 
595
596      if (m_left != null) {
597        text.append(m_left.treeToString(level + 1));
598      } else {
599        text.append("NULL\n");
600      }
601
602      for (i = 1; i <= level; i++) {
603        text.append("|   ");
604      } 
605
606      if (m_instances.attribute(m_splitAtt).name().charAt(0) != '[') {
607        text.append(m_instances.attribute(m_splitAtt).name() + " >  " 
608                    + Utils.doubleToString(m_splitValue, 1, 3) + " : ");
609      } else {
610        text.append(m_instances.attribute(m_splitAtt).name() + " true : ");
611      } 
612
613      if (m_right != null) {
614        text.append(m_right.treeToString(level + 1));
615      } else {
616        text.append("NULL\n");
617      }
618    } else {
619      text.append("LM" + m_leafModelNum);
620
621      if (m_globalDeviation > 0.0) {
622        text
623          .append(" (" + m_numInstances + "/" 
624                  + Utils.doubleToString((100.0 * m_rootMeanSquaredError /
625                                             m_globalDeviation), 1, 3) 
626                  + "%)\n");
627      } else {
628        text.append(" (" + m_numInstances + ")\n");
629      } 
630    } 
631    return text.toString();
632  } 
633
634  /**
635   * Traverses the tree and installs linear models at each node.
636   * This method must be called if pruning is not to be performed.
637   *
638   * @throws Exception if an error occurs
639   */
640  public void installLinearModels() throws Exception {
641    Evaluation nodeModelEval;
642    if (m_isLeaf) {
643      buildLinearModel(m_indices);
644    } else {
645      if (m_left != null) {
646        m_left.installLinearModels();
647      }
648
649      if (m_right != null) {
650        m_right.installLinearModels();
651      }
652      buildLinearModel(m_indices);
653    }
654    nodeModelEval = new Evaluation(m_instances);
655    nodeModelEval.evaluateModel(m_nodeModel, m_instances);
656    m_rootMeanSquaredError = nodeModelEval.rootMeanSquaredError();
657    // save space
658    if (!m_saveInstances) {
659      m_instances = new Instances(m_instances, 0);
660    }
661  }
662
663  /**
664   *
665   * @throws Exception
666   */
667  public void installSmoothedModels() throws Exception {
668
669    if (m_isLeaf) {
670      double [] coefficients = new double [m_numAttributes];
671      double intercept;
672      double  [] coeffsUsedByLinearModel = m_nodeModel.coefficients();
673      RuleNode current = this;
674     
675      // prime array with leaf node coefficients
676      for (int i = 0; i < coeffsUsedByLinearModel.length; i++) {
677        if (i != m_classIndex) {
678          coefficients[i] = coeffsUsedByLinearModel[i];
679        }
680      }
681      // intercept
682      intercept = m_nodeModel.intercept();
683
684      do {
685        if (current.m_parent != null) {
686          double n = current.m_numInstances;
687          // contribution of the model below
688          for (int i = 0; i < coefficients.length; i++) {
689            coefficients[i] = ((coefficients[i] * n) / (n + SMOOTHING_CONSTANT));
690          }
691          intercept =  ((intercept * n) / (n + SMOOTHING_CONSTANT));
692
693          // contribution of this model
694          coeffsUsedByLinearModel = current.m_parent.getModel().coefficients();
695          for (int i = 0; i < coeffsUsedByLinearModel.length; i++) {
696            if (i != m_classIndex) {
697              // smooth in these coefficients (at this node)
698              coefficients[i] += 
699                ((SMOOTHING_CONSTANT * coeffsUsedByLinearModel[i]) /
700                 (n + SMOOTHING_CONSTANT));
701            }
702          }
703          // smooth in the intercept
704          intercept += 
705            ((SMOOTHING_CONSTANT * 
706              current.m_parent.getModel().intercept()) /
707             (n + SMOOTHING_CONSTANT));
708          current = current.m_parent;
709        }
710      } while (current.m_parent != null);
711      m_nodeModel = 
712        new PreConstructedLinearModel(coefficients, intercept);
713      m_nodeModel.buildClassifier(m_instances);
714    }
715    if (m_left != null) {
716      m_left.installSmoothedModels();
717    }
718    if (m_right != null) {
719      m_right.installSmoothedModels();
720    }
721  }
722   
723  /**
724   * Recursively prune the tree
725   *
726   * @throws Exception if an error occurs
727   */
728  public void prune() throws Exception {
729    Evaluation nodeModelEval = null;
730
731    if (m_isLeaf) {
732      buildLinearModel(m_indices);
733      nodeModelEval = new Evaluation(m_instances);
734
735      // count the constant term as a paramter for a leaf
736      // Evaluate the model
737      nodeModelEval.evaluateModel(m_nodeModel, m_instances);
738
739      m_rootMeanSquaredError = nodeModelEval.rootMeanSquaredError();
740    } else {
741
742      // Prune the left and right subtrees
743      if (m_left != null) {
744        m_left.prune();
745      } 
746
747      if (m_right != null) {
748        m_right.prune();       
749      } 
750     
751      buildLinearModel(m_indices);
752      nodeModelEval = new Evaluation(m_instances);
753
754      double rmsModel;
755      double adjustedErrorModel;
756
757      nodeModelEval.evaluateModel(m_nodeModel, m_instances);
758
759      rmsModel = nodeModelEval.rootMeanSquaredError();
760      adjustedErrorModel = rmsModel
761        * pruningFactor(m_numInstances, 
762                        m_nodeModel.numParameters() + 1);
763
764      // Evaluate this node (ie its left and right subtrees)
765      Evaluation nodeEval = new Evaluation(m_instances);
766      double     rmsSubTree;
767      double     adjustedErrorNode;
768      int        l_params = 0, r_params = 0;
769
770      nodeEval.evaluateModel(this, m_instances);
771
772      rmsSubTree = nodeEval.rootMeanSquaredError();
773
774      if (m_left != null) {
775        l_params = m_left.numParameters();
776      } 
777
778      if (m_right != null) {
779        r_params = m_right.numParameters();
780      } 
781
782      adjustedErrorNode = rmsSubTree
783        * pruningFactor(m_numInstances, 
784                        (l_params + r_params + 1));
785
786      if ((adjustedErrorModel <= adjustedErrorNode) 
787          || (adjustedErrorModel < (m_globalDeviation * 0.00001))) {
788
789        // Choose linear model for this node rather than subtree model
790        m_isLeaf = true;
791        m_right = null;
792        m_left = null;
793        m_numParameters = m_nodeModel.numParameters() + 1;
794        m_rootMeanSquaredError = rmsModel;
795      } else {
796        m_numParameters = (l_params + r_params + 1);
797        m_rootMeanSquaredError = rmsSubTree;
798      } 
799    }
800    // save space
801    if (!m_saveInstances) {
802      m_instances = new Instances(m_instances, 0);
803    }
804  } 
805
806
807  /**
808   * Compute the pruning factor
809   *
810   * @param num_instances number of instances
811   * @param num_params number of parameters in the model
812   * @return the pruning factor
813   */
814  private double pruningFactor(int num_instances, int num_params) {
815    if (num_instances <= num_params) {
816      return 10.0;    // Caution says Yong in his code
817    } 
818
819    return ((double) (num_instances + m_pruningMultiplier * num_params) 
820            / (double) (num_instances - num_params));
821  } 
822
823  /**
824   * Find the leaf with greatest coverage
825   *
826   * @param maxCoverage the greatest coverage found so far
827   * @param bestLeaf the leaf with the greatest coverage
828   */
829  public void findBestLeaf(double[] maxCoverage, RuleNode[] bestLeaf) {
830    if (!m_isLeaf) {
831      if (m_left != null) {
832        m_left.findBestLeaf(maxCoverage, bestLeaf);
833      } 
834
835      if (m_right != null) {
836        m_right.findBestLeaf(maxCoverage, bestLeaf);
837      } 
838    } else {
839      if (m_numInstances > maxCoverage[0]) {
840        maxCoverage[0] = m_numInstances;
841        bestLeaf[0] = this;
842      } 
843    } 
844  } 
845
846  /**
847   * Return a list containing all the leaves in the tree
848   *
849   * @param v a single element array containing a vector of leaves
850   */
851  public void returnLeaves(FastVector[] v) {
852    if (m_isLeaf) {
853      v[0].addElement(this);
854    } else {
855      if (m_left != null) {
856        m_left.returnLeaves(v);
857      } 
858
859      if (m_right != null) {
860        m_right.returnLeaves(v);
861      } 
862    } 
863  } 
864
865  /**
866   * Get the parent of this node
867   *
868   * @return the parent of this node
869   */
870  public RuleNode parentNode() {
871    return m_parent;
872  } 
873
874  /**
875   * Get the left child of this node
876   *
877   * @return the left child of this node
878   */
879  public RuleNode leftNode() {
880    return m_left;
881  } 
882
883  /**
884   * Get the right child of this node
885   *
886   * @return the right child of this node
887   */
888  public RuleNode rightNode() {
889    return m_right;
890  } 
891
892  /**
893   * Get the index of the splitting attribute for this node
894   *
895   * @return the index of the splitting attribute
896   */
897  public int splitAtt() {
898    return m_splitAtt;
899  } 
900
901  /**
902   * Get the split point for this node
903   *
904   * @return the split point for this node
905   */
906  public double splitVal() {
907    return m_splitValue;
908  } 
909
910  /**
911   * Get the number of linear models in the tree
912   *
913   * @return the number of linear models
914   */
915  public int numberOfLinearModels() {
916    if (m_isLeaf) {
917      return 1;
918    } else {
919      return m_left.numberOfLinearModels() + m_right.numberOfLinearModels();
920    } 
921  } 
922
923  /**
924   * Return true if this node is a leaf
925   *
926   * @return true if this node is a leaf
927   */
928  public boolean isLeaf() {
929    return m_isLeaf;
930  } 
931
932  /**
933   * Get the root mean squared error at this node
934   *
935   * @return the root mean squared error
936   */
937  protected double rootMeanSquaredError() {
938    return m_rootMeanSquaredError;
939  } 
940
941  /**
942   * Get the linear model at this node
943   *
944   * @return the linear model at this node
945   */
946  public PreConstructedLinearModel getModel() {
947    return m_nodeModel;
948  }
949
950  /**
951   * Return the number of instances that reach this node.
952   *
953   * @return the number of instances at this node.
954   */
955  public int getNumInstances() {
956    return m_numInstances;
957  }
958
959  /**
960   * Get the number of parameters in the model at this node
961   *
962   * @return the number of parameters in the model at this node
963   */
964  private int numParameters() {
965    return m_numParameters;
966  } 
967
968  /**
969   * Get the value of regressionTree.
970   *
971   * @return Value of regressionTree.
972   */
973  public boolean getRegressionTree() {
974   
975    return m_regressionTree;
976  }
977
978  /**
979   * Set the minumum number of instances to allow at a leaf node
980   *
981   * @param minNum the minimum number of instances
982   */
983  public void setMinNumInstances(double minNum) {
984    m_splitNum = minNum;
985  }
986
987  /**
988   * Get the minimum number of instances to allow at a leaf node
989   *
990   * @return a <code>double</code> value
991   */
992  public double getMinNumInstances() {
993    return m_splitNum;
994  }
995 
996  /**
997   * Set the value of regressionTree.
998   *
999   * @param newregressionTree Value to assign to regressionTree.
1000   */
1001  public void setRegressionTree(boolean newregressionTree) {
1002   
1003    m_regressionTree = newregressionTree;
1004  }
1005                                                         
1006  /**
1007   * Print all the linear models at the learf (debugging purposes)
1008   */
1009  public void printAllModels() {
1010    if (m_isLeaf) {
1011      System.out.println(m_nodeModel.toString());
1012    } else {
1013      System.out.println(m_nodeModel.toString());
1014      m_left.printAllModels();
1015      m_right.printAllModels();
1016    } 
1017  } 
1018
1019  /**
1020   * Assigns a unique identifier to each node in the tree
1021   *
1022   * @param lastID last id number used
1023   * @return ID after processing child nodes
1024   */
1025  protected int assignIDs(int lastID) {
1026    int currLastID = lastID + 1;
1027    m_id = currLastID;
1028
1029    if (m_left != null) {
1030      currLastID = m_left.assignIDs(currLastID);
1031    }
1032
1033    if (m_right != null) {
1034      currLastID = m_right.assignIDs(currLastID);
1035    }
1036    return currLastID;
1037  }
1038
1039  /**
1040   * Assign a unique identifier to each node in the tree and then
1041   * calls graphTree
1042   *
1043   * @param text a <code>StringBuffer</code> value
1044   */
1045  public void graph(StringBuffer text) {
1046    assignIDs(-1);
1047    graphTree(text);
1048  }
1049
1050  /**
1051   * Return a dotty style string describing the tree
1052   *
1053   * @param text a <code>StringBuffer</code> value
1054   */
1055  protected void graphTree(StringBuffer text) {
1056    text.append("N" + m_id
1057                + (m_isLeaf
1058                   ? " [label=\"LM " + m_leafModelNum
1059                   : " [label=\"" + m_instances.attribute(m_splitAtt).name())
1060                + (m_isLeaf
1061                 ? " (" + ((m_globalDeviation > 0.0) 
1062                          ?  m_numInstances + "/" 
1063                             + Utils.doubleToString((100.0 * 
1064                                                     m_rootMeanSquaredError /
1065                                                     m_globalDeviation), 
1066                                                    1, 3) 
1067                             + "%)"
1068                           : m_numInstances + ")")
1069                    + "\" shape=box style=filled "
1070                   : "\"")
1071                + (m_saveInstances
1072                   ? "data=\n" + m_instances + "\n,\n"
1073                   : "")
1074                + "]\n");
1075               
1076    if (m_left != null) {
1077      text.append("N" + m_id + "->" + "N" + m_left.m_id + " [label=\"<="
1078                  + Utils.doubleToString(m_splitValue, 1, 3)
1079                  + "\"]\n");
1080      m_left.graphTree(text);
1081    }
1082     
1083    if (m_right != null) {
1084      text.append("N" + m_id + "->" + "N" + m_right.m_id + " [label=\">"
1085                  + Utils.doubleToString(m_splitValue, 1, 3)
1086                  + "\"]\n");
1087      m_right.graphTree(text);
1088    }
1089  }
1090
1091  /**
1092   * Set whether to save instances for visualization purposes.
1093   * Default is to save memory.
1094   *
1095   * @param save a <code>boolean</code> value
1096   */
1097  protected void setSaveInstances(boolean save) {
1098    m_saveInstances = save;
1099  }
1100 
1101  /**
1102   * Returns the revision string.
1103   *
1104   * @return            the revision
1105   */
1106  public String getRevision() {
1107    return RevisionUtils.extract("$Revision: 5928 $");
1108  }
1109}
Note: See TracBrowser for help on using the repository browser.