source: branches/MetisMQI/src/main/java/weka/classifiers/trees/ADTree.java @ 37

Last change on this file since 37 was 29, checked in by gnappo, 14 years ago

Taggata versione per la demo e aggiunto branch.

File size: 50.1 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    ADTree.java
19 *    Copyright (C) 2001 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.trees;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.classifiers.IterativeClassifier;
28import weka.classifiers.trees.adtree.PredictionNode;
29import weka.classifiers.trees.adtree.ReferenceInstances;
30import weka.classifiers.trees.adtree.Splitter;
31import weka.classifiers.trees.adtree.TwoWayNominalSplit;
32import weka.classifiers.trees.adtree.TwoWayNumericSplit;
33import weka.core.AdditionalMeasureProducer;
34import weka.core.Attribute;
35import weka.core.Capabilities;
36import weka.core.Drawable;
37import weka.core.FastVector;
38import weka.core.Instance;
39import weka.core.Instances;
40import weka.core.Option;
41import weka.core.OptionHandler;
42import weka.core.RevisionUtils;
43import weka.core.SelectedTag;
44import weka.core.SerializedObject;
45import weka.core.Tag;
46import weka.core.TechnicalInformation;
47import weka.core.TechnicalInformationHandler;
48import weka.core.Utils;
49import weka.core.WeightedInstancesHandler;
50import weka.core.Capabilities.Capability;
51import weka.core.TechnicalInformation.Field;
52import weka.core.TechnicalInformation.Type;
53
54import java.util.Enumeration;
55import java.util.Random;
56import java.util.Vector;
57
58/**
59 <!-- globalinfo-start -->
60 * Class for generating an alternating decision tree. The basic algorithm is based on:<br/>
61 * <br/>
62 * Freund, Y., Mason, L.: The alternating decision tree learning algorithm. In: Proceeding of the Sixteenth International Conference on Machine Learning, Bled, Slovenia, 124-133, 1999.<br/>
63 * <br/>
64 * This version currently only supports two-class problems. The number of boosting iterations needs to be manually tuned to suit the dataset and the desired complexity/accuracy tradeoff. Induction of the trees has been optimized, and heuristic search methods have been introduced to speed learning.
65 * <p/>
66 <!-- globalinfo-end -->
67 *
68 <!-- technical-bibtex-start -->
69 * BibTeX:
70 * <pre>
71 * &#64;inproceedings{Freund1999,
72 *    address = {Bled, Slovenia},
73 *    author = {Freund, Y. and Mason, L.},
74 *    booktitle = {Proceeding of the Sixteenth International Conference on Machine Learning},
75 *    pages = {124-133},
76 *    title = {The alternating decision tree learning algorithm},
77 *    year = {1999}
78 * }
79 * </pre>
80 * <p/>
81 <!-- technical-bibtex-end -->
82 *
83 <!-- options-start -->
84 * Valid options are: <p/>
85 *
86 * <pre> -B &lt;number of boosting iterations&gt;
87 *  Number of boosting iterations.
88 *  (Default = 10)</pre>
89 *
90 * <pre> -E &lt;-3|-2|-1|&gt;=0&gt;
91 *  Expand nodes: -3(all), -2(weight), -1(z_pure), &gt;=0 seed for random walk
92 *  (Default = -3)</pre>
93 *
94 * <pre> -D
95 *  Save the instance data with the model</pre>
96 *
97 <!-- options-end -->
98 *
99 * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
100 * @author Bernhard Pfahringer (bernhard@cs.waikato.ac.nz)
101 * @version $Revision: 5928 $
102 */
103public class ADTree
104  extends AbstractClassifier
105  implements OptionHandler, Drawable, AdditionalMeasureProducer,
106             WeightedInstancesHandler, IterativeClassifier, 
107             TechnicalInformationHandler {
108
109  /** for serialization */
110  static final long serialVersionUID = -1532264837167690683L;
111 
112  /**
113   * Returns a string describing classifier
114   * @return a description suitable for
115   * displaying in the explorer/experimenter gui
116   */
117  public String globalInfo() {
118
119    return  "Class for generating an alternating decision tree. The basic "
120      + "algorithm is based on:\n\n"
121      + getTechnicalInformation().toString() + "\n\n"
122      + "This version currently only supports two-class problems. The number of boosting "
123      + "iterations needs to be manually tuned to suit the dataset and the desired "
124      + "complexity/accuracy tradeoff. Induction of the trees has been optimized, and heuristic "
125      + "search methods have been introduced to speed learning.";
126  }
127
128  /** search mode: Expand all paths */
129  public static final int SEARCHPATH_ALL = 0;
130  /** search mode: Expand the heaviest path */
131  public static final int SEARCHPATH_HEAVIEST = 1;
132  /** search mode: Expand the best z-pure path */
133  public static final int SEARCHPATH_ZPURE = 2;
134  /** search mode: Expand a random path */
135  public static final int SEARCHPATH_RANDOM = 3;
136  /** The search modes */
137  public static final Tag [] TAGS_SEARCHPATH = {
138    new Tag(SEARCHPATH_ALL, "Expand all paths"),
139    new Tag(SEARCHPATH_HEAVIEST, "Expand the heaviest path"),
140    new Tag(SEARCHPATH_ZPURE, "Expand the best z-pure path"),
141    new Tag(SEARCHPATH_RANDOM, "Expand a random path")
142  };
143
144  /** The instances used to train the tree */
145  protected Instances m_trainInstances;
146
147  /** The root of the tree */
148  protected PredictionNode m_root = null;
149
150  /** The random number generator - used for the random search heuristic */
151  protected Random m_random = null; 
152
153  /** The number of the last splitter added to the tree */
154  protected int m_lastAddedSplitNum = 0;
155
156  /** An array containing the inidices to the numeric attributes in the data */
157  protected int[] m_numericAttIndices;
158
159  /** An array containing the inidices to the nominal attributes in the data */
160  protected int[] m_nominalAttIndices;
161
162  /** The total weight of the instances - used to speed Z calculations */
163  protected double m_trainTotalWeight;
164
165  /** The training instances with positive class - referencing the training dataset */
166  protected ReferenceInstances m_posTrainInstances;
167
168  /** The training instances with negative class - referencing the training dataset */
169  protected ReferenceInstances m_negTrainInstances;
170
171  /** The best node to insert under, as found so far by the latest search */
172  protected PredictionNode m_search_bestInsertionNode;
173
174  /** The best splitter to insert, as found so far by the latest search */
175  protected Splitter m_search_bestSplitter;
176
177  /** The smallest Z value found so far by the latest search */
178  protected double m_search_smallestZ;
179
180  /** The positive instances that apply to the best path found so far */
181  protected Instances m_search_bestPathPosInstances;
182
183  /** The negative instances that apply to the best path found so far */
184  protected Instances m_search_bestPathNegInstances;
185
186  /** Statistics - the number of prediction nodes investigated during search */
187  protected int m_nodesExpanded = 0;
188
189  /** Statistics - the number of instances processed during search */
190  protected int m_examplesCounted = 0;
191
192  /** Option - the number of boosting iterations o perform */
193  protected int m_boostingIterations = 10;
194
195  /** Option - the search mode */
196  protected int m_searchPath = 0;
197
198  /** Option - the seed to use for a random search */
199  protected int m_randomSeed = 0; 
200
201  /** Option - whether the tree should remember the instance data */
202  protected boolean m_saveInstanceData = false; 
203
204  /**
205   * Returns an instance of a TechnicalInformation object, containing
206   * detailed information about the technical background of this class,
207   * e.g., paper reference or book this class is based on.
208   *
209   * @return the technical information about this class
210   */
211  public TechnicalInformation getTechnicalInformation() {
212    TechnicalInformation        result;
213   
214    result = new TechnicalInformation(Type.INPROCEEDINGS);
215    result.setValue(Field.AUTHOR, "Freund, Y. and Mason, L.");
216    result.setValue(Field.YEAR, "1999");
217    result.setValue(Field.TITLE, "The alternating decision tree learning algorithm");
218    result.setValue(Field.BOOKTITLE, "Proceeding of the Sixteenth International Conference on Machine Learning");
219    result.setValue(Field.ADDRESS, "Bled, Slovenia");
220    result.setValue(Field.PAGES, "124-133");
221   
222    return result;
223  }
224
225  /**
226   * Sets up the tree ready to be trained, using two-class optimized method.
227   *
228   * @param instances the instances to train the tree with
229   * @exception Exception if training data is unsuitable
230   */
231  public void initClassifier(Instances instances) throws Exception {
232
233    // clear stats
234    m_nodesExpanded = 0;
235    m_examplesCounted = 0;
236    m_lastAddedSplitNum = 0;
237
238    // prepare the random generator
239    m_random = new Random(m_randomSeed);
240
241    // create training set
242    m_trainInstances = new Instances(instances);
243
244    // create positive/negative subsets
245    m_posTrainInstances = new ReferenceInstances(m_trainInstances,
246                                                 m_trainInstances.numInstances());
247    m_negTrainInstances = new ReferenceInstances(m_trainInstances,
248                                                 m_trainInstances.numInstances());
249    for (Enumeration e = m_trainInstances.enumerateInstances(); e.hasMoreElements(); ) {
250      Instance inst = (Instance) e.nextElement();
251      if ((int) inst.classValue() == 0)
252        m_negTrainInstances.addReference(inst); // belongs in negative class
253      else
254        m_posTrainInstances.addReference(inst); // belongs in positive class
255    }
256    m_posTrainInstances.compactify();
257    m_negTrainInstances.compactify();
258
259    // create the root prediction node
260    double rootPredictionValue = calcPredictionValue(m_posTrainInstances,
261                                                     m_negTrainInstances);
262    m_root = new PredictionNode(rootPredictionValue);
263
264    // pre-adjust weights
265    updateWeights(m_posTrainInstances, m_negTrainInstances, rootPredictionValue);
266   
267    // pre-calculate what we can
268    generateAttributeIndicesSingle();
269  }
270
271  /**
272   * Performs one iteration.
273   *
274   * @param iteration the index of the current iteration (0-based)
275   * @exception Exception if this iteration fails
276   */ 
277  public void next(int iteration) throws Exception {
278
279    boost();
280  }
281
282  /**
283   * Performs a single boosting iteration, using two-class optimized method.
284   * Will add a new splitter node and two prediction nodes to the tree
285   * (unless merging takes place).
286   *
287   * @exception Exception if try to boost without setting up tree first or there are no
288   * instances to train with
289   */
290  public void boost() throws Exception {
291
292    if (m_trainInstances == null || m_trainInstances.numInstances() == 0)
293      throw new Exception("Trying to boost with no training data");
294
295    // perform the search
296    searchForBestTestSingle();
297
298    if (m_search_bestSplitter == null) return; // handle empty instances
299
300    // create the new nodes for the tree, updating the weights
301    for (int i=0; i<2; i++) {
302      Instances posInstances =
303        m_search_bestSplitter.instancesDownBranch(i, m_search_bestPathPosInstances);
304      Instances negInstances =
305        m_search_bestSplitter.instancesDownBranch(i, m_search_bestPathNegInstances);
306      double predictionValue = calcPredictionValue(posInstances, negInstances);
307      PredictionNode newPredictor = new PredictionNode(predictionValue);
308      updateWeights(posInstances, negInstances, predictionValue);
309      m_search_bestSplitter.setChildForBranch(i, newPredictor);
310    }
311
312    // insert the new nodes
313    m_search_bestInsertionNode.addChild((Splitter) m_search_bestSplitter, this);
314
315    // free memory
316    m_search_bestPathPosInstances = null;
317    m_search_bestPathNegInstances = null;
318    m_search_bestSplitter = null;
319  }
320
321  /**
322   * Generates the m_nominalAttIndices and m_numericAttIndices arrays to index
323   * the respective attribute types in the training data.
324   *
325   */
326  private void generateAttributeIndicesSingle() {
327
328    // insert indices into vectors
329    FastVector nominalIndices = new FastVector();
330    FastVector numericIndices = new FastVector();
331
332    for (int i=0; i<m_trainInstances.numAttributes(); i++) {
333      if (i != m_trainInstances.classIndex()) {
334        if (m_trainInstances.attribute(i).isNumeric())
335          numericIndices.addElement(new Integer(i));
336        else
337          nominalIndices.addElement(new Integer(i));
338      }
339    }
340
341    // create nominal array
342    m_nominalAttIndices = new int[nominalIndices.size()];
343    for (int i=0; i<nominalIndices.size(); i++)
344      m_nominalAttIndices[i] = ((Integer)nominalIndices.elementAt(i)).intValue();
345   
346    // create numeric array
347    m_numericAttIndices = new int[numericIndices.size()];
348    for (int i=0; i<numericIndices.size(); i++)
349      m_numericAttIndices[i] = ((Integer)numericIndices.elementAt(i)).intValue();
350  }
351
352  /**
353   * Performs a search for the best test (splitter) to add to the tree, by aiming to
354   * minimize the Z value.
355   *
356   * @exception Exception if search fails
357   */
358  private void searchForBestTestSingle() throws Exception {
359
360    // keep track of total weight for efficient wRemainder calculations
361    m_trainTotalWeight = m_trainInstances.sumOfWeights();
362   
363    m_search_smallestZ = Double.POSITIVE_INFINITY;
364    searchForBestTestSingle(m_root, m_posTrainInstances, m_negTrainInstances);
365  }
366
367  /**
368   * Recursive function that carries out search for the best test (splitter) to add to
369   * this part of the tree, by aiming to minimize the Z value. Performs Z-pure cutoff to
370   * reduce search space.
371   *
372   * @param currentNode the root of the subtree to be searched, and the current node
373   * being considered as parent of a new split
374   * @param posInstances the positive-class instances that apply at this node
375   * @param negInstances the negative-class instances that apply at this node
376   * @exception Exception if search fails
377   */
378  private void searchForBestTestSingle(PredictionNode currentNode,
379                                       Instances posInstances, Instances negInstances)
380    throws Exception {
381
382    // don't investigate pure or empty nodes any further
383    if (posInstances.numInstances() == 0 || negInstances.numInstances() == 0) return;
384
385    // do z-pure cutoff
386    if (calcZpure(posInstances, negInstances) >= m_search_smallestZ) return;
387
388    // keep stats
389    m_nodesExpanded++;
390    m_examplesCounted += posInstances.numInstances() + negInstances.numInstances();
391
392    // evaluate static splitters (nominal)
393    for (int i=0; i<m_nominalAttIndices.length; i++)
394      evaluateNominalSplitSingle(m_nominalAttIndices[i], currentNode,
395                                 posInstances, negInstances);
396
397    // evaluate dynamic splitters (numeric)
398    if (m_numericAttIndices.length > 0) {
399
400      // merge the two sets of instances into one
401      Instances allInstances = new Instances(posInstances);
402      for (Enumeration e = negInstances.enumerateInstances(); e.hasMoreElements(); )
403        allInstances.add((Instance) e.nextElement());
404   
405      // use method of finding the optimal Z split-point
406      for (int i=0; i<m_numericAttIndices.length; i++)
407        evaluateNumericSplitSingle(m_numericAttIndices[i], currentNode,
408                                   posInstances, negInstances, allInstances);
409    }
410
411    if (currentNode.getChildren().size() == 0) return;
412
413    // keep searching
414    switch (m_searchPath) {
415    case SEARCHPATH_ALL:
416      goDownAllPathsSingle(currentNode, posInstances, negInstances);
417      break;
418    case SEARCHPATH_HEAVIEST: 
419      goDownHeaviestPathSingle(currentNode, posInstances, negInstances);
420      break;
421    case SEARCHPATH_ZPURE: 
422      goDownZpurePathSingle(currentNode, posInstances, negInstances);
423      break;
424    case SEARCHPATH_RANDOM: 
425      goDownRandomPathSingle(currentNode, posInstances, negInstances);
426      break;
427    }
428  }
429
430  /**
431   * Continues single (two-class optimized) search by investigating every node in the
432   * subtree under currentNode.
433   *
434   * @param currentNode the root of the subtree to be searched
435   * @param posInstances the positive-class instances that apply at this node
436   * @param negInstances the negative-class instances that apply at this node
437   * @exception Exception if search fails
438   */
439  private void goDownAllPathsSingle(PredictionNode currentNode,
440                                    Instances posInstances, Instances negInstances)
441    throws Exception {
442
443    for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) {
444      Splitter split = (Splitter) e.nextElement();
445      for (int i=0; i<split.getNumOfBranches(); i++)
446        searchForBestTestSingle(split.getChildForBranch(i),
447                                split.instancesDownBranch(i, posInstances),
448                                split.instancesDownBranch(i, negInstances));
449    }
450  }
451
452  /**
453   * Continues single (two-class optimized) search by investigating only the path
454   * with the most heavily weighted instances.
455   *
456   * @param currentNode the root of the subtree to be searched
457   * @param posInstances the positive-class instances that apply at this node
458   * @param negInstances the negative-class instances that apply at this node
459   * @exception Exception if search fails
460   */
461  private void goDownHeaviestPathSingle(PredictionNode currentNode,
462                                        Instances posInstances, Instances negInstances)
463    throws Exception {
464
465    Splitter heaviestSplit = null;
466    int heaviestBranch = 0;
467    double largestWeight = 0.0;
468    for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) {
469      Splitter split = (Splitter) e.nextElement();
470      for (int i=0; i<split.getNumOfBranches(); i++) {
471        double weight =
472          split.instancesDownBranch(i, posInstances).sumOfWeights() +
473          split.instancesDownBranch(i, negInstances).sumOfWeights();
474        if (weight > largestWeight) {
475          heaviestSplit = split;
476          heaviestBranch = i;
477          largestWeight = weight;
478        }
479      }
480    }
481    if (heaviestSplit != null)
482      searchForBestTestSingle(heaviestSplit.getChildForBranch(heaviestBranch),
483                              heaviestSplit.instancesDownBranch(heaviestBranch,
484                                                                posInstances),
485                              heaviestSplit.instancesDownBranch(heaviestBranch,
486                                                                negInstances));
487  }
488
489  /**
490   * Continues single (two-class optimized) search by investigating only the path
491   * with the best Z-pure value at each branch.
492   *
493   * @param currentNode the root of the subtree to be searched
494   * @param posInstances the positive-class instances that apply at this node
495   * @param negInstances the negative-class instances that apply at this node
496   * @exception Exception if search fails
497   */
498  private void goDownZpurePathSingle(PredictionNode currentNode,
499                                     Instances posInstances, Instances negInstances)
500    throws Exception {
501
502    double lowestZpure = m_search_smallestZ; // do z-pure cutoff
503    PredictionNode bestPath = null;
504    Instances bestPosSplit = null, bestNegSplit = null;
505
506    // search for branch with lowest Z-pure
507    for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) {
508      Splitter split = (Splitter) e.nextElement();
509      for (int i=0; i<split.getNumOfBranches(); i++) {
510        Instances posSplit = split.instancesDownBranch(i, posInstances);
511        Instances negSplit = split.instancesDownBranch(i, negInstances);
512        double newZpure = calcZpure(posSplit, negSplit);
513        if (newZpure < lowestZpure) {
514          lowestZpure = newZpure;
515          bestPath = split.getChildForBranch(i);
516          bestPosSplit = posSplit;
517          bestNegSplit = negSplit;
518        }
519      }
520    }
521
522    if (bestPath != null)
523      searchForBestTestSingle(bestPath, bestPosSplit, bestNegSplit);
524  }
525
526  /**
527   * Continues single (two-class optimized) search by investigating a random path.
528   *
529   * @param currentNode the root of the subtree to be searched
530   * @param posInstances the positive-class instances that apply at this node
531   * @param negInstances the negative-class instances that apply at this node
532   * @exception Exception if search fails
533   */
534  private void goDownRandomPathSingle(PredictionNode currentNode,
535                                      Instances posInstances, Instances negInstances)
536    throws Exception {
537
538    FastVector children = currentNode.getChildren();
539    Splitter split = (Splitter) children.elementAt(getRandom(children.size()));
540    int branch = getRandom(split.getNumOfBranches());
541    searchForBestTestSingle(split.getChildForBranch(branch),
542                            split.instancesDownBranch(branch, posInstances),
543                            split.instancesDownBranch(branch, negInstances));
544  }
545
546  /**
547   * Investigates the option of introducing a nominal split under currentNode. If it
548   * finds a split that has a Z-value lower than has already been found it will
549   * update the search information to record this as the best option so far.
550   *
551   * @param attIndex index of a nominal attribute to create a split from
552   * @param currentNode the parent under which a split is to be considered
553   * @param posInstances the positive-class instances that apply at this node
554   * @param negInstances the negative-class instances that apply at this node
555   */
556  private void evaluateNominalSplitSingle(int attIndex, PredictionNode currentNode,
557                                          Instances posInstances, Instances negInstances)
558  {
559   
560    double[] indexAndZ = findLowestZNominalSplit(posInstances, negInstances, attIndex);
561
562    if (indexAndZ[1] < m_search_smallestZ) {
563      m_search_smallestZ = indexAndZ[1];
564      m_search_bestInsertionNode = currentNode;
565      m_search_bestSplitter = new TwoWayNominalSplit(attIndex, (int) indexAndZ[0]);
566      m_search_bestPathPosInstances = posInstances;
567      m_search_bestPathNegInstances = negInstances;
568    }
569  }
570
571  /**
572   * Investigates the option of introducing a two-way numeric split under currentNode.
573   * If it finds a split that has a Z-value lower than has already been found it will
574   * update the search information to record this as the best option so far.
575   *
576   * @param attIndex index of a numeric attribute to create a split from
577   * @param currentNode the parent under which a split is to be considered
578   * @param posInstances the positive-class instances that apply at this node
579   * @param negInstances the negative-class instances that apply at this node
580   * @param allInstances all of the instances the apply at this node (pos+neg combined)
581   * @throws Exception in case of an error
582   */
583  private void evaluateNumericSplitSingle(int attIndex, PredictionNode currentNode,
584                                          Instances posInstances, Instances negInstances,
585                                          Instances allInstances)
586    throws Exception {
587   
588    double[] splitAndZ = findLowestZNumericSplit(allInstances, attIndex);
589
590    if (splitAndZ[1] < m_search_smallestZ) {
591      m_search_smallestZ = splitAndZ[1];
592      m_search_bestInsertionNode = currentNode;
593      m_search_bestSplitter = new TwoWayNumericSplit(attIndex, splitAndZ[0]);
594      m_search_bestPathPosInstances = posInstances;
595      m_search_bestPathNegInstances = negInstances;
596    }
597  }
598
599  /**
600   * Calculates the prediction value used for a particular set of instances.
601   *
602   * @param posInstances the positive-class instances
603   * @param negInstances the negative-class instances
604   * @return the prediction value
605   */
606  private double calcPredictionValue(Instances posInstances, Instances negInstances) {
607   
608    return 0.5 * Math.log( (posInstances.sumOfWeights() + 1.0)
609                          / (negInstances.sumOfWeights() + 1.0) );
610  }
611
612  /**
613   * Calculates the Z-pure value for a particular set of instances.
614   *
615   * @param posInstances the positive-class instances
616   * @param negInstances the negative-class instances
617   * @return the Z-pure value
618   */
619  private double calcZpure(Instances posInstances, Instances negInstances) {
620   
621    double posWeight = posInstances.sumOfWeights();
622    double negWeight = negInstances.sumOfWeights();
623    return (2.0 * (Math.sqrt(posWeight+1.0) + Math.sqrt(negWeight+1.0))) + 
624      (m_trainTotalWeight - (posWeight + negWeight));
625  }
626
627  /**
628   * Updates the weights of instances that are influenced by a new prediction value.
629   *
630   * @param posInstances positive-class instances to which the prediction value applies
631   * @param negInstances negative-class instances to which the prediction value applies
632   * @param predictionValue the new prediction value
633   */
634  private void updateWeights(Instances posInstances, Instances negInstances,
635                             double predictionValue) {
636   
637    // do positives
638    double weightMultiplier = Math.pow(Math.E, -predictionValue);
639    for (Enumeration e = posInstances.enumerateInstances(); e.hasMoreElements(); ) {
640      Instance inst = (Instance) e.nextElement();
641      inst.setWeight(inst.weight() * weightMultiplier);
642    }
643    // do negatives
644    weightMultiplier = Math.pow(Math.E, predictionValue);
645    for (Enumeration e = negInstances.enumerateInstances(); e.hasMoreElements(); ) {
646      Instance inst = (Instance) e.nextElement();
647      inst.setWeight(inst.weight() * weightMultiplier);
648    }
649  }
650
651  /**
652   * Finds the nominal attribute value to split on that results in the lowest Z-value.
653   *
654   * @param posInstances the positive-class instances to split
655   * @param negInstances the negative-class instances to split
656   * @param attIndex the index of the nominal attribute to find a split for
657   * @return a double array, index[0] contains the value to split on, index[1] contains
658   * the Z-value of the split
659   */
660  private double[] findLowestZNominalSplit(Instances posInstances, Instances negInstances,
661                                           int attIndex)
662  {
663   
664    double lowestZ = Double.MAX_VALUE;
665    int bestIndex = 0;
666
667    // set up arrays
668    double[] posWeights = attributeValueWeights(posInstances, attIndex);
669    double[] negWeights = attributeValueWeights(negInstances, attIndex);
670    double posWeight = Utils.sum(posWeights);
671    double negWeight = Utils.sum(negWeights);
672
673    int maxIndex = posWeights.length;
674    if (maxIndex == 2) maxIndex = 1; // avoid repeating due to 2-way symmetry
675
676    for (int i = 0; i < maxIndex; i++) {
677      // calculate Z
678      double w1 = posWeights[i] + 1.0;
679      double w2 = negWeights[i] + 1.0;
680      double w3 = posWeight - w1 + 2.0;
681      double w4 = negWeight - w2 + 2.0;
682      double wRemainder = m_trainTotalWeight + 4.0 - (w1 + w2 + w3 + w4);
683      double newZ = (2.0 * (Math.sqrt(w1 * w2) + Math.sqrt(w3 * w4))) + wRemainder;
684
685      // record best option
686      if (newZ < lowestZ) { 
687        lowestZ = newZ;
688        bestIndex = i;
689      }
690    }
691
692    // return result
693    double[] indexAndZ = new double[2];
694    indexAndZ[0] = (double) bestIndex;
695    indexAndZ[1] = lowestZ;
696    return indexAndZ; 
697  }
698
699  /**
700   * Simultanously sum the weights of all attribute values for all instances.
701   *
702   * @param instances the instances to get the weights from
703   * @param attIndex index of the attribute to be evaluated
704   * @return a double array containing the weight of each attribute value
705   */   
706  private double[] attributeValueWeights(Instances instances, int attIndex)
707  {
708   
709    double[] weights = new double[instances.attribute(attIndex).numValues()];
710    for(int i = 0; i < weights.length; i++) weights[i] = 0.0;
711
712    for (Enumeration e = instances.enumerateInstances(); e.hasMoreElements(); ) {
713      Instance inst = (Instance) e.nextElement();
714      if (!inst.isMissing(attIndex)) weights[(int)inst.value(attIndex)] += inst.weight();
715    }
716    return weights;
717  }
718
719  /**
720   * Finds the numeric split-point that results in the lowest Z-value.
721   *
722   * @param instances the instances to find a split for
723   * @param attIndex the index of the numeric attribute to find a split for
724   * @return a double array, index[0] contains the split-point, index[1] contains the
725   * Z-value of the split
726   * @throws Exception in case of an error
727   */
728  private double[] findLowestZNumericSplit(Instances instances, int attIndex)
729    throws Exception {
730   
731    double splitPoint = 0.0;
732    double bestVal = Double.MAX_VALUE, currVal, currCutPoint;
733    int numMissing = 0;
734    double[][] distribution = new double[3][instances.numClasses()];   
735
736    // compute counts for all the values
737    for (int i = 0; i < instances.numInstances(); i++) {
738      Instance inst = instances.instance(i);
739      if (!inst.isMissing(attIndex)) {
740        distribution[1][(int)inst.classValue()] += inst.weight();
741      } else {
742        distribution[2][(int)inst.classValue()] += inst.weight();
743        numMissing++;
744      }
745    }
746
747    // sort instances
748    instances.sort(attIndex);
749   
750    // make split counts for each possible split and evaluate
751    for (int i = 0; i < instances.numInstances() - (numMissing + 1); i++) {
752      Instance inst = instances.instance(i);
753      Instance instPlusOne = instances.instance(i + 1);
754      distribution[0][(int)inst.classValue()] += inst.weight();
755      distribution[1][(int)inst.classValue()] -= inst.weight();
756      if (Utils.sm(inst.value(attIndex), instPlusOne.value(attIndex))) {
757        currCutPoint = (inst.value(attIndex) + instPlusOne.value(attIndex)) / 2.0;
758        currVal = conditionedZOnRows(distribution);
759        if (currVal < bestVal) {
760          splitPoint = currCutPoint;
761          bestVal = currVal;
762        }
763      }
764    }
765       
766    double[] splitAndZ = new double[2];
767    splitAndZ[0] = splitPoint;
768    splitAndZ[1] = bestVal;
769    return splitAndZ;
770  }
771
772  /**
773   * Calculates the Z-value from the rows of a weight distribution array.
774   *
775   * @param distribution the weight distribution
776   * @return the Z-value
777   */
778  private double conditionedZOnRows(double [][] distribution) {
779   
780    double w1 = distribution[0][0] + 1.0;
781    double w2 = distribution[0][1] + 1.0;
782    double w3 = distribution[1][0] + 1.0; 
783    double w4 = distribution[1][1] + 1.0;
784    double wRemainder = m_trainTotalWeight + 4.0 - (w1 + w2 + w3 + w4);
785    return (2.0 * (Math.sqrt(w1 * w2) + Math.sqrt(w3 * w4))) + wRemainder;
786  }
787
788  /**
789   * Returns the class probability distribution for an instance.
790   *
791   * @param instance the instance to be classified
792   * @return the distribution the tree generates for the instance
793   */
794  public double[] distributionForInstance(Instance instance) {
795   
796    double predVal = predictionValueForInstance(instance, m_root, 0.0);
797   
798    double[] distribution = new double[2];
799    distribution[0] = 1.0 / (1.0 + Math.pow(Math.E, predVal));
800    distribution[1] = 1.0 / (1.0 + Math.pow(Math.E, -predVal));
801
802    return distribution;
803  }
804
805  /**
806   * Returns the class prediction value (vote) for an instance.
807   *
808   * @param inst the instance
809   * @param currentNode the root of the tree to get the values from
810   * @param currentValue the current value before adding the value contained in the
811   * subtree
812   * @return the class prediction value (vote)
813   */
814  protected double predictionValueForInstance(Instance inst, PredictionNode currentNode,
815                                            double currentValue) {
816   
817    currentValue += currentNode.getValue();
818    for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) {
819      Splitter split = (Splitter) e.nextElement();
820      int branch = split.branchInstanceGoesDown(inst);
821      if (branch >= 0)
822        currentValue = predictionValueForInstance(inst, split.getChildForBranch(branch),
823                                                  currentValue);
824    }
825    return currentValue;
826  }
827
828  /**
829   * Returns a description of the classifier.
830   *
831   * @return a string containing a description of the classifier
832   */
833  public String toString() {
834   
835    if (m_root == null)
836      return ("ADTree not built yet");
837    else {
838      return ("Alternating decision tree:\n\n" + toString(m_root, 1) +
839              "\nLegend: " + legend() +
840              "\nTree size (total number of nodes): " + numOfAllNodes(m_root) + 
841              "\nLeaves (number of predictor nodes): " + numOfPredictionNodes(m_root)
842              );
843    }
844  }
845
846  /**
847   * Traverses the tree, forming a string that describes it.
848   *
849   * @param currentNode the current node under investigation
850   * @param level the current level in the tree
851   * @return the string describing the subtree
852   */     
853  protected String toString(PredictionNode currentNode, int level) {
854   
855    StringBuffer text = new StringBuffer();
856   
857    text.append(": " + Utils.doubleToString(currentNode.getValue(),3));
858   
859    for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) {
860      Splitter split = (Splitter) e.nextElement();
861           
862      for (int j=0; j<split.getNumOfBranches(); j++) {
863        PredictionNode child = split.getChildForBranch(j);
864        if (child != null) {
865          text.append("\n");
866          for (int k = 0; k < level; k++) {
867            text.append("|  ");
868          }
869          text.append("(" + split.orderAdded + ")");
870          text.append(split.attributeString(m_trainInstances) + " "
871                      + split.comparisonString(j, m_trainInstances));
872          text.append(toString(child, level + 1));
873        }
874      }
875    }
876    return text.toString();
877  }
878
879  /**
880   *  Returns the type of graph this classifier
881   *  represents.
882   *  @return Drawable.TREE
883   */   
884  public int graphType() {
885      return Drawable.TREE;
886  }
887
888  /**
889   * Returns graph describing the tree.
890   *
891   * @return the graph of the tree in dotty format
892   * @exception Exception if something goes wrong
893   */
894  public String graph() throws Exception {
895   
896    StringBuffer text = new StringBuffer();
897    text.append("digraph ADTree {\n");
898    graphTraverse(m_root, text, 0, 0, m_trainInstances);
899    return text.toString() +"}\n";
900  }
901
902  /**
903   * Traverses the tree, graphing each node.
904   *
905   * @param currentNode the currentNode under investigation
906   * @param text the string built so far
907   * @param splitOrder the order the parent splitter was added to the tree
908   * @param predOrder the order this predictor was added to the split
909   * @param instances the data to work on
910   * @exception Exception if something goes wrong
911   */       
912  protected void graphTraverse(PredictionNode currentNode, StringBuffer text,
913                               int splitOrder, int predOrder, Instances instances)
914    throws Exception {
915   
916    text.append("S" + splitOrder + "P" + predOrder + " [label=\"");
917    text.append(Utils.doubleToString(currentNode.getValue(),3));
918    if (splitOrder == 0) // show legend in root
919      text.append(" (" + legend() + ")");
920    text.append("\" shape=box style=filled");
921    if (instances.numInstances() > 0) text.append(" data=\n" + instances + "\n,\n");
922    text.append("]\n");
923    for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) {
924      Splitter split = (Splitter) e.nextElement();
925      text.append("S" + splitOrder + "P" + predOrder + "->" + "S" + split.orderAdded +
926                  " [style=dotted]\n");
927      text.append("S" + split.orderAdded + " [label=\"" + split.orderAdded + ": " +
928                  split.attributeString(m_trainInstances) + "\"]\n");
929
930      for (int i=0; i<split.getNumOfBranches(); i++) {
931        PredictionNode child = split.getChildForBranch(i);
932        if (child != null) {
933          text.append("S" + split.orderAdded + "->" + "S" + split.orderAdded + "P" + i +
934                      " [label=\"" + split.comparisonString(i, m_trainInstances) + "\"]\n");
935          graphTraverse(child, text, split.orderAdded, i,
936                        split.instancesDownBranch(i, instances));
937        }
938      }
939    } 
940  }
941
942  /**
943   * Returns the legend of the tree, describing how results are to be interpreted.
944   *
945   * @return a string containing the legend of the classifier
946   */
947  public String legend() {
948   
949    Attribute classAttribute = null;
950    if (m_trainInstances == null) return "";
951    try {classAttribute = m_trainInstances.classAttribute();} catch (Exception x){};
952    return ("-ve = " + classAttribute.value(0) +
953            ", +ve = " + classAttribute.value(1));
954  }
955
956  /**
957   * @return tip text for this property suitable for
958   * displaying in the explorer/experimenter gui
959   */
960  public String numOfBoostingIterationsTipText() {
961
962    return "Sets the number of boosting iterations to perform. You will need to manually "
963      + "tune this parameter to suit the dataset and the desired complexity/accuracy "
964      + "tradeoff. More boosting iterations will result in larger (potentially more "
965      + " accurate) trees, but will make learning slower. Each iteration will add 3 nodes "
966      + "(1 split + 2 prediction) to the tree unless merging occurs.";
967  }
968
969  /**
970   * Gets the number of boosting iterations.
971   *
972   * @return the number of boosting iterations
973   */
974  public int getNumOfBoostingIterations() {
975   
976    return m_boostingIterations;
977  }
978
979  /**
980   * Sets the number of boosting iterations.
981   *
982   * @param b the number of boosting iterations to use
983   */
984  public void setNumOfBoostingIterations(int b) {
985   
986    m_boostingIterations = b; 
987  }
988
989  /**
990   * @return tip text for this property suitable for
991   * displaying in the explorer/experimenter gui
992   */
993  public String searchPathTipText() {
994
995    return "Sets the type of search to perform when building the tree. The default option"
996      + " (Expand all paths) will do an exhaustive search. The other search methods are"
997      + " heuristic, so they are not guaranteed to find an optimal solution but they are"
998      + " much faster. Expand the heaviest path: searches the path with the most heavily"
999      + " weighted instances. Expand the best z-pure path: searches the path determined"
1000      + " by the best z-pure estimate. Expand a random path: the fastest method, simply"
1001      + " searches down a single random path on each iteration.";
1002  }
1003
1004  /**
1005   * Gets the method of searching the tree for a new insertion. Will be one of
1006   * SEARCHPATH_ALL, SEARCHPATH_HEAVIEST, SEARCHPATH_ZPURE, SEARCHPATH_RANDOM.
1007   *
1008   * @return the tree searching mode
1009   */
1010  public SelectedTag getSearchPath() {
1011
1012    return new SelectedTag(m_searchPath, TAGS_SEARCHPATH);
1013  }
1014 
1015  /**
1016   * Sets the method of searching the tree for a new insertion. Will be one of
1017   * SEARCHPATH_ALL, SEARCHPATH_HEAVIEST, SEARCHPATH_ZPURE, SEARCHPATH_RANDOM.
1018   *
1019   * @param newMethod the new tree searching mode
1020   */
1021  public void setSearchPath(SelectedTag newMethod) {
1022   
1023    if (newMethod.getTags() == TAGS_SEARCHPATH) {
1024      m_searchPath = newMethod.getSelectedTag().getID();
1025    }
1026  }
1027
1028  /**
1029   * @return tip text for this property suitable for
1030   * displaying in the explorer/experimenter gui
1031   */
1032  public String randomSeedTipText() {
1033
1034    return "Sets the random seed to use for a random search.";
1035  }
1036
1037  /**
1038   * Gets random seed for a random walk.
1039   *
1040   * @return the random seed
1041   */
1042  public int getRandomSeed() {
1043   
1044    return m_randomSeed;
1045  }
1046
1047  /**
1048   * Sets random seed for a random walk.
1049   *
1050   * @param seed the random seed
1051   */
1052  public void setRandomSeed(int seed) {
1053   
1054    // the actual random object is created when the tree is initialized
1055    m_randomSeed = seed; 
1056  } 
1057
1058  /**
1059   * @return tip text for this property suitable for
1060   * displaying in the explorer/experimenter gui
1061   */
1062  public String saveInstanceDataTipText() {
1063
1064    return "Sets whether the tree is to save instance data - the model will take up more"
1065      + " memory if it does. If enabled you will be able to visualize the instances at"
1066      + " the prediction nodes when visualizing the tree.";
1067  }
1068
1069  /**
1070   * Gets whether the tree is to save instance data.
1071   *
1072   * @return the random seed
1073   */
1074  public boolean getSaveInstanceData() {
1075   
1076    return m_saveInstanceData;
1077  }
1078
1079  /**
1080   * Sets whether the tree is to save instance data.
1081   *
1082   * @param v true then the tree saves instance data
1083   */
1084  public void setSaveInstanceData(boolean v) {
1085   
1086    m_saveInstanceData = v;
1087  }
1088
1089  /**
1090   * Returns an enumeration describing the available options..
1091   *
1092   * @return an enumeration of all the available options.
1093   */
1094  public Enumeration listOptions() {
1095   
1096    Vector newVector = new Vector(3);
1097    newVector.addElement(new Option(
1098                                    "\tNumber of boosting iterations.\n"
1099                                    +"\t(Default = 10)",
1100                                    "B", 1,"-B <number of boosting iterations>"));
1101    newVector.addElement(new Option(
1102                                    "\tExpand nodes: -3(all), -2(weight), -1(z_pure), "
1103                                    +">=0 seed for random walk\n"
1104                                    +"\t(Default = -3)",
1105                                    "E", 1,"-E <-3|-2|-1|>=0>"));
1106    newVector.addElement(new Option(
1107                                    "\tSave the instance data with the model",
1108                                    "D", 0,"-D"));
1109    return newVector.elements();
1110  }
1111
1112  /**
1113   * Parses a given list of options. Valid options are:<p>
1114   *
1115   * -B num <br>
1116   * Set the number of boosting iterations
1117   * (default 10) <p>
1118   *
1119   * -E num <br>
1120   * Set the nodes to expand: -3(all), -2(weight), -1(z_pure), >=0 seed for random walk
1121   * (default -3) <p>
1122   *
1123   * -D <br>
1124   * Save the instance data with the model <p>
1125   *
1126   * @param options the list of options as an array of strings
1127   * @exception Exception if an option is not supported
1128   */
1129  public void setOptions(String[] options) throws Exception {
1130   
1131    String bString = Utils.getOption('B', options);
1132    if (bString.length() != 0) setNumOfBoostingIterations(Integer.parseInt(bString));
1133
1134    String eString = Utils.getOption('E', options);
1135    if (eString.length() != 0) {
1136      int value = Integer.parseInt(eString);
1137      if (value >= 0) {
1138        setSearchPath(new SelectedTag(SEARCHPATH_RANDOM, TAGS_SEARCHPATH));
1139        setRandomSeed(value);
1140      } else setSearchPath(new SelectedTag(value + 3, TAGS_SEARCHPATH));
1141    }
1142
1143    setSaveInstanceData(Utils.getFlag('D', options));
1144
1145    Utils.checkForRemainingOptions(options);
1146  }
1147
1148  /**
1149   * Gets the current settings of ADTree.
1150   *
1151   * @return an array of strings suitable for passing to setOptions()
1152   */
1153  public String[] getOptions() {
1154   
1155    String[] options = new String[6];
1156    int current = 0;
1157    options[current++] = "-B"; options[current++] = "" + getNumOfBoostingIterations();
1158    options[current++] = "-E"; options[current++] = "" +
1159                                 (m_searchPath == SEARCHPATH_RANDOM ?
1160                                  m_randomSeed : m_searchPath - 3);
1161    if (getSaveInstanceData()) options[current++] = "-D";
1162    while (current < options.length) options[current++] = "";
1163    return options;
1164  }
1165
1166  /**
1167   * Calls measure function for tree size - the total number of nodes.
1168   *
1169   * @return the tree size
1170   */
1171  public double measureTreeSize() {
1172   
1173    return numOfAllNodes(m_root);
1174  }
1175
1176  /**
1177   * Calls measure function for leaf size - the number of prediction nodes.
1178   *
1179   * @return the leaf size
1180   */
1181  public double measureNumLeaves() {
1182   
1183    return numOfPredictionNodes(m_root);
1184  }
1185
1186  /**
1187   * Calls measure function for prediction leaf size - the number of
1188   * prediction nodes without children.
1189   *
1190   * @return the leaf size
1191   */
1192  public double measureNumPredictionLeaves() {
1193   
1194    return numOfPredictionLeafNodes(m_root);
1195  }
1196
1197  /**
1198   * Returns the number of nodes expanded.
1199   *
1200   * @return the number of nodes expanded during search
1201   */
1202  public double measureNodesExpanded() {
1203   
1204    return m_nodesExpanded;
1205  }
1206
1207  /**
1208   * Returns the number of examples "counted".
1209   *
1210   * @return the number of nodes processed during search
1211   */
1212
1213  public double measureExamplesProcessed() {
1214   
1215    return m_examplesCounted;
1216  }
1217
1218  /**
1219   * Returns an enumeration of the additional measure names.
1220   *
1221   * @return an enumeration of the measure names
1222   */
1223  public Enumeration enumerateMeasures() {
1224   
1225    Vector newVector = new Vector(4);
1226    newVector.addElement("measureTreeSize");
1227    newVector.addElement("measureNumLeaves");
1228    newVector.addElement("measureNumPredictionLeaves");
1229    newVector.addElement("measureNodesExpanded");
1230    newVector.addElement("measureExamplesProcessed");
1231    return newVector.elements();
1232  }
1233 
1234  /**
1235   * Returns the value of the named measure.
1236   *
1237   * @param additionalMeasureName the name of the measure to query for its value
1238   * @return the value of the named measure
1239   * @exception IllegalArgumentException if the named measure is not supported
1240   */
1241  public double getMeasure(String additionalMeasureName) {
1242   
1243    if (additionalMeasureName.equalsIgnoreCase("measureTreeSize")) {
1244      return measureTreeSize();
1245    }
1246    else if (additionalMeasureName.equalsIgnoreCase("measureNumLeaves")) {
1247      return measureNumLeaves();
1248    }
1249    else if (additionalMeasureName.equalsIgnoreCase("measureNumPredictionLeaves")) {
1250      return measureNumPredictionLeaves();
1251    }
1252    else if (additionalMeasureName.equalsIgnoreCase("measureNodesExpanded")) {
1253      return measureNodesExpanded();
1254    }
1255    else if (additionalMeasureName.equalsIgnoreCase("measureExamplesProcessed")) {
1256      return measureExamplesProcessed();
1257    }
1258    else {throw new IllegalArgumentException(additionalMeasureName
1259                              + " not supported (ADTree)");
1260    }
1261  }
1262
1263  /**
1264   * Returns the total number of nodes in a tree.
1265   *
1266   * @param root the root of the tree being measured
1267   * @return tree size in number of splitter + prediction nodes
1268   */       
1269  protected int numOfAllNodes(PredictionNode root) {
1270   
1271    int numSoFar = 0;
1272    if (root != null) {
1273      numSoFar++;
1274      for (Enumeration e = root.children(); e.hasMoreElements(); ) {
1275        numSoFar++;
1276        Splitter split = (Splitter) e.nextElement();
1277        for (int i=0; i<split.getNumOfBranches(); i++)
1278            numSoFar += numOfAllNodes(split.getChildForBranch(i));
1279      }
1280    }
1281    return numSoFar;
1282  }
1283
1284  /**
1285   * Returns the number of prediction nodes in a tree.
1286   *
1287   * @param root the root of the tree being measured
1288   * @return tree size in number of prediction nodes
1289   */       
1290  protected int numOfPredictionNodes(PredictionNode root) {
1291   
1292    int numSoFar = 0;
1293    if (root != null) {
1294      numSoFar++;
1295      for (Enumeration e = root.children(); e.hasMoreElements(); ) {
1296        Splitter split = (Splitter) e.nextElement();
1297        for (int i=0; i<split.getNumOfBranches(); i++)
1298            numSoFar += numOfPredictionNodes(split.getChildForBranch(i));
1299      }
1300    }
1301    return numSoFar;
1302  }
1303
1304  /**
1305   * Returns the number of leaf nodes in a tree - prediction nodes without
1306   * children.
1307   *
1308   * @param root the root of the tree being measured
1309   * @return tree leaf size in number of prediction nodes
1310   */       
1311  protected int numOfPredictionLeafNodes(PredictionNode root) {
1312   
1313    int numSoFar = 0;
1314    if (root.getChildren().size() > 0) {
1315      for (Enumeration e = root.children(); e.hasMoreElements(); ) {
1316        Splitter split = (Splitter) e.nextElement();
1317        for (int i=0; i<split.getNumOfBranches(); i++)
1318            numSoFar += numOfPredictionLeafNodes(split.getChildForBranch(i));
1319      }
1320    } else numSoFar = 1;
1321    return numSoFar;
1322  }
1323
1324  /**
1325   * Gets the next random value.
1326   *
1327   * @param max the maximum value (+1) to be returned
1328   * @return the next random value (between 0 and max-1)
1329   */
1330  protected int getRandom(int max) {
1331   
1332    return m_random.nextInt(max);
1333  }
1334
1335  /**
1336   * Returns the next number in the order that splitter nodes have been added to
1337   * the tree, and records that a new splitter has been added.
1338   *
1339   * @return the next number in the order
1340   */
1341  public int nextSplitAddedOrder() {
1342
1343    return ++m_lastAddedSplitNum;
1344  }
1345
1346  /**
1347   * Returns default capabilities of the classifier.
1348   *
1349   * @return      the capabilities of this classifier
1350   */
1351  public Capabilities getCapabilities() {
1352    Capabilities result = super.getCapabilities();
1353    result.disableAll();
1354
1355    // attributes
1356    result.enable(Capability.NOMINAL_ATTRIBUTES);
1357    result.enable(Capability.NUMERIC_ATTRIBUTES);
1358    result.enable(Capability.DATE_ATTRIBUTES);
1359    result.enable(Capability.MISSING_VALUES);
1360
1361    // class
1362    result.enable(Capability.BINARY_CLASS);
1363    result.enable(Capability.MISSING_CLASS_VALUES);
1364   
1365    return result;
1366  }
1367
1368  /**
1369   * Builds a classifier for a set of instances.
1370   *
1371   * @param instances the instances to train the classifier with
1372   * @exception Exception if something goes wrong
1373   */
1374  public void buildClassifier(Instances instances) throws Exception {
1375
1376    // can classifier handle the data?
1377    getCapabilities().testWithFail(instances);
1378
1379    // remove instances with missing class
1380    instances = new Instances(instances);
1381    instances.deleteWithMissingClass();
1382
1383    // set up the tree
1384    initClassifier(instances);
1385
1386    // build the tree
1387    for (int T = 0; T < m_boostingIterations; T++) boost();
1388
1389    // clean up if desired
1390    if (!m_saveInstanceData) done();
1391  }
1392
1393  /**
1394   * Frees memory that is no longer needed for a final model - will no longer be able
1395   * to increment the classifier after calling this.
1396   *
1397   */
1398  public void done() {
1399
1400    m_trainInstances = new Instances(m_trainInstances, 0);
1401    m_random = null; 
1402    m_numericAttIndices = null;
1403    m_nominalAttIndices = null;
1404    m_posTrainInstances = null;
1405    m_negTrainInstances = null;
1406  }
1407
1408  /**
1409   * Creates a clone that is identical to the current tree, but is independent.
1410   * Deep copies the essential elements such as the tree nodes, and the instances
1411   * (because the weights change.) Reference copies several elements such as the
1412   * potential splitter sets, assuming that such elements should never differ between
1413   * clones.
1414   *
1415   * @return the clone
1416   */
1417  public Object clone() {
1418   
1419    ADTree clone = new ADTree();
1420
1421    if (m_root != null) { // check for initialization first
1422      clone.m_root = (PredictionNode) m_root.clone(); // deep copy the tree
1423
1424      clone.m_trainInstances = new Instances(m_trainInstances); // copy training instances
1425     
1426      // deep copy the random object
1427      if (m_random != null) { 
1428        SerializedObject randomSerial = null;
1429        try {
1430          randomSerial = new SerializedObject(m_random);
1431        } catch (Exception ignored) {} // we know that Random is serializable
1432        clone.m_random = (Random) randomSerial.getObject();
1433      }
1434
1435      clone.m_lastAddedSplitNum = m_lastAddedSplitNum;
1436      clone.m_numericAttIndices = m_numericAttIndices;
1437      clone.m_nominalAttIndices = m_nominalAttIndices;
1438      clone.m_trainTotalWeight = m_trainTotalWeight;
1439
1440      // reconstruct pos/negTrainInstances references
1441      if (m_posTrainInstances != null) { 
1442        clone.m_posTrainInstances =
1443          new ReferenceInstances(m_trainInstances, m_posTrainInstances.numInstances());
1444        clone.m_negTrainInstances =
1445          new ReferenceInstances(m_trainInstances, m_negTrainInstances.numInstances());
1446        for (Enumeration e = clone.m_trainInstances.enumerateInstances();
1447             e.hasMoreElements(); ) {
1448          Instance inst = (Instance) e.nextElement();
1449          try { // ignore classValue() exception
1450            if ((int) inst.classValue() == 0)
1451              clone.m_negTrainInstances.addReference(inst); // belongs in negative class
1452            else
1453              clone.m_posTrainInstances.addReference(inst); // belongs in positive class
1454          } catch (Exception ignored) {} 
1455        }
1456      }
1457    }
1458    clone.m_nodesExpanded = m_nodesExpanded;
1459    clone.m_examplesCounted = m_examplesCounted;
1460    clone.m_boostingIterations = m_boostingIterations;
1461    clone.m_searchPath = m_searchPath;
1462    clone.m_randomSeed = m_randomSeed;
1463
1464    return clone;
1465  }
1466
1467  /**
1468   * Merges two trees together. Modifies the tree being acted on, leaving tree passed
1469   * as a parameter untouched (cloned). Does not check to see whether training instances
1470   * are compatible - strange things could occur if they are not.
1471   *
1472   * @param mergeWith the tree to merge with
1473   * @exception Exception if merge could not be performed
1474   */
1475  public void merge(ADTree mergeWith) throws Exception {
1476   
1477    if (m_root == null || mergeWith.m_root == null)
1478      throw new Exception("Trying to merge an uninitialized tree");
1479    m_root.merge(mergeWith.m_root, this);
1480  }
1481 
1482  /**
1483   * Returns the revision string.
1484   *
1485   * @return            the revision
1486   */
1487  public String getRevision() {
1488    return RevisionUtils.extract("$Revision: 5928 $");
1489  }
1490
1491  /**
1492   * Main method for testing this class.
1493   *
1494   * @param argv the options
1495   */
1496  public static void main(String [] argv) {
1497    runClassifier(new ADTree(), argv);
1498  }
1499}
Note: See TracBrowser for help on using the repository browser.