/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* SimpleCart.java
* Copyright (C) 2007 Haijian Shi
*
*/
package weka.classifiers.trees;
import weka.classifiers.Evaluation;
import weka.classifiers.RandomizableClassifier;
import weka.core.AdditionalMeasureProducer;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.Capabilities.Capability;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;
import weka.core.matrix.Matrix;
import java.util.Arrays;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
/**
* Class implementing minimal cost-complexity pruning.
* Note when dealing with missing values, use "fractional instances" method instead of surrogate split method.
*
* For more information, see:
*
* Leo Breiman, Jerome H. Friedman, Richard A. Olshen, Charles J. Stone (1984). Classification and Regression Trees. Wadsworth International Group, Belmont, California.
*
* @book{Breiman1984, * address = {Belmont, California}, * author = {Leo Breiman and Jerome H. Friedman and Richard A. Olshen and Charles J. Stone}, * publisher = {Wadsworth International Group}, * title = {Classification and Regression Trees}, * year = {1984} * } ** * * Valid options are: * *
-S <num> * Random number seed. * (default 1)* *
-D * If set, classifier is run in debug mode and * may output additional info to the console* *
-M <min no> * The minimal number of instances at the terminal nodes. * (default 2)* *
-N <num folds> * The number of folds used in the minimal cost-complexity pruning. * (default 5)* *
-U * Don't use the minimal cost-complexity pruning. * (default yes).* *
-H * Don't use the heuristic method for binary split. * (default true).* *
-A * Use 1 SE rule to make pruning decision. * (default no).* *
-C * Percentage of training data size (0-1]. * (default 1).* * * @author Haijian Shi (hs69@cs.waikato.ac.nz) * @version $Revision: 5987 $ */ public class SimpleCart extends RandomizableClassifier implements AdditionalMeasureProducer, TechnicalInformationHandler { /** For serialization. */ private static final long serialVersionUID = 4154189200352566053L; /** Training data. */ protected Instances m_train; /** Successor nodes. */ protected SimpleCart[] m_Successors; /** Attribute used to split data. */ protected Attribute m_Attribute; /** Split point for a numeric attribute. */ protected double m_SplitValue; /** Split subset used to split data for nominal attributes. */ protected String m_SplitString; /** Class value if the node is leaf. */ protected double m_ClassValue; /** Class attriubte of data. */ protected Attribute m_ClassAttribute; /** Minimum number of instances in at the terminal nodes. */ protected double m_minNumObj = 2; /** Number of folds for minimal cost-complexity pruning. */ protected int m_numFoldsPruning = 5; /** Alpha-value (for pruning) at the node. */ protected double m_Alpha; /** Number of training examples misclassified by the model (subtree rooted). */ protected double m_numIncorrectModel; /** Number of training examples misclassified by the model (subtree not rooted). */ protected double m_numIncorrectTree; /** Indicate if the node is a leaf node. */ protected boolean m_isLeaf; /** If use minimal cost-compexity pruning. */ protected boolean m_Prune = true; /** Total number of instances used to build the classifier. */ protected int m_totalTrainInstances; /** Proportion for each branch. */ protected double[] m_Props; /** Class probabilities. */ protected double[] m_ClassProbs = null; /** Distributions of leaf node (or temporary leaf node in minimal cost-complexity pruning) */ protected double[] m_Distribution; /** If use huristic search for nominal attributes in multi-class problems (default true). */ protected boolean m_Heuristic = true; /** If use the 1SE rule to make final decision tree. */ protected boolean m_UseOneSE = false; /** Training data size. */ protected double m_SizePer = 1; /** * Return a description suitable for displaying in the explorer/experimenter. * * @return a description suitable for displaying in the * explorer/experimenter */ public String globalInfo() { return "Class implementing minimal cost-complexity pruning.\n" + "Note when dealing with missing values, use \"fractional " + "instances\" method instead of surrogate split method.\n\n" + "For more information, see:\n\n" + getTechnicalInformation().toString(); } /** * Returns an instance of a TechnicalInformation object, containing * detailed information about the technical background of this class, * e.g., paper reference or book this class is based on. * * @return the technical information about this class */ public TechnicalInformation getTechnicalInformation() { TechnicalInformation result; result = new TechnicalInformation(Type.BOOK); result.setValue(Field.AUTHOR, "Leo Breiman and Jerome H. Friedman and Richard A. Olshen and Charles J. Stone"); result.setValue(Field.YEAR, "1984"); result.setValue(Field.TITLE, "Classification and Regression Trees"); result.setValue(Field.PUBLISHER, "Wadsworth International Group"); result.setValue(Field.ADDRESS, "Belmont, California"); return result; } /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); result.disableAll(); // attributes result.enable(Capability.NOMINAL_ATTRIBUTES); result.enable(Capability.NUMERIC_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); // class result.enable(Capability.NOMINAL_CLASS); return result; } /** * Build the classifier. * * @param data the training instances * @throws Exception if something goes wrong */ public void buildClassifier(Instances data) throws Exception { getCapabilities().testWithFail(data); data = new Instances(data); data.deleteWithMissingClass(); // unpruned CART decision tree if (!m_Prune) { // calculate sorted indices and weights, and compute initial class counts. int[][] sortedIndices = new int[data.numAttributes()][0]; double[][] weights = new double[data.numAttributes()][0]; double[] classProbs = new double[data.numClasses()]; double totalWeight = computeSortedInfo(data,sortedIndices, weights,classProbs); makeTree(data, data.numInstances(),sortedIndices,weights,classProbs, totalWeight,m_minNumObj, m_Heuristic); return; } Random random = new Random(m_Seed); Instances cvData = new Instances(data); cvData.randomize(random); cvData = new Instances(cvData,0,(int)(cvData.numInstances()*m_SizePer)-1); cvData.stratify(m_numFoldsPruning); double[][] alphas = new double[m_numFoldsPruning][]; double[][] errors = new double[m_numFoldsPruning][]; // calculate errors and alphas for each fold for (int i = 0; i < m_numFoldsPruning; i++) { //for every fold, grow tree on training set and fix error on test set. Instances train = cvData.trainCV(m_numFoldsPruning, i); Instances test = cvData.testCV(m_numFoldsPruning, i); // calculate sorted indices and weights, and compute initial class counts for each fold int[][] sortedIndices = new int[train.numAttributes()][0]; double[][] weights = new double[train.numAttributes()][0]; double[] classProbs = new double[train.numClasses()]; double totalWeight = computeSortedInfo(train,sortedIndices, weights,classProbs); makeTree(train, train.numInstances(),sortedIndices,weights,classProbs, totalWeight,m_minNumObj, m_Heuristic); int numNodes = numInnerNodes(); alphas[i] = new double[numNodes + 2]; errors[i] = new double[numNodes + 2]; // prune back and log alpha-values and errors on test set prune(alphas[i], errors[i], test); } // calculate sorted indices and weights, and compute initial class counts on all training instances int[][] sortedIndices = new int[data.numAttributes()][0]; double[][] weights = new double[data.numAttributes()][0]; double[] classProbs = new double[data.numClasses()]; double totalWeight = computeSortedInfo(data,sortedIndices, weights,classProbs); //build tree using all the data makeTree(data, data.numInstances(),sortedIndices,weights,classProbs, totalWeight,m_minNumObj, m_Heuristic); int numNodes = numInnerNodes(); double[] treeAlphas = new double[numNodes + 2]; // prune back and log alpha-values int iterations = prune(treeAlphas, null, null); double[] treeErrors = new double[numNodes + 2]; // for each pruned subtree, find the cross-validated error for (int i = 0; i <= iterations; i++){ //compute midpoint alphas double alpha = Math.sqrt(treeAlphas[i] * treeAlphas[i+1]); double error = 0; for (int k = 0; k < m_numFoldsPruning; k++) { int l = 0; while (alphas[k][l] <= alpha) l++; error += errors[k][l - 1]; } treeErrors[i] = error/m_numFoldsPruning; } // find best alpha int best = -1; double bestError = Double.MAX_VALUE; for (int i = iterations; i >= 0; i--) { if (treeErrors[i] < bestError) { bestError = treeErrors[i]; best = i; } } // 1 SE rule to choose expansion if (m_UseOneSE) { double oneSE = Math.sqrt(bestError*(1-bestError)/(data.numInstances())); for (int i = iterations; i >= 0; i--) { if (treeErrors[i] <= bestError+oneSE) { best = i; break; } } } double bestAlpha = Math.sqrt(treeAlphas[best] * treeAlphas[best + 1]); //"unprune" final tree (faster than regrowing it) unprune(); prune(bestAlpha); } /** * Make binary decision tree recursively. * * @param data the training instances * @param totalInstances total number of instances * @param sortedIndices sorted indices of the instances * @param weights weights of the instances * @param classProbs class probabilities * @param totalWeight total weight of instances * @param minNumObj minimal number of instances at leaf nodes * @param useHeuristic if use heuristic search for nominal attributes in multi-class problem * @throws Exception if something goes wrong */ protected void makeTree(Instances data, int totalInstances, int[][] sortedIndices, double[][] weights, double[] classProbs, double totalWeight, double minNumObj, boolean useHeuristic) throws Exception{ // if no instances have reached this node (normally won't happen) if (totalWeight == 0){ m_Attribute = null; m_ClassValue = Utils.missingValue(); m_Distribution = new double[data.numClasses()]; return; } m_totalTrainInstances = totalInstances; m_isLeaf = true; m_ClassProbs = new double[classProbs.length]; m_Distribution = new double[classProbs.length]; System.arraycopy(classProbs, 0, m_ClassProbs, 0, classProbs.length); System.arraycopy(classProbs, 0, m_Distribution, 0, classProbs.length); if (Utils.sum(m_ClassProbs)!=0) Utils.normalize(m_ClassProbs); // Compute class distributions and value of splitting // criterion for each attribute double[][][] dists = new double[data.numAttributes()][0][0]; double[][] props = new double[data.numAttributes()][0]; double[][] totalSubsetWeights = new double[data.numAttributes()][2]; double[] splits = new double[data.numAttributes()]; String[] splitString = new String[data.numAttributes()]; double[] giniGains = new double[data.numAttributes()]; // for each attribute find split information for (int i = 0; i < data.numAttributes(); i++) { Attribute att = data.attribute(i); if (i==data.classIndex()) continue; if (att.isNumeric()) { // numeric attribute splits[i] = numericDistribution(props, dists, att, sortedIndices[i], weights[i], totalSubsetWeights, giniGains, data); } else { // nominal attribute splitString[i] = nominalDistribution(props, dists, att, sortedIndices[i], weights[i], totalSubsetWeights, giniGains, data, useHeuristic); } } // Find best attribute (split with maximum Gini gain) int attIndex = Utils.maxIndex(giniGains); m_Attribute = data.attribute(attIndex); m_train = new Instances(data, sortedIndices[attIndex].length); for (int i=0; i
-S <num> * Random number seed. * (default 1)* *
-D * If set, classifier is run in debug mode and * may output additional info to the console* *
-M <min no> * The minimal number of instances at the terminal nodes. * (default 2)* *
-N <num folds> * The number of folds used in the minimal cost-complexity pruning. * (default 5)* *
-U * Don't use the minimal cost-complexity pruning. * (default yes).* *
-H * Don't use the heuristic method for binary split. * (default true).* *
-A * Use 1 SE rule to make pruning decision. * (default no).* *
-C * Percentage of training data size (0-1]. * (default 1).* * * @param options the list of options as an array of strings * @throws Exception if an options is not supported */ public void setOptions(String[] options) throws Exception { String tmpStr; super.setOptions(options); tmpStr = Utils.getOption('M', options); if (tmpStr.length() != 0) setMinNumObj(Double.parseDouble(tmpStr)); else setMinNumObj(2); tmpStr = Utils.getOption('N', options); if (tmpStr.length()!=0) setNumFoldsPruning(Integer.parseInt(tmpStr)); else setNumFoldsPruning(5); setUsePrune(!Utils.getFlag('U',options)); setHeuristic(!Utils.getFlag('H',options)); setUseOneSE(Utils.getFlag('A',options)); tmpStr = Utils.getOption('C', options); if (tmpStr.length()!=0) setSizePer(Double.parseDouble(tmpStr)); else setSizePer(1); Utils.checkForRemainingOptions(options); } /** * Gets the current settings of the classifier. * * @return the current setting of the classifier */ public String[] getOptions() { int i; Vector result; String[] options; result = new Vector(); options = super.getOptions(); for (i = 0; i < options.length; i++) result.add(options[i]); result.add("-M"); result.add("" + getMinNumObj()); result.add("-N"); result.add("" + getNumFoldsPruning()); if (!getUsePrune()) result.add("-U"); if (!getHeuristic()) result.add("-H"); if (getUseOneSE()) result.add("-A"); result.add("-C"); result.add("" + getSizePer()); return (String[]) result.toArray(new String[result.size()]); } /** * Return an enumeration of the measure names. * * @return an enumeration of the measure names */ public Enumeration enumerateMeasures() { Vector result = new Vector(); result.addElement("measureTreeSize"); return result.elements(); } /** * Return number of tree size. * * @return number of tree size */ public double measureTreeSize() { return numNodes(); } /** * Returns the value of the named measure. * * @param additionalMeasureName the name of the measure to query for its value * @return the value of the named measure * @throws IllegalArgumentException if the named measure is not supported */ public double getMeasure(String additionalMeasureName) { if (additionalMeasureName.compareToIgnoreCase("measureTreeSize") == 0) { return measureTreeSize(); } else { throw new IllegalArgumentException(additionalMeasureName + " not supported (Cart pruning)"); } } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String minNumObjTipText() { return "The minimal number of observations at the terminal nodes (default 2)."; } /** * Set minimal number of instances at the terminal nodes. * * @param value minimal number of instances at the terminal nodes */ public void setMinNumObj(double value) { m_minNumObj = value; } /** * Get minimal number of instances at the terminal nodes. * * @return minimal number of instances at the terminal nodes */ public double getMinNumObj() { return m_minNumObj; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String numFoldsPruningTipText() { return "The number of folds in the internal cross-validation (default 5)."; } /** * Set number of folds in internal cross-validation. * * @param value number of folds in internal cross-validation. */ public void setNumFoldsPruning(int value) { m_numFoldsPruning = value; } /** * Set number of folds in internal cross-validation. * * @return number of folds in internal cross-validation. */ public int getNumFoldsPruning() { return m_numFoldsPruning; } /** * Return the tip text for this property * * @return tip text for this property suitable for displaying in * the explorer/experimenter gui. */ public String usePruneTipText() { return "Use minimal cost-complexity pruning (default yes)."; } /** * Set if use minimal cost-complexity pruning. * * @param value if use minimal cost-complexity pruning */ public void setUsePrune(boolean value) { m_Prune = value; } /** * Get if use minimal cost-complexity pruning. * * @return if use minimal cost-complexity pruning */ public boolean getUsePrune() { return m_Prune; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui. */ public String heuristicTipText() { return "If heuristic search is used for binary split for nominal attributes " + "in multi-class problems (default yes)."; } /** * Set if use heuristic search for nominal attributes in multi-class problems. * * @param value if use heuristic search for nominal attributes in * multi-class problems */ public void setHeuristic(boolean value) { m_Heuristic = value; } /** * Get if use heuristic search for nominal attributes in multi-class problems. * * @return if use heuristic search for nominal attributes in * multi-class problems */ public boolean getHeuristic() {return m_Heuristic;} /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui. */ public String useOneSETipText() { return "Use the 1SE rule to make pruning decisoin."; } /** * Set if use the 1SE rule to choose final model. * * @param value if use the 1SE rule to choose final model */ public void setUseOneSE(boolean value) { m_UseOneSE = value; } /** * Get if use the 1SE rule to choose final model. * * @return if use the 1SE rule to choose final model */ public boolean getUseOneSE() { return m_UseOneSE; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui. */ public String sizePerTipText() { return "The percentage of the training set size (0-1, 0 not included)."; } /** * Set training set size. * * @param value training set size */ public void setSizePer(double value) { if ((value <= 0) || (value > 1)) System.err.println( "The percentage of the training set size must be in range 0 to 1 " + "(0 not included) - ignored!"); else m_SizePer = value; } /** * Get training set size. * * @return training set size */ public double getSizePer() { return m_SizePer; } /** * Returns the revision string. * * @return the revision */ public String getRevision() { return RevisionUtils.extract("$Revision: 5987 $"); } /** * Main method. * @param args the options for the classifier */ public static void main(String[] args) { runClassifier(new SimpleCart(), args); } }