source: src/main/java/weka/classifiers/rules/DecisionTable.java @ 17

Last change on this file since 17 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 39.3 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    DecisionTable.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.rules;
24
25import weka.attributeSelection.ASSearch;
26import weka.attributeSelection.BestFirst;
27import weka.attributeSelection.SubsetEvaluator;
28import weka.attributeSelection.ASEvaluation;
29import weka.classifiers.Classifier;
30import weka.classifiers.AbstractClassifier;
31import weka.classifiers.Evaluation;
32import weka.classifiers.lazy.IBk;
33import weka.core.AdditionalMeasureProducer;
34import weka.core.Capabilities;
35import weka.core.Instance;
36import weka.core.Instances;
37import weka.core.Option;
38import weka.core.OptionHandler;
39import weka.core.RevisionUtils;
40import weka.core.SelectedTag;
41import weka.core.Tag;
42import weka.core.TechnicalInformation;
43import weka.core.TechnicalInformationHandler;
44import weka.core.Utils;
45import weka.core.WeightedInstancesHandler;
46import weka.core.Capabilities.Capability;
47import weka.core.TechnicalInformation.Field;
48import weka.core.TechnicalInformation.Type;
49import weka.filters.Filter;
50import weka.filters.unsupervised.attribute.Remove;
51
52import java.util.Arrays;
53import java.util.BitSet;
54import java.util.Enumeration;
55import java.util.Hashtable;
56import java.util.Random;
57import java.util.Vector;
58
59/**
60 <!-- globalinfo-start -->
61 * Class for building and using a simple decision table majority classifier.<br/>
62 * <br/>
63 * For more information see: <br/>
64 * <br/>
65 * Ron Kohavi: The Power of Decision Tables. In: 8th European Conference on Machine Learning, 174-189, 1995.
66 * <p/>
67 <!-- globalinfo-end -->
68 *
69 <!-- technical-bibtex-start -->
70 * BibTeX:
71 * <pre>
72 * &#64;inproceedings{Kohavi1995,
73 *    author = {Ron Kohavi},
74 *    booktitle = {8th European Conference on Machine Learning},
75 *    pages = {174-189},
76 *    publisher = {Springer},
77 *    title = {The Power of Decision Tables},
78 *    year = {1995}
79 * }
80 * </pre>
81 * <p/>
82 <!-- technical-bibtex-end -->
83 *
84 <!-- options-start -->
85 * Valid options are: <p/>
86 *
87 * <pre> -S &lt;search method specification&gt;
88 *  Full class name of search method, followed
89 *  by its options.
90 *  eg: "weka.attributeSelection.BestFirst -D 1"
91 *  (default weka.attributeSelection.BestFirst)</pre>
92 *
93 * <pre> -X &lt;number of folds&gt;
94 *  Use cross validation to evaluate features.
95 *  Use number of folds = 1 for leave one out CV.
96 *  (Default = leave one out CV)</pre>
97 *
98 * <pre> -E &lt;acc | rmse | mae | auc&gt;
99 *  Performance evaluation measure to use for selecting attributes.
100 *  (Default = accuracy for discrete class and rmse for numeric class)</pre>
101 *
102 * <pre> -I
103 *  Use nearest neighbour instead of global table majority.</pre>
104 *
105 * <pre> -R
106 *  Display decision table rules.
107 * </pre>
108 *
109 * <pre>
110 * Options specific to search method weka.attributeSelection.BestFirst:
111 * </pre>
112 *
113 * <pre> -P &lt;start set&gt;
114 *  Specify a starting set of attributes.
115 *  Eg. 1,3,5-7.</pre>
116 *
117 * <pre> -D &lt;0 = backward | 1 = forward | 2 = bi-directional&gt;
118 *  Direction of search. (default = 1).</pre>
119 *
120 * <pre> -N &lt;num&gt;
121 *  Number of non-improving nodes to
122 *  consider before terminating search.</pre>
123 *
124 * <pre> -S &lt;num&gt;
125 *  Size of lookup cache for evaluated subsets.
126 *  Expressed as a multiple of the number of
127 *  attributes in the data set. (default = 1)</pre>
128 *
129 <!-- options-end -->
130 *
131 * @author Mark Hall (mhall@cs.waikato.ac.nz)
132 * @version $Revision: 5987 $
133 */
134public class DecisionTable 
135  extends AbstractClassifier
136  implements OptionHandler, WeightedInstancesHandler, 
137             AdditionalMeasureProducer, TechnicalInformationHandler {
138
139  /** for serialization */
140  static final long serialVersionUID = 2888557078165701326L;
141
142  /** The hashtable used to hold training instances */
143  protected Hashtable m_entries;
144
145  /** The class priors to use when there is no match in the table */
146  protected double [] m_classPriorCounts;
147  protected double [] m_classPriors;
148
149  /** Holds the final feature set */
150  protected int [] m_decisionFeatures;
151
152  /** Discretization filter */
153  protected Filter m_disTransform;
154
155  /** Filter used to remove columns discarded by feature selection */
156  protected Remove m_delTransform;
157
158  /** IB1 used to classify non matching instances rather than majority class */
159  protected IBk m_ibk;
160
161  /** Holds the original training instances */
162  protected Instances m_theInstances;
163
164  /** Holds the final feature selected set of instances */
165  protected Instances m_dtInstances;
166
167  /** The number of attributes in the dataset */
168  protected int m_numAttributes;
169
170  /** The number of instances in the dataset */
171  private int m_numInstances;
172
173  /** Class is nominal */
174  protected boolean m_classIsNominal;
175
176  /** Use the IBk classifier rather than majority class */
177  protected boolean m_useIBk;
178
179  /** Display Rules */
180  protected boolean m_displayRules;
181
182  /** Number of folds for cross validating feature sets */
183  private int m_CVFolds;
184
185  /** Random numbers for use in cross validation */
186  private Random m_rr;
187
188  /** Holds the majority class */
189  protected double m_majority;
190
191  /** The search method to use */
192  protected ASSearch m_search = new BestFirst();
193
194  /** Our own internal evaluator */
195  protected ASEvaluation m_evaluator;
196
197  /** The evaluation object used to evaluate subsets */
198  protected Evaluation m_evaluation;
199
200  /** default is accuracy for discrete class and RMSE for numeric class */
201  public static final int EVAL_DEFAULT = 1;
202  public static final int EVAL_ACCURACY = 2;
203  public static final int EVAL_RMSE = 3;
204  public static final int EVAL_MAE = 4;
205  public static final int EVAL_AUC = 5;
206
207  public static final Tag [] TAGS_EVALUATION = {
208    new Tag(EVAL_DEFAULT, "Default: accuracy (discrete class); RMSE (numeric class)"),
209    new Tag(EVAL_ACCURACY, "Accuracy (discrete class only"),
210    new Tag(EVAL_RMSE, "RMSE (of the class probabilities for discrete class)"),
211    new Tag(EVAL_MAE, "MAE (of the class probabilities for discrete class)"),
212    new Tag(EVAL_AUC, "AUC (area under the ROC curve - discrete class only)")
213  };
214
215  protected int m_evaluationMeasure = EVAL_DEFAULT;
216
217  /**
218   * Returns a string describing classifier
219   * @return a description suitable for
220   * displaying in the explorer/experimenter gui
221   */
222  public String globalInfo() {
223
224    return 
225    "Class for building and using a simple decision table majority "
226    + "classifier.\n\n"
227    + "For more information see: \n\n"
228    + getTechnicalInformation().toString();
229  }
230
231  /**
232   * Returns an instance of a TechnicalInformation object, containing
233   * detailed information about the technical background of this class,
234   * e.g., paper reference or book this class is based on.
235   *
236   * @return the technical information about this class
237   */
238  public TechnicalInformation getTechnicalInformation() {
239    TechnicalInformation        result;
240
241    result = new TechnicalInformation(Type.INPROCEEDINGS);
242    result.setValue(Field.AUTHOR, "Ron Kohavi");
243    result.setValue(Field.TITLE, "The Power of Decision Tables");
244    result.setValue(Field.BOOKTITLE, "8th European Conference on Machine Learning");
245    result.setValue(Field.YEAR, "1995");
246    result.setValue(Field.PAGES, "174-189");
247    result.setValue(Field.PUBLISHER, "Springer");
248
249    return result;
250  }
251 
252  /**
253   * Inserts an instance into the hash table
254   *
255   * @param inst instance to be inserted
256   * @param instA to create the hash key from
257   * @throws Exception if the instance can't be inserted
258   */
259  private void insertIntoTable(Instance inst, double [] instA)
260  throws Exception {
261
262    double [] tempClassDist2;
263    double [] newDist;
264    DecisionTableHashKey thekey;
265
266    if (instA != null) {
267      thekey = new DecisionTableHashKey(instA);
268    } else {
269      thekey = new DecisionTableHashKey(inst, inst.numAttributes(), false);
270    }
271
272    // see if this one is already in the table
273    tempClassDist2 = (double []) m_entries.get(thekey);
274    if (tempClassDist2 == null) {
275      if (m_classIsNominal) {
276        newDist = new double [m_theInstances.classAttribute().numValues()];
277       
278        //Leplace estimation
279        for (int i = 0; i < m_theInstances.classAttribute().numValues(); i++) {
280          newDist[i] = 1.0;
281        }
282       
283        newDist[(int)inst.classValue()] = inst.weight();
284
285        // add to the table
286        m_entries.put(thekey, newDist);
287      } else {
288        newDist = new double [2];
289        newDist[0] = inst.classValue() * inst.weight();
290        newDist[1] = inst.weight();
291
292        // add to the table
293        m_entries.put(thekey, newDist);
294      }
295    } else { 
296
297      // update the distribution for this instance
298      if (m_classIsNominal) {
299        tempClassDist2[(int)inst.classValue()]+=inst.weight();
300
301        // update the table
302        m_entries.put(thekey, tempClassDist2);
303      } else  {
304        tempClassDist2[0] += (inst.classValue() * inst.weight());
305        tempClassDist2[1] += inst.weight();
306
307        // update the table
308        m_entries.put(thekey, tempClassDist2);
309      }
310    }
311  }
312
313  /**
314   * Classifies an instance for internal leave one out cross validation
315   * of feature sets
316   *
317   * @param instance instance to be "left out" and classified
318   * @param instA feature values of the selected features for the instance
319   * @return the classification of the instance
320   * @throws Exception if something goes wrong
321   */
322  double evaluateInstanceLeaveOneOut(Instance instance, double [] instA)
323  throws Exception {
324
325    DecisionTableHashKey thekey;
326    double [] tempDist;
327    double [] normDist;
328
329    thekey = new DecisionTableHashKey(instA);
330    if (m_classIsNominal) {
331
332      // if this one is not in the table
333      if ((tempDist = (double [])m_entries.get(thekey)) == null) {
334        throw new Error("This should never happen!");
335      } else {
336        normDist = new double [tempDist.length];
337        System.arraycopy(tempDist,0,normDist,0,tempDist.length);
338        normDist[(int)instance.classValue()] -= instance.weight();
339
340        // update the table
341        // first check to see if the class counts are all zero now
342        boolean ok = false;
343        for (int i=0;i<normDist.length;i++) {
344          if (Utils.gr(normDist[i],1.0)) {
345            ok = true;
346            break;
347          }
348        }
349
350//      downdate the class prior counts
351        m_classPriorCounts[(int)instance.classValue()] -= 
352          instance.weight();
353        double [] classPriors = m_classPriorCounts.clone();
354        Utils.normalize(classPriors);
355        if (!ok) { // majority class
356          normDist = classPriors;
357        }
358
359        m_classPriorCounts[(int)instance.classValue()] += 
360          instance.weight();
361
362        //if (ok) {
363        Utils.normalize(normDist);
364        if (m_evaluationMeasure == EVAL_AUC) {
365          m_evaluation.evaluateModelOnceAndRecordPrediction(normDist, instance);                                               
366        } else {
367          m_evaluation.evaluateModelOnce(normDist, instance);
368        }
369        return Utils.maxIndex(normDist);
370        /*} else {
371          normDist = new double [normDist.length];
372          normDist[(int)m_majority] = 1.0;
373          if (m_evaluationMeasure == EVAL_AUC) {
374            m_evaluation.evaluateModelOnceAndRecordPrediction(normDist, instance);                                             
375          } else {
376            m_evaluation.evaluateModelOnce(normDist, instance);
377          }
378          return m_majority;
379        } */
380      }
381      //      return Utils.maxIndex(tempDist);
382    } else {
383
384      // see if this one is already in the table
385      if ((tempDist = (double[])m_entries.get(thekey)) != null) {
386        normDist = new double [tempDist.length];
387        System.arraycopy(tempDist,0,normDist,0,tempDist.length);
388        normDist[0] -= (instance.classValue() * instance.weight());
389        normDist[1] -= instance.weight();
390        if (Utils.eq(normDist[1],0.0)) {
391          double [] temp = new double[1];
392          temp[0] = m_majority;
393          m_evaluation.evaluateModelOnce(temp, instance);
394          return m_majority;
395        } else {
396          double [] temp = new double[1];
397          temp[0] = normDist[0] / normDist[1];
398          m_evaluation.evaluateModelOnce(temp, instance);
399          return temp[0];
400        }
401      } else {
402        throw new Error("This should never happen!");
403      }
404    }
405
406    // shouldn't get here
407    // return 0.0;
408  }
409
410  /**
411   * Calculates the accuracy on a test fold for internal cross validation
412   * of feature sets
413   *
414   * @param fold set of instances to be "left out" and classified
415   * @param fs currently selected feature set
416   * @return the accuracy for the fold
417   * @throws Exception if something goes wrong
418   */
419  double evaluateFoldCV(Instances fold, int [] fs) throws Exception {
420
421    int i;
422    int ruleCount = 0;
423    int numFold = fold.numInstances();
424    int numCl = m_theInstances.classAttribute().numValues();
425    double [][] class_distribs = new double [numFold][numCl];
426    double [] instA = new double [fs.length];
427    double [] normDist;
428    DecisionTableHashKey thekey;
429    double acc = 0.0;
430    int classI = m_theInstances.classIndex();
431    Instance inst;
432
433    if (m_classIsNominal) {
434      normDist = new double [numCl];
435    } else {
436      normDist = new double [2];
437    }
438
439    // first *remove* instances
440    for (i=0;i<numFold;i++) {
441      inst = fold.instance(i);
442      for (int j=0;j<fs.length;j++) {
443        if (fs[j] == classI) {
444          instA[j] = Double.MAX_VALUE; // missing for the class
445        } else if (inst.isMissing(fs[j])) {
446          instA[j] = Double.MAX_VALUE;
447        } else{
448          instA[j] = inst.value(fs[j]);
449        }
450      }
451      thekey = new DecisionTableHashKey(instA);
452      if ((class_distribs[i] = (double [])m_entries.get(thekey)) == null) {
453        throw new Error("This should never happen!");
454      } else {
455        if (m_classIsNominal) {
456          class_distribs[i][(int)inst.classValue()] -= inst.weight();
457        } else {
458          class_distribs[i][0] -= (inst.classValue() * inst.weight());
459          class_distribs[i][1] -= inst.weight();
460        }
461        ruleCount++;
462      }
463      m_classPriorCounts[(int)inst.classValue()] -= 
464        inst.weight(); 
465    }
466    double [] classPriors = m_classPriorCounts.clone();
467    Utils.normalize(classPriors);
468
469    // now classify instances
470    for (i=0;i<numFold;i++) {
471      inst = fold.instance(i);
472      System.arraycopy(class_distribs[i],0,normDist,0,normDist.length);
473      if (m_classIsNominal) {
474        boolean ok = false;
475        for (int j=0;j<normDist.length;j++) {
476          if (Utils.gr(normDist[j],1.0)) {
477            ok = true;
478            break;
479          }
480        }
481
482        if (!ok) { // majority class
483          normDist = classPriors.clone();
484        }
485
486//      if (ok) {
487        Utils.normalize(normDist);
488        if (m_evaluationMeasure == EVAL_AUC) {
489          m_evaluation.evaluateModelOnceAndRecordPrediction(normDist, inst);                                           
490        } else {
491          m_evaluation.evaluateModelOnce(normDist, inst);
492        }
493        /*      } else {                                       
494          normDist[(int)m_majority] = 1.0;
495          if (m_evaluationMeasure == EVAL_AUC) {
496            m_evaluation.evaluateModelOnceAndRecordPrediction(normDist, inst);                                         
497          } else {
498            m_evaluation.evaluateModelOnce(normDist, inst);                                     
499          }
500        } */
501      } else {
502        if (Utils.eq(normDist[1],0.0)) {
503          double [] temp = new double[1];
504          temp[0] = m_majority;
505          m_evaluation.evaluateModelOnce(temp, inst);
506        } else {
507          double [] temp = new double[1];
508          temp[0] = normDist[0] / normDist[1];
509          m_evaluation.evaluateModelOnce(temp, inst);
510        }
511      }
512    }
513
514    // now re-insert instances
515    for (i=0;i<numFold;i++) {
516      inst = fold.instance(i);
517
518      m_classPriorCounts[(int)inst.classValue()] += 
519        inst.weight();
520
521      if (m_classIsNominal) {
522        class_distribs[i][(int)inst.classValue()] += inst.weight();
523      } else {
524        class_distribs[i][0] += (inst.classValue() * inst.weight());
525        class_distribs[i][1] += inst.weight();
526      }
527    }
528    return acc;
529  }
530
531
532  /**
533   * Evaluates a feature subset by cross validation
534   *
535   * @param feature_set the subset to be evaluated
536   * @param num_atts the number of attributes in the subset
537   * @return the estimated accuracy
538   * @throws Exception if subset can't be evaluated
539   */
540  protected double estimatePerformance(BitSet feature_set, int num_atts)
541  throws Exception {
542
543    m_evaluation = new Evaluation(m_theInstances);
544    int i;
545    int [] fs = new int [num_atts];
546
547    double [] instA = new double [num_atts];
548    int classI = m_theInstances.classIndex();
549
550    int index = 0;
551    for (i=0;i<m_numAttributes;i++) {
552      if (feature_set.get(i)) {
553        fs[index++] = i;
554      }
555    }
556
557    // create new hash table
558    m_entries = new Hashtable((int)(m_theInstances.numInstances() * 1.5));
559
560    // insert instances into the hash table
561    for (i=0;i<m_numInstances;i++) {
562
563      Instance inst = m_theInstances.instance(i);
564      for (int j=0;j<fs.length;j++) {
565        if (fs[j] == classI) {
566          instA[j] = Double.MAX_VALUE; // missing for the class
567        } else if (inst.isMissing(fs[j])) {
568          instA[j] = Double.MAX_VALUE;
569        } else {
570          instA[j] = inst.value(fs[j]);
571        }
572      }
573      insertIntoTable(inst, instA);
574    }
575
576
577    if (m_CVFolds == 1) {
578
579      // calculate leave one out error
580      for (i=0;i<m_numInstances;i++) {
581        Instance inst = m_theInstances.instance(i);
582        for (int j=0;j<fs.length;j++) {
583          if (fs[j] == classI) {
584            instA[j] = Double.MAX_VALUE; // missing for the class
585          } else if (inst.isMissing(fs[j])) {
586            instA[j] = Double.MAX_VALUE;
587          } else {
588            instA[j] = inst.value(fs[j]);
589          }
590        }
591        evaluateInstanceLeaveOneOut(inst, instA);                               
592      }
593    } else {
594      m_theInstances.randomize(m_rr);
595      m_theInstances.stratify(m_CVFolds);
596
597      // calculate 10 fold cross validation error
598      for (i=0;i<m_CVFolds;i++) {
599        Instances insts = m_theInstances.testCV(m_CVFolds,i);
600        evaluateFoldCV(insts, fs);
601      }
602    }
603
604    switch (m_evaluationMeasure) {
605    case EVAL_DEFAULT:
606      if (m_classIsNominal) {
607        return m_evaluation.pctCorrect();
608      }
609      return -m_evaluation.rootMeanSquaredError();
610    case EVAL_ACCURACY:
611      return m_evaluation.pctCorrect();
612    case EVAL_RMSE:
613      return -m_evaluation.rootMeanSquaredError();
614    case EVAL_MAE:
615      return -m_evaluation.meanAbsoluteError();
616    case EVAL_AUC:
617      double [] classPriors = m_evaluation.getClassPriors();
618      Utils.normalize(classPriors);
619      double weightedAUC = 0;
620      for (i = 0; i < m_theInstances.classAttribute().numValues(); i++) {
621        double tempAUC = m_evaluation.areaUnderROC(i);
622        if (!Utils.isMissingValue(tempAUC)) {
623          weightedAUC += (classPriors[i] * tempAUC);
624        } else {
625          System.err.println("Undefined AUC!!");
626        }
627      }
628      return weightedAUC;
629    }
630    // shouldn't get here
631    return 0.0;
632  }
633
634  /**
635   * Returns a String representation of a feature subset
636   *
637   * @param sub BitSet representation of a subset
638   * @return String containing subset
639   */
640  private String printSub(BitSet sub) {
641
642    String s="";
643    for (int jj=0;jj<m_numAttributes;jj++) {
644      if (sub.get(jj)) {
645        s += " "+(jj+1);
646      }
647    }
648    return s;
649  }
650
651  /**
652   * Resets the options.
653   */
654  protected void resetOptions()  {
655
656    m_entries = null;
657    m_decisionFeatures = null;
658    m_useIBk = false;
659    m_CVFolds = 1;
660    m_displayRules = false;
661    m_evaluationMeasure = EVAL_DEFAULT;
662  }
663
664  /**
665   * Constructor for a DecisionTable
666   */
667  public DecisionTable() {
668
669    resetOptions();
670  }
671
672  /**
673   * Returns an enumeration describing the available options.
674   *
675   * @return an enumeration of all the available options.
676   */
677  public Enumeration listOptions() {
678
679    Vector newVector = new Vector(7);
680
681    newVector.addElement(new Option(
682        "\tFull class name of search method, followed\n"
683        + "\tby its options.\n"
684        + "\teg: \"weka.attributeSelection.BestFirst -D 1\"\n"
685        + "\t(default weka.attributeSelection.BestFirst)",
686        "S", 1, "-S <search method specification>"));
687
688    newVector.addElement(new Option(
689        "\tUse cross validation to evaluate features.\n" +
690        "\tUse number of folds = 1 for leave one out CV.\n" +
691        "\t(Default = leave one out CV)",
692        "X", 1, "-X <number of folds>"));
693
694    newVector.addElement(new Option(
695        "\tPerformance evaluation measure to use for selecting attributes.\n" +
696        "\t(Default = accuracy for discrete class and rmse for numeric class)",
697        "E", 1, "-E <acc | rmse | mae | auc>"));
698
699    newVector.addElement(new Option(
700        "\tUse nearest neighbour instead of global table majority.",
701        "I", 0, "-I"));
702
703    newVector.addElement(new Option(
704        "\tDisplay decision table rules.\n",
705        "R", 0, "-R")); 
706
707    newVector.addElement(new Option(
708        "",
709        "", 0, "\nOptions specific to search method "
710        + m_search.getClass().getName() + ":"));
711    Enumeration enu = ((OptionHandler)m_search).listOptions();
712    while (enu.hasMoreElements()) {
713      newVector.addElement(enu.nextElement());
714    }
715    return newVector.elements();
716  }
717
718  /**
719   * Returns the tip text for this property
720   * @return tip text for this property suitable for
721   * displaying in the explorer/experimenter gui
722   */
723  public String crossValTipText() {
724    return "Sets the number of folds for cross validation (1 = leave one out).";
725  }
726
727  /**
728   * Sets the number of folds for cross validation (1 = leave one out)
729   *
730   * @param folds the number of folds
731   */
732  public void setCrossVal(int folds) {
733
734    m_CVFolds = folds;
735  }
736
737  /**
738   * Gets the number of folds for cross validation
739   *
740   * @return the number of cross validation folds
741   */
742  public int getCrossVal() {
743
744    return m_CVFolds;
745  }
746
747  /**
748   * Returns the tip text for this property
749   * @return tip text for this property suitable for
750   * displaying in the explorer/experimenter gui
751   */
752  public String useIBkTipText() {
753    return "Sets whether IBk should be used instead of the majority class.";
754  }
755
756  /**
757   * Sets whether IBk should be used instead of the majority class
758   *
759   * @param ibk true if IBk is to be used
760   */
761  public void setUseIBk(boolean ibk) {
762
763    m_useIBk = ibk;
764  }
765
766  /**
767   * Gets whether IBk is being used instead of the majority class
768   *
769   * @return true if IBk is being used
770   */
771  public boolean getUseIBk() {
772
773    return m_useIBk;
774  }
775
776  /**
777   * Returns the tip text for this property
778   * @return tip text for this property suitable for
779   * displaying in the explorer/experimenter gui
780   */
781  public String displayRulesTipText() {
782    return "Sets whether rules are to be printed.";
783  }
784
785  /**
786   * Sets whether rules are to be printed
787   *
788   * @param rules true if rules are to be printed
789   */
790  public void setDisplayRules(boolean rules) {
791
792    m_displayRules = rules;
793  }
794
795  /**
796   * Gets whether rules are being printed
797   *
798   * @return true if rules are being printed
799   */
800  public boolean getDisplayRules() {
801
802    return m_displayRules;
803  }
804
805  /**
806   * Returns the tip text for this property
807   * @return tip text for this property suitable for
808   * displaying in the explorer/experimenter gui
809   */
810  public String searchTipText() {
811    return "The search method used to find good attribute combinations for the "
812    + "decision table.";
813  }
814  /**
815   * Sets the search method to use
816   *
817   * @param search
818   */
819  public void setSearch(ASSearch search) {
820    m_search = search;
821  }
822
823  /**
824   * Gets the current search method
825   *
826   * @return the search method used
827   */
828  public ASSearch getSearch() {
829    return m_search;
830  }
831
832  /**
833   * Returns the tip text for this property
834   * @return tip text for this property suitable for
835   * displaying in the explorer/experimenter gui
836   */
837  public String evaluationMeasureTipText() {
838    return "The measure used to evaluate the performance of attribute combinations "
839    + "used in the decision table.";
840  }
841  /**
842   * Gets the currently set performance evaluation measure used for selecting
843   * attributes for the decision table
844   *
845   * @return the performance evaluation measure
846   */
847  public SelectedTag getEvaluationMeasure() {
848    return new SelectedTag(m_evaluationMeasure, TAGS_EVALUATION);
849  }
850
851  /**
852   * Sets the performance evaluation measure to use for selecting attributes
853   * for the decision table
854   *
855   * @param newMethod the new performance evaluation metric to use
856   */
857  public void setEvaluationMeasure(SelectedTag newMethod) {
858    if (newMethod.getTags() == TAGS_EVALUATION) {
859      m_evaluationMeasure = newMethod.getSelectedTag().getID();
860    }
861  }
862
863  /**
864   * Parses the options for this object. <p/>
865   *
866   <!-- options-start -->
867   * Valid options are: <p/>
868   *
869   * <pre> -S &lt;search method specification&gt;
870   *  Full class name of search method, followed
871   *  by its options.
872   *  eg: "weka.attributeSelection.BestFirst -D 1"
873   *  (default weka.attributeSelection.BestFirst)</pre>
874   *
875   * <pre> -X &lt;number of folds&gt;
876   *  Use cross validation to evaluate features.
877   *  Use number of folds = 1 for leave one out CV.
878   *  (Default = leave one out CV)</pre>
879   *
880   * <pre> -E &lt;acc | rmse | mae | auc&gt;
881   *  Performance evaluation measure to use for selecting attributes.
882   *  (Default = accuracy for discrete class and rmse for numeric class)</pre>
883   *
884   * <pre> -I
885   *  Use nearest neighbour instead of global table majority.</pre>
886   *
887   * <pre> -R
888   *  Display decision table rules.
889   * </pre>
890   *
891   * <pre>
892   * Options specific to search method weka.attributeSelection.BestFirst:
893   * </pre>
894   *
895   * <pre> -P &lt;start set&gt;
896   *  Specify a starting set of attributes.
897   *  Eg. 1,3,5-7.</pre>
898   *
899   * <pre> -D &lt;0 = backward | 1 = forward | 2 = bi-directional&gt;
900   *  Direction of search. (default = 1).</pre>
901   *
902   * <pre> -N &lt;num&gt;
903   *  Number of non-improving nodes to
904   *  consider before terminating search.</pre>
905   *
906   * <pre> -S &lt;num&gt;
907   *  Size of lookup cache for evaluated subsets.
908   *  Expressed as a multiple of the number of
909   *  attributes in the data set. (default = 1)</pre>
910   *
911   <!-- options-end -->
912   *
913   * @param options the list of options as an array of strings
914   * @throws Exception if an option is not supported
915   */
916  public void setOptions(String[] options) throws Exception {
917
918    String optionString;
919
920    resetOptions();
921
922    optionString = Utils.getOption('X',options);
923    if (optionString.length() != 0) {
924      m_CVFolds = Integer.parseInt(optionString);
925    }
926
927    m_useIBk = Utils.getFlag('I',options);
928
929    m_displayRules = Utils.getFlag('R',options);
930
931    optionString = Utils.getOption('E', options);
932    if (optionString.length() != 0) {
933      if (optionString.equals("acc")) {
934        setEvaluationMeasure(new SelectedTag(EVAL_ACCURACY, TAGS_EVALUATION));
935      } else if (optionString.equals("rmse")) {
936        setEvaluationMeasure(new SelectedTag(EVAL_RMSE, TAGS_EVALUATION));
937      } else if (optionString.equals("mae")) {
938        setEvaluationMeasure(new SelectedTag(EVAL_MAE, TAGS_EVALUATION));
939      } else if (optionString.equals("auc")) {
940        setEvaluationMeasure(new SelectedTag(EVAL_AUC, TAGS_EVALUATION));
941      } else {
942        throw new IllegalArgumentException("Invalid evaluation measure");
943      }
944    }
945
946    String searchString = Utils.getOption('S', options);
947    if (searchString.length() == 0)
948      searchString = weka.attributeSelection.BestFirst.class.getName();
949    String [] searchSpec = Utils.splitOptions(searchString);
950    if (searchSpec.length == 0) {
951      throw new IllegalArgumentException("Invalid search specification string");
952    }
953    String searchName = searchSpec[0];
954    searchSpec[0] = "";
955    setSearch(ASSearch.forName(searchName, searchSpec));
956  }
957
958  /**
959   * Gets the current settings of the classifier.
960   *
961   * @return an array of strings suitable for passing to setOptions
962   */
963  public String [] getOptions() {
964
965    String [] options = new String [9];
966    int current = 0;
967
968    options[current++] = "-X"; options[current++] = "" + m_CVFolds;
969
970    if (m_evaluationMeasure != EVAL_DEFAULT) {
971      options[current++] = "-E";
972      switch (m_evaluationMeasure) {
973      case EVAL_ACCURACY:
974        options[current++] = "acc";
975        break;
976      case EVAL_RMSE:
977        options[current++] = "rmse";
978        break;
979      case EVAL_MAE:
980        options[current++] = "mae";
981        break;
982      case EVAL_AUC:
983        options[current++] = "auc";
984        break;
985      }
986    }
987    if (m_useIBk) {
988      options[current++] = "-I";
989    }
990    if (m_displayRules) {
991      options[current++] = "-R";
992    }
993
994    options[current++] = "-S";
995    options[current++] = "" + getSearchSpec();
996
997    while (current < options.length) {
998      options[current++] = "";
999    }
1000    return options;
1001  }
1002
1003  /**
1004   * Gets the search specification string, which contains the class name of
1005   * the search method and any options to it
1006   *
1007   * @return the search string.
1008   */
1009  protected String getSearchSpec() {
1010
1011    ASSearch s = getSearch();
1012    if (s instanceof OptionHandler) {
1013      return s.getClass().getName() + " "
1014      + Utils.joinOptions(((OptionHandler)s).getOptions());
1015    }
1016    return s.getClass().getName();
1017  }
1018
1019  /**
1020   * Returns default capabilities of the classifier.
1021   *
1022   * @return      the capabilities of this classifier
1023   */
1024  public Capabilities getCapabilities() {
1025    Capabilities result = super.getCapabilities();
1026    result.disableAll();
1027
1028    // attributes
1029    result.enable(Capability.NOMINAL_ATTRIBUTES);
1030    result.enable(Capability.NUMERIC_ATTRIBUTES);
1031    result.enable(Capability.DATE_ATTRIBUTES);
1032    result.enable(Capability.MISSING_VALUES);
1033
1034    // class
1035    result.enable(Capability.NOMINAL_CLASS);
1036    if (m_evaluationMeasure != EVAL_ACCURACY && m_evaluationMeasure != EVAL_AUC) {
1037      result.enable(Capability.NUMERIC_CLASS);
1038      result.enable(Capability.DATE_CLASS);
1039    }
1040   
1041    result.enable(Capability.MISSING_CLASS_VALUES);
1042
1043    return result;
1044  }
1045 
1046  private class DummySubsetEvaluator extends ASEvaluation implements SubsetEvaluator {
1047    /** for serialization */
1048    private static final long serialVersionUID = 3927442457704974150L;
1049     
1050    public void buildEvaluator(Instances data) throws Exception {
1051    }
1052
1053    public double evaluateSubset(BitSet subset) throws Exception {
1054
1055      int fc = 0;
1056      for (int jj = 0;jj < m_numAttributes; jj++) {
1057        if (subset.get(jj)) {
1058          fc++;
1059        }
1060      }
1061
1062      return estimatePerformance(subset, fc);
1063    }
1064  }
1065
1066  /**
1067   * Sets up a dummy subset evaluator that basically just delegates
1068   * evaluation to the estimatePerformance method in DecisionTable
1069   */
1070  protected void setUpEvaluator() throws Exception {
1071    m_evaluator = new DummySubsetEvaluator();
1072  }
1073
1074  protected boolean m_saveMemory = true;
1075  /**
1076   * Generates the classifier.
1077   *
1078   * @param data set of instances serving as training data
1079   * @throws Exception if the classifier has not been generated successfully
1080   */
1081  public void buildClassifier(Instances data) throws Exception {
1082
1083    // can classifier handle the data?
1084    getCapabilities().testWithFail(data);
1085
1086    // remove instances with missing class
1087    m_theInstances = new Instances(data);
1088    m_theInstances.deleteWithMissingClass();
1089
1090    m_rr = new Random(1);
1091
1092    if (m_theInstances.classAttribute().isNominal())  {//        Set up class priors
1093      m_classPriorCounts = 
1094        new double [data.classAttribute().numValues()];
1095      Arrays.fill(m_classPriorCounts, 1.0);
1096      for (int i = 0; i <data.numInstances(); i++) {
1097        Instance curr = data.instance(i);
1098        m_classPriorCounts[(int)curr.classValue()] += 
1099          curr.weight();
1100      }
1101      m_classPriors = m_classPriorCounts.clone();
1102      Utils.normalize(m_classPriors);
1103    }
1104
1105    setUpEvaluator();
1106
1107    if (m_theInstances.classAttribute().isNumeric()) {
1108      m_disTransform = new weka.filters.unsupervised.attribute.Discretize();
1109      m_classIsNominal = false;
1110
1111      // use binned discretisation if the class is numeric
1112      ((weka.filters.unsupervised.attribute.Discretize)m_disTransform).
1113      setBins(10);
1114      ((weka.filters.unsupervised.attribute.Discretize)m_disTransform).
1115      setInvertSelection(true);
1116
1117      // Discretize all attributes EXCEPT the class
1118      String rangeList = "";
1119      rangeList+=(m_theInstances.classIndex()+1);
1120      //System.out.println("The class col: "+m_theInstances.classIndex());
1121
1122      ((weka.filters.unsupervised.attribute.Discretize)m_disTransform).
1123      setAttributeIndices(rangeList);
1124    } else {
1125      m_disTransform = new weka.filters.supervised.attribute.Discretize();
1126      ((weka.filters.supervised.attribute.Discretize)m_disTransform).setUseBetterEncoding(true);
1127      m_classIsNominal = true;
1128    }
1129
1130    m_disTransform.setInputFormat(m_theInstances);
1131    m_theInstances = Filter.useFilter(m_theInstances, m_disTransform);
1132
1133    m_numAttributes = m_theInstances.numAttributes();
1134    m_numInstances = m_theInstances.numInstances();
1135    m_majority = m_theInstances.meanOrMode(m_theInstances.classAttribute());
1136
1137    // Perform the search
1138    int [] selected = m_search.search(m_evaluator, m_theInstances);
1139
1140    m_decisionFeatures = new int [selected.length+1];
1141    System.arraycopy(selected, 0, m_decisionFeatures, 0, selected.length);
1142    m_decisionFeatures[m_decisionFeatures.length-1] = m_theInstances.classIndex();
1143
1144    // reduce instances to selected features
1145    m_delTransform = new Remove();
1146    m_delTransform.setInvertSelection(true);
1147
1148    // set features to keep
1149    m_delTransform.setAttributeIndicesArray(m_decisionFeatures); 
1150    m_delTransform.setInputFormat(m_theInstances);
1151    m_dtInstances = Filter.useFilter(m_theInstances, m_delTransform);
1152
1153    // reset the number of attributes
1154    m_numAttributes = m_dtInstances.numAttributes();
1155
1156    // create hash table
1157    m_entries = new Hashtable((int)(m_dtInstances.numInstances() * 1.5));
1158
1159    // insert instances into the hash table
1160    for (int i = 0; i < m_numInstances; i++) {
1161      Instance inst = m_dtInstances.instance(i);
1162      insertIntoTable(inst, null);
1163    }
1164
1165    // Replace the global table majority with nearest neighbour?
1166    if (m_useIBk) {
1167      m_ibk = new IBk();
1168      m_ibk.buildClassifier(m_theInstances);
1169    }
1170
1171    // Save memory
1172    if (m_saveMemory) {
1173      m_theInstances = new Instances(m_theInstances, 0);
1174      m_dtInstances = new Instances(m_dtInstances, 0);
1175    }
1176    m_evaluation = null;
1177  }
1178
1179  /**
1180   * Calculates the class membership probabilities for the given
1181   * test instance.
1182   *
1183   * @param instance the instance to be classified
1184   * @return predicted class probability distribution
1185   * @throws Exception if distribution can't be computed
1186   */
1187  public double [] distributionForInstance(Instance instance)
1188  throws Exception {
1189
1190    DecisionTableHashKey thekey;
1191    double [] tempDist;
1192    double [] normDist;
1193
1194    m_disTransform.input(instance);
1195    m_disTransform.batchFinished();
1196    instance = m_disTransform.output();
1197
1198    m_delTransform.input(instance);
1199    m_delTransform.batchFinished();
1200    instance = m_delTransform.output();
1201
1202    thekey = new DecisionTableHashKey(instance, instance.numAttributes(), false);
1203
1204    // if this one is not in the table
1205    if ((tempDist = (double [])m_entries.get(thekey)) == null) {
1206      if (m_useIBk) {
1207        tempDist = m_ibk.distributionForInstance(instance);
1208      } else {
1209        if (!m_classIsNominal) {
1210          tempDist = new double[1];
1211          tempDist[0] = m_majority;
1212        } else {
1213          tempDist = m_classPriors.clone();
1214          /*tempDist = new double [m_theInstances.classAttribute().numValues()];
1215          tempDist[(int)m_majority] = 1.0; */
1216        }
1217      }
1218    } else {
1219      if (!m_classIsNominal) {
1220        normDist = new double[1];
1221        normDist[0] = (tempDist[0] / tempDist[1]);
1222        tempDist = normDist;
1223      } else {
1224
1225        // normalise distribution
1226        normDist = new double [tempDist.length];
1227        System.arraycopy(tempDist,0,normDist,0,tempDist.length);
1228        Utils.normalize(normDist);
1229        tempDist = normDist;
1230      }
1231    }
1232    return tempDist;
1233  }
1234
1235  /**
1236   * Returns a string description of the features selected
1237   *
1238   * @return a string of features
1239   */
1240  public String printFeatures() {
1241
1242    int i;
1243    String s = "";
1244
1245    for (i=0;i<m_decisionFeatures.length;i++) {
1246      if (i==0) {
1247        s = ""+(m_decisionFeatures[i]+1);
1248      } else {
1249        s += ","+(m_decisionFeatures[i]+1);
1250      }
1251    }
1252    return s;
1253  }
1254
1255  /**
1256   * Returns the number of rules
1257   * @return the number of rules
1258   */
1259  public double measureNumRules() {
1260    return m_entries.size();
1261  }
1262
1263  /**
1264   * Returns an enumeration of the additional measure names
1265   * @return an enumeration of the measure names
1266   */
1267  public Enumeration enumerateMeasures() {
1268    Vector newVector = new Vector(1);
1269    newVector.addElement("measureNumRules");
1270    return newVector.elements();
1271  }
1272
1273  /**
1274   * Returns the value of the named measure
1275   * @param additionalMeasureName the name of the measure to query for its value
1276   * @return the value of the named measure
1277   * @throws IllegalArgumentException if the named measure is not supported
1278   */
1279  public double getMeasure(String additionalMeasureName) {
1280    if (additionalMeasureName.compareToIgnoreCase("measureNumRules") == 0) {
1281      return measureNumRules();
1282    } else {
1283      throw new IllegalArgumentException(additionalMeasureName
1284          + " not supported (DecisionTable)");
1285    }
1286  }
1287
1288  /**
1289   * Returns a description of the classifier.
1290   *
1291   * @return a description of the classifier as a string.
1292   */
1293  public String toString() {
1294
1295    if (m_entries == null) {
1296      return "Decision Table: No model built yet.";
1297    } else {
1298      StringBuffer text = new StringBuffer();
1299
1300      text.append("Decision Table:"+
1301          "\n\nNumber of training instances: "+m_numInstances+
1302          "\nNumber of Rules : "+m_entries.size()+"\n");
1303
1304      if (m_useIBk) {
1305        text.append("Non matches covered by IB1.\n");
1306      } else {
1307        text.append("Non matches covered by Majority class.\n");
1308      }
1309
1310      text.append(m_search.toString());
1311      /*text.append("Best first search for feature set,\nterminated after "+
1312                                        m_maxStale+" non improving subsets.\n"); */
1313
1314      text.append("Evaluation (for feature selection): CV ");
1315      if (m_CVFolds > 1) {
1316        text.append("("+m_CVFolds+" fold) ");
1317      } else {
1318        text.append("(leave one out) ");
1319      }
1320      text.append("\nFeature set: "+printFeatures());
1321
1322      if (m_displayRules) {
1323
1324        // find out the max column width
1325        int maxColWidth = 0;
1326        for (int i=0;i<m_dtInstances.numAttributes();i++) {
1327          if (m_dtInstances.attribute(i).name().length() > maxColWidth) {
1328            maxColWidth = m_dtInstances.attribute(i).name().length();
1329          }
1330
1331          if (m_classIsNominal || (i != m_dtInstances.classIndex())) {
1332            Enumeration e = m_dtInstances.attribute(i).enumerateValues();
1333            while (e.hasMoreElements()) {
1334              String ss = (String)e.nextElement();
1335              if (ss.length() > maxColWidth) {
1336                maxColWidth = ss.length();
1337              }
1338            }
1339          }
1340        }
1341
1342        text.append("\n\nRules:\n");
1343        StringBuffer tm = new StringBuffer();
1344        for (int i=0;i<m_dtInstances.numAttributes();i++) {
1345          if (m_dtInstances.classIndex() != i) {
1346            int d = maxColWidth - m_dtInstances.attribute(i).name().length();
1347            tm.append(m_dtInstances.attribute(i).name());
1348            for (int j=0;j<d+1;j++) {
1349              tm.append(" ");
1350            }
1351          }
1352        }
1353        tm.append(m_dtInstances.attribute(m_dtInstances.classIndex()).name()+"  ");
1354
1355        for (int i=0;i<tm.length()+10;i++) {
1356          text.append("=");
1357        }
1358        text.append("\n");
1359        text.append(tm);
1360        text.append("\n");
1361        for (int i=0;i<tm.length()+10;i++) {
1362          text.append("=");
1363        }
1364        text.append("\n");
1365
1366        Enumeration e = m_entries.keys();
1367        while (e.hasMoreElements()) {
1368          DecisionTableHashKey tt = (DecisionTableHashKey)e.nextElement();
1369          text.append(tt.toString(m_dtInstances,maxColWidth));
1370          double [] ClassDist = (double []) m_entries.get(tt);
1371
1372          if (m_classIsNominal) {
1373            int m = Utils.maxIndex(ClassDist);
1374            try {
1375              text.append(m_dtInstances.classAttribute().value(m)+"\n");
1376            } catch (Exception ee) {
1377              System.out.println(ee.getMessage());
1378            }
1379          } else {
1380            text.append((ClassDist[0] / ClassDist[1])+"\n");
1381          }
1382        }
1383
1384        for (int i=0;i<tm.length()+10;i++) {
1385          text.append("=");
1386        }
1387        text.append("\n");
1388        text.append("\n");
1389      }
1390      return text.toString();
1391    }
1392  }
1393 
1394  /**
1395   * Returns the revision string.
1396   *
1397   * @return            the revision
1398   */
1399  public String getRevision() {
1400    return RevisionUtils.extract("$Revision: 5987 $");
1401  }
1402
1403  /**
1404   * Main method for testing this class.
1405   *
1406   * @param argv the command-line options
1407   */
1408  public static void main(String [] argv) {
1409    runClassifier(new DecisionTable(), argv);
1410  }
1411}
Note: See TracBrowser for help on using the repository browser.