/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* ADTree.java
* Copyright (C) 2001 University of Waikato, Hamilton, New Zealand
*
*/
package weka.classifiers.trees;
import weka.classifiers.Classifier;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.IterativeClassifier;
import weka.classifiers.trees.adtree.PredictionNode;
import weka.classifiers.trees.adtree.ReferenceInstances;
import weka.classifiers.trees.adtree.Splitter;
import weka.classifiers.trees.adtree.TwoWayNominalSplit;
import weka.classifiers.trees.adtree.TwoWayNumericSplit;
import weka.core.AdditionalMeasureProducer;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Drawable;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.SerializedObject;
import weka.core.Tag;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.core.Capabilities.Capability;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
/**
* Class for generating an alternating decision tree. The basic algorithm is based on:
*
* Freund, Y., Mason, L.: The alternating decision tree learning algorithm. In: Proceeding of the Sixteenth International Conference on Machine Learning, Bled, Slovenia, 124-133, 1999.
*
* This version currently only supports two-class problems. The number of boosting iterations needs to be manually tuned to suit the dataset and the desired complexity/accuracy tradeoff. Induction of the trees has been optimized, and heuristic search methods have been introduced to speed learning.
*
* @inproceedings{Freund1999, * address = {Bled, Slovenia}, * author = {Freund, Y. and Mason, L.}, * booktitle = {Proceeding of the Sixteenth International Conference on Machine Learning}, * pages = {124-133}, * title = {The alternating decision tree learning algorithm}, * year = {1999} * } ** * * Valid options are: * *
-B <number of boosting iterations> * Number of boosting iterations. * (Default = 10)* *
-E <-3|-2|-1|>=0> * Expand nodes: -3(all), -2(weight), -1(z_pure), >=0 seed for random walk * (Default = -3)* *
-D * Save the instance data with the model* * * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz) * @author Bernhard Pfahringer (bernhard@cs.waikato.ac.nz) * @version $Revision: 5928 $ */ public class ADTree extends AbstractClassifier implements OptionHandler, Drawable, AdditionalMeasureProducer, WeightedInstancesHandler, IterativeClassifier, TechnicalInformationHandler { /** for serialization */ static final long serialVersionUID = -1532264837167690683L; /** * Returns a string describing classifier * @return a description suitable for * displaying in the explorer/experimenter gui */ public String globalInfo() { return "Class for generating an alternating decision tree. The basic " + "algorithm is based on:\n\n" + getTechnicalInformation().toString() + "\n\n" + "This version currently only supports two-class problems. The number of boosting " + "iterations needs to be manually tuned to suit the dataset and the desired " + "complexity/accuracy tradeoff. Induction of the trees has been optimized, and heuristic " + "search methods have been introduced to speed learning."; } /** search mode: Expand all paths */ public static final int SEARCHPATH_ALL = 0; /** search mode: Expand the heaviest path */ public static final int SEARCHPATH_HEAVIEST = 1; /** search mode: Expand the best z-pure path */ public static final int SEARCHPATH_ZPURE = 2; /** search mode: Expand a random path */ public static final int SEARCHPATH_RANDOM = 3; /** The search modes */ public static final Tag [] TAGS_SEARCHPATH = { new Tag(SEARCHPATH_ALL, "Expand all paths"), new Tag(SEARCHPATH_HEAVIEST, "Expand the heaviest path"), new Tag(SEARCHPATH_ZPURE, "Expand the best z-pure path"), new Tag(SEARCHPATH_RANDOM, "Expand a random path") }; /** The instances used to train the tree */ protected Instances m_trainInstances; /** The root of the tree */ protected PredictionNode m_root = null; /** The random number generator - used for the random search heuristic */ protected Random m_random = null; /** The number of the last splitter added to the tree */ protected int m_lastAddedSplitNum = 0; /** An array containing the inidices to the numeric attributes in the data */ protected int[] m_numericAttIndices; /** An array containing the inidices to the nominal attributes in the data */ protected int[] m_nominalAttIndices; /** The total weight of the instances - used to speed Z calculations */ protected double m_trainTotalWeight; /** The training instances with positive class - referencing the training dataset */ protected ReferenceInstances m_posTrainInstances; /** The training instances with negative class - referencing the training dataset */ protected ReferenceInstances m_negTrainInstances; /** The best node to insert under, as found so far by the latest search */ protected PredictionNode m_search_bestInsertionNode; /** The best splitter to insert, as found so far by the latest search */ protected Splitter m_search_bestSplitter; /** The smallest Z value found so far by the latest search */ protected double m_search_smallestZ; /** The positive instances that apply to the best path found so far */ protected Instances m_search_bestPathPosInstances; /** The negative instances that apply to the best path found so far */ protected Instances m_search_bestPathNegInstances; /** Statistics - the number of prediction nodes investigated during search */ protected int m_nodesExpanded = 0; /** Statistics - the number of instances processed during search */ protected int m_examplesCounted = 0; /** Option - the number of boosting iterations o perform */ protected int m_boostingIterations = 10; /** Option - the search mode */ protected int m_searchPath = 0; /** Option - the seed to use for a random search */ protected int m_randomSeed = 0; /** Option - whether the tree should remember the instance data */ protected boolean m_saveInstanceData = false; /** * Returns an instance of a TechnicalInformation object, containing * detailed information about the technical background of this class, * e.g., paper reference or book this class is based on. * * @return the technical information about this class */ public TechnicalInformation getTechnicalInformation() { TechnicalInformation result; result = new TechnicalInformation(Type.INPROCEEDINGS); result.setValue(Field.AUTHOR, "Freund, Y. and Mason, L."); result.setValue(Field.YEAR, "1999"); result.setValue(Field.TITLE, "The alternating decision tree learning algorithm"); result.setValue(Field.BOOKTITLE, "Proceeding of the Sixteenth International Conference on Machine Learning"); result.setValue(Field.ADDRESS, "Bled, Slovenia"); result.setValue(Field.PAGES, "124-133"); return result; } /** * Sets up the tree ready to be trained, using two-class optimized method. * * @param instances the instances to train the tree with * @exception Exception if training data is unsuitable */ public void initClassifier(Instances instances) throws Exception { // clear stats m_nodesExpanded = 0; m_examplesCounted = 0; m_lastAddedSplitNum = 0; // prepare the random generator m_random = new Random(m_randomSeed); // create training set m_trainInstances = new Instances(instances); // create positive/negative subsets m_posTrainInstances = new ReferenceInstances(m_trainInstances, m_trainInstances.numInstances()); m_negTrainInstances = new ReferenceInstances(m_trainInstances, m_trainInstances.numInstances()); for (Enumeration e = m_trainInstances.enumerateInstances(); e.hasMoreElements(); ) { Instance inst = (Instance) e.nextElement(); if ((int) inst.classValue() == 0) m_negTrainInstances.addReference(inst); // belongs in negative class else m_posTrainInstances.addReference(inst); // belongs in positive class } m_posTrainInstances.compactify(); m_negTrainInstances.compactify(); // create the root prediction node double rootPredictionValue = calcPredictionValue(m_posTrainInstances, m_negTrainInstances); m_root = new PredictionNode(rootPredictionValue); // pre-adjust weights updateWeights(m_posTrainInstances, m_negTrainInstances, rootPredictionValue); // pre-calculate what we can generateAttributeIndicesSingle(); } /** * Performs one iteration. * * @param iteration the index of the current iteration (0-based) * @exception Exception if this iteration fails */ public void next(int iteration) throws Exception { boost(); } /** * Performs a single boosting iteration, using two-class optimized method. * Will add a new splitter node and two prediction nodes to the tree * (unless merging takes place). * * @exception Exception if try to boost without setting up tree first or there are no * instances to train with */ public void boost() throws Exception { if (m_trainInstances == null || m_trainInstances.numInstances() == 0) throw new Exception("Trying to boost with no training data"); // perform the search searchForBestTestSingle(); if (m_search_bestSplitter == null) return; // handle empty instances // create the new nodes for the tree, updating the weights for (int i=0; i<2; i++) { Instances posInstances = m_search_bestSplitter.instancesDownBranch(i, m_search_bestPathPosInstances); Instances negInstances = m_search_bestSplitter.instancesDownBranch(i, m_search_bestPathNegInstances); double predictionValue = calcPredictionValue(posInstances, negInstances); PredictionNode newPredictor = new PredictionNode(predictionValue); updateWeights(posInstances, negInstances, predictionValue); m_search_bestSplitter.setChildForBranch(i, newPredictor); } // insert the new nodes m_search_bestInsertionNode.addChild((Splitter) m_search_bestSplitter, this); // free memory m_search_bestPathPosInstances = null; m_search_bestPathNegInstances = null; m_search_bestSplitter = null; } /** * Generates the m_nominalAttIndices and m_numericAttIndices arrays to index * the respective attribute types in the training data. * */ private void generateAttributeIndicesSingle() { // insert indices into vectors FastVector nominalIndices = new FastVector(); FastVector numericIndices = new FastVector(); for (int i=0; i
*
* -B num
* Set the number of boosting iterations
* (default 10)
*
* -E num
* Set the nodes to expand: -3(all), -2(weight), -1(z_pure), >=0 seed for random walk
* (default -3)
*
* -D
* Save the instance data with the model
*
* @param options the list of options as an array of strings
* @exception Exception if an option is not supported
*/
public void setOptions(String[] options) throws Exception {
String bString = Utils.getOption('B', options);
if (bString.length() != 0) setNumOfBoostingIterations(Integer.parseInt(bString));
String eString = Utils.getOption('E', options);
if (eString.length() != 0) {
int value = Integer.parseInt(eString);
if (value >= 0) {
setSearchPath(new SelectedTag(SEARCHPATH_RANDOM, TAGS_SEARCHPATH));
setRandomSeed(value);
} else setSearchPath(new SelectedTag(value + 3, TAGS_SEARCHPATH));
}
setSaveInstanceData(Utils.getFlag('D', options));
Utils.checkForRemainingOptions(options);
}
/**
* Gets the current settings of ADTree.
*
* @return an array of strings suitable for passing to setOptions()
*/
public String[] getOptions() {
String[] options = new String[6];
int current = 0;
options[current++] = "-B"; options[current++] = "" + getNumOfBoostingIterations();
options[current++] = "-E"; options[current++] = "" +
(m_searchPath == SEARCHPATH_RANDOM ?
m_randomSeed : m_searchPath - 3);
if (getSaveInstanceData()) options[current++] = "-D";
while (current < options.length) options[current++] = "";
return options;
}
/**
* Calls measure function for tree size - the total number of nodes.
*
* @return the tree size
*/
public double measureTreeSize() {
return numOfAllNodes(m_root);
}
/**
* Calls measure function for leaf size - the number of prediction nodes.
*
* @return the leaf size
*/
public double measureNumLeaves() {
return numOfPredictionNodes(m_root);
}
/**
* Calls measure function for prediction leaf size - the number of
* prediction nodes without children.
*
* @return the leaf size
*/
public double measureNumPredictionLeaves() {
return numOfPredictionLeafNodes(m_root);
}
/**
* Returns the number of nodes expanded.
*
* @return the number of nodes expanded during search
*/
public double measureNodesExpanded() {
return m_nodesExpanded;
}
/**
* Returns the number of examples "counted".
*
* @return the number of nodes processed during search
*/
public double measureExamplesProcessed() {
return m_examplesCounted;
}
/**
* Returns an enumeration of the additional measure names.
*
* @return an enumeration of the measure names
*/
public Enumeration enumerateMeasures() {
Vector newVector = new Vector(4);
newVector.addElement("measureTreeSize");
newVector.addElement("measureNumLeaves");
newVector.addElement("measureNumPredictionLeaves");
newVector.addElement("measureNodesExpanded");
newVector.addElement("measureExamplesProcessed");
return newVector.elements();
}
/**
* Returns the value of the named measure.
*
* @param additionalMeasureName the name of the measure to query for its value
* @return the value of the named measure
* @exception IllegalArgumentException if the named measure is not supported
*/
public double getMeasure(String additionalMeasureName) {
if (additionalMeasureName.equalsIgnoreCase("measureTreeSize")) {
return measureTreeSize();
}
else if (additionalMeasureName.equalsIgnoreCase("measureNumLeaves")) {
return measureNumLeaves();
}
else if (additionalMeasureName.equalsIgnoreCase("measureNumPredictionLeaves")) {
return measureNumPredictionLeaves();
}
else if (additionalMeasureName.equalsIgnoreCase("measureNodesExpanded")) {
return measureNodesExpanded();
}
else if (additionalMeasureName.equalsIgnoreCase("measureExamplesProcessed")) {
return measureExamplesProcessed();
}
else {throw new IllegalArgumentException(additionalMeasureName
+ " not supported (ADTree)");
}
}
/**
* Returns the total number of nodes in a tree.
*
* @param root the root of the tree being measured
* @return tree size in number of splitter + prediction nodes
*/
protected int numOfAllNodes(PredictionNode root) {
int numSoFar = 0;
if (root != null) {
numSoFar++;
for (Enumeration e = root.children(); e.hasMoreElements(); ) {
numSoFar++;
Splitter split = (Splitter) e.nextElement();
for (int i=0; i