source: src/main/java/weka/classifiers/trees/RandomTree.java @ 18

Last change on this file since 18 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 39.6 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    RandomTree.java
19 *    Copyright (C) 2001 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.trees;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.core.Attribute;
28import weka.core.Capabilities;
29import weka.core.ContingencyTables;
30import weka.core.Drawable;
31import weka.core.Instance;
32import weka.core.Instances;
33import weka.core.Option;
34import weka.core.OptionHandler;
35import weka.core.Randomizable;
36import weka.core.RevisionUtils;
37import weka.core.Utils;
38import weka.core.WeightedInstancesHandler;
39import weka.core.Capabilities.Capability;
40
41import java.util.Enumeration;
42import java.util.Random;
43import java.util.Vector;
44
45/**
46 * <!-- globalinfo-start -->
47 * Class for constructing a tree that considers K randomly  chosen attributes at each node. Performs no pruning. Also has an option to allow estimation of class probabilities based on a hold-out set (backfitting).
48 * <p/>
49 * <!-- globalinfo-end -->
50 *
51 * <!-- options-start -->
52 * Valid options are: <p/>
53 *
54 * <pre> -K &lt;number of attributes&gt;
55 *  Number of attributes to randomly investigate
56 *  (&lt;0 = int(log_2(#attributes)+1)).</pre>
57 *
58 * <pre> -M &lt;minimum number of instances&gt;
59 *  Set minimum number of instances per leaf.</pre>
60 *
61 * <pre> -S &lt;num&gt;
62 *  Seed for random number generator.
63 *  (default 1)</pre>
64 *
65 * <pre> -depth &lt;num&gt;
66 *  The maximum depth of the tree, 0 for unlimited.
67 *  (default 0)</pre>
68 *
69 * <pre> -N &lt;num&gt;
70 *  Number of folds for backfitting (default 0, no backfitting).</pre>
71 *
72 * <pre> -U
73 *  Allow unclassified instances.</pre>
74 *
75 * <pre> -D
76 *  If set, classifier is run in debug mode and
77 *  may output additional info to the console</pre>
78 *
79 * <!-- options-end -->
80 *
81 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
82 * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
83 * @version $Revision: 5928 $
84 */
85public class RandomTree extends AbstractClassifier implements OptionHandler,
86WeightedInstancesHandler, Randomizable, Drawable {
87
88  /** for serialization */
89  static final long serialVersionUID = 8934314652175299374L;
90
91  /** The subtrees appended to this tree. */
92  protected RandomTree[] m_Successors;
93
94  /** The attribute to split on. */
95  protected int m_Attribute = -1;
96
97  /** The split point. */
98  protected double m_SplitPoint = Double.NaN;
99
100  /** The header information. */
101  protected Instances m_Info = null;
102
103  /** The proportions of training instances going down each branch. */
104  protected double[] m_Prop = null;
105
106  /** Class probabilities from the training data. */
107  protected double[] m_ClassDistribution = null;
108
109  /** Minimum number of instances for leaf. */
110  protected double m_MinNum = 1.0;
111
112  /** The number of attributes considered for a split. */
113  protected int m_KValue = 0;
114
115  /** The random seed to use. */
116  protected int m_randomSeed = 1;
117
118  /** The maximum depth of the tree (0 = unlimited) */
119  protected int m_MaxDepth = 0;
120
121  /** Determines how much data is used for backfitting */
122  protected int m_NumFolds = 0;
123
124  /** Whether unclassified instances are allowed */
125  protected boolean m_AllowUnclassifiedInstances = false;
126
127  /** a ZeroR model in case no model can be built from the data */
128  protected Classifier m_ZeroR;
129
130  /**
131   * Returns a string describing classifier
132   *
133   * @return a description suitable for displaying in the
134   *         explorer/experimenter gui
135   */
136  public String globalInfo() {
137
138    return "Class for constructing a tree that considers K randomly "
139    + " chosen attributes at each node. Performs no pruning. Also has"
140    + " an option to allow estimation of class probabilities based on"
141    + " a hold-out set (backfitting).";
142  }
143
144  /**
145   * Returns the tip text for this property
146   *
147   * @return tip text for this property suitable for displaying in the
148   *         explorer/experimenter gui
149   */
150  public String minNumTipText() {
151    return "The minimum total weight of the instances in a leaf.";
152  }
153
154  /**
155   * Get the value of MinNum.
156   *
157   * @return Value of MinNum.
158   */
159  public double getMinNum() {
160
161    return m_MinNum;
162  }
163
164  /**
165   * Set the value of MinNum.
166   *
167   * @param newMinNum
168   *            Value to assign to MinNum.
169   */
170  public void setMinNum(double newMinNum) {
171
172    m_MinNum = newMinNum;
173  }
174
175  /**
176   * Returns the tip text for this property
177   *
178   * @return tip text for this property suitable for displaying in the
179   *         explorer/experimenter gui
180   */
181  public String KValueTipText() {
182    return "Sets the number of randomly chosen attributes. If 0, log_2(number_of_attributes) + 1 is used.";
183  }
184
185  /**
186   * Get the value of K.
187   *
188   * @return Value of K.
189   */
190  public int getKValue() {
191
192    return m_KValue;
193  }
194
195  /**
196   * Set the value of K.
197   *
198   * @param k
199   *            Value to assign to K.
200   */
201  public void setKValue(int k) {
202
203    m_KValue = k;
204  }
205
206  /**
207   * Returns the tip text for this property
208   *
209   * @return tip text for this property suitable for displaying in the
210   *         explorer/experimenter gui
211   */
212  public String seedTipText() {
213    return "The random number seed used for selecting attributes.";
214  }
215
216  /**
217   * Set the seed for random number generation.
218   *
219   * @param seed
220   *            the seed
221   */
222  public void setSeed(int seed) {
223
224    m_randomSeed = seed;
225  }
226
227  /**
228   * Gets the seed for the random number generations
229   *
230   * @return the seed for the random number generation
231   */
232  public int getSeed() {
233
234    return m_randomSeed;
235  }
236
237  /**
238   * Returns the tip text for this property
239   *
240   * @return tip text for this property suitable for displaying in the
241   *         explorer/experimenter gui
242   */
243  public String maxDepthTipText() {
244    return "The maximum depth of the tree, 0 for unlimited.";
245  }
246
247  /**
248   * Get the maximum depth of trh tree, 0 for unlimited.
249   *
250   * @return the maximum depth.
251   */
252  public int getMaxDepth() {
253    return m_MaxDepth;
254  }
255
256  /**
257   * Returns the tip text for this property
258   * @return tip text for this property suitable for
259   * displaying in the explorer/experimenter gui
260   */
261  public String numFoldsTipText() {
262    return "Determines the amount of data used for backfitting. One fold is used for "
263      + "backfitting, the rest for growing the tree. (Default: 0, no backfitting)";
264  }
265 
266  /**
267   * Get the value of NumFolds.
268   *
269   * @return Value of NumFolds.
270   */
271  public int getNumFolds() {
272   
273    return m_NumFolds;
274  }
275 
276  /**
277   * Set the value of NumFolds.
278   *
279   * @param newNumFolds Value to assign to NumFolds.
280   */
281  public void setNumFolds(int newNumFolds) {
282   
283    m_NumFolds = newNumFolds;
284  }
285
286  /**
287   * Returns the tip text for this property
288   * @return tip text for this property suitable for
289   * displaying in the explorer/experimenter gui
290   */
291  public String allowUnclassifiedInstancesTipText() {
292    return "Whether to allow unclassified instances.";
293  }
294 
295  /**
296   * Get the value of NumFolds.
297   *
298   * @return Value of NumFolds.
299   */
300  public boolean getAllowUnclassifiedInstances() {
301   
302    return m_AllowUnclassifiedInstances;
303  }
304 
305  /**
306   * Set the value of AllowUnclassifiedInstances.
307   *
308   * @param newAllowUnclassifiedInstances Value to assign to AllowUnclassifiedInstances.
309   */
310  public void setAllowUnclassifiedInstances(boolean newAllowUnclassifiedInstances) {
311   
312    m_AllowUnclassifiedInstances = newAllowUnclassifiedInstances;
313  }
314
315  /**
316   * Set the maximum depth of the tree, 0 for unlimited.
317   *
318   * @param value
319   *            the maximum depth.
320   */
321  public void setMaxDepth(int value) {
322    m_MaxDepth = value;
323  }
324
325  /**
326   * Lists the command-line options for this classifier.
327   *
328   * @return an enumeration over all possible options
329   */
330  public Enumeration listOptions() {
331
332    Vector newVector = new Vector();
333
334    newVector.addElement(new Option(
335        "\tNumber of attributes to randomly investigate\n"
336        + "\t(<0 = int(log_2(#attributes)+1)).", "K", 1,
337    "-K <number of attributes>"));
338
339    newVector.addElement(new Option(
340        "\tSet minimum number of instances per leaf.", "M", 1,
341    "-M <minimum number of instances>"));
342
343    newVector.addElement(new Option("\tSeed for random number generator.\n"
344        + "\t(default 1)", "S", 1, "-S <num>"));
345
346    newVector.addElement(new Option(
347        "\tThe maximum depth of the tree, 0 for unlimited.\n"
348        + "\t(default 0)", "depth", 1, "-depth <num>"));
349
350    newVector.
351      addElement(new Option("\tNumber of folds for backfitting " +
352                            "(default 0, no backfitting).",
353                            "N", 1, "-N <num>"));
354    newVector.
355      addElement(new Option("\tAllow unclassified instances.",
356                            "U", 0, "-U"));
357
358    Enumeration enu = super.listOptions();
359    while (enu.hasMoreElements()) {
360      newVector.addElement(enu.nextElement());
361    }
362
363    return newVector.elements();
364  }
365
366  /**
367   * Gets options from this classifier.
368   *
369   * @return the options for the current setup
370   */
371  public String[] getOptions() {
372    Vector result;
373    String[] options;
374    int i;
375
376    result = new Vector();
377
378    result.add("-K");
379    result.add("" + getKValue());
380
381    result.add("-M");
382    result.add("" + getMinNum());
383
384    result.add("-S");
385    result.add("" + getSeed());
386
387    if (getMaxDepth() > 0) {
388      result.add("-depth");
389      result.add("" + getMaxDepth());
390    }
391
392    if (getNumFolds() > 0) {
393      result.add("-N"); 
394      result.add("" + getNumFolds());
395    }
396
397    if (getAllowUnclassifiedInstances()) {
398      result.add("-U");
399    }
400
401    options = super.getOptions();
402    for (i = 0; i < options.length; i++)
403      result.add(options[i]);
404
405    return (String[]) result.toArray(new String[result.size()]);
406  }
407
408  /**
409   * Parses a given list of options.
410   * <p/>
411   *
412   * <!-- options-start -->
413   * Valid options are: <p/>
414   *
415   * <pre> -K &lt;number of attributes&gt;
416   *  Number of attributes to randomly investigate
417   *  (&lt;0 = int(log_2(#attributes)+1)).</pre>
418   *
419   * <pre> -M &lt;minimum number of instances&gt;
420   *  Set minimum number of instances per leaf.</pre>
421   *
422   * <pre> -S &lt;num&gt;
423   *  Seed for random number generator.
424   *  (default 1)</pre>
425   *
426   * <pre> -depth &lt;num&gt;
427   *  The maximum depth of the tree, 0 for unlimited.
428   *  (default 0)</pre>
429   *
430   * <pre> -N &lt;num&gt;
431   *  Number of folds for backfitting (default 0, no backfitting).</pre>
432   *
433   * <pre> -U
434   *  Allow unclassified instances.</pre>
435   *
436   * <pre> -D
437   *  If set, classifier is run in debug mode and
438   *  may output additional info to the console</pre>
439   *
440   * <!-- options-end -->
441   *
442   * @param options
443   *            the list of options as an array of strings
444   * @throws Exception
445   *             if an option is not supported
446   */
447  public void setOptions(String[] options) throws Exception {
448    String tmpStr;
449
450    tmpStr = Utils.getOption('K', options);
451    if (tmpStr.length() != 0) {
452      m_KValue = Integer.parseInt(tmpStr);
453    } else {
454      m_KValue = 0;
455    }
456
457    tmpStr = Utils.getOption('M', options);
458    if (tmpStr.length() != 0) {
459      m_MinNum = Double.parseDouble(tmpStr);
460    } else {
461      m_MinNum = 1;
462    }
463
464    tmpStr = Utils.getOption('S', options);
465    if (tmpStr.length() != 0) {
466      setSeed(Integer.parseInt(tmpStr));
467    } else {
468      setSeed(1);
469    }
470
471    tmpStr = Utils.getOption("depth", options);
472    if (tmpStr.length() != 0) {
473      setMaxDepth(Integer.parseInt(tmpStr));
474    } else {
475      setMaxDepth(0);
476    }
477    String numFoldsString = Utils.getOption('N', options);
478    if (numFoldsString.length() != 0) {
479      m_NumFolds = Integer.parseInt(numFoldsString);
480    } else {
481      m_NumFolds = 0;
482    }
483
484    setAllowUnclassifiedInstances(Utils.getFlag('U', options));
485
486    super.setOptions(options);
487
488    Utils.checkForRemainingOptions(options);
489  }
490
491  /**
492   * Returns default capabilities of the classifier.
493   *
494   * @return the capabilities of this classifier
495   */
496  public Capabilities getCapabilities() {
497    Capabilities result = super.getCapabilities();
498    result.disableAll();
499
500    // attributes
501    result.enable(Capability.NOMINAL_ATTRIBUTES);
502    result.enable(Capability.NUMERIC_ATTRIBUTES);
503    result.enable(Capability.DATE_ATTRIBUTES);
504    result.enable(Capability.MISSING_VALUES);
505
506    // class
507    result.enable(Capability.NOMINAL_CLASS);
508    result.enable(Capability.MISSING_CLASS_VALUES);
509
510    return result;
511  }
512
513  /**
514   * Builds classifier.
515   *
516   * @param data
517   *            the data to train with
518   * @throws Exception
519   *             if something goes wrong or the data doesn't fit
520   */
521  public void buildClassifier(Instances data) throws Exception {
522
523    // Make sure K value is in range
524    if (m_KValue > data.numAttributes() - 1)
525      m_KValue = data.numAttributes() - 1;
526    if (m_KValue < 1)
527      m_KValue = (int) Utils.log2(data.numAttributes()) + 1;
528
529    // can classifier handle the data?
530    getCapabilities().testWithFail(data);
531
532    // remove instances with missing class
533    data = new Instances(data);
534    data.deleteWithMissingClass();
535
536    // only class? -> build ZeroR model
537    if (data.numAttributes() == 1) {
538      System.err
539      .println("Cannot build model (only class attribute present in data!), "
540          + "using ZeroR model instead!");
541      m_ZeroR = new weka.classifiers.rules.ZeroR();
542      m_ZeroR.buildClassifier(data);
543      return;
544    } else {
545      m_ZeroR = null;
546    }
547
548    // Figure out appropriate datasets
549    Instances train = null;
550    Instances backfit = null;
551    Random rand = data.getRandomNumberGenerator(m_randomSeed);
552    if (m_NumFolds <= 0) {
553      train = data;
554    } else {
555      data.randomize(rand);
556      data.stratify(m_NumFolds);
557      train = data.trainCV(m_NumFolds, 1, rand);
558      backfit = data.testCV(m_NumFolds, 1);
559    }
560
561    // Create the attribute indices window
562    int[] attIndicesWindow = new int[data.numAttributes() - 1];
563    int j = 0;
564    for (int i = 0; i < attIndicesWindow.length; i++) {
565      if (j == data.classIndex())
566        j++; // do not include the class
567      attIndicesWindow[i] = j++;
568    }
569
570    // Compute initial class counts
571    double[] classProbs = new double[train.numClasses()];
572    for (int i = 0; i < train.numInstances(); i++) {
573      Instance inst = train.instance(i);
574      classProbs[(int) inst.classValue()] += inst.weight();
575    }
576
577    // Build tree
578    buildTree(train, classProbs, new Instances(data, 0), m_MinNum, m_Debug, attIndicesWindow, 
579              rand, 0, getAllowUnclassifiedInstances());
580     
581    // Backfit if required
582    if (backfit != null) {
583      backfitData(backfit);
584    }
585  }
586
587  /**
588   * Backfits the given data into the tree.
589   */
590  public void backfitData(Instances data) throws Exception {
591
592    // Compute initial class counts
593    double[] classProbs = new double[data.numClasses()];
594    for (int i = 0; i < data.numInstances(); i++) {
595      Instance inst = data.instance(i);
596      classProbs[(int) inst.classValue()] += inst.weight();
597    }
598
599    // Fit data into tree
600    backfitData(data, classProbs);
601  }
602
603  /**
604   * Computes class distribution of an instance using the decision tree.
605   *
606   * @param instance
607   *            the instance to compute the distribution for
608   * @return the computed class distribution
609   * @throws Exception
610   *             if computation fails
611   */
612  public double[] distributionForInstance(Instance instance) throws Exception {
613
614    // default model?
615    if (m_ZeroR != null) {
616      return m_ZeroR.distributionForInstance(instance);
617    }
618
619    double[] returnedDist = null;
620
621    if (m_Attribute > -1) {
622
623      // Node is not a leaf
624      if (instance.isMissing(m_Attribute)) {
625
626        // Value is missing
627        returnedDist = new double[m_Info.numClasses()];
628
629        // Split instance up
630        for (int i = 0; i < m_Successors.length; i++) {
631          double[] help = m_Successors[i]
632                                       .distributionForInstance(instance);
633          if (help != null) {
634            for (int j = 0; j < help.length; j++) {
635              returnedDist[j] += m_Prop[i] * help[j];
636            }
637          }
638        }
639      } else if (m_Info.attribute(m_Attribute).isNominal()) {
640
641        // For nominal attributes
642        returnedDist = m_Successors[(int) instance.value(m_Attribute)]
643                                    .distributionForInstance(instance);
644      } else {
645
646        // For numeric attributes
647        if (instance.value(m_Attribute) < m_SplitPoint) {
648          returnedDist = m_Successors[0]
649                                      .distributionForInstance(instance);
650        } else {
651          returnedDist = m_Successors[1]
652                                      .distributionForInstance(instance);
653        }
654      }
655    }
656
657
658    // Node is a leaf or successor is empty?
659    if ((m_Attribute == -1) || (returnedDist == null)) {
660 
661      // Is node empty?
662      if (m_ClassDistribution == null) {
663        if (getAllowUnclassifiedInstances()) {
664          return new double[m_Info.numClasses()];
665        } else {
666          return null;
667        }
668      }
669
670      // Else return normalized distribution
671      double[] normalizedDistribution = (double[]) m_ClassDistribution.clone();
672      Utils.normalize(normalizedDistribution);
673      return normalizedDistribution;
674    } else {
675      return returnedDist;
676    }
677  }
678
679  /**
680   * Outputs the decision tree as a graph
681   *
682   * @return the tree as a graph
683   */
684  public String toGraph() {
685
686    try {
687      StringBuffer resultBuff = new StringBuffer();
688      toGraph(resultBuff, 0);
689      String result = "digraph Tree {\n" + "edge [style=bold]\n"
690      + resultBuff.toString() + "\n}\n";
691      return result;
692    } catch (Exception e) {
693      return null;
694    }
695  }
696
697  /**
698   * Outputs one node for graph.
699   *
700   * @param text
701   *            the buffer to append the output to
702   * @param num
703   *            unique node id
704   * @return the next node id
705   * @throws Exception
706   *             if generation fails
707   */
708  public int toGraph(StringBuffer text, int num) throws Exception {
709
710    int maxIndex = Utils.maxIndex(m_ClassDistribution);
711    String classValue = m_Info.classAttribute().value(maxIndex);
712
713    num++;
714    if (m_Attribute == -1) {
715      text.append("N" + Integer.toHexString(hashCode()) + " [label=\""
716          + num + ": " + classValue + "\"" + "shape=box]\n");
717    } else {
718      text.append("N" + Integer.toHexString(hashCode()) + " [label=\""
719          + num + ": " + classValue + "\"]\n");
720      for (int i = 0; i < m_Successors.length; i++) {
721        text.append("N" + Integer.toHexString(hashCode()) + "->" + "N"
722            + Integer.toHexString(m_Successors[i].hashCode())
723            + " [label=\"" + m_Info.attribute(m_Attribute).name());
724        if (m_Info.attribute(m_Attribute).isNumeric()) {
725          if (i == 0) {
726            text.append(" < "
727                + Utils.doubleToString(m_SplitPoint, 2));
728          } else {
729            text.append(" >= "
730                + Utils.doubleToString(m_SplitPoint, 2));
731          }
732        } else {
733          text.append(" = " + m_Info.attribute(m_Attribute).value(i));
734        }
735        text.append("\"]\n");
736        num = m_Successors[i].toGraph(text, num);
737      }
738    }
739
740    return num;
741  }
742
743  /**
744   * Outputs the decision tree.
745   *
746   * @return a string representation of the classifier
747   */
748  public String toString() {
749
750    // only ZeroR model?
751    if (m_ZeroR != null) {
752      StringBuffer buf = new StringBuffer();
753      buf
754      .append(this.getClass().getName().replaceAll(".*\\.", "")
755          + "\n");
756      buf.append(this.getClass().getName().replaceAll(".*\\.", "")
757          .replaceAll(".", "=")
758          + "\n\n");
759      buf
760      .append("Warning: No model could be built, hence ZeroR model is used:\n\n");
761      buf.append(m_ZeroR.toString());
762      return buf.toString();
763    }
764
765    if (m_Successors == null) {
766      return "RandomTree: no model has been built yet.";
767    } else {
768      return "\nRandomTree\n==========\n"
769      + toString(0)
770      + "\n"
771      + "\nSize of the tree : "
772      + numNodes()
773      + (getMaxDepth() > 0 ? ("\nMax depth of tree: " + getMaxDepth())
774          : (""));
775    }
776  }
777
778  /**
779   * Outputs a leaf.
780   *
781   * @return the leaf as string
782   * @throws Exception
783   *             if generation fails
784   */
785  protected String leafString() throws Exception {
786
787    double sum = 0, maxCount = 0;
788    int maxIndex = 0;
789    if (m_ClassDistribution != null) {
790      sum = Utils.sum(m_ClassDistribution);
791      maxIndex = Utils.maxIndex(m_ClassDistribution);
792      maxCount = m_ClassDistribution[maxIndex];
793    } 
794    return " : "
795    + m_Info.classAttribute().value(maxIndex)
796    + " ("
797    + Utils.doubleToString(sum, 2)
798    + "/"
799    + Utils.doubleToString(sum - maxCount, 2) + ")";
800  }
801
802  /**
803   * Recursively outputs the tree.
804   *
805   * @param level
806   *            the current level of the tree
807   * @return the generated subtree
808   */
809  protected String toString(int level) {
810
811    try {
812      StringBuffer text = new StringBuffer();
813
814      if (m_Attribute == -1) {
815
816        // Output leaf info
817        return leafString();
818      } else if (m_Info.attribute(m_Attribute).isNominal()) {
819
820        // For nominal attributes
821        for (int i = 0; i < m_Successors.length; i++) {
822          text.append("\n");
823          for (int j = 0; j < level; j++) {
824            text.append("|   ");
825          }
826          text.append(m_Info.attribute(m_Attribute).name() + " = "
827              + m_Info.attribute(m_Attribute).value(i));
828          text.append(m_Successors[i].toString(level + 1));
829        }
830      } else {
831
832        // For numeric attributes
833        text.append("\n");
834        for (int j = 0; j < level; j++) {
835          text.append("|   ");
836        }
837        text.append(m_Info.attribute(m_Attribute).name() + " < "
838            + Utils.doubleToString(m_SplitPoint, 2));
839        text.append(m_Successors[0].toString(level + 1));
840        text.append("\n");
841        for (int j = 0; j < level; j++) {
842          text.append("|   ");
843        }
844        text.append(m_Info.attribute(m_Attribute).name() + " >= "
845            + Utils.doubleToString(m_SplitPoint, 2));
846        text.append(m_Successors[1].toString(level + 1));
847      }
848
849      return text.toString();
850    } catch (Exception e) {
851      e.printStackTrace();
852      return "RandomTree: tree can't be printed";
853    }
854  }
855
856  /**
857   * Recursively backfits data into the tree.
858   *
859   * @param data
860   *            the data to work with
861   * @param classProbs
862   *            the class distribution
863   * @throws Exception
864   *             if generation fails
865   */
866  protected void backfitData(Instances data, double[] classProbs) throws Exception {
867
868    // Make leaf if there are no training instances
869    if (data.numInstances() == 0) {
870      m_Attribute = -1;
871      m_ClassDistribution = null;
872      m_Prop = null;
873      return;
874    }
875
876    // Check if node doesn't contain enough instances or is pure
877    // or maximum depth reached
878    m_ClassDistribution = (double[]) classProbs.clone();
879
880    /*    if (Utils.sum(m_ClassDistribution) < 2 * m_MinNum
881        || Utils.eq(m_ClassDistribution[Utils.maxIndex(m_ClassDistribution)], Utils
882                    .sum(m_ClassDistribution))) {
883     
884      // Make leaf
885      m_Attribute = -1;
886      m_Prop = null;
887      return;
888      }*/
889
890    // Are we at an inner node
891    if (m_Attribute > -1) {
892     
893      // Compute new weights for subsets based on backfit data
894      m_Prop = new double[m_Successors.length];
895      for (int i = 0; i < data.numInstances(); i++) {
896        Instance inst = data.instance(i);
897        if (!inst.isMissing(m_Attribute)) {
898          if (data.attribute(m_Attribute).isNominal()) {
899            m_Prop[(int)inst.value(m_Attribute)] += inst.weight();
900          } else {
901            m_Prop[(inst.value(m_Attribute) < m_SplitPoint) ? 0 : 1] += inst.weight();
902          }
903        }
904      }
905
906      // If we only have missing values we can make this node into a leaf
907      if (Utils.sum(m_Prop) <= 0) {
908        m_Attribute = -1;
909        m_Prop = null;
910        return;
911      }
912
913      // Otherwise normalize the proportions
914      Utils.normalize(m_Prop);
915
916      // Split data
917      Instances[] subsets = splitData(data);
918     
919      // Go through subsets
920      for (int i = 0; i < subsets.length; i++) {
921       
922        // Compute distribution for current subset
923        double[] dist = new double[data.numClasses()];
924        for (int j = 0; j < subsets[i].numInstances(); j++) {
925          dist[(int)subsets[i].instance(j).classValue()] += subsets[i].instance(j).weight();
926        }
927       
928        // Backfit subset
929        m_Successors[i].backfitData(subsets[i], dist);
930      }
931
932      // If unclassified instances are allowed, we don't need to store the class distribution
933      if (getAllowUnclassifiedInstances()) {
934        m_ClassDistribution = null;
935        return;
936      }
937
938      // Otherwise, if all successors are non-empty, we don't need to store the class distribution
939      boolean emptySuccessor = false;
940      for (int i = 0; i < subsets.length; i++) {
941        if (m_Successors[i].m_ClassDistribution == null) {
942          emptySuccessor = true;
943          return;
944        }
945      }
946      m_ClassDistribution = null;
947     
948      // If we have a least two non-empty successors, we should keep this tree
949      /*      int nonEmptySuccessors = 0;
950      for (int i = 0; i < subsets.length; i++) {
951        if (m_Successors[i].m_ClassDistribution != null) {
952          nonEmptySuccessors++;
953          if (nonEmptySuccessors > 1) {
954            return;
955          }
956        }
957      }
958     
959      // Otherwise, this node is a leaf or should become a leaf
960      m_Successors = null;
961      m_Attribute = -1;
962      m_Prop = null;
963      return;*/
964    }
965  }
966
967  /**
968   * Recursively generates a tree.
969   *
970   * @param data
971   *            the data to work with
972   * @param classProbs
973   *            the class distribution
974   * @param header
975   *            the header of the data
976   * @param minNum
977   *            the minimum number of instances per leaf
978   * @param debug
979   *            whether debugging is on
980   * @param attIndicesWindow
981   *            the attribute window to choose attributes from
982   * @param random
983   *            random number generator for choosing random attributes
984   * @param depth
985   *            the current depth
986   * @param determineStructure
987   *            whether to determine structure
988   * @throws Exception
989   *             if generation fails
990   */
991  protected void buildTree(Instances data, double[] classProbs, Instances header,
992                           double minNum, boolean debug, int[] attIndicesWindow,
993                           Random random, int depth, boolean allow) throws Exception {
994
995    // Store structure of dataset, set minimum number of instances
996    m_Info = header;
997    m_Debug = debug;
998    m_MinNum = minNum;
999    m_AllowUnclassifiedInstances = allow;
1000
1001    // Make leaf if there are no training instances
1002    if (data.numInstances() == 0) {
1003      m_Attribute = -1;
1004      m_ClassDistribution = null;
1005      m_Prop = null;
1006      return;
1007    }
1008
1009    // Check if node doesn't contain enough instances or is pure
1010    // or maximum depth reached
1011    m_ClassDistribution = (double[]) classProbs.clone();
1012
1013    if (Utils.sum(m_ClassDistribution) < 2 * m_MinNum
1014        || Utils.eq(m_ClassDistribution[Utils.maxIndex(m_ClassDistribution)], Utils
1015            .sum(m_ClassDistribution))
1016            || ((getMaxDepth() > 0) && (depth >= getMaxDepth()))) {
1017      // Make leaf
1018      m_Attribute = -1;
1019      m_Prop = null;
1020      return;
1021    }
1022
1023    // Compute class distributions and value of splitting
1024    // criterion for each attribute
1025    double[] vals = new double[data.numAttributes()];
1026    double[][][] dists = new double[data.numAttributes()][0][0];
1027    double[][] props = new double[data.numAttributes()][0];
1028    double[] splits = new double[data.numAttributes()];
1029   
1030    // Investigate K random attributes
1031    int attIndex = 0;
1032    int windowSize = attIndicesWindow.length;
1033    int k = m_KValue;
1034    boolean gainFound = false;
1035    while ((windowSize > 0) && (k-- > 0 || !gainFound)) {
1036     
1037      int chosenIndex = random.nextInt(windowSize);
1038      attIndex = attIndicesWindow[chosenIndex];
1039     
1040      // shift chosen attIndex out of window
1041      attIndicesWindow[chosenIndex] = attIndicesWindow[windowSize - 1];
1042      attIndicesWindow[windowSize - 1] = attIndex;
1043      windowSize--;
1044     
1045      splits[attIndex] = distribution(props, dists, attIndex, data);
1046      vals[attIndex] = gain(dists[attIndex], priorVal(dists[attIndex]));
1047     
1048      if (Utils.gr(vals[attIndex], 0))
1049        gainFound = true;
1050    }
1051     
1052    // Find best attribute
1053    m_Attribute = Utils.maxIndex(vals);
1054    double[][] distribution = dists[m_Attribute];
1055
1056    // Any useful split found?
1057    if (Utils.gr(vals[m_Attribute], 0)) {
1058
1059      // Build subtrees
1060      m_SplitPoint = splits[m_Attribute];
1061      m_Prop = props[m_Attribute];
1062      Instances[] subsets = splitData(data);
1063      m_Successors = new RandomTree[distribution.length];
1064      for (int i = 0; i < distribution.length; i++) {
1065        m_Successors[i] = new RandomTree();
1066        m_Successors[i].setKValue(m_KValue);
1067        m_Successors[i].setMaxDepth(getMaxDepth());
1068        m_Successors[i].buildTree(subsets[i], distribution[i], header, m_MinNum, m_Debug,
1069                                  attIndicesWindow, random, depth + 1, allow);
1070      }
1071
1072      // If all successors are non-empty, we don't need to store the class distribution
1073      boolean emptySuccessor = false;
1074      for (int i = 0; i < subsets.length; i++) {
1075        if (m_Successors[i].m_ClassDistribution == null) {
1076          emptySuccessor = true;
1077          break;
1078        }
1079      }
1080      if (!emptySuccessor) {
1081        m_ClassDistribution = null;
1082      }
1083    } else {
1084
1085      // Make leaf
1086      m_Attribute = -1;
1087    }
1088  }
1089
1090  /**
1091   * Computes size of the tree.
1092   *
1093   * @return the number of nodes
1094   */
1095  public int numNodes() {
1096
1097    if (m_Attribute == -1) {
1098      return 1;
1099    } else {
1100      int size = 1;
1101      for (int i = 0; i < m_Successors.length; i++) {
1102        size += m_Successors[i].numNodes();
1103      }
1104      return size;
1105    }
1106  }
1107
1108  /**
1109   * Splits instances into subsets based on the given split.
1110   *
1111   * @param data
1112   *            the data to work with
1113   * @return  the subsets of instances
1114   * @throws Exception
1115   *             if something goes wrong
1116   */
1117  protected Instances[] splitData(Instances data) throws Exception {
1118
1119    // Allocate array of Instances objects
1120    Instances[] subsets = new Instances[m_Prop.length];
1121    for (int i = 0; i < m_Prop.length; i++) {
1122      subsets[i] = new Instances(data, data.numInstances());
1123    }
1124
1125    // Go through the data
1126    for (int i = 0; i < data.numInstances(); i++) {
1127
1128      // Get instance
1129      Instance inst = data.instance(i);
1130
1131      // Does the instance have a missing value?
1132      if (inst.isMissing(m_Attribute)) {
1133       
1134        // Split instance up
1135        for (int k = 0; k < m_Prop.length; k++) {
1136          if (m_Prop[k] > 0) {
1137            Instance copy = (Instance)inst.copy();
1138            copy.setWeight(m_Prop[k] * inst.weight());
1139            subsets[k].add(copy);
1140          }
1141        }
1142
1143        // Proceed to next instance
1144        continue;
1145      }
1146
1147      // Do we have a nominal attribute?
1148      if (data.attribute(m_Attribute).isNominal()) {
1149        subsets[(int)inst.value(m_Attribute)].add(inst);
1150
1151        // Proceed to next instance
1152        continue;
1153      }
1154
1155      // Do we have a numeric attribute?
1156      if (data.attribute(m_Attribute).isNumeric()) {
1157        subsets[(inst.value(m_Attribute) < m_SplitPoint) ? 0 : 1].add(inst);
1158
1159        // Proceed to next instance
1160        continue;
1161      }
1162     
1163      // Else throw an exception
1164      throw new IllegalArgumentException("Unknown attribute type");
1165    }
1166
1167    // Save memory
1168    for (int i = 0; i < m_Prop.length; i++) {
1169      subsets[i].compactify();
1170    }
1171
1172    // Return the subsets
1173    return subsets;
1174  }
1175
1176  /**
1177   * Computes class distribution for an attribute.
1178   *
1179   * @param props
1180   * @param dists
1181   * @param att
1182   *            the attribute index
1183   * @param data
1184   *            the data to work with
1185   * @throws Exception
1186   *             if something goes wrong
1187   */
1188  protected double distribution(double[][] props, double[][][] dists, int att, Instances data)
1189  throws Exception {
1190
1191    double splitPoint = Double.NaN;
1192    Attribute attribute = data.attribute(att);
1193    double[][] dist = null;
1194    int indexOfFirstMissingValue = -1;
1195
1196    if (attribute.isNominal()) {
1197
1198      // For nominal attributes
1199      dist = new double[attribute.numValues()][data.numClasses()];
1200      for (int i = 0; i < data.numInstances(); i++) {
1201        Instance inst = data.instance(i);
1202        if (inst.isMissing(att)) {
1203
1204          // Skip missing values at this stage
1205          if (indexOfFirstMissingValue < 0) {
1206            indexOfFirstMissingValue = i;
1207          }
1208          continue;
1209        }
1210        dist[(int) inst.value(att)][(int) inst.classValue()] += inst.weight();
1211      }
1212    } else {
1213
1214      // For numeric attributes
1215      double[][] currDist = new double[2][data.numClasses()];
1216      dist = new double[2][data.numClasses()];
1217
1218      // Sort data
1219      data.sort(att);
1220
1221      // Move all instances into second subset
1222      for (int j = 0; j < data.numInstances(); j++) {
1223        Instance inst = data.instance(j);
1224        if (inst.isMissing(att)) {
1225
1226          // Can stop as soon as we hit a missing value
1227          indexOfFirstMissingValue = j;
1228          break;
1229        }
1230        currDist[1][(int) inst.classValue()] += inst.weight();
1231      }
1232
1233      // Value before splitting
1234      double priorVal = priorVal(currDist);
1235
1236      // Save initial distribution
1237      for (int j = 0; j < currDist.length; j++) {
1238        System.arraycopy(currDist[j], 0, dist[j], 0, dist[j].length);
1239      }
1240
1241      // Try all possible split points
1242      double currSplit = data.instance(0).value(att);
1243      double currVal, bestVal = -Double.MAX_VALUE;
1244      for (int i = 0; i < data.numInstances(); i++) {
1245        Instance inst = data.instance(i);
1246        if (inst.isMissing(att)) {
1247
1248          // Can stop as soon as we hit a missing value
1249          break;
1250        }
1251
1252        // Can we place a sensible split point here?
1253        if (inst.value(att) > currSplit) {
1254
1255          // Compute gain for split point
1256          currVal = gain(currDist, priorVal);
1257
1258          // Is the current split point the best point so far?
1259          if (currVal > bestVal) {
1260
1261            // Store value of current point
1262            bestVal = currVal;
1263
1264            // Save split point
1265            splitPoint = (inst.value(att) + currSplit) / 2.0;
1266
1267            // Save distribution
1268            for (int j = 0; j < currDist.length; j++) {
1269              System.arraycopy(currDist[j], 0, dist[j], 0, dist[j].length);
1270            }
1271          }
1272        }
1273        currSplit = inst.value(att);
1274
1275        // Shift over the weight
1276        currDist[0][(int) inst.classValue()] += inst.weight();
1277        currDist[1][(int) inst.classValue()] -= inst.weight();
1278      }
1279    }
1280
1281    // Compute weights for subsets
1282    props[att] = new double[dist.length];
1283    for (int k = 0; k < props[att].length; k++) {
1284      props[att][k] = Utils.sum(dist[k]);
1285    }
1286    if (Utils.eq(Utils.sum(props[att]), 0)) {
1287      for (int k = 0; k < props[att].length; k++) {
1288        props[att][k] = 1.0 / (double) props[att].length;
1289      }
1290    } else {
1291      Utils.normalize(props[att]);
1292    }
1293
1294    // Any instances with missing values ?
1295    if (indexOfFirstMissingValue > -1) {
1296
1297      // Distribute weights for instances with missing values
1298      for (int i = indexOfFirstMissingValue; i < data.numInstances(); i++) {
1299        Instance inst = data.instance(i);
1300        if (attribute.isNominal()) {
1301
1302          // Need to check if attribute value is missing
1303          if (inst.isMissing(att)) {
1304            for (int j = 0; j < dist.length; j++) {
1305              dist[j][(int) inst.classValue()] += props[att][j] * inst.weight();
1306            }
1307          }
1308        } else {
1309
1310          // Can be sure that value is missing, so no test required
1311          for (int j = 0; j < dist.length; j++) {
1312            dist[j][(int) inst.classValue()] += props[att][j] * inst.weight();
1313          }
1314        }
1315      }
1316    }
1317
1318    // Return distribution and split point
1319    dists[att] = dist;
1320    return splitPoint;
1321  }
1322
1323  /**
1324   * Computes value of splitting criterion before split.
1325   *
1326   * @param dist
1327   *            the distributions
1328   * @return the splitting criterion
1329   */
1330  protected double priorVal(double[][] dist) {
1331
1332    return ContingencyTables.entropyOverColumns(dist);
1333  }
1334
1335  /**
1336   * Computes value of splitting criterion after split.
1337   *
1338   * @param dist
1339   *            the distributions
1340   * @param priorVal
1341   *            the splitting criterion
1342   * @return the gain after the split
1343   */
1344  protected double gain(double[][] dist, double priorVal) {
1345
1346    return priorVal - ContingencyTables.entropyConditionedOnRows(dist);
1347  }
1348
1349  /**
1350   * Returns the revision string.
1351   *
1352   * @return the revision
1353   */
1354  public String getRevision() {
1355    return RevisionUtils.extract("$Revision: 5928 $");
1356  }
1357
1358  /**
1359   * Main method for this class.
1360   *
1361   * @param argv
1362   *            the commandline parameters
1363   */
1364  public static void main(String[] argv) {
1365    runClassifier(new RandomTree(), argv);
1366  }
1367
1368  /**
1369   * Returns graph describing the tree.
1370   *
1371   * @return the graph describing the tree
1372   * @throws Exception
1373   *             if graph can't be computed
1374   */
1375  public String graph() throws Exception {
1376
1377    if (m_Successors == null) {
1378      throw new Exception("RandomTree: No model built yet.");
1379    }
1380    StringBuffer resultBuff = new StringBuffer();
1381    toGraph(resultBuff, 0, null);
1382    String result = "digraph RandomTree {\n" + "edge [style=bold]\n"
1383    + resultBuff.toString() + "\n}\n";
1384    return result;
1385  }
1386
1387  /**
1388   * Returns the type of graph this classifier represents.
1389   *
1390   * @return Drawable.TREE
1391   */
1392  public int graphType() {
1393    return Drawable.TREE;
1394  }
1395
1396  /**
1397   * Outputs one node for graph.
1398   *
1399   * @param text
1400   *            the buffer to append the output to
1401   * @param num
1402   *            the current node id
1403   * @param parent
1404   *            the parent of the nodes
1405   * @return the next node id
1406   * @throws Exception
1407   *             if something goes wrong
1408   */
1409  protected int toGraph(StringBuffer text, int num, RandomTree parent)
1410  throws Exception {
1411
1412    num++;
1413    if (m_Attribute == -1) {
1414      text.append("N" + Integer.toHexString(RandomTree.this.hashCode())
1415          + " [label=\"" + num + leafString() + "\""
1416          + " shape=box]\n");
1417
1418    } else {
1419      text.append("N" + Integer.toHexString(RandomTree.this.hashCode())
1420          + " [label=\"" + num + ": "
1421          + m_Info.attribute(m_Attribute).name() + "\"]\n");
1422      for (int i = 0; i < m_Successors.length; i++) {
1423        text.append("N"
1424            + Integer.toHexString(RandomTree.this.hashCode())
1425            + "->" + "N"
1426            + Integer.toHexString(m_Successors[i].hashCode())
1427            + " [label=\"");
1428        if (m_Info.attribute(m_Attribute).isNumeric()) {
1429          if (i == 0) {
1430            text.append(" < "
1431                + Utils.doubleToString(m_SplitPoint, 2));
1432          } else {
1433            text.append(" >= "
1434                + Utils.doubleToString(m_SplitPoint, 2));
1435          }
1436        } else {
1437          text.append(" = " + m_Info.attribute(m_Attribute).value(i));
1438        }
1439        text.append("\"]\n");
1440        num = m_Successors[i].toGraph(text, num, this);
1441      }
1442    }
1443
1444    return num;
1445  }
1446}
1447
Note: See TracBrowser for help on using the repository browser.