source: src/main/java/weka/classifiers/trees/SimpleCart.java @ 11

Last change on this file since 11 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 58.7 KB
RevLine 
[4]1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 * SimpleCart.java
19 * Copyright (C) 2007 Haijian Shi
20 *
21 */
22
23package weka.classifiers.trees;
24
25import weka.classifiers.Evaluation;
26import weka.classifiers.RandomizableClassifier;
27import weka.core.AdditionalMeasureProducer;
28import weka.core.Attribute;
29import weka.core.Capabilities;
30import weka.core.Instance;
31import weka.core.Instances;
32import weka.core.Option;
33import weka.core.RevisionUtils;
34import weka.core.TechnicalInformation;
35import weka.core.TechnicalInformationHandler;
36import weka.core.Utils;
37import weka.core.Capabilities.Capability;
38import weka.core.TechnicalInformation.Field;
39import weka.core.TechnicalInformation.Type;
40import weka.core.matrix.Matrix;
41
42import java.util.Arrays;
43import java.util.Enumeration;
44import java.util.Random;
45import java.util.Vector;
46
47/**
48 <!-- globalinfo-start -->
49 * Class implementing minimal cost-complexity pruning.<br/>
50 * Note when dealing with missing values, use "fractional instances" method instead of surrogate split method.<br/>
51 * <br/>
52 * For more information, see:<br/>
53 * <br/>
54 * Leo Breiman, Jerome H. Friedman, Richard A. Olshen, Charles J. Stone (1984). Classification and Regression Trees. Wadsworth International Group, Belmont, California.
55 * <p/>
56 <!-- globalinfo-end -->       
57 *
58 <!-- technical-bibtex-start -->
59 * BibTeX:
60 * <pre>
61 * &#64;book{Breiman1984,
62 *    address = {Belmont, California},
63 *    author = {Leo Breiman and Jerome H. Friedman and Richard A. Olshen and Charles J. Stone},
64 *    publisher = {Wadsworth International Group},
65 *    title = {Classification and Regression Trees},
66 *    year = {1984}
67 * }
68 * </pre>
69 * <p/>
70 <!-- technical-bibtex-end -->
71 *
72 <!-- options-start -->
73 * Valid options are: <p/>
74 *
75 * <pre> -S &lt;num&gt;
76 *  Random number seed.
77 *  (default 1)</pre>
78 *
79 * <pre> -D
80 *  If set, classifier is run in debug mode and
81 *  may output additional info to the console</pre>
82 *
83 * <pre> -M &lt;min no&gt;
84 *  The minimal number of instances at the terminal nodes.
85 *  (default 2)</pre>
86 *
87 * <pre> -N &lt;num folds&gt;
88 *  The number of folds used in the minimal cost-complexity pruning.
89 *  (default 5)</pre>
90 *
91 * <pre> -U
92 *  Don't use the minimal cost-complexity pruning.
93 *  (default yes).</pre>
94 *
95 * <pre> -H
96 *  Don't use the heuristic method for binary split.
97 *  (default true).</pre>
98 *
99 * <pre> -A
100 *  Use 1 SE rule to make pruning decision.
101 *  (default no).</pre>
102 *
103 * <pre> -C
104 *  Percentage of training data size (0-1].
105 *  (default 1).</pre>
106 *
107 <!-- options-end -->
108 *
109 * @author Haijian Shi (hs69@cs.waikato.ac.nz)
110 * @version $Revision: 5987 $
111 */
112public class SimpleCart
113  extends RandomizableClassifier
114  implements AdditionalMeasureProducer, TechnicalInformationHandler {
115
116  /** For serialization.         */
117  private static final long serialVersionUID = 4154189200352566053L;
118
119  /** Training data.  */
120  protected Instances m_train;
121
122  /** Successor nodes. */
123  protected SimpleCart[] m_Successors;
124
125  /** Attribute used to split data. */
126  protected Attribute m_Attribute;
127
128  /** Split point for a numeric attribute. */
129  protected double m_SplitValue;
130
131  /** Split subset used to split data for nominal attributes. */
132  protected String m_SplitString;
133
134  /** Class value if the node is leaf. */
135  protected double m_ClassValue;
136
137  /** Class attriubte of data. */
138  protected Attribute m_ClassAttribute;
139
140  /** Minimum number of instances in at the terminal nodes. */
141  protected double m_minNumObj = 2;
142
143  /** Number of folds for minimal cost-complexity pruning. */
144  protected int m_numFoldsPruning = 5;
145
146  /** Alpha-value (for pruning) at the node. */
147  protected double m_Alpha;
148
149  /** Number of training examples misclassified by the model (subtree rooted). */
150  protected double m_numIncorrectModel;
151
152  /** Number of training examples misclassified by the model (subtree not rooted). */
153  protected double m_numIncorrectTree;
154
155  /** Indicate if the node is a leaf node. */
156  protected boolean m_isLeaf;
157
158  /** If use minimal cost-compexity pruning. */
159  protected boolean m_Prune = true;
160
161  /** Total number of instances used to build the classifier. */
162  protected int m_totalTrainInstances;
163
164  /** Proportion for each branch. */
165  protected double[] m_Props;
166
167  /** Class probabilities. */
168  protected double[] m_ClassProbs = null;
169
170  /** Distributions of leaf node (or temporary leaf node in minimal cost-complexity pruning) */
171  protected double[] m_Distribution;
172
173  /** If use huristic search for nominal attributes in multi-class problems (default true). */
174  protected boolean m_Heuristic = true;
175
176  /** If use the 1SE rule to make final decision tree. */
177  protected boolean m_UseOneSE = false;
178
179  /** Training data size. */
180  protected double m_SizePer = 1;
181
182  /**
183   * Return a description suitable for displaying in the explorer/experimenter.
184   *
185   * @return            a description suitable for displaying in the
186   *                    explorer/experimenter
187   */
188  public String globalInfo() {
189    return 
190        "Class implementing minimal cost-complexity pruning.\n"
191      + "Note when dealing with missing values, use \"fractional "
192      + "instances\" method instead of surrogate split method.\n\n"
193      + "For more information, see:\n\n"
194      + getTechnicalInformation().toString();
195  }
196
197  /**
198   * Returns an instance of a TechnicalInformation object, containing
199   * detailed information about the technical background of this class,
200   * e.g., paper reference or book this class is based on.
201   *
202   * @return            the technical information about this class
203   */
204  public TechnicalInformation getTechnicalInformation() {
205    TechnicalInformation        result;
206   
207    result = new TechnicalInformation(Type.BOOK);
208    result.setValue(Field.AUTHOR, "Leo Breiman and Jerome H. Friedman and Richard A. Olshen and Charles J. Stone");
209    result.setValue(Field.YEAR, "1984");
210    result.setValue(Field.TITLE, "Classification and Regression Trees");
211    result.setValue(Field.PUBLISHER, "Wadsworth International Group");
212    result.setValue(Field.ADDRESS, "Belmont, California");
213   
214    return result;
215  }
216
217  /**
218   * Returns default capabilities of the classifier.
219   *
220   * @return            the capabilities of this classifier
221   */
222  public Capabilities getCapabilities() {
223    Capabilities result = super.getCapabilities();
224    result.disableAll();
225
226    // attributes
227    result.enable(Capability.NOMINAL_ATTRIBUTES);
228    result.enable(Capability.NUMERIC_ATTRIBUTES);
229    result.enable(Capability.MISSING_VALUES);
230
231    // class
232    result.enable(Capability.NOMINAL_CLASS);
233
234    return result;
235  }
236
237  /**
238   * Build the classifier.
239   *
240   * @param data        the training instances
241   * @throws Exception  if something goes wrong
242   */
243  public void buildClassifier(Instances data) throws Exception {
244
245    getCapabilities().testWithFail(data);
246    data = new Instances(data);       
247    data.deleteWithMissingClass();
248
249    // unpruned CART decision tree
250    if (!m_Prune) {
251
252      // calculate sorted indices and weights, and compute initial class counts.
253      int[][] sortedIndices = new int[data.numAttributes()][0];
254      double[][] weights = new double[data.numAttributes()][0];
255      double[] classProbs = new double[data.numClasses()];
256      double totalWeight = computeSortedInfo(data,sortedIndices, weights,classProbs);
257
258      makeTree(data, data.numInstances(),sortedIndices,weights,classProbs,
259          totalWeight,m_minNumObj, m_Heuristic);
260      return;
261    }
262
263    Random random = new Random(m_Seed);
264    Instances cvData = new Instances(data);
265    cvData.randomize(random);
266    cvData = new Instances(cvData,0,(int)(cvData.numInstances()*m_SizePer)-1);
267    cvData.stratify(m_numFoldsPruning);
268
269    double[][] alphas = new double[m_numFoldsPruning][];
270    double[][] errors = new double[m_numFoldsPruning][];
271
272    // calculate errors and alphas for each fold
273    for (int i = 0; i < m_numFoldsPruning; i++) {
274
275      //for every fold, grow tree on training set and fix error on test set.
276      Instances train = cvData.trainCV(m_numFoldsPruning, i);
277      Instances test = cvData.testCV(m_numFoldsPruning, i);
278
279      // calculate sorted indices and weights, and compute initial class counts for each fold
280      int[][] sortedIndices = new int[train.numAttributes()][0];
281      double[][] weights = new double[train.numAttributes()][0];
282      double[] classProbs = new double[train.numClasses()];
283      double totalWeight = computeSortedInfo(train,sortedIndices, weights,classProbs);
284
285      makeTree(train, train.numInstances(),sortedIndices,weights,classProbs,
286          totalWeight,m_minNumObj, m_Heuristic);
287
288      int numNodes = numInnerNodes();
289      alphas[i] = new double[numNodes + 2];
290      errors[i] = new double[numNodes + 2];
291
292      // prune back and log alpha-values and errors on test set
293      prune(alphas[i], errors[i], test);
294    }
295
296    // calculate sorted indices and weights, and compute initial class counts on all training instances
297    int[][] sortedIndices = new int[data.numAttributes()][0];
298    double[][] weights = new double[data.numAttributes()][0];
299    double[] classProbs = new double[data.numClasses()];
300    double totalWeight = computeSortedInfo(data,sortedIndices, weights,classProbs);
301
302    //build tree using all the data
303    makeTree(data, data.numInstances(),sortedIndices,weights,classProbs,
304        totalWeight,m_minNumObj, m_Heuristic);
305
306    int numNodes = numInnerNodes();
307
308    double[] treeAlphas = new double[numNodes + 2];
309
310    // prune back and log alpha-values
311    int iterations = prune(treeAlphas, null, null);
312
313    double[] treeErrors = new double[numNodes + 2];
314
315    // for each pruned subtree, find the cross-validated error
316    for (int i = 0; i <= iterations; i++){
317      //compute midpoint alphas
318      double alpha = Math.sqrt(treeAlphas[i] * treeAlphas[i+1]);
319      double error = 0;
320      for (int k = 0; k < m_numFoldsPruning; k++) {
321        int l = 0;
322        while (alphas[k][l] <= alpha) l++;
323        error += errors[k][l - 1];
324      }
325      treeErrors[i] = error/m_numFoldsPruning;
326    }
327
328    // find best alpha
329    int best = -1;
330    double bestError = Double.MAX_VALUE;
331    for (int i = iterations; i >= 0; i--) {
332      if (treeErrors[i] < bestError) {
333        bestError = treeErrors[i];
334        best = i;
335      }
336    }
337
338    // 1 SE rule to choose expansion
339    if (m_UseOneSE) {
340      double oneSE = Math.sqrt(bestError*(1-bestError)/(data.numInstances()));
341      for (int i = iterations; i >= 0; i--) {
342        if (treeErrors[i] <= bestError+oneSE) {
343          best = i;
344          break;
345        }
346      }
347    }
348
349    double bestAlpha = Math.sqrt(treeAlphas[best] * treeAlphas[best + 1]);
350
351    //"unprune" final tree (faster than regrowing it)
352    unprune();
353    prune(bestAlpha);       
354  }
355
356  /**
357   * Make binary decision tree recursively.
358   *
359   * @param data                the training instances
360   * @param totalInstances      total number of instances
361   * @param sortedIndices       sorted indices of the instances
362   * @param weights             weights of the instances
363   * @param classProbs          class probabilities
364   * @param totalWeight         total weight of instances
365   * @param minNumObj           minimal number of instances at leaf nodes
366   * @param useHeuristic        if use heuristic search for nominal attributes in multi-class problem
367   * @throws Exception          if something goes wrong
368   */
369  protected void makeTree(Instances data, int totalInstances, int[][] sortedIndices,
370      double[][] weights, double[] classProbs, double totalWeight, double minNumObj,
371      boolean useHeuristic) throws Exception{
372
373    // if no instances have reached this node (normally won't happen)
374    if (totalWeight == 0){
375      m_Attribute = null;
376      m_ClassValue = Utils.missingValue();
377      m_Distribution = new double[data.numClasses()];
378      return;
379    }
380
381    m_totalTrainInstances = totalInstances;
382    m_isLeaf = true;
383
384    m_ClassProbs = new double[classProbs.length];
385    m_Distribution = new double[classProbs.length];
386    System.arraycopy(classProbs, 0, m_ClassProbs, 0, classProbs.length);
387    System.arraycopy(classProbs, 0, m_Distribution, 0, classProbs.length);
388    if (Utils.sum(m_ClassProbs)!=0) Utils.normalize(m_ClassProbs);
389
390    // Compute class distributions and value of splitting
391    // criterion for each attribute
392    double[][][] dists = new double[data.numAttributes()][0][0];
393    double[][] props = new double[data.numAttributes()][0];
394    double[][] totalSubsetWeights = new double[data.numAttributes()][2];
395    double[] splits = new double[data.numAttributes()];
396    String[] splitString = new String[data.numAttributes()];
397    double[] giniGains = new double[data.numAttributes()];
398
399    // for each attribute find split information
400    for (int i = 0; i < data.numAttributes(); i++) {
401      Attribute att = data.attribute(i);
402      if (i==data.classIndex()) continue;
403      if (att.isNumeric()) {
404        // numeric attribute
405        splits[i] = numericDistribution(props, dists, att, sortedIndices[i],
406            weights[i], totalSubsetWeights, giniGains, data);
407      } else {
408        // nominal attribute
409        splitString[i] = nominalDistribution(props, dists, att, sortedIndices[i],
410            weights[i], totalSubsetWeights, giniGains, data, useHeuristic);
411      }
412    }
413
414    // Find best attribute (split with maximum Gini gain)
415    int attIndex = Utils.maxIndex(giniGains);
416    m_Attribute = data.attribute(attIndex);
417
418    m_train = new Instances(data, sortedIndices[attIndex].length);
419    for (int i=0; i<sortedIndices[attIndex].length; i++) {
420      Instance inst = data.instance(sortedIndices[attIndex][i]);
421      Instance instCopy = (Instance)inst.copy();
422      instCopy.setWeight(weights[attIndex][i]);
423      m_train.add(instCopy);
424    }
425
426    // Check if node does not contain enough instances, or if it can not be split,
427    // or if it is pure. If does, make leaf.
428    if (totalWeight < 2 * minNumObj || giniGains[attIndex]==0 ||
429        props[attIndex][0]==0 || props[attIndex][1]==0) {
430      makeLeaf(data);
431    }
432   
433    else {           
434      m_Props = props[attIndex];
435      int[][][] subsetIndices = new int[2][data.numAttributes()][0];
436      double[][][] subsetWeights = new double[2][data.numAttributes()][0];
437
438      // numeric split
439      if (m_Attribute.isNumeric()) m_SplitValue = splits[attIndex];
440
441      // nominal split
442      else m_SplitString = splitString[attIndex];
443
444      splitData(subsetIndices, subsetWeights, m_Attribute, m_SplitValue,
445          m_SplitString, sortedIndices, weights, data);
446     
447      // If split of the node results in a node with less than minimal number of isntances,
448      // make the node leaf node.
449      if (subsetIndices[0][attIndex].length<minNumObj ||
450          subsetIndices[1][attIndex].length<minNumObj) {
451        makeLeaf(data);
452        return;
453      }
454
455      // Otherwise, split the node.
456      m_isLeaf = false;
457      m_Successors = new SimpleCart[2];
458      for (int i = 0; i < 2; i++) {
459        m_Successors[i] = new SimpleCart();
460        m_Successors[i].makeTree(data, m_totalTrainInstances, subsetIndices[i],
461            subsetWeights[i],dists[attIndex][i], totalSubsetWeights[attIndex][i],
462            minNumObj, useHeuristic);
463      }
464    }
465  }
466
467  /**
468   * Prunes the original tree using the CART pruning scheme, given a
469   * cost-complexity parameter alpha.
470   *
471   * @param alpha       the cost-complexity parameter
472   * @throws Exception  if something goes wrong
473   */
474  public void prune(double alpha) throws Exception {
475
476    Vector nodeList;
477
478    // determine training error of pruned subtrees (both with and without replacing a subtree),
479    // and calculate alpha-values from them
480    modelErrors();
481    treeErrors();
482    calculateAlphas();
483
484    // get list of all inner nodes in the tree
485    nodeList = getInnerNodes();
486
487    boolean prune = (nodeList.size() > 0);
488    double preAlpha = Double.MAX_VALUE;
489    while (prune) {
490
491      // select node with minimum alpha
492      SimpleCart nodeToPrune = nodeToPrune(nodeList);
493
494      // want to prune if its alpha is smaller than alpha
495      if (nodeToPrune.m_Alpha > alpha) {
496        break;
497      }
498
499      nodeToPrune.makeLeaf(nodeToPrune.m_train);
500
501      // normally would not happen
502      if (nodeToPrune.m_Alpha==preAlpha) {
503        nodeToPrune.makeLeaf(nodeToPrune.m_train);
504        treeErrors();
505        calculateAlphas();
506        nodeList = getInnerNodes();
507        prune = (nodeList.size() > 0);
508        continue;
509      }
510      preAlpha = nodeToPrune.m_Alpha;
511
512      //update tree errors and alphas
513      treeErrors();
514      calculateAlphas();
515
516      nodeList = getInnerNodes();
517      prune = (nodeList.size() > 0);
518    }
519  }
520
521  /**
522   * Method for performing one fold in the cross-validation of minimal
523   * cost-complexity pruning. Generates a sequence of alpha-values with error
524   * estimates for the corresponding (partially pruned) trees, given the test
525   * set of that fold.
526   *
527   * @param alphas      array to hold the generated alpha-values
528   * @param errors      array to hold the corresponding error estimates
529   * @param test        test set of that fold (to obtain error estimates)
530   * @return            the iteration of the pruning
531   * @throws Exception  if something goes wrong
532   */
533  public int prune(double[] alphas, double[] errors, Instances test) 
534    throws Exception {
535
536    Vector nodeList;
537
538    // determine training error of subtrees (both with and without replacing a subtree),
539    // and calculate alpha-values from them
540    modelErrors();
541    treeErrors();
542    calculateAlphas();
543
544    // get list of all inner nodes in the tree
545    nodeList = getInnerNodes();
546
547    boolean prune = (nodeList.size() > 0);
548
549    //alpha_0 is always zero (unpruned tree)
550    alphas[0] = 0;
551
552    Evaluation eval;
553
554    // error of unpruned tree
555    if (errors != null) {
556      eval = new Evaluation(test);
557      eval.evaluateModel(this, test);
558      errors[0] = eval.errorRate();
559    }
560
561    int iteration = 0;
562    double preAlpha = Double.MAX_VALUE;
563    while (prune) {
564
565      iteration++;
566
567      // get node with minimum alpha
568      SimpleCart nodeToPrune = nodeToPrune(nodeList);
569
570      // do not set m_sons null, want to unprune
571      nodeToPrune.m_isLeaf = true;
572
573      // normally would not happen
574      if (nodeToPrune.m_Alpha==preAlpha) {
575        iteration--;
576        treeErrors();
577        calculateAlphas();
578        nodeList = getInnerNodes();
579        prune = (nodeList.size() > 0);
580        continue;
581      }
582
583      // get alpha-value of node
584      alphas[iteration] = nodeToPrune.m_Alpha;
585
586      // log error
587      if (errors != null) {
588        eval = new Evaluation(test);
589        eval.evaluateModel(this, test);
590        errors[iteration] = eval.errorRate();
591      }
592      preAlpha = nodeToPrune.m_Alpha;
593
594      //update errors/alphas
595      treeErrors();
596      calculateAlphas();
597
598      nodeList = getInnerNodes();
599      prune = (nodeList.size() > 0);
600    }
601
602    //set last alpha 1 to indicate end
603    alphas[iteration + 1] = 1.0;
604    return iteration;
605  }
606
607  /**
608   * Method to "unprune" the CART tree. Sets all leaf-fields to false.
609   * Faster than re-growing the tree because CART do not have to be fit again.
610   */
611  protected void unprune() {
612    if (m_Successors != null) {
613      m_isLeaf = false;
614      for (int i = 0; i < m_Successors.length; i++) m_Successors[i].unprune();
615    }
616  }
617
618  /**
619   * Compute distributions, proportions and total weights of two successor
620   * nodes for a given numeric attribute.
621   *
622   * @param props               proportions of each two branches for each attribute
623   * @param dists               class distributions of two branches for each attribute
624   * @param att                 numeric att split on
625   * @param sortedIndices       sorted indices of instances for the attirubte
626   * @param weights             weights of instances for the attirbute
627   * @param subsetWeights       total weight of two branches split based on the attribute
628   * @param giniGains           Gini gains for each attribute
629   * @param data                training instances
630   * @return                    Gini gain the given numeric attribute
631   * @throws Exception          if something goes wrong
632   */
633  protected double numericDistribution(double[][] props, double[][][] dists,
634      Attribute att, int[] sortedIndices, double[] weights, double[][] subsetWeights,
635      double[] giniGains, Instances data)
636    throws Exception {
637
638    double splitPoint = Double.NaN;
639    double[][] dist = null;
640    int numClasses = data.numClasses();
641    int i; // differ instances with or without missing values
642
643    double[][] currDist = new double[2][numClasses];
644    dist = new double[2][numClasses];
645
646    // Move all instances without missing values into second subset
647    double[] parentDist = new double[numClasses];
648    int missingStart = 0;
649    for (int j = 0; j < sortedIndices.length; j++) {
650      Instance inst = data.instance(sortedIndices[j]);
651      if (!inst.isMissing(att)) {
652        missingStart ++;
653        currDist[1][(int)inst.classValue()] += weights[j];
654      }
655      parentDist[(int)inst.classValue()] += weights[j];
656    }
657    System.arraycopy(currDist[1], 0, dist[1], 0, dist[1].length);
658
659    // Try all possible split points
660    double currSplit = data.instance(sortedIndices[0]).value(att);
661    double currGiniGain;
662    double bestGiniGain = -Double.MAX_VALUE;
663
664    for (i = 0; i < sortedIndices.length; i++) {
665      Instance inst = data.instance(sortedIndices[i]);
666      if (inst.isMissing(att)) {
667        break;
668      }
669      if (inst.value(att) > currSplit) {
670
671        double[][] tempDist = new double[2][numClasses];
672        for (int k=0; k<2; k++) {
673          //tempDist[k] = currDist[k];
674          System.arraycopy(currDist[k], 0, tempDist[k], 0, tempDist[k].length);
675        }
676
677        double[] tempProps = new double[2];
678        for (int k=0; k<2; k++) {
679          tempProps[k] = Utils.sum(tempDist[k]);
680        }
681
682        if (Utils.sum(tempProps) !=0) Utils.normalize(tempProps);
683
684        // split missing values
685        int index = missingStart;
686        while (index < sortedIndices.length) {
687          Instance insta = data.instance(sortedIndices[index]);
688          for (int j = 0; j < 2; j++) {
689            tempDist[j][(int)insta.classValue()] += tempProps[j] * weights[index];
690          }
691          index++;
692        }
693
694        currGiniGain = computeGiniGain(parentDist,tempDist);
695
696        if (currGiniGain > bestGiniGain) {
697          bestGiniGain = currGiniGain;
698
699          // clean split point
700//        splitPoint = Math.rint((inst.value(att) + currSplit)/2.0*100000)/100000.0;
701          splitPoint = (inst.value(att) + currSplit) / 2.0;
702
703          for (int j = 0; j < currDist.length; j++) {
704            System.arraycopy(tempDist[j], 0, dist[j], 0,
705                dist[j].length);
706          }
707        }
708      }
709      currSplit = inst.value(att);
710      currDist[0][(int)inst.classValue()] += weights[i];
711      currDist[1][(int)inst.classValue()] -= weights[i];
712    }
713
714    // Compute weights
715    int attIndex = att.index();
716    props[attIndex] = new double[2];
717    for (int k = 0; k < 2; k++) {
718      props[attIndex][k] = Utils.sum(dist[k]);
719    }
720    if (Utils.sum(props[attIndex]) != 0) Utils.normalize(props[attIndex]);
721
722    // Compute subset weights
723    subsetWeights[attIndex] = new double[2];
724    for (int j = 0; j < 2; j++) {
725      subsetWeights[attIndex][j] += Utils.sum(dist[j]);
726    }
727
728    // clean Gini gain
729    //giniGains[attIndex] = Math.rint(bestGiniGain*10000000)/10000000.0;
730    giniGains[attIndex] = bestGiniGain;
731    dists[attIndex] = dist;
732
733    return splitPoint;
734  }
735
736  /**
737   * Compute distributions, proportions and total weights of two successor
738   * nodes for a given nominal attribute.
739   *
740   * @param props               proportions of each two branches for each attribute
741   * @param dists               class distributions of two branches for each attribute
742   * @param att                 numeric att split on
743   * @param sortedIndices       sorted indices of instances for the attirubte
744   * @param weights             weights of instances for the attirbute
745   * @param subsetWeights       total weight of two branches split based on the attribute
746   * @param giniGains           Gini gains for each attribute
747   * @param data                training instances
748   * @param useHeuristic        if use heuristic search
749   * @return                    Gini gain for the given nominal attribute
750   * @throws Exception          if something goes wrong
751   */
752  protected String nominalDistribution(double[][] props, double[][][] dists,
753      Attribute att, int[] sortedIndices, double[] weights, double[][] subsetWeights,
754      double[] giniGains, Instances data, boolean useHeuristic)
755    throws Exception {
756
757    String[] values = new String[att.numValues()];
758    int numCat = values.length; // number of values of the attribute
759    int numClasses = data.numClasses();
760
761    String bestSplitString = "";
762    double bestGiniGain = -Double.MAX_VALUE;
763
764    // class frequency for each value
765    int[] classFreq = new int[numCat];
766    for (int j=0; j<numCat; j++) classFreq[j] = 0;
767
768    double[] parentDist = new double[numClasses];
769    double[][] currDist = new double[2][numClasses];
770    double[][] dist = new double[2][numClasses];
771    int missingStart = 0;
772
773    for (int i = 0; i < sortedIndices.length; i++) {
774      Instance inst = data.instance(sortedIndices[i]);
775      if (!inst.isMissing(att)) {
776        missingStart++;
777        classFreq[(int)inst.value(att)] ++;
778      }
779      parentDist[(int)inst.classValue()] += weights[i];
780    }
781
782    // count the number of values that class frequency is not 0
783    int nonEmpty = 0;
784    for (int j=0; j<numCat; j++) {
785      if (classFreq[j]!=0) nonEmpty ++;
786    }
787
788    // attribute values that class frequency is not 0
789    String[] nonEmptyValues = new String[nonEmpty];
790    int nonEmptyIndex = 0;
791    for (int j=0; j<numCat; j++) {
792      if (classFreq[j]!=0) {
793        nonEmptyValues[nonEmptyIndex] = att.value(j);
794        nonEmptyIndex ++;
795      }
796    }
797
798    // attribute values that class frequency is 0
799    int empty = numCat - nonEmpty;
800    String[] emptyValues = new String[empty];
801    int emptyIndex = 0;
802    for (int j=0; j<numCat; j++) {
803      if (classFreq[j]==0) {
804        emptyValues[emptyIndex] = att.value(j);
805        emptyIndex ++;
806      }
807    }
808
809    if (nonEmpty<=1) {
810      giniGains[att.index()] = 0;
811      return "";
812    }
813
814    // for tow-class probloms
815    if (data.numClasses()==2) {
816
817      //// Firstly, for attribute values which class frequency is not zero
818
819      // probability of class 0 for each attribute value
820      double[] pClass0 = new double[nonEmpty];
821      // class distribution for each attribute value
822      double[][] valDist = new double[nonEmpty][2];
823
824      for (int j=0; j<nonEmpty; j++) {
825        for (int k=0; k<2; k++) {
826          valDist[j][k] = 0;
827        }
828      }
829
830      for (int i = 0; i < sortedIndices.length; i++) {
831        Instance inst = data.instance(sortedIndices[i]);
832        if (inst.isMissing(att)) {
833          break;
834        }
835
836        for (int j=0; j<nonEmpty; j++) {
837          if (att.value((int)inst.value(att)).compareTo(nonEmptyValues[j])==0) {
838            valDist[j][(int)inst.classValue()] += inst.weight();
839            break;
840          }
841        }
842      }
843
844      for (int j=0; j<nonEmpty; j++) {
845        double distSum = Utils.sum(valDist[j]);
846        if (distSum==0) pClass0[j]=0;
847        else pClass0[j] = valDist[j][0]/distSum;
848      }
849
850      // sort category according to the probability of the first class
851      String[] sortedValues = new String[nonEmpty];
852      for (int j=0; j<nonEmpty; j++) {
853        sortedValues[j] = nonEmptyValues[Utils.minIndex(pClass0)];
854        pClass0[Utils.minIndex(pClass0)] = Double.MAX_VALUE;
855      }
856
857      // Find a subset of attribute values that maximize Gini decrease
858
859      // for the attribute values that class frequency is not 0
860      String tempStr = "";
861
862      for (int j=0; j<nonEmpty-1; j++) {
863        currDist = new double[2][numClasses];
864        if (tempStr=="") tempStr="(" + sortedValues[j] + ")";
865        else tempStr += "|"+ "(" + sortedValues[j] + ")";
866        for (int i=0; i<sortedIndices.length;i++) {
867          Instance inst = data.instance(sortedIndices[i]);
868          if (inst.isMissing(att)) {
869            break;
870          }
871
872          if (tempStr.indexOf
873              ("(" + att.value((int)inst.value(att)) + ")")!=-1) {
874            currDist[0][(int)inst.classValue()] += weights[i];
875          } else currDist[1][(int)inst.classValue()] += weights[i];
876        }
877
878        double[][] tempDist = new double[2][numClasses];
879        for (int kk=0; kk<2; kk++) {
880          tempDist[kk] = currDist[kk];
881        }
882
883        double[] tempProps = new double[2];
884        for (int kk=0; kk<2; kk++) {
885          tempProps[kk] = Utils.sum(tempDist[kk]);
886        }
887
888        if (Utils.sum(tempProps)!=0) Utils.normalize(tempProps);
889
890        // split missing values
891        int mstart = missingStart;
892        while (mstart < sortedIndices.length) {
893          Instance insta = data.instance(sortedIndices[mstart]);
894          for (int jj = 0; jj < 2; jj++) {
895            tempDist[jj][(int)insta.classValue()] += tempProps[jj] * weights[mstart];
896          }
897          mstart++;
898        }
899
900        double currGiniGain = computeGiniGain(parentDist,tempDist);
901
902        if (currGiniGain>bestGiniGain) {
903          bestGiniGain = currGiniGain;
904          bestSplitString = tempStr;
905          for (int jj = 0; jj < 2; jj++) {
906            //dist[jj] = new double[currDist[jj].length];
907            System.arraycopy(tempDist[jj], 0, dist[jj], 0,
908                dist[jj].length);
909          }
910        }
911      }
912    }
913
914    // multi-class problems - exhaustive search
915    else if (!useHeuristic || nonEmpty<=4) {
916
917      // Firstly, for attribute values which class frequency is not zero
918      for (int i=0; i<(int)Math.pow(2,nonEmpty-1); i++) {
919        String tempStr="";
920        currDist = new double[2][numClasses];
921        int mod;
922        int bit10 = i;
923        for (int j=nonEmpty-1; j>=0; j--) {
924          mod = bit10%2; // convert from 10bit to 2bit
925          if (mod==1) {
926            if (tempStr=="") tempStr = "("+nonEmptyValues[j]+")";
927            else tempStr += "|" + "("+nonEmptyValues[j]+")";
928          }
929          bit10 = bit10/2;
930        }
931        for (int j=0; j<sortedIndices.length;j++) {
932          Instance inst = data.instance(sortedIndices[j]);
933          if (inst.isMissing(att)) {
934            break;
935          }
936
937          if (tempStr.indexOf("("+att.value((int)inst.value(att))+")")!=-1) {
938            currDist[0][(int)inst.classValue()] += weights[j];
939          } else currDist[1][(int)inst.classValue()] += weights[j];
940        }
941
942        double[][] tempDist = new double[2][numClasses];
943        for (int k=0; k<2; k++) {
944          tempDist[k] = currDist[k];
945        }
946
947        double[] tempProps = new double[2];
948        for (int k=0; k<2; k++) {
949          tempProps[k] = Utils.sum(tempDist[k]);
950        }
951
952        if (Utils.sum(tempProps)!=0) Utils.normalize(tempProps);
953
954        // split missing values
955        int index = missingStart;
956        while (index < sortedIndices.length) {
957          Instance insta = data.instance(sortedIndices[index]);
958          for (int j = 0; j < 2; j++) {
959            tempDist[j][(int)insta.classValue()] += tempProps[j] * weights[index];
960          }
961          index++;
962        }
963
964        double currGiniGain = computeGiniGain(parentDist,tempDist);
965
966        if (currGiniGain>bestGiniGain) {
967          bestGiniGain = currGiniGain;
968          bestSplitString = tempStr;
969          for (int j = 0; j < 2; j++) {
970            //dist[jj] = new double[currDist[jj].length];
971            System.arraycopy(tempDist[j], 0, dist[j], 0,
972                dist[j].length);
973          }
974        }
975      }
976    }
977
978    // huristic search to solve multi-classes problems
979    else {
980      // Firstly, for attribute values which class frequency is not zero
981      int n = nonEmpty;
982      int k = data.numClasses();  // number of classes of the data
983      double[][] P = new double[n][k];      // class probability matrix
984      int[] numInstancesValue = new int[n]; // number of instances for an attribute value
985      double[] meanClass = new double[k];   // vector of mean class probability
986      int numInstances = data.numInstances(); // total number of instances
987
988      // initialize the vector of mean class probability
989      for (int j=0; j<meanClass.length; j++) meanClass[j]=0;
990
991      for (int j=0; j<numInstances; j++) {
992        Instance inst = (Instance)data.instance(j);
993        int valueIndex = 0; // attribute value index in nonEmptyValues
994        for (int i=0; i<nonEmpty; i++) {
995          if (att.value((int)inst.value(att)).compareToIgnoreCase(nonEmptyValues[i])==0){
996            valueIndex = i;
997            break;
998          }
999        }
1000        P[valueIndex][(int)inst.classValue()]++;
1001        numInstancesValue[valueIndex]++;
1002        meanClass[(int)inst.classValue()]++;
1003      }
1004
1005      // calculate the class probability matrix
1006      for (int i=0; i<P.length; i++) {
1007        for (int j=0; j<P[0].length; j++) {
1008          if (numInstancesValue[i]==0) P[i][j]=0;
1009          else P[i][j]/=numInstancesValue[i];
1010        }
1011      }
1012
1013      //calculate the vector of mean class probability
1014      for (int i=0; i<meanClass.length; i++) {
1015        meanClass[i]/=numInstances;
1016      }
1017
1018      // calculate the covariance matrix
1019      double[][] covariance = new double[k][k];
1020      for (int i1=0; i1<k; i1++) {
1021        for (int i2=0; i2<k; i2++) {
1022          double element = 0;
1023          for (int j=0; j<n; j++) {
1024            element += (P[j][i2]-meanClass[i2])*(P[j][i1]-meanClass[i1])
1025            *numInstancesValue[j];
1026          }
1027          covariance[i1][i2] = element;
1028        }
1029      }
1030
1031      Matrix matrix = new Matrix(covariance);
1032      weka.core.matrix.EigenvalueDecomposition eigen =
1033        new weka.core.matrix.EigenvalueDecomposition(matrix);
1034      double[] eigenValues = eigen.getRealEigenvalues();
1035
1036      // find index of the largest eigenvalue
1037      int index=0;
1038      double largest = eigenValues[0];
1039      for (int i=1; i<eigenValues.length; i++) {
1040        if (eigenValues[i]>largest) {
1041          index=i;
1042          largest = eigenValues[i];
1043        }
1044      }
1045
1046      // calculate the first principle component
1047      double[] FPC = new double[k];
1048      Matrix eigenVector = eigen.getV();
1049      double[][] vectorArray = eigenVector.getArray();
1050      for (int i=0; i<FPC.length; i++) {
1051        FPC[i] = vectorArray[i][index];
1052      }
1053
1054      // calculate the first principle component scores
1055      //System.out.println("the first principle component scores: ");
1056      double[] Sa = new double[n];
1057      for (int i=0; i<Sa.length; i++) {
1058        Sa[i]=0;
1059        for (int j=0; j<k; j++) {
1060          Sa[i] += FPC[j]*P[i][j];
1061        }
1062      }
1063
1064      // sort category according to Sa(s)
1065      double[] pCopy = new double[n];
1066      System.arraycopy(Sa,0,pCopy,0,n);
1067      String[] sortedValues = new String[n];
1068      Arrays.sort(Sa);
1069
1070      for (int j=0; j<n; j++) {
1071        sortedValues[j] = nonEmptyValues[Utils.minIndex(pCopy)];
1072        pCopy[Utils.minIndex(pCopy)] = Double.MAX_VALUE;
1073      }
1074
1075      // for the attribute values that class frequency is not 0
1076      String tempStr = "";
1077
1078      for (int j=0; j<nonEmpty-1; j++) {
1079        currDist = new double[2][numClasses];
1080        if (tempStr=="") tempStr="(" + sortedValues[j] + ")";
1081        else tempStr += "|"+ "(" + sortedValues[j] + ")";
1082        for (int i=0; i<sortedIndices.length;i++) {
1083          Instance inst = data.instance(sortedIndices[i]);
1084          if (inst.isMissing(att)) {
1085            break;
1086          }
1087
1088          if (tempStr.indexOf
1089              ("(" + att.value((int)inst.value(att)) + ")")!=-1) {
1090            currDist[0][(int)inst.classValue()] += weights[i];
1091          } else currDist[1][(int)inst.classValue()] += weights[i];
1092        }
1093
1094        double[][] tempDist = new double[2][numClasses];
1095        for (int kk=0; kk<2; kk++) {
1096          tempDist[kk] = currDist[kk];
1097        }
1098
1099        double[] tempProps = new double[2];
1100        for (int kk=0; kk<2; kk++) {
1101          tempProps[kk] = Utils.sum(tempDist[kk]);
1102        }
1103
1104        if (Utils.sum(tempProps)!=0) Utils.normalize(tempProps);
1105
1106        // split missing values
1107        int mstart = missingStart;
1108        while (mstart < sortedIndices.length) {
1109          Instance insta = data.instance(sortedIndices[mstart]);
1110          for (int jj = 0; jj < 2; jj++) {
1111            tempDist[jj][(int)insta.classValue()] += tempProps[jj] * weights[mstart];
1112          }
1113          mstart++;
1114        }
1115
1116        double currGiniGain = computeGiniGain(parentDist,tempDist);
1117
1118        if (currGiniGain>bestGiniGain) {
1119          bestGiniGain = currGiniGain;
1120          bestSplitString = tempStr;
1121          for (int jj = 0; jj < 2; jj++) {
1122            //dist[jj] = new double[currDist[jj].length];
1123            System.arraycopy(tempDist[jj], 0, dist[jj], 0,
1124                dist[jj].length);
1125          }
1126        }
1127      }
1128    }
1129
1130    // Compute weights
1131    int attIndex = att.index();       
1132    props[attIndex] = new double[2];
1133    for (int k = 0; k < 2; k++) {
1134      props[attIndex][k] = Utils.sum(dist[k]);
1135    }
1136
1137    if (!(Utils.sum(props[attIndex]) > 0)) {
1138      for (int k = 0; k < props[attIndex].length; k++) {
1139        props[attIndex][k] = 1.0 / (double)props[attIndex].length;
1140      }
1141    } else {
1142      Utils.normalize(props[attIndex]);
1143    }
1144
1145
1146    // Compute subset weights
1147    subsetWeights[attIndex] = new double[2];
1148    for (int j = 0; j < 2; j++) {
1149      subsetWeights[attIndex][j] += Utils.sum(dist[j]);
1150    }
1151
1152    // Then, for the attribute values that class frequency is 0, split it into the
1153    // most frequent branch
1154    for (int j=0; j<empty; j++) {
1155      if (props[attIndex][0]>=props[attIndex][1]) {
1156        if (bestSplitString=="") bestSplitString = "(" + emptyValues[j] + ")";
1157        else bestSplitString += "|" + "(" + emptyValues[j] + ")";
1158      }
1159    }
1160
1161    // clean Gini gain for the attribute
1162    //giniGains[attIndex] = Math.rint(bestGiniGain*10000000)/10000000.0;
1163    giniGains[attIndex] = bestGiniGain;
1164
1165    dists[attIndex] = dist;
1166    return bestSplitString;
1167  }
1168
1169
1170  /**
1171   * Split data into two subsets and store sorted indices and weights for two
1172   * successor nodes.
1173   *
1174   * @param subsetIndices       sorted indecis of instances for each attribute
1175   *                            for two successor node
1176   * @param subsetWeights       weights of instances for each attribute for
1177   *                            two successor node
1178   * @param att                 attribute the split based on
1179   * @param splitPoint          split point the split based on if att is numeric
1180   * @param splitStr            split subset the split based on if att is nominal
1181   * @param sortedIndices       sorted indices of the instances to be split
1182   * @param weights             weights of the instances to bes split
1183   * @param data                training data
1184   * @throws Exception          if something goes wrong 
1185   */
1186  protected void splitData(int[][][] subsetIndices, double[][][] subsetWeights,
1187      Attribute att, double splitPoint, String splitStr, int[][] sortedIndices,
1188      double[][] weights, Instances data) throws Exception {
1189
1190    int j;
1191    // For each attribute
1192    for (int i = 0; i < data.numAttributes(); i++) {
1193      if (i==data.classIndex()) continue;
1194      int[] num = new int[2];
1195      for (int k = 0; k < 2; k++) {
1196        subsetIndices[k][i] = new int[sortedIndices[i].length];
1197        subsetWeights[k][i] = new double[weights[i].length];
1198      }
1199
1200      for (j = 0; j < sortedIndices[i].length; j++) {
1201        Instance inst = data.instance(sortedIndices[i][j]);
1202        if (inst.isMissing(att)) {
1203          // Split instance up
1204          for (int k = 0; k < 2; k++) {
1205            if (m_Props[k] > 0) {
1206              subsetIndices[k][i][num[k]] = sortedIndices[i][j];
1207              subsetWeights[k][i][num[k]] = m_Props[k] * weights[i][j];
1208              num[k]++;
1209            }
1210          }
1211        } else {
1212          int subset;
1213          if (att.isNumeric())  {
1214            subset = (inst.value(att) < splitPoint) ? 0 : 1;
1215          } else { // nominal attribute
1216            if (splitStr.indexOf
1217                ("(" + att.value((int)inst.value(att.index()))+")")!=-1) {
1218              subset = 0;
1219            } else subset = 1;
1220          }
1221          subsetIndices[subset][i][num[subset]] = sortedIndices[i][j];
1222          subsetWeights[subset][i][num[subset]] = weights[i][j];
1223          num[subset]++;
1224        }
1225      }
1226
1227      // Trim arrays
1228      for (int k = 0; k < 2; k++) {
1229        int[] copy = new int[num[k]];
1230        System.arraycopy(subsetIndices[k][i], 0, copy, 0, num[k]);
1231        subsetIndices[k][i] = copy;
1232        double[] copyWeights = new double[num[k]];
1233        System.arraycopy(subsetWeights[k][i], 0 ,copyWeights, 0, num[k]);
1234        subsetWeights[k][i] = copyWeights;
1235      }
1236    }
1237  }
1238
1239  /**
1240   * Updates the numIncorrectModel field for all nodes when subtree (to be
1241   * pruned) is rooted. This is needed for calculating the alpha-values.
1242   *
1243   * @throws Exception  if something goes wrong
1244   */
1245  public void modelErrors() throws Exception{
1246    Evaluation eval = new Evaluation(m_train);
1247
1248    if (!m_isLeaf) {
1249      m_isLeaf = true; //temporarily make leaf
1250
1251      // calculate distribution for evaluation
1252      eval.evaluateModel(this, m_train);
1253      m_numIncorrectModel = eval.incorrect();
1254
1255      m_isLeaf = false;
1256
1257      for (int i = 0; i < m_Successors.length; i++)
1258        m_Successors[i].modelErrors();
1259
1260    } else {
1261      eval.evaluateModel(this, m_train);
1262      m_numIncorrectModel = eval.incorrect();
1263    }       
1264  }
1265
1266  /**
1267   * Updates the numIncorrectTree field for all nodes. This is needed for
1268   * calculating the alpha-values.
1269   *
1270   * @throws Exception  if something goes wrong
1271   */
1272  public void treeErrors() throws Exception {
1273    if (m_isLeaf) {
1274      m_numIncorrectTree = m_numIncorrectModel;
1275    } else {
1276      m_numIncorrectTree = 0;
1277      for (int i = 0; i < m_Successors.length; i++) {
1278        m_Successors[i].treeErrors();
1279        m_numIncorrectTree += m_Successors[i].m_numIncorrectTree;
1280      }
1281    }
1282  }
1283
1284  /**
1285   * Updates the alpha field for all nodes.
1286   *
1287   * @throws Exception  if something goes wrong
1288   */
1289  public void calculateAlphas() throws Exception {
1290
1291    if (!m_isLeaf) {
1292      double errorDiff = m_numIncorrectModel - m_numIncorrectTree;
1293      if (errorDiff <=0) {
1294        //split increases training error (should not normally happen).
1295        //prune it instantly.
1296        makeLeaf(m_train);
1297        m_Alpha = Double.MAX_VALUE;
1298      } else {
1299        //compute alpha
1300        errorDiff /= m_totalTrainInstances;
1301        m_Alpha = errorDiff / (double)(numLeaves() - 1);
1302        long alphaLong = Math.round(m_Alpha*Math.pow(10,10));
1303        m_Alpha = (double)alphaLong/Math.pow(10,10);
1304        for (int i = 0; i < m_Successors.length; i++) {
1305          m_Successors[i].calculateAlphas();
1306        }
1307      }
1308    } else {
1309      //alpha = infinite for leaves (do not want to prune)
1310      m_Alpha = Double.MAX_VALUE;
1311    }
1312  }
1313
1314  /**
1315   * Find the node with minimal alpha value. If two nodes have the same alpha,
1316   * choose the one with more leave nodes.
1317   *
1318   * @param nodeList    list of inner nodes
1319   * @return            the node to be pruned
1320   */
1321  protected SimpleCart nodeToPrune(Vector nodeList) {
1322    if (nodeList.size()==0) return null;
1323    if (nodeList.size()==1) return (SimpleCart)nodeList.elementAt(0);
1324    SimpleCart returnNode = (SimpleCart)nodeList.elementAt(0);
1325    double baseAlpha = returnNode.m_Alpha;
1326    for (int i=1; i<nodeList.size(); i++) {
1327      SimpleCart node = (SimpleCart)nodeList.elementAt(i);
1328      if (node.m_Alpha < baseAlpha) {
1329        baseAlpha = node.m_Alpha;
1330        returnNode = node;
1331      } else if (node.m_Alpha == baseAlpha) { // break tie
1332        if (node.numLeaves()>returnNode.numLeaves()) {
1333          returnNode = node;
1334        }
1335      }
1336    }
1337    return returnNode;
1338  }
1339
1340  /**
1341   * Compute sorted indices, weights and class probabilities for a given
1342   * dataset. Return total weights of the data at the node.
1343   *
1344   * @param data                training data
1345   * @param sortedIndices       sorted indices of instances at the node
1346   * @param weights             weights of instances at the node
1347   * @param classProbs          class probabilities at the node
1348   * @return total              weights of instances at the node
1349   * @throws Exception          if something goes wrong
1350   */
1351  protected double computeSortedInfo(Instances data, int[][] sortedIndices, double[][] weights,
1352      double[] classProbs) throws Exception {
1353
1354    // Create array of sorted indices and weights
1355    double[] vals = new double[data.numInstances()];
1356    for (int j = 0; j < data.numAttributes(); j++) {
1357      if (j==data.classIndex()) continue;
1358      weights[j] = new double[data.numInstances()];
1359
1360      if (data.attribute(j).isNominal()) {
1361
1362        // Handling nominal attributes. Putting indices of
1363        // instances with missing values at the end.
1364        sortedIndices[j] = new int[data.numInstances()];
1365        int count = 0;
1366        for (int i = 0; i < data.numInstances(); i++) {
1367          Instance inst = data.instance(i);
1368          if (!inst.isMissing(j)) {
1369            sortedIndices[j][count] = i;
1370            weights[j][count] = inst.weight();
1371            count++;
1372          }
1373        }
1374        for (int i = 0; i < data.numInstances(); i++) {
1375          Instance inst = data.instance(i);
1376          if (inst.isMissing(j)) {
1377            sortedIndices[j][count] = i;
1378            weights[j][count] = inst.weight();
1379            count++;
1380          }
1381        }
1382      } else {
1383
1384        // Sorted indices are computed for numeric attributes
1385        // missing values instances are put to end
1386        for (int i = 0; i < data.numInstances(); i++) {
1387          Instance inst = data.instance(i);
1388          vals[i] = inst.value(j);
1389        }
1390        sortedIndices[j] = Utils.sort(vals);
1391        for (int i = 0; i < data.numInstances(); i++) {
1392          weights[j][i] = data.instance(sortedIndices[j][i]).weight();
1393        }
1394      }
1395    }
1396
1397    // Compute initial class counts
1398    double totalWeight = 0;
1399    for (int i = 0; i < data.numInstances(); i++) {
1400      Instance inst = data.instance(i);
1401      classProbs[(int)inst.classValue()] += inst.weight();
1402      totalWeight += inst.weight();
1403    }
1404
1405    return totalWeight;
1406  }
1407
1408  /**
1409   * Compute and return gini gain for given distributions of a node and its
1410   * successor nodes.
1411   *
1412   * @param parentDist  class distributions of parent node
1413   * @param childDist   class distributions of successor nodes
1414   * @return            Gini gain computed
1415   */
1416  protected double computeGiniGain(double[] parentDist, double[][] childDist) {
1417    double totalWeight = Utils.sum(parentDist);
1418    if (totalWeight==0) return 0;
1419
1420    double leftWeight = Utils.sum(childDist[0]);
1421    double rightWeight = Utils.sum(childDist[1]);
1422
1423    double parentGini = computeGini(parentDist, totalWeight);
1424    double leftGini = computeGini(childDist[0],leftWeight);
1425    double rightGini = computeGini(childDist[1], rightWeight);
1426
1427    return parentGini - leftWeight/totalWeight*leftGini -
1428    rightWeight/totalWeight*rightGini;
1429  }
1430
1431  /**
1432   * Compute and return gini index for a given distribution of a node.
1433   *
1434   * @param dist        class distributions
1435   * @param total       class distributions
1436   * @return            Gini index of the class distributions
1437   */
1438  protected double computeGini(double[] dist, double total) {
1439    if (total==0) return 0;
1440    double val = 0;
1441    for (int i=0; i<dist.length; i++) {
1442      val += (dist[i]/total)*(dist[i]/total);
1443    }
1444    return 1- val;
1445  }
1446
1447  /**
1448   * Computes class probabilities for instance using the decision tree.
1449   *
1450   * @param instance    the instance for which class probabilities is to be computed
1451   * @return            the class probabilities for the given instance
1452   * @throws Exception  if something goes wrong
1453   */
1454  public double[] distributionForInstance(Instance instance)
1455  throws Exception {
1456    if (!m_isLeaf) {
1457      // value of split attribute is missing
1458      if (instance.isMissing(m_Attribute)) {
1459        double[] returnedDist = new double[m_ClassProbs.length];
1460
1461        for (int i = 0; i < m_Successors.length; i++) {
1462          double[] help =
1463            m_Successors[i].distributionForInstance(instance);
1464          if (help != null) {
1465            for (int j = 0; j < help.length; j++) {
1466              returnedDist[j] += m_Props[i] * help[j];
1467            }
1468          }
1469        }
1470        return returnedDist;
1471      }
1472
1473      // split attribute is nonimal
1474      else if (m_Attribute.isNominal()) {
1475        if (m_SplitString.indexOf("(" +
1476            m_Attribute.value((int)instance.value(m_Attribute)) + ")")!=-1)
1477          return  m_Successors[0].distributionForInstance(instance);
1478        else return  m_Successors[1].distributionForInstance(instance);
1479      }
1480
1481      // split attribute is numeric
1482      else {
1483        if (instance.value(m_Attribute) < m_SplitValue)
1484          return m_Successors[0].distributionForInstance(instance);
1485        else
1486          return m_Successors[1].distributionForInstance(instance);
1487      }
1488    }
1489
1490    // leaf node
1491    else return m_ClassProbs;
1492  }
1493
1494  /**
1495   * Make the node leaf node.
1496   *
1497   * @param data        trainging data
1498   */
1499  protected void makeLeaf(Instances data) {
1500    m_Attribute = null;
1501    m_isLeaf = true;
1502    m_ClassValue=Utils.maxIndex(m_ClassProbs);
1503    m_ClassAttribute = data.classAttribute();
1504  }
1505
1506  /**
1507   * Prints the decision tree using the protected toString method from below.
1508   *
1509   * @return            a textual description of the classifier
1510   */
1511  public String toString() {
1512    if ((m_ClassProbs == null) && (m_Successors == null)) {
1513      return "CART Tree: No model built yet.";
1514    }
1515
1516    return "CART Decision Tree\n" + toString(0)+"\n\n"
1517    +"Number of Leaf Nodes: "+numLeaves()+"\n\n" +
1518    "Size of the Tree: "+numNodes();
1519  }
1520
1521  /**
1522   * Outputs a tree at a certain level.
1523   *
1524   * @param level       the level at which the tree is to be printed
1525   * @return            a tree at a certain level
1526   */
1527  protected String toString(int level) {
1528
1529    StringBuffer text = new StringBuffer();
1530    // if leaf nodes
1531    if (m_Attribute == null) {
1532      if (Utils.isMissingValue(m_ClassValue)) {
1533        text.append(": null");
1534      } else {
1535        double correctNum = (int)(m_Distribution[Utils.maxIndex(m_Distribution)]*100)/
1536        100.0;
1537        double wrongNum = (int)((Utils.sum(m_Distribution) -
1538            m_Distribution[Utils.maxIndex(m_Distribution)])*100)/100.0;
1539        String str = "("  + correctNum + "/" + wrongNum + ")";
1540        text.append(": " + m_ClassAttribute.value((int) m_ClassValue)+ str);
1541      }
1542    } else {
1543      for (int j = 0; j < 2; j++) {
1544        text.append("\n");
1545        for (int i = 0; i < level; i++) {
1546          text.append("|  ");
1547        }
1548        if (j==0) {
1549          if (m_Attribute.isNumeric())
1550            text.append(m_Attribute.name() + " < " + m_SplitValue);
1551          else
1552            text.append(m_Attribute.name() + "=" + m_SplitString);
1553        } else {
1554          if (m_Attribute.isNumeric())
1555            text.append(m_Attribute.name() + " >= " + m_SplitValue);
1556          else
1557            text.append(m_Attribute.name() + "!=" + m_SplitString);
1558        }
1559        text.append(m_Successors[j].toString(level + 1));
1560      }
1561    }
1562    return text.toString();
1563  }
1564
1565  /**
1566   * Compute size of the tree.
1567   *
1568   * @return            size of the tree
1569   */
1570  public int numNodes() {
1571    if (m_isLeaf) {
1572      return 1;
1573    } else {
1574      int size =1;
1575      for (int i=0;i<m_Successors.length;i++) {
1576        size+=m_Successors[i].numNodes();
1577      }
1578      return size;
1579    }
1580  }
1581
1582  /**
1583   * Method to count the number of inner nodes in the tree.
1584   *
1585   * @return            the number of inner nodes
1586   */
1587  public int numInnerNodes(){
1588    if (m_Attribute==null) return 0;
1589    int numNodes = 1;
1590    for (int i = 0; i < m_Successors.length; i++)
1591      numNodes += m_Successors[i].numInnerNodes();
1592    return numNodes;
1593  }
1594
1595  /**
1596   * Return a list of all inner nodes in the tree.
1597   *
1598   * @return            the list of all inner nodes
1599   */
1600  protected Vector getInnerNodes(){
1601    Vector nodeList = new Vector();
1602    fillInnerNodes(nodeList);
1603    return nodeList;
1604  }
1605
1606  /**
1607   * Fills a list with all inner nodes in the tree.
1608   *
1609   * @param nodeList    the list to be filled
1610   */
1611  protected void fillInnerNodes(Vector nodeList) {
1612    if (!m_isLeaf) {
1613      nodeList.add(this);
1614      for (int i = 0; i < m_Successors.length; i++)
1615        m_Successors[i].fillInnerNodes(nodeList);
1616    }
1617  }
1618
1619  /**
1620   * Compute number of leaf nodes.
1621   *
1622   * @return            number of leaf nodes
1623   */
1624  public int numLeaves() {
1625    if (m_isLeaf) return 1;
1626    else {
1627      int size=0;
1628      for (int i=0;i<m_Successors.length;i++) {
1629        size+=m_Successors[i].numLeaves();
1630      }
1631      return size;
1632    }
1633  }
1634
1635  /**
1636   * Returns an enumeration describing the available options.
1637   *
1638   * @return            an enumeration of all the available options.
1639   */
1640  public Enumeration listOptions() {
1641    Vector      result;
1642    Enumeration en;
1643   
1644    result = new Vector();
1645   
1646    en = super.listOptions();
1647    while (en.hasMoreElements())
1648      result.addElement(en.nextElement());
1649
1650    result.addElement(new Option(
1651        "\tThe minimal number of instances at the terminal nodes.\n" 
1652        + "\t(default 2)",
1653        "M", 1, "-M <min no>"));
1654   
1655    result.addElement(new Option(
1656        "\tThe number of folds used in the minimal cost-complexity pruning.\n"
1657        + "\t(default 5)",
1658        "N", 1, "-N <num folds>"));
1659   
1660    result.addElement(new Option(
1661        "\tDon't use the minimal cost-complexity pruning.\n"
1662        + "\t(default yes).",
1663        "U", 0, "-U"));
1664   
1665    result.addElement(new Option(
1666        "\tDon't use the heuristic method for binary split.\n"
1667        + "\t(default true).",
1668        "H", 0, "-H"));
1669   
1670    result.addElement(new Option(
1671        "\tUse 1 SE rule to make pruning decision.\n"
1672        + "\t(default no).",
1673        "A", 0, "-A"));
1674   
1675    result.addElement(new Option(
1676        "\tPercentage of training data size (0-1].\n" 
1677        + "\t(default 1).",
1678        "C", 1, "-C"));
1679
1680    return result.elements();
1681  }
1682
1683  /**
1684   * Parses a given list of options. <p/>
1685   *
1686   <!-- options-start -->
1687   * Valid options are: <p/>
1688   *
1689   * <pre> -S &lt;num&gt;
1690   *  Random number seed.
1691   *  (default 1)</pre>
1692   *
1693   * <pre> -D
1694   *  If set, classifier is run in debug mode and
1695   *  may output additional info to the console</pre>
1696   *
1697   * <pre> -M &lt;min no&gt;
1698   *  The minimal number of instances at the terminal nodes.
1699   *  (default 2)</pre>
1700   *
1701   * <pre> -N &lt;num folds&gt;
1702   *  The number of folds used in the minimal cost-complexity pruning.
1703   *  (default 5)</pre>
1704   *
1705   * <pre> -U
1706   *  Don't use the minimal cost-complexity pruning.
1707   *  (default yes).</pre>
1708   *
1709   * <pre> -H
1710   *  Don't use the heuristic method for binary split.
1711   *  (default true).</pre>
1712   *
1713   * <pre> -A
1714   *  Use 1 SE rule to make pruning decision.
1715   *  (default no).</pre>
1716   *
1717   * <pre> -C
1718   *  Percentage of training data size (0-1].
1719   *  (default 1).</pre>
1720   *
1721   <!-- options-end -->
1722   *
1723   * @param options the list of options as an array of strings
1724   * @throws Exception if an options is not supported
1725   */
1726  public void setOptions(String[] options) throws Exception {
1727    String      tmpStr;
1728   
1729    super.setOptions(options);
1730   
1731    tmpStr = Utils.getOption('M', options);
1732    if (tmpStr.length() != 0)
1733      setMinNumObj(Double.parseDouble(tmpStr));
1734    else
1735      setMinNumObj(2);
1736
1737    tmpStr = Utils.getOption('N', options);
1738    if (tmpStr.length()!=0)
1739      setNumFoldsPruning(Integer.parseInt(tmpStr));
1740    else
1741      setNumFoldsPruning(5);
1742
1743    setUsePrune(!Utils.getFlag('U',options));
1744    setHeuristic(!Utils.getFlag('H',options));
1745    setUseOneSE(Utils.getFlag('A',options));
1746
1747    tmpStr = Utils.getOption('C', options);
1748    if (tmpStr.length()!=0)
1749      setSizePer(Double.parseDouble(tmpStr));
1750    else
1751      setSizePer(1);
1752
1753    Utils.checkForRemainingOptions(options);
1754  }
1755
1756  /**
1757   * Gets the current settings of the classifier.
1758   *
1759   * @return            the current setting of the classifier
1760   */
1761  public String[] getOptions() {
1762    int         i;
1763    Vector      result;
1764    String[]    options;
1765
1766    result = new Vector();
1767
1768    options = super.getOptions();
1769    for (i = 0; i < options.length; i++)
1770      result.add(options[i]);
1771
1772    result.add("-M");
1773    result.add("" + getMinNumObj());
1774   
1775    result.add("-N");
1776    result.add("" + getNumFoldsPruning());
1777   
1778    if (!getUsePrune())
1779      result.add("-U");
1780   
1781    if (!getHeuristic())
1782      result.add("-H");
1783   
1784    if (getUseOneSE())
1785      result.add("-A");
1786   
1787    result.add("-C");
1788    result.add("" + getSizePer());
1789
1790    return (String[]) result.toArray(new String[result.size()]);         
1791  }
1792
1793  /**
1794   * Return an enumeration of the measure names.
1795   *
1796   * @return            an enumeration of the measure names
1797   */
1798  public Enumeration enumerateMeasures() {
1799    Vector result = new Vector();
1800   
1801    result.addElement("measureTreeSize");
1802   
1803    return result.elements();
1804  }
1805
1806  /**
1807   * Return number of tree size.
1808   *
1809   * @return            number of tree size
1810   */
1811  public double measureTreeSize() {
1812    return numNodes();
1813  }
1814
1815  /**
1816   * Returns the value of the named measure.
1817   *
1818   * @param additionalMeasureName       the name of the measure to query for its value
1819   * @return                            the value of the named measure
1820   * @throws IllegalArgumentException   if the named measure is not supported
1821   */
1822  public double getMeasure(String additionalMeasureName) {
1823    if (additionalMeasureName.compareToIgnoreCase("measureTreeSize") == 0) {
1824      return measureTreeSize();
1825    } else {
1826      throw new IllegalArgumentException(additionalMeasureName
1827          + " not supported (Cart pruning)");
1828    }
1829  }
1830
1831  /**
1832   * Returns the tip text for this property
1833   *
1834   * @return            tip text for this property suitable for
1835   *                    displaying in the explorer/experimenter gui
1836   */
1837  public String minNumObjTipText() {
1838    return "The minimal number of observations at the terminal nodes (default 2).";
1839  }
1840
1841  /**
1842   * Set minimal number of instances at the terminal nodes.
1843   *
1844   * @param value       minimal number of instances at the terminal nodes
1845   */
1846  public void setMinNumObj(double value) {
1847    m_minNumObj = value;
1848  }
1849
1850  /**
1851   * Get minimal number of instances at the terminal nodes.
1852   *
1853   * @return            minimal number of instances at the terminal nodes
1854   */
1855  public double getMinNumObj() {
1856    return m_minNumObj;
1857  }
1858
1859  /**
1860   * Returns the tip text for this property
1861   *
1862   * @return            tip text for this property suitable for
1863   *                    displaying in the explorer/experimenter gui
1864   */
1865  public String numFoldsPruningTipText() {
1866    return "The number of folds in the internal cross-validation (default 5).";
1867  }
1868
1869  /**
1870   * Set number of folds in internal cross-validation.
1871   *
1872   * @param value       number of folds in internal cross-validation.
1873   */
1874  public void setNumFoldsPruning(int value) {
1875    m_numFoldsPruning = value;
1876  }
1877
1878  /**
1879   * Set number of folds in internal cross-validation.
1880   *
1881   * @return            number of folds in internal cross-validation.
1882   */
1883  public int getNumFoldsPruning() {
1884    return m_numFoldsPruning;
1885  }
1886
1887  /**
1888   * Return the tip text for this property
1889   *
1890   * @return            tip text for this property suitable for displaying in
1891   *                    the explorer/experimenter gui.
1892   */
1893  public String usePruneTipText() {
1894    return "Use minimal cost-complexity pruning (default yes).";
1895  }
1896
1897  /**
1898   * Set if use minimal cost-complexity pruning.
1899   *
1900   * @param value       if use minimal cost-complexity pruning
1901   */
1902  public void setUsePrune(boolean value) {
1903    m_Prune = value;
1904  }
1905
1906  /**
1907   * Get if use minimal cost-complexity pruning.
1908   *
1909   * @return            if use minimal cost-complexity pruning
1910   */
1911  public boolean getUsePrune() {
1912    return m_Prune;
1913  }
1914
1915  /**
1916   * Returns the tip text for this property
1917   *
1918   * @return            tip text for this property suitable for
1919   *                    displaying in the explorer/experimenter gui.
1920   */
1921  public String heuristicTipText() {
1922    return 
1923        "If heuristic search is used for binary split for nominal attributes "
1924      + "in multi-class problems (default yes).";
1925  }
1926
1927  /**
1928   * Set if use heuristic search for nominal attributes in multi-class problems.
1929   *
1930   * @param value       if use heuristic search for nominal attributes in
1931   *                    multi-class problems
1932   */
1933  public void setHeuristic(boolean value) {
1934    m_Heuristic = value;
1935  }
1936
1937  /**
1938   * Get if use heuristic search for nominal attributes in multi-class problems.
1939   *
1940   * @return            if use heuristic search for nominal attributes in
1941   *                    multi-class problems
1942   */
1943  public boolean getHeuristic() {return m_Heuristic;}
1944
1945  /**
1946   * Returns the tip text for this property
1947   *
1948   * @return            tip text for this property suitable for
1949   *                    displaying in the explorer/experimenter gui.
1950   */
1951  public String useOneSETipText() {
1952    return "Use the 1SE rule to make pruning decisoin.";
1953  }
1954
1955  /**
1956   * Set if use the 1SE rule to choose final model.
1957   *
1958   * @param value       if use the 1SE rule to choose final model
1959   */
1960  public void setUseOneSE(boolean value) {
1961    m_UseOneSE = value;
1962  }
1963
1964  /**
1965   * Get if use the 1SE rule to choose final model.
1966   *
1967   * @return            if use the 1SE rule to choose final model
1968   */
1969  public boolean getUseOneSE() {
1970    return m_UseOneSE;
1971  }
1972
1973  /**
1974   * Returns the tip text for this property
1975   *
1976   * @return            tip text for this property suitable for
1977   *                    displaying in the explorer/experimenter gui.
1978   */
1979  public String sizePerTipText() {
1980    return "The percentage of the training set size (0-1, 0 not included).";
1981  }
1982
1983  /**
1984   * Set training set size.
1985   *
1986   * @param value       training set size
1987   */ 
1988  public void setSizePer(double value) {
1989    if ((value <= 0) || (value > 1))
1990      System.err.println(
1991          "The percentage of the training set size must be in range 0 to 1 "
1992          + "(0 not included) - ignored!");
1993    else
1994      m_SizePer = value;
1995  }
1996
1997  /**
1998   * Get training set size.
1999   *
2000   * @return            training set size
2001   */
2002  public double getSizePer() {
2003    return m_SizePer;
2004  }
2005 
2006  /**
2007   * Returns the revision string.
2008   *
2009   * @return            the revision
2010   */
2011  public String getRevision() {
2012    return RevisionUtils.extract("$Revision: 5987 $");
2013  }
2014
2015  /**
2016   * Main method.
2017   * @param args the options for the classifier
2018   */
2019  public static void main(String[] args) {
2020    runClassifier(new SimpleCart(), args);
2021  }
2022}
Note: See TracBrowser for help on using the repository browser.