1 | /* |
---|
2 | * This program is free software; you can redistribute it and/or modify |
---|
3 | * it under the terms of the GNU General Public License as published by |
---|
4 | * the Free Software Foundation; either version 2 of the License, or |
---|
5 | * (at your option) any later version. |
---|
6 | * |
---|
7 | * This program is distributed in the hope that it will be useful, |
---|
8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
---|
10 | * GNU General Public License for more details. |
---|
11 | * |
---|
12 | * You should have received a copy of the GNU General Public License |
---|
13 | * along with this program; if not, write to the Free Software |
---|
14 | * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. |
---|
15 | */ |
---|
16 | |
---|
17 | /* |
---|
18 | * LADTree.java |
---|
19 | * Copyright (C) 2001 University of Waikato, Hamilton, New Zealand |
---|
20 | * |
---|
21 | */ |
---|
22 | |
---|
23 | package weka.classifiers.trees; |
---|
24 | |
---|
25 | import weka.classifiers.*; |
---|
26 | import weka.core.Capabilities; |
---|
27 | import weka.core.Capabilities.Capability; |
---|
28 | import weka.core.*; |
---|
29 | import weka.classifiers.trees.adtree.ReferenceInstances; |
---|
30 | import java.util.*; |
---|
31 | import java.io.*; |
---|
32 | import weka.core.TechnicalInformation; |
---|
33 | import weka.core.TechnicalInformationHandler; |
---|
34 | import weka.core.TechnicalInformation.Field; |
---|
35 | import weka.core.TechnicalInformation.Type; |
---|
36 | |
---|
37 | /** |
---|
38 | <!-- globalinfo-start --> |
---|
39 | * Class for generating a multi-class alternating decision tree using the LogitBoost strategy. For more info, see<br/> |
---|
40 | * <br/> |
---|
41 | * Geoffrey Holmes, Bernhard Pfahringer, Richard Kirkby, Eibe Frank, Mark Hall: Multiclass alternating decision trees. In: ECML, 161-172, 2001. |
---|
42 | * <p/> |
---|
43 | <!-- globalinfo-end --> |
---|
44 | * |
---|
45 | <!-- technical-bibtex-start --> |
---|
46 | * BibTeX: |
---|
47 | * <pre> |
---|
48 | * @inproceedings{Holmes2001, |
---|
49 | * author = {Geoffrey Holmes and Bernhard Pfahringer and Richard Kirkby and Eibe Frank and Mark Hall}, |
---|
50 | * booktitle = {ECML}, |
---|
51 | * pages = {161-172}, |
---|
52 | * publisher = {Springer}, |
---|
53 | * title = {Multiclass alternating decision trees}, |
---|
54 | * year = {2001} |
---|
55 | * } |
---|
56 | * </pre> |
---|
57 | * <p/> |
---|
58 | <!-- technical-bibtex-end --> |
---|
59 | * |
---|
60 | <!-- options-start --> |
---|
61 | * Valid options are: <p/> |
---|
62 | * |
---|
63 | * <pre> -B <number of boosting iterations> |
---|
64 | * Number of boosting iterations. |
---|
65 | * (Default = 10)</pre> |
---|
66 | * |
---|
67 | * <pre> -D |
---|
68 | * If set, classifier is run in debug mode and |
---|
69 | * may output additional info to the console</pre> |
---|
70 | * |
---|
71 | <!-- options-end --> |
---|
72 | * |
---|
73 | * @author Richard Kirkby |
---|
74 | * @version $Revision: 6035 $ |
---|
75 | */ |
---|
76 | |
---|
77 | public class LADTree |
---|
78 | extends AbstractClassifier implements Drawable, |
---|
79 | AdditionalMeasureProducer, |
---|
80 | TechnicalInformationHandler { |
---|
81 | |
---|
82 | /** |
---|
83 | * For serialization |
---|
84 | */ |
---|
85 | private static final long serialVersionUID = -4940716114518300302L; |
---|
86 | |
---|
87 | // Constant from LogitBoost |
---|
88 | protected double Z_MAX = 4; |
---|
89 | |
---|
90 | // Number of classes |
---|
91 | protected int m_numOfClasses; |
---|
92 | |
---|
93 | // Instances as reference instances |
---|
94 | protected ReferenceInstances m_trainInstances; |
---|
95 | |
---|
96 | // Root of the tree |
---|
97 | protected PredictionNode m_root = null; |
---|
98 | |
---|
99 | // To keep track of the order in which splits are added |
---|
100 | protected int m_lastAddedSplitNum = 0; |
---|
101 | |
---|
102 | // Indices for numeric attributes |
---|
103 | protected int[] m_numericAttIndices; |
---|
104 | |
---|
105 | // Variables to keep track of best options |
---|
106 | protected double m_search_smallestLeastSquares; |
---|
107 | protected PredictionNode m_search_bestInsertionNode; |
---|
108 | protected Splitter m_search_bestSplitter; |
---|
109 | protected Instances m_search_bestPathInstances; |
---|
110 | |
---|
111 | // A collection of splitter nodes |
---|
112 | protected FastVector m_staticPotentialSplitters2way; |
---|
113 | |
---|
114 | // statistics |
---|
115 | protected int m_nodesExpanded = 0; |
---|
116 | protected int m_examplesCounted = 0; |
---|
117 | |
---|
118 | // options |
---|
119 | protected int m_boostingIterations = 10; |
---|
120 | |
---|
121 | /** |
---|
122 | * Returns a string describing classifier |
---|
123 | * @return a description suitable for |
---|
124 | * displaying in the explorer/experimenter gui |
---|
125 | */ |
---|
126 | public String globalInfo() { |
---|
127 | |
---|
128 | return "Class for generating a multi-class alternating decision tree using " + |
---|
129 | "the LogitBoost strategy. For more info, see\n\n" |
---|
130 | + getTechnicalInformation().toString(); |
---|
131 | } |
---|
132 | |
---|
133 | /** |
---|
134 | * Returns an instance of a TechnicalInformation object, containing |
---|
135 | * detailed information about the technical background of this class, |
---|
136 | * e.g., paper reference or book this class is based on. |
---|
137 | * |
---|
138 | * @return the technical information about this class |
---|
139 | */ |
---|
140 | public TechnicalInformation getTechnicalInformation() { |
---|
141 | TechnicalInformation result; |
---|
142 | |
---|
143 | result = new TechnicalInformation(Type.INPROCEEDINGS); |
---|
144 | result.setValue(Field.AUTHOR, "Geoffrey Holmes and Bernhard Pfahringer and Richard Kirkby and Eibe Frank and Mark Hall"); |
---|
145 | result.setValue(Field.TITLE, "Multiclass alternating decision trees"); |
---|
146 | result.setValue(Field.BOOKTITLE, "ECML"); |
---|
147 | result.setValue(Field.YEAR, "2001"); |
---|
148 | result.setValue(Field.PAGES, "161-172"); |
---|
149 | result.setValue(Field.PUBLISHER, "Springer"); |
---|
150 | |
---|
151 | return result; |
---|
152 | } |
---|
153 | |
---|
154 | /** helper classes ********************************************************************/ |
---|
155 | |
---|
156 | protected class LADInstance extends DenseInstance { |
---|
157 | public double[] fVector; |
---|
158 | public double[] wVector; |
---|
159 | public double[] pVector; |
---|
160 | public double[] zVector; |
---|
161 | public LADInstance(Instance instance) { |
---|
162 | |
---|
163 | super(instance); |
---|
164 | |
---|
165 | setDataset(instance.dataset()); // preserve dataset |
---|
166 | |
---|
167 | // set up vectors |
---|
168 | fVector = new double[m_numOfClasses]; |
---|
169 | wVector = new double[m_numOfClasses]; |
---|
170 | pVector = new double[m_numOfClasses]; |
---|
171 | zVector = new double[m_numOfClasses]; |
---|
172 | |
---|
173 | // set initial probabilities |
---|
174 | double initProb = 1.0 / ((double) m_numOfClasses); |
---|
175 | for (int i=0; i<m_numOfClasses; i++) { |
---|
176 | pVector[i] = initProb; |
---|
177 | } |
---|
178 | updateZVector(); |
---|
179 | updateWVector(); |
---|
180 | } |
---|
181 | public void updateWeights(double[] fVectorIncrement) { |
---|
182 | for (int i=0; i<fVector.length; i++) { |
---|
183 | fVector[i] += fVectorIncrement[i]; |
---|
184 | } |
---|
185 | updateVectors(fVector); |
---|
186 | } |
---|
187 | public void updateVectors(double[] newFVector) { |
---|
188 | updatePVector(newFVector); |
---|
189 | updateZVector(); |
---|
190 | updateWVector(); |
---|
191 | } |
---|
192 | public void updatePVector(double[] newFVector) { |
---|
193 | double max = newFVector[Utils.maxIndex(newFVector)]; |
---|
194 | for (int i=0; i<pVector.length; i++) { |
---|
195 | pVector[i] = Math.exp(newFVector[i] - max); |
---|
196 | } |
---|
197 | Utils.normalize(pVector); |
---|
198 | } |
---|
199 | public void updateWVector() { |
---|
200 | for (int i=0; i<wVector.length; i++) { |
---|
201 | wVector[i] = (yVector(i) - pVector[i]) / zVector[i]; |
---|
202 | } |
---|
203 | } |
---|
204 | public void updateZVector() { |
---|
205 | |
---|
206 | for (int i=0; i<zVector.length; i++) { |
---|
207 | if (yVector(i) == 1) { |
---|
208 | zVector[i] = 1.0 / pVector[i]; |
---|
209 | if (zVector[i] > Z_MAX) { // threshold |
---|
210 | zVector[i] = Z_MAX; |
---|
211 | } |
---|
212 | } else { |
---|
213 | zVector[i] = -1.0 / (1.0 - pVector[i]); |
---|
214 | if (zVector[i] < -Z_MAX) { // threshold |
---|
215 | zVector[i] = -Z_MAX; |
---|
216 | } |
---|
217 | } |
---|
218 | } |
---|
219 | } |
---|
220 | public double yVector(int index) { |
---|
221 | return (index == (int) classValue() ? 1.0 : 0.0); |
---|
222 | } |
---|
223 | public Object copy() { |
---|
224 | LADInstance copy = new LADInstance((Instance) super.copy()); |
---|
225 | System.arraycopy(fVector, 0, copy.fVector, 0, fVector.length); |
---|
226 | System.arraycopy(wVector, 0, copy.wVector, 0, wVector.length); |
---|
227 | System.arraycopy(pVector, 0, copy.pVector, 0, pVector.length); |
---|
228 | System.arraycopy(zVector, 0, copy.zVector, 0, zVector.length); |
---|
229 | return copy; |
---|
230 | } |
---|
231 | public String toString() { |
---|
232 | |
---|
233 | StringBuffer text = new StringBuffer(); |
---|
234 | text.append(" * F("); |
---|
235 | for (int i=0; i<fVector.length; i++) { |
---|
236 | text.append(Utils.doubleToString(fVector[i], 3)); |
---|
237 | if (i<fVector.length-1) text.append(","); |
---|
238 | } |
---|
239 | text.append(") P("); |
---|
240 | for (int i=0; i<pVector.length; i++) { |
---|
241 | text.append(Utils.doubleToString(pVector[i], 3)); |
---|
242 | if (i<pVector.length-1) text.append(","); |
---|
243 | } |
---|
244 | text.append(") W("); |
---|
245 | for (int i=0; i<wVector.length; i++) { |
---|
246 | text.append(Utils.doubleToString(wVector[i], 3)); |
---|
247 | if (i<wVector.length-1) text.append(","); |
---|
248 | } |
---|
249 | text.append(")"); |
---|
250 | return super.toString() + text.toString(); |
---|
251 | |
---|
252 | } |
---|
253 | } |
---|
254 | |
---|
255 | protected class PredictionNode implements Serializable, Cloneable{ |
---|
256 | private double[] values; |
---|
257 | private FastVector children; // any number of splitter nodes |
---|
258 | |
---|
259 | public PredictionNode(double[] newValues) { |
---|
260 | values = new double[m_numOfClasses]; |
---|
261 | setValues(newValues); |
---|
262 | children = new FastVector(); |
---|
263 | } |
---|
264 | public void setValues(double[] newValues) { |
---|
265 | System.arraycopy(newValues, 0, values, 0, m_numOfClasses); |
---|
266 | } |
---|
267 | public double[] getValues() { |
---|
268 | return values; |
---|
269 | } |
---|
270 | public FastVector getChildren() { return children; } |
---|
271 | public Enumeration children() { return children.elements(); } |
---|
272 | public void addChild(Splitter newChild) { // merges, adds a clone (deep copy) |
---|
273 | Splitter oldEqual = null; |
---|
274 | for (Enumeration e = children(); e.hasMoreElements(); ) { |
---|
275 | Splitter split = (Splitter) e.nextElement(); |
---|
276 | if (newChild.equalTo(split)) { oldEqual = split; break; } |
---|
277 | } |
---|
278 | if (oldEqual == null) { |
---|
279 | Splitter addChild = (Splitter) newChild.clone(); |
---|
280 | addChild.orderAdded = ++m_lastAddedSplitNum; |
---|
281 | children.addElement(addChild); |
---|
282 | } |
---|
283 | else { // do a merge |
---|
284 | for (int i=0; i<newChild.getNumOfBranches(); i++) { |
---|
285 | PredictionNode oldPred = oldEqual.getChildForBranch(i); |
---|
286 | PredictionNode newPred = newChild.getChildForBranch(i); |
---|
287 | if (oldPred != null && newPred != null) |
---|
288 | oldPred.merge(newPred); |
---|
289 | } |
---|
290 | } |
---|
291 | } |
---|
292 | public Object clone() { // does a deep copy (recurses through tree) |
---|
293 | PredictionNode clone = new PredictionNode(values); |
---|
294 | // should actually clone once merges are enabled! |
---|
295 | for (Enumeration e = children.elements(); e.hasMoreElements(); ) |
---|
296 | clone.children.addElement((Splitter)((Splitter) e.nextElement()).clone()); |
---|
297 | return clone; |
---|
298 | } |
---|
299 | public void merge(PredictionNode merger) { |
---|
300 | // need to merge linear models here somehow |
---|
301 | for (int i=0; i<m_numOfClasses; i++) values[i] += merger.values[i]; |
---|
302 | for (Enumeration e = merger.children(); e.hasMoreElements(); ) { |
---|
303 | addChild((Splitter)e.nextElement()); |
---|
304 | } |
---|
305 | } |
---|
306 | } |
---|
307 | |
---|
308 | /** splitter classes ******************************************************************/ |
---|
309 | |
---|
310 | protected abstract class Splitter implements Serializable, Cloneable { |
---|
311 | protected int attIndex; |
---|
312 | public int orderAdded; |
---|
313 | public abstract int getNumOfBranches(); |
---|
314 | public abstract int branchInstanceGoesDown(Instance i); |
---|
315 | public abstract Instances instancesDownBranch(int branch, Instances sourceInstances); |
---|
316 | public abstract String attributeString(); |
---|
317 | public abstract String comparisonString(int branchNum); |
---|
318 | public abstract boolean equalTo(Splitter compare); |
---|
319 | public abstract void setChildForBranch(int branchNum, PredictionNode childPredictor); |
---|
320 | public abstract PredictionNode getChildForBranch(int branchNum); |
---|
321 | public abstract Object clone(); |
---|
322 | } |
---|
323 | |
---|
324 | protected class TwoWayNominalSplit extends Splitter { |
---|
325 | //private int attIndex; |
---|
326 | private int trueSplitValue; |
---|
327 | private PredictionNode[] children; |
---|
328 | public TwoWayNominalSplit(int _attIndex, int _trueSplitValue) { |
---|
329 | attIndex = _attIndex; trueSplitValue = _trueSplitValue; |
---|
330 | children = new PredictionNode[2]; |
---|
331 | } |
---|
332 | public int getNumOfBranches() { return 2; } |
---|
333 | public int branchInstanceGoesDown(Instance inst) { |
---|
334 | if (inst.isMissing(attIndex)) return -1; |
---|
335 | else if (inst.value(attIndex) == trueSplitValue) return 0; |
---|
336 | else return 1; |
---|
337 | } |
---|
338 | public Instances instancesDownBranch(int branch, Instances instances) { |
---|
339 | ReferenceInstances filteredInstances = new ReferenceInstances(instances, 1); |
---|
340 | if (branch == -1) { |
---|
341 | for (Enumeration e = instances.enumerateInstances(); e.hasMoreElements(); ) { |
---|
342 | Instance inst = (Instance) e.nextElement(); |
---|
343 | if (inst.isMissing(attIndex)) filteredInstances.addReference(inst); |
---|
344 | } |
---|
345 | } else if (branch == 0) { |
---|
346 | for (Enumeration e = instances.enumerateInstances(); e.hasMoreElements(); ) { |
---|
347 | Instance inst = (Instance) e.nextElement(); |
---|
348 | if (!inst.isMissing(attIndex) && inst.value(attIndex) == trueSplitValue) |
---|
349 | filteredInstances.addReference(inst); |
---|
350 | } |
---|
351 | } else { |
---|
352 | for (Enumeration e = instances.enumerateInstances(); e.hasMoreElements(); ) { |
---|
353 | Instance inst = (Instance) e.nextElement(); |
---|
354 | if (!inst.isMissing(attIndex) && inst.value(attIndex) != trueSplitValue) |
---|
355 | filteredInstances.addReference(inst); |
---|
356 | } |
---|
357 | } |
---|
358 | return filteredInstances; |
---|
359 | } |
---|
360 | public String attributeString() { |
---|
361 | return m_trainInstances.attribute(attIndex).name(); |
---|
362 | } |
---|
363 | public String comparisonString(int branchNum) { |
---|
364 | Attribute att = m_trainInstances.attribute(attIndex); |
---|
365 | if (att.numValues() != 2) |
---|
366 | return ((branchNum == 0 ? "= " : "!= ") + att.value(trueSplitValue)); |
---|
367 | else return ("= " + (branchNum == 0 ? |
---|
368 | att.value(trueSplitValue) : |
---|
369 | att.value(trueSplitValue == 0 ? 1 : 0))); |
---|
370 | } |
---|
371 | public boolean equalTo(Splitter compare) { |
---|
372 | if (compare instanceof TwoWayNominalSplit) { // test object type |
---|
373 | TwoWayNominalSplit compareSame = (TwoWayNominalSplit) compare; |
---|
374 | return (attIndex == compareSame.attIndex && |
---|
375 | trueSplitValue == compareSame.trueSplitValue); |
---|
376 | } else return false; |
---|
377 | } |
---|
378 | public void setChildForBranch(int branchNum, PredictionNode childPredictor) { |
---|
379 | children[branchNum] = childPredictor; |
---|
380 | } |
---|
381 | public PredictionNode getChildForBranch(int branchNum) { |
---|
382 | return children[branchNum]; |
---|
383 | } |
---|
384 | public Object clone() { // deep copy |
---|
385 | TwoWayNominalSplit clone = new TwoWayNominalSplit(attIndex, trueSplitValue); |
---|
386 | if (children[0] != null) |
---|
387 | clone.setChildForBranch(0, (PredictionNode) children[0].clone()); |
---|
388 | if (children[1] != null) |
---|
389 | clone.setChildForBranch(1, (PredictionNode) children[1].clone()); |
---|
390 | return clone; |
---|
391 | } |
---|
392 | } |
---|
393 | |
---|
394 | protected class TwoWayNumericSplit extends Splitter implements Cloneable { |
---|
395 | //private int attIndex; |
---|
396 | private double splitPoint; |
---|
397 | private PredictionNode[] children; |
---|
398 | public TwoWayNumericSplit(int _attIndex, double _splitPoint) { |
---|
399 | attIndex = _attIndex; |
---|
400 | splitPoint = _splitPoint; |
---|
401 | children = new PredictionNode[2]; |
---|
402 | } |
---|
403 | public TwoWayNumericSplit(int _attIndex, Instances instances) throws Exception { |
---|
404 | attIndex = _attIndex; |
---|
405 | splitPoint = findSplit(instances, attIndex); |
---|
406 | children = new PredictionNode[2]; |
---|
407 | } |
---|
408 | public int getNumOfBranches() { return 2; } |
---|
409 | public int branchInstanceGoesDown(Instance inst) { |
---|
410 | if (inst.isMissing(attIndex)) return -1; |
---|
411 | else if (inst.value(attIndex) < splitPoint) return 0; |
---|
412 | else return 1; |
---|
413 | } |
---|
414 | public Instances instancesDownBranch(int branch, Instances instances) { |
---|
415 | ReferenceInstances filteredInstances = new ReferenceInstances(instances, 1); |
---|
416 | if (branch == -1) { |
---|
417 | for (Enumeration e = instances.enumerateInstances(); e.hasMoreElements(); ) { |
---|
418 | Instance inst = (Instance) e.nextElement(); |
---|
419 | if (inst.isMissing(attIndex)) filteredInstances.addReference(inst); |
---|
420 | } |
---|
421 | } else if (branch == 0) { |
---|
422 | for (Enumeration e = instances.enumerateInstances(); e.hasMoreElements(); ) { |
---|
423 | Instance inst = (Instance) e.nextElement(); |
---|
424 | if (!inst.isMissing(attIndex) && inst.value(attIndex) < splitPoint) |
---|
425 | filteredInstances.addReference(inst); |
---|
426 | } |
---|
427 | } else { |
---|
428 | for (Enumeration e = instances.enumerateInstances(); e.hasMoreElements(); ) { |
---|
429 | Instance inst = (Instance) e.nextElement(); |
---|
430 | if (!inst.isMissing(attIndex) && inst.value(attIndex) >= splitPoint) |
---|
431 | filteredInstances.addReference(inst); |
---|
432 | } |
---|
433 | } |
---|
434 | return filteredInstances; |
---|
435 | } |
---|
436 | public String attributeString() { |
---|
437 | return m_trainInstances.attribute(attIndex).name(); |
---|
438 | } |
---|
439 | public String comparisonString(int branchNum) { |
---|
440 | return ((branchNum == 0 ? "< " : ">= ") + Utils.doubleToString(splitPoint, 3)); |
---|
441 | } |
---|
442 | public boolean equalTo(Splitter compare) { |
---|
443 | if (compare instanceof TwoWayNumericSplit) { // test object type |
---|
444 | TwoWayNumericSplit compareSame = (TwoWayNumericSplit) compare; |
---|
445 | return (attIndex == compareSame.attIndex && |
---|
446 | splitPoint == compareSame.splitPoint); |
---|
447 | } else return false; |
---|
448 | } |
---|
449 | public void setChildForBranch(int branchNum, PredictionNode childPredictor) { |
---|
450 | children[branchNum] = childPredictor; |
---|
451 | } |
---|
452 | public PredictionNode getChildForBranch(int branchNum) { |
---|
453 | return children[branchNum]; |
---|
454 | } |
---|
455 | public Object clone() { // deep copy |
---|
456 | TwoWayNumericSplit clone = new TwoWayNumericSplit(attIndex, splitPoint); |
---|
457 | if (children[0] != null) |
---|
458 | clone.setChildForBranch(0, (PredictionNode) children[0].clone()); |
---|
459 | if (children[1] != null) |
---|
460 | clone.setChildForBranch(1, (PredictionNode) children[1].clone()); |
---|
461 | return clone; |
---|
462 | } |
---|
463 | private double findSplit(Instances instances, int index) throws Exception { |
---|
464 | double splitPoint = 0; |
---|
465 | double bestVal = Double.MAX_VALUE, currVal, currCutPoint; |
---|
466 | int numMissing = 0; |
---|
467 | double[][] distribution = new double[3][instances.numClasses()]; |
---|
468 | |
---|
469 | // Compute counts for all the values |
---|
470 | for (int i = 0; i < instances.numInstances(); i++) { |
---|
471 | Instance inst = instances.instance(i); |
---|
472 | if (!inst.isMissing(index)) { |
---|
473 | distribution[1][(int)inst.classValue()] ++; |
---|
474 | } else { |
---|
475 | distribution[2][(int)inst.classValue()] ++; |
---|
476 | numMissing++; |
---|
477 | } |
---|
478 | } |
---|
479 | |
---|
480 | // Sort instances |
---|
481 | instances.sort(index); |
---|
482 | |
---|
483 | // Make split counts for each possible split and evaluate |
---|
484 | for (int i = 0; i < instances.numInstances() - (numMissing + 1); i++) { |
---|
485 | Instance inst = instances.instance(i); |
---|
486 | Instance instPlusOne = instances.instance(i + 1); |
---|
487 | distribution[0][(int)inst.classValue()] += inst.weight(); |
---|
488 | distribution[1][(int)inst.classValue()] -= inst.weight(); |
---|
489 | if (Utils.sm(inst.value(index), instPlusOne.value(index))) { |
---|
490 | currCutPoint = (inst.value(index) + instPlusOne.value(index)) / 2.0; |
---|
491 | currVal = ContingencyTables.entropyConditionedOnRows(distribution); |
---|
492 | if (Utils.sm(currVal, bestVal)) { |
---|
493 | splitPoint = currCutPoint; |
---|
494 | bestVal = currVal; |
---|
495 | } |
---|
496 | } |
---|
497 | } |
---|
498 | |
---|
499 | return splitPoint; |
---|
500 | } |
---|
501 | } |
---|
502 | |
---|
503 | /** |
---|
504 | * Sets up the tree ready to be trained. |
---|
505 | * |
---|
506 | * @param instances the instances to train the tree with |
---|
507 | * @exception Exception if training data is unsuitable |
---|
508 | */ |
---|
509 | public void initClassifier(Instances instances) throws Exception { |
---|
510 | |
---|
511 | // clear stats |
---|
512 | m_nodesExpanded = 0; |
---|
513 | m_examplesCounted = 0; |
---|
514 | m_lastAddedSplitNum = 0; |
---|
515 | |
---|
516 | m_numOfClasses = instances.numClasses(); |
---|
517 | |
---|
518 | // make sure training data is suitable |
---|
519 | if (instances.checkForStringAttributes()) { |
---|
520 | throw new Exception("Can't handle string attributes!"); |
---|
521 | } |
---|
522 | if (!instances.classAttribute().isNominal()) { |
---|
523 | throw new Exception("Class must be nominal!"); |
---|
524 | } |
---|
525 | |
---|
526 | // create training set (use LADInstance class) |
---|
527 | m_trainInstances = |
---|
528 | new ReferenceInstances(instances, instances.numInstances()); |
---|
529 | for (Enumeration e = instances.enumerateInstances(); e.hasMoreElements(); ) { |
---|
530 | Instance inst = (Instance) e.nextElement(); |
---|
531 | if (!inst.classIsMissing()) { |
---|
532 | LADInstance adtInst = new LADInstance(inst); |
---|
533 | m_trainInstances.addReference(adtInst); |
---|
534 | adtInst.setDataset(m_trainInstances); |
---|
535 | } |
---|
536 | } |
---|
537 | |
---|
538 | // create the root prediction node |
---|
539 | m_root = new PredictionNode(new double[m_numOfClasses]); |
---|
540 | |
---|
541 | // pre-calculate what we can |
---|
542 | generateStaticPotentialSplittersAndNumericIndices(); |
---|
543 | } |
---|
544 | |
---|
545 | public void next(int iteration) throws Exception { |
---|
546 | boost(); |
---|
547 | } |
---|
548 | |
---|
549 | public void done() throws Exception {} |
---|
550 | |
---|
551 | /** |
---|
552 | * Performs a single boosting iteration. |
---|
553 | * Will add a new splitter node and two prediction nodes to the tree |
---|
554 | * (unless merging takes place). |
---|
555 | * |
---|
556 | * @exception Exception if try to boost without setting up tree first |
---|
557 | */ |
---|
558 | private void boost() throws Exception { |
---|
559 | |
---|
560 | if (m_trainInstances == null) |
---|
561 | throw new Exception("Trying to boost with no training data"); |
---|
562 | |
---|
563 | // perform the search |
---|
564 | searchForBestTest(); |
---|
565 | |
---|
566 | if (m_Debug) { |
---|
567 | System.out.println("Best split found: " |
---|
568 | + m_search_bestSplitter.getNumOfBranches() + "-way split on " |
---|
569 | + m_search_bestSplitter.attributeString() |
---|
570 | //+ "\nsmallestLeastSquares = " + m_search_smallestLeastSquares); |
---|
571 | + "\nBestGain = " + m_search_smallestLeastSquares); |
---|
572 | } |
---|
573 | |
---|
574 | if (m_search_bestSplitter == null) return; // handle empty instances |
---|
575 | |
---|
576 | // create the new nodes for the tree, updating the weights |
---|
577 | for (int i=0; i<m_search_bestSplitter.getNumOfBranches(); i++) { |
---|
578 | Instances applicableInstances = |
---|
579 | m_search_bestSplitter.instancesDownBranch(i, m_search_bestPathInstances); |
---|
580 | double[] predictionValues = calcPredictionValues(applicableInstances); |
---|
581 | PredictionNode newPredictor = new PredictionNode(predictionValues); |
---|
582 | updateWeights(applicableInstances, predictionValues); |
---|
583 | m_search_bestSplitter.setChildForBranch(i, newPredictor); |
---|
584 | } |
---|
585 | |
---|
586 | // insert the new nodes |
---|
587 | m_search_bestInsertionNode.addChild((Splitter) m_search_bestSplitter); |
---|
588 | |
---|
589 | if (m_Debug) { |
---|
590 | System.out.println("Tree is now:\n" + toString(m_root, 1) + "\n"); |
---|
591 | //System.out.println("Instances are now:\n" + m_trainInstances + "\n"); |
---|
592 | } |
---|
593 | |
---|
594 | // free memory |
---|
595 | m_search_bestPathInstances = null; |
---|
596 | } |
---|
597 | |
---|
598 | private void updateWeights(Instances instances, double[] newPredictionValues) { |
---|
599 | |
---|
600 | for (int i=0; i<instances.numInstances(); i++) |
---|
601 | ((LADInstance) instances.instance(i)).updateWeights(newPredictionValues); |
---|
602 | } |
---|
603 | |
---|
604 | /** |
---|
605 | * Generates the m_staticPotentialSplitters2way |
---|
606 | * vector to contain all possible nominal splits, and the m_numericAttIndices array to |
---|
607 | * index the numeric attributes in the training data. |
---|
608 | * |
---|
609 | */ |
---|
610 | private void generateStaticPotentialSplittersAndNumericIndices() { |
---|
611 | |
---|
612 | m_staticPotentialSplitters2way = new FastVector(); |
---|
613 | FastVector numericIndices = new FastVector(); |
---|
614 | |
---|
615 | for (int i=0; i<m_trainInstances.numAttributes(); i++) { |
---|
616 | if (i == m_trainInstances.classIndex()) continue; |
---|
617 | if (m_trainInstances.attribute(i).isNumeric()) |
---|
618 | numericIndices.addElement(new Integer(i)); |
---|
619 | else { |
---|
620 | int numValues = m_trainInstances.attribute(i).numValues(); |
---|
621 | if (numValues == 2) // avoid redundancy due to 2-way symmetry |
---|
622 | m_staticPotentialSplitters2way.addElement(new TwoWayNominalSplit(i, 0)); |
---|
623 | else for (int j=0; j<numValues; j++) |
---|
624 | m_staticPotentialSplitters2way.addElement(new TwoWayNominalSplit(i, j)); |
---|
625 | } |
---|
626 | } |
---|
627 | |
---|
628 | m_numericAttIndices = new int[numericIndices.size()]; |
---|
629 | for (int i=0; i<numericIndices.size(); i++) |
---|
630 | m_numericAttIndices[i] = ((Integer)numericIndices.elementAt(i)).intValue(); |
---|
631 | } |
---|
632 | |
---|
633 | /** |
---|
634 | * Performs a search for the best test (splitter) to add to the tree, by looking |
---|
635 | * for the largest weight change. |
---|
636 | * |
---|
637 | * @exception Exception if search fails |
---|
638 | */ |
---|
639 | private void searchForBestTest() throws Exception { |
---|
640 | |
---|
641 | if (m_Debug) { |
---|
642 | System.out.println("Searching for best split..."); |
---|
643 | } |
---|
644 | |
---|
645 | m_search_smallestLeastSquares = 0.0; //Double.POSITIVE_INFINITY; |
---|
646 | searchForBestTest(m_root, m_trainInstances); |
---|
647 | } |
---|
648 | |
---|
649 | /** |
---|
650 | * Recursive function that carries out search for the best test (splitter) to add to |
---|
651 | * this part of the tree, by looking for the largest weight change. Will try 2-way |
---|
652 | * and/or multi-way splits depending on the current state. |
---|
653 | * |
---|
654 | * @param currentNode the root of the subtree to be searched, and the current node |
---|
655 | * being considered as parent of a new split |
---|
656 | * @param instances the instances that apply at this node |
---|
657 | * @exception Exception if search fails |
---|
658 | */ |
---|
659 | private void searchForBestTest(PredictionNode currentNode, Instances instances) |
---|
660 | throws Exception |
---|
661 | { |
---|
662 | |
---|
663 | // keep stats |
---|
664 | m_nodesExpanded++; |
---|
665 | m_examplesCounted += instances.numInstances(); |
---|
666 | |
---|
667 | // evaluate static splitters (nominal) |
---|
668 | for (Enumeration e = m_staticPotentialSplitters2way.elements(); |
---|
669 | e.hasMoreElements(); ) { |
---|
670 | evaluateSplitter((Splitter) e.nextElement(), currentNode, instances); |
---|
671 | } |
---|
672 | |
---|
673 | if (m_Debug) { |
---|
674 | //System.out.println("Instances considered are: " + instances); |
---|
675 | } |
---|
676 | |
---|
677 | // evaluate dynamic splitters (numeric) |
---|
678 | for (int i=0; i<m_numericAttIndices.length; i++) { |
---|
679 | evaluateNumericSplit(currentNode, instances, m_numericAttIndices[i]); |
---|
680 | } |
---|
681 | |
---|
682 | if (currentNode.getChildren().size() == 0) return; |
---|
683 | |
---|
684 | // keep searching |
---|
685 | goDownAllPaths(currentNode, instances); |
---|
686 | } |
---|
687 | |
---|
688 | /** |
---|
689 | * Continues general multi-class search by investigating every node in the |
---|
690 | * subtree under currentNode. |
---|
691 | * |
---|
692 | * @param currentNode the root of the subtree to be searched |
---|
693 | * @param instances the instances that apply at this node |
---|
694 | * @exception Exception if search fails |
---|
695 | */ |
---|
696 | private void goDownAllPaths(PredictionNode currentNode, Instances instances) |
---|
697 | throws Exception |
---|
698 | { |
---|
699 | |
---|
700 | for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) { |
---|
701 | Splitter split = (Splitter) e.nextElement(); |
---|
702 | for (int i=0; i<split.getNumOfBranches(); i++) |
---|
703 | searchForBestTest(split.getChildForBranch(i), |
---|
704 | split.instancesDownBranch(i, instances)); |
---|
705 | } |
---|
706 | } |
---|
707 | |
---|
708 | /** |
---|
709 | * Investigates the option of introducing a split under currentNode. If the |
---|
710 | * split creates a weight change that is larger than has already been found it will |
---|
711 | * update the search information to record this as the best option so far. |
---|
712 | * |
---|
713 | * @param split the splitter node to evaluate |
---|
714 | * @param currentNode the parent under which the split is to be considered |
---|
715 | * @param instances the instances that apply at this node |
---|
716 | * @exception Exception if something goes wrong |
---|
717 | */ |
---|
718 | private void evaluateSplitter(Splitter split, PredictionNode currentNode, |
---|
719 | Instances instances) |
---|
720 | throws Exception |
---|
721 | { |
---|
722 | |
---|
723 | double leastSquares = leastSquaresNonMissing(instances,split.attIndex); |
---|
724 | |
---|
725 | for (int i=0; i<split.getNumOfBranches(); i++) |
---|
726 | leastSquares -= leastSquares(split.instancesDownBranch(i, instances)); |
---|
727 | |
---|
728 | if (m_Debug) { |
---|
729 | //System.out.println("Instances considered are: " + instances); |
---|
730 | System.out.print(split.getNumOfBranches() + "-way split on " + split.attributeString() |
---|
731 | + " has leastSquares value of " |
---|
732 | + Utils.doubleToString(leastSquares,3)); |
---|
733 | } |
---|
734 | |
---|
735 | if (leastSquares > m_search_smallestLeastSquares) { |
---|
736 | if (m_Debug) { |
---|
737 | System.out.print(" (best so far)"); |
---|
738 | } |
---|
739 | m_search_smallestLeastSquares = leastSquares; |
---|
740 | m_search_bestInsertionNode = currentNode; |
---|
741 | m_search_bestSplitter = split; |
---|
742 | m_search_bestPathInstances = instances; |
---|
743 | } |
---|
744 | if (m_Debug) { |
---|
745 | System.out.print("\n"); |
---|
746 | } |
---|
747 | } |
---|
748 | |
---|
749 | private void evaluateNumericSplit(PredictionNode currentNode, |
---|
750 | Instances instances, int attIndex) |
---|
751 | { |
---|
752 | |
---|
753 | double[] splitAndLS = findNumericSplitpointAndLS(instances, attIndex); |
---|
754 | double gain = leastSquaresNonMissing(instances,attIndex) - splitAndLS[1]; |
---|
755 | |
---|
756 | if (m_Debug) { |
---|
757 | //System.out.println("Instances considered are: " + instances); |
---|
758 | System.out.print("Numeric split on " + instances.attribute(attIndex).name() |
---|
759 | + " has leastSquares value of " |
---|
760 | //+ Utils.doubleToString(splitAndLS[1],3)); |
---|
761 | + Utils.doubleToString(gain,3)); |
---|
762 | } |
---|
763 | |
---|
764 | if (gain > m_search_smallestLeastSquares) { |
---|
765 | if (m_Debug) { |
---|
766 | System.out.print(" (best so far)"); |
---|
767 | } |
---|
768 | m_search_smallestLeastSquares = gain; //splitAndLS[1]; |
---|
769 | m_search_bestInsertionNode = currentNode; |
---|
770 | m_search_bestSplitter = new TwoWayNumericSplit(attIndex, splitAndLS[0]);; |
---|
771 | m_search_bestPathInstances = instances; |
---|
772 | } |
---|
773 | if (m_Debug) { |
---|
774 | System.out.print("\n"); |
---|
775 | } |
---|
776 | } |
---|
777 | |
---|
778 | private double[] findNumericSplitpointAndLS(Instances instances, int attIndex) { |
---|
779 | |
---|
780 | double allLS = leastSquares(instances); |
---|
781 | |
---|
782 | // all instances in right subset |
---|
783 | double[] term1L = new double[m_numOfClasses]; |
---|
784 | double[] term2L = new double[m_numOfClasses]; |
---|
785 | double[] term3L = new double[m_numOfClasses]; |
---|
786 | double[] meanNumL = new double[m_numOfClasses]; |
---|
787 | double[] meanDenL = new double[m_numOfClasses]; |
---|
788 | |
---|
789 | double[] term1R = new double[m_numOfClasses]; |
---|
790 | double[] term2R = new double[m_numOfClasses]; |
---|
791 | double[] term3R = new double[m_numOfClasses]; |
---|
792 | double[] meanNumR = new double[m_numOfClasses]; |
---|
793 | double[] meanDenR = new double[m_numOfClasses]; |
---|
794 | |
---|
795 | double temp1, temp2, temp3; |
---|
796 | |
---|
797 | double[] classMeans = new double[m_numOfClasses]; |
---|
798 | double[] classTotals = new double[m_numOfClasses]; |
---|
799 | |
---|
800 | // fill up RHS |
---|
801 | for (int j=0; j<m_numOfClasses; j++) { |
---|
802 | for (int i=0; i<instances.numInstances(); i++) { |
---|
803 | LADInstance inst = (LADInstance) instances.instance(i); |
---|
804 | temp1 = inst.wVector[j] * inst.zVector[j]; |
---|
805 | term1R[j] += temp1 * inst.zVector[j]; |
---|
806 | term2R[j] += temp1; |
---|
807 | term3R[j] += inst.wVector[j]; |
---|
808 | meanNumR[j] += inst.wVector[j] * inst.zVector[j]; |
---|
809 | } |
---|
810 | } |
---|
811 | |
---|
812 | //leastSquares = term1 - (2.0 * u * term2) + (u * u * term3); |
---|
813 | |
---|
814 | double leastSquares; |
---|
815 | boolean newSplit; |
---|
816 | double smallestLeastSquares = Double.POSITIVE_INFINITY; |
---|
817 | double bestSplit = 0.0; |
---|
818 | double meanL, meanR; |
---|
819 | |
---|
820 | instances.sort(attIndex); |
---|
821 | |
---|
822 | for (int i=0; i<instances.numInstances()-1; i++) {// shift inst from right to left |
---|
823 | if (instances.instance(i+1).isMissing(attIndex)) break; |
---|
824 | if (instances.instance(i+1).value(attIndex) > instances.instance(i).value(attIndex)) |
---|
825 | newSplit = true; |
---|
826 | else newSplit = false; |
---|
827 | LADInstance inst = (LADInstance) instances.instance(i); |
---|
828 | leastSquares = 0.0; |
---|
829 | for (int j=0; j<m_numOfClasses; j++) { |
---|
830 | temp1 = inst.wVector[j] * inst.zVector[j]; |
---|
831 | temp2 = temp1 * inst.zVector[j]; |
---|
832 | temp3 = inst.wVector[j] * inst.zVector[j]; |
---|
833 | term1L[j] += temp2; |
---|
834 | term2L[j] += temp1; |
---|
835 | term3L[j] += inst.wVector[j]; |
---|
836 | term1R[j] -= temp2; |
---|
837 | term2R[j] -= temp1; |
---|
838 | term3R[j] -= inst.wVector[j]; |
---|
839 | meanNumL[j] += temp3; |
---|
840 | meanNumR[j] -= temp3; |
---|
841 | if (newSplit) { |
---|
842 | meanL = meanNumL[j] / term3L[j]; |
---|
843 | meanR = meanNumR[j] / term3R[j]; |
---|
844 | leastSquares += term1L[j] - (2.0 * meanL * term2L[j]) |
---|
845 | + (meanL * meanL * term3L[j]); |
---|
846 | leastSquares += term1R[j] - (2.0 * meanR * term2R[j]) |
---|
847 | + (meanR * meanR * term3R[j]); |
---|
848 | } |
---|
849 | } |
---|
850 | if (m_Debug && newSplit) |
---|
851 | System.out.println(attIndex + "/" + |
---|
852 | ((instances.instance(i).value(attIndex) + |
---|
853 | instances.instance(i+1).value(attIndex)) / 2.0) + |
---|
854 | " = " + (allLS - leastSquares)); |
---|
855 | |
---|
856 | if (newSplit && leastSquares < smallestLeastSquares) { |
---|
857 | bestSplit = (instances.instance(i).value(attIndex) + |
---|
858 | instances.instance(i+1).value(attIndex)) / 2.0; |
---|
859 | smallestLeastSquares = leastSquares; |
---|
860 | } |
---|
861 | } |
---|
862 | double[] result = new double[2]; |
---|
863 | result[0] = bestSplit; |
---|
864 | result[1] = smallestLeastSquares > 0 ? smallestLeastSquares : 0; |
---|
865 | return result; |
---|
866 | } |
---|
867 | |
---|
868 | private double leastSquares(Instances instances) { |
---|
869 | |
---|
870 | double numerator=0, denominator=0, w, t; |
---|
871 | double[] classMeans = new double[m_numOfClasses]; |
---|
872 | double[] classTotals = new double[m_numOfClasses]; |
---|
873 | |
---|
874 | for (int i=0; i<instances.numInstances(); i++) { |
---|
875 | LADInstance inst = (LADInstance) instances.instance(i); |
---|
876 | for (int j=0; j<m_numOfClasses; j++) { |
---|
877 | classMeans[j] += inst.zVector[j] * inst.wVector[j]; |
---|
878 | classTotals[j] += inst.wVector[j]; |
---|
879 | } |
---|
880 | } |
---|
881 | |
---|
882 | double numInstances = (double) instances.numInstances(); |
---|
883 | for (int j=0; j<m_numOfClasses; j++) { |
---|
884 | if (classTotals[j] != 0) classMeans[j] /= classTotals[j]; |
---|
885 | } |
---|
886 | |
---|
887 | for (int i=0; i<instances.numInstances(); i++) |
---|
888 | for (int j=0; j<m_numOfClasses; j++) { |
---|
889 | LADInstance inst = (LADInstance) instances.instance(i); |
---|
890 | w = inst.wVector[j]; |
---|
891 | t = inst.zVector[j] - classMeans[j]; |
---|
892 | numerator += w * (t * t); |
---|
893 | denominator += w; |
---|
894 | } |
---|
895 | //System.out.println(numerator + " / " + denominator); |
---|
896 | return numerator > 0 ? numerator : 0;// / denominator; |
---|
897 | } |
---|
898 | |
---|
899 | |
---|
900 | private double leastSquaresNonMissing(Instances instances, int attIndex) { |
---|
901 | |
---|
902 | double numerator=0, denominator=0, w, t; |
---|
903 | double[] classMeans = new double[m_numOfClasses]; |
---|
904 | double[] classTotals = new double[m_numOfClasses]; |
---|
905 | |
---|
906 | for (int i=0; i<instances.numInstances(); i++) { |
---|
907 | LADInstance inst = (LADInstance) instances.instance(i); |
---|
908 | for (int j=0; j<m_numOfClasses; j++) { |
---|
909 | classMeans[j] += inst.zVector[j] * inst.wVector[j]; |
---|
910 | classTotals[j] += inst.wVector[j]; |
---|
911 | } |
---|
912 | } |
---|
913 | |
---|
914 | double numInstances = (double) instances.numInstances(); |
---|
915 | for (int j=0; j<m_numOfClasses; j++) { |
---|
916 | if (classTotals[j] != 0) classMeans[j] /= classTotals[j]; |
---|
917 | } |
---|
918 | |
---|
919 | for (int i=0; i<instances.numInstances(); i++) |
---|
920 | for (int j=0; j<m_numOfClasses; j++) { |
---|
921 | LADInstance inst = (LADInstance) instances.instance(i); |
---|
922 | if(!inst.isMissing(attIndex)) { |
---|
923 | w = inst.wVector[j]; |
---|
924 | t = inst.zVector[j] - classMeans[j]; |
---|
925 | numerator += w * (t * t); |
---|
926 | denominator += w; |
---|
927 | } |
---|
928 | } |
---|
929 | //System.out.println(numerator + " / " + denominator); |
---|
930 | return numerator > 0 ? numerator : 0;// / denominator; |
---|
931 | } |
---|
932 | |
---|
933 | private double[] calcPredictionValues(Instances instances) { |
---|
934 | |
---|
935 | double[] classMeans = new double[m_numOfClasses]; |
---|
936 | double meansSum = 0; |
---|
937 | double multiplier = ((double) (m_numOfClasses-1)) / ((double) (m_numOfClasses)); |
---|
938 | |
---|
939 | double[] classTotals = new double[m_numOfClasses]; |
---|
940 | |
---|
941 | for (int i=0; i<instances.numInstances(); i++) { |
---|
942 | LADInstance inst = (LADInstance) instances.instance(i); |
---|
943 | for (int j=0; j<m_numOfClasses; j++) { |
---|
944 | classMeans[j] += inst.zVector[j] * inst.wVector[j]; |
---|
945 | classTotals[j] += inst.wVector[j]; |
---|
946 | } |
---|
947 | } |
---|
948 | double numInstances = (double) instances.numInstances(); |
---|
949 | for (int j=0; j<m_numOfClasses; j++) { |
---|
950 | if (classTotals[j] != 0) classMeans[j] /= classTotals[j]; |
---|
951 | meansSum += classMeans[j]; |
---|
952 | } |
---|
953 | meansSum /= m_numOfClasses; |
---|
954 | |
---|
955 | for (int j=0; j<m_numOfClasses; j++) { |
---|
956 | classMeans[j] = multiplier * (classMeans[j] - meansSum); |
---|
957 | } |
---|
958 | return classMeans; |
---|
959 | } |
---|
960 | |
---|
961 | /** |
---|
962 | * Returns the class probability distribution for an instance. |
---|
963 | * |
---|
964 | * @param instance the instance to be classified |
---|
965 | * @return the distribution the tree generates for the instance |
---|
966 | */ |
---|
967 | public double[] distributionForInstance(Instance instance) { |
---|
968 | |
---|
969 | double[] predValues = new double[m_numOfClasses]; |
---|
970 | for (int i=0; i<m_numOfClasses; i++) predValues[i] = 0.0; |
---|
971 | double[] distribution = predictionValuesForInstance(instance, m_root, predValues); |
---|
972 | double max = distribution[Utils.maxIndex(distribution)]; |
---|
973 | for (int i=0; i<m_numOfClasses; i++) { |
---|
974 | distribution[i] = Math.exp(distribution[i] - max); |
---|
975 | } |
---|
976 | double sum = Utils.sum(distribution); |
---|
977 | if (sum > 0.0) Utils.normalize(distribution, sum); |
---|
978 | return distribution; |
---|
979 | } |
---|
980 | |
---|
981 | /** |
---|
982 | * Returns the class prediction values (votes) for an instance. |
---|
983 | * |
---|
984 | * @param inst the instance |
---|
985 | * @param currentNode the root of the tree to get the values from |
---|
986 | * @param currentValues the current values before adding the values contained in the |
---|
987 | * subtree |
---|
988 | * @return the class prediction values (votes) |
---|
989 | */ |
---|
990 | private double[] predictionValuesForInstance(Instance inst, PredictionNode currentNode, |
---|
991 | double[] currentValues) { |
---|
992 | |
---|
993 | double[] predValues = currentNode.getValues(); |
---|
994 | for (int i=0; i<m_numOfClasses; i++) currentValues[i] += predValues[i]; |
---|
995 | //for (int i=0; i<m_numOfClasses; i++) currentValues[i] = predValues[i]; |
---|
996 | for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) { |
---|
997 | Splitter split = (Splitter) e.nextElement(); |
---|
998 | int branch = split.branchInstanceGoesDown(inst); |
---|
999 | if (branch >= 0) |
---|
1000 | currentValues = predictionValuesForInstance(inst, split.getChildForBranch(branch), |
---|
1001 | currentValues); |
---|
1002 | } |
---|
1003 | return currentValues; |
---|
1004 | } |
---|
1005 | |
---|
1006 | |
---|
1007 | |
---|
1008 | /** model output functions ************************************************************/ |
---|
1009 | |
---|
1010 | /** |
---|
1011 | * Returns a description of the classifier. |
---|
1012 | * |
---|
1013 | * @return a string containing a description of the classifier |
---|
1014 | */ |
---|
1015 | public String toString() { |
---|
1016 | |
---|
1017 | String className = getClass().getName(); |
---|
1018 | if (m_root == null) |
---|
1019 | return (className +" not built yet"); |
---|
1020 | else { |
---|
1021 | return (className + ":\n\n" + toString(m_root, 1) + |
---|
1022 | "\nLegend: " + legend() + |
---|
1023 | "\n#Tree size (total): " + |
---|
1024 | numOfAllNodes(m_root) + |
---|
1025 | "\n#Tree size (number of predictor nodes): " + |
---|
1026 | numOfPredictionNodes(m_root) + |
---|
1027 | "\n#Leaves (number of predictor nodes): " + |
---|
1028 | numOfLeafNodes(m_root) + |
---|
1029 | "\n#Expanded nodes: " + |
---|
1030 | m_nodesExpanded + |
---|
1031 | "\n#Processed examples: " + |
---|
1032 | m_examplesCounted + |
---|
1033 | "\n#Ratio e/n: " + |
---|
1034 | ((double)m_examplesCounted/(double)m_nodesExpanded) |
---|
1035 | ); |
---|
1036 | } |
---|
1037 | } |
---|
1038 | |
---|
1039 | /** |
---|
1040 | * Traverses the tree, forming a string that describes it. |
---|
1041 | * |
---|
1042 | * @param currentNode the current node under investigation |
---|
1043 | * @param level the current level in the tree |
---|
1044 | * @return the string describing the subtree |
---|
1045 | */ |
---|
1046 | private String toString(PredictionNode currentNode, int level) { |
---|
1047 | |
---|
1048 | StringBuffer text = new StringBuffer(); |
---|
1049 | |
---|
1050 | text.append(": "); |
---|
1051 | double[] predValues = currentNode.getValues(); |
---|
1052 | for (int i=0; i<m_numOfClasses; i++) { |
---|
1053 | text.append(Utils.doubleToString(predValues[i],3)); |
---|
1054 | if (i<m_numOfClasses-1) text.append(","); |
---|
1055 | } |
---|
1056 | for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) { |
---|
1057 | Splitter split = (Splitter) e.nextElement(); |
---|
1058 | |
---|
1059 | for (int j=0; j<split.getNumOfBranches(); j++) { |
---|
1060 | PredictionNode child = split.getChildForBranch(j); |
---|
1061 | if (child != null) { |
---|
1062 | text.append("\n"); |
---|
1063 | for (int k = 0; k < level; k++) { |
---|
1064 | text.append("| "); |
---|
1065 | } |
---|
1066 | text.append("(" + split.orderAdded + ")"); |
---|
1067 | text.append(split.attributeString() + " " + split.comparisonString(j)); |
---|
1068 | text.append(toString(child, level + 1)); |
---|
1069 | } |
---|
1070 | } |
---|
1071 | } |
---|
1072 | return text.toString(); |
---|
1073 | } |
---|
1074 | |
---|
1075 | /** |
---|
1076 | * Returns graph describing the tree. |
---|
1077 | * |
---|
1078 | * @return the graph of the tree in dotty format |
---|
1079 | * @exception Exception if something goes wrong |
---|
1080 | */ |
---|
1081 | public String graph() throws Exception { |
---|
1082 | |
---|
1083 | StringBuffer text = new StringBuffer(); |
---|
1084 | text.append("digraph ADTree {\n"); |
---|
1085 | //text.append("center=true\nsize=\"8.27,11.69\"\n"); |
---|
1086 | graphTraverse(m_root, text, 0, 0); |
---|
1087 | return text.toString() +"}\n"; |
---|
1088 | } |
---|
1089 | |
---|
1090 | |
---|
1091 | /** |
---|
1092 | * Traverses the tree, graphing each node. |
---|
1093 | * |
---|
1094 | * @param currentNode the currentNode under investigation |
---|
1095 | * @param text the string built so far |
---|
1096 | * @param splitOrder the order the parent splitter was added to the tree |
---|
1097 | * @param predOrder the order this predictor was added to the split |
---|
1098 | * @exception Exception if something goes wrong |
---|
1099 | */ |
---|
1100 | protected void graphTraverse(PredictionNode currentNode, StringBuffer text, |
---|
1101 | int splitOrder, int predOrder) |
---|
1102 | throws Exception |
---|
1103 | { |
---|
1104 | |
---|
1105 | text.append("S" + splitOrder + "P" + predOrder + " [label=\""); |
---|
1106 | double[] predValues = currentNode.getValues(); |
---|
1107 | for (int i=0; i<m_numOfClasses; i++) { |
---|
1108 | text.append(Utils.doubleToString(predValues[i],3)); |
---|
1109 | if (i<m_numOfClasses-1) text.append(","); |
---|
1110 | } |
---|
1111 | if (splitOrder == 0) // show legend in root |
---|
1112 | text.append(" (" + legend() + ")"); |
---|
1113 | text.append("\" shape=box style=filled]\n"); |
---|
1114 | for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) { |
---|
1115 | Splitter split = (Splitter) e.nextElement(); |
---|
1116 | text.append("S" + splitOrder + "P" + predOrder + "->" + "S" + split.orderAdded + |
---|
1117 | " [style=dotted]\n"); |
---|
1118 | text.append("S" + split.orderAdded + " [label=\"" + split.orderAdded + ": " + |
---|
1119 | split.attributeString() + "\"]\n"); |
---|
1120 | |
---|
1121 | for (int i=0; i<split.getNumOfBranches(); i++) { |
---|
1122 | PredictionNode child = split.getChildForBranch(i); |
---|
1123 | if (child != null) { |
---|
1124 | text.append("S" + split.orderAdded + "->" + "S" + split.orderAdded + "P" + i + |
---|
1125 | " [label=\"" + split.comparisonString(i) + "\"]\n"); |
---|
1126 | graphTraverse(child, text, split.orderAdded, i); |
---|
1127 | } |
---|
1128 | } |
---|
1129 | } |
---|
1130 | } |
---|
1131 | |
---|
1132 | /** |
---|
1133 | * Returns the legend of the tree, describing how results are to be interpreted. |
---|
1134 | * |
---|
1135 | * @return a string containing the legend of the classifier |
---|
1136 | */ |
---|
1137 | public String legend() { |
---|
1138 | |
---|
1139 | Attribute classAttribute = null; |
---|
1140 | if (m_trainInstances == null) return ""; |
---|
1141 | try {classAttribute = m_trainInstances.classAttribute();} catch (Exception x){}; |
---|
1142 | if (m_numOfClasses == 1) { |
---|
1143 | return ("-ve = " + classAttribute.value(0) |
---|
1144 | + ", +ve = " + classAttribute.value(1)); |
---|
1145 | } else { |
---|
1146 | StringBuffer text = new StringBuffer(); |
---|
1147 | for (int i=0; i<m_numOfClasses; i++) { |
---|
1148 | if (i>0) text.append(", "); |
---|
1149 | text.append(classAttribute.value(i)); |
---|
1150 | } |
---|
1151 | return text.toString(); |
---|
1152 | } |
---|
1153 | } |
---|
1154 | |
---|
1155 | |
---|
1156 | |
---|
1157 | /** option handling ******************************************************************/ |
---|
1158 | |
---|
1159 | /** |
---|
1160 | * @return tip text for this property suitable for |
---|
1161 | * displaying in the explorer/experimenter gui |
---|
1162 | */ |
---|
1163 | public String numOfBoostingIterationsTipText() { |
---|
1164 | |
---|
1165 | return "The number of boosting iterations to use, which determines the size of the tree."; |
---|
1166 | } |
---|
1167 | |
---|
1168 | /** |
---|
1169 | * Gets the number of boosting iterations. |
---|
1170 | * |
---|
1171 | * @return the number of boosting iterations |
---|
1172 | */ |
---|
1173 | public int getNumOfBoostingIterations() { |
---|
1174 | |
---|
1175 | return m_boostingIterations; |
---|
1176 | } |
---|
1177 | |
---|
1178 | /** |
---|
1179 | * Sets the number of boosting iterations. |
---|
1180 | * |
---|
1181 | * @param b the number of boosting iterations to use |
---|
1182 | */ |
---|
1183 | public void setNumOfBoostingIterations(int b) { |
---|
1184 | |
---|
1185 | m_boostingIterations = b; |
---|
1186 | } |
---|
1187 | |
---|
1188 | /** |
---|
1189 | * Returns an enumeration describing the available options. |
---|
1190 | * |
---|
1191 | * @return an enumeration of all the available options |
---|
1192 | */ |
---|
1193 | public Enumeration listOptions() { |
---|
1194 | |
---|
1195 | Vector newVector = new Vector(1); |
---|
1196 | newVector.addElement(new Option( |
---|
1197 | "\tNumber of boosting iterations.\n" |
---|
1198 | +"\t(Default = 10)", |
---|
1199 | "B", 1,"-B <number of boosting iterations>")); |
---|
1200 | |
---|
1201 | Enumeration enu = super.listOptions(); |
---|
1202 | while (enu.hasMoreElements()) { |
---|
1203 | newVector.addElement(enu.nextElement()); |
---|
1204 | } |
---|
1205 | |
---|
1206 | return newVector.elements(); |
---|
1207 | } |
---|
1208 | |
---|
1209 | /** |
---|
1210 | * Parses a given list of options. Valid options are:<p> |
---|
1211 | * |
---|
1212 | * -B num <br> |
---|
1213 | * Set the number of boosting iterations |
---|
1214 | * (default 10) <p> |
---|
1215 | * |
---|
1216 | * @param options the list of options as an array of strings |
---|
1217 | * @exception Exception if an option is not supported |
---|
1218 | */ |
---|
1219 | public void setOptions(String[] options) throws Exception { |
---|
1220 | |
---|
1221 | String bString = Utils.getOption('B', options); |
---|
1222 | if (bString.length() != 0) setNumOfBoostingIterations(Integer.parseInt(bString)); |
---|
1223 | |
---|
1224 | super.setOptions(options); |
---|
1225 | |
---|
1226 | Utils.checkForRemainingOptions(options); |
---|
1227 | } |
---|
1228 | |
---|
1229 | /** |
---|
1230 | * Gets the current settings of ADTree. |
---|
1231 | * |
---|
1232 | * @return an array of strings suitable for passing to setOptions() |
---|
1233 | */ |
---|
1234 | public String[] getOptions() { |
---|
1235 | |
---|
1236 | String[] options = new String[2 + super.getOptions().length]; |
---|
1237 | |
---|
1238 | int current = 0; |
---|
1239 | options[current++] = "-B"; options[current++] = "" + getNumOfBoostingIterations(); |
---|
1240 | |
---|
1241 | System.arraycopy(super.getOptions(), 0, options, current, super.getOptions().length); |
---|
1242 | |
---|
1243 | while (current < options.length) options[current++] = ""; |
---|
1244 | return options; |
---|
1245 | } |
---|
1246 | |
---|
1247 | |
---|
1248 | |
---|
1249 | /** additional measures ***************************************************************/ |
---|
1250 | |
---|
1251 | /** |
---|
1252 | * Calls measure function for tree size. |
---|
1253 | * |
---|
1254 | * @return the tree size |
---|
1255 | */ |
---|
1256 | public double measureTreeSize() { |
---|
1257 | |
---|
1258 | return numOfAllNodes(m_root); |
---|
1259 | } |
---|
1260 | |
---|
1261 | /** |
---|
1262 | * Calls measure function for leaf size. |
---|
1263 | * |
---|
1264 | * @return the leaf size |
---|
1265 | */ |
---|
1266 | public double measureNumLeaves() { |
---|
1267 | |
---|
1268 | return numOfPredictionNodes(m_root); |
---|
1269 | } |
---|
1270 | |
---|
1271 | /** |
---|
1272 | * Calls measure function for leaf size. |
---|
1273 | * |
---|
1274 | * @return the leaf size |
---|
1275 | */ |
---|
1276 | public double measureNumPredictionLeaves() { |
---|
1277 | |
---|
1278 | return numOfLeafNodes(m_root); |
---|
1279 | } |
---|
1280 | |
---|
1281 | /** |
---|
1282 | * Returns the number of nodes expanded. |
---|
1283 | * |
---|
1284 | * @return the number of nodes expanded during search |
---|
1285 | */ |
---|
1286 | public double measureNodesExpanded() { |
---|
1287 | |
---|
1288 | return m_nodesExpanded; |
---|
1289 | } |
---|
1290 | |
---|
1291 | /** |
---|
1292 | * Returns the number of examples "counted". |
---|
1293 | * |
---|
1294 | * @return the number of nodes processed during search |
---|
1295 | */ |
---|
1296 | public double measureExamplesCounted() { |
---|
1297 | |
---|
1298 | return m_examplesCounted; |
---|
1299 | } |
---|
1300 | |
---|
1301 | /** |
---|
1302 | * Returns an enumeration of the additional measure names. |
---|
1303 | * |
---|
1304 | * @return an enumeration of the measure names |
---|
1305 | */ |
---|
1306 | public Enumeration enumerateMeasures() { |
---|
1307 | |
---|
1308 | Vector newVector = new Vector(5); |
---|
1309 | newVector.addElement("measureTreeSize"); |
---|
1310 | newVector.addElement("measureNumLeaves"); |
---|
1311 | newVector.addElement("measureNumPredictionLeaves"); |
---|
1312 | newVector.addElement("measureNodesExpanded"); |
---|
1313 | newVector.addElement("measureExamplesCounted"); |
---|
1314 | return newVector.elements(); |
---|
1315 | } |
---|
1316 | |
---|
1317 | /** |
---|
1318 | * Returns the value of the named measure. |
---|
1319 | * |
---|
1320 | * @param additionalMeasureName the name of the measure to query for its value |
---|
1321 | * @return the value of the named measure |
---|
1322 | * @exception IllegalArgumentException if the named measure is not supported |
---|
1323 | */ |
---|
1324 | public double getMeasure(String additionalMeasureName) { |
---|
1325 | |
---|
1326 | if (additionalMeasureName.equals("measureTreeSize")) { |
---|
1327 | return measureTreeSize(); |
---|
1328 | } |
---|
1329 | else if (additionalMeasureName.equals("measureNodesExpanded")) { |
---|
1330 | return measureNodesExpanded(); |
---|
1331 | } |
---|
1332 | else if (additionalMeasureName.equals("measureNumLeaves")) { |
---|
1333 | return measureNumLeaves(); |
---|
1334 | } |
---|
1335 | else if (additionalMeasureName.equals("measureNumPredictionLeaves")) { |
---|
1336 | return measureNumPredictionLeaves(); |
---|
1337 | } |
---|
1338 | else if (additionalMeasureName.equals("measureExamplesCounted")) { |
---|
1339 | return measureExamplesCounted(); |
---|
1340 | } |
---|
1341 | else {throw new IllegalArgumentException(additionalMeasureName |
---|
1342 | + " not supported (ADTree)"); |
---|
1343 | } |
---|
1344 | } |
---|
1345 | |
---|
1346 | /** |
---|
1347 | * Returns the number of prediction nodes in a tree. |
---|
1348 | * |
---|
1349 | * @param root the root of the tree being measured |
---|
1350 | * @return tree size in number of prediction nodes |
---|
1351 | */ |
---|
1352 | protected int numOfPredictionNodes(PredictionNode root) { |
---|
1353 | |
---|
1354 | int numSoFar = 0; |
---|
1355 | if (root != null) { |
---|
1356 | numSoFar++; |
---|
1357 | for (Enumeration e = root.children(); e.hasMoreElements(); ) { |
---|
1358 | Splitter split = (Splitter) e.nextElement(); |
---|
1359 | for (int i=0; i<split.getNumOfBranches(); i++) |
---|
1360 | numSoFar += numOfPredictionNodes(split.getChildForBranch(i)); |
---|
1361 | } |
---|
1362 | } |
---|
1363 | return numSoFar; |
---|
1364 | } |
---|
1365 | |
---|
1366 | /** |
---|
1367 | * Returns the number of leaf nodes in a tree. |
---|
1368 | * |
---|
1369 | * @param root the root of the tree being measured |
---|
1370 | * @return tree leaf size in number of prediction nodes |
---|
1371 | */ |
---|
1372 | protected int numOfLeafNodes(PredictionNode root) { |
---|
1373 | |
---|
1374 | int numSoFar = 0; |
---|
1375 | if (root.getChildren().size() > 0) { |
---|
1376 | for (Enumeration e = root.children(); e.hasMoreElements(); ) { |
---|
1377 | Splitter split = (Splitter) e.nextElement(); |
---|
1378 | for (int i=0; i<split.getNumOfBranches(); i++) |
---|
1379 | numSoFar += numOfLeafNodes(split.getChildForBranch(i)); |
---|
1380 | } |
---|
1381 | } else numSoFar = 1; |
---|
1382 | return numSoFar; |
---|
1383 | } |
---|
1384 | |
---|
1385 | /** |
---|
1386 | * Returns the total number of nodes in a tree. |
---|
1387 | * |
---|
1388 | * @param root the root of the tree being measured |
---|
1389 | * @return tree size in number of splitter + prediction nodes |
---|
1390 | */ |
---|
1391 | protected int numOfAllNodes(PredictionNode root) { |
---|
1392 | |
---|
1393 | int numSoFar = 0; |
---|
1394 | if (root != null) { |
---|
1395 | numSoFar++; |
---|
1396 | for (Enumeration e = root.children(); e.hasMoreElements(); ) { |
---|
1397 | numSoFar++; |
---|
1398 | Splitter split = (Splitter) e.nextElement(); |
---|
1399 | for (int i=0; i<split.getNumOfBranches(); i++) |
---|
1400 | numSoFar += numOfAllNodes(split.getChildForBranch(i)); |
---|
1401 | } |
---|
1402 | } |
---|
1403 | return numSoFar; |
---|
1404 | } |
---|
1405 | |
---|
1406 | /** main functions ********************************************************************/ |
---|
1407 | |
---|
1408 | /** |
---|
1409 | * Builds a classifier for a set of instances. |
---|
1410 | * |
---|
1411 | * @param instances the instances to train the classifier with |
---|
1412 | * @exception Exception if something goes wrong |
---|
1413 | */ |
---|
1414 | public void buildClassifier(Instances instances) throws Exception { |
---|
1415 | |
---|
1416 | // set up the tree |
---|
1417 | initClassifier(instances); |
---|
1418 | |
---|
1419 | // build the tree |
---|
1420 | for (int T = 0; T < m_boostingIterations; T++) { |
---|
1421 | boost(); |
---|
1422 | } |
---|
1423 | } |
---|
1424 | |
---|
1425 | public int predictiveError(Instances test) { |
---|
1426 | int error = 0; |
---|
1427 | for(int i = test.numInstances()-1; i>=0; i--) { |
---|
1428 | Instance inst = test.instance(i); |
---|
1429 | try { |
---|
1430 | if (classifyInstance(inst) != inst.classValue()) |
---|
1431 | error++; |
---|
1432 | } catch (Exception e) { error++;} |
---|
1433 | } |
---|
1434 | return error; |
---|
1435 | } |
---|
1436 | |
---|
1437 | /** |
---|
1438 | * Merges two trees together. Modifies the tree being acted on, leaving tree passed |
---|
1439 | * as a parameter untouched (cloned). Does not check to see whether training instances |
---|
1440 | * are compatible - strange things could occur if they are not. |
---|
1441 | * |
---|
1442 | * @param mergeWith the tree to merge with |
---|
1443 | * @exception Exception if merge could not be performed |
---|
1444 | */ |
---|
1445 | public void merge(LADTree mergeWith) throws Exception { |
---|
1446 | |
---|
1447 | if (m_root == null || mergeWith.m_root == null) |
---|
1448 | throw new Exception("Trying to merge an uninitialized tree"); |
---|
1449 | if (m_numOfClasses != mergeWith.m_numOfClasses) |
---|
1450 | throw new Exception("Trees not suitable for merge - " |
---|
1451 | + "different sized prediction nodes"); |
---|
1452 | m_root.merge(mergeWith.m_root); |
---|
1453 | } |
---|
1454 | |
---|
1455 | /** |
---|
1456 | * Returns the type of graph this classifier |
---|
1457 | * represents. |
---|
1458 | * @return Drawable.TREE |
---|
1459 | */ |
---|
1460 | public int graphType() { |
---|
1461 | return Drawable.TREE; |
---|
1462 | } |
---|
1463 | |
---|
1464 | /** |
---|
1465 | * Returns the revision string. |
---|
1466 | * |
---|
1467 | * @return the revision |
---|
1468 | */ |
---|
1469 | public String getRevision() { |
---|
1470 | return RevisionUtils.extract("$Revision: 6035 $"); |
---|
1471 | } |
---|
1472 | |
---|
1473 | /** |
---|
1474 | * Returns default capabilities of the classifier. |
---|
1475 | * |
---|
1476 | * @return the capabilities of this classifier |
---|
1477 | */ |
---|
1478 | public Capabilities getCapabilities() { |
---|
1479 | Capabilities result = super.getCapabilities(); |
---|
1480 | result.disableAll(); |
---|
1481 | |
---|
1482 | // attributes |
---|
1483 | result.enable(Capability.NOMINAL_ATTRIBUTES); |
---|
1484 | result.enable(Capability.NUMERIC_ATTRIBUTES); |
---|
1485 | result.enable(Capability.DATE_ATTRIBUTES); |
---|
1486 | result.enable(Capability.MISSING_VALUES); |
---|
1487 | |
---|
1488 | // class |
---|
1489 | result.enable(Capability.NOMINAL_CLASS); |
---|
1490 | result.enable(Capability.MISSING_CLASS_VALUES); |
---|
1491 | |
---|
1492 | return result; |
---|
1493 | } |
---|
1494 | |
---|
1495 | /** |
---|
1496 | * Main method for testing this class. |
---|
1497 | * |
---|
1498 | * @param argv the options |
---|
1499 | */ |
---|
1500 | public static void main(String [] argv) { |
---|
1501 | runClassifier(new LADTree(), argv); |
---|
1502 | } |
---|
1503 | |
---|
1504 | } |
---|
1505 | |
---|