source: src/main/java/weka/classifiers/trees/LADTree.java @ 26

Last change on this file since 26 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 48.2 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    LADTree.java
19 *    Copyright (C) 2001 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.trees;
24
25import weka.classifiers.*;
26import weka.core.Capabilities;
27import weka.core.Capabilities.Capability;
28import weka.core.*;
29import weka.classifiers.trees.adtree.ReferenceInstances;
30import java.util.*;
31import java.io.*;
32import weka.core.TechnicalInformation;
33import weka.core.TechnicalInformationHandler;
34import weka.core.TechnicalInformation.Field;
35import weka.core.TechnicalInformation.Type;
36
37/**
38 <!-- globalinfo-start -->
39 * Class for generating a multi-class alternating decision tree using the LogitBoost strategy. For more info, see<br/>
40 * <br/>
41 * Geoffrey Holmes, Bernhard Pfahringer, Richard Kirkby, Eibe Frank, Mark Hall: Multiclass alternating decision trees. In: ECML, 161-172, 2001.
42 * <p/>
43 <!-- globalinfo-end -->
44 *
45 <!-- technical-bibtex-start -->
46 * BibTeX:
47 * <pre>
48 * &#64;inproceedings{Holmes2001,
49 *    author = {Geoffrey Holmes and Bernhard Pfahringer and Richard Kirkby and Eibe Frank and Mark Hall},
50 *    booktitle = {ECML},
51 *    pages = {161-172},
52 *    publisher = {Springer},
53 *    title = {Multiclass alternating decision trees},
54 *    year = {2001}
55 * }
56 * </pre>
57 * <p/>
58 <!-- technical-bibtex-end -->
59 *
60 <!-- options-start -->
61 * Valid options are: <p/>
62 *
63 * <pre> -B &lt;number of boosting iterations&gt;
64 *  Number of boosting iterations.
65 *  (Default = 10)</pre>
66 *
67 * <pre> -D
68 *  If set, classifier is run in debug mode and
69 *  may output additional info to the console</pre>
70 *
71 <!-- options-end -->
72 *
73 * @author Richard Kirkby
74 * @version $Revision: 6035 $
75*/
76
77public class LADTree
78  extends AbstractClassifier implements Drawable,
79                                AdditionalMeasureProducer,
80                                TechnicalInformationHandler {
81
82  /**
83   * For serialization
84   */
85  private static final long serialVersionUID = -4940716114518300302L;
86
87  // Constant from LogitBoost
88  protected double Z_MAX = 4;
89
90  // Number of classes
91  protected int m_numOfClasses;
92
93  // Instances as reference instances
94  protected ReferenceInstances m_trainInstances;
95
96  // Root of the tree
97  protected PredictionNode m_root = null; 
98
99  // To keep track of the order in which splits are added
100  protected int m_lastAddedSplitNum = 0;
101
102  // Indices for numeric attributes
103  protected int[] m_numericAttIndices;
104
105  // Variables to keep track of best options
106  protected double m_search_smallestLeastSquares;
107  protected PredictionNode m_search_bestInsertionNode;
108  protected Splitter m_search_bestSplitter;
109  protected Instances m_search_bestPathInstances;
110
111  // A collection of splitter nodes
112  protected FastVector m_staticPotentialSplitters2way;
113
114  // statistics
115  protected int m_nodesExpanded = 0;
116  protected int m_examplesCounted = 0;
117
118  // options
119  protected int m_boostingIterations = 10;
120
121  /**
122   * Returns a string describing classifier
123   * @return a description suitable for
124   * displaying in the explorer/experimenter gui
125   */
126  public String globalInfo() {
127
128    return  "Class for generating a multi-class alternating decision tree using " +
129      "the LogitBoost strategy. For more info, see\n\n"
130      + getTechnicalInformation().toString();
131  }
132
133  /**
134   * Returns an instance of a TechnicalInformation object, containing
135   * detailed information about the technical background of this class,
136   * e.g., paper reference or book this class is based on.
137   *
138   * @return the technical information about this class
139   */
140  public TechnicalInformation getTechnicalInformation() {
141    TechnicalInformation        result;
142       
143    result = new TechnicalInformation(Type.INPROCEEDINGS);
144    result.setValue(Field.AUTHOR, "Geoffrey Holmes and Bernhard Pfahringer and Richard Kirkby and Eibe Frank and Mark Hall");
145    result.setValue(Field.TITLE, "Multiclass alternating decision trees");
146    result.setValue(Field.BOOKTITLE, "ECML");
147    result.setValue(Field.YEAR, "2001");
148    result.setValue(Field.PAGES, "161-172");
149    result.setValue(Field.PUBLISHER, "Springer");
150   
151    return result;
152  }
153
154  /** helper classes ********************************************************************/
155
156  protected class LADInstance extends DenseInstance {
157    public double[] fVector;
158    public double[] wVector;
159    public double[] pVector;
160    public double[] zVector;
161    public LADInstance(Instance instance) {
162   
163      super(instance);
164     
165      setDataset(instance.dataset()); // preserve dataset
166
167      // set up vectors
168      fVector = new double[m_numOfClasses];
169      wVector = new double[m_numOfClasses];
170      pVector = new double[m_numOfClasses];
171      zVector = new double[m_numOfClasses];
172
173      // set initial probabilities
174      double initProb = 1.0 / ((double) m_numOfClasses);
175      for (int i=0; i<m_numOfClasses; i++) {
176        pVector[i] = initProb;
177      }
178      updateZVector();
179      updateWVector();
180    }
181    public void updateWeights(double[] fVectorIncrement) {
182      for (int i=0; i<fVector.length; i++) {
183        fVector[i] += fVectorIncrement[i];
184      }
185      updateVectors(fVector);
186    }
187    public void updateVectors(double[] newFVector) {
188      updatePVector(newFVector);
189      updateZVector();
190      updateWVector();
191    }
192    public void updatePVector(double[] newFVector) {
193      double max = newFVector[Utils.maxIndex(newFVector)];
194      for (int i=0; i<pVector.length; i++) {
195        pVector[i] = Math.exp(newFVector[i] - max);
196      }
197      Utils.normalize(pVector);
198    }
199    public void updateWVector() {
200      for (int i=0; i<wVector.length; i++) {
201        wVector[i] = (yVector(i) - pVector[i]) / zVector[i];
202      }
203    }
204    public void updateZVector() {
205
206      for (int i=0; i<zVector.length; i++) {
207        if (yVector(i) == 1) {
208          zVector[i] = 1.0 / pVector[i];
209          if (zVector[i] > Z_MAX) { // threshold
210            zVector[i] = Z_MAX;
211          }
212        } else {
213          zVector[i] = -1.0 / (1.0 - pVector[i]);
214          if (zVector[i] < -Z_MAX) { // threshold
215            zVector[i] = -Z_MAX;
216          }
217        }
218      }
219    }
220    public double yVector(int index) {
221      return (index == (int) classValue() ? 1.0 : 0.0); 
222    }
223    public Object copy() {
224      LADInstance copy = new LADInstance((Instance) super.copy());
225      System.arraycopy(fVector, 0, copy.fVector, 0, fVector.length);
226      System.arraycopy(wVector, 0, copy.wVector, 0, wVector.length);
227      System.arraycopy(pVector, 0, copy.pVector, 0, pVector.length);
228      System.arraycopy(zVector, 0, copy.zVector, 0, zVector.length);
229      return copy;
230    }
231    public String toString() {
232
233      StringBuffer text = new StringBuffer();
234      text.append(" * F(");
235      for (int i=0; i<fVector.length; i++) {
236        text.append(Utils.doubleToString(fVector[i], 3));
237        if (i<fVector.length-1) text.append(",");
238      }
239      text.append(") P(");
240      for (int i=0; i<pVector.length; i++) {
241        text.append(Utils.doubleToString(pVector[i], 3));
242        if (i<pVector.length-1) text.append(",");
243      }
244      text.append(") W(");
245      for (int i=0; i<wVector.length; i++) {
246        text.append(Utils.doubleToString(wVector[i], 3));
247        if (i<wVector.length-1) text.append(",");
248      }
249      text.append(")");
250      return super.toString() + text.toString();
251
252    }
253  }
254
255  protected class PredictionNode implements Serializable, Cloneable{
256    private double[] values;
257    private FastVector children; // any number of splitter nodes
258   
259    public PredictionNode(double[] newValues) {
260      values = new double[m_numOfClasses];
261      setValues(newValues);
262      children = new FastVector();
263    }
264    public void setValues(double[] newValues) {
265      System.arraycopy(newValues, 0, values, 0, m_numOfClasses);
266    }
267    public double[] getValues() {
268      return values;
269    }
270    public FastVector getChildren() { return children; }
271    public Enumeration children() { return children.elements(); }
272    public void addChild(Splitter newChild) { // merges, adds a clone (deep copy)
273      Splitter oldEqual = null;
274      for (Enumeration e = children(); e.hasMoreElements(); ) {
275        Splitter split = (Splitter) e.nextElement();
276        if (newChild.equalTo(split)) { oldEqual = split; break; }
277      }
278      if (oldEqual == null) {
279        Splitter addChild = (Splitter) newChild.clone();
280        addChild.orderAdded = ++m_lastAddedSplitNum;
281        children.addElement(addChild);
282      }
283      else { // do a merge
284        for (int i=0; i<newChild.getNumOfBranches(); i++) {
285          PredictionNode oldPred = oldEqual.getChildForBranch(i);
286          PredictionNode newPred = newChild.getChildForBranch(i);
287          if (oldPred != null && newPred != null)
288            oldPred.merge(newPred);
289        }
290      }
291    }
292    public Object clone() { // does a deep copy (recurses through tree)
293      PredictionNode clone = new PredictionNode(values);
294      // should actually clone once merges are enabled!
295      for (Enumeration e = children.elements(); e.hasMoreElements(); )
296        clone.children.addElement((Splitter)((Splitter) e.nextElement()).clone());
297      return clone;
298    }
299    public void merge(PredictionNode merger) {
300      // need to merge linear models here somehow
301      for (int i=0; i<m_numOfClasses; i++) values[i] += merger.values[i];
302      for (Enumeration e = merger.children(); e.hasMoreElements(); ) {
303        addChild((Splitter)e.nextElement());
304      }
305    }
306  }
307
308  /** splitter classes ******************************************************************/
309
310  protected abstract class Splitter implements Serializable, Cloneable {
311      protected int attIndex;
312    public int orderAdded;
313    public abstract int getNumOfBranches();
314    public abstract int branchInstanceGoesDown(Instance i);
315    public abstract Instances instancesDownBranch(int branch, Instances sourceInstances);
316    public abstract String attributeString();
317    public abstract String comparisonString(int branchNum);
318    public abstract boolean equalTo(Splitter compare);
319    public abstract void setChildForBranch(int branchNum, PredictionNode childPredictor);
320    public abstract PredictionNode getChildForBranch(int branchNum);
321    public abstract Object clone();
322  }
323
324  protected class TwoWayNominalSplit extends Splitter {
325      //private int attIndex;
326    private int trueSplitValue;
327    private PredictionNode[] children;
328    public TwoWayNominalSplit(int _attIndex, int _trueSplitValue) {
329      attIndex = _attIndex; trueSplitValue = _trueSplitValue;
330      children = new PredictionNode[2];
331    }
332    public int getNumOfBranches() { return 2; }
333    public int branchInstanceGoesDown(Instance inst) {
334      if (inst.isMissing(attIndex)) return -1;
335      else if (inst.value(attIndex) == trueSplitValue) return 0;
336      else return 1;
337    }
338    public Instances instancesDownBranch(int branch, Instances instances) {
339      ReferenceInstances filteredInstances = new ReferenceInstances(instances, 1);
340      if (branch == -1) {
341        for (Enumeration e = instances.enumerateInstances(); e.hasMoreElements(); ) {
342          Instance inst = (Instance) e.nextElement();
343          if (inst.isMissing(attIndex)) filteredInstances.addReference(inst);
344        }
345      } else if (branch == 0) {
346        for (Enumeration e = instances.enumerateInstances(); e.hasMoreElements(); ) {
347          Instance inst = (Instance) e.nextElement();
348          if (!inst.isMissing(attIndex) && inst.value(attIndex) == trueSplitValue)
349            filteredInstances.addReference(inst);
350        }
351      } else {
352        for (Enumeration e = instances.enumerateInstances(); e.hasMoreElements(); ) {
353          Instance inst = (Instance) e.nextElement();
354          if (!inst.isMissing(attIndex) && inst.value(attIndex) != trueSplitValue)
355            filteredInstances.addReference(inst);
356        }
357      }
358      return filteredInstances;
359    }
360    public String attributeString() {
361      return m_trainInstances.attribute(attIndex).name();
362    }
363    public String comparisonString(int branchNum) {
364      Attribute att = m_trainInstances.attribute(attIndex);
365      if (att.numValues() != 2) 
366        return ((branchNum == 0 ? "= " : "!= ") + att.value(trueSplitValue));
367      else return ("= " + (branchNum == 0 ?
368                           att.value(trueSplitValue) :
369                           att.value(trueSplitValue == 0 ? 1 : 0)));
370    }
371    public boolean equalTo(Splitter compare) {
372      if (compare instanceof TwoWayNominalSplit) { // test object type
373        TwoWayNominalSplit compareSame = (TwoWayNominalSplit) compare;
374        return (attIndex == compareSame.attIndex &&
375                trueSplitValue == compareSame.trueSplitValue);
376      } else return false;
377    }
378    public void setChildForBranch(int branchNum, PredictionNode childPredictor) {
379      children[branchNum] = childPredictor;
380    }
381    public PredictionNode getChildForBranch(int branchNum) {
382      return children[branchNum];
383    }
384    public Object clone() { // deep copy
385      TwoWayNominalSplit clone = new TwoWayNominalSplit(attIndex, trueSplitValue);
386      if (children[0] != null)
387        clone.setChildForBranch(0, (PredictionNode) children[0].clone());
388      if (children[1] != null)
389        clone.setChildForBranch(1, (PredictionNode) children[1].clone());
390      return clone;
391    }
392  }
393
394  protected class TwoWayNumericSplit extends Splitter implements Cloneable {
395      //private int attIndex;
396    private double splitPoint;
397    private PredictionNode[] children;
398    public TwoWayNumericSplit(int _attIndex, double _splitPoint) {
399      attIndex = _attIndex;
400      splitPoint = _splitPoint;
401      children = new PredictionNode[2];
402    }
403    public TwoWayNumericSplit(int _attIndex, Instances instances) throws Exception {
404      attIndex = _attIndex;
405      splitPoint = findSplit(instances, attIndex);
406      children = new PredictionNode[2];
407    }
408    public int getNumOfBranches() { return 2; }
409    public int branchInstanceGoesDown(Instance inst) {
410      if (inst.isMissing(attIndex)) return -1;
411      else if (inst.value(attIndex) < splitPoint) return 0;
412      else return 1;
413    }
414    public Instances instancesDownBranch(int branch, Instances instances) {
415      ReferenceInstances filteredInstances = new ReferenceInstances(instances, 1);
416      if (branch == -1) {
417        for (Enumeration e = instances.enumerateInstances(); e.hasMoreElements(); ) {
418          Instance inst = (Instance) e.nextElement();
419          if (inst.isMissing(attIndex)) filteredInstances.addReference(inst);
420        }
421      } else if (branch == 0) {
422        for (Enumeration e = instances.enumerateInstances(); e.hasMoreElements(); ) {
423          Instance inst = (Instance) e.nextElement();
424          if (!inst.isMissing(attIndex) && inst.value(attIndex) < splitPoint)
425            filteredInstances.addReference(inst);
426        }
427      } else {
428        for (Enumeration e = instances.enumerateInstances(); e.hasMoreElements(); ) {
429          Instance inst = (Instance) e.nextElement();
430          if (!inst.isMissing(attIndex) && inst.value(attIndex) >= splitPoint)
431            filteredInstances.addReference(inst);
432        }
433      }
434      return filteredInstances;
435    }
436    public String attributeString() {
437      return m_trainInstances.attribute(attIndex).name();
438    }
439    public String comparisonString(int branchNum) {
440      return ((branchNum == 0 ? "< " : ">= ") + Utils.doubleToString(splitPoint, 3));
441    }
442    public boolean equalTo(Splitter compare) {
443      if (compare instanceof TwoWayNumericSplit) { // test object type
444        TwoWayNumericSplit compareSame = (TwoWayNumericSplit) compare;
445        return (attIndex == compareSame.attIndex &&
446                splitPoint == compareSame.splitPoint);
447      } else return false;
448    }
449    public void setChildForBranch(int branchNum, PredictionNode childPredictor) {
450      children[branchNum] = childPredictor;
451    }
452    public PredictionNode getChildForBranch(int branchNum) {
453      return children[branchNum];
454    }
455    public Object clone() { // deep copy
456      TwoWayNumericSplit clone = new TwoWayNumericSplit(attIndex, splitPoint);
457      if (children[0] != null)
458        clone.setChildForBranch(0, (PredictionNode) children[0].clone());
459      if (children[1] != null)
460        clone.setChildForBranch(1, (PredictionNode) children[1].clone());
461      return clone;
462    }
463    private double findSplit(Instances instances, int index) throws Exception {
464      double splitPoint = 0;
465      double bestVal = Double.MAX_VALUE, currVal, currCutPoint;
466      int numMissing = 0;
467      double[][] distribution = new double[3][instances.numClasses()];   
468
469      // Compute counts for all the values
470      for (int i = 0; i < instances.numInstances(); i++) {
471        Instance inst = instances.instance(i);
472        if (!inst.isMissing(index)) {
473          distribution[1][(int)inst.classValue()] ++;
474        } else {
475          distribution[2][(int)inst.classValue()] ++;
476          numMissing++;
477        }
478      }
479     
480      // Sort instances
481      instances.sort(index);
482     
483      // Make split counts for each possible split and evaluate
484      for (int i = 0; i < instances.numInstances() - (numMissing + 1); i++) {
485        Instance inst = instances.instance(i);
486        Instance instPlusOne = instances.instance(i + 1);
487        distribution[0][(int)inst.classValue()] += inst.weight();
488        distribution[1][(int)inst.classValue()] -= inst.weight();
489        if (Utils.sm(inst.value(index), instPlusOne.value(index))) {
490          currCutPoint = (inst.value(index) + instPlusOne.value(index)) / 2.0;
491          currVal = ContingencyTables.entropyConditionedOnRows(distribution);
492          if (Utils.sm(currVal, bestVal)) {
493            splitPoint = currCutPoint;
494            bestVal = currVal;
495          }
496        }
497      }
498
499      return splitPoint;
500    }
501  }
502
503  /**
504   * Sets up the tree ready to be trained.
505   *
506   * @param instances the instances to train the tree with
507   * @exception Exception if training data is unsuitable
508   */
509  public void initClassifier(Instances instances) throws Exception {
510
511    // clear stats
512    m_nodesExpanded = 0;
513    m_examplesCounted = 0;
514    m_lastAddedSplitNum = 0;
515
516    m_numOfClasses = instances.numClasses();
517
518    // make sure training data is suitable
519    if (instances.checkForStringAttributes()) {
520      throw new Exception("Can't handle string attributes!");
521    }
522    if (!instances.classAttribute().isNominal()) {
523      throw new Exception("Class must be nominal!");
524    }
525
526    // create training set (use LADInstance class)
527    m_trainInstances =
528      new ReferenceInstances(instances, instances.numInstances());
529    for (Enumeration e = instances.enumerateInstances(); e.hasMoreElements(); ) {
530      Instance inst = (Instance) e.nextElement();
531      if (!inst.classIsMissing()) {
532        LADInstance adtInst = new LADInstance(inst);
533        m_trainInstances.addReference(adtInst);
534        adtInst.setDataset(m_trainInstances);
535      }
536    }
537
538    // create the root prediction node
539    m_root = new PredictionNode(new double[m_numOfClasses]);
540   
541    // pre-calculate what we can
542    generateStaticPotentialSplittersAndNumericIndices();
543  }
544
545    public void next(int iteration) throws Exception {
546        boost();
547    }
548
549    public void done() throws Exception {}
550
551  /**
552   * Performs a single boosting iteration.
553   * Will add a new splitter node and two prediction nodes to the tree
554   * (unless merging takes place).
555   *
556   * @exception Exception if try to boost without setting up tree first
557   */
558  private void boost() throws Exception {
559
560    if (m_trainInstances == null)
561      throw new Exception("Trying to boost with no training data");
562
563    // perform the search
564    searchForBestTest();
565
566    if (m_Debug) {
567      System.out.println("Best split found: "
568                         + m_search_bestSplitter.getNumOfBranches() + "-way split on "
569                         + m_search_bestSplitter.attributeString()
570                         //+ "\nsmallestLeastSquares = " + m_search_smallestLeastSquares);
571                         + "\nBestGain = " + m_search_smallestLeastSquares);
572    }
573
574    if (m_search_bestSplitter == null) return; // handle empty instances
575
576    // create the new nodes for the tree, updating the weights
577    for (int i=0; i<m_search_bestSplitter.getNumOfBranches(); i++) {
578      Instances applicableInstances =
579        m_search_bestSplitter.instancesDownBranch(i, m_search_bestPathInstances);
580      double[] predictionValues = calcPredictionValues(applicableInstances);
581      PredictionNode newPredictor = new PredictionNode(predictionValues);
582      updateWeights(applicableInstances, predictionValues);
583      m_search_bestSplitter.setChildForBranch(i, newPredictor);
584    }
585
586    // insert the new nodes
587    m_search_bestInsertionNode.addChild((Splitter) m_search_bestSplitter);
588
589    if (m_Debug) {
590      System.out.println("Tree is now:\n" + toString(m_root, 1) + "\n");
591      //System.out.println("Instances are now:\n" + m_trainInstances + "\n");
592    }
593
594    // free memory
595    m_search_bestPathInstances = null;
596  }
597
598  private void updateWeights(Instances instances, double[] newPredictionValues) {
599
600    for (int i=0; i<instances.numInstances(); i++)
601      ((LADInstance) instances.instance(i)).updateWeights(newPredictionValues);
602  }
603
604  /**
605   * Generates the m_staticPotentialSplitters2way
606   * vector to contain all possible nominal splits, and the m_numericAttIndices array to
607   * index the numeric attributes in the training data.
608   *
609   */
610  private void generateStaticPotentialSplittersAndNumericIndices() {
611   
612    m_staticPotentialSplitters2way = new FastVector();
613    FastVector numericIndices = new FastVector();
614
615    for (int i=0; i<m_trainInstances.numAttributes(); i++) {
616      if (i == m_trainInstances.classIndex()) continue;
617      if (m_trainInstances.attribute(i).isNumeric())
618        numericIndices.addElement(new Integer(i));
619      else {
620        int numValues = m_trainInstances.attribute(i).numValues();
621        if (numValues == 2) // avoid redundancy due to 2-way symmetry
622          m_staticPotentialSplitters2way.addElement(new TwoWayNominalSplit(i, 0));
623        else for (int j=0; j<numValues; j++)
624          m_staticPotentialSplitters2way.addElement(new TwoWayNominalSplit(i, j));
625      }
626    }
627
628    m_numericAttIndices = new int[numericIndices.size()];
629    for (int i=0; i<numericIndices.size(); i++)
630      m_numericAttIndices[i] = ((Integer)numericIndices.elementAt(i)).intValue();
631  }
632
633  /**
634   * Performs a search for the best test (splitter) to add to the tree, by looking
635   * for the largest weight change.
636   *
637   * @exception Exception if search fails
638   */
639  private void searchForBestTest() throws Exception {
640   
641    if (m_Debug) {
642      System.out.println("Searching for best split...");
643    }
644
645    m_search_smallestLeastSquares = 0.0; //Double.POSITIVE_INFINITY;
646    searchForBestTest(m_root, m_trainInstances);
647  }
648
649  /**
650   * Recursive function that carries out search for the best test (splitter) to add to
651   * this part of the tree, by looking for the largest weight change. Will try 2-way
652   * and/or multi-way splits depending on the current state.
653   *
654   * @param currentNode the root of the subtree to be searched, and the current node
655   * being considered as parent of a new split
656   * @param instances the instances that apply at this node
657   * @exception Exception if search fails
658   */
659  private void searchForBestTest(PredictionNode currentNode, Instances instances)
660    throws Exception
661  {
662
663    // keep stats
664    m_nodesExpanded++;
665    m_examplesCounted += instances.numInstances();
666     
667    // evaluate static splitters (nominal)
668    for (Enumeration e = m_staticPotentialSplitters2way.elements();
669         e.hasMoreElements(); ) {
670      evaluateSplitter((Splitter) e.nextElement(), currentNode, instances);
671    }
672
673    if (m_Debug) {
674        //System.out.println("Instances considered are: " + instances);
675    }
676
677    // evaluate dynamic splitters (numeric)
678    for (int i=0; i<m_numericAttIndices.length; i++) {
679      evaluateNumericSplit(currentNode, instances, m_numericAttIndices[i]);
680    }
681
682    if (currentNode.getChildren().size() == 0) return;
683
684    // keep searching
685    goDownAllPaths(currentNode, instances);
686  }
687
688  /**
689   * Continues general multi-class search by investigating every node in the
690   * subtree under currentNode.
691   *
692   * @param currentNode the root of the subtree to be searched
693   * @param instances the instances that apply at this node
694   * @exception Exception if search fails
695   */
696  private void goDownAllPaths(PredictionNode currentNode, Instances instances)
697    throws Exception
698  {
699   
700    for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) {
701      Splitter split = (Splitter) e.nextElement();
702      for (int i=0; i<split.getNumOfBranches(); i++)
703        searchForBestTest(split.getChildForBranch(i),
704                          split.instancesDownBranch(i, instances));
705    }
706  }
707
708  /**
709   * Investigates the option of introducing a split under currentNode. If the
710   * split creates a weight change that is larger than has already been found it will
711   * update the search information to record this as the best option so far.
712   *
713   * @param split the splitter node to evaluate
714   * @param currentNode the parent under which the split is to be considered
715   * @param instances the instances that apply at this node
716   * @exception Exception if something goes wrong
717   */
718  private void evaluateSplitter(Splitter split, PredictionNode currentNode,
719                                Instances instances)
720    throws Exception
721  {
722   
723    double leastSquares = leastSquaresNonMissing(instances,split.attIndex);
724
725    for (int i=0; i<split.getNumOfBranches(); i++)
726      leastSquares -= leastSquares(split.instancesDownBranch(i, instances));
727
728    if (m_Debug) {
729      //System.out.println("Instances considered are: " + instances);
730      System.out.print(split.getNumOfBranches() + "-way split on " + split.attributeString()
731                       + " has leastSquares value of "
732                       + Utils.doubleToString(leastSquares,3));
733    }
734
735    if (leastSquares > m_search_smallestLeastSquares) {
736      if (m_Debug) {
737        System.out.print(" (best so far)");
738      }
739      m_search_smallestLeastSquares = leastSquares;
740      m_search_bestInsertionNode = currentNode;
741      m_search_bestSplitter = split;
742      m_search_bestPathInstances = instances;
743    }
744    if (m_Debug) {
745      System.out.print("\n");
746    }
747  }
748
749  private void evaluateNumericSplit(PredictionNode currentNode,
750                                    Instances instances, int attIndex)
751  {
752 
753    double[] splitAndLS = findNumericSplitpointAndLS(instances, attIndex);
754    double gain = leastSquaresNonMissing(instances,attIndex) - splitAndLS[1];
755 
756   if (m_Debug) {
757     //System.out.println("Instances considered are: " + instances);
758     System.out.print("Numeric split on " + instances.attribute(attIndex).name()
759                      + " has leastSquares value of " 
760                      //+ Utils.doubleToString(splitAndLS[1],3));
761                      + Utils.doubleToString(gain,3));
762    }
763
764   if (gain > m_search_smallestLeastSquares) {
765      if (m_Debug) {
766        System.out.print(" (best so far)");
767      }
768      m_search_smallestLeastSquares = gain; //splitAndLS[1];
769      m_search_bestInsertionNode = currentNode;
770      m_search_bestSplitter = new TwoWayNumericSplit(attIndex, splitAndLS[0]);;
771      m_search_bestPathInstances = instances;
772    }
773    if (m_Debug) {
774      System.out.print("\n");
775    }
776  }
777
778  private double[] findNumericSplitpointAndLS(Instances instances, int attIndex) {
779
780      double allLS = leastSquares(instances);
781
782    // all instances in right subset
783    double[] term1L = new double[m_numOfClasses];
784    double[] term2L = new double[m_numOfClasses];
785    double[] term3L = new double[m_numOfClasses];
786    double[] meanNumL = new double[m_numOfClasses];
787    double[] meanDenL = new double[m_numOfClasses];
788
789    double[] term1R = new double[m_numOfClasses];
790    double[] term2R = new double[m_numOfClasses];
791    double[] term3R = new double[m_numOfClasses];
792    double[] meanNumR = new double[m_numOfClasses];
793    double[] meanDenR = new double[m_numOfClasses];
794
795    double temp1, temp2, temp3;
796
797    double[] classMeans = new double[m_numOfClasses];
798    double[] classTotals = new double[m_numOfClasses];
799
800    // fill up RHS
801    for (int j=0; j<m_numOfClasses; j++) { 
802      for (int i=0; i<instances.numInstances(); i++) {
803        LADInstance inst = (LADInstance) instances.instance(i);
804        temp1 = inst.wVector[j] * inst.zVector[j];
805        term1R[j] += temp1 * inst.zVector[j];
806        term2R[j] += temp1;
807        term3R[j] += inst.wVector[j];
808        meanNumR[j] += inst.wVector[j] * inst.zVector[j];
809      }
810    }
811
812    //leastSquares = term1 - (2.0 * u * term2) + (u * u * term3);
813
814    double leastSquares;
815    boolean newSplit;
816    double smallestLeastSquares = Double.POSITIVE_INFINITY;
817    double bestSplit = 0.0;
818    double meanL, meanR;
819
820    instances.sort(attIndex);
821
822    for (int i=0; i<instances.numInstances()-1; i++) {// shift inst from right to left
823      if (instances.instance(i+1).isMissing(attIndex)) break;
824      if (instances.instance(i+1).value(attIndex) > instances.instance(i).value(attIndex))
825        newSplit = true;
826      else newSplit = false;
827      LADInstance inst = (LADInstance) instances.instance(i);
828      leastSquares = 0.0;
829      for (int j=0; j<m_numOfClasses; j++) {   
830        temp1 = inst.wVector[j] * inst.zVector[j];
831        temp2 = temp1 * inst.zVector[j];
832        temp3 = inst.wVector[j] * inst.zVector[j];
833        term1L[j] += temp2;
834        term2L[j] += temp1;
835        term3L[j] += inst.wVector[j];
836        term1R[j] -= temp2;
837        term2R[j] -= temp1;
838        term3R[j] -= inst.wVector[j];
839        meanNumL[j] += temp3;
840        meanNumR[j] -= temp3;
841        if (newSplit) {
842          meanL = meanNumL[j] / term3L[j];
843          meanR = meanNumR[j] / term3R[j];
844          leastSquares += term1L[j] - (2.0 * meanL * term2L[j])
845            + (meanL * meanL * term3L[j]);
846          leastSquares += term1R[j] - (2.0 * meanR * term2R[j])
847            + (meanR * meanR * term3R[j]);
848        }
849      }
850      if (m_Debug && newSplit)
851      System.out.println(attIndex + "/" + 
852                         ((instances.instance(i).value(attIndex) +
853                           instances.instance(i+1).value(attIndex)) / 2.0) +
854                         " = " + (allLS - leastSquares));
855
856      if (newSplit && leastSquares < smallestLeastSquares) {
857        bestSplit = (instances.instance(i).value(attIndex) +
858                     instances.instance(i+1).value(attIndex)) / 2.0;
859        smallestLeastSquares = leastSquares;
860      }
861    }
862    double[] result = new double[2];
863    result[0] = bestSplit;
864    result[1] = smallestLeastSquares > 0 ? smallestLeastSquares : 0;
865    return result;
866  }
867
868  private double leastSquares(Instances instances) {
869
870    double numerator=0, denominator=0, w, t;
871    double[] classMeans = new double[m_numOfClasses];
872    double[] classTotals = new double[m_numOfClasses];
873
874    for (int i=0; i<instances.numInstances(); i++) {
875      LADInstance inst = (LADInstance) instances.instance(i);
876      for (int j=0; j<m_numOfClasses; j++) {
877        classMeans[j] += inst.zVector[j] * inst.wVector[j];
878        classTotals[j] += inst.wVector[j];
879      }
880    }
881
882    double numInstances = (double) instances.numInstances();
883    for (int j=0; j<m_numOfClasses; j++) {
884      if (classTotals[j] != 0) classMeans[j] /= classTotals[j];
885    }
886
887    for (int i=0; i<instances.numInstances(); i++) 
888      for (int j=0; j<m_numOfClasses; j++) {
889        LADInstance inst = (LADInstance) instances.instance(i);
890        w = inst.wVector[j];
891        t = inst.zVector[j] - classMeans[j];
892        numerator += w * (t * t);
893        denominator += w;
894      }
895    //System.out.println(numerator + " / " + denominator);
896    return numerator > 0 ? numerator : 0;//  / denominator;
897  }
898
899
900  private double leastSquaresNonMissing(Instances instances, int attIndex) {
901
902    double numerator=0, denominator=0, w, t;
903    double[] classMeans = new double[m_numOfClasses];
904    double[] classTotals = new double[m_numOfClasses];
905
906    for (int i=0; i<instances.numInstances(); i++) {
907      LADInstance inst = (LADInstance) instances.instance(i);
908      for (int j=0; j<m_numOfClasses; j++) {
909          classMeans[j] += inst.zVector[j] * inst.wVector[j];
910          classTotals[j] += inst.wVector[j];
911      }
912    }
913
914    double numInstances = (double) instances.numInstances();
915    for (int j=0; j<m_numOfClasses; j++) {
916      if (classTotals[j] != 0) classMeans[j] /= classTotals[j];
917    }
918
919    for (int i=0; i<instances.numInstances(); i++) 
920      for (int j=0; j<m_numOfClasses; j++) {
921        LADInstance inst = (LADInstance) instances.instance(i);
922        if(!inst.isMissing(attIndex)) {
923            w = inst.wVector[j];
924            t = inst.zVector[j] - classMeans[j];
925            numerator += w * (t * t);
926            denominator += w;
927        }
928      }
929    //System.out.println(numerator + " / " + denominator);
930    return numerator > 0 ? numerator : 0;//  / denominator;
931  }
932
933  private double[] calcPredictionValues(Instances instances) {
934
935    double[] classMeans = new double[m_numOfClasses];
936    double meansSum = 0;
937    double multiplier = ((double) (m_numOfClasses-1)) / ((double) (m_numOfClasses));
938
939    double[] classTotals = new double[m_numOfClasses];
940
941    for (int i=0; i<instances.numInstances(); i++) {
942      LADInstance inst = (LADInstance) instances.instance(i);
943      for (int j=0; j<m_numOfClasses; j++) {
944        classMeans[j] += inst.zVector[j] * inst.wVector[j];
945        classTotals[j] += inst.wVector[j];
946      }
947    }
948    double numInstances = (double) instances.numInstances();
949    for (int j=0; j<m_numOfClasses; j++) {
950      if (classTotals[j] != 0) classMeans[j] /= classTotals[j];
951      meansSum += classMeans[j];
952    }
953    meansSum /= m_numOfClasses;
954
955    for (int j=0; j<m_numOfClasses; j++) {
956      classMeans[j] = multiplier * (classMeans[j] - meansSum);
957    }
958    return classMeans;
959  }
960
961  /**
962   * Returns the class probability distribution for an instance.
963   *
964   * @param instance the instance to be classified
965   * @return the distribution the tree generates for the instance
966   */
967  public double[] distributionForInstance(Instance instance) {
968   
969    double[] predValues = new double[m_numOfClasses];
970    for (int i=0; i<m_numOfClasses; i++) predValues[i] = 0.0;
971    double[] distribution = predictionValuesForInstance(instance, m_root, predValues);
972    double max = distribution[Utils.maxIndex(distribution)];
973    for (int i=0; i<m_numOfClasses; i++) {
974      distribution[i] = Math.exp(distribution[i] - max);
975    }
976    double sum = Utils.sum(distribution);
977    if (sum > 0.0) Utils.normalize(distribution, sum);
978    return distribution;
979  }
980
981  /**
982   * Returns the class prediction values (votes) for an instance.
983   *
984   * @param inst the instance
985   * @param currentNode the root of the tree to get the values from
986   * @param currentValues the current values before adding the values contained in the
987   * subtree
988   * @return the class prediction values (votes)
989   */
990  private double[] predictionValuesForInstance(Instance inst, PredictionNode currentNode,
991                                               double[] currentValues) {
992   
993    double[] predValues = currentNode.getValues();
994    for (int i=0; i<m_numOfClasses; i++) currentValues[i] += predValues[i];
995    //for (int i=0; i<m_numOfClasses; i++) currentValues[i] = predValues[i];
996    for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) {
997      Splitter split = (Splitter) e.nextElement();
998      int branch = split.branchInstanceGoesDown(inst);
999      if (branch >= 0)
1000        currentValues = predictionValuesForInstance(inst, split.getChildForBranch(branch),
1001                                                    currentValues);
1002    }
1003    return currentValues;
1004  }
1005
1006
1007
1008  /** model output functions ************************************************************/
1009
1010  /**
1011   * Returns a description of the classifier.
1012   *
1013   * @return a string containing a description of the classifier
1014   */
1015  public String toString() {
1016   
1017    String className = getClass().getName();
1018    if (m_root == null)
1019      return (className +" not built yet");
1020    else {
1021      return (className + ":\n\n" + toString(m_root, 1) +
1022              "\nLegend: " + legend() +
1023              "\n#Tree size (total): " +
1024              numOfAllNodes(m_root) + 
1025              "\n#Tree size (number of predictor nodes): " +
1026              numOfPredictionNodes(m_root) + 
1027              "\n#Leaves (number of predictor nodes): " +
1028              numOfLeafNodes(m_root) + 
1029              "\n#Expanded nodes: " +
1030              m_nodesExpanded +
1031              "\n#Processed examples: " +
1032              m_examplesCounted + 
1033              "\n#Ratio e/n: " + 
1034              ((double)m_examplesCounted/(double)m_nodesExpanded)
1035              );
1036    }
1037  }
1038
1039  /**
1040   * Traverses the tree, forming a string that describes it.
1041   *
1042   * @param currentNode the current node under investigation
1043   * @param level the current level in the tree
1044   * @return the string describing the subtree
1045   */     
1046  private String toString(PredictionNode currentNode, int level) {
1047   
1048    StringBuffer text = new StringBuffer();
1049   
1050    text.append(": ");
1051    double[] predValues = currentNode.getValues();
1052    for (int i=0; i<m_numOfClasses; i++) {
1053      text.append(Utils.doubleToString(predValues[i],3));
1054      if (i<m_numOfClasses-1) text.append(",");
1055    }
1056    for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) {
1057      Splitter split = (Splitter) e.nextElement();
1058           
1059      for (int j=0; j<split.getNumOfBranches(); j++) {
1060        PredictionNode child = split.getChildForBranch(j);
1061        if (child != null) {
1062          text.append("\n");
1063          for (int k = 0; k < level; k++) {
1064            text.append("|  ");
1065          }
1066          text.append("(" + split.orderAdded + ")");
1067          text.append(split.attributeString() + " " + split.comparisonString(j));
1068          text.append(toString(child, level + 1));
1069        }
1070      }
1071    }
1072    return text.toString();
1073  }
1074
1075  /**
1076   * Returns graph describing the tree.
1077   *
1078   * @return the graph of the tree in dotty format
1079   * @exception Exception if something goes wrong
1080   */
1081  public String graph() throws Exception {
1082   
1083    StringBuffer text = new StringBuffer();
1084    text.append("digraph ADTree {\n");
1085    //text.append("center=true\nsize=\"8.27,11.69\"\n");
1086    graphTraverse(m_root, text, 0, 0);
1087    return text.toString() +"}\n";
1088  }
1089
1090
1091  /**
1092   * Traverses the tree, graphing each node.
1093   *
1094   * @param currentNode the currentNode under investigation
1095   * @param text the string built so far
1096   * @param splitOrder the order the parent splitter was added to the tree
1097   * @param predOrder the order this predictor was added to the split
1098   * @exception Exception if something goes wrong
1099   */       
1100  protected void graphTraverse(PredictionNode currentNode, StringBuffer text,
1101                               int splitOrder, int predOrder)
1102    throws Exception
1103  {
1104   
1105    text.append("S" + splitOrder + "P" + predOrder + " [label=\"");
1106    double[] predValues = currentNode.getValues();
1107    for (int i=0; i<m_numOfClasses; i++) {
1108      text.append(Utils.doubleToString(predValues[i],3));
1109      if (i<m_numOfClasses-1) text.append(",");
1110    }
1111    if (splitOrder == 0) // show legend in root
1112      text.append(" (" + legend() + ")");
1113    text.append("\" shape=box style=filled]\n");
1114    for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) {
1115      Splitter split = (Splitter) e.nextElement();
1116      text.append("S" + splitOrder + "P" + predOrder + "->" + "S" + split.orderAdded +
1117                  " [style=dotted]\n");
1118      text.append("S" + split.orderAdded + " [label=\"" + split.orderAdded + ": " +
1119                  split.attributeString() + "\"]\n");
1120
1121      for (int i=0; i<split.getNumOfBranches(); i++) {
1122        PredictionNode child = split.getChildForBranch(i);
1123        if (child != null) {
1124          text.append("S" + split.orderAdded + "->" + "S" + split.orderAdded + "P" + i +
1125                      " [label=\"" + split.comparisonString(i) + "\"]\n");
1126          graphTraverse(child, text, split.orderAdded, i);
1127        }
1128      }
1129    } 
1130  }
1131
1132  /**
1133   * Returns the legend of the tree, describing how results are to be interpreted.
1134   *
1135   * @return a string containing the legend of the classifier
1136   */
1137  public String legend() {
1138   
1139    Attribute classAttribute = null;
1140    if (m_trainInstances == null) return "";
1141    try {classAttribute = m_trainInstances.classAttribute();} catch (Exception x){};
1142    if (m_numOfClasses == 1) {
1143      return ("-ve = " + classAttribute.value(0)
1144              + ", +ve = " + classAttribute.value(1));
1145    } else {
1146      StringBuffer text = new StringBuffer();
1147      for (int i=0; i<m_numOfClasses; i++) {
1148        if (i>0) text.append(", ");
1149        text.append(classAttribute.value(i));
1150      }
1151      return text.toString();
1152    }
1153  }
1154
1155
1156
1157  /** option handling  ******************************************************************/
1158
1159  /**
1160   * @return tip text for this property suitable for
1161   * displaying in the explorer/experimenter gui
1162   */
1163  public String numOfBoostingIterationsTipText() {
1164
1165    return "The number of boosting iterations to use, which determines the size of the tree.";
1166  }
1167
1168  /**
1169   * Gets the number of boosting iterations.
1170   *
1171   * @return the number of boosting iterations
1172   */
1173  public int getNumOfBoostingIterations() {
1174   
1175    return m_boostingIterations;
1176  }
1177
1178  /**
1179   * Sets the number of boosting iterations.
1180   *
1181   * @param b the number of boosting iterations to use
1182   */
1183  public void setNumOfBoostingIterations(int b) {
1184   
1185    m_boostingIterations = b; 
1186  }
1187
1188  /**
1189   * Returns an enumeration describing the available options.
1190   *
1191   * @return an enumeration of all the available options
1192   */
1193  public Enumeration listOptions() {
1194   
1195    Vector newVector = new Vector(1);
1196    newVector.addElement(new Option(
1197                                    "\tNumber of boosting iterations.\n"
1198                                    +"\t(Default = 10)",
1199                                    "B", 1,"-B <number of boosting iterations>"));
1200
1201    Enumeration enu = super.listOptions();
1202    while (enu.hasMoreElements()) {
1203      newVector.addElement(enu.nextElement());
1204    }
1205
1206    return newVector.elements();
1207  }
1208
1209  /**
1210   * Parses a given list of options. Valid options are:<p>
1211   *
1212   * -B num <br>
1213   * Set the number of boosting iterations
1214   * (default 10) <p>
1215   *
1216   * @param options the list of options as an array of strings
1217   * @exception Exception if an option is not supported
1218   */
1219  public void setOptions(String[] options) throws Exception {
1220   
1221    String bString = Utils.getOption('B', options);
1222    if (bString.length() != 0) setNumOfBoostingIterations(Integer.parseInt(bString));
1223
1224    super.setOptions(options);
1225
1226    Utils.checkForRemainingOptions(options);
1227  }
1228
1229  /**
1230   * Gets the current settings of ADTree.
1231   *
1232   * @return an array of strings suitable for passing to setOptions()
1233   */
1234  public String[] getOptions() {
1235   
1236    String[] options = new String[2  + super.getOptions().length];
1237
1238    int current = 0;
1239    options[current++] = "-B"; options[current++] = "" + getNumOfBoostingIterations();
1240
1241    System.arraycopy(super.getOptions(), 0, options, current, super.getOptions().length);
1242
1243    while (current < options.length) options[current++] = "";
1244    return options;
1245  }
1246
1247
1248
1249  /** additional measures ***************************************************************/
1250
1251  /**
1252   * Calls measure function for tree size.
1253   *
1254   * @return the tree size
1255   */
1256  public double measureTreeSize() {
1257   
1258    return numOfAllNodes(m_root);
1259  }
1260
1261  /**
1262   * Calls measure function for leaf size.
1263   *
1264   * @return the leaf size
1265   */
1266  public double measureNumLeaves() {
1267   
1268    return numOfPredictionNodes(m_root);
1269  }
1270
1271  /**
1272   * Calls measure function for leaf size.
1273   *
1274   * @return the leaf size
1275   */
1276  public double measureNumPredictionLeaves() {
1277   
1278    return numOfLeafNodes(m_root);
1279  }
1280
1281  /**
1282   * Returns the number of nodes expanded.
1283   *
1284   * @return the number of nodes expanded during search
1285   */
1286  public double measureNodesExpanded() {
1287   
1288    return m_nodesExpanded;
1289  }
1290
1291  /**
1292   * Returns the number of examples "counted".
1293   *
1294   * @return the number of nodes processed during search
1295   */
1296  public double measureExamplesCounted() {
1297   
1298    return m_examplesCounted;
1299  }
1300
1301  /**
1302   * Returns an enumeration of the additional measure names.
1303   *
1304   * @return an enumeration of the measure names
1305   */
1306  public Enumeration enumerateMeasures() {
1307   
1308    Vector newVector = new Vector(5);
1309    newVector.addElement("measureTreeSize");
1310    newVector.addElement("measureNumLeaves");
1311    newVector.addElement("measureNumPredictionLeaves");
1312    newVector.addElement("measureNodesExpanded");
1313    newVector.addElement("measureExamplesCounted");
1314    return newVector.elements();
1315  }
1316 
1317  /**
1318   * Returns the value of the named measure.
1319   *
1320   * @param additionalMeasureName the name of the measure to query for its value
1321   * @return the value of the named measure
1322   * @exception IllegalArgumentException if the named measure is not supported
1323   */
1324  public double getMeasure(String additionalMeasureName) {
1325   
1326    if (additionalMeasureName.equals("measureTreeSize")) {
1327      return measureTreeSize();
1328    }
1329    else if (additionalMeasureName.equals("measureNodesExpanded")) {
1330      return measureNodesExpanded();
1331    }
1332    else if (additionalMeasureName.equals("measureNumLeaves")) {
1333      return measureNumLeaves();
1334    }
1335    else if (additionalMeasureName.equals("measureNumPredictionLeaves")) {
1336      return measureNumPredictionLeaves();
1337    }
1338    else if (additionalMeasureName.equals("measureExamplesCounted")) {
1339      return measureExamplesCounted();
1340    }
1341    else {throw new IllegalArgumentException(additionalMeasureName
1342                              + " not supported (ADTree)");
1343    }
1344  }
1345
1346  /**
1347   * Returns the number of prediction nodes in a tree.
1348   *
1349   * @param root the root of the tree being measured
1350   * @return tree size in number of prediction nodes
1351   */       
1352  protected int numOfPredictionNodes(PredictionNode root) {
1353   
1354    int numSoFar = 0;
1355    if (root != null) {
1356      numSoFar++;
1357      for (Enumeration e = root.children(); e.hasMoreElements(); ) {
1358        Splitter split = (Splitter) e.nextElement();
1359        for (int i=0; i<split.getNumOfBranches(); i++)
1360            numSoFar += numOfPredictionNodes(split.getChildForBranch(i));
1361      }
1362    }
1363    return numSoFar;
1364  }
1365
1366  /**
1367   * Returns the number of leaf nodes in a tree.
1368   *
1369   * @param root the root of the tree being measured
1370   * @return tree leaf size in number of prediction nodes
1371   */       
1372  protected int numOfLeafNodes(PredictionNode root) {
1373   
1374    int numSoFar = 0;
1375    if (root.getChildren().size() > 0) {
1376      for (Enumeration e = root.children(); e.hasMoreElements(); ) {
1377        Splitter split = (Splitter) e.nextElement();
1378        for (int i=0; i<split.getNumOfBranches(); i++)
1379            numSoFar += numOfLeafNodes(split.getChildForBranch(i));
1380      }
1381    } else numSoFar = 1;
1382    return numSoFar;
1383  }
1384
1385  /**
1386   * Returns the total number of nodes in a tree.
1387   *
1388   * @param root the root of the tree being measured
1389   * @return tree size in number of splitter + prediction nodes
1390   */       
1391  protected int numOfAllNodes(PredictionNode root) {
1392   
1393    int numSoFar = 0;
1394    if (root != null) {
1395      numSoFar++;
1396      for (Enumeration e = root.children(); e.hasMoreElements(); ) {
1397        numSoFar++;
1398        Splitter split = (Splitter) e.nextElement();
1399        for (int i=0; i<split.getNumOfBranches(); i++)
1400            numSoFar += numOfAllNodes(split.getChildForBranch(i));
1401      }
1402    }
1403    return numSoFar;
1404  }
1405 
1406  /** main functions ********************************************************************/
1407
1408  /**
1409   * Builds a classifier for a set of instances.
1410   *
1411   * @param instances the instances to train the classifier with
1412   * @exception Exception if something goes wrong
1413   */
1414  public void buildClassifier(Instances instances) throws Exception {
1415
1416    // set up the tree
1417    initClassifier(instances);
1418
1419    // build the tree
1420    for (int T = 0; T < m_boostingIterations; T++) {
1421        boost();   
1422    }
1423  }
1424
1425    public int predictiveError(Instances test) {
1426        int error = 0;
1427        for(int i = test.numInstances()-1; i>=0; i--) {
1428            Instance inst = test.instance(i);
1429            try {
1430                if (classifyInstance(inst) != inst.classValue())
1431                    error++;
1432            } catch (Exception e) { error++;}
1433        }
1434        return error;
1435    }
1436
1437  /**
1438   * Merges two trees together. Modifies the tree being acted on, leaving tree passed
1439   * as a parameter untouched (cloned). Does not check to see whether training instances
1440   * are compatible - strange things could occur if they are not.
1441   *
1442   * @param mergeWith the tree to merge with
1443   * @exception Exception if merge could not be performed
1444   */
1445  public void merge(LADTree mergeWith) throws Exception {
1446   
1447    if (m_root == null || mergeWith.m_root == null)
1448      throw new Exception("Trying to merge an uninitialized tree");
1449    if (m_numOfClasses != mergeWith.m_numOfClasses)
1450      throw new Exception("Trees not suitable for merge - "
1451                          + "different sized prediction nodes");
1452    m_root.merge(mergeWith.m_root);
1453  }
1454
1455  /**
1456   *  Returns the type of graph this classifier
1457   *  represents.
1458   *  @return Drawable.TREE
1459   */   
1460  public int graphType() {
1461      return Drawable.TREE;
1462  }
1463 
1464  /**
1465   * Returns the revision string.
1466   *
1467   * @return            the revision
1468   */
1469  public String getRevision() {
1470    return RevisionUtils.extract("$Revision: 6035 $");
1471  }
1472
1473  /**
1474   * Returns default capabilities of the classifier.
1475   *
1476   * @return      the capabilities of this classifier
1477   */
1478  public Capabilities getCapabilities() {
1479    Capabilities result = super.getCapabilities();
1480    result.disableAll();
1481
1482    // attributes
1483    result.enable(Capability.NOMINAL_ATTRIBUTES);
1484    result.enable(Capability.NUMERIC_ATTRIBUTES);
1485    result.enable(Capability.DATE_ATTRIBUTES);
1486    result.enable(Capability.MISSING_VALUES);
1487
1488    // class
1489    result.enable(Capability.NOMINAL_CLASS);
1490    result.enable(Capability.MISSING_CLASS_VALUES);
1491   
1492    return result;
1493  }
1494
1495  /**
1496   * Main method for testing this class.
1497   *
1498   * @param argv the options
1499   */
1500  public static void main(String [] argv) {   
1501    runClassifier(new LADTree(), argv);
1502  }
1503
1504}
1505
Note: See TracBrowser for help on using the repository browser.