source: src/main/java/weka/classifiers/trees/BFTree.java @ 7

Last change on this file since 7 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 83.2 KB
RevLine 
[4]1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 * BFTree.java
19 * Copyright (C) 2007 Haijian Shi
20 *
21 */
22
23package weka.classifiers.trees;
24
25import weka.classifiers.Evaluation;
26import weka.classifiers.RandomizableClassifier;
27import weka.core.AdditionalMeasureProducer;
28import weka.core.Attribute;
29import weka.core.Capabilities;
30import weka.core.FastVector;
31import weka.core.Instance;
32import weka.core.Instances;
33import weka.core.Option;
34import weka.core.RevisionUtils;
35import weka.core.SelectedTag;
36import weka.core.Tag;
37import weka.core.TechnicalInformation;
38import weka.core.TechnicalInformationHandler;
39import weka.core.Utils;
40import weka.core.Capabilities.Capability;
41import weka.core.TechnicalInformation.Field;
42import weka.core.TechnicalInformation.Type;
43import weka.core.matrix.Matrix;
44
45import java.util.Arrays;
46import java.util.Enumeration;
47import java.util.Random;
48import java.util.Vector;
49
50/**
51 <!-- globalinfo-start -->
52 * Class for building a best-first decision tree classifier. This class uses binary split for both nominal and numeric attributes. For missing values, the method of 'fractional' instances is used.<br/>
53 * <br/>
54 * For more information, see:<br/>
55 * <br/>
56 * Haijian Shi (2007). Best-first decision tree learning. Hamilton, NZ.<br/>
57 * <br/>
58 * Jerome Friedman, Trevor Hastie, Robert Tibshirani (2000). Additive logistic regression : A statistical view of boosting. Annals of statistics. 28(2):337-407.
59 * <p/>
60 <!-- globalinfo-end -->
61 *
62 <!-- technical-bibtex-start -->
63 * BibTeX:
64 * <pre>
65 * &#64;mastersthesis{Shi2007,
66 *    address = {Hamilton, NZ},
67 *    author = {Haijian Shi},
68 *    note = {COMP594},
69 *    school = {University of Waikato},
70 *    title = {Best-first decision tree learning},
71 *    year = {2007}
72 * }
73 *
74 * &#64;article{Friedman2000,
75 *    author = {Jerome Friedman and Trevor Hastie and Robert Tibshirani},
76 *    journal = {Annals of statistics},
77 *    number = {2},
78 *    pages = {337-407},
79 *    title = {Additive logistic regression : A statistical view of boosting},
80 *    volume = {28},
81 *    year = {2000},
82 *    ISSN = {0090-5364}
83 * }
84 * </pre>
85 * <p/>
86 <!-- technical-bibtex-end -->
87 *
88 <!-- options-start -->
89 * Valid options are: <p/>
90 *
91 * <pre> -S &lt;num&gt;
92 *  Random number seed.
93 *  (default 1)</pre>
94 *
95 * <pre> -D
96 *  If set, classifier is run in debug mode and
97 *  may output additional info to the console</pre>
98 *
99 * <pre> -P &lt;UNPRUNED|POSTPRUNED|PREPRUNED&gt;
100 *  The pruning strategy.
101 *  (default: POSTPRUNED)</pre>
102 *
103 * <pre> -M &lt;min no&gt;
104 *  The minimal number of instances at the terminal nodes.
105 *  (default 2)</pre>
106 *
107 * <pre> -N &lt;num folds&gt;
108 *  The number of folds used in the pruning.
109 *  (default 5)</pre>
110 *
111 * <pre> -H
112 *  Don't use heuristic search for nominal attributes in multi-class
113 *  problem (default yes).
114 * </pre>
115 *
116 * <pre> -G
117 *  Don't use Gini index for splitting (default yes),
118 *  if not information is used.</pre>
119 *
120 * <pre> -R
121 *  Don't use error rate in internal cross-validation (default yes),
122 *  but root mean squared error.</pre>
123 *
124 * <pre> -A
125 *  Use the 1 SE rule to make pruning decision.
126 *  (default no).</pre>
127 *
128 * <pre> -C
129 *  Percentage of training data size (0-1]
130 *  (default 1).</pre>
131 *
132 <!-- options-end -->
133 *
134 * @author Haijian Shi (hs69@cs.waikato.ac.nz)
135 * @version $Revision: 5987 $
136 */
137public class BFTree
138  extends RandomizableClassifier
139  implements AdditionalMeasureProducer, TechnicalInformationHandler {
140
141  /** For serialization.         */
142  private static final long serialVersionUID = -7035607375962528217L;
143
144  /** pruning strategy: un-pruned */
145  public static final int PRUNING_UNPRUNED = 0;
146  /** pruning strategy: post-pruning */
147  public static final int PRUNING_POSTPRUNING = 1;
148  /** pruning strategy: pre-pruning */
149  public static final int PRUNING_PREPRUNING = 2;
150  /** pruning strategy */
151  public static final Tag[] TAGS_PRUNING = {
152    new Tag(PRUNING_UNPRUNED, "unpruned", "Un-pruned"),
153    new Tag(PRUNING_POSTPRUNING, "postpruned", "Post-pruning"),
154    new Tag(PRUNING_PREPRUNING, "prepruned", "Pre-pruning")
155  };
156 
157  /** the pruning strategy */
158  protected int m_PruningStrategy = PRUNING_POSTPRUNING;
159
160  /** Successor nodes. */
161  protected BFTree[] m_Successors;
162
163  /** Attribute used for splitting. */
164  protected Attribute m_Attribute;
165
166  /** Split point (for numeric attributes). */
167  protected double m_SplitValue;
168
169  /** Split subset (for nominal attributes). */
170  protected String m_SplitString;
171
172  /** Class value for a node. */
173  protected double m_ClassValue;
174
175  /** Class attribute of a dataset. */
176  protected Attribute m_ClassAttribute;
177
178  /** Minimum number of instances at leaf nodes. */
179  protected int m_minNumObj = 2;
180
181  /** Number of folds for the pruning. */
182  protected int m_numFoldsPruning = 5;
183
184  /** If the ndoe is leaf node. */
185  protected boolean m_isLeaf;
186
187  /** Number of expansions. */
188  protected static int m_Expansion;
189
190  /** Fixed number of expansions (if no pruning method is used, its value is -1. Otherwise,
191   *  its value is gotten from internal cross-validation).   */
192  protected int m_FixedExpansion = -1;
193
194  /** If use huristic search for binary split (default true). Note even if its value is true, it is only
195   * used when the number of values of a nominal attribute is larger than 4. */
196  protected boolean m_Heuristic = true;
197
198  /** If use Gini index as the splitting criterion - default (if not, information is used). */
199  protected boolean m_UseGini = true;
200
201  /** If use error rate in internal cross-validation to fix the number of expansions - default
202   *  (if not, root mean squared error is used). */
203  protected boolean m_UseErrorRate = true;
204
205  /** If use the 1SE rule to make the decision. */
206  protected boolean m_UseOneSE = false;
207
208  /** Class distributions.  */
209  protected double[] m_Distribution;
210
211  /** Branch proportions. */
212  protected double[] m_Props;
213
214  /** Sorted indices. */
215  protected int[][] m_SortedIndices;
216
217  /** Sorted weights. */
218  protected double[][] m_Weights;
219
220  /** Distributions of each attribute for two successor nodes. */
221  protected double[][][] m_Dists;
222
223  /** Class probabilities. */
224  protected double[] m_ClassProbs;
225
226  /** Total weights. */
227  protected double m_TotalWeight;
228
229  /** The training data size (0-1). Default 1. */
230  protected double m_SizePer = 1;
231
232  /**
233   * Returns a string describing classifier
234   *
235   * @return            a description suitable for displaying in the
236   *                    explorer/experimenter gui
237   */
238  public String globalInfo() {
239    return 
240        "Class for building a best-first decision tree classifier. "
241      + "This class uses binary split for both nominal and numeric attributes. "
242      + "For missing values, the method of 'fractional' instances is used.\n\n"
243      + "For more information, see:\n\n"
244      + getTechnicalInformation().toString();
245  }
246 
247  /**
248   * Returns an instance of a TechnicalInformation object, containing
249   * detailed information about the technical background of this class,
250   * e.g., paper reference or book this class is based on.
251   *
252   * @return the technical information about this class
253   */
254  public TechnicalInformation getTechnicalInformation() {
255    TechnicalInformation        result;
256    TechnicalInformation        additional;
257   
258    result = new TechnicalInformation(Type.MASTERSTHESIS);
259    result.setValue(Field.AUTHOR, "Haijian Shi");
260    result.setValue(Field.YEAR, "2007");
261    result.setValue(Field.TITLE, "Best-first decision tree learning");
262    result.setValue(Field.SCHOOL, "University of Waikato");
263    result.setValue(Field.ADDRESS, "Hamilton, NZ");
264    result.setValue(Field.NOTE, "COMP594");
265   
266    additional = result.add(Type.ARTICLE);
267    additional.setValue(Field.AUTHOR, "Jerome Friedman and Trevor Hastie and Robert Tibshirani");
268    additional.setValue(Field.YEAR, "2000");
269    additional.setValue(Field.TITLE, "Additive logistic regression : A statistical view of boosting");
270    additional.setValue(Field.JOURNAL, "Annals of statistics");
271    additional.setValue(Field.VOLUME, "28");
272    additional.setValue(Field.NUMBER, "2");
273    additional.setValue(Field.PAGES, "337-407");
274    additional.setValue(Field.ISSN, "0090-5364");
275   
276    return result;
277  }
278
279  /**
280   * Returns default capabilities of the classifier.
281   *
282   * @return            the capabilities of this classifier
283   */
284  public Capabilities getCapabilities() {
285    Capabilities result = super.getCapabilities();
286    result.disableAll();
287
288    // attributes
289    result.enable(Capability.NOMINAL_ATTRIBUTES);
290    result.enable(Capability.NUMERIC_ATTRIBUTES);
291    result.enable(Capability.MISSING_VALUES);
292
293    // class
294    result.enable(Capability.NOMINAL_CLASS);
295
296    return result;
297  }
298
299  /**
300   * Method for building a BestFirst decision tree classifier.
301   *
302   * @param data        set of instances serving as training data
303   * @throws Exception  if decision tree cannot be built successfully
304   */
305  public void buildClassifier(Instances data) throws Exception {
306
307    getCapabilities().testWithFail(data);
308    data = new Instances(data);
309    data.deleteWithMissingClass();
310
311    // build an unpruned tree
312    if (m_PruningStrategy == PRUNING_UNPRUNED) {
313
314      // calculate sorted indices, weights and initial class probabilities
315      int[][] sortedIndices = new int[data.numAttributes()][0];
316      double[][] weights = new double[data.numAttributes()][0];
317      double[] classProbs = new double[data.numClasses()];
318      double totalWeight = computeSortedInfo(data,sortedIndices, weights,classProbs);
319
320      // Compute information of the best split for this node (include split attribute,
321      // split value and gini gain (or information gain)). At the same time, compute
322      // variables dists, props and totalSubsetWeights.
323      double[][][] dists = new double[data.numAttributes()][2][data.numClasses()];
324      double[][] props = new double[data.numAttributes()][2];
325      double[][] totalSubsetWeights = new double[data.numAttributes()][2];
326      FastVector nodeInfo = computeSplitInfo(this, data, sortedIndices, weights, dists,
327          props, totalSubsetWeights, m_Heuristic, m_UseGini);
328
329      // add the node (with all split info) into BestFirstElements
330      FastVector BestFirstElements = new FastVector();
331      BestFirstElements.addElement(nodeInfo);
332
333      // Make the best-first decision tree.
334      int attIndex = ((Attribute)nodeInfo.elementAt(1)).index();
335      m_Expansion = 0;
336      makeTree(BestFirstElements, data, sortedIndices, weights, dists, classProbs,
337          totalWeight, props[attIndex] ,m_minNumObj, m_Heuristic, m_UseGini, m_FixedExpansion);
338
339      return;
340    }
341
342    // the following code is for pre-pruning and post-pruning methods
343
344    // Compute train data, test data, sorted indices, sorted weights, total weights,
345    // class probabilities, class distributions, branch proportions and total subset
346    // weights for root nodes of each fold for prepruning and postpruning.
347    int expansion = 0;
348
349    Random random = new Random(m_Seed);
350    Instances cvData = new Instances(data);
351    cvData.randomize(random);
352    cvData = new Instances(cvData,0,(int)(cvData.numInstances()*m_SizePer)-1);
353    cvData.stratify(m_numFoldsPruning);
354
355    Instances[] train = new Instances[m_numFoldsPruning];
356    Instances[] test = new Instances[m_numFoldsPruning];
357    FastVector[] parallelBFElements = new FastVector [m_numFoldsPruning];
358    BFTree[] m_roots = new BFTree[m_numFoldsPruning];
359
360    int[][][] sortedIndices = new int[m_numFoldsPruning][data.numAttributes()][0];
361    double[][][] weights = new double[m_numFoldsPruning][data.numAttributes()][0];
362    double[][] classProbs = new double[m_numFoldsPruning][data.numClasses()];
363    double[] totalWeight = new double[m_numFoldsPruning];
364
365    double[][][][] dists =
366      new double[m_numFoldsPruning][data.numAttributes()][2][data.numClasses()];
367    double[][][] props =
368      new double[m_numFoldsPruning][data.numAttributes()][2];
369    double[][][] totalSubsetWeights =
370      new double[m_numFoldsPruning][data.numAttributes()][2];
371    FastVector[] nodeInfo = new FastVector[m_numFoldsPruning];
372
373    for (int i = 0; i < m_numFoldsPruning; i++) {
374      train[i] = cvData.trainCV(m_numFoldsPruning, i);
375      test[i] = cvData.testCV(m_numFoldsPruning, i);
376      parallelBFElements[i] = new FastVector();
377      m_roots[i] = new BFTree();
378
379      // calculate sorted indices, weights, initial class counts and total weights for each training data
380      totalWeight[i] = computeSortedInfo(train[i],sortedIndices[i], weights[i],
381          classProbs[i]);
382
383      // compute information of the best split for this node (include split attribute,
384      // split value and gini gain (or information gain)) in this fold
385      nodeInfo[i] = computeSplitInfo(m_roots[i], train[i], sortedIndices[i],
386          weights[i], dists[i], props[i], totalSubsetWeights[i], m_Heuristic, m_UseGini);
387
388      // compute information for root nodes
389
390      int attIndex = ((Attribute)nodeInfo[i].elementAt(1)).index();
391
392      m_roots[i].m_SortedIndices = new int[sortedIndices[i].length][0];
393      m_roots[i].m_Weights = new double[weights[i].length][0];
394      m_roots[i].m_Dists = new double[dists[i].length][0][0];
395      m_roots[i].m_ClassProbs = new double[classProbs[i].length];
396      m_roots[i].m_Distribution = new double[classProbs[i].length];
397      m_roots[i].m_Props = new double[2];
398
399      for (int j=0; j<m_roots[i].m_SortedIndices.length; j++) {
400        m_roots[i].m_SortedIndices[j] = sortedIndices[i][j];
401        m_roots[i].m_Weights[j] = weights[i][j];
402        m_roots[i].m_Dists[j] = dists[i][j];
403      }
404
405      System.arraycopy(classProbs[i], 0, m_roots[i].m_ClassProbs, 0,
406          classProbs[i].length);
407      if (Utils.sum(m_roots[i].m_ClassProbs)!=0)
408        Utils.normalize(m_roots[i].m_ClassProbs);
409
410      System.arraycopy(classProbs[i], 0, m_roots[i].m_Distribution, 0,
411          classProbs[i].length);
412      System.arraycopy(props[i][attIndex], 0, m_roots[i].m_Props, 0,
413          props[i][attIndex].length);
414
415      m_roots[i].m_TotalWeight = totalWeight[i];
416
417      parallelBFElements[i].addElement(nodeInfo[i]);
418    }
419
420    // build a pre-pruned tree
421    if (m_PruningStrategy == PRUNING_PREPRUNING) {
422
423      double previousError = Double.MAX_VALUE;
424      double currentError = previousError;
425      double minError = Double.MAX_VALUE;
426      int minExpansion = 0;
427      FastVector errorList = new FastVector();
428      while(true) {
429        // compute average error
430        double expansionError = 0;
431        int count = 0;
432
433        for (int i=0; i<m_numFoldsPruning; i++) {
434          Evaluation eval;
435
436          // calculate error rate if only root node
437          if (expansion==0) {
438            m_roots[i].m_isLeaf = true;
439            eval = new Evaluation(test[i]);
440            eval.evaluateModel(m_roots[i], test[i]);
441            if (m_UseErrorRate) expansionError += eval.errorRate();
442            else expansionError += eval.rootMeanSquaredError();
443            count ++;
444          }
445
446          // make tree - expand one node at a time
447          else {
448            if (m_roots[i] == null) continue; // if the tree cannot be expanded, go to next fold
449            m_roots[i].m_isLeaf = false;
450            BFTree nodeToSplit = (BFTree)
451            (((FastVector)(parallelBFElements[i].elementAt(0))).elementAt(0));
452            if (!m_roots[i].makeTree(parallelBFElements[i], m_roots[i], train[i],
453                nodeToSplit.m_SortedIndices, nodeToSplit.m_Weights,
454                nodeToSplit.m_Dists, nodeToSplit.m_ClassProbs,
455                nodeToSplit.m_TotalWeight, nodeToSplit.m_Props, m_minNumObj,
456                m_Heuristic, m_UseGini)) {
457              m_roots[i] = null; // cannot be expanded
458              continue;
459            }
460            eval = new Evaluation(test[i]);
461            eval.evaluateModel(m_roots[i], test[i]);
462            if (m_UseErrorRate) expansionError += eval.errorRate();
463            else expansionError += eval.rootMeanSquaredError();
464            count ++;
465          }
466        }
467
468        // no tree can be expanded any more
469        if (count==0) break;
470
471        expansionError /=count;
472        errorList.addElement(new Double(expansionError));
473        currentError = expansionError;
474
475        if (!m_UseOneSE) {
476          if (currentError>previousError)
477            break;
478        }
479
480        else {
481          if (expansionError < minError) {
482            minError = expansionError;
483            minExpansion = expansion;
484          }
485
486          if (currentError>previousError) {
487            double oneSE = Math.sqrt(minError*(1-minError)/
488                data.numInstances());
489            if (currentError > minError + oneSE) {
490              break;
491            }
492          }
493        }
494
495        expansion ++;
496        previousError = currentError;
497      }
498
499      if (!m_UseOneSE) expansion = expansion - 1;
500      else {
501        double oneSE = Math.sqrt(minError*(1-minError)/data.numInstances());
502        for (int i=0; i<errorList.size(); i++) {
503          double error = ((Double)(errorList.elementAt(i))).doubleValue();
504          if (error<=minError + oneSE) { // && counts[i]>=m_numFoldsPruning/2) {
505            expansion = i;
506            break;
507          }
508        }
509      }
510    }
511
512    // build a postpruned tree
513    else {
514      FastVector[] modelError = new FastVector[m_numFoldsPruning];
515
516      // calculate error of each expansion for each fold
517      for (int i = 0; i < m_numFoldsPruning; i++) {
518        modelError[i] = new FastVector();
519
520        m_roots[i].m_isLeaf = true;
521        Evaluation eval = new Evaluation(test[i]);
522        eval.evaluateModel(m_roots[i], test[i]);
523        double error;
524        if (m_UseErrorRate) error = eval.errorRate();
525        else error = eval.rootMeanSquaredError();
526        modelError[i].addElement(new Double(error));
527
528        m_roots[i].m_isLeaf = false;
529        BFTree nodeToSplit = (BFTree)
530        (((FastVector)(parallelBFElements[i].elementAt(0))).elementAt(0));
531
532        m_roots[i].makeTree(parallelBFElements[i], m_roots[i], train[i], test[i],
533            modelError[i],nodeToSplit.m_SortedIndices, nodeToSplit.m_Weights,
534            nodeToSplit.m_Dists, nodeToSplit.m_ClassProbs,
535            nodeToSplit.m_TotalWeight, nodeToSplit.m_Props, m_minNumObj,
536            m_Heuristic, m_UseGini, m_UseErrorRate);
537        m_roots[i] = null;
538      }
539
540      // find the expansion with minimal error rate
541      double minError = Double.MAX_VALUE;
542
543      int maxExpansion = modelError[0].size();
544      for (int i=1; i<modelError.length; i++) {
545        if (modelError[i].size()>maxExpansion)
546          maxExpansion = modelError[i].size();
547      }
548
549      double[] error = new double[maxExpansion];
550      int[] counts = new int[maxExpansion];
551      for (int i=0; i<maxExpansion; i++) {
552        counts[i] = 0;
553        error[i] = 0;
554        for (int j=0; j<m_numFoldsPruning; j++) {
555          if (i<modelError[j].size()) {
556            error[i] += ((Double)modelError[j].elementAt(i)).doubleValue();
557            counts[i]++;
558          }
559        }
560        error[i] = error[i]/counts[i]; //average error for each expansion
561
562        if (error[i]<minError) {// && counts[i]>=m_numFoldsPruning/2) {
563          minError = error[i];
564          expansion = i;
565        }
566      }
567
568      // the 1 SE rule choosen
569      if (m_UseOneSE) {
570        double oneSE = Math.sqrt(minError*(1-minError)/
571            data.numInstances());
572        for (int i=0; i<maxExpansion; i++) {
573          if (error[i]<=minError + oneSE) { // && counts[i]>=m_numFoldsPruning/2) {
574            expansion = i;
575            break;
576          }
577        }
578      }
579    }
580
581    // make tree on all data based on the expansion caculated
582    // from cross-validation
583
584    // calculate sorted indices, weights and initial class counts
585    int[][] prune_sortedIndices = new int[data.numAttributes()][0];
586    double[][] prune_weights = new double[data.numAttributes()][0];
587    double[] prune_classProbs = new double[data.numClasses()];
588    double prune_totalWeight = computeSortedInfo(data, prune_sortedIndices,
589        prune_weights, prune_classProbs);
590
591    // compute information of the best split for this node (include split attribute,
592    // split value and gini gain)
593    double[][][] prune_dists = new double[data.numAttributes()][2][data.numClasses()];
594    double[][] prune_props = new double[data.numAttributes()][2];
595    double[][] prune_totalSubsetWeights = new double[data.numAttributes()][2];
596    FastVector prune_nodeInfo = computeSplitInfo(this, data, prune_sortedIndices,
597        prune_weights, prune_dists, prune_props, prune_totalSubsetWeights, m_Heuristic,m_UseGini);
598
599    // add the root node (with its split info) to BestFirstElements
600    FastVector BestFirstElements = new FastVector();
601    BestFirstElements.addElement(prune_nodeInfo);
602
603    int attIndex = ((Attribute)prune_nodeInfo.elementAt(1)).index();
604    m_Expansion = 0;
605    makeTree(BestFirstElements, data, prune_sortedIndices, prune_weights, prune_dists,
606        prune_classProbs, prune_totalWeight, prune_props[attIndex] ,m_minNumObj,
607        m_Heuristic, m_UseGini, expansion);
608  }
609
610  /**
611   * Recursively build a best-first decision tree.
612   * Method for building a Best-First tree for a given number of expansions.
613   * preExpasion is -1 means that no expansion is specified (just for a
614   * tree without any pruning method). Pre-pruning and post-pruning methods also
615   * use this method to build the final tree on all training data based on the
616   * expansion calculated from internal cross-validation.
617   *
618   * @param BestFirstElements   list to store BFTree nodes
619   * @param data                training data
620   * @param sortedIndices       sorted indices of the instances
621   * @param weights             weights of the instances
622   * @param dists               class distributions for each attribute
623   * @param classProbs          class probabilities of this node
624   * @param totalWeight         total weight of this node (note if the node
625   *                            can not split, this value is not calculated.)
626   * @param branchProps         proportions of two subbranches
627   * @param minNumObj           minimal number of instances at leaf nodes
628   * @param useHeuristic        if use heuristic search for nominal attributes
629   *                            in multi-class problem
630   * @param useGini             if use Gini index as splitting criterion
631   * @param preExpansion        the number of expansions the tree to be expanded
632   * @throws Exception          if something goes wrong
633   */
634  protected void makeTree(FastVector BestFirstElements,Instances data,
635      int[][] sortedIndices, double[][] weights, double[][][] dists,
636      double[] classProbs, double totalWeight, double[] branchProps,
637      int minNumObj, boolean useHeuristic, boolean useGini, int preExpansion)
638        throws Exception {
639
640    if (BestFirstElements.size()==0) return;
641
642    ///////////////////////////////////////////////////////////////////////
643    // All information about the node to split (the first BestFirst object in
644    // BestFirstElements)
645    FastVector firstElement = (FastVector)BestFirstElements.elementAt(0);
646
647    // split attribute
648    Attribute att = (Attribute)firstElement.elementAt(1);
649
650    // info of split value or split string
651    double splitValue = Double.NaN;
652    String splitStr = null;
653    if (att.isNumeric())
654      splitValue = ((Double)firstElement.elementAt(2)).doubleValue();
655    else {
656      splitStr=((String)firstElement.elementAt(2)).toString();
657    }
658
659    // the best gini gain or information gain of this node
660    double gain = ((Double)firstElement.elementAt(3)).doubleValue();
661    ///////////////////////////////////////////////////////////////////////
662
663    if (m_ClassProbs==null) {
664      m_SortedIndices = new int[sortedIndices.length][0];
665      m_Weights = new double[weights.length][0];
666      m_Dists = new double[dists.length][0][0];
667      m_ClassProbs = new double[classProbs.length];
668      m_Distribution = new double[classProbs.length];
669      m_Props = new double[2];
670
671      for (int i=0; i<m_SortedIndices.length; i++) {
672        m_SortedIndices[i] = sortedIndices[i];
673        m_Weights[i] = weights[i];
674        m_Dists[i] = dists[i];
675      }
676
677      System.arraycopy(classProbs, 0, m_ClassProbs, 0, classProbs.length);
678      System.arraycopy(classProbs, 0, m_Distribution, 0, classProbs.length);
679      System.arraycopy(branchProps, 0, m_Props, 0, m_Props.length);
680      m_TotalWeight = totalWeight;
681      if (Utils.sum(m_ClassProbs)!=0) Utils.normalize(m_ClassProbs);
682    }
683
684    // If no enough data or this node can not be split, find next node to split.
685    if (totalWeight < 2*minNumObj || branchProps[0]==0
686        || branchProps[1]==0) {
687      // remove the first element
688      BestFirstElements.removeElementAt(0);
689
690      makeLeaf(data);
691      if (BestFirstElements.size()!=0) {
692        FastVector nextSplitElement = (FastVector)BestFirstElements.elementAt(0);
693        BFTree nextSplitNode = (BFTree)nextSplitElement.elementAt(0);
694        nextSplitNode.makeTree(BestFirstElements,data,
695            nextSplitNode.m_SortedIndices, nextSplitNode.m_Weights,
696            nextSplitNode.m_Dists,
697            nextSplitNode.m_ClassProbs, nextSplitNode.m_TotalWeight,
698            nextSplitNode.m_Props, minNumObj, useHeuristic, useGini, preExpansion);
699      }
700      return;
701    }
702
703    // If gini gain or information gain is 0, make all nodes in the BestFirstElements leaf nodes
704    // because these nodes are sorted descendingly according to gini gain or information gain.
705    // (namely, gini gain or information gain of all nodes in BestFirstEelements is 0).
706    if (gain==0 || preExpansion==m_Expansion) {
707      for (int i=0; i<BestFirstElements.size(); i++) {
708        FastVector element = (FastVector)BestFirstElements.elementAt(i);
709        BFTree node = (BFTree)element.elementAt(0);
710        node.makeLeaf(data);
711      }
712      BestFirstElements.removeAllElements();
713    }
714
715    // gain is not 0
716    else {
717      // remove the first element
718      BestFirstElements.removeElementAt(0);
719
720      m_Attribute = att;
721      if (m_Attribute.isNumeric()) m_SplitValue = splitValue;
722      else m_SplitString = splitStr;
723
724      int[][][] subsetIndices = new int[2][data.numAttributes()][0];
725      double[][][] subsetWeights = new double[2][data.numAttributes()][0];
726
727      splitData(subsetIndices, subsetWeights, m_Attribute, m_SplitValue,
728          m_SplitString, sortedIndices, weights, data);
729
730      // If split will generate node(s) which has total weights less than m_minNumObj,
731      // do not split.
732      int attIndex = att.index();
733      if (subsetIndices[0][attIndex].length<minNumObj ||
734          subsetIndices[1][attIndex].length<minNumObj) {
735        makeLeaf(data);
736      }
737
738      // split the node
739      else {
740        m_isLeaf = false;
741        m_Attribute = att;
742
743        // if expansion is specified (if pruning method used)
744        if (    (m_PruningStrategy == PRUNING_PREPRUNING) 
745             || (m_PruningStrategy == PRUNING_POSTPRUNING)
746             || (preExpansion != -1)) 
747          m_Expansion++;
748
749        makeSuccessors(BestFirstElements,data,subsetIndices,subsetWeights,dists,
750            att,useHeuristic, useGini);
751      }
752
753      // choose next node to split
754      if (BestFirstElements.size()!=0) {
755        FastVector nextSplitElement = (FastVector)BestFirstElements.elementAt(0);
756        BFTree nextSplitNode = (BFTree)nextSplitElement.elementAt(0);
757        nextSplitNode.makeTree(BestFirstElements,data,
758            nextSplitNode.m_SortedIndices, nextSplitNode.m_Weights,
759            nextSplitNode.m_Dists,
760            nextSplitNode.m_ClassProbs, nextSplitNode.m_TotalWeight,
761            nextSplitNode.m_Props, minNumObj, useHeuristic, useGini, preExpansion);
762      }
763
764    }
765  }
766
767  /**
768   * This method is to find the number of expansions based on internal
769   * cross-validation for just pre-pruning. It expands the first BestFirst
770   * node in the BestFirstElements if it is expansible, otherwise it looks
771   * for next exapansible node. If it finds a node is expansibel, expand the
772   * node, then return true. (note it just expands one node at a time).
773   *
774   * @param BestFirstElements   list to store BFTree nodes
775   * @param root                root node of tree in each fold
776   * @param train               training data
777   * @param sortedIndices       sorted indices of the instances
778   * @param weights             weights of the instances
779   * @param dists               class distributions for each attribute
780   * @param classProbs          class probabilities of this node
781   * @param totalWeight         total weight of this node (note if the node
782   *                            can not split, this value is not calculated.)
783   * @param branchProps         proportions of two subbranches
784   * @param minNumObj   minimal number of instances at leaf nodes
785   * @param useHeuristic        if use heuristic search for nominal attributes
786   *                            in multi-class problem
787   * @param useGini             if use Gini index as splitting criterion
788   * @return true               if expand successfully, otherwise return false
789   *                            (all nodes in BestFirstElements cannot be
790   *                            expanded).
791   * @throws Exception          if something goes wrong
792   */
793  protected boolean makeTree(FastVector BestFirstElements, BFTree root,
794      Instances train, int[][] sortedIndices, double[][] weights,
795      double[][][] dists, double[] classProbs, double totalWeight,
796      double[] branchProps, int minNumObj, boolean useHeuristic, boolean useGini)
797  throws Exception {
798
799    if (BestFirstElements.size()==0) return false;
800
801    ///////////////////////////////////////////////////////////////////////
802    // All information about the node to split (first BestFirst object in
803    // BestFirstElements)
804    FastVector firstElement = (FastVector)BestFirstElements.elementAt(0);
805
806    // node to split
807    BFTree nodeToSplit = (BFTree)firstElement.elementAt(0);
808
809    // split attribute
810    Attribute att = (Attribute)firstElement.elementAt(1);
811
812    // info of split value or split string
813    double splitValue = Double.NaN;
814    String splitStr = null;
815    if (att.isNumeric())
816      splitValue = ((Double)firstElement.elementAt(2)).doubleValue();
817    else {
818      splitStr=((String)firstElement.elementAt(2)).toString();
819    }
820
821    // the best gini gain or information gain of this node
822    double gain = ((Double)firstElement.elementAt(3)).doubleValue();
823    ///////////////////////////////////////////////////////////////////////
824
825    // If no enough data to split for this node or this node can not be split find next node to split.
826    if (totalWeight < 2*minNumObj || branchProps[0]==0
827        || branchProps[1]==0) {
828      // remove the first element
829      BestFirstElements.removeElementAt(0);
830      nodeToSplit.makeLeaf(train);
831      BFTree nextNode = (BFTree)
832      ((FastVector)BestFirstElements.elementAt(0)).elementAt(0);
833      return root.makeTree(BestFirstElements, root, train,
834          nextNode.m_SortedIndices, nextNode.m_Weights, nextNode.m_Dists,
835          nextNode.m_ClassProbs, nextNode.m_TotalWeight,
836          nextNode.m_Props, minNumObj, useHeuristic, useGini);
837    }
838
839    // If gini gain or information is 0, make all nodes in the BestFirstElements leaf nodes
840    // because these node sorted descendingly according to gini gain or information gain.
841    // (namely, gini gain or information gain of all nodes in BestFirstEelements is 0).
842    if (gain==0) {
843      for (int i=0; i<BestFirstElements.size(); i++) {
844        FastVector element = (FastVector)BestFirstElements.elementAt(i);
845        BFTree node = (BFTree)element.elementAt(0);
846        node.makeLeaf(train);
847      }
848      BestFirstElements.removeAllElements();
849      return false;
850    }
851
852    else {
853      // remove the first element
854      BestFirstElements.removeElementAt(0);
855      nodeToSplit.m_Attribute = att;
856      if (att.isNumeric()) nodeToSplit.m_SplitValue = splitValue;
857      else nodeToSplit.m_SplitString = splitStr;
858
859      int[][][] subsetIndices = new int[2][train.numAttributes()][0];
860      double[][][] subsetWeights = new double[2][train.numAttributes()][0];
861
862      splitData(subsetIndices, subsetWeights, nodeToSplit.m_Attribute,
863          nodeToSplit.m_SplitValue, nodeToSplit.m_SplitString,
864          nodeToSplit.m_SortedIndices, nodeToSplit.m_Weights, train);
865
866      // if split will generate node(s) which has total weights less than m_minNumObj,
867      // do not split
868      int attIndex = att.index();
869      if (subsetIndices[0][attIndex].length<minNumObj ||
870          subsetIndices[1][attIndex].length<minNumObj) {
871
872        nodeToSplit.makeLeaf(train);
873        BFTree nextNode = (BFTree)
874        ((FastVector)BestFirstElements.elementAt(0)).elementAt(0);
875        return root.makeTree(BestFirstElements, root, train,
876            nextNode.m_SortedIndices, nextNode.m_Weights, nextNode.m_Dists,
877            nextNode.m_ClassProbs, nextNode.m_TotalWeight,
878            nextNode.m_Props, minNumObj, useHeuristic, useGini);
879      }
880
881      // split the node
882      else {
883        nodeToSplit.m_isLeaf = false;
884        nodeToSplit.m_Attribute = att;
885
886        nodeToSplit.makeSuccessors(BestFirstElements,train,subsetIndices,
887            subsetWeights,dists, nodeToSplit.m_Attribute,useHeuristic,useGini);
888
889        for (int i=0; i<2; i++){
890          nodeToSplit.m_Successors[i].makeLeaf(train);
891        }
892
893        return true;
894      }
895    }
896  }
897
898  /**
899   * This method is to find the number of expansions based on internal
900   * cross-validation for just post-pruning. It expands the first BestFirst
901   * node in the BestFirstElements until no node can be split. When building
902   * the tree, stroe error for each temporary tree, namely for each expansion.
903   *
904   * @param BestFirstElements   list to store BFTree nodes
905   * @param root                root node of tree in each fold
906   * @param train               training data in each fold
907   * @param test                test data in each fold
908   * @param modelError          list to store error for each expansion in
909   *                            each fold
910   * @param sortedIndices       sorted indices of the instances
911   * @param weights             weights of the instances
912   * @param dists               class distributions for each attribute
913   * @param classProbs          class probabilities of this node
914   * @param totalWeight         total weight of this node (note if the node
915   *                            can not split, this value is not calculated.)
916   * @param branchProps         proportions of two subbranches
917   * @param minNumObj           minimal number of instances at leaf nodes
918   * @param useHeuristic        if use heuristic search for nominal attributes
919   *                            in multi-class problem
920   * @param useGini             if use Gini index as splitting criterion
921   * @param useErrorRate        if use error rate in internal cross-validation
922   * @throws Exception          if something goes wrong
923   */
924  protected void makeTree(FastVector BestFirstElements, BFTree root,
925      Instances train, Instances test, FastVector modelError, int[][] sortedIndices,
926      double[][] weights, double[][][] dists, double[] classProbs, double totalWeight,
927      double[] branchProps, int minNumObj, boolean useHeuristic, boolean useGini, boolean useErrorRate)
928  throws Exception {
929
930    if (BestFirstElements.size()==0) return;
931
932    ///////////////////////////////////////////////////////////////////////
933    // All information about the node to split (first BestFirst object in
934    // BestFirstElements)
935    FastVector firstElement = (FastVector)BestFirstElements.elementAt(0);
936
937    // node to split
938    //BFTree nodeToSplit = (BFTree)firstElement.elementAt(0);
939
940    // split attribute
941    Attribute att = (Attribute)firstElement.elementAt(1);
942
943    // info of split value or split string
944    double splitValue = Double.NaN;
945    String splitStr = null;
946    if (att.isNumeric())
947      splitValue = ((Double)firstElement.elementAt(2)).doubleValue();
948    else {
949      splitStr=((String)firstElement.elementAt(2)).toString();
950    }
951
952    // the best gini gain or information of this node
953    double gain = ((Double)firstElement.elementAt(3)).doubleValue();
954    ///////////////////////////////////////////////////////////////////////
955
956    if (totalWeight < 2*minNumObj || branchProps[0]==0
957        || branchProps[1]==0) {
958      // remove the first element
959      BestFirstElements.removeElementAt(0);
960      makeLeaf(train);
961      if (BestFirstElements.size() == 0) {
962        return;
963      }
964
965      BFTree nextSplitNode = (BFTree)
966      ((FastVector)BestFirstElements.elementAt(0)).elementAt(0);
967      nextSplitNode.makeTree(BestFirstElements, root, train, test, modelError,
968          nextSplitNode.m_SortedIndices, nextSplitNode.m_Weights,
969          nextSplitNode.m_Dists, nextSplitNode.m_ClassProbs,
970          nextSplitNode.m_TotalWeight, nextSplitNode.m_Props, minNumObj,
971          useHeuristic, useGini, useErrorRate);
972      return;
973
974    }
975
976    // If gini gain or information gain is 0, make all nodes in the BestFirstElements leaf nodes
977    // because these node sorted descendingly according to gini gain or information gain.
978    // (namely, gini gain or information gain of all nodes in BestFirstEelements is 0).
979    if (gain==0) {
980      for (int i=0; i<BestFirstElements.size(); i++) {
981        FastVector element = (FastVector)BestFirstElements.elementAt(i);
982        BFTree node = (BFTree)element.elementAt(0);
983        node.makeLeaf(train);
984      }
985      BestFirstElements.removeAllElements();
986    }
987
988    // gini gain or information gain is not 0
989    else {
990      // remove the first element
991      BestFirstElements.removeElementAt(0);
992      m_Attribute = att;
993      if (att.isNumeric()) m_SplitValue = splitValue;
994      else m_SplitString = splitStr;
995
996      int[][][] subsetIndices = new int[2][train.numAttributes()][0];
997      double[][][] subsetWeights = new double[2][train.numAttributes()][0];
998
999      splitData(subsetIndices, subsetWeights, m_Attribute,
1000          m_SplitValue, m_SplitString,
1001          sortedIndices, weights, train);
1002
1003      // if split will generate node(s) which has total weights less than m_minNumObj,
1004      // do not split
1005      int attIndex = att.index();
1006      if (subsetIndices[0][attIndex].length<minNumObj ||
1007          subsetIndices[1][attIndex].length<minNumObj) {
1008        makeLeaf(train);
1009      }
1010
1011      // split the node and cauculate error rate of this temporary tree
1012      else {
1013        m_isLeaf = false;
1014        m_Attribute = att;
1015
1016        makeSuccessors(BestFirstElements,train,subsetIndices,
1017            subsetWeights,dists, m_Attribute, useHeuristic, useGini);
1018        for (int i=0; i<2; i++){
1019          m_Successors[i].makeLeaf(train);
1020        }
1021
1022        Evaluation eval = new Evaluation(test);
1023        eval.evaluateModel(root, test);
1024        double error;
1025        if (useErrorRate) error = eval.errorRate();
1026        else error = eval.rootMeanSquaredError();
1027        modelError.addElement(new Double(error));
1028      }
1029
1030      if (BestFirstElements.size()!=0) {
1031        FastVector nextSplitElement = (FastVector)BestFirstElements.elementAt(0);
1032        BFTree nextSplitNode = (BFTree)nextSplitElement.elementAt(0);
1033        nextSplitNode.makeTree(BestFirstElements, root, train, test, modelError,
1034            nextSplitNode.m_SortedIndices, nextSplitNode.m_Weights,
1035            nextSplitNode.m_Dists, nextSplitNode.m_ClassProbs,
1036            nextSplitNode.m_TotalWeight, nextSplitNode.m_Props, minNumObj,
1037            useHeuristic, useGini,useErrorRate);
1038      }
1039    }
1040  }
1041
1042
1043  /**
1044   * Generate successor nodes for a node and put them into BestFirstElements
1045   * according to gini gain or information gain in a descending order.
1046   *
1047   * @param BestFirstElements   list to store BestFirst nodes
1048   * @param data                training instance
1049   * @param subsetSortedIndices sorted indices of instances of successor nodes
1050   * @param subsetWeights       weights of instances of successor nodes
1051   * @param dists               class distributions of successor nodes
1052   * @param att                 attribute used to split the node
1053   * @param useHeuristic        if use heuristic search for nominal attributes in multi-class problem
1054   * @param useGini             if use Gini index as splitting criterion
1055   * @throws Exception          if something goes wrong
1056   */
1057  protected void makeSuccessors(FastVector BestFirstElements,Instances data,
1058      int[][][] subsetSortedIndices, double[][][] subsetWeights,
1059      double[][][] dists,
1060      Attribute att, boolean useHeuristic, boolean useGini) throws Exception {
1061
1062    m_Successors = new BFTree[2];
1063
1064    for (int i=0; i<2; i++) {
1065      m_Successors[i] = new BFTree();
1066      m_Successors[i].m_isLeaf = true;
1067
1068      // class probability and distribution for this successor node
1069      m_Successors[i].m_ClassProbs = new double[data.numClasses()];
1070      m_Successors[i].m_Distribution = new double[data.numClasses()];
1071      System.arraycopy(dists[att.index()][i], 0, m_Successors[i].m_ClassProbs,
1072          0,m_Successors[i].m_ClassProbs.length);
1073      System.arraycopy(dists[att.index()][i], 0, m_Successors[i].m_Distribution,
1074          0,m_Successors[i].m_Distribution.length);
1075      if (Utils.sum(m_Successors[i].m_ClassProbs)!=0)
1076        Utils.normalize(m_Successors[i].m_ClassProbs);
1077
1078      // split information for this successor node
1079      double[][] props = new double[data.numAttributes()][2];
1080      double[][][] subDists = new double[data.numAttributes()][2][data.numClasses()];
1081      double[][] totalSubsetWeights = new double[data.numAttributes()][2];
1082      FastVector splitInfo = m_Successors[i].computeSplitInfo(m_Successors[i], data,
1083          subsetSortedIndices[i], subsetWeights[i], subDists, props,
1084          totalSubsetWeights, useHeuristic, useGini);
1085
1086      // branch proportion for this successor node
1087      int splitIndex = ((Attribute)splitInfo.elementAt(1)).index();
1088      m_Successors[i].m_Props = new double[2];
1089      System.arraycopy(props[splitIndex], 0, m_Successors[i].m_Props, 0,
1090          m_Successors[i].m_Props.length);
1091
1092      // sorted indices and weights of each attribute for this successor node
1093      m_Successors[i].m_SortedIndices = new int[data.numAttributes()][0];
1094      m_Successors[i].m_Weights = new double[data.numAttributes()][0];
1095      for (int j=0; j<m_Successors[i].m_SortedIndices.length; j++) {
1096        m_Successors[i].m_SortedIndices[j] = subsetSortedIndices[i][j];
1097        m_Successors[i].m_Weights[j] = subsetWeights[i][j];
1098      }
1099
1100      // distribution of each attribute for this successor node
1101      m_Successors[i].m_Dists = new double[data.numAttributes()][2][data.numClasses()];
1102      for (int j=0; j<subDists.length; j++) {
1103        m_Successors[i].m_Dists[j] = subDists[j];
1104      }
1105
1106      // total weights for this successor node.
1107      m_Successors[i].m_TotalWeight = Utils.sum(totalSubsetWeights[splitIndex]);
1108
1109      // insert this successor node into BestFirstElements according to gini gain or information gain
1110      //  descendingly
1111      if (BestFirstElements.size()==0) {
1112        BestFirstElements.addElement(splitInfo);
1113      } else {
1114        double gGain = ((Double)(splitInfo.elementAt(3))).doubleValue();
1115        int vectorSize = BestFirstElements.size();
1116        FastVector lastNode = (FastVector)BestFirstElements.elementAt(vectorSize-1);
1117
1118        // If gini gain is less than that of last node in FastVector
1119        if (gGain<((Double)(lastNode.elementAt(3))).doubleValue()) {
1120          BestFirstElements.insertElementAt(splitInfo, vectorSize);
1121        } else {
1122          for (int j=0; j<vectorSize; j++) {
1123            FastVector node = (FastVector)BestFirstElements.elementAt(j);
1124            double nodeGain = ((Double)(node.elementAt(3))).doubleValue();
1125            if (gGain>=nodeGain) {
1126              BestFirstElements.insertElementAt(splitInfo, j);
1127              break;
1128            }
1129          }
1130        }
1131      }
1132    }
1133  }
1134
1135  /**
1136   * Compute sorted indices, weights and class probabilities for a given
1137   * dataset. Return total weights of the data at the node.
1138   *
1139   * @param data                training data
1140   * @param sortedIndices       sorted indices of instances at the node
1141   * @param weights             weights of instances at the node
1142   * @param classProbs          class probabilities at the node
1143   * @return                    total weights of instances at the node
1144   * @throws Exception          if something goes wrong
1145   */
1146  protected double computeSortedInfo(Instances data, int[][] sortedIndices, double[][] weights,
1147      double[] classProbs) throws Exception {
1148
1149    // Create array of sorted indices and weights
1150    double[] vals = new double[data.numInstances()];
1151    for (int j = 0; j < data.numAttributes(); j++) {
1152      if (j==data.classIndex()) continue;
1153      weights[j] = new double[data.numInstances()];
1154
1155      if (data.attribute(j).isNominal()) {
1156
1157        // Handling nominal attributes. Putting indices of
1158        // instances with missing values at the end.
1159        sortedIndices[j] = new int[data.numInstances()];
1160        int count = 0;
1161        for (int i = 0; i < data.numInstances(); i++) {
1162          Instance inst = data.instance(i);
1163          if (!inst.isMissing(j)) {
1164            sortedIndices[j][count] = i;
1165            weights[j][count] = inst.weight();
1166            count++;
1167          }
1168        }
1169        for (int i = 0; i < data.numInstances(); i++) {
1170          Instance inst = data.instance(i);
1171          if (inst.isMissing(j)) {
1172            sortedIndices[j][count] = i;
1173            weights[j][count] = inst.weight();
1174            count++;
1175          }
1176        }
1177      } else {
1178
1179        // Sorted indices are computed for numeric attributes
1180        // missing values instances are put to end (through Utils.sort() method)
1181        for (int i = 0; i < data.numInstances(); i++) {
1182          Instance inst = data.instance(i);
1183          vals[i] = inst.value(j);
1184        }
1185        sortedIndices[j] = Utils.sort(vals);
1186        for (int i = 0; i < data.numInstances(); i++) {
1187          weights[j][i] = data.instance(sortedIndices[j][i]).weight();
1188        }
1189      }
1190    }
1191
1192    // Compute initial class counts and total weight
1193    double totalWeight = 0;
1194    for (int i = 0; i < data.numInstances(); i++) {
1195      Instance inst = data.instance(i);
1196      classProbs[(int)inst.classValue()] += inst.weight();
1197      totalWeight += inst.weight();
1198    }
1199
1200    return totalWeight;
1201  }
1202
1203  /**
1204   * Compute the best splitting attribute, split point or subset and the best
1205   * gini gain or iformation gain for a given dataset.
1206   *
1207   * @param node                node to be split
1208   * @param data                training data
1209   * @param sortedIndices       sorted indices of the instances
1210   * @param weights             weights of the instances
1211   * @param dists               class distributions for each attribute
1212   * @param props               proportions of two branches
1213   * @param totalSubsetWeights  total weight of two subsets
1214   * @param useHeuristic        if use heuristic search for nominal attributes
1215   *                            in multi-class problem
1216   * @param useGini             if use Gini index as splitting criterion
1217   * @return                    split information about the node
1218   * @throws Exception          if something is wrong
1219   */
1220  protected FastVector computeSplitInfo(BFTree node, Instances data, int[][] sortedIndices,
1221      double[][] weights, double[][][] dists, double[][] props,
1222      double[][] totalSubsetWeights, boolean useHeuristic, boolean useGini) throws Exception {
1223
1224    double[] splits = new double[data.numAttributes()];
1225    String[] splitString = new String[data.numAttributes()];
1226    double[] gains = new double[data.numAttributes()];
1227
1228    for (int i = 0; i < data.numAttributes(); i++) {
1229      if (i==data.classIndex()) continue;
1230      Attribute att = data.attribute(i);
1231      if (att.isNumeric()) {
1232        // numeric attribute
1233        splits[i] = numericDistribution(props, dists, att, sortedIndices[i],
1234            weights[i], totalSubsetWeights, gains, data, useGini);
1235      } else {
1236        // nominal attribute
1237        splitString[i] = nominalDistribution(props, dists, att, sortedIndices[i],
1238            weights[i], totalSubsetWeights, gains, data, useHeuristic, useGini);
1239      }
1240    }
1241
1242    int index = Utils.maxIndex(gains);
1243    double mBestGain = gains[index];
1244
1245    Attribute att = data.attribute(index);
1246    double mValue =Double.NaN;
1247    String mString = null;
1248    if (att.isNumeric())  mValue= splits[index];
1249    else {
1250      mString = splitString[index];
1251      if (mString==null) mString = "";
1252    }
1253
1254    // split information
1255    FastVector splitInfo = new FastVector();
1256    splitInfo.addElement(node);
1257    splitInfo.addElement(att);
1258    if (att.isNumeric()) splitInfo.addElement(new Double(mValue));
1259    else splitInfo.addElement(mString);
1260    splitInfo.addElement(new Double(mBestGain));
1261
1262    return splitInfo;
1263  }
1264
1265  /**
1266   * Compute distributions, proportions and total weights of two successor nodes for
1267   * a given numeric attribute.
1268   *
1269   * @param props               proportions of each two branches for each attribute
1270   * @param dists               class distributions of two branches for each attribute
1271   * @param att                 numeric att split on
1272   * @param sortedIndices       sorted indices of instances for the attirubte
1273   * @param weights             weights of instances for the attirbute
1274   * @param subsetWeights       total weight of two branches split based on the attribute
1275   * @param gains               Gini gains or information gains for each attribute
1276   * @param data                training instances
1277   * @param useGini             if use Gini index as splitting criterion
1278   * @return                    Gini gain or information gain for the given attribute
1279   * @throws Exception          if something goes wrong
1280   */
1281  protected double numericDistribution(double[][] props, double[][][] dists,
1282      Attribute att, int[] sortedIndices, double[] weights, double[][] subsetWeights,
1283      double[] gains, Instances data, boolean useGini)
1284  throws Exception {
1285
1286    double splitPoint = Double.NaN;
1287    double[][] dist = null;
1288    int numClasses = data.numClasses();
1289    int i; // differ instances with or without missing values
1290
1291    double[][] currDist = new double[2][numClasses];
1292    dist = new double[2][numClasses];
1293
1294    // Move all instances without missing values into second subset
1295    double[] parentDist = new double[numClasses];
1296    int missingStart = 0;
1297    for (int j = 0; j < sortedIndices.length; j++) {
1298      Instance inst = data.instance(sortedIndices[j]);
1299      if (!inst.isMissing(att)) {
1300        missingStart ++;
1301        currDist[1][(int)inst.classValue()] += weights[j];
1302      }
1303      parentDist[(int)inst.classValue()] += weights[j];
1304    }
1305    System.arraycopy(currDist[1], 0, dist[1], 0, dist[1].length);
1306
1307    // Try all possible split points
1308    double currSplit = data.instance(sortedIndices[0]).value(att);
1309    double currGain;
1310    double bestGain = -Double.MAX_VALUE;
1311
1312    for (i = 0; i < sortedIndices.length; i++) {
1313      Instance inst = data.instance(sortedIndices[i]);
1314      if (inst.isMissing(att)) {
1315        break;
1316      }
1317      if (inst.value(att) > currSplit) {
1318
1319        double[][] tempDist = new double[2][numClasses];
1320        for (int k=0; k<2; k++) {
1321          //tempDist[k] = currDist[k];
1322          System.arraycopy(currDist[k], 0, tempDist[k], 0, tempDist[k].length);
1323        }
1324
1325        double[] tempProps = new double[2];
1326        for (int k=0; k<2; k++) {
1327          tempProps[k] = Utils.sum(tempDist[k]);
1328        }
1329
1330        if (Utils.sum(tempProps) !=0) Utils.normalize(tempProps);
1331
1332        // split missing values
1333        int index = missingStart;
1334        while (index < sortedIndices.length) {
1335          Instance insta = data.instance(sortedIndices[index]);
1336          for (int j = 0; j < 2; j++) {
1337            tempDist[j][(int)insta.classValue()] += tempProps[j] * weights[index];
1338          }
1339          index++;
1340        }
1341
1342        if (useGini) currGain = computeGiniGain(parentDist,tempDist);
1343        else currGain = computeInfoGain(parentDist,tempDist);
1344
1345        if (currGain > bestGain) {
1346          bestGain = currGain;
1347          // clean split point
1348          splitPoint = Math.rint((inst.value(att) + currSplit)/2.0*100000)/100000.0;
1349          for (int j = 0; j < currDist.length; j++) {
1350            System.arraycopy(tempDist[j], 0, dist[j], 0,
1351                dist[j].length);
1352          }
1353        }
1354      }
1355      currSplit = inst.value(att);
1356      currDist[0][(int)inst.classValue()] += weights[i];
1357      currDist[1][(int)inst.classValue()] -= weights[i];
1358    }
1359
1360    // Compute weights
1361    int attIndex = att.index();
1362    props[attIndex] = new double[2];
1363    for (int k = 0; k < 2; k++) {
1364      props[attIndex][k] = Utils.sum(dist[k]);
1365    }
1366    if (Utils.sum(props[attIndex]) != 0) Utils.normalize(props[attIndex]);
1367
1368    // Compute subset weights
1369    subsetWeights[attIndex] = new double[2];
1370    for (int j = 0; j < 2; j++) {
1371      subsetWeights[attIndex][j] += Utils.sum(dist[j]);
1372    }
1373
1374    // clean gain
1375    gains[attIndex] = Math.rint(bestGain*10000000)/10000000.0;
1376    dists[attIndex] = dist;
1377    return splitPoint;
1378  }
1379
1380  /**
1381   * Compute distributions, proportions and total weights of two successor
1382   * nodes for a given nominal attribute.
1383   *
1384   * @param props               proportions of each two branches for each attribute
1385   * @param dists               class distributions of two branches for each attribute
1386   * @param att                 numeric att split on
1387   * @param sortedIndices       sorted indices of instances for the attirubte
1388   * @param weights             weights of instances for the attirbute
1389   * @param subsetWeights       total weight of two branches split based on the attribute
1390   * @param gains               Gini gains for each attribute
1391   * @param data                training instances
1392   * @param useHeuristic        if use heuristic search
1393   * @param useGini             if use Gini index as splitting criterion
1394   * @return                    Gini gain for the given attribute
1395   * @throws Exception          if something goes wrong
1396   */
1397  protected String nominalDistribution(double[][] props, double[][][] dists,
1398      Attribute att, int[] sortedIndices, double[] weights, double[][] subsetWeights,
1399      double[] gains, Instances data, boolean useHeuristic, boolean useGini)
1400  throws Exception {
1401
1402    String[] values = new String[att.numValues()];
1403    int numCat = values.length; // number of values of the attribute
1404    int numClasses = data.numClasses();
1405
1406    String bestSplitString = "";
1407    double bestGain = -Double.MAX_VALUE;
1408
1409    // class frequency for each value
1410    int[] classFreq = new int[numCat];
1411    for (int j=0; j<numCat; j++) classFreq[j] = 0;
1412
1413    double[] parentDist = new double[numClasses];
1414    double[][] currDist = new double[2][numClasses];
1415    double[][] dist = new double[2][numClasses];
1416    int missingStart = 0;
1417
1418    for (int i = 0; i < sortedIndices.length; i++) {
1419      Instance inst = data.instance(sortedIndices[i]);
1420      if (!inst.isMissing(att)) {
1421        missingStart++;
1422        classFreq[(int)inst.value(att)] ++;
1423      }
1424      parentDist[(int)inst.classValue()] += weights[i];
1425    }
1426
1427    // count the number of values that class frequency is not 0
1428    int nonEmpty = 0;
1429    for (int j=0; j<numCat; j++) {
1430      if (classFreq[j]!=0) nonEmpty ++;
1431    }
1432
1433    // attribute values which class frequency is not 0
1434    String[] nonEmptyValues = new String[nonEmpty];
1435    int nonEmptyIndex = 0;
1436    for (int j=0; j<numCat; j++) {
1437      if (classFreq[j]!=0) {
1438        nonEmptyValues[nonEmptyIndex] = att.value(j);
1439        nonEmptyIndex ++;
1440      }
1441    }
1442
1443    // attribute values which class frequency is 0
1444    int empty = numCat - nonEmpty;
1445    String[] emptyValues = new String[empty];
1446    int emptyIndex = 0;
1447    for (int j=0; j<numCat; j++) {
1448      if (classFreq[j]==0) {
1449        emptyValues[emptyIndex] = att.value(j);
1450        emptyIndex ++;
1451      }
1452    }
1453
1454    if (nonEmpty<=1) {
1455      gains[att.index()] = 0;
1456      return "";
1457    }
1458
1459    // for tow-class probloms
1460    if (data.numClasses()==2) {
1461
1462      //// Firstly, for attribute values which class frequency is not zero
1463
1464      // probability of class 0 for each attribute value
1465      double[] pClass0 = new double[nonEmpty];
1466      // class distribution for each attribute value
1467      double[][] valDist = new double[nonEmpty][2];
1468
1469      for (int j=0; j<nonEmpty; j++) {
1470        for (int k=0; k<2; k++) {
1471          valDist[j][k] = 0;
1472        }
1473      }
1474
1475      for (int i = 0; i < sortedIndices.length; i++) {
1476        Instance inst = data.instance(sortedIndices[i]);
1477        if (inst.isMissing(att)) {
1478          break;
1479        }
1480
1481        for (int j=0; j<nonEmpty; j++) {
1482          if (att.value((int)inst.value(att)).compareTo(nonEmptyValues[j])==0) {
1483            valDist[j][(int)inst.classValue()] += inst.weight();
1484            break;
1485          }
1486        }
1487      }
1488
1489      for (int j=0; j<nonEmpty; j++) {
1490        double distSum = Utils.sum(valDist[j]);
1491        if (distSum==0) pClass0[j]=0;
1492        else pClass0[j] = valDist[j][0]/distSum;
1493      }
1494
1495      // sort category according to the probability of class 0.0
1496      String[] sortedValues = new String[nonEmpty];
1497      for (int j=0; j<nonEmpty; j++) {
1498        sortedValues[j] = nonEmptyValues[Utils.minIndex(pClass0)];
1499        pClass0[Utils.minIndex(pClass0)] = Double.MAX_VALUE;
1500      }
1501
1502      // Find a subset of attribute values that maximize impurity decrease
1503
1504      // for the attribute values that class frequency is not 0
1505      String tempStr = "";
1506
1507      for (int j=0; j<nonEmpty-1; j++) {
1508        currDist = new double[2][numClasses];
1509        if (tempStr=="") tempStr="(" + sortedValues[j] + ")";
1510        else tempStr += "|"+ "(" + sortedValues[j] + ")";
1511        //System.out.println(sortedValues[j]);
1512        for (int i=0; i<sortedIndices.length;i++) {
1513          Instance inst = data.instance(sortedIndices[i]);
1514          if (inst.isMissing(att)) {
1515            break;
1516          }
1517
1518          if (tempStr.indexOf
1519              ("(" + att.value((int)inst.value(att)) + ")")!=-1) {
1520            currDist[0][(int)inst.classValue()] += weights[i];
1521          } else currDist[1][(int)inst.classValue()] += weights[i];
1522        }
1523
1524        double[][] tempDist = new double[2][numClasses];
1525        for (int kk=0; kk<2; kk++) {
1526          tempDist[kk] = currDist[kk];
1527        }
1528
1529        double[] tempProps = new double[2];
1530        for (int kk=0; kk<2; kk++) {
1531          tempProps[kk] = Utils.sum(tempDist[kk]);
1532        }
1533
1534        if (Utils.sum(tempProps)!=0) Utils.normalize(tempProps);
1535
1536        // split missing values
1537        int mstart = missingStart;
1538        while (mstart < sortedIndices.length) {
1539          Instance insta = data.instance(sortedIndices[mstart]);
1540          for (int jj = 0; jj < 2; jj++) {
1541            tempDist[jj][(int)insta.classValue()] += tempProps[jj] * weights[mstart];
1542          }
1543          mstart++;
1544        }
1545
1546        double currGain;
1547        if (useGini) currGain = computeGiniGain(parentDist,tempDist);
1548        else currGain = computeInfoGain(parentDist,tempDist);
1549
1550        if (currGain>bestGain) {
1551          bestGain = currGain;
1552          bestSplitString = tempStr;
1553          for (int jj = 0; jj < 2; jj++) {
1554            System.arraycopy(tempDist[jj], 0, dist[jj], 0,
1555                dist[jj].length);
1556          }
1557        }
1558      }
1559    }
1560
1561    // multi-class problems (exhaustive search)
1562    else if (!useHeuristic || nonEmpty<=4) {
1563      //else if (!useHeuristic || nonEmpty==2) {
1564
1565      // Firstly, for attribute values which class frequency is not zero
1566      for (int i=0; i<(int)Math.pow(2,nonEmpty-1); i++) {
1567        String tempStr="";
1568        currDist = new double[2][numClasses];
1569        int mod;
1570        int bit10 = i;
1571        for (int j=nonEmpty-1; j>=0; j--) {
1572          mod = bit10%2; // convert from 10bit to 2bit
1573          if (mod==1) {
1574            if (tempStr=="") tempStr = "("+nonEmptyValues[j]+")";
1575            else tempStr += "|" + "("+nonEmptyValues[j]+")";
1576          }
1577          bit10 = bit10/2;
1578        }
1579        for (int j=0; j<sortedIndices.length;j++) {
1580          Instance inst = data.instance(sortedIndices[j]);
1581          if (inst.isMissing(att)) {
1582            break;
1583          }
1584
1585          if (tempStr.indexOf("("+att.value((int)inst.value(att))+")")!=-1) {
1586            currDist[0][(int)inst.classValue()] += weights[j];
1587          } else currDist[1][(int)inst.classValue()] += weights[j];
1588        }
1589
1590        double[][] tempDist = new double[2][numClasses];
1591        for (int k=0; k<2; k++) {
1592          tempDist[k] = currDist[k];
1593        }
1594
1595        double[] tempProps = new double[2];
1596        for (int k=0; k<2; k++) {
1597          tempProps[k] = Utils.sum(tempDist[k]);
1598        }
1599
1600        if (Utils.sum(tempProps)!=0) Utils.normalize(tempProps);
1601
1602        // split missing values
1603        int index = missingStart;
1604        while (index < sortedIndices.length) {
1605          Instance insta = data.instance(sortedIndices[index]);
1606          for (int j = 0; j < 2; j++) {
1607            tempDist[j][(int)insta.classValue()] += tempProps[j] * weights[index];
1608          }
1609          index++;
1610        }
1611
1612        double currGain;
1613        if (useGini) currGain = computeGiniGain(parentDist,tempDist);
1614        else currGain = computeInfoGain(parentDist,tempDist);
1615
1616        if (currGain>bestGain) {
1617          bestGain = currGain;
1618          bestSplitString = tempStr;
1619          for (int j = 0; j < 2; j++) {
1620            //dist[jj] = new double[currDist[jj].length];
1621            System.arraycopy(tempDist[j], 0, dist[j], 0,
1622                dist[j].length);
1623          }
1624        }
1625      }
1626    }
1627
1628    // huristic method to solve multi-classes problems
1629    else {
1630      // Firstly, for attribute values which class frequency is not zero
1631      int n = nonEmpty;
1632      int k = data.numClasses();  // number of classes of the data
1633      double[][] P = new double[n][k];      // class probability matrix
1634      int[] numInstancesValue = new int[n]; // number of instances for an attribute value
1635      double[] meanClass = new double[k];   // vector of mean class probability
1636      int numInstances = data.numInstances(); // total number of instances
1637
1638      // initialize the vector of mean class probability
1639      for (int j=0; j<meanClass.length; j++) meanClass[j]=0;
1640
1641      for (int j=0; j<numInstances; j++) {
1642        Instance inst = (Instance)data.instance(j);
1643        int valueIndex = 0; // attribute value index in nonEmptyValues
1644        for (int i=0; i<nonEmpty; i++) {
1645          if (att.value((int)inst.value(att)).compareToIgnoreCase(nonEmptyValues[i])==0){
1646            valueIndex = i;
1647            break;
1648          }
1649        }
1650        P[valueIndex][(int)inst.classValue()]++;
1651        numInstancesValue[valueIndex]++;
1652        meanClass[(int)inst.classValue()]++;
1653      }
1654
1655      // calculate the class probability matrix
1656      for (int i=0; i<P.length; i++) {
1657        for (int j=0; j<P[0].length; j++) {
1658          if (numInstancesValue[i]==0) P[i][j]=0;
1659          else P[i][j]/=numInstancesValue[i];
1660        }
1661      }
1662
1663      //calculate the vector of mean class probability
1664      for (int i=0; i<meanClass.length; i++) {
1665        meanClass[i]/=numInstances;
1666      }
1667
1668      // calculate the covariance matrix
1669      double[][] covariance = new double[k][k];
1670      for (int i1=0; i1<k; i1++) {
1671        for (int i2=0; i2<k; i2++) {
1672          double element = 0;
1673          for (int j=0; j<n; j++) {
1674            element += (P[j][i2]-meanClass[i2])*(P[j][i1]-meanClass[i1])
1675            *numInstancesValue[j];
1676          }
1677          covariance[i1][i2] = element;
1678        }
1679      }
1680
1681      Matrix matrix = new Matrix(covariance);
1682      weka.core.matrix.EigenvalueDecomposition eigen =
1683        new weka.core.matrix.EigenvalueDecomposition(matrix);
1684      double[] eigenValues = eigen.getRealEigenvalues();
1685
1686      // find index of the largest eigenvalue
1687      int index=0;
1688      double largest = eigenValues[0];
1689      for (int i=1; i<eigenValues.length; i++) {
1690        if (eigenValues[i]>largest) {
1691          index=i;
1692          largest = eigenValues[i];
1693        }
1694      }
1695
1696      // calculate the first principle component
1697      double[] FPC = new double[k];
1698      Matrix eigenVector = eigen.getV();
1699      double[][] vectorArray = eigenVector.getArray();
1700      for (int i=0; i<FPC.length; i++) {
1701        FPC[i] = vectorArray[i][index];
1702      }
1703
1704      // calculate the first principle component scores
1705      double[] Sa = new double[n];
1706      for (int i=0; i<Sa.length; i++) {
1707        Sa[i]=0;
1708        for (int j=0; j<k; j++) {
1709          Sa[i] += FPC[j]*P[i][j];
1710        }
1711      }
1712
1713      // sort category according to Sa(s)
1714      double[] pCopy = new double[n];
1715      System.arraycopy(Sa,0,pCopy,0,n);
1716      String[] sortedValues = new String[n];
1717      Arrays.sort(Sa);
1718
1719      for (int j=0; j<n; j++) {
1720        sortedValues[j] = nonEmptyValues[Utils.minIndex(pCopy)];
1721        pCopy[Utils.minIndex(pCopy)] = Double.MAX_VALUE;
1722      }
1723
1724      // for the attribute values that class frequency is not 0
1725      String tempStr = "";
1726
1727      for (int j=0; j<nonEmpty-1; j++) {
1728        currDist = new double[2][numClasses];
1729        if (tempStr=="") tempStr="(" + sortedValues[j] + ")";
1730        else tempStr += "|"+ "(" + sortedValues[j] + ")";
1731        for (int i=0; i<sortedIndices.length;i++) {
1732          Instance inst = data.instance(sortedIndices[i]);
1733          if (inst.isMissing(att)) {
1734            break;
1735          }
1736
1737          if (tempStr.indexOf
1738              ("(" + att.value((int)inst.value(att)) + ")")!=-1) {
1739            currDist[0][(int)inst.classValue()] += weights[i];
1740          } else currDist[1][(int)inst.classValue()] += weights[i];
1741        }
1742
1743        double[][] tempDist = new double[2][numClasses];
1744        for (int kk=0; kk<2; kk++) {
1745          tempDist[kk] = currDist[kk];
1746        }
1747
1748        double[] tempProps = new double[2];
1749        for (int kk=0; kk<2; kk++) {
1750          tempProps[kk] = Utils.sum(tempDist[kk]);
1751        }
1752
1753        if (Utils.sum(tempProps)!=0) Utils.normalize(tempProps);
1754
1755        // split missing values
1756        int mstart = missingStart;
1757        while (mstart < sortedIndices.length) {
1758          Instance insta = data.instance(sortedIndices[mstart]);
1759          for (int jj = 0; jj < 2; jj++) {
1760            tempDist[jj][(int)insta.classValue()] += tempProps[jj] * weights[mstart];
1761          }
1762          mstart++;
1763        }
1764
1765        double currGain;
1766        if (useGini) currGain = computeGiniGain(parentDist,tempDist);
1767        else currGain = computeInfoGain(parentDist,tempDist);
1768
1769        if (currGain>bestGain) {
1770          bestGain = currGain;
1771          bestSplitString = tempStr;
1772          for (int jj = 0; jj < 2; jj++) {
1773            //dist[jj] = new double[currDist[jj].length];
1774            System.arraycopy(tempDist[jj], 0, dist[jj], 0,
1775                dist[jj].length);
1776          }
1777        }
1778      }
1779    }
1780
1781    // Compute weights
1782    int attIndex = att.index();
1783    props[attIndex] = new double[2];
1784    for (int k = 0; k < 2; k++) {
1785      props[attIndex][k] = Utils.sum(dist[k]);
1786    }
1787    if (!(Utils.sum(props[attIndex]) > 0)) {
1788      for (int k = 0; k < props[attIndex].length; k++) {
1789        props[attIndex][k] = 1.0 / (double)props[attIndex].length;
1790      }
1791    } else {
1792      Utils.normalize(props[attIndex]);
1793    }
1794
1795    // Compute subset weights
1796    subsetWeights[attIndex] = new double[2];
1797    for (int j = 0; j < 2; j++) {
1798      subsetWeights[attIndex][j] += Utils.sum(dist[j]);
1799    }
1800
1801    // Then, for the attribute values that class frequency is 0, split it into the
1802    // most frequent branch
1803    for (int j=0; j<empty; j++) {
1804      if (props[attIndex][0]>=props[attIndex][1]) {
1805        if (bestSplitString=="") bestSplitString = "(" + emptyValues[j] + ")";
1806        else bestSplitString += "|" + "(" + emptyValues[j] + ")";
1807      }
1808    }
1809
1810    // clean gain
1811    gains[attIndex] = Math.rint(bestGain*10000000)/10000000.0;
1812
1813    dists[attIndex] = dist;
1814    return bestSplitString;
1815  }
1816
1817
1818  /**
1819   * Split data into two subsets and store sorted indices and weights for two
1820   * successor nodes.
1821   *
1822   * @param subsetIndices       sorted indecis of instances for each attribute for two successor node
1823   * @param subsetWeights       weights of instances for each attribute for two successor node
1824   * @param att                 attribute the split based on
1825   * @param splitPoint          split point the split based on if att is numeric
1826   * @param splitStr            split subset the split based on if att is nominal
1827   * @param sortedIndices       sorted indices of the instances to be split
1828   * @param weights             weights of the instances to bes split
1829   * @param data                training data
1830   * @throws Exception          if something goes wrong 
1831   */
1832  protected void splitData(int[][][] subsetIndices, double[][][] subsetWeights,
1833      Attribute att, double splitPoint, String splitStr, int[][] sortedIndices,
1834      double[][] weights, Instances data) throws Exception {
1835
1836    int j;
1837    // For each attribute
1838    for (int i = 0; i < data.numAttributes(); i++) {
1839      if (i==data.classIndex()) continue;
1840      int[] num = new int[2];
1841      for (int k = 0; k < 2; k++) {
1842        subsetIndices[k][i] = new int[sortedIndices[i].length];
1843        subsetWeights[k][i] = new double[weights[i].length];
1844      }
1845
1846      for (j = 0; j < sortedIndices[i].length; j++) {
1847        Instance inst = data.instance(sortedIndices[i][j]);
1848        if (inst.isMissing(att)) {
1849          // Split instance up
1850          for (int k = 0; k < 2; k++) {
1851            if (m_Props[k] > 0) {
1852              subsetIndices[k][i][num[k]] = sortedIndices[i][j];
1853              subsetWeights[k][i][num[k]] = m_Props[k] * weights[i][j];
1854              num[k]++;
1855            }
1856          }
1857        } else {
1858          int subset;
1859          if (att.isNumeric())  {
1860            subset = (inst.value(att) < splitPoint) ? 0 : 1;
1861          } else { // nominal attribute
1862            if (splitStr.indexOf
1863                ("(" + att.value((int)inst.value(att.index()))+")")!=-1) {
1864              subset = 0;
1865            } else subset = 1;
1866          }
1867          subsetIndices[subset][i][num[subset]] = sortedIndices[i][j];
1868          subsetWeights[subset][i][num[subset]] = weights[i][j];
1869          num[subset]++;
1870        }
1871      }
1872
1873      // Trim arrays
1874      for (int k = 0; k < 2; k++) {
1875        int[] copy = new int[num[k]];
1876        System.arraycopy(subsetIndices[k][i], 0, copy, 0, num[k]);
1877        subsetIndices[k][i] = copy;
1878        double[] copyWeights = new double[num[k]];
1879        System.arraycopy(subsetWeights[k][i], 0 ,copyWeights, 0, num[k]);
1880        subsetWeights[k][i] = copyWeights;
1881      }
1882    }
1883  }
1884
1885
1886  /**
1887   * Compute and return gini gain for given distributions of a node and its
1888   * successor nodes.
1889   *
1890   * @param parentDist  class distributions of parent node
1891   * @param childDist   class distributions of successor nodes
1892   * @return            Gini gain computed
1893   */
1894  protected double computeGiniGain(double[] parentDist, double[][] childDist) {
1895    double totalWeight = Utils.sum(parentDist);
1896    if (totalWeight==0) return 0;
1897
1898    double leftWeight = Utils.sum(childDist[0]);
1899    double rightWeight = Utils.sum(childDist[1]);
1900
1901    double parentGini = computeGini(parentDist, totalWeight);
1902    double leftGini = computeGini(childDist[0],leftWeight);
1903    double rightGini = computeGini(childDist[1], rightWeight);
1904
1905    return parentGini - leftWeight/totalWeight*leftGini -
1906    rightWeight/totalWeight*rightGini;
1907  }
1908
1909  /**
1910   * Compute and return gini index for a given distribution of a node.
1911   *
1912   * @param dist        class distributions
1913   * @param total       class distributions
1914   * @return            Gini index of the class distributions
1915   */
1916  protected double computeGini(double[] dist, double total) {
1917    if (total==0) return 0;
1918    double val = 0;
1919    for (int i=0; i<dist.length; i++) {
1920      val += (dist[i]/total)*(dist[i]/total);
1921    }
1922    return 1- val;
1923  }
1924
1925  /**
1926   * Compute and return information gain for given distributions of a node
1927   * and its successor nodes.
1928   *
1929   * @param parentDist  class distributions of parent node
1930   * @param childDist   class distributions of successor nodes
1931   * @return            information gain computed
1932   */
1933  protected double computeInfoGain(double[] parentDist, double[][] childDist) {
1934    double totalWeight = Utils.sum(parentDist);
1935    if (totalWeight==0) return 0;
1936
1937    double leftWeight = Utils.sum(childDist[0]);
1938    double rightWeight = Utils.sum(childDist[1]);
1939
1940    double parentInfo = computeEntropy(parentDist, totalWeight);
1941    double leftInfo = computeEntropy(childDist[0],leftWeight);
1942    double rightInfo = computeEntropy(childDist[1], rightWeight);
1943
1944    return parentInfo - leftWeight/totalWeight*leftInfo -
1945    rightWeight/totalWeight*rightInfo;
1946  }
1947
1948  /**
1949   * Compute and return entropy for a given distribution of a node.
1950   *
1951   * @param dist        class distributions
1952   * @param total       class distributions
1953   * @return            entropy of the class distributions
1954   */
1955  protected double computeEntropy(double[] dist, double total) {
1956    if (total==0) return 0;
1957    double entropy = 0;
1958    for (int i=0; i<dist.length; i++) {
1959      if (dist[i]!=0) entropy -= dist[i]/total * Utils.log2(dist[i]/total);
1960    }
1961    return entropy;
1962  }
1963
1964  /**
1965   * Make the node leaf node.
1966   *
1967   * @param data        training data
1968   */
1969  protected void makeLeaf(Instances data) {
1970    m_Attribute = null;
1971    m_isLeaf = true;
1972    m_ClassValue=Utils.maxIndex(m_ClassProbs);
1973    m_ClassAttribute = data.classAttribute();
1974  }
1975
1976  /**
1977   * Computes class probabilities for instance using the decision tree.
1978   *
1979   * @param instance    the instance for which class probabilities is to be computed
1980   * @return            the class probabilities for the given instance
1981   * @throws Exception  if something goes wrong
1982   */
1983  public double[] distributionForInstance(Instance instance)
1984  throws Exception {
1985    if (!m_isLeaf) {
1986      // value of split attribute is missing
1987      if (instance.isMissing(m_Attribute)) {
1988        double[] returnedDist = new double[m_ClassProbs.length];
1989
1990        for (int i = 0; i < m_Successors.length; i++) {
1991          double[] help =
1992            m_Successors[i].distributionForInstance(instance);
1993          if (help != null) {
1994            for (int j = 0; j < help.length; j++) {
1995              returnedDist[j] += m_Props[i] * help[j];
1996            }
1997          }
1998        }
1999        return returnedDist;
2000      }
2001
2002      // split attribute is nonimal
2003      else if (m_Attribute.isNominal()) {
2004        if (m_SplitString.indexOf("(" +
2005            m_Attribute.value((int)instance.value(m_Attribute)) + ")")!=-1)
2006          return  m_Successors[0].distributionForInstance(instance);
2007        else return  m_Successors[1].distributionForInstance(instance);
2008      }
2009
2010      // split attribute is numeric
2011      else {
2012        if (instance.value(m_Attribute) < m_SplitValue)
2013          return m_Successors[0].distributionForInstance(instance);
2014        else
2015          return m_Successors[1].distributionForInstance(instance);
2016      }
2017    }
2018
2019    // leaf node
2020    else return m_ClassProbs;
2021  }
2022
2023  /**
2024   * Prints the decision tree using the protected toString method from below.
2025   *
2026   * @return            a textual description of the classifier
2027   */
2028  public String toString() {
2029    if ((m_Distribution == null) && (m_Successors == null)) {
2030      return "Best-First: No model built yet.";
2031    }
2032    return "Best-First Decision Tree\n" + toString(0)+"\n\n"
2033    +"Size of the Tree: "+numNodes()+"\n\n"
2034    +"Number of Leaf Nodes: "+numLeaves();
2035  }
2036
2037  /**
2038   * Outputs a tree at a certain level.
2039   *
2040   * @param level       the level at which the tree is to be printed
2041   * @return            a tree at a certain level.
2042   */
2043  protected String toString(int level) {
2044    StringBuffer text = new StringBuffer();
2045    // if leaf nodes
2046    if (m_Attribute == null) {
2047      if (Utils.isMissingValue(m_ClassValue)) {
2048        text.append(": null");
2049      } else {
2050        double correctNum = Math.rint(m_Distribution[Utils.maxIndex(m_Distribution)]*100)/
2051        100.0;
2052        double wrongNum = Math.rint((Utils.sum(m_Distribution) -
2053            m_Distribution[Utils.maxIndex(m_Distribution)])*100)/100.0;
2054        String str = "("  + correctNum + "/" + wrongNum + ")";
2055        text.append(": " + m_ClassAttribute.value((int) m_ClassValue)+ str);
2056      }
2057    } else {
2058      for (int j = 0; j < 2; j++) {
2059        text.append("\n");
2060        for (int i = 0; i < level; i++) {
2061          text.append("|  ");
2062        }
2063        if (j==0) {
2064          if (m_Attribute.isNumeric())
2065            text.append(m_Attribute.name() + " < " + m_SplitValue);
2066          else
2067            text.append(m_Attribute.name() + "=" + m_SplitString);
2068        } else {
2069          if (m_Attribute.isNumeric())
2070            text.append(m_Attribute.name() + " >= " + m_SplitValue);
2071          else
2072            text.append(m_Attribute.name() + "!=" + m_SplitString);
2073        }
2074        text.append(m_Successors[j].toString(level + 1));
2075      }
2076    }
2077    return text.toString();
2078  }
2079
2080  /**
2081   * Compute size of the tree.
2082   *
2083   * @return            size of the tree
2084   */
2085  public int numNodes() {
2086    if (m_isLeaf) {
2087      return 1;
2088    } else {
2089      int size =1;
2090      for (int i=0;i<m_Successors.length;i++) {
2091        size+=m_Successors[i].numNodes();
2092      }
2093      return size;
2094    }
2095  }
2096
2097  /**
2098   * Compute number of leaf nodes.
2099   *
2100   * @return            number of leaf nodes
2101   */
2102  public int numLeaves() {
2103    if (m_isLeaf) return 1;
2104    else {
2105      int size=0;
2106      for (int i=0;i<m_Successors.length;i++) {
2107        size+=m_Successors[i].numLeaves();
2108      }
2109      return size;
2110    }
2111  }
2112
2113  /**
2114   * Returns an enumeration describing the available options.
2115   *
2116   * @return            an enumeration describing the available options.
2117   */
2118  public Enumeration listOptions() {
2119    Vector              result;
2120    Enumeration         en;
2121   
2122    result = new Vector();
2123
2124    en = super.listOptions();
2125    while (en.hasMoreElements())
2126      result.addElement(en.nextElement());
2127
2128    result.addElement(new Option(
2129        "\tThe pruning strategy.\n"
2130        + "\t(default: " + new SelectedTag(PRUNING_POSTPRUNING, TAGS_PRUNING) + ")",
2131        "P", 1, "-P " + Tag.toOptionList(TAGS_PRUNING)));
2132
2133    result.addElement(new Option(
2134        "\tThe minimal number of instances at the terminal nodes.\n" 
2135        + "\t(default 2)",
2136        "M", 1, "-M <min no>"));
2137   
2138    result.addElement(new Option(
2139        "\tThe number of folds used in the pruning.\n"
2140        + "\t(default 5)",
2141        "N", 5, "-N <num folds>"));
2142   
2143    result.addElement(new Option(
2144        "\tDon't use heuristic search for nominal attributes in multi-class\n"
2145        + "\tproblem (default yes).\n",
2146        "H", 0, "-H"));
2147   
2148    result.addElement(new Option(
2149        "\tDon't use Gini index for splitting (default yes),\n"
2150        + "\tif not information is used.", 
2151        "G", 0, "-G"));
2152   
2153    result.addElement(new Option(
2154        "\tDon't use error rate in internal cross-validation (default yes), \n"
2155        + "\tbut root mean squared error.", 
2156        "R", 0, "-R"));
2157   
2158    result.addElement(new Option(
2159        "\tUse the 1 SE rule to make pruning decision.\n"
2160        + "\t(default no).", 
2161        "A", 0, "-A"));
2162   
2163    result.addElement(new Option(
2164        "\tPercentage of training data size (0-1]\n"
2165        + "\t(default 1).", 
2166        "C", 0, "-C"));
2167
2168    return result.elements();
2169  }
2170
2171  /**
2172   * Parses the options for this object. <p/>
2173   *
2174   <!-- options-start -->
2175   * Valid options are: <p/>
2176   *
2177   * <pre> -S &lt;num&gt;
2178   *  Random number seed.
2179   *  (default 1)</pre>
2180   *
2181   * <pre> -D
2182   *  If set, classifier is run in debug mode and
2183   *  may output additional info to the console</pre>
2184   *
2185   * <pre> -P &lt;UNPRUNED|POSTPRUNED|PREPRUNED&gt;
2186   *  The pruning strategy.
2187   *  (default: POSTPRUNED)</pre>
2188   *
2189   * <pre> -M &lt;min no&gt;
2190   *  The minimal number of instances at the terminal nodes.
2191   *  (default 2)</pre>
2192   *
2193   * <pre> -N &lt;num folds&gt;
2194   *  The number of folds used in the pruning.
2195   *  (default 5)</pre>
2196   *
2197   * <pre> -H
2198   *  Don't use heuristic search for nominal attributes in multi-class
2199   *  problem (default yes).
2200   * </pre>
2201   *
2202   * <pre> -G
2203   *  Don't use Gini index for splitting (default yes),
2204   *  if not information is used.</pre>
2205   *
2206   * <pre> -R
2207   *  Don't use error rate in internal cross-validation (default yes),
2208   *  but root mean squared error.</pre>
2209   *
2210   * <pre> -A
2211   *  Use the 1 SE rule to make pruning decision.
2212   *  (default no).</pre>
2213   *
2214   * <pre> -C
2215   *  Percentage of training data size (0-1]
2216   *  (default 1).</pre>
2217   *
2218   <!-- options-end -->
2219   *
2220   * @param options     the options to use
2221   * @throws Exception  if setting of options fails
2222   */
2223  public void setOptions(String[] options) throws Exception {
2224    String      tmpStr;
2225   
2226    super.setOptions(options);
2227
2228    tmpStr = Utils.getOption('M', options);
2229    if (tmpStr.length() != 0) 
2230      setMinNumObj(Integer.parseInt(tmpStr));
2231    else
2232      setMinNumObj(2);
2233
2234    tmpStr = Utils.getOption('N', options);
2235    if (tmpStr.length() != 0)
2236      setNumFoldsPruning(Integer.parseInt(tmpStr));
2237    else
2238      setNumFoldsPruning(5);
2239
2240    tmpStr = Utils.getOption('C', options);
2241    if (tmpStr.length()!=0)
2242      setSizePer(Double.parseDouble(tmpStr));
2243    else
2244      setSizePer(1);
2245
2246    tmpStr = Utils.getOption('P', options);
2247    if (tmpStr.length() != 0)
2248      setPruningStrategy(new SelectedTag(tmpStr, TAGS_PRUNING));
2249    else
2250      setPruningStrategy(new SelectedTag(PRUNING_POSTPRUNING, TAGS_PRUNING));
2251
2252    setHeuristic(!Utils.getFlag('H',options));
2253
2254    setUseGini(!Utils.getFlag('G',options));
2255   
2256    setUseErrorRate(!Utils.getFlag('R',options));
2257   
2258    setUseOneSE(Utils.getFlag('A',options));
2259  }
2260
2261  /**
2262   * Gets the current settings of the Classifier.
2263   *
2264   * @return            the current settings of the Classifier
2265   */
2266  public String[] getOptions() {
2267    int         i;
2268    Vector      result;
2269    String[]    options;
2270
2271    result = new Vector();
2272
2273    options = super.getOptions();
2274    for (i = 0; i < options.length; i++)
2275      result.add(options[i]);
2276
2277    result.add("-M");
2278    result.add("" + getMinNumObj());
2279
2280    result.add("-N");
2281    result.add("" + getNumFoldsPruning());
2282
2283    if (!getHeuristic())
2284      result.add("-H");
2285
2286    if (!getUseGini())
2287      result.add("-G");
2288
2289    if (!getUseErrorRate())
2290      result.add("-R");
2291
2292    if (getUseOneSE())
2293      result.add("-A");
2294
2295    result.add("-C");
2296    result.add("" + getSizePer());
2297
2298    result.add("-P");
2299    result.add("" + getPruningStrategy());
2300
2301    return (String[]) result.toArray(new String[result.size()]);         
2302  }
2303
2304  /**
2305   * Return an enumeration of the measure names.
2306   *
2307   * @return            an enumeration of the measure names
2308   */
2309  public Enumeration enumerateMeasures() {
2310    Vector result = new Vector();
2311   
2312    result.addElement("measureTreeSize");
2313   
2314    return result.elements();
2315  }
2316
2317  /**
2318   * Return number of tree size.
2319   *
2320   * @return            number of tree size
2321   */
2322  public double measureTreeSize() {
2323    return numNodes();
2324  }
2325
2326  /**
2327   * Returns the value of the named measure
2328   *
2329   * @param additionalMeasureName       the name of the measure to query for its value
2330   * @return                            the value of the named measure
2331   * @throws IllegalArgumentException   if the named measure is not supported
2332   */
2333  public double getMeasure(String additionalMeasureName) {
2334    if (additionalMeasureName.compareToIgnoreCase("measureTreeSize") == 0) {
2335      return measureTreeSize();
2336    } else {
2337      throw new IllegalArgumentException(additionalMeasureName
2338          + " not supported (Best-First)");
2339    }
2340  }
2341
2342  /**
2343   * Returns the tip text for this property
2344   *
2345   * @return            tip text for this property suitable for
2346   *                    displaying in the explorer/experimenter gui
2347   */
2348  public String pruningStrategyTipText() {
2349    return "Sets the pruning strategy.";
2350  }
2351
2352  /**
2353   * Sets the pruning strategy.
2354   *
2355   * @param value       the strategy
2356   */
2357  public void setPruningStrategy(SelectedTag value) {
2358    if (value.getTags() == TAGS_PRUNING) {
2359      m_PruningStrategy = value.getSelectedTag().getID();
2360    }
2361  }
2362
2363  /**
2364   * Gets the pruning strategy.
2365   *
2366   * @return            the current strategy.
2367   */
2368  public SelectedTag getPruningStrategy() {
2369    return new SelectedTag(m_PruningStrategy, TAGS_PRUNING);
2370  }
2371
2372  /**
2373   * Returns the tip text for this property
2374   *
2375   * @return            tip text for this property suitable for
2376   *                    displaying in the explorer/experimenter gui
2377   */
2378  public String minNumObjTipText() {
2379    return "Set minimal number of instances at the terminal nodes.";
2380  }
2381
2382  /**
2383   * Set minimal number of instances at the terminal nodes.
2384   *
2385   * @param value       minimal number of instances at the terminal nodes
2386   */
2387  public void setMinNumObj(int value) {
2388    m_minNumObj = value;
2389  }
2390
2391  /**
2392   * Get minimal number of instances at the terminal nodes.
2393   *
2394   * @return            minimal number of instances at the terminal nodes
2395   */
2396  public int getMinNumObj() {
2397    return m_minNumObj;
2398  }
2399
2400  /**
2401   * Returns the tip text for this property
2402   *
2403   * @return            tip text for this property suitable for
2404   *                    displaying in the explorer/experimenter gui
2405   */
2406  public String numFoldsPruningTipText() {
2407    return "Number of folds in internal cross-validation.";
2408  }
2409
2410  /**
2411   * Set number of folds in internal cross-validation.
2412   *
2413   * @param value       the number of folds
2414   */
2415  public void setNumFoldsPruning(int value) {
2416    m_numFoldsPruning = value;
2417  }
2418
2419  /**
2420   * Set number of folds in internal cross-validation.
2421   *
2422   * @return            number of folds in internal cross-validation
2423   */
2424  public int getNumFoldsPruning() {
2425    return m_numFoldsPruning;
2426  }
2427
2428  /**
2429   * Returns the tip text for this property
2430   *
2431   * @return            tip text for this property suitable for
2432   *                    displaying in the explorer/experimenter gui.
2433   */
2434  public String heuristicTipText() {
2435    return "If heuristic search is used for binary split for nominal attributes.";
2436  }
2437
2438  /**
2439   * Set if use heuristic search for nominal attributes in multi-class problems.
2440   *
2441   * @param value       if use heuristic search for nominal attributes in
2442   *                    multi-class problems
2443   */
2444  public void setHeuristic(boolean value) {
2445    m_Heuristic = value;
2446  }
2447
2448  /**
2449   * Get if use heuristic search for nominal attributes in multi-class problems.
2450   *
2451   * @return            if use heuristic search for nominal attributes in
2452   *                    multi-class problems
2453   */
2454  public boolean getHeuristic() {
2455    return m_Heuristic;
2456  }
2457
2458  /**
2459   * Returns the tip text for this property
2460   *
2461   * @return            tip text for this property suitable for
2462   *                    displaying in the explorer/experimenter gui.
2463   */
2464  public String useGiniTipText() {
2465    return "If true the Gini index is used for splitting criterion, otherwise the information is used.";
2466  }
2467
2468  /**
2469   * Set if use Gini index as splitting criterion.
2470   *
2471   * @param value       if use Gini index splitting criterion
2472   */
2473  public void setUseGini(boolean value) {
2474    m_UseGini = value;
2475  }
2476
2477  /**
2478   * Get if use Gini index as splitting criterion.
2479   *
2480   * @return            if use Gini index as splitting criterion
2481   */
2482  public boolean getUseGini() {
2483    return m_UseGini;
2484  }
2485
2486  /**
2487   * Returns the tip text for this property
2488   *
2489   * @return            tip text for this property suitable for
2490   *                    displaying in the explorer/experimenter gui.
2491   */
2492  public String useErrorRateTipText() {
2493    return "If error rate is used as error estimate. if not, root mean squared error is used.";
2494  }
2495
2496  /**
2497   * Set if use error rate in internal cross-validation.
2498   *
2499   * @param value       if use error rate in internal cross-validation
2500   */
2501  public void setUseErrorRate(boolean value) {
2502    m_UseErrorRate = value;
2503  }
2504
2505  /**
2506   * Get if use error rate in internal cross-validation.
2507   *
2508   * @return            if use error rate in internal cross-validation.
2509   */
2510  public boolean getUseErrorRate() {
2511    return m_UseErrorRate;
2512  }
2513
2514  /**
2515   * Returns the tip text for this property
2516   *
2517   * @return            tip text for this property suitable for
2518   *                    displaying in the explorer/experimenter gui.
2519   */
2520  public String useOneSETipText() {
2521    return "Use the 1SE rule to make pruning decision.";
2522  }
2523
2524  /**
2525   * Set if use the 1SE rule to choose final model.
2526   *
2527   * @param value       if use the 1SE rule to choose final model
2528   */
2529  public void setUseOneSE(boolean value) {
2530    m_UseOneSE = value;
2531  }
2532
2533  /**
2534   * Get if use the 1SE rule to choose final model.
2535   *
2536   * @return            if use the 1SE rule to choose final model
2537   */
2538  public boolean getUseOneSE() {
2539    return m_UseOneSE;
2540  }
2541
2542  /**
2543   * Returns the tip text for this property
2544   *
2545   * @return            tip text for this property suitable for
2546   *                    displaying in the explorer/experimenter gui.
2547   */
2548  public String sizePerTipText() {
2549    return "The percentage of the training set size (0-1, 0 not included).";
2550  }
2551
2552  /**
2553   * Set training set size.
2554   *
2555   * @param value       training set size
2556   */
2557  public void setSizePer(double value) {
2558    if ((value <= 0) || (value > 1))
2559      System.err.println(
2560          "The percentage of the training set size must be in range 0 to 1 "
2561          + "(0 not included) - ignored!");
2562    else
2563      m_SizePer = value;
2564  }
2565
2566  /**
2567   * Get training set size.
2568   *
2569   * @return            training set size
2570   */
2571  public double getSizePer() {
2572    return m_SizePer;
2573  }
2574 
2575  /**
2576   * Returns the revision string.
2577   *
2578   * @return            the revision
2579   */
2580  public String getRevision() {
2581    return RevisionUtils.extract("$Revision: 5987 $");
2582  }
2583
2584  /**
2585   * Main method.
2586   *
2587   * @param args the options for the classifier
2588   */
2589  public static void main(String[] args) {
2590    runClassifier(new BFTree(), args);
2591  }
2592}
Note: See TracBrowser for help on using the repository browser.