source: src/main/java/weka/classifiers/rules/DTNB.java @ 4

Last change on this file since 4 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 28.0 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    DecisionTable.java
19 *    Copyright (C) 2008 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.rules;
24
25import weka.attributeSelection.ASEvaluation;
26import weka.attributeSelection.ASSearch;
27import weka.attributeSelection.SubsetEvaluator;
28import weka.classifiers.bayes.NaiveBayes;
29import weka.core.Capabilities;
30import weka.core.Instance;
31import weka.core.Instances;
32import weka.core.Option;
33import weka.core.RevisionUtils;
34import weka.core.SelectedTag;
35import weka.core.TechnicalInformation;
36import weka.core.Utils;
37import weka.core.Capabilities.Capability;
38import weka.core.TechnicalInformation.Field;
39import weka.core.TechnicalInformation.Type;
40
41import java.util.BitSet;
42import java.util.Enumeration;
43import java.util.Vector;
44
45/**
46 *
47 <!-- globalinfo-start -->
48 * Class for building and using a decision table/naive bayes hybrid classifier. At each point in the search, the algorithm evaluates the merit of dividing the attributes into two disjoint subsets: one for the decision table, the other for naive Bayes. A forward selection search is used, where at each step, selected attributes are modeled by naive Bayes and the remainder by the decision table, and all attributes are modelled by the decision table initially. At each step, the algorithm also considers dropping an attribute entirely from the model.<br/>
49 * <br/>
50 * For more information, see: <br/>
51 * <br/>
52 * Mark Hall, Eibe Frank: Combining Naive Bayes and Decision Tables. In: Proceedings of the 21st Florida Artificial Intelligence Society Conference (FLAIRS), ???-???, 2008.
53 * <p/>
54 <!-- globalinfo-end -->
55 *
56 <!-- technical-bibtex-start -->
57 * BibTeX:
58 * <pre>
59 * &#64;inproceedings{Hall2008,
60 *    author = {Mark Hall and Eibe Frank},
61 *    booktitle = {Proceedings of the 21st Florida Artificial Intelligence Society Conference (FLAIRS)},
62 *    pages = {???-???},
63 *    publisher = {AAAI press},
64 *    title = {Combining Naive Bayes and Decision Tables},
65 *    year = {2008}
66 * }
67 * </pre>
68 * <p/>
69 <!-- technical-bibtex-end -->
70 *
71 <!-- options-start -->
72 * Valid options are: <p/>
73 *
74 * <pre> -X &lt;number of folds&gt;
75 *  Use cross validation to evaluate features.
76 *  Use number of folds = 1 for leave one out CV.
77 *  (Default = leave one out CV)</pre>
78 *
79 * <pre> -E &lt;acc | rmse | mae | auc&gt;
80 *  Performance evaluation measure to use for selecting attributes.
81 *  (Default = accuracy for discrete class and rmse for numeric class)</pre>
82 *
83 * <pre> -I
84 *  Use nearest neighbour instead of global table majority.</pre>
85 *
86 * <pre> -R
87 *  Display decision table rules.
88 * </pre>
89 *
90 <!-- options-end -->
91 *
92 * @author Mark Hall (mhall{[at]}pentaho{[dot]}org)
93 * @author Eibe Frank (eibe{[at]}cs{[dot]}waikato{[dot]}ac{[dot]}nz)
94 *
95 * @version $Revision: 1.4 $
96 *
97 */
98public class DTNB extends DecisionTable {
99
100  /**
101   * The naive Bayes half of the hybrid
102   */
103  protected NaiveBayes m_NB;
104
105  /**
106   * The features used by naive Bayes
107   */
108  private int [] m_nbFeatures;
109
110  /**
111   * Percentage of the total number of features used by the decision table
112   */
113  private double m_percentUsedByDT;
114 
115  /**
116   * Percentage of the features features that were dropped entirely
117   */
118  private double m_percentDeleted;
119
120  static final long serialVersionUID = 2999557077765701326L;
121
122  /**
123   * Returns a string describing classifier
124   * @return a description suitable for
125   * displaying in the explorer/experimenter gui
126   */
127  public String globalInfo() {
128
129    return 
130      "Class for building and using a decision table/naive bayes hybrid classifier. At each point "
131      + "in the search, the algorithm evaluates the merit of dividing the attributes into two disjoint "
132      + "subsets: one for the decision table, the other for naive Bayes. A forward selection search is "
133      + "used, where at each step, selected attributes are modeled by naive Bayes and the remainder "
134      + "by the decision table, and all attributes are modelled by the decision table initially. At each "
135      + "step, the algorithm also considers dropping an attribute entirely from the model.\n\n"
136      + "For more information, see: \n\n"
137      + getTechnicalInformation().toString();
138  }
139
140  /**
141   * Returns an instance of a TechnicalInformation object, containing
142   * detailed information about the technical background of this class,
143   * e.g., paper reference or book this class is based on.
144   *
145   * @return the technical information about this class
146   */
147  public TechnicalInformation getTechnicalInformation() {
148    TechnicalInformation        result;
149
150    result = new TechnicalInformation(Type.INPROCEEDINGS);
151    result.setValue(Field.AUTHOR, "Mark Hall and Eibe Frank");
152    result.setValue(Field.TITLE, "Combining Naive Bayes and Decision Tables");
153    result.setValue(Field.BOOKTITLE, "Proceedings of the 21st Florida Artificial Intelligence "
154                    + "Society Conference (FLAIRS)");
155    result.setValue(Field.YEAR, "2008");
156    result.setValue(Field.PAGES, "???-???");
157    result.setValue(Field.PUBLISHER, "AAAI press");
158
159    return result;
160  }
161
162  /**
163   * Calculates the accuracy on a test fold for internal cross validation
164   * of feature sets
165   *
166   * @param fold set of instances to be "left out" and classified
167   * @param fs currently selected feature set
168   * @return the accuracy for the fold
169   * @throws Exception if something goes wrong
170   */
171  double evaluateFoldCV(Instances fold, int [] fs) throws Exception {
172
173    int i;
174    int ruleCount = 0;
175    int numFold = fold.numInstances();
176    int numCl = m_theInstances.classAttribute().numValues();
177    double [][] class_distribs = new double [numFold][numCl];
178    double [] instA = new double [fs.length];
179    double [] normDist;
180    DecisionTableHashKey thekey;
181    double acc = 0.0;
182    int classI = m_theInstances.classIndex();
183    Instance inst;
184
185    if (m_classIsNominal) {
186      normDist = new double [numCl];
187    } else {
188      normDist = new double [2];
189    }
190
191    // first *remove* instances
192    for (i=0;i<numFold;i++) {
193      inst = fold.instance(i);
194      for (int j=0;j<fs.length;j++) {
195        if (fs[j] == classI) {
196          instA[j] = Double.MAX_VALUE; // missing for the class
197        } else if (inst.isMissing(fs[j])) {
198          instA[j] = Double.MAX_VALUE;
199        } else{
200          instA[j] = inst.value(fs[j]);
201        }
202      }
203      thekey = new DecisionTableHashKey(instA);
204      if ((class_distribs[i] = (double [])m_entries.get(thekey)) == null) {
205        throw new Error("This should never happen!");
206      } else {
207        if (m_classIsNominal) {
208          class_distribs[i][(int)inst.classValue()] -= inst.weight();
209          inst.setWeight(-inst.weight());
210          m_NB.updateClassifier(inst);
211          inst.setWeight(-inst.weight());
212        } else {
213          class_distribs[i][0] -= (inst.classValue() * inst.weight());
214          class_distribs[i][1] -= inst.weight();
215        }
216        ruleCount++;
217      }
218      m_classPriorCounts[(int)inst.classValue()] -= 
219        inst.weight(); 
220    }
221    double [] classPriors = m_classPriorCounts.clone();
222    Utils.normalize(classPriors);
223
224    // now classify instances
225    for (i=0;i<numFold;i++) {
226      inst = fold.instance(i);
227      System.arraycopy(class_distribs[i],0,normDist,0,normDist.length);
228      if (m_classIsNominal) {
229        boolean ok = false;
230        for (int j=0;j<normDist.length;j++) {
231          if (Utils.gr(normDist[j],1.0)) {
232            ok = true;
233            break;
234          }
235        }
236
237        if (!ok) { // majority class
238          normDist = classPriors.clone();
239        } else {
240          Utils.normalize(normDist);
241        }
242
243        double [] nbDist = m_NB.distributionForInstance(inst);
244
245        for (int l = 0; l < normDist.length; l++) {
246          normDist[l] = (Math.log(normDist[l]) - Math.log(classPriors[l]));
247          normDist[l] += Math.log(nbDist[l]);
248        }
249        normDist = Utils.logs2probs(normDist);
250        // Utils.normalize(normDist);
251
252        //      System.out.println(normDist[0] + " " + normDist[1] + " " + inst.classValue());
253
254        if (m_evaluationMeasure == EVAL_AUC) {
255          m_evaluation.evaluateModelOnceAndRecordPrediction(normDist, inst);
256        } else {
257          m_evaluation.evaluateModelOnce(normDist, inst);
258        }
259        /*      } else {                                       
260          normDist[(int)m_majority] = 1.0;
261          if (m_evaluationMeasure == EVAL_AUC) {
262            m_evaluation.evaluateModelOnceAndRecordPrediction(normDist, inst);                                         
263          } else {
264            m_evaluation.evaluateModelOnce(normDist, inst);                                     
265          }
266        } */
267      } else {
268        if (Utils.eq(normDist[1],0.0)) {
269          double [] temp = new double[1];
270          temp[0] = m_majority;
271          m_evaluation.evaluateModelOnce(temp, inst);
272        } else {
273          double [] temp = new double[1];
274          temp[0] = normDist[0] / normDist[1];
275          m_evaluation.evaluateModelOnce(temp, inst);
276        }
277      }
278    }
279
280    // now re-insert instances
281    for (i=0;i<numFold;i++) {
282      inst = fold.instance(i);
283
284      m_classPriorCounts[(int)inst.classValue()] += 
285        inst.weight();
286
287      if (m_classIsNominal) {
288        class_distribs[i][(int)inst.classValue()] += inst.weight();
289        m_NB.updateClassifier(inst);
290      } else {
291        class_distribs[i][0] += (inst.classValue() * inst.weight());
292        class_distribs[i][1] += inst.weight();
293      }
294    }
295    return acc;
296  }
297
298  /**
299   * Classifies an instance for internal leave one out cross validation
300   * of feature sets
301   *
302   * @param instance instance to be "left out" and classified
303   * @param instA feature values of the selected features for the instance
304   * @return the classification of the instance
305   * @throws Exception if something goes wrong
306   */
307  double evaluateInstanceLeaveOneOut(Instance instance, double [] instA)
308  throws Exception {
309
310    DecisionTableHashKey thekey;
311    double [] tempDist;
312    double [] normDist;
313
314    thekey = new DecisionTableHashKey(instA);
315
316    // if this one is not in the table
317    if ((tempDist = (double [])m_entries.get(thekey)) == null) {
318      throw new Error("This should never happen!");
319    } else {
320      normDist = new double [tempDist.length];
321      System.arraycopy(tempDist,0,normDist,0,tempDist.length);
322      normDist[(int)instance.classValue()] -= instance.weight();
323
324      // update the table
325      // first check to see if the class counts are all zero now
326      boolean ok = false;
327      for (int i=0;i<normDist.length;i++) {
328        if (Utils.gr(normDist[i],1.0)) {
329          ok = true;
330          break;
331        }
332      }
333
334      // downdate the class prior counts
335      m_classPriorCounts[(int)instance.classValue()] -= 
336        instance.weight(); 
337      double [] classPriors = m_classPriorCounts.clone();
338      Utils.normalize(classPriors);
339      if (!ok) { // majority class     
340        normDist = classPriors;
341      } else {
342        Utils.normalize(normDist);
343      }
344
345      m_classPriorCounts[(int)instance.classValue()] += 
346      instance.weight();
347
348      if (m_NB != null){
349        // downdate NaiveBayes
350
351        instance.setWeight(-instance.weight());
352        m_NB.updateClassifier(instance);
353        double [] nbDist = m_NB.distributionForInstance(instance);
354        instance.setWeight(-instance.weight());
355        m_NB.updateClassifier(instance);
356
357        for (int i = 0; i < normDist.length; i++) {
358          normDist[i] = (Math.log(normDist[i]) - Math.log(classPriors[i]));
359          normDist[i] += Math.log(nbDist[i]);
360        }
361        normDist = Utils.logs2probs(normDist);
362        // Utils.normalize(normDist);
363      }
364
365      if (m_evaluationMeasure == EVAL_AUC) {
366        m_evaluation.evaluateModelOnceAndRecordPrediction(normDist, instance);                                         
367      } else {
368        m_evaluation.evaluateModelOnce(normDist, instance);
369      }
370      return Utils.maxIndex(normDist);
371    }
372  }
373
374  /**
375   * Sets up a dummy subset evaluator that basically just delegates
376   * evaluation to the estimatePerformance method in DecisionTable
377   */
378  protected void setUpEvaluator() throws Exception {
379    m_evaluator = new EvalWithDelete();
380    m_evaluator.buildEvaluator(m_theInstances);
381  }
382 
383  protected class EvalWithDelete extends ASEvaluation implements SubsetEvaluator {
384   
385    // holds the list of attributes that are no longer in the model at all
386    private BitSet m_deletedFromDTNB;
387   
388    public void buildEvaluator(Instances data) throws Exception {
389      m_NB = null;
390      m_deletedFromDTNB = new BitSet(data.numAttributes());
391      // System.err.println("Here");
392    }
393   
394   private int setUpForEval(BitSet subset) throws Exception {
395     
396     int fc = 0;
397     for (int jj = 0;jj < m_numAttributes; jj++) {
398        if (subset.get(jj)) {
399          fc++;
400        }
401     }
402
403     //int [] nbFs = new int [fc];
404     //int count = 0;
405
406     for (int j = 0; j < m_numAttributes; j++) {
407        m_theInstances.attribute(j).setWeight(1.0); // reset weight
408        if (j != m_theInstances.classIndex()) {
409          if (subset.get(j)) {
410        //    nbFs[count++] = j;
411            m_theInstances.attribute(j).setWeight(0.0); // no influence for NB
412          }
413        }
414     }
415     
416     // process delete set
417     for (int i = 0; i < m_numAttributes; i++) {
418        if (m_deletedFromDTNB.get(i)) {
419           m_theInstances.attribute(i).setWeight(0.0); // no influence for NB
420        }
421     }
422     
423     if (m_NB == null) {
424        // construct naive bayes for the first time
425        m_NB = new NaiveBayes();
426        m_NB.buildClassifier(m_theInstances);
427     }
428     return fc;
429   }
430
431    public double evaluateSubset(BitSet subset) throws Exception {
432      int fc = setUpForEval(subset);
433     
434      return estimatePerformance(subset, fc);
435    }
436   
437    public double evaluateSubsetDelete(BitSet subset, int potentialDelete) throws Exception {
438     
439      int fc = setUpForEval(subset);
440     
441      // clear potentail delete for naive Bayes
442      m_theInstances.attribute(potentialDelete).setWeight(0.0);
443      //copy.clear(potentialDelete);
444      //fc--;
445      return estimatePerformance(subset, fc);
446    }
447   
448    public BitSet getDeletedList() {
449      return m_deletedFromDTNB;
450    }
451   
452    /**
453     * Returns the revision string.
454     *
455     * @return          the revision
456     */
457    public String getRevision() {
458      return RevisionUtils.extract("$Revision: 1.4 $");
459    }
460  }
461
462  protected ASSearch m_backwardWithDelete;
463
464  /**
465   * Inner class implementing a special forwards search that looks for a good
466   * split of attributes between naive Bayes and the decision table. It also
467   * considers dropping attributes entirely from the model.
468   */
469  protected class BackwardsWithDelete extends ASSearch {
470
471    public String globalInfo() {
472      return "Specialized search that performs a forward selection (naive Bayes)/"
473        + "backward elimination (decision table). Also considers dropping attributes "
474        + "entirely from the combined model.";
475    }
476
477    public String toString() {
478      return "";
479    }
480
481    public int [] search(ASEvaluation eval, Instances data)
482        throws Exception {
483        int i;
484        double best_merit = -Double.MAX_VALUE;
485        double temp_best = 0, temp_merit = 0, temp_merit_delete = 0;
486        int temp_index=0;
487        BitSet temp_group;
488        BitSet best_group = null;
489
490        int numAttribs = data.numAttributes();
491
492        if (best_group == null) {
493          best_group = new BitSet(numAttribs);
494        }
495
496       
497        int classIndex = data.classIndex();
498        for (i = 0; i < numAttribs; i++) {
499          if (i != classIndex) {
500            best_group.set(i);
501          }
502        }
503
504        //System.err.println(best_group);
505       
506        // Evaluate the initial subset
507        //      best_merit = m_evaluator.evaluateSubset(best_group);
508        best_merit = ((SubsetEvaluator)eval).evaluateSubset(best_group);
509
510        //System.err.println(best_merit);
511
512        // main search loop
513        boolean done = false;
514        boolean addone = false;
515        boolean z;
516        boolean deleted = false;
517        while (!done) {
518          temp_group = (BitSet)best_group.clone();
519          temp_best = best_merit;
520         
521          done = true;
522          addone = false;
523          for (i = 0; i < numAttribs;i++) {
524            z = ((i != classIndex) && (temp_group.get(i)));
525
526            if (z) {
527              // set/unset the bit
528              temp_group.clear(i);
529
530              //              temp_merit = m_evaluator.evaluateSubset(temp_group);
531              temp_merit = ((SubsetEvaluator)eval).evaluateSubset(temp_group);
532              //              temp_merit_delete = ((EvalWithDelete)m_evaluator).evaluateSubsetDelete(temp_group, i);
533              temp_merit_delete = ((EvalWithDelete)eval).evaluateSubsetDelete(temp_group, i);
534              boolean deleteBetter = false;
535              //System.out.println("Merit: " + temp_merit + "\t" + "Delete merit: " + temp_merit_delete);
536              if (temp_merit_delete >= temp_merit) {
537                temp_merit = temp_merit_delete;
538                deleteBetter = true;
539              }
540             
541              z = (temp_merit >= temp_best);
542
543              if (z) {
544                temp_best = temp_merit;
545                temp_index = i;
546                addone = true;
547                done = false;
548                if (deleteBetter) {
549                  deleted = true;
550                } else {
551                  deleted = false;
552                }
553              }
554
555              // unset this addition/deletion
556                temp_group.set(i);
557            }
558          }
559          if (addone) {
560            best_group.clear(temp_index);
561            best_merit = temp_best;
562            if (deleted) {
563              //              ((EvalWithDelete)m_evaluator).getDeletedList().set(temp_index);
564              ((EvalWithDelete)eval).getDeletedList().set(temp_index);
565            }
566            //System.err.println("----------------------");
567            //System.err.println("Best subset: (dec table)" + best_group);
568            //System.err.println("Best subset: (deleted)" + ((EvalWithDelete)m_evaluator).getDeletedList());
569            //System.err.println(best_merit);
570          }
571        }
572        return attributeList(best_group);
573      }
574     
575      /**
576       * converts a BitSet into a list of attribute indexes
577       * @param group the BitSet to convert
578       * @return an array of attribute indexes
579       **/
580      protected int[] attributeList (BitSet group) {
581        int count = 0;
582        BitSet copy = (BitSet)group.clone();
583       
584        /* remove any that have been completely deleted from DTNB
585        BitSet deleted = ((EvalWithDelete)m_evaluator).getDeletedList();
586        for (int i = 0; i < m_numAttributes; i++) {
587          if (deleted.get(i)) {
588            copy.clear(i);
589          }
590        } */
591       
592        // count how many were selected
593        for (int i = 0; i < m_numAttributes; i++) {
594          if (copy.get(i)) {
595            count++;
596          }
597        }
598
599        int[] list = new int[count];
600        count = 0;
601
602        for (int i = 0; i < m_numAttributes; i++) {
603          if (copy.get(i)) {
604            list[count++] = i;
605          }
606        }
607
608        return  list;
609      }
610     
611      /**
612       * Returns the revision string.
613       *
614       * @return                the revision
615       */
616      public String getRevision() {
617        return RevisionUtils.extract("$Revision: 1.4 $");
618      }
619  }
620
621  private void setUpSearch() {
622    m_backwardWithDelete = new BackwardsWithDelete();
623  }
624 
625  /**
626   * Generates the classifier.
627   *
628   * @param data set of instances serving as training data
629   * @throws Exception if the classifier has not been generated successfully
630   */
631  public void buildClassifier(Instances data) throws Exception {
632
633    m_saveMemory = false;
634
635    if (data.classAttribute().isNumeric()) {
636      throw new Exception("Can only handle nominal class!");
637    }
638
639    if (m_backwardWithDelete == null) {
640      setUpSearch();
641      m_search = m_backwardWithDelete;
642    }
643
644    /*    if (m_search != m_backwardWithDelete) {
645      m_search = m_backwardWithDelete;
646      } */
647    super.buildClassifier(data);
648
649    // new NB stuff
650
651    // delete the features used by the decision table (not the class!!)
652    for (int i = 0; i < m_theInstances.numAttributes(); i++) {
653      m_theInstances.attribute(i).setWeight(1.0); // reset all weights
654    }
655    // m_nbFeatures = new int [m_decisionFeatures.length - 1];
656     int count = 0;
657
658    for (int i = 0; i < m_decisionFeatures.length; i++) {
659      if (m_decisionFeatures[i] != m_theInstances.classIndex()) {
660        count++;
661//      m_nbFeatures[count++] = m_decisionFeatures[i];
662        m_theInstances.attribute(m_decisionFeatures[i]).setWeight(0.0); // No influence for NB
663      }
664    }
665   
666    double numDeleted = 0;
667    // remove any attributes that have been deleted completely from the DTNB
668    BitSet deleted = ((EvalWithDelete)m_evaluator).getDeletedList();
669    for (int i = 0; i < m_theInstances.numAttributes(); i++) {
670      if (deleted.get(i)) {
671        m_theInstances.attribute(i).setWeight(0.0);
672        // count--;
673        numDeleted++;
674        // System.err.println("Attribute "+i+" was eliminated completely");
675      }
676    }
677   
678    m_percentUsedByDT = (double)count / (m_theInstances.numAttributes() - 1);
679    m_percentDeleted = numDeleted / (m_theInstances.numAttributes() -1);
680
681    m_NB = new NaiveBayes();
682    m_NB.buildClassifier(m_theInstances);
683
684    m_dtInstances = new Instances(m_dtInstances, 0);
685    m_theInstances = new Instances(m_theInstances, 0);
686  }
687
688  /**
689   * Calculates the class membership probabilities for the given
690   * test instance.
691   *
692   * @param instance the instance to be classified
693   * @return predicted class probability distribution
694   * @exception Exception if distribution can't be computed
695   */
696  public double [] distributionForInstance(Instance instance)
697  throws Exception {
698
699    DecisionTableHashKey thekey;
700    double [] tempDist;
701    double [] normDist;
702
703    m_disTransform.input(instance);
704    m_disTransform.batchFinished();
705    instance = m_disTransform.output();
706
707    m_delTransform.input(instance);
708    m_delTransform.batchFinished();
709    Instance dtInstance = m_delTransform.output();
710
711    thekey = new DecisionTableHashKey(dtInstance, dtInstance.numAttributes(), false);
712
713    // if this one is not in the table
714    if ((tempDist = (double [])m_entries.get(thekey)) == null) {
715      if (m_useIBk) {
716        tempDist = m_ibk.distributionForInstance(dtInstance);
717      } else { 
718        // tempDist = new double [m_theInstances.classAttribute().numValues()];
719//      tempDist[(int)m_majority] = 1.0;
720       
721        tempDist = m_classPriors.clone();
722        // return tempDist; ??????
723      }
724    } else {
725      // normalise distribution
726      normDist = new double [tempDist.length];
727      System.arraycopy(tempDist,0,normDist,0,tempDist.length);
728      Utils.normalize(normDist);
729      tempDist = normDist;                     
730    }
731
732    double [] nbDist = m_NB.distributionForInstance(instance);
733    for (int i = 0; i < nbDist.length; i++) {
734      tempDist[i] = (Math.log(tempDist[i]) - Math.log(m_classPriors[i]));
735      tempDist[i] += Math.log(nbDist[i]);
736
737      /*tempDist[i] *= nbDist[i];
738      tempDist[i] /= m_classPriors[i];*/
739    }
740    tempDist = Utils.logs2probs(tempDist);
741    Utils.normalize(tempDist);
742
743    return tempDist;
744  }
745
746  public String toString() {
747
748    String sS = super.toString();
749    if (m_displayRules && m_NB != null) {
750      sS += m_NB.toString();                   
751    }
752    return sS;
753  }
754 
755  /**
756   * Returns the number of rules
757   * @return the number of rules
758   */
759  public double measurePercentAttsUsedByDT() {
760    return m_percentUsedByDT;
761  }
762 
763  /**
764   * Returns an enumeration of the additional measure names
765   * @return an enumeration of the measure names
766   */
767  public Enumeration enumerateMeasures() {
768    Vector newVector = new Vector(2);
769    newVector.addElement("measureNumRules");
770    newVector.addElement("measurePercentAttsUsedByDT");
771    return newVector.elements();
772  }
773
774  /**
775   * Returns the value of the named measure
776   * @param additionalMeasureName the name of the measure to query for its value
777   * @return the value of the named measure
778   * @throws IllegalArgumentException if the named measure is not supported
779   */
780  public double getMeasure(String additionalMeasureName) {
781    if (additionalMeasureName.compareToIgnoreCase("measureNumRules") == 0) {
782      return measureNumRules();
783    } else if (additionalMeasureName.compareToIgnoreCase("measurePercentAttsUsedByDT") == 0) {
784      return measurePercentAttsUsedByDT();
785    } else {
786      throw new IllegalArgumentException(additionalMeasureName
787          + " not supported (DecisionTable)");
788    }
789  }
790
791  /**
792   * Returns default capabilities of the classifier.
793   *
794   * @return      the capabilities of this classifier
795   */
796  public Capabilities getCapabilities() {
797    Capabilities result = super.getCapabilities();
798
799    result.disable(Capability.NUMERIC_CLASS);
800    result.disable(Capability.DATE_CLASS);
801
802    return result;
803  }
804
805  /**
806   * Sets the search method to use
807   *
808   * @param search
809   */
810  public void setSearch(ASSearch search) {
811    // Search method cannot be changed.
812    // Must be BackwardsWithDelete
813    return;
814  }
815
816  /**
817   * Gets the current search method
818   *
819   * @return the search method used
820   */
821  public ASSearch getSearch() {
822    if (m_backwardWithDelete == null) {
823      setUpSearch();
824      //      setSearch(m_backwardWithDelete);
825      m_search = m_backwardWithDelete;
826    }
827    return m_search;
828  }
829
830  /**
831   * Returns an enumeration describing the available options.
832   *
833   * @return an enumeration of all the available options.
834   */
835  public Enumeration listOptions() {
836
837    Vector newVector = new Vector(7);
838
839    newVector.addElement(new Option(
840        "\tUse cross validation to evaluate features.\n" +
841        "\tUse number of folds = 1 for leave one out CV.\n" +
842        "\t(Default = leave one out CV)",
843        "X", 1, "-X <number of folds>"));
844
845    newVector.addElement(new Option(
846        "\tPerformance evaluation measure to use for selecting attributes.\n" +
847        "\t(Default = accuracy for discrete class and rmse for numeric class)",
848        "E", 1, "-E <acc | rmse | mae | auc>"));
849
850    newVector.addElement(new Option(
851        "\tUse nearest neighbour instead of global table majority.",
852        "I", 0, "-I"));
853
854    newVector.addElement(new Option(
855        "\tDisplay decision table rules.\n",
856        "R", 0, "-R")); 
857
858    return newVector.elements();
859  }
860
861  /**
862   * Parses the options for this object. <p/>
863   *
864   <!-- options-start -->
865   * Valid options are: <p/>
866   *
867   * <pre> -X &lt;number of folds&gt;
868   *  Use cross validation to evaluate features.
869   *  Use number of folds = 1 for leave one out CV.
870   *  (Default = leave one out CV)</pre>
871   *
872   * <pre> -E &lt;acc | rmse | mae | auc&gt;
873   *  Performance evaluation measure to use for selecting attributes.
874   *  (Default = accuracy for discrete class and rmse for numeric class)</pre>
875   *
876   * <pre> -I
877   *  Use nearest neighbour instead of global table majority.</pre>
878   *
879   * <pre> -R
880   *  Display decision table rules.
881   * </pre>
882   *
883   <!-- options-end -->
884   *
885   * @param options the list of options as an array of strings
886   * @throws Exception if an option is not supported
887   */
888  public void setOptions(String[] options) throws Exception {
889
890    String optionString;
891
892    resetOptions();
893
894    optionString = Utils.getOption('X',options);
895    if (optionString.length() != 0) {
896      setCrossVal(Integer.parseInt(optionString));
897    }
898
899    m_useIBk = Utils.getFlag('I',options);
900
901    m_displayRules = Utils.getFlag('R',options);
902
903    optionString = Utils.getOption('E', options);
904    if (optionString.length() != 0) {
905      if (optionString.equals("acc")) {
906        setEvaluationMeasure(new SelectedTag(EVAL_ACCURACY, TAGS_EVALUATION));
907      } else if (optionString.equals("rmse")) {
908        setEvaluationMeasure(new SelectedTag(EVAL_RMSE, TAGS_EVALUATION));
909      } else if (optionString.equals("mae")) {
910        setEvaluationMeasure(new SelectedTag(EVAL_MAE, TAGS_EVALUATION));
911      } else if (optionString.equals("auc")) {
912        setEvaluationMeasure(new SelectedTag(EVAL_AUC, TAGS_EVALUATION));
913      } else {
914        throw new IllegalArgumentException("Invalid evaluation measure");
915      }
916    }
917  }
918
919  /**
920   * Gets the current settings of the classifier.
921   *
922   * @return an array of strings suitable for passing to setOptions
923   */
924  public String [] getOptions() {
925
926    String [] options = new String [9];
927    int current = 0;
928
929    options[current++] = "-X"; options[current++] = "" + getCrossVal();
930
931    if (m_evaluationMeasure != EVAL_DEFAULT) {
932      options[current++] = "-E";
933      switch (m_evaluationMeasure) {
934      case EVAL_ACCURACY:
935        options[current++] = "acc";
936        break;
937      case EVAL_RMSE:
938        options[current++] = "rmse";
939        break;
940      case EVAL_MAE:
941        options[current++] = "mae";
942        break;
943      case EVAL_AUC:
944        options[current++] = "auc";
945        break;
946      }
947    }
948    if (m_useIBk) {
949      options[current++] = "-I";
950    }
951    if (m_displayRules) {
952      options[current++] = "-R";
953    }
954
955    while (current < options.length) {
956      options[current++] = "";
957    }
958    return options;
959  }
960 
961  /**
962   * Returns the revision string.
963   *
964   * @return            the revision
965   */
966  public String getRevision() {
967    return RevisionUtils.extract("$Revision: 1.4 $");
968  }
969
970  /**
971   * Main method for testing this class.
972   *
973   * @param argv the command-line options
974   */
975  public static void main(String [] argv) {
976    runClassifier(new DTNB(), argv);
977  }
978}
979
Note: See TracBrowser for help on using the repository browser.