1 | /* |
---|
2 | * This program is free software; you can redistribute it and/or modify |
---|
3 | * it under the terms of the GNU General Public License as published by |
---|
4 | * the Free Software Foundation; either version 2 of the License, or |
---|
5 | * (at your option) any later version. |
---|
6 | * |
---|
7 | * This program is distributed in the hope that it will be useful, |
---|
8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
---|
10 | * GNU General Public License for more details. |
---|
11 | * |
---|
12 | * You should have received a copy of the GNU General Public License |
---|
13 | * along with this program; if not, write to the Free Software |
---|
14 | * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. |
---|
15 | */ |
---|
16 | |
---|
17 | /* |
---|
18 | * DecisionTable.java |
---|
19 | * Copyright (C) 2008 University of Waikato, Hamilton, New Zealand |
---|
20 | * |
---|
21 | */ |
---|
22 | |
---|
23 | package weka.classifiers.rules; |
---|
24 | |
---|
25 | import weka.attributeSelection.ASEvaluation; |
---|
26 | import weka.attributeSelection.ASSearch; |
---|
27 | import weka.attributeSelection.SubsetEvaluator; |
---|
28 | import weka.classifiers.bayes.NaiveBayes; |
---|
29 | import weka.core.Capabilities; |
---|
30 | import weka.core.Instance; |
---|
31 | import weka.core.Instances; |
---|
32 | import weka.core.Option; |
---|
33 | import weka.core.RevisionUtils; |
---|
34 | import weka.core.SelectedTag; |
---|
35 | import weka.core.TechnicalInformation; |
---|
36 | import weka.core.Utils; |
---|
37 | import weka.core.Capabilities.Capability; |
---|
38 | import weka.core.TechnicalInformation.Field; |
---|
39 | import weka.core.TechnicalInformation.Type; |
---|
40 | |
---|
41 | import java.util.BitSet; |
---|
42 | import java.util.Enumeration; |
---|
43 | import java.util.Vector; |
---|
44 | |
---|
45 | /** |
---|
46 | * |
---|
47 | <!-- globalinfo-start --> |
---|
48 | * Class for building and using a decision table/naive bayes hybrid classifier. At each point in the search, the algorithm evaluates the merit of dividing the attributes into two disjoint subsets: one for the decision table, the other for naive Bayes. A forward selection search is used, where at each step, selected attributes are modeled by naive Bayes and the remainder by the decision table, and all attributes are modelled by the decision table initially. At each step, the algorithm also considers dropping an attribute entirely from the model.<br/> |
---|
49 | * <br/> |
---|
50 | * For more information, see: <br/> |
---|
51 | * <br/> |
---|
52 | * Mark Hall, Eibe Frank: Combining Naive Bayes and Decision Tables. In: Proceedings of the 21st Florida Artificial Intelligence Society Conference (FLAIRS), ???-???, 2008. |
---|
53 | * <p/> |
---|
54 | <!-- globalinfo-end --> |
---|
55 | * |
---|
56 | <!-- technical-bibtex-start --> |
---|
57 | * BibTeX: |
---|
58 | * <pre> |
---|
59 | * @inproceedings{Hall2008, |
---|
60 | * author = {Mark Hall and Eibe Frank}, |
---|
61 | * booktitle = {Proceedings of the 21st Florida Artificial Intelligence Society Conference (FLAIRS)}, |
---|
62 | * pages = {???-???}, |
---|
63 | * publisher = {AAAI press}, |
---|
64 | * title = {Combining Naive Bayes and Decision Tables}, |
---|
65 | * year = {2008} |
---|
66 | * } |
---|
67 | * </pre> |
---|
68 | * <p/> |
---|
69 | <!-- technical-bibtex-end --> |
---|
70 | * |
---|
71 | <!-- options-start --> |
---|
72 | * Valid options are: <p/> |
---|
73 | * |
---|
74 | * <pre> -X <number of folds> |
---|
75 | * Use cross validation to evaluate features. |
---|
76 | * Use number of folds = 1 for leave one out CV. |
---|
77 | * (Default = leave one out CV)</pre> |
---|
78 | * |
---|
79 | * <pre> -E <acc | rmse | mae | auc> |
---|
80 | * Performance evaluation measure to use for selecting attributes. |
---|
81 | * (Default = accuracy for discrete class and rmse for numeric class)</pre> |
---|
82 | * |
---|
83 | * <pre> -I |
---|
84 | * Use nearest neighbour instead of global table majority.</pre> |
---|
85 | * |
---|
86 | * <pre> -R |
---|
87 | * Display decision table rules. |
---|
88 | * </pre> |
---|
89 | * |
---|
90 | <!-- options-end --> |
---|
91 | * |
---|
92 | * @author Mark Hall (mhall{[at]}pentaho{[dot]}org) |
---|
93 | * @author Eibe Frank (eibe{[at]}cs{[dot]}waikato{[dot]}ac{[dot]}nz) |
---|
94 | * |
---|
95 | * @version $Revision: 1.4 $ |
---|
96 | * |
---|
97 | */ |
---|
98 | public class DTNB extends DecisionTable { |
---|
99 | |
---|
100 | /** |
---|
101 | * The naive Bayes half of the hybrid |
---|
102 | */ |
---|
103 | protected NaiveBayes m_NB; |
---|
104 | |
---|
105 | /** |
---|
106 | * The features used by naive Bayes |
---|
107 | */ |
---|
108 | private int [] m_nbFeatures; |
---|
109 | |
---|
110 | /** |
---|
111 | * Percentage of the total number of features used by the decision table |
---|
112 | */ |
---|
113 | private double m_percentUsedByDT; |
---|
114 | |
---|
115 | /** |
---|
116 | * Percentage of the features features that were dropped entirely |
---|
117 | */ |
---|
118 | private double m_percentDeleted; |
---|
119 | |
---|
120 | static final long serialVersionUID = 2999557077765701326L; |
---|
121 | |
---|
122 | /** |
---|
123 | * Returns a string describing classifier |
---|
124 | * @return a description suitable for |
---|
125 | * displaying in the explorer/experimenter gui |
---|
126 | */ |
---|
127 | public String globalInfo() { |
---|
128 | |
---|
129 | return |
---|
130 | "Class for building and using a decision table/naive bayes hybrid classifier. At each point " |
---|
131 | + "in the search, the algorithm evaluates the merit of dividing the attributes into two disjoint " |
---|
132 | + "subsets: one for the decision table, the other for naive Bayes. A forward selection search is " |
---|
133 | + "used, where at each step, selected attributes are modeled by naive Bayes and the remainder " |
---|
134 | + "by the decision table, and all attributes are modelled by the decision table initially. At each " |
---|
135 | + "step, the algorithm also considers dropping an attribute entirely from the model.\n\n" |
---|
136 | + "For more information, see: \n\n" |
---|
137 | + getTechnicalInformation().toString(); |
---|
138 | } |
---|
139 | |
---|
140 | /** |
---|
141 | * Returns an instance of a TechnicalInformation object, containing |
---|
142 | * detailed information about the technical background of this class, |
---|
143 | * e.g., paper reference or book this class is based on. |
---|
144 | * |
---|
145 | * @return the technical information about this class |
---|
146 | */ |
---|
147 | public TechnicalInformation getTechnicalInformation() { |
---|
148 | TechnicalInformation result; |
---|
149 | |
---|
150 | result = new TechnicalInformation(Type.INPROCEEDINGS); |
---|
151 | result.setValue(Field.AUTHOR, "Mark Hall and Eibe Frank"); |
---|
152 | result.setValue(Field.TITLE, "Combining Naive Bayes and Decision Tables"); |
---|
153 | result.setValue(Field.BOOKTITLE, "Proceedings of the 21st Florida Artificial Intelligence " |
---|
154 | + "Society Conference (FLAIRS)"); |
---|
155 | result.setValue(Field.YEAR, "2008"); |
---|
156 | result.setValue(Field.PAGES, "???-???"); |
---|
157 | result.setValue(Field.PUBLISHER, "AAAI press"); |
---|
158 | |
---|
159 | return result; |
---|
160 | } |
---|
161 | |
---|
162 | /** |
---|
163 | * Calculates the accuracy on a test fold for internal cross validation |
---|
164 | * of feature sets |
---|
165 | * |
---|
166 | * @param fold set of instances to be "left out" and classified |
---|
167 | * @param fs currently selected feature set |
---|
168 | * @return the accuracy for the fold |
---|
169 | * @throws Exception if something goes wrong |
---|
170 | */ |
---|
171 | double evaluateFoldCV(Instances fold, int [] fs) throws Exception { |
---|
172 | |
---|
173 | int i; |
---|
174 | int ruleCount = 0; |
---|
175 | int numFold = fold.numInstances(); |
---|
176 | int numCl = m_theInstances.classAttribute().numValues(); |
---|
177 | double [][] class_distribs = new double [numFold][numCl]; |
---|
178 | double [] instA = new double [fs.length]; |
---|
179 | double [] normDist; |
---|
180 | DecisionTableHashKey thekey; |
---|
181 | double acc = 0.0; |
---|
182 | int classI = m_theInstances.classIndex(); |
---|
183 | Instance inst; |
---|
184 | |
---|
185 | if (m_classIsNominal) { |
---|
186 | normDist = new double [numCl]; |
---|
187 | } else { |
---|
188 | normDist = new double [2]; |
---|
189 | } |
---|
190 | |
---|
191 | // first *remove* instances |
---|
192 | for (i=0;i<numFold;i++) { |
---|
193 | inst = fold.instance(i); |
---|
194 | for (int j=0;j<fs.length;j++) { |
---|
195 | if (fs[j] == classI) { |
---|
196 | instA[j] = Double.MAX_VALUE; // missing for the class |
---|
197 | } else if (inst.isMissing(fs[j])) { |
---|
198 | instA[j] = Double.MAX_VALUE; |
---|
199 | } else{ |
---|
200 | instA[j] = inst.value(fs[j]); |
---|
201 | } |
---|
202 | } |
---|
203 | thekey = new DecisionTableHashKey(instA); |
---|
204 | if ((class_distribs[i] = (double [])m_entries.get(thekey)) == null) { |
---|
205 | throw new Error("This should never happen!"); |
---|
206 | } else { |
---|
207 | if (m_classIsNominal) { |
---|
208 | class_distribs[i][(int)inst.classValue()] -= inst.weight(); |
---|
209 | inst.setWeight(-inst.weight()); |
---|
210 | m_NB.updateClassifier(inst); |
---|
211 | inst.setWeight(-inst.weight()); |
---|
212 | } else { |
---|
213 | class_distribs[i][0] -= (inst.classValue() * inst.weight()); |
---|
214 | class_distribs[i][1] -= inst.weight(); |
---|
215 | } |
---|
216 | ruleCount++; |
---|
217 | } |
---|
218 | m_classPriorCounts[(int)inst.classValue()] -= |
---|
219 | inst.weight(); |
---|
220 | } |
---|
221 | double [] classPriors = m_classPriorCounts.clone(); |
---|
222 | Utils.normalize(classPriors); |
---|
223 | |
---|
224 | // now classify instances |
---|
225 | for (i=0;i<numFold;i++) { |
---|
226 | inst = fold.instance(i); |
---|
227 | System.arraycopy(class_distribs[i],0,normDist,0,normDist.length); |
---|
228 | if (m_classIsNominal) { |
---|
229 | boolean ok = false; |
---|
230 | for (int j=0;j<normDist.length;j++) { |
---|
231 | if (Utils.gr(normDist[j],1.0)) { |
---|
232 | ok = true; |
---|
233 | break; |
---|
234 | } |
---|
235 | } |
---|
236 | |
---|
237 | if (!ok) { // majority class |
---|
238 | normDist = classPriors.clone(); |
---|
239 | } else { |
---|
240 | Utils.normalize(normDist); |
---|
241 | } |
---|
242 | |
---|
243 | double [] nbDist = m_NB.distributionForInstance(inst); |
---|
244 | |
---|
245 | for (int l = 0; l < normDist.length; l++) { |
---|
246 | normDist[l] = (Math.log(normDist[l]) - Math.log(classPriors[l])); |
---|
247 | normDist[l] += Math.log(nbDist[l]); |
---|
248 | } |
---|
249 | normDist = Utils.logs2probs(normDist); |
---|
250 | // Utils.normalize(normDist); |
---|
251 | |
---|
252 | // System.out.println(normDist[0] + " " + normDist[1] + " " + inst.classValue()); |
---|
253 | |
---|
254 | if (m_evaluationMeasure == EVAL_AUC) { |
---|
255 | m_evaluation.evaluateModelOnceAndRecordPrediction(normDist, inst); |
---|
256 | } else { |
---|
257 | m_evaluation.evaluateModelOnce(normDist, inst); |
---|
258 | } |
---|
259 | /* } else { |
---|
260 | normDist[(int)m_majority] = 1.0; |
---|
261 | if (m_evaluationMeasure == EVAL_AUC) { |
---|
262 | m_evaluation.evaluateModelOnceAndRecordPrediction(normDist, inst); |
---|
263 | } else { |
---|
264 | m_evaluation.evaluateModelOnce(normDist, inst); |
---|
265 | } |
---|
266 | } */ |
---|
267 | } else { |
---|
268 | if (Utils.eq(normDist[1],0.0)) { |
---|
269 | double [] temp = new double[1]; |
---|
270 | temp[0] = m_majority; |
---|
271 | m_evaluation.evaluateModelOnce(temp, inst); |
---|
272 | } else { |
---|
273 | double [] temp = new double[1]; |
---|
274 | temp[0] = normDist[0] / normDist[1]; |
---|
275 | m_evaluation.evaluateModelOnce(temp, inst); |
---|
276 | } |
---|
277 | } |
---|
278 | } |
---|
279 | |
---|
280 | // now re-insert instances |
---|
281 | for (i=0;i<numFold;i++) { |
---|
282 | inst = fold.instance(i); |
---|
283 | |
---|
284 | m_classPriorCounts[(int)inst.classValue()] += |
---|
285 | inst.weight(); |
---|
286 | |
---|
287 | if (m_classIsNominal) { |
---|
288 | class_distribs[i][(int)inst.classValue()] += inst.weight(); |
---|
289 | m_NB.updateClassifier(inst); |
---|
290 | } else { |
---|
291 | class_distribs[i][0] += (inst.classValue() * inst.weight()); |
---|
292 | class_distribs[i][1] += inst.weight(); |
---|
293 | } |
---|
294 | } |
---|
295 | return acc; |
---|
296 | } |
---|
297 | |
---|
298 | /** |
---|
299 | * Classifies an instance for internal leave one out cross validation |
---|
300 | * of feature sets |
---|
301 | * |
---|
302 | * @param instance instance to be "left out" and classified |
---|
303 | * @param instA feature values of the selected features for the instance |
---|
304 | * @return the classification of the instance |
---|
305 | * @throws Exception if something goes wrong |
---|
306 | */ |
---|
307 | double evaluateInstanceLeaveOneOut(Instance instance, double [] instA) |
---|
308 | throws Exception { |
---|
309 | |
---|
310 | DecisionTableHashKey thekey; |
---|
311 | double [] tempDist; |
---|
312 | double [] normDist; |
---|
313 | |
---|
314 | thekey = new DecisionTableHashKey(instA); |
---|
315 | |
---|
316 | // if this one is not in the table |
---|
317 | if ((tempDist = (double [])m_entries.get(thekey)) == null) { |
---|
318 | throw new Error("This should never happen!"); |
---|
319 | } else { |
---|
320 | normDist = new double [tempDist.length]; |
---|
321 | System.arraycopy(tempDist,0,normDist,0,tempDist.length); |
---|
322 | normDist[(int)instance.classValue()] -= instance.weight(); |
---|
323 | |
---|
324 | // update the table |
---|
325 | // first check to see if the class counts are all zero now |
---|
326 | boolean ok = false; |
---|
327 | for (int i=0;i<normDist.length;i++) { |
---|
328 | if (Utils.gr(normDist[i],1.0)) { |
---|
329 | ok = true; |
---|
330 | break; |
---|
331 | } |
---|
332 | } |
---|
333 | |
---|
334 | // downdate the class prior counts |
---|
335 | m_classPriorCounts[(int)instance.classValue()] -= |
---|
336 | instance.weight(); |
---|
337 | double [] classPriors = m_classPriorCounts.clone(); |
---|
338 | Utils.normalize(classPriors); |
---|
339 | if (!ok) { // majority class |
---|
340 | normDist = classPriors; |
---|
341 | } else { |
---|
342 | Utils.normalize(normDist); |
---|
343 | } |
---|
344 | |
---|
345 | m_classPriorCounts[(int)instance.classValue()] += |
---|
346 | instance.weight(); |
---|
347 | |
---|
348 | if (m_NB != null){ |
---|
349 | // downdate NaiveBayes |
---|
350 | |
---|
351 | instance.setWeight(-instance.weight()); |
---|
352 | m_NB.updateClassifier(instance); |
---|
353 | double [] nbDist = m_NB.distributionForInstance(instance); |
---|
354 | instance.setWeight(-instance.weight()); |
---|
355 | m_NB.updateClassifier(instance); |
---|
356 | |
---|
357 | for (int i = 0; i < normDist.length; i++) { |
---|
358 | normDist[i] = (Math.log(normDist[i]) - Math.log(classPriors[i])); |
---|
359 | normDist[i] += Math.log(nbDist[i]); |
---|
360 | } |
---|
361 | normDist = Utils.logs2probs(normDist); |
---|
362 | // Utils.normalize(normDist); |
---|
363 | } |
---|
364 | |
---|
365 | if (m_evaluationMeasure == EVAL_AUC) { |
---|
366 | m_evaluation.evaluateModelOnceAndRecordPrediction(normDist, instance); |
---|
367 | } else { |
---|
368 | m_evaluation.evaluateModelOnce(normDist, instance); |
---|
369 | } |
---|
370 | return Utils.maxIndex(normDist); |
---|
371 | } |
---|
372 | } |
---|
373 | |
---|
374 | /** |
---|
375 | * Sets up a dummy subset evaluator that basically just delegates |
---|
376 | * evaluation to the estimatePerformance method in DecisionTable |
---|
377 | */ |
---|
378 | protected void setUpEvaluator() throws Exception { |
---|
379 | m_evaluator = new EvalWithDelete(); |
---|
380 | m_evaluator.buildEvaluator(m_theInstances); |
---|
381 | } |
---|
382 | |
---|
383 | protected class EvalWithDelete extends ASEvaluation implements SubsetEvaluator { |
---|
384 | |
---|
385 | // holds the list of attributes that are no longer in the model at all |
---|
386 | private BitSet m_deletedFromDTNB; |
---|
387 | |
---|
388 | public void buildEvaluator(Instances data) throws Exception { |
---|
389 | m_NB = null; |
---|
390 | m_deletedFromDTNB = new BitSet(data.numAttributes()); |
---|
391 | // System.err.println("Here"); |
---|
392 | } |
---|
393 | |
---|
394 | private int setUpForEval(BitSet subset) throws Exception { |
---|
395 | |
---|
396 | int fc = 0; |
---|
397 | for (int jj = 0;jj < m_numAttributes; jj++) { |
---|
398 | if (subset.get(jj)) { |
---|
399 | fc++; |
---|
400 | } |
---|
401 | } |
---|
402 | |
---|
403 | //int [] nbFs = new int [fc]; |
---|
404 | //int count = 0; |
---|
405 | |
---|
406 | for (int j = 0; j < m_numAttributes; j++) { |
---|
407 | m_theInstances.attribute(j).setWeight(1.0); // reset weight |
---|
408 | if (j != m_theInstances.classIndex()) { |
---|
409 | if (subset.get(j)) { |
---|
410 | // nbFs[count++] = j; |
---|
411 | m_theInstances.attribute(j).setWeight(0.0); // no influence for NB |
---|
412 | } |
---|
413 | } |
---|
414 | } |
---|
415 | |
---|
416 | // process delete set |
---|
417 | for (int i = 0; i < m_numAttributes; i++) { |
---|
418 | if (m_deletedFromDTNB.get(i)) { |
---|
419 | m_theInstances.attribute(i).setWeight(0.0); // no influence for NB |
---|
420 | } |
---|
421 | } |
---|
422 | |
---|
423 | if (m_NB == null) { |
---|
424 | // construct naive bayes for the first time |
---|
425 | m_NB = new NaiveBayes(); |
---|
426 | m_NB.buildClassifier(m_theInstances); |
---|
427 | } |
---|
428 | return fc; |
---|
429 | } |
---|
430 | |
---|
431 | public double evaluateSubset(BitSet subset) throws Exception { |
---|
432 | int fc = setUpForEval(subset); |
---|
433 | |
---|
434 | return estimatePerformance(subset, fc); |
---|
435 | } |
---|
436 | |
---|
437 | public double evaluateSubsetDelete(BitSet subset, int potentialDelete) throws Exception { |
---|
438 | |
---|
439 | int fc = setUpForEval(subset); |
---|
440 | |
---|
441 | // clear potentail delete for naive Bayes |
---|
442 | m_theInstances.attribute(potentialDelete).setWeight(0.0); |
---|
443 | //copy.clear(potentialDelete); |
---|
444 | //fc--; |
---|
445 | return estimatePerformance(subset, fc); |
---|
446 | } |
---|
447 | |
---|
448 | public BitSet getDeletedList() { |
---|
449 | return m_deletedFromDTNB; |
---|
450 | } |
---|
451 | |
---|
452 | /** |
---|
453 | * Returns the revision string. |
---|
454 | * |
---|
455 | * @return the revision |
---|
456 | */ |
---|
457 | public String getRevision() { |
---|
458 | return RevisionUtils.extract("$Revision: 1.4 $"); |
---|
459 | } |
---|
460 | } |
---|
461 | |
---|
462 | protected ASSearch m_backwardWithDelete; |
---|
463 | |
---|
464 | /** |
---|
465 | * Inner class implementing a special forwards search that looks for a good |
---|
466 | * split of attributes between naive Bayes and the decision table. It also |
---|
467 | * considers dropping attributes entirely from the model. |
---|
468 | */ |
---|
469 | protected class BackwardsWithDelete extends ASSearch { |
---|
470 | |
---|
471 | public String globalInfo() { |
---|
472 | return "Specialized search that performs a forward selection (naive Bayes)/" |
---|
473 | + "backward elimination (decision table). Also considers dropping attributes " |
---|
474 | + "entirely from the combined model."; |
---|
475 | } |
---|
476 | |
---|
477 | public String toString() { |
---|
478 | return ""; |
---|
479 | } |
---|
480 | |
---|
481 | public int [] search(ASEvaluation eval, Instances data) |
---|
482 | throws Exception { |
---|
483 | int i; |
---|
484 | double best_merit = -Double.MAX_VALUE; |
---|
485 | double temp_best = 0, temp_merit = 0, temp_merit_delete = 0; |
---|
486 | int temp_index=0; |
---|
487 | BitSet temp_group; |
---|
488 | BitSet best_group = null; |
---|
489 | |
---|
490 | int numAttribs = data.numAttributes(); |
---|
491 | |
---|
492 | if (best_group == null) { |
---|
493 | best_group = new BitSet(numAttribs); |
---|
494 | } |
---|
495 | |
---|
496 | |
---|
497 | int classIndex = data.classIndex(); |
---|
498 | for (i = 0; i < numAttribs; i++) { |
---|
499 | if (i != classIndex) { |
---|
500 | best_group.set(i); |
---|
501 | } |
---|
502 | } |
---|
503 | |
---|
504 | //System.err.println(best_group); |
---|
505 | |
---|
506 | // Evaluate the initial subset |
---|
507 | // best_merit = m_evaluator.evaluateSubset(best_group); |
---|
508 | best_merit = ((SubsetEvaluator)eval).evaluateSubset(best_group); |
---|
509 | |
---|
510 | //System.err.println(best_merit); |
---|
511 | |
---|
512 | // main search loop |
---|
513 | boolean done = false; |
---|
514 | boolean addone = false; |
---|
515 | boolean z; |
---|
516 | boolean deleted = false; |
---|
517 | while (!done) { |
---|
518 | temp_group = (BitSet)best_group.clone(); |
---|
519 | temp_best = best_merit; |
---|
520 | |
---|
521 | done = true; |
---|
522 | addone = false; |
---|
523 | for (i = 0; i < numAttribs;i++) { |
---|
524 | z = ((i != classIndex) && (temp_group.get(i))); |
---|
525 | |
---|
526 | if (z) { |
---|
527 | // set/unset the bit |
---|
528 | temp_group.clear(i); |
---|
529 | |
---|
530 | // temp_merit = m_evaluator.evaluateSubset(temp_group); |
---|
531 | temp_merit = ((SubsetEvaluator)eval).evaluateSubset(temp_group); |
---|
532 | // temp_merit_delete = ((EvalWithDelete)m_evaluator).evaluateSubsetDelete(temp_group, i); |
---|
533 | temp_merit_delete = ((EvalWithDelete)eval).evaluateSubsetDelete(temp_group, i); |
---|
534 | boolean deleteBetter = false; |
---|
535 | //System.out.println("Merit: " + temp_merit + "\t" + "Delete merit: " + temp_merit_delete); |
---|
536 | if (temp_merit_delete >= temp_merit) { |
---|
537 | temp_merit = temp_merit_delete; |
---|
538 | deleteBetter = true; |
---|
539 | } |
---|
540 | |
---|
541 | z = (temp_merit >= temp_best); |
---|
542 | |
---|
543 | if (z) { |
---|
544 | temp_best = temp_merit; |
---|
545 | temp_index = i; |
---|
546 | addone = true; |
---|
547 | done = false; |
---|
548 | if (deleteBetter) { |
---|
549 | deleted = true; |
---|
550 | } else { |
---|
551 | deleted = false; |
---|
552 | } |
---|
553 | } |
---|
554 | |
---|
555 | // unset this addition/deletion |
---|
556 | temp_group.set(i); |
---|
557 | } |
---|
558 | } |
---|
559 | if (addone) { |
---|
560 | best_group.clear(temp_index); |
---|
561 | best_merit = temp_best; |
---|
562 | if (deleted) { |
---|
563 | // ((EvalWithDelete)m_evaluator).getDeletedList().set(temp_index); |
---|
564 | ((EvalWithDelete)eval).getDeletedList().set(temp_index); |
---|
565 | } |
---|
566 | //System.err.println("----------------------"); |
---|
567 | //System.err.println("Best subset: (dec table)" + best_group); |
---|
568 | //System.err.println("Best subset: (deleted)" + ((EvalWithDelete)m_evaluator).getDeletedList()); |
---|
569 | //System.err.println(best_merit); |
---|
570 | } |
---|
571 | } |
---|
572 | return attributeList(best_group); |
---|
573 | } |
---|
574 | |
---|
575 | /** |
---|
576 | * converts a BitSet into a list of attribute indexes |
---|
577 | * @param group the BitSet to convert |
---|
578 | * @return an array of attribute indexes |
---|
579 | **/ |
---|
580 | protected int[] attributeList (BitSet group) { |
---|
581 | int count = 0; |
---|
582 | BitSet copy = (BitSet)group.clone(); |
---|
583 | |
---|
584 | /* remove any that have been completely deleted from DTNB |
---|
585 | BitSet deleted = ((EvalWithDelete)m_evaluator).getDeletedList(); |
---|
586 | for (int i = 0; i < m_numAttributes; i++) { |
---|
587 | if (deleted.get(i)) { |
---|
588 | copy.clear(i); |
---|
589 | } |
---|
590 | } */ |
---|
591 | |
---|
592 | // count how many were selected |
---|
593 | for (int i = 0; i < m_numAttributes; i++) { |
---|
594 | if (copy.get(i)) { |
---|
595 | count++; |
---|
596 | } |
---|
597 | } |
---|
598 | |
---|
599 | int[] list = new int[count]; |
---|
600 | count = 0; |
---|
601 | |
---|
602 | for (int i = 0; i < m_numAttributes; i++) { |
---|
603 | if (copy.get(i)) { |
---|
604 | list[count++] = i; |
---|
605 | } |
---|
606 | } |
---|
607 | |
---|
608 | return list; |
---|
609 | } |
---|
610 | |
---|
611 | /** |
---|
612 | * Returns the revision string. |
---|
613 | * |
---|
614 | * @return the revision |
---|
615 | */ |
---|
616 | public String getRevision() { |
---|
617 | return RevisionUtils.extract("$Revision: 1.4 $"); |
---|
618 | } |
---|
619 | } |
---|
620 | |
---|
621 | private void setUpSearch() { |
---|
622 | m_backwardWithDelete = new BackwardsWithDelete(); |
---|
623 | } |
---|
624 | |
---|
625 | /** |
---|
626 | * Generates the classifier. |
---|
627 | * |
---|
628 | * @param data set of instances serving as training data |
---|
629 | * @throws Exception if the classifier has not been generated successfully |
---|
630 | */ |
---|
631 | public void buildClassifier(Instances data) throws Exception { |
---|
632 | |
---|
633 | m_saveMemory = false; |
---|
634 | |
---|
635 | if (data.classAttribute().isNumeric()) { |
---|
636 | throw new Exception("Can only handle nominal class!"); |
---|
637 | } |
---|
638 | |
---|
639 | if (m_backwardWithDelete == null) { |
---|
640 | setUpSearch(); |
---|
641 | m_search = m_backwardWithDelete; |
---|
642 | } |
---|
643 | |
---|
644 | /* if (m_search != m_backwardWithDelete) { |
---|
645 | m_search = m_backwardWithDelete; |
---|
646 | } */ |
---|
647 | super.buildClassifier(data); |
---|
648 | |
---|
649 | // new NB stuff |
---|
650 | |
---|
651 | // delete the features used by the decision table (not the class!!) |
---|
652 | for (int i = 0; i < m_theInstances.numAttributes(); i++) { |
---|
653 | m_theInstances.attribute(i).setWeight(1.0); // reset all weights |
---|
654 | } |
---|
655 | // m_nbFeatures = new int [m_decisionFeatures.length - 1]; |
---|
656 | int count = 0; |
---|
657 | |
---|
658 | for (int i = 0; i < m_decisionFeatures.length; i++) { |
---|
659 | if (m_decisionFeatures[i] != m_theInstances.classIndex()) { |
---|
660 | count++; |
---|
661 | // m_nbFeatures[count++] = m_decisionFeatures[i]; |
---|
662 | m_theInstances.attribute(m_decisionFeatures[i]).setWeight(0.0); // No influence for NB |
---|
663 | } |
---|
664 | } |
---|
665 | |
---|
666 | double numDeleted = 0; |
---|
667 | // remove any attributes that have been deleted completely from the DTNB |
---|
668 | BitSet deleted = ((EvalWithDelete)m_evaluator).getDeletedList(); |
---|
669 | for (int i = 0; i < m_theInstances.numAttributes(); i++) { |
---|
670 | if (deleted.get(i)) { |
---|
671 | m_theInstances.attribute(i).setWeight(0.0); |
---|
672 | // count--; |
---|
673 | numDeleted++; |
---|
674 | // System.err.println("Attribute "+i+" was eliminated completely"); |
---|
675 | } |
---|
676 | } |
---|
677 | |
---|
678 | m_percentUsedByDT = (double)count / (m_theInstances.numAttributes() - 1); |
---|
679 | m_percentDeleted = numDeleted / (m_theInstances.numAttributes() -1); |
---|
680 | |
---|
681 | m_NB = new NaiveBayes(); |
---|
682 | m_NB.buildClassifier(m_theInstances); |
---|
683 | |
---|
684 | m_dtInstances = new Instances(m_dtInstances, 0); |
---|
685 | m_theInstances = new Instances(m_theInstances, 0); |
---|
686 | } |
---|
687 | |
---|
688 | /** |
---|
689 | * Calculates the class membership probabilities for the given |
---|
690 | * test instance. |
---|
691 | * |
---|
692 | * @param instance the instance to be classified |
---|
693 | * @return predicted class probability distribution |
---|
694 | * @exception Exception if distribution can't be computed |
---|
695 | */ |
---|
696 | public double [] distributionForInstance(Instance instance) |
---|
697 | throws Exception { |
---|
698 | |
---|
699 | DecisionTableHashKey thekey; |
---|
700 | double [] tempDist; |
---|
701 | double [] normDist; |
---|
702 | |
---|
703 | m_disTransform.input(instance); |
---|
704 | m_disTransform.batchFinished(); |
---|
705 | instance = m_disTransform.output(); |
---|
706 | |
---|
707 | m_delTransform.input(instance); |
---|
708 | m_delTransform.batchFinished(); |
---|
709 | Instance dtInstance = m_delTransform.output(); |
---|
710 | |
---|
711 | thekey = new DecisionTableHashKey(dtInstance, dtInstance.numAttributes(), false); |
---|
712 | |
---|
713 | // if this one is not in the table |
---|
714 | if ((tempDist = (double [])m_entries.get(thekey)) == null) { |
---|
715 | if (m_useIBk) { |
---|
716 | tempDist = m_ibk.distributionForInstance(dtInstance); |
---|
717 | } else { |
---|
718 | // tempDist = new double [m_theInstances.classAttribute().numValues()]; |
---|
719 | // tempDist[(int)m_majority] = 1.0; |
---|
720 | |
---|
721 | tempDist = m_classPriors.clone(); |
---|
722 | // return tempDist; ?????? |
---|
723 | } |
---|
724 | } else { |
---|
725 | // normalise distribution |
---|
726 | normDist = new double [tempDist.length]; |
---|
727 | System.arraycopy(tempDist,0,normDist,0,tempDist.length); |
---|
728 | Utils.normalize(normDist); |
---|
729 | tempDist = normDist; |
---|
730 | } |
---|
731 | |
---|
732 | double [] nbDist = m_NB.distributionForInstance(instance); |
---|
733 | for (int i = 0; i < nbDist.length; i++) { |
---|
734 | tempDist[i] = (Math.log(tempDist[i]) - Math.log(m_classPriors[i])); |
---|
735 | tempDist[i] += Math.log(nbDist[i]); |
---|
736 | |
---|
737 | /*tempDist[i] *= nbDist[i]; |
---|
738 | tempDist[i] /= m_classPriors[i];*/ |
---|
739 | } |
---|
740 | tempDist = Utils.logs2probs(tempDist); |
---|
741 | Utils.normalize(tempDist); |
---|
742 | |
---|
743 | return tempDist; |
---|
744 | } |
---|
745 | |
---|
746 | public String toString() { |
---|
747 | |
---|
748 | String sS = super.toString(); |
---|
749 | if (m_displayRules && m_NB != null) { |
---|
750 | sS += m_NB.toString(); |
---|
751 | } |
---|
752 | return sS; |
---|
753 | } |
---|
754 | |
---|
755 | /** |
---|
756 | * Returns the number of rules |
---|
757 | * @return the number of rules |
---|
758 | */ |
---|
759 | public double measurePercentAttsUsedByDT() { |
---|
760 | return m_percentUsedByDT; |
---|
761 | } |
---|
762 | |
---|
763 | /** |
---|
764 | * Returns an enumeration of the additional measure names |
---|
765 | * @return an enumeration of the measure names |
---|
766 | */ |
---|
767 | public Enumeration enumerateMeasures() { |
---|
768 | Vector newVector = new Vector(2); |
---|
769 | newVector.addElement("measureNumRules"); |
---|
770 | newVector.addElement("measurePercentAttsUsedByDT"); |
---|
771 | return newVector.elements(); |
---|
772 | } |
---|
773 | |
---|
774 | /** |
---|
775 | * Returns the value of the named measure |
---|
776 | * @param additionalMeasureName the name of the measure to query for its value |
---|
777 | * @return the value of the named measure |
---|
778 | * @throws IllegalArgumentException if the named measure is not supported |
---|
779 | */ |
---|
780 | public double getMeasure(String additionalMeasureName) { |
---|
781 | if (additionalMeasureName.compareToIgnoreCase("measureNumRules") == 0) { |
---|
782 | return measureNumRules(); |
---|
783 | } else if (additionalMeasureName.compareToIgnoreCase("measurePercentAttsUsedByDT") == 0) { |
---|
784 | return measurePercentAttsUsedByDT(); |
---|
785 | } else { |
---|
786 | throw new IllegalArgumentException(additionalMeasureName |
---|
787 | + " not supported (DecisionTable)"); |
---|
788 | } |
---|
789 | } |
---|
790 | |
---|
791 | /** |
---|
792 | * Returns default capabilities of the classifier. |
---|
793 | * |
---|
794 | * @return the capabilities of this classifier |
---|
795 | */ |
---|
796 | public Capabilities getCapabilities() { |
---|
797 | Capabilities result = super.getCapabilities(); |
---|
798 | |
---|
799 | result.disable(Capability.NUMERIC_CLASS); |
---|
800 | result.disable(Capability.DATE_CLASS); |
---|
801 | |
---|
802 | return result; |
---|
803 | } |
---|
804 | |
---|
805 | /** |
---|
806 | * Sets the search method to use |
---|
807 | * |
---|
808 | * @param search |
---|
809 | */ |
---|
810 | public void setSearch(ASSearch search) { |
---|
811 | // Search method cannot be changed. |
---|
812 | // Must be BackwardsWithDelete |
---|
813 | return; |
---|
814 | } |
---|
815 | |
---|
816 | /** |
---|
817 | * Gets the current search method |
---|
818 | * |
---|
819 | * @return the search method used |
---|
820 | */ |
---|
821 | public ASSearch getSearch() { |
---|
822 | if (m_backwardWithDelete == null) { |
---|
823 | setUpSearch(); |
---|
824 | // setSearch(m_backwardWithDelete); |
---|
825 | m_search = m_backwardWithDelete; |
---|
826 | } |
---|
827 | return m_search; |
---|
828 | } |
---|
829 | |
---|
830 | /** |
---|
831 | * Returns an enumeration describing the available options. |
---|
832 | * |
---|
833 | * @return an enumeration of all the available options. |
---|
834 | */ |
---|
835 | public Enumeration listOptions() { |
---|
836 | |
---|
837 | Vector newVector = new Vector(7); |
---|
838 | |
---|
839 | newVector.addElement(new Option( |
---|
840 | "\tUse cross validation to evaluate features.\n" + |
---|
841 | "\tUse number of folds = 1 for leave one out CV.\n" + |
---|
842 | "\t(Default = leave one out CV)", |
---|
843 | "X", 1, "-X <number of folds>")); |
---|
844 | |
---|
845 | newVector.addElement(new Option( |
---|
846 | "\tPerformance evaluation measure to use for selecting attributes.\n" + |
---|
847 | "\t(Default = accuracy for discrete class and rmse for numeric class)", |
---|
848 | "E", 1, "-E <acc | rmse | mae | auc>")); |
---|
849 | |
---|
850 | newVector.addElement(new Option( |
---|
851 | "\tUse nearest neighbour instead of global table majority.", |
---|
852 | "I", 0, "-I")); |
---|
853 | |
---|
854 | newVector.addElement(new Option( |
---|
855 | "\tDisplay decision table rules.\n", |
---|
856 | "R", 0, "-R")); |
---|
857 | |
---|
858 | return newVector.elements(); |
---|
859 | } |
---|
860 | |
---|
861 | /** |
---|
862 | * Parses the options for this object. <p/> |
---|
863 | * |
---|
864 | <!-- options-start --> |
---|
865 | * Valid options are: <p/> |
---|
866 | * |
---|
867 | * <pre> -X <number of folds> |
---|
868 | * Use cross validation to evaluate features. |
---|
869 | * Use number of folds = 1 for leave one out CV. |
---|
870 | * (Default = leave one out CV)</pre> |
---|
871 | * |
---|
872 | * <pre> -E <acc | rmse | mae | auc> |
---|
873 | * Performance evaluation measure to use for selecting attributes. |
---|
874 | * (Default = accuracy for discrete class and rmse for numeric class)</pre> |
---|
875 | * |
---|
876 | * <pre> -I |
---|
877 | * Use nearest neighbour instead of global table majority.</pre> |
---|
878 | * |
---|
879 | * <pre> -R |
---|
880 | * Display decision table rules. |
---|
881 | * </pre> |
---|
882 | * |
---|
883 | <!-- options-end --> |
---|
884 | * |
---|
885 | * @param options the list of options as an array of strings |
---|
886 | * @throws Exception if an option is not supported |
---|
887 | */ |
---|
888 | public void setOptions(String[] options) throws Exception { |
---|
889 | |
---|
890 | String optionString; |
---|
891 | |
---|
892 | resetOptions(); |
---|
893 | |
---|
894 | optionString = Utils.getOption('X',options); |
---|
895 | if (optionString.length() != 0) { |
---|
896 | setCrossVal(Integer.parseInt(optionString)); |
---|
897 | } |
---|
898 | |
---|
899 | m_useIBk = Utils.getFlag('I',options); |
---|
900 | |
---|
901 | m_displayRules = Utils.getFlag('R',options); |
---|
902 | |
---|
903 | optionString = Utils.getOption('E', options); |
---|
904 | if (optionString.length() != 0) { |
---|
905 | if (optionString.equals("acc")) { |
---|
906 | setEvaluationMeasure(new SelectedTag(EVAL_ACCURACY, TAGS_EVALUATION)); |
---|
907 | } else if (optionString.equals("rmse")) { |
---|
908 | setEvaluationMeasure(new SelectedTag(EVAL_RMSE, TAGS_EVALUATION)); |
---|
909 | } else if (optionString.equals("mae")) { |
---|
910 | setEvaluationMeasure(new SelectedTag(EVAL_MAE, TAGS_EVALUATION)); |
---|
911 | } else if (optionString.equals("auc")) { |
---|
912 | setEvaluationMeasure(new SelectedTag(EVAL_AUC, TAGS_EVALUATION)); |
---|
913 | } else { |
---|
914 | throw new IllegalArgumentException("Invalid evaluation measure"); |
---|
915 | } |
---|
916 | } |
---|
917 | } |
---|
918 | |
---|
919 | /** |
---|
920 | * Gets the current settings of the classifier. |
---|
921 | * |
---|
922 | * @return an array of strings suitable for passing to setOptions |
---|
923 | */ |
---|
924 | public String [] getOptions() { |
---|
925 | |
---|
926 | String [] options = new String [9]; |
---|
927 | int current = 0; |
---|
928 | |
---|
929 | options[current++] = "-X"; options[current++] = "" + getCrossVal(); |
---|
930 | |
---|
931 | if (m_evaluationMeasure != EVAL_DEFAULT) { |
---|
932 | options[current++] = "-E"; |
---|
933 | switch (m_evaluationMeasure) { |
---|
934 | case EVAL_ACCURACY: |
---|
935 | options[current++] = "acc"; |
---|
936 | break; |
---|
937 | case EVAL_RMSE: |
---|
938 | options[current++] = "rmse"; |
---|
939 | break; |
---|
940 | case EVAL_MAE: |
---|
941 | options[current++] = "mae"; |
---|
942 | break; |
---|
943 | case EVAL_AUC: |
---|
944 | options[current++] = "auc"; |
---|
945 | break; |
---|
946 | } |
---|
947 | } |
---|
948 | if (m_useIBk) { |
---|
949 | options[current++] = "-I"; |
---|
950 | } |
---|
951 | if (m_displayRules) { |
---|
952 | options[current++] = "-R"; |
---|
953 | } |
---|
954 | |
---|
955 | while (current < options.length) { |
---|
956 | options[current++] = ""; |
---|
957 | } |
---|
958 | return options; |
---|
959 | } |
---|
960 | |
---|
961 | /** |
---|
962 | * Returns the revision string. |
---|
963 | * |
---|
964 | * @return the revision |
---|
965 | */ |
---|
966 | public String getRevision() { |
---|
967 | return RevisionUtils.extract("$Revision: 1.4 $"); |
---|
968 | } |
---|
969 | |
---|
970 | /** |
---|
971 | * Main method for testing this class. |
---|
972 | * |
---|
973 | * @param argv the command-line options |
---|
974 | */ |
---|
975 | public static void main(String [] argv) { |
---|
976 | runClassifier(new DTNB(), argv); |
---|
977 | } |
---|
978 | } |
---|
979 | |
---|