/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * DecisionTable.java * Copyright (C) 1999 University of Waikato, Hamilton, New Zealand * */ package weka.classifiers.rules; import weka.attributeSelection.ASSearch; import weka.attributeSelection.BestFirst; import weka.attributeSelection.SubsetEvaluator; import weka.attributeSelection.ASEvaluation; import weka.classifiers.Classifier; import weka.classifiers.AbstractClassifier; import weka.classifiers.Evaluation; import weka.classifiers.lazy.IBk; import weka.core.AdditionalMeasureProducer; import weka.core.Capabilities; import weka.core.Instance; import weka.core.Instances; import weka.core.Option; import weka.core.OptionHandler; import weka.core.RevisionUtils; import weka.core.SelectedTag; import weka.core.Tag; import weka.core.TechnicalInformation; import weka.core.TechnicalInformationHandler; import weka.core.Utils; import weka.core.WeightedInstancesHandler; import weka.core.Capabilities.Capability; import weka.core.TechnicalInformation.Field; import weka.core.TechnicalInformation.Type; import weka.filters.Filter; import weka.filters.unsupervised.attribute.Remove; import java.util.Arrays; import java.util.BitSet; import java.util.Enumeration; import java.util.Hashtable; import java.util.Random; import java.util.Vector; /** * Class for building and using a simple decision table majority classifier.
*
* For more information see:
*
* Ron Kohavi: The Power of Decision Tables. In: 8th European Conference on Machine Learning, 174-189, 1995. *

* * BibTeX: *

 * @inproceedings{Kohavi1995,
 *    author = {Ron Kohavi},
 *    booktitle = {8th European Conference on Machine Learning},
 *    pages = {174-189},
 *    publisher = {Springer},
 *    title = {The Power of Decision Tables},
 *    year = {1995}
 * }
 * 
*

* * Valid options are:

* *

 -S <search method specification>
 *  Full class name of search method, followed
 *  by its options.
 *  eg: "weka.attributeSelection.BestFirst -D 1"
 *  (default weka.attributeSelection.BestFirst)
* *
 -X <number of folds>
 *  Use cross validation to evaluate features.
 *  Use number of folds = 1 for leave one out CV.
 *  (Default = leave one out CV)
* *
 -E <acc | rmse | mae | auc>
 *  Performance evaluation measure to use for selecting attributes.
 *  (Default = accuracy for discrete class and rmse for numeric class)
* *
 -I
 *  Use nearest neighbour instead of global table majority.
* *
 -R
 *  Display decision table rules.
 * 
* *
 
 * Options specific to search method weka.attributeSelection.BestFirst:
 * 
* *
 -P <start set>
 *  Specify a starting set of attributes.
 *  Eg. 1,3,5-7.
* *
 -D <0 = backward | 1 = forward | 2 = bi-directional>
 *  Direction of search. (default = 1).
* *
 -N <num>
 *  Number of non-improving nodes to
 *  consider before terminating search.
* *
 -S <num>
 *  Size of lookup cache for evaluated subsets.
 *  Expressed as a multiple of the number of
 *  attributes in the data set. (default = 1)
* * * @author Mark Hall (mhall@cs.waikato.ac.nz) * @version $Revision: 5987 $ */ public class DecisionTable extends AbstractClassifier implements OptionHandler, WeightedInstancesHandler, AdditionalMeasureProducer, TechnicalInformationHandler { /** for serialization */ static final long serialVersionUID = 2888557078165701326L; /** The hashtable used to hold training instances */ protected Hashtable m_entries; /** The class priors to use when there is no match in the table */ protected double [] m_classPriorCounts; protected double [] m_classPriors; /** Holds the final feature set */ protected int [] m_decisionFeatures; /** Discretization filter */ protected Filter m_disTransform; /** Filter used to remove columns discarded by feature selection */ protected Remove m_delTransform; /** IB1 used to classify non matching instances rather than majority class */ protected IBk m_ibk; /** Holds the original training instances */ protected Instances m_theInstances; /** Holds the final feature selected set of instances */ protected Instances m_dtInstances; /** The number of attributes in the dataset */ protected int m_numAttributes; /** The number of instances in the dataset */ private int m_numInstances; /** Class is nominal */ protected boolean m_classIsNominal; /** Use the IBk classifier rather than majority class */ protected boolean m_useIBk; /** Display Rules */ protected boolean m_displayRules; /** Number of folds for cross validating feature sets */ private int m_CVFolds; /** Random numbers for use in cross validation */ private Random m_rr; /** Holds the majority class */ protected double m_majority; /** The search method to use */ protected ASSearch m_search = new BestFirst(); /** Our own internal evaluator */ protected ASEvaluation m_evaluator; /** The evaluation object used to evaluate subsets */ protected Evaluation m_evaluation; /** default is accuracy for discrete class and RMSE for numeric class */ public static final int EVAL_DEFAULT = 1; public static final int EVAL_ACCURACY = 2; public static final int EVAL_RMSE = 3; public static final int EVAL_MAE = 4; public static final int EVAL_AUC = 5; public static final Tag [] TAGS_EVALUATION = { new Tag(EVAL_DEFAULT, "Default: accuracy (discrete class); RMSE (numeric class)"), new Tag(EVAL_ACCURACY, "Accuracy (discrete class only"), new Tag(EVAL_RMSE, "RMSE (of the class probabilities for discrete class)"), new Tag(EVAL_MAE, "MAE (of the class probabilities for discrete class)"), new Tag(EVAL_AUC, "AUC (area under the ROC curve - discrete class only)") }; protected int m_evaluationMeasure = EVAL_DEFAULT; /** * Returns a string describing classifier * @return a description suitable for * displaying in the explorer/experimenter gui */ public String globalInfo() { return "Class for building and using a simple decision table majority " + "classifier.\n\n" + "For more information see: \n\n" + getTechnicalInformation().toString(); } /** * Returns an instance of a TechnicalInformation object, containing * detailed information about the technical background of this class, * e.g., paper reference or book this class is based on. * * @return the technical information about this class */ public TechnicalInformation getTechnicalInformation() { TechnicalInformation result; result = new TechnicalInformation(Type.INPROCEEDINGS); result.setValue(Field.AUTHOR, "Ron Kohavi"); result.setValue(Field.TITLE, "The Power of Decision Tables"); result.setValue(Field.BOOKTITLE, "8th European Conference on Machine Learning"); result.setValue(Field.YEAR, "1995"); result.setValue(Field.PAGES, "174-189"); result.setValue(Field.PUBLISHER, "Springer"); return result; } /** * Inserts an instance into the hash table * * @param inst instance to be inserted * @param instA to create the hash key from * @throws Exception if the instance can't be inserted */ private void insertIntoTable(Instance inst, double [] instA) throws Exception { double [] tempClassDist2; double [] newDist; DecisionTableHashKey thekey; if (instA != null) { thekey = new DecisionTableHashKey(instA); } else { thekey = new DecisionTableHashKey(inst, inst.numAttributes(), false); } // see if this one is already in the table tempClassDist2 = (double []) m_entries.get(thekey); if (tempClassDist2 == null) { if (m_classIsNominal) { newDist = new double [m_theInstances.classAttribute().numValues()]; //Leplace estimation for (int i = 0; i < m_theInstances.classAttribute().numValues(); i++) { newDist[i] = 1.0; } newDist[(int)inst.classValue()] = inst.weight(); // add to the table m_entries.put(thekey, newDist); } else { newDist = new double [2]; newDist[0] = inst.classValue() * inst.weight(); newDist[1] = inst.weight(); // add to the table m_entries.put(thekey, newDist); } } else { // update the distribution for this instance if (m_classIsNominal) { tempClassDist2[(int)inst.classValue()]+=inst.weight(); // update the table m_entries.put(thekey, tempClassDist2); } else { tempClassDist2[0] += (inst.classValue() * inst.weight()); tempClassDist2[1] += inst.weight(); // update the table m_entries.put(thekey, tempClassDist2); } } } /** * Classifies an instance for internal leave one out cross validation * of feature sets * * @param instance instance to be "left out" and classified * @param instA feature values of the selected features for the instance * @return the classification of the instance * @throws Exception if something goes wrong */ double evaluateInstanceLeaveOneOut(Instance instance, double [] instA) throws Exception { DecisionTableHashKey thekey; double [] tempDist; double [] normDist; thekey = new DecisionTableHashKey(instA); if (m_classIsNominal) { // if this one is not in the table if ((tempDist = (double [])m_entries.get(thekey)) == null) { throw new Error("This should never happen!"); } else { normDist = new double [tempDist.length]; System.arraycopy(tempDist,0,normDist,0,tempDist.length); normDist[(int)instance.classValue()] -= instance.weight(); // update the table // first check to see if the class counts are all zero now boolean ok = false; for (int i=0;i")); newVector.addElement(new Option( "\tUse cross validation to evaluate features.\n" + "\tUse number of folds = 1 for leave one out CV.\n" + "\t(Default = leave one out CV)", "X", 1, "-X ")); newVector.addElement(new Option( "\tPerformance evaluation measure to use for selecting attributes.\n" + "\t(Default = accuracy for discrete class and rmse for numeric class)", "E", 1, "-E ")); newVector.addElement(new Option( "\tUse nearest neighbour instead of global table majority.", "I", 0, "-I")); newVector.addElement(new Option( "\tDisplay decision table rules.\n", "R", 0, "-R")); newVector.addElement(new Option( "", "", 0, "\nOptions specific to search method " + m_search.getClass().getName() + ":")); Enumeration enu = ((OptionHandler)m_search).listOptions(); while (enu.hasMoreElements()) { newVector.addElement(enu.nextElement()); } return newVector.elements(); } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String crossValTipText() { return "Sets the number of folds for cross validation (1 = leave one out)."; } /** * Sets the number of folds for cross validation (1 = leave one out) * * @param folds the number of folds */ public void setCrossVal(int folds) { m_CVFolds = folds; } /** * Gets the number of folds for cross validation * * @return the number of cross validation folds */ public int getCrossVal() { return m_CVFolds; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String useIBkTipText() { return "Sets whether IBk should be used instead of the majority class."; } /** * Sets whether IBk should be used instead of the majority class * * @param ibk true if IBk is to be used */ public void setUseIBk(boolean ibk) { m_useIBk = ibk; } /** * Gets whether IBk is being used instead of the majority class * * @return true if IBk is being used */ public boolean getUseIBk() { return m_useIBk; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String displayRulesTipText() { return "Sets whether rules are to be printed."; } /** * Sets whether rules are to be printed * * @param rules true if rules are to be printed */ public void setDisplayRules(boolean rules) { m_displayRules = rules; } /** * Gets whether rules are being printed * * @return true if rules are being printed */ public boolean getDisplayRules() { return m_displayRules; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String searchTipText() { return "The search method used to find good attribute combinations for the " + "decision table."; } /** * Sets the search method to use * * @param search */ public void setSearch(ASSearch search) { m_search = search; } /** * Gets the current search method * * @return the search method used */ public ASSearch getSearch() { return m_search; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String evaluationMeasureTipText() { return "The measure used to evaluate the performance of attribute combinations " + "used in the decision table."; } /** * Gets the currently set performance evaluation measure used for selecting * attributes for the decision table * * @return the performance evaluation measure */ public SelectedTag getEvaluationMeasure() { return new SelectedTag(m_evaluationMeasure, TAGS_EVALUATION); } /** * Sets the performance evaluation measure to use for selecting attributes * for the decision table * * @param newMethod the new performance evaluation metric to use */ public void setEvaluationMeasure(SelectedTag newMethod) { if (newMethod.getTags() == TAGS_EVALUATION) { m_evaluationMeasure = newMethod.getSelectedTag().getID(); } } /** * Parses the options for this object.

* * Valid options are:

* *

 -S <search method specification>
   *  Full class name of search method, followed
   *  by its options.
   *  eg: "weka.attributeSelection.BestFirst -D 1"
   *  (default weka.attributeSelection.BestFirst)
* *
 -X <number of folds>
   *  Use cross validation to evaluate features.
   *  Use number of folds = 1 for leave one out CV.
   *  (Default = leave one out CV)
* *
 -E <acc | rmse | mae | auc>
   *  Performance evaluation measure to use for selecting attributes.
   *  (Default = accuracy for discrete class and rmse for numeric class)
* *
 -I
   *  Use nearest neighbour instead of global table majority.
* *
 -R
   *  Display decision table rules.
   * 
* *
 
   * Options specific to search method weka.attributeSelection.BestFirst:
   * 
* *
 -P <start set>
   *  Specify a starting set of attributes.
   *  Eg. 1,3,5-7.
* *
 -D <0 = backward | 1 = forward | 2 = bi-directional>
   *  Direction of search. (default = 1).
* *
 -N <num>
   *  Number of non-improving nodes to
   *  consider before terminating search.
* *
 -S <num>
   *  Size of lookup cache for evaluated subsets.
   *  Expressed as a multiple of the number of
   *  attributes in the data set. (default = 1)
* * * @param options the list of options as an array of strings * @throws Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { String optionString; resetOptions(); optionString = Utils.getOption('X',options); if (optionString.length() != 0) { m_CVFolds = Integer.parseInt(optionString); } m_useIBk = Utils.getFlag('I',options); m_displayRules = Utils.getFlag('R',options); optionString = Utils.getOption('E', options); if (optionString.length() != 0) { if (optionString.equals("acc")) { setEvaluationMeasure(new SelectedTag(EVAL_ACCURACY, TAGS_EVALUATION)); } else if (optionString.equals("rmse")) { setEvaluationMeasure(new SelectedTag(EVAL_RMSE, TAGS_EVALUATION)); } else if (optionString.equals("mae")) { setEvaluationMeasure(new SelectedTag(EVAL_MAE, TAGS_EVALUATION)); } else if (optionString.equals("auc")) { setEvaluationMeasure(new SelectedTag(EVAL_AUC, TAGS_EVALUATION)); } else { throw new IllegalArgumentException("Invalid evaluation measure"); } } String searchString = Utils.getOption('S', options); if (searchString.length() == 0) searchString = weka.attributeSelection.BestFirst.class.getName(); String [] searchSpec = Utils.splitOptions(searchString); if (searchSpec.length == 0) { throw new IllegalArgumentException("Invalid search specification string"); } String searchName = searchSpec[0]; searchSpec[0] = ""; setSearch(ASSearch.forName(searchName, searchSpec)); } /** * Gets the current settings of the classifier. * * @return an array of strings suitable for passing to setOptions */ public String [] getOptions() { String [] options = new String [9]; int current = 0; options[current++] = "-X"; options[current++] = "" + m_CVFolds; if (m_evaluationMeasure != EVAL_DEFAULT) { options[current++] = "-E"; switch (m_evaluationMeasure) { case EVAL_ACCURACY: options[current++] = "acc"; break; case EVAL_RMSE: options[current++] = "rmse"; break; case EVAL_MAE: options[current++] = "mae"; break; case EVAL_AUC: options[current++] = "auc"; break; } } if (m_useIBk) { options[current++] = "-I"; } if (m_displayRules) { options[current++] = "-R"; } options[current++] = "-S"; options[current++] = "" + getSearchSpec(); while (current < options.length) { options[current++] = ""; } return options; } /** * Gets the search specification string, which contains the class name of * the search method and any options to it * * @return the search string. */ protected String getSearchSpec() { ASSearch s = getSearch(); if (s instanceof OptionHandler) { return s.getClass().getName() + " " + Utils.joinOptions(((OptionHandler)s).getOptions()); } return s.getClass().getName(); } /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); result.disableAll(); // attributes result.enable(Capability.NOMINAL_ATTRIBUTES); result.enable(Capability.NUMERIC_ATTRIBUTES); result.enable(Capability.DATE_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); // class result.enable(Capability.NOMINAL_CLASS); if (m_evaluationMeasure != EVAL_ACCURACY && m_evaluationMeasure != EVAL_AUC) { result.enable(Capability.NUMERIC_CLASS); result.enable(Capability.DATE_CLASS); } result.enable(Capability.MISSING_CLASS_VALUES); return result; } private class DummySubsetEvaluator extends ASEvaluation implements SubsetEvaluator { /** for serialization */ private static final long serialVersionUID = 3927442457704974150L; public void buildEvaluator(Instances data) throws Exception { } public double evaluateSubset(BitSet subset) throws Exception { int fc = 0; for (int jj = 0;jj < m_numAttributes; jj++) { if (subset.get(jj)) { fc++; } } return estimatePerformance(subset, fc); } } /** * Sets up a dummy subset evaluator that basically just delegates * evaluation to the estimatePerformance method in DecisionTable */ protected void setUpEvaluator() throws Exception { m_evaluator = new DummySubsetEvaluator(); } protected boolean m_saveMemory = true; /** * Generates the classifier. * * @param data set of instances serving as training data * @throws Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances data) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(data); // remove instances with missing class m_theInstances = new Instances(data); m_theInstances.deleteWithMissingClass(); m_rr = new Random(1); if (m_theInstances.classAttribute().isNominal()) {// Set up class priors m_classPriorCounts = new double [data.classAttribute().numValues()]; Arrays.fill(m_classPriorCounts, 1.0); for (int i = 0; i 1) { text.append("("+m_CVFolds+" fold) "); } else { text.append("(leave one out) "); } text.append("\nFeature set: "+printFeatures()); if (m_displayRules) { // find out the max column width int maxColWidth = 0; for (int i=0;i maxColWidth) { maxColWidth = m_dtInstances.attribute(i).name().length(); } if (m_classIsNominal || (i != m_dtInstances.classIndex())) { Enumeration e = m_dtInstances.attribute(i).enumerateValues(); while (e.hasMoreElements()) { String ss = (String)e.nextElement(); if (ss.length() > maxColWidth) { maxColWidth = ss.length(); } } } } text.append("\n\nRules:\n"); StringBuffer tm = new StringBuffer(); for (int i=0;i