source: src/main/java/weka/classifiers/trees/REPTree.java @ 4

Last change on this file since 4 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 55.2 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    REPTree.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.trees;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.classifiers.Sourcable;
28import weka.classifiers.rules.ZeroR;
29import weka.core.AdditionalMeasureProducer;
30import weka.core.Attribute;
31import weka.core.Capabilities;
32import weka.core.ContingencyTables;
33import weka.core.Drawable;
34import weka.core.Instance;
35import weka.core.Instances;
36import weka.core.Option;
37import weka.core.OptionHandler;
38import weka.core.RevisionHandler;
39import weka.core.RevisionUtils;
40import weka.core.Utils;
41import weka.core.WeightedInstancesHandler;
42import weka.core.Capabilities.Capability;
43
44import java.io.Serializable;
45import java.util.Enumeration;
46import java.util.Random;
47import java.util.Vector;
48
49/**
50 <!-- globalinfo-start -->
51 * Fast decision tree learner. Builds a decision/regression tree using information gain/variance and prunes it using reduced-error pruning (with backfitting).  Only sorts values for numeric attributes once. Missing values are dealt with by splitting the corresponding instances into pieces (i.e. as in C4.5).
52 * <p/>
53 <!-- globalinfo-end -->
54 *
55 <!-- options-start -->
56 * Valid options are: <p/>
57 *
58 * <pre> -M &lt;minimum number of instances&gt;
59 *  Set minimum number of instances per leaf (default 2).</pre>
60 *
61 * <pre> -V &lt;minimum variance for split&gt;
62 *  Set minimum numeric class variance proportion
63 *  of train variance for split (default 1e-3).</pre>
64 *
65 * <pre> -N &lt;number of folds&gt;
66 *  Number of folds for reduced error pruning (default 3).</pre>
67 *
68 * <pre> -S &lt;seed&gt;
69 *  Seed for random data shuffling (default 1).</pre>
70 *
71 * <pre> -P
72 *  No pruning.</pre>
73 *
74 * <pre> -L
75 *  Maximum tree depth (default -1, no maximum)</pre>
76 *
77 <!-- options-end -->
78 *
79 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
80 * @version $Revision: 5928 $
81 */
82public class REPTree 
83  extends AbstractClassifier
84  implements OptionHandler, WeightedInstancesHandler, Drawable, 
85             AdditionalMeasureProducer, Sourcable {
86
87  /** for serialization */
88  static final long serialVersionUID = -8562443428621539458L;
89 
90  /** ZeroR model that is used if no attributes are present. */
91  protected ZeroR m_zeroR;
92
93  /**
94   * Returns a string describing classifier
95   * @return a description suitable for
96   * displaying in the explorer/experimenter gui
97   */
98  public String globalInfo() {
99
100    return  "Fast decision tree learner. Builds a decision/regression tree using "
101      + "information gain/variance and prunes it using reduced-error pruning "
102      + "(with backfitting).  Only sorts values for numeric attributes "
103      + "once. Missing values are dealt with by splitting the corresponding "
104      + "instances into pieces (i.e. as in C4.5).";
105  }
106
107  /** An inner class for building and storing the tree structure */
108  protected class Tree 
109    implements Serializable, RevisionHandler {
110   
111    /** for serialization */
112    static final long serialVersionUID = -1635481717888437935L;
113   
114    /** The header information (for printing the tree). */
115    protected Instances m_Info = null;
116
117    /** The subtrees of this tree. */ 
118    protected Tree[] m_Successors;
119   
120    /** The attribute to split on. */
121    protected int m_Attribute = -1;
122
123    /** The split point. */
124    protected double m_SplitPoint = Double.NaN;
125   
126    /** The proportions of training instances going down each branch. */
127    protected double[] m_Prop = null;
128
129    /** Class probabilities from the training data in the nominal case.
130        Holds the mean in the numeric case. */
131    protected double[] m_ClassProbs = null;
132   
133    /** The (unnormalized) class distribution in the nominal
134        case. Holds the sum of squared errors and the weight
135        in the numeric case. */
136    protected double[] m_Distribution = null;
137   
138    /** Class distribution of hold-out set at node in the nominal case.
139        Straight sum of weights in the numeric case (i.e. array has
140        only one element. */
141    protected double[] m_HoldOutDist = null;
142   
143    /** The hold-out error of the node. The number of miss-classified
144        instances in the nominal case, the sum of squared errors in the
145        numeric case. */
146    protected double m_HoldOutError = 0;
147 
148    /**
149     * Computes class distribution of an instance using the tree.
150     *
151     * @param instance the instance to compute the distribution for
152     * @return the distribution
153     * @throws Exception if computation fails
154     */
155    protected double[] distributionForInstance(Instance instance) 
156      throws Exception {
157
158      double[] returnedDist = null;
159     
160      if (m_Attribute > -1) {
161       
162        // Node is not a leaf
163        if (instance.isMissing(m_Attribute)) {
164         
165          // Value is missing
166          returnedDist = new double[m_Info.numClasses()];
167
168          // Split instance up
169          for (int i = 0; i < m_Successors.length; i++) {
170            double[] help = 
171              m_Successors[i].distributionForInstance(instance);
172            if (help != null) {
173              for (int j = 0; j < help.length; j++) {
174                returnedDist[j] += m_Prop[i] * help[j];
175              }
176            }
177          }
178        } else if (m_Info.attribute(m_Attribute).isNominal()) {
179         
180          // For nominal attributes
181          returnedDist =  m_Successors[(int)instance.value(m_Attribute)].
182            distributionForInstance(instance);
183        } else {
184         
185          // For numeric attributes
186          if (instance.value(m_Attribute) < m_SplitPoint) {
187            returnedDist = 
188              m_Successors[0].distributionForInstance(instance);
189          } else {
190            returnedDist = 
191              m_Successors[1].distributionForInstance(instance);
192          }
193        }
194      }
195      if ((m_Attribute == -1) || (returnedDist == null)) {
196       
197        // Node is a leaf or successor is empty
198        return m_ClassProbs;
199      } else {
200        return returnedDist;
201      }
202    }
203
204   /**
205    * Returns a string containing java source code equivalent to the test
206    * made at this node. The instance being tested is called "i". This
207    * routine assumes to be called in the order of branching, enabling us to
208    * set the >= condition test (the last one) of a numeric splitpoint
209    * to just "true" (because being there in the flow implies that the
210    * previous less-than test failed).
211    *
212    * @param index index of the value tested
213    * @return a value of type 'String'
214    */
215    public final String sourceExpression(int index) {
216     
217      StringBuffer expr = null;
218      if (index < 0) {
219        return "i[" + m_Attribute + "] == null";
220      }
221      if (m_Info.attribute(m_Attribute).isNominal()) {
222        expr = new StringBuffer("i[");
223        expr.append(m_Attribute).append("]");
224        expr.append(".equals(\"").append(m_Info.attribute(m_Attribute)
225                .value(index)).append("\")");
226      } else {
227        expr = new StringBuffer("");
228        if (index == 0) {
229          expr.append("((Double)i[")
230            .append(m_Attribute).append("]).doubleValue() < ")
231            .append(m_SplitPoint);
232        } else {
233          expr.append("true");
234        }
235      }
236      return expr.toString();
237    }
238
239   /**
240    * Returns source code for the tree as if-then statements. The
241    * class is assigned to variable "p", and assumes the tested
242    * instance is named "i". The results are returned as two stringbuffers:
243    * a section of code for assignment of the class, and a section of
244    * code containing support code (eg: other support methods).
245    * <p/>
246    * TODO: If the outputted source code encounters a missing value
247    * for the evaluated attribute, it stops branching and uses the
248    * class distribution of the current node to decide the return value.
249    * This is unlike the behaviour of distributionForInstance().
250    *
251    * @param className the classname that this static classifier has
252    * @param parent parent node of the current node
253    * @return an array containing two stringbuffers, the first string containing
254    * assignment code, and the second containing source for support code.
255    * @throws Exception if something goes wrong
256    */
257    public StringBuffer [] toSource(String className, Tree parent) 
258      throws Exception {
259   
260      StringBuffer [] result = new StringBuffer[2];
261      double[] currentProbs;
262
263      if(m_ClassProbs == null)
264        currentProbs = parent.m_ClassProbs;
265      else
266        currentProbs = m_ClassProbs;
267
268      long printID = nextID();
269
270      // Is this a leaf?
271      if (m_Attribute == -1) {
272        result[0] = new StringBuffer("  p = ");
273        if(m_Info.classAttribute().isNumeric())
274          result[0].append(currentProbs[0]);
275        else {
276          result[0].append(Utils.maxIndex(currentProbs));
277        }
278        result[0].append(";\n");
279        result[1] = new StringBuffer("");
280      } else {
281        StringBuffer text = new StringBuffer("");
282        StringBuffer atEnd = new StringBuffer("");
283
284        text.append("  static double N")
285          .append(Integer.toHexString(this.hashCode()) + printID)
286          .append("(Object []i) {\n")
287          .append("    double p = Double.NaN;\n");
288
289        text.append("    /* " + m_Info.attribute(m_Attribute).name() + " */\n");
290        // Missing attribute?
291        text.append("    if (" + this.sourceExpression(-1) + ") {\n")
292          .append("      p = ");
293        if(m_Info.classAttribute().isNumeric())
294          text.append(currentProbs[0] + ";\n");
295        else
296          text.append(Utils.maxIndex(currentProbs) + ";\n");
297        text.append("    } ");
298       
299        // Branching of the tree
300        for (int i=0;i<m_Successors.length; i++) {
301          text.append("else if (" + this.sourceExpression(i) + ") {\n");
302          // Is the successor a leaf?
303          if(m_Successors[i].m_Attribute == -1) {
304            double[] successorProbs = m_Successors[i].m_ClassProbs;
305            if(successorProbs == null)
306              successorProbs = m_ClassProbs;
307            text.append("      p = ");
308            if(m_Info.classAttribute().isNumeric()) {
309              text.append(successorProbs[0] + ";\n");
310            } else {
311              text.append(Utils.maxIndex(successorProbs) + ";\n");
312            }
313          } else {
314            StringBuffer [] sub = m_Successors[i].toSource(className, this);
315            text.append("" + sub[0]);
316            atEnd.append("" + sub[1]);
317          }
318          text.append("    } ");
319          if (i == m_Successors.length - 1) {
320            text.append("\n");
321          }
322        }
323
324        text.append("    return p;\n  }\n");
325
326        result[0] = new StringBuffer("    p = " + className + ".N");
327        result[0].append(Integer.toHexString(this.hashCode()) + printID)
328          .append("(i);\n");
329        result[1] = text.append("" + atEnd);
330      }
331      return result;
332    }
333
334       
335    /**
336     * Outputs one node for graph.
337     *
338     * @param text the buffer to append the output to
339     * @param num the current node id
340     * @param parent the parent of the nodes
341     * @return the next node id
342     * @throws Exception if something goes wrong
343     */
344    protected int toGraph(StringBuffer text, int num,
345                        Tree parent) throws Exception {
346     
347      num++;
348      if (m_Attribute == -1) {
349        text.append("N" + Integer.toHexString(Tree.this.hashCode()) +
350                    " [label=\"" + num + leafString(parent) +"\"" +
351                    "shape=box]\n");
352      } else {
353        text.append("N" + Integer.toHexString(Tree.this.hashCode()) +
354                    " [label=\"" + num + ": " + 
355                    m_Info.attribute(m_Attribute).name() + 
356                    "\"]\n");
357        for (int i = 0; i < m_Successors.length; i++) {
358          text.append("N" + Integer.toHexString(Tree.this.hashCode()) 
359                      + "->" + 
360                      "N" + 
361                      Integer.toHexString(m_Successors[i].hashCode())  +
362                      " [label=\"");
363          if (m_Info.attribute(m_Attribute).isNumeric()) {
364            if (i == 0) {
365              text.append(" < " +
366                          Utils.doubleToString(m_SplitPoint, 2));
367            } else {
368              text.append(" >= " +
369                          Utils.doubleToString(m_SplitPoint, 2));
370            }
371          } else {
372            text.append(" = " + m_Info.attribute(m_Attribute).value(i));
373          }
374          text.append("\"]\n");
375          num = m_Successors[i].toGraph(text, num, this);
376        }
377      }
378     
379      return num;
380    }
381
382    /**
383     * Outputs description of a leaf node.
384     *
385     * @param parent the parent of the node
386     * @return the description of the node
387     * @throws Exception if generation fails
388     */
389    protected String leafString(Tree parent) throws Exception {
390   
391      if (m_Info.classAttribute().isNumeric()) {
392        double classMean;
393        if (m_ClassProbs == null) {
394          classMean = parent.m_ClassProbs[0];
395        } else {
396          classMean = m_ClassProbs[0];
397        }
398        StringBuffer buffer = new StringBuffer();
399        buffer.append(" : " + Utils.doubleToString(classMean, 2));
400        double avgError = 0;
401        if (m_Distribution[1] > 0) {
402          avgError = m_Distribution[0] / m_Distribution[1];
403        }
404        buffer.append(" (" +
405                      Utils.doubleToString(m_Distribution[1], 2) + "/" +
406                      Utils.doubleToString(avgError, 2) 
407                      + ")");
408        avgError = 0;
409        if (m_HoldOutDist[0] > 0) {
410          avgError = m_HoldOutError / m_HoldOutDist[0];
411        }
412        buffer.append(" [" +
413                      Utils.doubleToString(m_HoldOutDist[0], 2) + "/" +
414                      Utils.doubleToString(avgError, 2) 
415                      + "]");
416        return buffer.toString();
417      } else { 
418        int maxIndex;
419        if (m_ClassProbs == null) {
420          maxIndex = Utils.maxIndex(parent.m_ClassProbs);
421        } else {
422          maxIndex = Utils.maxIndex(m_ClassProbs);
423        }
424        return " : " + m_Info.classAttribute().value(maxIndex) + 
425          " (" + Utils.doubleToString(Utils.sum(m_Distribution), 2) + 
426          "/" + 
427          Utils.doubleToString((Utils.sum(m_Distribution) - 
428                                m_Distribution[maxIndex]), 2) + ")" +
429          " [" + Utils.doubleToString(Utils.sum(m_HoldOutDist), 2) + "/" + 
430          Utils.doubleToString((Utils.sum(m_HoldOutDist) - 
431                                m_HoldOutDist[maxIndex]), 2) + "]";
432      }
433    }
434 
435    /**
436     * Recursively outputs the tree.
437     *
438     * @param level the current level
439     * @param parent the current parent
440     * @return the generated substree
441     */
442    protected String toString(int level, Tree parent) {
443
444      try {
445        StringBuffer text = new StringBuffer();
446     
447        if (m_Attribute == -1) {
448       
449          // Output leaf info
450          return leafString(parent);
451        } else if (m_Info.attribute(m_Attribute).isNominal()) {
452       
453          // For nominal attributes
454          for (int i = 0; i < m_Successors.length; i++) {
455            text.append("\n");
456            for (int j = 0; j < level; j++) {
457              text.append("|   ");
458            }
459            text.append(m_Info.attribute(m_Attribute).name() + " = " +
460                        m_Info.attribute(m_Attribute).value(i));
461            text.append(m_Successors[i].toString(level + 1, this));
462          }
463        } else {
464       
465          // For numeric attributes
466          text.append("\n");
467          for (int j = 0; j < level; j++) {
468            text.append("|   ");
469          }
470          text.append(m_Info.attribute(m_Attribute).name() + " < " +
471                      Utils.doubleToString(m_SplitPoint, 2));
472          text.append(m_Successors[0].toString(level + 1, this));
473          text.append("\n");
474          for (int j = 0; j < level; j++) {
475            text.append("|   ");
476          }
477          text.append(m_Info.attribute(m_Attribute).name() + " >= " +
478                      Utils.doubleToString(m_SplitPoint, 2));
479          text.append(m_Successors[1].toString(level + 1, this));
480        }
481     
482        return text.toString();
483      } catch (Exception e) {
484        e.printStackTrace();
485        return "Decision tree: tree can't be printed";
486      }
487    }     
488
489    /**
490     * Recursively generates a tree.
491     *
492     * @param sortedIndices the sorted indices of the instances
493     * @param weights the weights of the instances
494     * @param data the data to work with
495     * @param totalWeight
496     * @param classProbs the class probabilities
497     * @param header the header of the data
498     * @param minNum the minimum number of instances in a leaf
499     * @param minVariance
500     * @param depth the current depth of the tree
501     * @param maxDepth the maximum allowed depth of the tree
502     * @throws Exception if generation fails
503     */
504    protected void buildTree(int[][] sortedIndices, double[][] weights,
505                             Instances data, double totalWeight, 
506                             double[] classProbs, Instances header,
507                             double minNum, double minVariance,
508                             int depth, int maxDepth) 
509      throws Exception {
510     
511      // Store structure of dataset, set minimum number of instances
512      // and make space for potential info from pruning data
513      m_Info = header;
514      m_HoldOutDist = new double[data.numClasses()];
515       
516      // Make leaf if there are no training instances
517      int helpIndex = 0;
518      if (data.classIndex() == 0) {
519        helpIndex = 1;
520      }
521      if (sortedIndices[helpIndex].length == 0) {
522        if (data.classAttribute().isNumeric()) {
523          m_Distribution = new double[2];
524        } else {
525          m_Distribution = new double[data.numClasses()];
526        }
527        m_ClassProbs = null;
528        return;
529      }
530     
531      double priorVar = 0;
532      if (data.classAttribute().isNumeric()) {
533
534        // Compute prior variance
535        double totalSum = 0, totalSumSquared = 0, totalSumOfWeights = 0; 
536        for (int i = 0; i < sortedIndices[helpIndex].length; i++) {
537          Instance inst = data.instance(sortedIndices[helpIndex][i]);
538          totalSum += inst.classValue() * weights[helpIndex][i];
539          totalSumSquared += 
540            inst.classValue() * inst.classValue() * weights[helpIndex][i];
541          totalSumOfWeights += weights[helpIndex][i];
542        }
543        priorVar = singleVariance(totalSum, totalSumSquared, 
544                                  totalSumOfWeights);
545      }
546
547      // Check if node doesn't contain enough instances, is pure
548      // or the maximum tree depth is reached
549      m_ClassProbs = new double[classProbs.length];
550      System.arraycopy(classProbs, 0, m_ClassProbs, 0, classProbs.length);
551      if ((totalWeight < (2 * minNum)) ||
552
553          // Nominal case
554          (data.classAttribute().isNominal() &&
555           Utils.eq(m_ClassProbs[Utils.maxIndex(m_ClassProbs)],
556                    Utils.sum(m_ClassProbs))) ||
557
558          // Numeric case
559          (data.classAttribute().isNumeric() && 
560           ((priorVar / totalWeight) < minVariance)) ||
561
562          // Check tree depth
563          ((m_MaxDepth >= 0) && (depth >= maxDepth))) {
564
565        // Make leaf
566        m_Attribute = -1;
567        if (data.classAttribute().isNominal()) {
568
569          // Nominal case
570          m_Distribution = new double[m_ClassProbs.length];
571          for (int i = 0; i < m_ClassProbs.length; i++) {
572            m_Distribution[i] = m_ClassProbs[i];
573          }
574          Utils.normalize(m_ClassProbs);
575        } else {
576
577          // Numeric case
578          m_Distribution = new double[2];
579          m_Distribution[0] = priorVar;
580          m_Distribution[1] = totalWeight;
581        }
582        return;
583      }
584
585      // Compute class distributions and value of splitting
586      // criterion for each attribute
587      double[] vals = new double[data.numAttributes()];
588      double[][][] dists = new double[data.numAttributes()][0][0];
589      double[][] props = new double[data.numAttributes()][0];
590      double[][] totalSubsetWeights = new double[data.numAttributes()][0];
591      double[] splits = new double[data.numAttributes()];
592      if (data.classAttribute().isNominal()) { 
593
594        // Nominal case
595        for (int i = 0; i < data.numAttributes(); i++) {
596          if (i != data.classIndex()) {
597            splits[i] = distribution(props, dists, i, sortedIndices[i], 
598                                     weights[i], totalSubsetWeights, data);
599            vals[i] = gain(dists[i], priorVal(dists[i]));
600          }
601        }
602      } else {
603
604        // Numeric case
605        for (int i = 0; i < data.numAttributes(); i++) {
606          if (i != data.classIndex()) {
607            splits[i] = 
608              numericDistribution(props, dists, i, sortedIndices[i], 
609                                  weights[i], totalSubsetWeights, data, 
610                                  vals);
611          }
612        }
613      }
614
615      // Find best attribute
616      m_Attribute = Utils.maxIndex(vals);
617      int numAttVals = dists[m_Attribute].length;
618
619      // Check if there are at least two subsets with
620      // required minimum number of instances
621      int count = 0;
622      for (int i = 0; i < numAttVals; i++) {
623        if (totalSubsetWeights[m_Attribute][i] >= minNum) {
624          count++;
625        }
626        if (count > 1) {
627          break;
628        }
629      }
630
631      // Any useful split found?
632      if ((vals[m_Attribute] > 0) && (count > 1)) {
633
634        // Build subtrees
635        m_SplitPoint = splits[m_Attribute];
636        m_Prop = props[m_Attribute];
637        int[][][] subsetIndices = 
638          new int[numAttVals][data.numAttributes()][0];
639        double[][][] subsetWeights = 
640          new double[numAttVals][data.numAttributes()][0];
641        splitData(subsetIndices, subsetWeights, m_Attribute, m_SplitPoint, 
642                  sortedIndices, weights, data);
643        m_Successors = new Tree[numAttVals];
644        for (int i = 0; i < numAttVals; i++) {
645          m_Successors[i] = new Tree();
646          m_Successors[i].
647            buildTree(subsetIndices[i], subsetWeights[i], 
648                      data, totalSubsetWeights[m_Attribute][i],
649                      dists[m_Attribute][i], header, minNum, 
650                      minVariance, depth + 1, maxDepth);
651        }
652      } else {
653     
654        // Make leaf
655        m_Attribute = -1;
656      }
657
658      // Normalize class counts
659      if (data.classAttribute().isNominal()) {
660        m_Distribution = new double[m_ClassProbs.length];
661        for (int i = 0; i < m_ClassProbs.length; i++) {
662            m_Distribution[i] = m_ClassProbs[i];
663        }
664        Utils.normalize(m_ClassProbs);
665      } else {
666        m_Distribution = new double[2];
667        m_Distribution[0] = priorVar;
668        m_Distribution[1] = totalWeight;
669      }
670    }
671
672    /**
673     * Computes size of the tree.
674     *
675     * @return the number of nodes
676     */
677    protected int numNodes() {
678   
679      if (m_Attribute == -1) {
680        return 1;
681      } else {
682        int size = 1;
683        for (int i = 0; i < m_Successors.length; i++) {
684          size += m_Successors[i].numNodes();
685        }
686        return size;
687      }
688    }
689
690    /**
691     * Splits instances into subsets.
692     *
693     * @param subsetIndices the sorted indices in the subset
694     * @param subsetWeights the weights of the subset
695     * @param att the attribute index
696     * @param splitPoint the split point for numeric attributes
697     * @param sortedIndices the sorted indices of the whole set
698     * @param weights the weights of the whole set
699     * @param data the data to work with
700     * @throws Exception if something goes wrong
701     */
702    protected void splitData(int[][][] subsetIndices, 
703                             double[][][] subsetWeights,
704                             int att, double splitPoint, 
705                             int[][] sortedIndices, double[][] weights, 
706                             Instances data) throws Exception {
707   
708      int j;
709      int[] num;
710   
711      // For each attribute
712      for (int i = 0; i < data.numAttributes(); i++) {
713        if (i != data.classIndex()) {
714          if (data.attribute(att).isNominal()) {
715
716            // For nominal attributes
717            num = new int[data.attribute(att).numValues()];
718            for (int k = 0; k < num.length; k++) {
719              subsetIndices[k][i] = new int[sortedIndices[i].length];
720              subsetWeights[k][i] = new double[sortedIndices[i].length];
721            }
722            for (j = 0; j < sortedIndices[i].length; j++) {
723              Instance inst = data.instance(sortedIndices[i][j]);
724              if (inst.isMissing(att)) {
725
726                // Split instance up
727                for (int k = 0; k < num.length; k++) {
728                  if (m_Prop[k] > 0) {
729                    subsetIndices[k][i][num[k]] = sortedIndices[i][j];
730                    subsetWeights[k][i][num[k]] = 
731                      m_Prop[k] * weights[i][j];
732                    num[k]++;
733                  }
734                }
735              } else {
736                int subset = (int)inst.value(att);
737                subsetIndices[subset][i][num[subset]] = 
738                  sortedIndices[i][j];
739                subsetWeights[subset][i][num[subset]] = weights[i][j];
740                num[subset]++;
741              }
742            }
743          } else {
744
745            // For numeric attributes
746            num = new int[2];
747            for (int k = 0; k < 2; k++) {
748              subsetIndices[k][i] = new int[sortedIndices[i].length];
749              subsetWeights[k][i] = new double[weights[i].length];
750            }
751            for (j = 0; j < sortedIndices[i].length; j++) {
752              Instance inst = data.instance(sortedIndices[i][j]);
753              if (inst.isMissing(att)) {
754
755                // Split instance up
756                for (int k = 0; k < num.length; k++) {
757                  if (m_Prop[k] > 0) {
758                    subsetIndices[k][i][num[k]] = sortedIndices[i][j];
759                    subsetWeights[k][i][num[k]] = 
760                      m_Prop[k] * weights[i][j];
761                    num[k]++;
762                  }
763                }
764              } else {
765                int subset = (inst.value(att) < splitPoint) ? 0 : 1;
766                subsetIndices[subset][i][num[subset]] = 
767                  sortedIndices[i][j];
768                subsetWeights[subset][i][num[subset]] = weights[i][j];
769                num[subset]++;
770              } 
771            }
772          }
773       
774          // Trim arrays
775          for (int k = 0; k < num.length; k++) {
776            int[] copy = new int[num[k]];
777            System.arraycopy(subsetIndices[k][i], 0, copy, 0, num[k]);
778            subsetIndices[k][i] = copy;
779            double[] copyWeights = new double[num[k]];
780            System.arraycopy(subsetWeights[k][i], 0,
781                             copyWeights, 0, num[k]);
782            subsetWeights[k][i] = copyWeights;
783          }
784        }
785      }
786    }
787
788    /**
789     * Computes class distribution for an attribute.
790     *
791     * @param props
792     * @param dists
793     * @param att the attribute index
794     * @param sortedIndices the sorted indices of the instances
795     * @param weights the weights of the instances
796     * @param subsetWeights the weights of the subset
797     * @param data the data to work with
798     * @return the split point
799     * @throws Exception if computation fails
800     */
801    protected double distribution(double[][] props,
802                                  double[][][] dists, int att, 
803                                  int[] sortedIndices,
804                                  double[] weights, 
805                                  double[][] subsetWeights, 
806                                  Instances data) 
807      throws Exception {
808
809      double splitPoint = Double.NaN;
810      Attribute attribute = data.attribute(att);
811      double[][] dist = null;
812      int i;
813
814      if (attribute.isNominal()) {
815
816        // For nominal attributes
817        dist = new double[attribute.numValues()][data.numClasses()];
818        for (i = 0; i < sortedIndices.length; i++) {
819          Instance inst = data.instance(sortedIndices[i]);
820          if (inst.isMissing(att)) {
821            break;
822          }
823          dist[(int)inst.value(att)][(int)inst.classValue()] += weights[i];
824        }
825      } else {
826
827        // For numeric attributes
828        double[][] currDist = new double[2][data.numClasses()];
829        dist = new double[2][data.numClasses()];
830
831        // Move all instances into second subset
832        for (int j = 0; j < sortedIndices.length; j++) {
833          Instance inst = data.instance(sortedIndices[j]);
834          if (inst.isMissing(att)) {
835            break;
836          }
837          currDist[1][(int)inst.classValue()] += weights[j];
838        }
839        double priorVal = priorVal(currDist);
840        System.arraycopy(currDist[1], 0, dist[1], 0, dist[1].length);
841
842        // Try all possible split points
843        double currSplit = data.instance(sortedIndices[0]).value(att);
844        double currVal, bestVal = -Double.MAX_VALUE;
845        for (i = 0; i < sortedIndices.length; i++) {
846          Instance inst = data.instance(sortedIndices[i]);
847          if (inst.isMissing(att)) {
848            break;
849          }
850          if (inst.value(att) > currSplit) {
851            currVal = gain(currDist, priorVal);
852            if (currVal > bestVal) {
853              bestVal = currVal;
854              splitPoint = (inst.value(att) + currSplit) / 2.0;
855              for (int j = 0; j < currDist.length; j++) {
856                System.arraycopy(currDist[j], 0, dist[j], 0, 
857                                 dist[j].length);
858              }
859            } 
860          } 
861          currSplit = inst.value(att);
862          currDist[0][(int)inst.classValue()] += weights[i];
863          currDist[1][(int)inst.classValue()] -= weights[i];
864        }
865      }
866
867      // Compute weights
868      props[att] = new double[dist.length];
869      for (int k = 0; k < props[att].length; k++) {
870        props[att][k] = Utils.sum(dist[k]);
871      }
872      if (!(Utils.sum(props[att]) > 0)) {
873        for (int k = 0; k < props[att].length; k++) {
874          props[att][k] = 1.0 / (double)props[att].length;
875        }
876      } else {
877        Utils.normalize(props[att]);
878      }
879   
880      // Distribute counts
881      while (i < sortedIndices.length) {
882        Instance inst = data.instance(sortedIndices[i]);
883        for (int j = 0; j < dist.length; j++) {
884          dist[j][(int)inst.classValue()] += props[att][j] * weights[i];
885        }
886        i++;
887      }
888
889      // Compute subset weights
890      subsetWeights[att] = new double[dist.length];
891      for (int j = 0; j < dist.length; j++) {
892        subsetWeights[att][j] += Utils.sum(dist[j]);
893      }
894
895      // Return distribution and split point
896      dists[att] = dist;
897      return splitPoint;
898    }     
899
900    /**
901     * Computes class distribution for an attribute.
902     *
903     * @param props
904     * @param dists
905     * @param att the attribute index
906     * @param sortedIndices the sorted indices of the instances
907     * @param weights the weights of the instances
908     * @param subsetWeights the weights of the subset
909     * @param data the data to work with
910     * @param vals
911     * @return the split point
912     * @throws Exception if computation fails
913     */
914    protected double numericDistribution(double[][] props, 
915                                         double[][][] dists, int att, 
916                                         int[] sortedIndices,
917                                         double[] weights, 
918                                         double[][] subsetWeights, 
919                                         Instances data,
920                                         double[] vals) 
921      throws Exception {
922
923      double splitPoint = Double.NaN;
924      Attribute attribute = data.attribute(att);
925      double[][] dist = null;
926      double[] sums = null;
927      double[] sumSquared = null;
928      double[] sumOfWeights = null;
929      double totalSum = 0, totalSumSquared = 0, totalSumOfWeights = 0;
930
931      int i;
932
933      if (attribute.isNominal()) {
934
935        // For nominal attributes
936        sums = new double[attribute.numValues()];
937        sumSquared = new double[attribute.numValues()];
938        sumOfWeights = new double[attribute.numValues()];
939        int attVal;
940        for (i = 0; i < sortedIndices.length; i++) {
941          Instance inst = data.instance(sortedIndices[i]);
942          if (inst.isMissing(att)) {
943            break;
944          }
945          attVal = (int)inst.value(att);
946          sums[attVal] += inst.classValue() * weights[i];
947          sumSquared[attVal] += 
948            inst.classValue() * inst.classValue() * weights[i];
949          sumOfWeights[attVal] += weights[i];
950        }
951        totalSum = Utils.sum(sums);
952        totalSumSquared = Utils.sum(sumSquared);
953        totalSumOfWeights = Utils.sum(sumOfWeights);
954      } else {
955
956        // For numeric attributes
957        sums = new double[2];
958        sumSquared = new double[2];
959        sumOfWeights = new double[2];
960        double[] currSums = new double[2];
961        double[] currSumSquared = new double[2];
962        double[] currSumOfWeights = new double[2];
963
964        // Move all instances into second subset
965        for (int j = 0; j < sortedIndices.length; j++) {
966          Instance inst = data.instance(sortedIndices[j]);
967          if (inst.isMissing(att)) {
968            break;
969          }
970          currSums[1] += inst.classValue() * weights[j];
971          currSumSquared[1] += 
972            inst.classValue() * inst.classValue() * weights[j];
973          currSumOfWeights[1] += weights[j];
974         
975        }
976        totalSum = currSums[1];
977        totalSumSquared = currSumSquared[1];
978        totalSumOfWeights = currSumOfWeights[1];
979       
980        sums[1] = currSums[1];
981        sumSquared[1] = currSumSquared[1];
982        sumOfWeights[1] = currSumOfWeights[1];
983
984        // Try all possible split points
985        double currSplit = data.instance(sortedIndices[0]).value(att);
986        double currVal, bestVal = Double.MAX_VALUE;
987        for (i = 0; i < sortedIndices.length; i++) {
988          Instance inst = data.instance(sortedIndices[i]);
989          if (inst.isMissing(att)) {
990            break;
991          }
992          if (inst.value(att) > currSplit) {
993            currVal = variance(currSums, currSumSquared, currSumOfWeights);
994            if (currVal < bestVal) {
995              bestVal = currVal;
996              splitPoint = (inst.value(att) + currSplit) / 2.0;
997              for (int j = 0; j < 2; j++) {
998                sums[j] = currSums[j];
999                sumSquared[j] = currSumSquared[j];
1000                sumOfWeights[j] = currSumOfWeights[j];
1001              }
1002            } 
1003          } 
1004
1005          currSplit = inst.value(att);
1006
1007          double classVal = inst.classValue() * weights[i];
1008          double classValSquared = inst.classValue() * classVal;
1009
1010          currSums[0] += classVal;
1011          currSumSquared[0] += classValSquared;
1012          currSumOfWeights[0] += weights[i];
1013
1014          currSums[1] -= classVal;
1015          currSumSquared[1] -= classValSquared;
1016          currSumOfWeights[1] -= weights[i];
1017        }
1018      }
1019
1020      // Compute weights
1021      props[att] = new double[sums.length];
1022      for (int k = 0; k < props[att].length; k++) {
1023        props[att][k] = sumOfWeights[k];
1024      }
1025      if (!(Utils.sum(props[att]) > 0)) {
1026        for (int k = 0; k < props[att].length; k++) {
1027          props[att][k] = 1.0 / (double)props[att].length;
1028        }
1029      } else {
1030        Utils.normalize(props[att]);
1031      }
1032   
1033       
1034      // Distribute counts for missing values
1035      while (i < sortedIndices.length) {
1036        Instance inst = data.instance(sortedIndices[i]);
1037        for (int j = 0; j < sums.length; j++) {
1038          sums[j] += props[att][j] * inst.classValue() * weights[i];
1039          sumSquared[j] += props[att][j] * inst.classValue() * 
1040            inst.classValue() * weights[i];
1041          sumOfWeights[j] += props[att][j] * weights[i];
1042        }
1043        totalSum += inst.classValue() * weights[i];
1044        totalSumSquared += 
1045          inst.classValue() * inst.classValue() * weights[i]; 
1046        totalSumOfWeights += weights[i];
1047        i++;
1048      }
1049
1050      // Compute final distribution
1051      dist = new double[sums.length][data.numClasses()];
1052      for (int j = 0; j < sums.length; j++) {
1053        if (sumOfWeights[j] > 0) {
1054          dist[j][0] = sums[j] / sumOfWeights[j];
1055        } else {
1056          dist[j][0] = totalSum / totalSumOfWeights;
1057        }
1058      }
1059     
1060      // Compute variance gain
1061      double priorVar =
1062        singleVariance(totalSum, totalSumSquared, totalSumOfWeights);
1063      double var = variance(sums, sumSquared, sumOfWeights);
1064      double gain = priorVar - var;
1065     
1066      // Return distribution and split point
1067      subsetWeights[att] = sumOfWeights;
1068      dists[att] = dist;
1069      vals[att] = gain;
1070      return splitPoint;
1071    }     
1072
1073    /**
1074     * Computes variance for subsets.
1075     *
1076     * @param s
1077     * @param sS
1078     * @param sumOfWeights
1079     * @return the variance
1080     */
1081    protected double variance(double[] s, double[] sS, 
1082                            double[] sumOfWeights) {
1083     
1084      double var = 0;
1085     
1086      for (int i = 0; i < s.length; i++) {
1087        if (sumOfWeights[i] > 0) {
1088          var += singleVariance(s[i], sS[i], sumOfWeights[i]);
1089        }
1090      }
1091     
1092      return var;
1093    }
1094   
1095    /**
1096     * Computes the variance for a single set
1097     *
1098     * @param s
1099     * @param sS
1100     * @param weight the weight
1101     * @return the variance
1102     */
1103    protected double singleVariance(double s, double sS, double weight) {
1104     
1105      return sS - ((s * s) / weight);
1106    }
1107
1108    /**
1109     * Computes value of splitting criterion before split.
1110     *
1111     * @param dist
1112     * @return the splitting criterion
1113     */
1114    protected double priorVal(double[][] dist) {
1115
1116      return ContingencyTables.entropyOverColumns(dist);
1117    }
1118
1119    /**
1120     * Computes value of splitting criterion after split.
1121     *
1122     * @param dist
1123     * @param priorVal the splitting criterion
1124     * @return the gain after splitting
1125     */
1126    protected double gain(double[][] dist, double priorVal) {
1127
1128      return priorVal - ContingencyTables.entropyConditionedOnRows(dist);
1129    }
1130
1131    /**
1132     * Prunes the tree using the hold-out data (bottom-up).
1133     *
1134     * @return the error
1135     * @throws Exception if pruning fails for some reason
1136     */
1137    protected double reducedErrorPrune() throws Exception {
1138
1139      // Is node leaf ?
1140      if (m_Attribute == -1) {
1141        return m_HoldOutError;
1142      }
1143
1144      // Prune all sub trees
1145      double errorTree = 0;
1146      for (int i = 0; i < m_Successors.length; i++) {
1147        errorTree += m_Successors[i].reducedErrorPrune();
1148      }
1149
1150      // Replace sub tree with leaf if error doesn't get worse
1151      if (errorTree >= m_HoldOutError) {
1152        m_Attribute = -1;
1153        m_Successors = null;
1154        return m_HoldOutError;
1155      } else {
1156        return errorTree;
1157      }
1158    }
1159
1160    /**
1161     * Inserts hold-out set into tree.
1162     *
1163     * @param data the data to insert
1164     * @throws Exception if something goes wrong
1165     */
1166    protected void insertHoldOutSet(Instances data) throws Exception {
1167
1168      for (int i = 0; i < data.numInstances(); i++) {
1169        insertHoldOutInstance(data.instance(i), data.instance(i).weight(),
1170                              this);
1171      }
1172    }
1173
1174    /**
1175     * Inserts an instance from the hold-out set into the tree.
1176     *
1177     * @param inst the instance to insert
1178     * @param weight the weight of the instance
1179     * @param parent the parent of the node
1180     * @throws Exception if insertion fails
1181     */
1182    protected void insertHoldOutInstance(Instance inst, double weight, 
1183                                         Tree parent) throws Exception {
1184     
1185      // Insert instance into hold-out class distribution
1186      if (inst.classAttribute().isNominal()) {
1187       
1188        // Nominal case
1189        m_HoldOutDist[(int)inst.classValue()] += weight;
1190        int predictedClass = 0;
1191        if (m_ClassProbs == null) {
1192          predictedClass = Utils.maxIndex(parent.m_ClassProbs);
1193        } else {
1194          predictedClass = Utils.maxIndex(m_ClassProbs);
1195        }
1196        if (predictedClass != (int)inst.classValue()) {
1197          m_HoldOutError += weight;
1198        }
1199      } else {
1200       
1201        // Numeric case
1202        m_HoldOutDist[0] += weight;
1203        double diff = 0;
1204        if (m_ClassProbs == null) {
1205          diff = parent.m_ClassProbs[0] - inst.classValue();
1206        } else {
1207          diff =  m_ClassProbs[0] - inst.classValue();
1208        }
1209        m_HoldOutError += diff * diff * weight;
1210      } 
1211     
1212      // The process is recursive
1213      if (m_Attribute != -1) {
1214       
1215        // If node is not a leaf
1216        if (inst.isMissing(m_Attribute)) {
1217         
1218          // Distribute instance
1219          for (int i = 0; i < m_Successors.length; i++) {
1220            if (m_Prop[i] > 0) {
1221              m_Successors[i].insertHoldOutInstance(inst, weight * 
1222                                                    m_Prop[i], this);
1223            }
1224          }
1225        } else {
1226         
1227          if (m_Info.attribute(m_Attribute).isNominal()) {
1228           
1229            // Treat nominal attributes
1230            m_Successors[(int)inst.value(m_Attribute)].
1231              insertHoldOutInstance(inst, weight, this);
1232          } else {
1233           
1234            // Treat numeric attributes
1235            if (inst.value(m_Attribute) < m_SplitPoint) {
1236              m_Successors[0].insertHoldOutInstance(inst, weight, this);
1237            } else {
1238              m_Successors[1].insertHoldOutInstance(inst, weight, this);
1239            }
1240          }
1241        }
1242      }
1243    }
1244 
1245    /**
1246     * Inserts hold-out set into tree.
1247     *
1248     * @param data the data to insert
1249     * @throws Exception if insertion fails
1250     */
1251    protected void backfitHoldOutSet(Instances data) throws Exception {
1252     
1253      for (int i = 0; i < data.numInstances(); i++) {
1254        backfitHoldOutInstance(data.instance(i), data.instance(i).weight(),
1255                               this);
1256      }
1257    }
1258   
1259    /**
1260     * Inserts an instance from the hold-out set into the tree.
1261     *
1262     * @param inst the instance to insert
1263     * @param weight the weight of the instance
1264     * @param parent the parent node
1265     * @throws Exception if insertion fails
1266     */
1267    protected void backfitHoldOutInstance(Instance inst, double weight, 
1268                                          Tree parent) throws Exception {
1269     
1270      // Insert instance into hold-out class distribution
1271      if (inst.classAttribute().isNominal()) {
1272       
1273        // Nominal case
1274        if (m_ClassProbs == null) {
1275          m_ClassProbs = new double[inst.numClasses()];
1276        }
1277        System.arraycopy(m_Distribution, 0, m_ClassProbs, 0, inst.numClasses());
1278        m_ClassProbs[(int)inst.classValue()] += weight;
1279        Utils.normalize(m_ClassProbs);
1280      } else {
1281       
1282        // Numeric case
1283        if (m_ClassProbs == null) {
1284          m_ClassProbs = new double[1];
1285        }
1286        m_ClassProbs[0] *= m_Distribution[1];
1287        m_ClassProbs[0] += weight * inst.classValue();
1288        m_ClassProbs[0] /= (m_Distribution[1] + weight);
1289      } 
1290     
1291      // The process is recursive
1292      if (m_Attribute != -1) {
1293       
1294        // If node is not a leaf
1295        if (inst.isMissing(m_Attribute)) {
1296         
1297          // Distribute instance
1298          for (int i = 0; i < m_Successors.length; i++) {
1299            if (m_Prop[i] > 0) {
1300              m_Successors[i].backfitHoldOutInstance(inst, weight * 
1301                                                     m_Prop[i], this);
1302            }
1303          }
1304        } else {
1305         
1306          if (m_Info.attribute(m_Attribute).isNominal()) {
1307           
1308            // Treat nominal attributes
1309            m_Successors[(int)inst.value(m_Attribute)].
1310              backfitHoldOutInstance(inst, weight, this);
1311          } else {
1312           
1313            // Treat numeric attributes
1314            if (inst.value(m_Attribute) < m_SplitPoint) {
1315              m_Successors[0].backfitHoldOutInstance(inst, weight, this);
1316            } else {
1317              m_Successors[1].backfitHoldOutInstance(inst, weight, this);
1318            }
1319          }
1320        }
1321      }
1322    }
1323   
1324    /**
1325     * Returns the revision string.
1326     *
1327     * @return          the revision
1328     */
1329    public String getRevision() {
1330      return RevisionUtils.extract("$Revision: 5928 $");
1331    }
1332  }
1333
1334  /** The Tree object */
1335  protected Tree m_Tree = null;
1336   
1337  /** Number of folds for reduced error pruning. */
1338  protected int m_NumFolds = 3;
1339   
1340  /** Seed for random data shuffling. */
1341  protected int m_Seed = 1;
1342   
1343  /** Don't prune */
1344  protected boolean m_NoPruning = false;
1345
1346  /** The minimum number of instances per leaf. */
1347  protected double m_MinNum = 2;
1348
1349  /** The minimum proportion of the total variance (over all the data)
1350      required for split. */
1351  protected double m_MinVarianceProp = 1e-3;
1352
1353  /** Upper bound on the tree depth */
1354  protected int m_MaxDepth = -1;
1355 
1356  /**
1357   * Returns the tip text for this property
1358   * @return tip text for this property suitable for
1359   * displaying in the explorer/experimenter gui
1360   */
1361  public String noPruningTipText() {
1362    return "Whether pruning is performed.";
1363  }
1364 
1365  /**
1366   * Get the value of NoPruning.
1367   *
1368   * @return Value of NoPruning.
1369   */
1370  public boolean getNoPruning() {
1371   
1372    return m_NoPruning;
1373  }
1374 
1375  /**
1376   * Set the value of NoPruning.
1377   *
1378   * @param newNoPruning Value to assign to NoPruning.
1379   */
1380  public void setNoPruning(boolean newNoPruning) {
1381   
1382    m_NoPruning = newNoPruning;
1383  }
1384 
1385  /**
1386   * Returns the tip text for this property
1387   * @return tip text for this property suitable for
1388   * displaying in the explorer/experimenter gui
1389   */
1390  public String minNumTipText() {
1391    return "The minimum total weight of the instances in a leaf.";
1392  }
1393
1394  /**
1395   * Get the value of MinNum.
1396   *
1397   * @return Value of MinNum.
1398   */
1399  public double getMinNum() {
1400   
1401    return m_MinNum;
1402  }
1403 
1404  /**
1405   * Set the value of MinNum.
1406   *
1407   * @param newMinNum Value to assign to MinNum.
1408   */
1409  public void setMinNum(double newMinNum) {
1410   
1411    m_MinNum = newMinNum;
1412  }
1413 
1414  /**
1415   * Returns the tip text for this property
1416   * @return tip text for this property suitable for
1417   * displaying in the explorer/experimenter gui
1418   */
1419  public String minVariancePropTipText() {
1420    return "The minimum proportion of the variance on all the data " +
1421      "that needs to be present at a node in order for splitting to " +
1422      "be performed in regression trees.";
1423  }
1424
1425  /**
1426   * Get the value of MinVarianceProp.
1427   *
1428   * @return Value of MinVarianceProp.
1429   */
1430  public double getMinVarianceProp() {
1431   
1432    return m_MinVarianceProp;
1433  }
1434 
1435  /**
1436   * Set the value of MinVarianceProp.
1437   *
1438   * @param newMinVarianceProp Value to assign to MinVarianceProp.
1439   */
1440  public void setMinVarianceProp(double newMinVarianceProp) {
1441   
1442    m_MinVarianceProp = newMinVarianceProp;
1443  }
1444
1445  /**
1446   * Returns the tip text for this property
1447   * @return tip text for this property suitable for
1448   * displaying in the explorer/experimenter gui
1449   */
1450  public String seedTipText() {
1451    return "The seed used for randomizing the data.";
1452  }
1453
1454  /**
1455   * Get the value of Seed.
1456   *
1457   * @return Value of Seed.
1458   */
1459  public int getSeed() {
1460   
1461    return m_Seed;
1462  }
1463 
1464  /**
1465   * Set the value of Seed.
1466   *
1467   * @param newSeed Value to assign to Seed.
1468   */
1469  public void setSeed(int newSeed) {
1470   
1471    m_Seed = newSeed;
1472  }
1473
1474  /**
1475   * Returns the tip text for this property
1476   * @return tip text for this property suitable for
1477   * displaying in the explorer/experimenter gui
1478   */
1479  public String numFoldsTipText() {
1480    return "Determines the amount of data used for pruning. One fold is used for "
1481      + "pruning, the rest for growing the rules.";
1482  }
1483 
1484  /**
1485   * Get the value of NumFolds.
1486   *
1487   * @return Value of NumFolds.
1488   */
1489  public int getNumFolds() {
1490   
1491    return m_NumFolds;
1492  }
1493 
1494  /**
1495   * Set the value of NumFolds.
1496   *
1497   * @param newNumFolds Value to assign to NumFolds.
1498   */
1499  public void setNumFolds(int newNumFolds) {
1500   
1501    m_NumFolds = newNumFolds;
1502  }
1503 
1504  /**
1505   * Returns the tip text for this property
1506   * @return tip text for this property suitable for
1507   * displaying in the explorer/experimenter gui
1508   */
1509  public String maxDepthTipText() {
1510    return "The maximum tree depth (-1 for no restriction).";
1511  }
1512
1513  /**
1514   * Get the value of MaxDepth.
1515   *
1516   * @return Value of MaxDepth.
1517   */
1518  public int getMaxDepth() {
1519   
1520    return m_MaxDepth;
1521  }
1522 
1523  /**
1524   * Set the value of MaxDepth.
1525   *
1526   * @param newMaxDepth Value to assign to MaxDepth.
1527   */
1528  public void setMaxDepth(int newMaxDepth) {
1529   
1530    m_MaxDepth = newMaxDepth;
1531  }
1532 
1533  /**
1534   * Lists the command-line options for this classifier.
1535   *
1536   * @return an enumeration over all commandline options
1537   */
1538  public Enumeration listOptions() {
1539   
1540    Vector newVector = new Vector(5);
1541
1542    newVector.
1543      addElement(new Option("\tSet minimum number of instances per leaf " +
1544                            "(default 2).",
1545                            "M", 1, "-M <minimum number of instances>"));
1546    newVector.
1547      addElement(new Option("\tSet minimum numeric class variance proportion\n" +
1548                            "\tof train variance for split (default 1e-3).",
1549                            "V", 1, "-V <minimum variance for split>"));
1550    newVector.
1551      addElement(new Option("\tNumber of folds for reduced error pruning " +
1552                            "(default 3).",
1553                            "N", 1, "-N <number of folds>"));
1554    newVector.
1555      addElement(new Option("\tSeed for random data shuffling (default 1).",
1556                            "S", 1, "-S <seed>"));
1557    newVector.
1558      addElement(new Option("\tNo pruning.",
1559                            "P", 0, "-P"));
1560    newVector.
1561      addElement(new Option("\tMaximum tree depth (default -1, no maximum)",
1562                            "L", 1, "-L"));
1563
1564    return newVector.elements();
1565  } 
1566
1567  /**
1568   * Gets options from this classifier.
1569   *
1570   * @return the options for the current setup
1571   */
1572  public String[] getOptions() {
1573   
1574    String [] options = new String [12];
1575    int current = 0;
1576    options[current++] = "-M"; 
1577    options[current++] = "" + (int)getMinNum();
1578    options[current++] = "-V"; 
1579    options[current++] = "" + getMinVarianceProp();
1580    options[current++] = "-N"; 
1581    options[current++] = "" + getNumFolds();
1582    options[current++] = "-S"; 
1583    options[current++] = "" + getSeed();
1584    options[current++] = "-L"; 
1585    options[current++] = "" + getMaxDepth();
1586    if (getNoPruning()) {
1587      options[current++] = "-P";
1588    }
1589    while (current < options.length) {
1590      options[current++] = "";
1591    }
1592    return options;
1593  }
1594
1595  /**
1596   * Parses a given list of options. <p/>
1597   *
1598   <!-- options-start -->
1599   * Valid options are: <p/>
1600   *
1601   * <pre> -M &lt;minimum number of instances&gt;
1602   *  Set minimum number of instances per leaf (default 2).</pre>
1603   *
1604   * <pre> -V &lt;minimum variance for split&gt;
1605   *  Set minimum numeric class variance proportion
1606   *  of train variance for split (default 1e-3).</pre>
1607   *
1608   * <pre> -N &lt;number of folds&gt;
1609   *  Number of folds for reduced error pruning (default 3).</pre>
1610   *
1611   * <pre> -S &lt;seed&gt;
1612   *  Seed for random data shuffling (default 1).</pre>
1613   *
1614   * <pre> -P
1615   *  No pruning.</pre>
1616   *
1617   * <pre> -L
1618   *  Maximum tree depth (default -1, no maximum)</pre>
1619   *
1620   <!-- options-end -->
1621   *
1622   * @param options the list of options as an array of strings
1623   * @throws Exception if an option is not supported
1624   */
1625  public void setOptions(String[] options) throws Exception {
1626   
1627    String minNumString = Utils.getOption('M', options);
1628    if (minNumString.length() != 0) {
1629      m_MinNum = (double)Integer.parseInt(minNumString);
1630    } else {
1631      m_MinNum = 2;
1632    }
1633    String minVarString = Utils.getOption('V', options);
1634    if (minVarString.length() != 0) {
1635      m_MinVarianceProp = Double.parseDouble(minVarString);
1636    } else {
1637      m_MinVarianceProp = 1e-3;
1638    }
1639    String numFoldsString = Utils.getOption('N', options);
1640    if (numFoldsString.length() != 0) {
1641      m_NumFolds = Integer.parseInt(numFoldsString);
1642    } else {
1643      m_NumFolds = 3;
1644    }
1645    String seedString = Utils.getOption('S', options);
1646    if (seedString.length() != 0) {
1647      m_Seed = Integer.parseInt(seedString);
1648    } else {
1649      m_Seed = 1;
1650    }
1651    m_NoPruning = Utils.getFlag('P', options);
1652    String depthString = Utils.getOption('L', options);
1653    if (depthString.length() != 0) {
1654      m_MaxDepth = Integer.parseInt(depthString);
1655    } else {
1656      m_MaxDepth = -1;
1657    }
1658    Utils.checkForRemainingOptions(options);
1659  }
1660 
1661  /**
1662   * Computes size of the tree.
1663   *
1664   * @return the number of nodes
1665   */
1666  public int numNodes() {
1667
1668    return m_Tree.numNodes();
1669  }
1670
1671  /**
1672   * Returns an enumeration of the additional measure names.
1673   *
1674   * @return an enumeration of the measure names
1675   */
1676  public Enumeration enumerateMeasures() {
1677   
1678    Vector newVector = new Vector(1);
1679    newVector.addElement("measureTreeSize");
1680    return newVector.elements();
1681  }
1682 
1683  /**
1684   * Returns the value of the named measure.
1685   *
1686   * @param additionalMeasureName the name of the measure to query for its value
1687   * @return the value of the named measure
1688   * @throws IllegalArgumentException if the named measure is not supported
1689   */
1690  public double getMeasure(String additionalMeasureName) {
1691   
1692    if (additionalMeasureName.equalsIgnoreCase("measureTreeSize")) {
1693      return (double) numNodes();
1694    }
1695    else {throw new IllegalArgumentException(additionalMeasureName
1696                              + " not supported (REPTree)");
1697    }
1698  }
1699
1700  /**
1701   * Returns default capabilities of the classifier.
1702   *
1703   * @return      the capabilities of this classifier
1704   */
1705  public Capabilities getCapabilities() {
1706    Capabilities result = super.getCapabilities();
1707    result.disableAll();
1708
1709    // attributes
1710    result.enable(Capability.NOMINAL_ATTRIBUTES);
1711    result.enable(Capability.NUMERIC_ATTRIBUTES);
1712    result.enable(Capability.DATE_ATTRIBUTES);
1713    result.enable(Capability.MISSING_VALUES);
1714
1715    // class
1716    result.enable(Capability.NOMINAL_CLASS);
1717    result.enable(Capability.NUMERIC_CLASS);
1718    result.enable(Capability.DATE_CLASS);
1719    result.enable(Capability.MISSING_CLASS_VALUES);
1720   
1721    return result;
1722  }
1723
1724  /**
1725   * Builds classifier.
1726   *
1727   * @param data the data to train with
1728   * @throws Exception if building fails
1729   */
1730  public void buildClassifier(Instances data) throws Exception {
1731
1732    // can classifier handle the data?
1733    getCapabilities().testWithFail(data);
1734
1735    // remove instances with missing class
1736    data = new Instances(data);
1737    data.deleteWithMissingClass();
1738   
1739    Random random = new Random(m_Seed);
1740
1741    m_zeroR = null;
1742    if (data.numAttributes() == 1) {
1743      m_zeroR = new ZeroR();
1744      m_zeroR.buildClassifier(data);
1745      return;
1746    }
1747
1748    // Randomize and stratify
1749    data.randomize(random);
1750    if (data.classAttribute().isNominal()) {
1751      data.stratify(m_NumFolds);
1752    }
1753
1754    // Split data into training and pruning set
1755    Instances train = null;
1756    Instances prune = null;
1757    if (!m_NoPruning) {
1758      train = data.trainCV(m_NumFolds, 0, random);
1759      prune = data.testCV(m_NumFolds, 0);
1760    } else {
1761      train = data;
1762    }
1763
1764    // Create array of sorted indices and weights
1765    int[][] sortedIndices = new int[train.numAttributes()][0];
1766    double[][] weights = new double[train.numAttributes()][0];
1767    double[] vals = new double[train.numInstances()];
1768    for (int j = 0; j < train.numAttributes(); j++) {
1769      if (j != train.classIndex()) {
1770        weights[j] = new double[train.numInstances()];
1771        if (train.attribute(j).isNominal()) {
1772
1773          // Handling nominal attributes. Putting indices of
1774          // instances with missing values at the end.
1775          sortedIndices[j] = new int[train.numInstances()];
1776          int count = 0;
1777          for (int i = 0; i < train.numInstances(); i++) {
1778            Instance inst = train.instance(i);
1779            if (!inst.isMissing(j)) {
1780              sortedIndices[j][count] = i;
1781              weights[j][count] = inst.weight();
1782              count++;
1783            }
1784          }
1785          for (int i = 0; i < train.numInstances(); i++) {
1786            Instance inst = train.instance(i);
1787            if (inst.isMissing(j)) {
1788              sortedIndices[j][count] = i;
1789              weights[j][count] = inst.weight();
1790              count++;
1791            }
1792          }
1793        } else {
1794
1795          // Sorted indices are computed for numeric attributes
1796          for (int i = 0; i < train.numInstances(); i++) {
1797            Instance inst = train.instance(i);
1798            vals[i] = inst.value(j);
1799          }
1800          sortedIndices[j] = Utils.sort(vals);
1801          for (int i = 0; i < train.numInstances(); i++) {
1802            weights[j][i] = train.instance(sortedIndices[j][i]).weight();
1803          }
1804        }
1805      }
1806    }
1807
1808    // Compute initial class counts
1809    double[] classProbs = new double[train.numClasses()];
1810    double totalWeight = 0, totalSumSquared = 0;
1811    for (int i = 0; i < train.numInstances(); i++) {
1812      Instance inst = train.instance(i);
1813      if (data.classAttribute().isNominal()) {
1814        classProbs[(int)inst.classValue()] += inst.weight();
1815        totalWeight += inst.weight();
1816      } else {
1817        classProbs[0] += inst.classValue() * inst.weight();
1818        totalSumSquared += inst.classValue() * inst.classValue() * inst.weight();
1819        totalWeight += inst.weight();
1820      }
1821    }
1822    m_Tree = new Tree();
1823    double trainVariance = 0;
1824    if (data.classAttribute().isNumeric()) {
1825      trainVariance = m_Tree.
1826        singleVariance(classProbs[0], totalSumSquared, totalWeight) / totalWeight;
1827      classProbs[0] /= totalWeight;
1828    }
1829
1830    // Build tree
1831    m_Tree.buildTree(sortedIndices, weights, train, totalWeight, classProbs,
1832                     new Instances(train, 0), m_MinNum, m_MinVarianceProp * 
1833                     trainVariance, 0, m_MaxDepth);
1834   
1835    // Insert pruning data and perform reduced error pruning
1836    if (!m_NoPruning) {
1837      m_Tree.insertHoldOutSet(prune);
1838      m_Tree.reducedErrorPrune();
1839      m_Tree.backfitHoldOutSet(prune);
1840    }
1841  }
1842
1843  /**
1844   * Computes class distribution of an instance using the tree.
1845   *
1846   * @param instance the instance to compute the distribution for
1847   * @return the computed class probabilities
1848   * @throws Exception if computation fails
1849   */
1850  public double[] distributionForInstance(Instance instance) 
1851    throws Exception {
1852     
1853      if (m_zeroR != null) {
1854        return m_zeroR.distributionForInstance(instance);
1855      } else {
1856        return m_Tree.distributionForInstance(instance);
1857      }
1858  }
1859
1860
1861  /**
1862   * For getting a unique ID when outputting the tree source
1863   * (hashcode isn't guaranteed unique)
1864   */
1865  private static long PRINTED_NODES = 0;
1866
1867  /**
1868   * Gets the next unique node ID.
1869   *
1870   * @return the next unique node ID.
1871   */
1872  protected static long nextID() {
1873
1874    return PRINTED_NODES ++;
1875  }
1876
1877  /**
1878   * resets the counter for the nodes
1879   */
1880  protected static void resetID() {
1881    PRINTED_NODES = 0;
1882  }
1883
1884  /**
1885   * Returns the tree as if-then statements.
1886   *
1887   * @param className the name for the generated class
1888   * @return the tree as a Java if-then type statement
1889   * @throws Exception if something goes wrong
1890   */
1891  public String toSource(String className) 
1892    throws Exception {
1893     
1894    if (m_Tree == null) {
1895      throw new Exception("REPTree: No model built yet.");
1896    } 
1897    StringBuffer [] source = m_Tree.toSource(className, m_Tree);
1898    return
1899    "class " + className + " {\n\n"
1900    +"  public static double classify(Object [] i)\n"
1901    +"    throws Exception {\n\n"
1902    +"    double p = Double.NaN;\n"
1903    + source[0]  // Assignment code
1904    +"    return p;\n"
1905    +"  }\n"
1906    + source[1]  // Support code
1907    +"}\n";
1908  }
1909
1910  /**
1911   *  Returns the type of graph this classifier
1912   *  represents.
1913   *  @return Drawable.TREE
1914   */   
1915  public int graphType() {
1916      return Drawable.TREE;
1917  }
1918
1919  /**
1920   * Outputs the decision tree as a graph
1921   *
1922   * @return the tree as a graph
1923   * @throws Exception if generation fails
1924   */
1925  public String graph() throws Exception {
1926
1927    if (m_Tree == null) {
1928      throw new Exception("REPTree: No model built yet.");
1929    } 
1930    StringBuffer resultBuff = new StringBuffer();
1931    m_Tree.toGraph(resultBuff, 0, null);
1932    String result = "digraph Tree {\n" + "edge [style=bold]\n" + resultBuff.toString()
1933      + "\n}\n";
1934    return result;
1935  }
1936 
1937  /**
1938   * Outputs the decision tree.
1939   *
1940   * @return a string representation of the classifier
1941   */
1942  public String toString() {
1943
1944    if (m_zeroR != null) {
1945      return "No attributes other than class. Using ZeroR.\n\n" + m_zeroR.toString();
1946    }
1947    if ((m_Tree == null)) {
1948      return "REPTree: No model built yet.";
1949    } 
1950    return     
1951      "\nREPTree\n============\n" + m_Tree.toString(0, null) + "\n" +
1952      "\nSize of the tree : " + numNodes();
1953  }
1954 
1955  /**
1956   * Returns the revision string.
1957   *
1958   * @return            the revision
1959   */
1960  public String getRevision() {
1961    return RevisionUtils.extract("$Revision: 5928 $");
1962  }
1963
1964  /**
1965   * Main method for this class.
1966   *
1967   * @param argv the commandline options
1968   */
1969  public static void main(String[] argv) {
1970    runClassifier(new REPTree(), argv);
1971  }
1972}
Note: See TracBrowser for help on using the repository browser.