1 | /* |
---|
2 | * This program is free software; you can redistribute it and/or modify |
---|
3 | * it under the terms of the GNU General Public License as published by |
---|
4 | * the Free Software Foundation; either version 2 of the License, or |
---|
5 | * (at your option) any later version. |
---|
6 | * |
---|
7 | * This program is distributed in the hope that it will be useful, |
---|
8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
---|
10 | * GNU General Public License for more details. |
---|
11 | * |
---|
12 | * You should have received a copy of the GNU General Public License |
---|
13 | * along with this program; if not, write to the Free Software |
---|
14 | * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. |
---|
15 | */ |
---|
16 | |
---|
17 | /* |
---|
18 | * RandomTree.java |
---|
19 | * Copyright (C) 2001 University of Waikato, Hamilton, New Zealand |
---|
20 | * |
---|
21 | */ |
---|
22 | |
---|
23 | package weka.classifiers.trees; |
---|
24 | |
---|
25 | import weka.classifiers.Classifier; |
---|
26 | import weka.classifiers.AbstractClassifier; |
---|
27 | import weka.core.Attribute; |
---|
28 | import weka.core.Capabilities; |
---|
29 | import weka.core.ContingencyTables; |
---|
30 | import weka.core.Drawable; |
---|
31 | import weka.core.Instance; |
---|
32 | import weka.core.Instances; |
---|
33 | import weka.core.Option; |
---|
34 | import weka.core.OptionHandler; |
---|
35 | import weka.core.Randomizable; |
---|
36 | import weka.core.RevisionUtils; |
---|
37 | import weka.core.Utils; |
---|
38 | import weka.core.WeightedInstancesHandler; |
---|
39 | import weka.core.Capabilities.Capability; |
---|
40 | |
---|
41 | import java.util.Enumeration; |
---|
42 | import java.util.Random; |
---|
43 | import java.util.Vector; |
---|
44 | |
---|
45 | /** |
---|
46 | * <!-- globalinfo-start --> |
---|
47 | * Class for constructing a tree that considers K randomly chosen attributes at each node. Performs no pruning. Also has an option to allow estimation of class probabilities based on a hold-out set (backfitting). |
---|
48 | * <p/> |
---|
49 | * <!-- globalinfo-end --> |
---|
50 | * |
---|
51 | * <!-- options-start --> |
---|
52 | * Valid options are: <p/> |
---|
53 | * |
---|
54 | * <pre> -K <number of attributes> |
---|
55 | * Number of attributes to randomly investigate |
---|
56 | * (<0 = int(log_2(#attributes)+1)).</pre> |
---|
57 | * |
---|
58 | * <pre> -M <minimum number of instances> |
---|
59 | * Set minimum number of instances per leaf.</pre> |
---|
60 | * |
---|
61 | * <pre> -S <num> |
---|
62 | * Seed for random number generator. |
---|
63 | * (default 1)</pre> |
---|
64 | * |
---|
65 | * <pre> -depth <num> |
---|
66 | * The maximum depth of the tree, 0 for unlimited. |
---|
67 | * (default 0)</pre> |
---|
68 | * |
---|
69 | * <pre> -N <num> |
---|
70 | * Number of folds for backfitting (default 0, no backfitting).</pre> |
---|
71 | * |
---|
72 | * <pre> -U |
---|
73 | * Allow unclassified instances.</pre> |
---|
74 | * |
---|
75 | * <pre> -D |
---|
76 | * If set, classifier is run in debug mode and |
---|
77 | * may output additional info to the console</pre> |
---|
78 | * |
---|
79 | * <!-- options-end --> |
---|
80 | * |
---|
81 | * @author Eibe Frank (eibe@cs.waikato.ac.nz) |
---|
82 | * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz) |
---|
83 | * @version $Revision: 5928 $ |
---|
84 | */ |
---|
85 | public class RandomTree extends AbstractClassifier implements OptionHandler, |
---|
86 | WeightedInstancesHandler, Randomizable, Drawable { |
---|
87 | |
---|
88 | /** for serialization */ |
---|
89 | static final long serialVersionUID = 8934314652175299374L; |
---|
90 | |
---|
91 | /** The subtrees appended to this tree. */ |
---|
92 | protected RandomTree[] m_Successors; |
---|
93 | |
---|
94 | /** The attribute to split on. */ |
---|
95 | protected int m_Attribute = -1; |
---|
96 | |
---|
97 | /** The split point. */ |
---|
98 | protected double m_SplitPoint = Double.NaN; |
---|
99 | |
---|
100 | /** The header information. */ |
---|
101 | protected Instances m_Info = null; |
---|
102 | |
---|
103 | /** The proportions of training instances going down each branch. */ |
---|
104 | protected double[] m_Prop = null; |
---|
105 | |
---|
106 | /** Class probabilities from the training data. */ |
---|
107 | protected double[] m_ClassDistribution = null; |
---|
108 | |
---|
109 | /** Minimum number of instances for leaf. */ |
---|
110 | protected double m_MinNum = 1.0; |
---|
111 | |
---|
112 | /** The number of attributes considered for a split. */ |
---|
113 | protected int m_KValue = 0; |
---|
114 | |
---|
115 | /** The random seed to use. */ |
---|
116 | protected int m_randomSeed = 1; |
---|
117 | |
---|
118 | /** The maximum depth of the tree (0 = unlimited) */ |
---|
119 | protected int m_MaxDepth = 0; |
---|
120 | |
---|
121 | /** Determines how much data is used for backfitting */ |
---|
122 | protected int m_NumFolds = 0; |
---|
123 | |
---|
124 | /** Whether unclassified instances are allowed */ |
---|
125 | protected boolean m_AllowUnclassifiedInstances = false; |
---|
126 | |
---|
127 | /** a ZeroR model in case no model can be built from the data */ |
---|
128 | protected Classifier m_ZeroR; |
---|
129 | |
---|
130 | /** |
---|
131 | * Returns a string describing classifier |
---|
132 | * |
---|
133 | * @return a description suitable for displaying in the |
---|
134 | * explorer/experimenter gui |
---|
135 | */ |
---|
136 | public String globalInfo() { |
---|
137 | |
---|
138 | return "Class for constructing a tree that considers K randomly " |
---|
139 | + " chosen attributes at each node. Performs no pruning. Also has" |
---|
140 | + " an option to allow estimation of class probabilities based on" |
---|
141 | + " a hold-out set (backfitting)."; |
---|
142 | } |
---|
143 | |
---|
144 | /** |
---|
145 | * Returns the tip text for this property |
---|
146 | * |
---|
147 | * @return tip text for this property suitable for displaying in the |
---|
148 | * explorer/experimenter gui |
---|
149 | */ |
---|
150 | public String minNumTipText() { |
---|
151 | return "The minimum total weight of the instances in a leaf."; |
---|
152 | } |
---|
153 | |
---|
154 | /** |
---|
155 | * Get the value of MinNum. |
---|
156 | * |
---|
157 | * @return Value of MinNum. |
---|
158 | */ |
---|
159 | public double getMinNum() { |
---|
160 | |
---|
161 | return m_MinNum; |
---|
162 | } |
---|
163 | |
---|
164 | /** |
---|
165 | * Set the value of MinNum. |
---|
166 | * |
---|
167 | * @param newMinNum |
---|
168 | * Value to assign to MinNum. |
---|
169 | */ |
---|
170 | public void setMinNum(double newMinNum) { |
---|
171 | |
---|
172 | m_MinNum = newMinNum; |
---|
173 | } |
---|
174 | |
---|
175 | /** |
---|
176 | * Returns the tip text for this property |
---|
177 | * |
---|
178 | * @return tip text for this property suitable for displaying in the |
---|
179 | * explorer/experimenter gui |
---|
180 | */ |
---|
181 | public String KValueTipText() { |
---|
182 | return "Sets the number of randomly chosen attributes. If 0, log_2(number_of_attributes) + 1 is used."; |
---|
183 | } |
---|
184 | |
---|
185 | /** |
---|
186 | * Get the value of K. |
---|
187 | * |
---|
188 | * @return Value of K. |
---|
189 | */ |
---|
190 | public int getKValue() { |
---|
191 | |
---|
192 | return m_KValue; |
---|
193 | } |
---|
194 | |
---|
195 | /** |
---|
196 | * Set the value of K. |
---|
197 | * |
---|
198 | * @param k |
---|
199 | * Value to assign to K. |
---|
200 | */ |
---|
201 | public void setKValue(int k) { |
---|
202 | |
---|
203 | m_KValue = k; |
---|
204 | } |
---|
205 | |
---|
206 | /** |
---|
207 | * Returns the tip text for this property |
---|
208 | * |
---|
209 | * @return tip text for this property suitable for displaying in the |
---|
210 | * explorer/experimenter gui |
---|
211 | */ |
---|
212 | public String seedTipText() { |
---|
213 | return "The random number seed used for selecting attributes."; |
---|
214 | } |
---|
215 | |
---|
216 | /** |
---|
217 | * Set the seed for random number generation. |
---|
218 | * |
---|
219 | * @param seed |
---|
220 | * the seed |
---|
221 | */ |
---|
222 | public void setSeed(int seed) { |
---|
223 | |
---|
224 | m_randomSeed = seed; |
---|
225 | } |
---|
226 | |
---|
227 | /** |
---|
228 | * Gets the seed for the random number generations |
---|
229 | * |
---|
230 | * @return the seed for the random number generation |
---|
231 | */ |
---|
232 | public int getSeed() { |
---|
233 | |
---|
234 | return m_randomSeed; |
---|
235 | } |
---|
236 | |
---|
237 | /** |
---|
238 | * Returns the tip text for this property |
---|
239 | * |
---|
240 | * @return tip text for this property suitable for displaying in the |
---|
241 | * explorer/experimenter gui |
---|
242 | */ |
---|
243 | public String maxDepthTipText() { |
---|
244 | return "The maximum depth of the tree, 0 for unlimited."; |
---|
245 | } |
---|
246 | |
---|
247 | /** |
---|
248 | * Get the maximum depth of trh tree, 0 for unlimited. |
---|
249 | * |
---|
250 | * @return the maximum depth. |
---|
251 | */ |
---|
252 | public int getMaxDepth() { |
---|
253 | return m_MaxDepth; |
---|
254 | } |
---|
255 | |
---|
256 | /** |
---|
257 | * Returns the tip text for this property |
---|
258 | * @return tip text for this property suitable for |
---|
259 | * displaying in the explorer/experimenter gui |
---|
260 | */ |
---|
261 | public String numFoldsTipText() { |
---|
262 | return "Determines the amount of data used for backfitting. One fold is used for " |
---|
263 | + "backfitting, the rest for growing the tree. (Default: 0, no backfitting)"; |
---|
264 | } |
---|
265 | |
---|
266 | /** |
---|
267 | * Get the value of NumFolds. |
---|
268 | * |
---|
269 | * @return Value of NumFolds. |
---|
270 | */ |
---|
271 | public int getNumFolds() { |
---|
272 | |
---|
273 | return m_NumFolds; |
---|
274 | } |
---|
275 | |
---|
276 | /** |
---|
277 | * Set the value of NumFolds. |
---|
278 | * |
---|
279 | * @param newNumFolds Value to assign to NumFolds. |
---|
280 | */ |
---|
281 | public void setNumFolds(int newNumFolds) { |
---|
282 | |
---|
283 | m_NumFolds = newNumFolds; |
---|
284 | } |
---|
285 | |
---|
286 | /** |
---|
287 | * Returns the tip text for this property |
---|
288 | * @return tip text for this property suitable for |
---|
289 | * displaying in the explorer/experimenter gui |
---|
290 | */ |
---|
291 | public String allowUnclassifiedInstancesTipText() { |
---|
292 | return "Whether to allow unclassified instances."; |
---|
293 | } |
---|
294 | |
---|
295 | /** |
---|
296 | * Get the value of NumFolds. |
---|
297 | * |
---|
298 | * @return Value of NumFolds. |
---|
299 | */ |
---|
300 | public boolean getAllowUnclassifiedInstances() { |
---|
301 | |
---|
302 | return m_AllowUnclassifiedInstances; |
---|
303 | } |
---|
304 | |
---|
305 | /** |
---|
306 | * Set the value of AllowUnclassifiedInstances. |
---|
307 | * |
---|
308 | * @param newAllowUnclassifiedInstances Value to assign to AllowUnclassifiedInstances. |
---|
309 | */ |
---|
310 | public void setAllowUnclassifiedInstances(boolean newAllowUnclassifiedInstances) { |
---|
311 | |
---|
312 | m_AllowUnclassifiedInstances = newAllowUnclassifiedInstances; |
---|
313 | } |
---|
314 | |
---|
315 | /** |
---|
316 | * Set the maximum depth of the tree, 0 for unlimited. |
---|
317 | * |
---|
318 | * @param value |
---|
319 | * the maximum depth. |
---|
320 | */ |
---|
321 | public void setMaxDepth(int value) { |
---|
322 | m_MaxDepth = value; |
---|
323 | } |
---|
324 | |
---|
325 | /** |
---|
326 | * Lists the command-line options for this classifier. |
---|
327 | * |
---|
328 | * @return an enumeration over all possible options |
---|
329 | */ |
---|
330 | public Enumeration listOptions() { |
---|
331 | |
---|
332 | Vector newVector = new Vector(); |
---|
333 | |
---|
334 | newVector.addElement(new Option( |
---|
335 | "\tNumber of attributes to randomly investigate\n" |
---|
336 | + "\t(<0 = int(log_2(#attributes)+1)).", "K", 1, |
---|
337 | "-K <number of attributes>")); |
---|
338 | |
---|
339 | newVector.addElement(new Option( |
---|
340 | "\tSet minimum number of instances per leaf.", "M", 1, |
---|
341 | "-M <minimum number of instances>")); |
---|
342 | |
---|
343 | newVector.addElement(new Option("\tSeed for random number generator.\n" |
---|
344 | + "\t(default 1)", "S", 1, "-S <num>")); |
---|
345 | |
---|
346 | newVector.addElement(new Option( |
---|
347 | "\tThe maximum depth of the tree, 0 for unlimited.\n" |
---|
348 | + "\t(default 0)", "depth", 1, "-depth <num>")); |
---|
349 | |
---|
350 | newVector. |
---|
351 | addElement(new Option("\tNumber of folds for backfitting " + |
---|
352 | "(default 0, no backfitting).", |
---|
353 | "N", 1, "-N <num>")); |
---|
354 | newVector. |
---|
355 | addElement(new Option("\tAllow unclassified instances.", |
---|
356 | "U", 0, "-U")); |
---|
357 | |
---|
358 | Enumeration enu = super.listOptions(); |
---|
359 | while (enu.hasMoreElements()) { |
---|
360 | newVector.addElement(enu.nextElement()); |
---|
361 | } |
---|
362 | |
---|
363 | return newVector.elements(); |
---|
364 | } |
---|
365 | |
---|
366 | /** |
---|
367 | * Gets options from this classifier. |
---|
368 | * |
---|
369 | * @return the options for the current setup |
---|
370 | */ |
---|
371 | public String[] getOptions() { |
---|
372 | Vector result; |
---|
373 | String[] options; |
---|
374 | int i; |
---|
375 | |
---|
376 | result = new Vector(); |
---|
377 | |
---|
378 | result.add("-K"); |
---|
379 | result.add("" + getKValue()); |
---|
380 | |
---|
381 | result.add("-M"); |
---|
382 | result.add("" + getMinNum()); |
---|
383 | |
---|
384 | result.add("-S"); |
---|
385 | result.add("" + getSeed()); |
---|
386 | |
---|
387 | if (getMaxDepth() > 0) { |
---|
388 | result.add("-depth"); |
---|
389 | result.add("" + getMaxDepth()); |
---|
390 | } |
---|
391 | |
---|
392 | if (getNumFolds() > 0) { |
---|
393 | result.add("-N"); |
---|
394 | result.add("" + getNumFolds()); |
---|
395 | } |
---|
396 | |
---|
397 | if (getAllowUnclassifiedInstances()) { |
---|
398 | result.add("-U"); |
---|
399 | } |
---|
400 | |
---|
401 | options = super.getOptions(); |
---|
402 | for (i = 0; i < options.length; i++) |
---|
403 | result.add(options[i]); |
---|
404 | |
---|
405 | return (String[]) result.toArray(new String[result.size()]); |
---|
406 | } |
---|
407 | |
---|
408 | /** |
---|
409 | * Parses a given list of options. |
---|
410 | * <p/> |
---|
411 | * |
---|
412 | * <!-- options-start --> |
---|
413 | * Valid options are: <p/> |
---|
414 | * |
---|
415 | * <pre> -K <number of attributes> |
---|
416 | * Number of attributes to randomly investigate |
---|
417 | * (<0 = int(log_2(#attributes)+1)).</pre> |
---|
418 | * |
---|
419 | * <pre> -M <minimum number of instances> |
---|
420 | * Set minimum number of instances per leaf.</pre> |
---|
421 | * |
---|
422 | * <pre> -S <num> |
---|
423 | * Seed for random number generator. |
---|
424 | * (default 1)</pre> |
---|
425 | * |
---|
426 | * <pre> -depth <num> |
---|
427 | * The maximum depth of the tree, 0 for unlimited. |
---|
428 | * (default 0)</pre> |
---|
429 | * |
---|
430 | * <pre> -N <num> |
---|
431 | * Number of folds for backfitting (default 0, no backfitting).</pre> |
---|
432 | * |
---|
433 | * <pre> -U |
---|
434 | * Allow unclassified instances.</pre> |
---|
435 | * |
---|
436 | * <pre> -D |
---|
437 | * If set, classifier is run in debug mode and |
---|
438 | * may output additional info to the console</pre> |
---|
439 | * |
---|
440 | * <!-- options-end --> |
---|
441 | * |
---|
442 | * @param options |
---|
443 | * the list of options as an array of strings |
---|
444 | * @throws Exception |
---|
445 | * if an option is not supported |
---|
446 | */ |
---|
447 | public void setOptions(String[] options) throws Exception { |
---|
448 | String tmpStr; |
---|
449 | |
---|
450 | tmpStr = Utils.getOption('K', options); |
---|
451 | if (tmpStr.length() != 0) { |
---|
452 | m_KValue = Integer.parseInt(tmpStr); |
---|
453 | } else { |
---|
454 | m_KValue = 0; |
---|
455 | } |
---|
456 | |
---|
457 | tmpStr = Utils.getOption('M', options); |
---|
458 | if (tmpStr.length() != 0) { |
---|
459 | m_MinNum = Double.parseDouble(tmpStr); |
---|
460 | } else { |
---|
461 | m_MinNum = 1; |
---|
462 | } |
---|
463 | |
---|
464 | tmpStr = Utils.getOption('S', options); |
---|
465 | if (tmpStr.length() != 0) { |
---|
466 | setSeed(Integer.parseInt(tmpStr)); |
---|
467 | } else { |
---|
468 | setSeed(1); |
---|
469 | } |
---|
470 | |
---|
471 | tmpStr = Utils.getOption("depth", options); |
---|
472 | if (tmpStr.length() != 0) { |
---|
473 | setMaxDepth(Integer.parseInt(tmpStr)); |
---|
474 | } else { |
---|
475 | setMaxDepth(0); |
---|
476 | } |
---|
477 | String numFoldsString = Utils.getOption('N', options); |
---|
478 | if (numFoldsString.length() != 0) { |
---|
479 | m_NumFolds = Integer.parseInt(numFoldsString); |
---|
480 | } else { |
---|
481 | m_NumFolds = 0; |
---|
482 | } |
---|
483 | |
---|
484 | setAllowUnclassifiedInstances(Utils.getFlag('U', options)); |
---|
485 | |
---|
486 | super.setOptions(options); |
---|
487 | |
---|
488 | Utils.checkForRemainingOptions(options); |
---|
489 | } |
---|
490 | |
---|
491 | /** |
---|
492 | * Returns default capabilities of the classifier. |
---|
493 | * |
---|
494 | * @return the capabilities of this classifier |
---|
495 | */ |
---|
496 | public Capabilities getCapabilities() { |
---|
497 | Capabilities result = super.getCapabilities(); |
---|
498 | result.disableAll(); |
---|
499 | |
---|
500 | // attributes |
---|
501 | result.enable(Capability.NOMINAL_ATTRIBUTES); |
---|
502 | result.enable(Capability.NUMERIC_ATTRIBUTES); |
---|
503 | result.enable(Capability.DATE_ATTRIBUTES); |
---|
504 | result.enable(Capability.MISSING_VALUES); |
---|
505 | |
---|
506 | // class |
---|
507 | result.enable(Capability.NOMINAL_CLASS); |
---|
508 | result.enable(Capability.MISSING_CLASS_VALUES); |
---|
509 | |
---|
510 | return result; |
---|
511 | } |
---|
512 | |
---|
513 | /** |
---|
514 | * Builds classifier. |
---|
515 | * |
---|
516 | * @param data |
---|
517 | * the data to train with |
---|
518 | * @throws Exception |
---|
519 | * if something goes wrong or the data doesn't fit |
---|
520 | */ |
---|
521 | public void buildClassifier(Instances data) throws Exception { |
---|
522 | |
---|
523 | // Make sure K value is in range |
---|
524 | if (m_KValue > data.numAttributes() - 1) |
---|
525 | m_KValue = data.numAttributes() - 1; |
---|
526 | if (m_KValue < 1) |
---|
527 | m_KValue = (int) Utils.log2(data.numAttributes()) + 1; |
---|
528 | |
---|
529 | // can classifier handle the data? |
---|
530 | getCapabilities().testWithFail(data); |
---|
531 | |
---|
532 | // remove instances with missing class |
---|
533 | data = new Instances(data); |
---|
534 | data.deleteWithMissingClass(); |
---|
535 | |
---|
536 | // only class? -> build ZeroR model |
---|
537 | if (data.numAttributes() == 1) { |
---|
538 | System.err |
---|
539 | .println("Cannot build model (only class attribute present in data!), " |
---|
540 | + "using ZeroR model instead!"); |
---|
541 | m_ZeroR = new weka.classifiers.rules.ZeroR(); |
---|
542 | m_ZeroR.buildClassifier(data); |
---|
543 | return; |
---|
544 | } else { |
---|
545 | m_ZeroR = null; |
---|
546 | } |
---|
547 | |
---|
548 | // Figure out appropriate datasets |
---|
549 | Instances train = null; |
---|
550 | Instances backfit = null; |
---|
551 | Random rand = data.getRandomNumberGenerator(m_randomSeed); |
---|
552 | if (m_NumFolds <= 0) { |
---|
553 | train = data; |
---|
554 | } else { |
---|
555 | data.randomize(rand); |
---|
556 | data.stratify(m_NumFolds); |
---|
557 | train = data.trainCV(m_NumFolds, 1, rand); |
---|
558 | backfit = data.testCV(m_NumFolds, 1); |
---|
559 | } |
---|
560 | |
---|
561 | // Create the attribute indices window |
---|
562 | int[] attIndicesWindow = new int[data.numAttributes() - 1]; |
---|
563 | int j = 0; |
---|
564 | for (int i = 0; i < attIndicesWindow.length; i++) { |
---|
565 | if (j == data.classIndex()) |
---|
566 | j++; // do not include the class |
---|
567 | attIndicesWindow[i] = j++; |
---|
568 | } |
---|
569 | |
---|
570 | // Compute initial class counts |
---|
571 | double[] classProbs = new double[train.numClasses()]; |
---|
572 | for (int i = 0; i < train.numInstances(); i++) { |
---|
573 | Instance inst = train.instance(i); |
---|
574 | classProbs[(int) inst.classValue()] += inst.weight(); |
---|
575 | } |
---|
576 | |
---|
577 | // Build tree |
---|
578 | buildTree(train, classProbs, new Instances(data, 0), m_MinNum, m_Debug, attIndicesWindow, |
---|
579 | rand, 0, getAllowUnclassifiedInstances()); |
---|
580 | |
---|
581 | // Backfit if required |
---|
582 | if (backfit != null) { |
---|
583 | backfitData(backfit); |
---|
584 | } |
---|
585 | } |
---|
586 | |
---|
587 | /** |
---|
588 | * Backfits the given data into the tree. |
---|
589 | */ |
---|
590 | public void backfitData(Instances data) throws Exception { |
---|
591 | |
---|
592 | // Compute initial class counts |
---|
593 | double[] classProbs = new double[data.numClasses()]; |
---|
594 | for (int i = 0; i < data.numInstances(); i++) { |
---|
595 | Instance inst = data.instance(i); |
---|
596 | classProbs[(int) inst.classValue()] += inst.weight(); |
---|
597 | } |
---|
598 | |
---|
599 | // Fit data into tree |
---|
600 | backfitData(data, classProbs); |
---|
601 | } |
---|
602 | |
---|
603 | /** |
---|
604 | * Computes class distribution of an instance using the decision tree. |
---|
605 | * |
---|
606 | * @param instance |
---|
607 | * the instance to compute the distribution for |
---|
608 | * @return the computed class distribution |
---|
609 | * @throws Exception |
---|
610 | * if computation fails |
---|
611 | */ |
---|
612 | public double[] distributionForInstance(Instance instance) throws Exception { |
---|
613 | |
---|
614 | // default model? |
---|
615 | if (m_ZeroR != null) { |
---|
616 | return m_ZeroR.distributionForInstance(instance); |
---|
617 | } |
---|
618 | |
---|
619 | double[] returnedDist = null; |
---|
620 | |
---|
621 | if (m_Attribute > -1) { |
---|
622 | |
---|
623 | // Node is not a leaf |
---|
624 | if (instance.isMissing(m_Attribute)) { |
---|
625 | |
---|
626 | // Value is missing |
---|
627 | returnedDist = new double[m_Info.numClasses()]; |
---|
628 | |
---|
629 | // Split instance up |
---|
630 | for (int i = 0; i < m_Successors.length; i++) { |
---|
631 | double[] help = m_Successors[i] |
---|
632 | .distributionForInstance(instance); |
---|
633 | if (help != null) { |
---|
634 | for (int j = 0; j < help.length; j++) { |
---|
635 | returnedDist[j] += m_Prop[i] * help[j]; |
---|
636 | } |
---|
637 | } |
---|
638 | } |
---|
639 | } else if (m_Info.attribute(m_Attribute).isNominal()) { |
---|
640 | |
---|
641 | // For nominal attributes |
---|
642 | returnedDist = m_Successors[(int) instance.value(m_Attribute)] |
---|
643 | .distributionForInstance(instance); |
---|
644 | } else { |
---|
645 | |
---|
646 | // For numeric attributes |
---|
647 | if (instance.value(m_Attribute) < m_SplitPoint) { |
---|
648 | returnedDist = m_Successors[0] |
---|
649 | .distributionForInstance(instance); |
---|
650 | } else { |
---|
651 | returnedDist = m_Successors[1] |
---|
652 | .distributionForInstance(instance); |
---|
653 | } |
---|
654 | } |
---|
655 | } |
---|
656 | |
---|
657 | |
---|
658 | // Node is a leaf or successor is empty? |
---|
659 | if ((m_Attribute == -1) || (returnedDist == null)) { |
---|
660 | |
---|
661 | // Is node empty? |
---|
662 | if (m_ClassDistribution == null) { |
---|
663 | if (getAllowUnclassifiedInstances()) { |
---|
664 | return new double[m_Info.numClasses()]; |
---|
665 | } else { |
---|
666 | return null; |
---|
667 | } |
---|
668 | } |
---|
669 | |
---|
670 | // Else return normalized distribution |
---|
671 | double[] normalizedDistribution = (double[]) m_ClassDistribution.clone(); |
---|
672 | Utils.normalize(normalizedDistribution); |
---|
673 | return normalizedDistribution; |
---|
674 | } else { |
---|
675 | return returnedDist; |
---|
676 | } |
---|
677 | } |
---|
678 | |
---|
679 | /** |
---|
680 | * Outputs the decision tree as a graph |
---|
681 | * |
---|
682 | * @return the tree as a graph |
---|
683 | */ |
---|
684 | public String toGraph() { |
---|
685 | |
---|
686 | try { |
---|
687 | StringBuffer resultBuff = new StringBuffer(); |
---|
688 | toGraph(resultBuff, 0); |
---|
689 | String result = "digraph Tree {\n" + "edge [style=bold]\n" |
---|
690 | + resultBuff.toString() + "\n}\n"; |
---|
691 | return result; |
---|
692 | } catch (Exception e) { |
---|
693 | return null; |
---|
694 | } |
---|
695 | } |
---|
696 | |
---|
697 | /** |
---|
698 | * Outputs one node for graph. |
---|
699 | * |
---|
700 | * @param text |
---|
701 | * the buffer to append the output to |
---|
702 | * @param num |
---|
703 | * unique node id |
---|
704 | * @return the next node id |
---|
705 | * @throws Exception |
---|
706 | * if generation fails |
---|
707 | */ |
---|
708 | public int toGraph(StringBuffer text, int num) throws Exception { |
---|
709 | |
---|
710 | int maxIndex = Utils.maxIndex(m_ClassDistribution); |
---|
711 | String classValue = m_Info.classAttribute().value(maxIndex); |
---|
712 | |
---|
713 | num++; |
---|
714 | if (m_Attribute == -1) { |
---|
715 | text.append("N" + Integer.toHexString(hashCode()) + " [label=\"" |
---|
716 | + num + ": " + classValue + "\"" + "shape=box]\n"); |
---|
717 | } else { |
---|
718 | text.append("N" + Integer.toHexString(hashCode()) + " [label=\"" |
---|
719 | + num + ": " + classValue + "\"]\n"); |
---|
720 | for (int i = 0; i < m_Successors.length; i++) { |
---|
721 | text.append("N" + Integer.toHexString(hashCode()) + "->" + "N" |
---|
722 | + Integer.toHexString(m_Successors[i].hashCode()) |
---|
723 | + " [label=\"" + m_Info.attribute(m_Attribute).name()); |
---|
724 | if (m_Info.attribute(m_Attribute).isNumeric()) { |
---|
725 | if (i == 0) { |
---|
726 | text.append(" < " |
---|
727 | + Utils.doubleToString(m_SplitPoint, 2)); |
---|
728 | } else { |
---|
729 | text.append(" >= " |
---|
730 | + Utils.doubleToString(m_SplitPoint, 2)); |
---|
731 | } |
---|
732 | } else { |
---|
733 | text.append(" = " + m_Info.attribute(m_Attribute).value(i)); |
---|
734 | } |
---|
735 | text.append("\"]\n"); |
---|
736 | num = m_Successors[i].toGraph(text, num); |
---|
737 | } |
---|
738 | } |
---|
739 | |
---|
740 | return num; |
---|
741 | } |
---|
742 | |
---|
743 | /** |
---|
744 | * Outputs the decision tree. |
---|
745 | * |
---|
746 | * @return a string representation of the classifier |
---|
747 | */ |
---|
748 | public String toString() { |
---|
749 | |
---|
750 | // only ZeroR model? |
---|
751 | if (m_ZeroR != null) { |
---|
752 | StringBuffer buf = new StringBuffer(); |
---|
753 | buf |
---|
754 | .append(this.getClass().getName().replaceAll(".*\\.", "") |
---|
755 | + "\n"); |
---|
756 | buf.append(this.getClass().getName().replaceAll(".*\\.", "") |
---|
757 | .replaceAll(".", "=") |
---|
758 | + "\n\n"); |
---|
759 | buf |
---|
760 | .append("Warning: No model could be built, hence ZeroR model is used:\n\n"); |
---|
761 | buf.append(m_ZeroR.toString()); |
---|
762 | return buf.toString(); |
---|
763 | } |
---|
764 | |
---|
765 | if (m_Successors == null) { |
---|
766 | return "RandomTree: no model has been built yet."; |
---|
767 | } else { |
---|
768 | return "\nRandomTree\n==========\n" |
---|
769 | + toString(0) |
---|
770 | + "\n" |
---|
771 | + "\nSize of the tree : " |
---|
772 | + numNodes() |
---|
773 | + (getMaxDepth() > 0 ? ("\nMax depth of tree: " + getMaxDepth()) |
---|
774 | : ("")); |
---|
775 | } |
---|
776 | } |
---|
777 | |
---|
778 | /** |
---|
779 | * Outputs a leaf. |
---|
780 | * |
---|
781 | * @return the leaf as string |
---|
782 | * @throws Exception |
---|
783 | * if generation fails |
---|
784 | */ |
---|
785 | protected String leafString() throws Exception { |
---|
786 | |
---|
787 | double sum = 0, maxCount = 0; |
---|
788 | int maxIndex = 0; |
---|
789 | if (m_ClassDistribution != null) { |
---|
790 | sum = Utils.sum(m_ClassDistribution); |
---|
791 | maxIndex = Utils.maxIndex(m_ClassDistribution); |
---|
792 | maxCount = m_ClassDistribution[maxIndex]; |
---|
793 | } |
---|
794 | return " : " |
---|
795 | + m_Info.classAttribute().value(maxIndex) |
---|
796 | + " (" |
---|
797 | + Utils.doubleToString(sum, 2) |
---|
798 | + "/" |
---|
799 | + Utils.doubleToString(sum - maxCount, 2) + ")"; |
---|
800 | } |
---|
801 | |
---|
802 | /** |
---|
803 | * Recursively outputs the tree. |
---|
804 | * |
---|
805 | * @param level |
---|
806 | * the current level of the tree |
---|
807 | * @return the generated subtree |
---|
808 | */ |
---|
809 | protected String toString(int level) { |
---|
810 | |
---|
811 | try { |
---|
812 | StringBuffer text = new StringBuffer(); |
---|
813 | |
---|
814 | if (m_Attribute == -1) { |
---|
815 | |
---|
816 | // Output leaf info |
---|
817 | return leafString(); |
---|
818 | } else if (m_Info.attribute(m_Attribute).isNominal()) { |
---|
819 | |
---|
820 | // For nominal attributes |
---|
821 | for (int i = 0; i < m_Successors.length; i++) { |
---|
822 | text.append("\n"); |
---|
823 | for (int j = 0; j < level; j++) { |
---|
824 | text.append("| "); |
---|
825 | } |
---|
826 | text.append(m_Info.attribute(m_Attribute).name() + " = " |
---|
827 | + m_Info.attribute(m_Attribute).value(i)); |
---|
828 | text.append(m_Successors[i].toString(level + 1)); |
---|
829 | } |
---|
830 | } else { |
---|
831 | |
---|
832 | // For numeric attributes |
---|
833 | text.append("\n"); |
---|
834 | for (int j = 0; j < level; j++) { |
---|
835 | text.append("| "); |
---|
836 | } |
---|
837 | text.append(m_Info.attribute(m_Attribute).name() + " < " |
---|
838 | + Utils.doubleToString(m_SplitPoint, 2)); |
---|
839 | text.append(m_Successors[0].toString(level + 1)); |
---|
840 | text.append("\n"); |
---|
841 | for (int j = 0; j < level; j++) { |
---|
842 | text.append("| "); |
---|
843 | } |
---|
844 | text.append(m_Info.attribute(m_Attribute).name() + " >= " |
---|
845 | + Utils.doubleToString(m_SplitPoint, 2)); |
---|
846 | text.append(m_Successors[1].toString(level + 1)); |
---|
847 | } |
---|
848 | |
---|
849 | return text.toString(); |
---|
850 | } catch (Exception e) { |
---|
851 | e.printStackTrace(); |
---|
852 | return "RandomTree: tree can't be printed"; |
---|
853 | } |
---|
854 | } |
---|
855 | |
---|
856 | /** |
---|
857 | * Recursively backfits data into the tree. |
---|
858 | * |
---|
859 | * @param data |
---|
860 | * the data to work with |
---|
861 | * @param classProbs |
---|
862 | * the class distribution |
---|
863 | * @throws Exception |
---|
864 | * if generation fails |
---|
865 | */ |
---|
866 | protected void backfitData(Instances data, double[] classProbs) throws Exception { |
---|
867 | |
---|
868 | // Make leaf if there are no training instances |
---|
869 | if (data.numInstances() == 0) { |
---|
870 | m_Attribute = -1; |
---|
871 | m_ClassDistribution = null; |
---|
872 | m_Prop = null; |
---|
873 | return; |
---|
874 | } |
---|
875 | |
---|
876 | // Check if node doesn't contain enough instances or is pure |
---|
877 | // or maximum depth reached |
---|
878 | m_ClassDistribution = (double[]) classProbs.clone(); |
---|
879 | |
---|
880 | /* if (Utils.sum(m_ClassDistribution) < 2 * m_MinNum |
---|
881 | || Utils.eq(m_ClassDistribution[Utils.maxIndex(m_ClassDistribution)], Utils |
---|
882 | .sum(m_ClassDistribution))) { |
---|
883 | |
---|
884 | // Make leaf |
---|
885 | m_Attribute = -1; |
---|
886 | m_Prop = null; |
---|
887 | return; |
---|
888 | }*/ |
---|
889 | |
---|
890 | // Are we at an inner node |
---|
891 | if (m_Attribute > -1) { |
---|
892 | |
---|
893 | // Compute new weights for subsets based on backfit data |
---|
894 | m_Prop = new double[m_Successors.length]; |
---|
895 | for (int i = 0; i < data.numInstances(); i++) { |
---|
896 | Instance inst = data.instance(i); |
---|
897 | if (!inst.isMissing(m_Attribute)) { |
---|
898 | if (data.attribute(m_Attribute).isNominal()) { |
---|
899 | m_Prop[(int)inst.value(m_Attribute)] += inst.weight(); |
---|
900 | } else { |
---|
901 | m_Prop[(inst.value(m_Attribute) < m_SplitPoint) ? 0 : 1] += inst.weight(); |
---|
902 | } |
---|
903 | } |
---|
904 | } |
---|
905 | |
---|
906 | // If we only have missing values we can make this node into a leaf |
---|
907 | if (Utils.sum(m_Prop) <= 0) { |
---|
908 | m_Attribute = -1; |
---|
909 | m_Prop = null; |
---|
910 | return; |
---|
911 | } |
---|
912 | |
---|
913 | // Otherwise normalize the proportions |
---|
914 | Utils.normalize(m_Prop); |
---|
915 | |
---|
916 | // Split data |
---|
917 | Instances[] subsets = splitData(data); |
---|
918 | |
---|
919 | // Go through subsets |
---|
920 | for (int i = 0; i < subsets.length; i++) { |
---|
921 | |
---|
922 | // Compute distribution for current subset |
---|
923 | double[] dist = new double[data.numClasses()]; |
---|
924 | for (int j = 0; j < subsets[i].numInstances(); j++) { |
---|
925 | dist[(int)subsets[i].instance(j).classValue()] += subsets[i].instance(j).weight(); |
---|
926 | } |
---|
927 | |
---|
928 | // Backfit subset |
---|
929 | m_Successors[i].backfitData(subsets[i], dist); |
---|
930 | } |
---|
931 | |
---|
932 | // If unclassified instances are allowed, we don't need to store the class distribution |
---|
933 | if (getAllowUnclassifiedInstances()) { |
---|
934 | m_ClassDistribution = null; |
---|
935 | return; |
---|
936 | } |
---|
937 | |
---|
938 | // Otherwise, if all successors are non-empty, we don't need to store the class distribution |
---|
939 | boolean emptySuccessor = false; |
---|
940 | for (int i = 0; i < subsets.length; i++) { |
---|
941 | if (m_Successors[i].m_ClassDistribution == null) { |
---|
942 | emptySuccessor = true; |
---|
943 | return; |
---|
944 | } |
---|
945 | } |
---|
946 | m_ClassDistribution = null; |
---|
947 | |
---|
948 | // If we have a least two non-empty successors, we should keep this tree |
---|
949 | /* int nonEmptySuccessors = 0; |
---|
950 | for (int i = 0; i < subsets.length; i++) { |
---|
951 | if (m_Successors[i].m_ClassDistribution != null) { |
---|
952 | nonEmptySuccessors++; |
---|
953 | if (nonEmptySuccessors > 1) { |
---|
954 | return; |
---|
955 | } |
---|
956 | } |
---|
957 | } |
---|
958 | |
---|
959 | // Otherwise, this node is a leaf or should become a leaf |
---|
960 | m_Successors = null; |
---|
961 | m_Attribute = -1; |
---|
962 | m_Prop = null; |
---|
963 | return;*/ |
---|
964 | } |
---|
965 | } |
---|
966 | |
---|
967 | /** |
---|
968 | * Recursively generates a tree. |
---|
969 | * |
---|
970 | * @param data |
---|
971 | * the data to work with |
---|
972 | * @param classProbs |
---|
973 | * the class distribution |
---|
974 | * @param header |
---|
975 | * the header of the data |
---|
976 | * @param minNum |
---|
977 | * the minimum number of instances per leaf |
---|
978 | * @param debug |
---|
979 | * whether debugging is on |
---|
980 | * @param attIndicesWindow |
---|
981 | * the attribute window to choose attributes from |
---|
982 | * @param random |
---|
983 | * random number generator for choosing random attributes |
---|
984 | * @param depth |
---|
985 | * the current depth |
---|
986 | * @param determineStructure |
---|
987 | * whether to determine structure |
---|
988 | * @throws Exception |
---|
989 | * if generation fails |
---|
990 | */ |
---|
991 | protected void buildTree(Instances data, double[] classProbs, Instances header, |
---|
992 | double minNum, boolean debug, int[] attIndicesWindow, |
---|
993 | Random random, int depth, boolean allow) throws Exception { |
---|
994 | |
---|
995 | // Store structure of dataset, set minimum number of instances |
---|
996 | m_Info = header; |
---|
997 | m_Debug = debug; |
---|
998 | m_MinNum = minNum; |
---|
999 | m_AllowUnclassifiedInstances = allow; |
---|
1000 | |
---|
1001 | // Make leaf if there are no training instances |
---|
1002 | if (data.numInstances() == 0) { |
---|
1003 | m_Attribute = -1; |
---|
1004 | m_ClassDistribution = null; |
---|
1005 | m_Prop = null; |
---|
1006 | return; |
---|
1007 | } |
---|
1008 | |
---|
1009 | // Check if node doesn't contain enough instances or is pure |
---|
1010 | // or maximum depth reached |
---|
1011 | m_ClassDistribution = (double[]) classProbs.clone(); |
---|
1012 | |
---|
1013 | if (Utils.sum(m_ClassDistribution) < 2 * m_MinNum |
---|
1014 | || Utils.eq(m_ClassDistribution[Utils.maxIndex(m_ClassDistribution)], Utils |
---|
1015 | .sum(m_ClassDistribution)) |
---|
1016 | || ((getMaxDepth() > 0) && (depth >= getMaxDepth()))) { |
---|
1017 | // Make leaf |
---|
1018 | m_Attribute = -1; |
---|
1019 | m_Prop = null; |
---|
1020 | return; |
---|
1021 | } |
---|
1022 | |
---|
1023 | // Compute class distributions and value of splitting |
---|
1024 | // criterion for each attribute |
---|
1025 | double[] vals = new double[data.numAttributes()]; |
---|
1026 | double[][][] dists = new double[data.numAttributes()][0][0]; |
---|
1027 | double[][] props = new double[data.numAttributes()][0]; |
---|
1028 | double[] splits = new double[data.numAttributes()]; |
---|
1029 | |
---|
1030 | // Investigate K random attributes |
---|
1031 | int attIndex = 0; |
---|
1032 | int windowSize = attIndicesWindow.length; |
---|
1033 | int k = m_KValue; |
---|
1034 | boolean gainFound = false; |
---|
1035 | while ((windowSize > 0) && (k-- > 0 || !gainFound)) { |
---|
1036 | |
---|
1037 | int chosenIndex = random.nextInt(windowSize); |
---|
1038 | attIndex = attIndicesWindow[chosenIndex]; |
---|
1039 | |
---|
1040 | // shift chosen attIndex out of window |
---|
1041 | attIndicesWindow[chosenIndex] = attIndicesWindow[windowSize - 1]; |
---|
1042 | attIndicesWindow[windowSize - 1] = attIndex; |
---|
1043 | windowSize--; |
---|
1044 | |
---|
1045 | splits[attIndex] = distribution(props, dists, attIndex, data); |
---|
1046 | vals[attIndex] = gain(dists[attIndex], priorVal(dists[attIndex])); |
---|
1047 | |
---|
1048 | if (Utils.gr(vals[attIndex], 0)) |
---|
1049 | gainFound = true; |
---|
1050 | } |
---|
1051 | |
---|
1052 | // Find best attribute |
---|
1053 | m_Attribute = Utils.maxIndex(vals); |
---|
1054 | double[][] distribution = dists[m_Attribute]; |
---|
1055 | |
---|
1056 | // Any useful split found? |
---|
1057 | if (Utils.gr(vals[m_Attribute], 0)) { |
---|
1058 | |
---|
1059 | // Build subtrees |
---|
1060 | m_SplitPoint = splits[m_Attribute]; |
---|
1061 | m_Prop = props[m_Attribute]; |
---|
1062 | Instances[] subsets = splitData(data); |
---|
1063 | m_Successors = new RandomTree[distribution.length]; |
---|
1064 | for (int i = 0; i < distribution.length; i++) { |
---|
1065 | m_Successors[i] = new RandomTree(); |
---|
1066 | m_Successors[i].setKValue(m_KValue); |
---|
1067 | m_Successors[i].setMaxDepth(getMaxDepth()); |
---|
1068 | m_Successors[i].buildTree(subsets[i], distribution[i], header, m_MinNum, m_Debug, |
---|
1069 | attIndicesWindow, random, depth + 1, allow); |
---|
1070 | } |
---|
1071 | |
---|
1072 | // If all successors are non-empty, we don't need to store the class distribution |
---|
1073 | boolean emptySuccessor = false; |
---|
1074 | for (int i = 0; i < subsets.length; i++) { |
---|
1075 | if (m_Successors[i].m_ClassDistribution == null) { |
---|
1076 | emptySuccessor = true; |
---|
1077 | break; |
---|
1078 | } |
---|
1079 | } |
---|
1080 | if (!emptySuccessor) { |
---|
1081 | m_ClassDistribution = null; |
---|
1082 | } |
---|
1083 | } else { |
---|
1084 | |
---|
1085 | // Make leaf |
---|
1086 | m_Attribute = -1; |
---|
1087 | } |
---|
1088 | } |
---|
1089 | |
---|
1090 | /** |
---|
1091 | * Computes size of the tree. |
---|
1092 | * |
---|
1093 | * @return the number of nodes |
---|
1094 | */ |
---|
1095 | public int numNodes() { |
---|
1096 | |
---|
1097 | if (m_Attribute == -1) { |
---|
1098 | return 1; |
---|
1099 | } else { |
---|
1100 | int size = 1; |
---|
1101 | for (int i = 0; i < m_Successors.length; i++) { |
---|
1102 | size += m_Successors[i].numNodes(); |
---|
1103 | } |
---|
1104 | return size; |
---|
1105 | } |
---|
1106 | } |
---|
1107 | |
---|
1108 | /** |
---|
1109 | * Splits instances into subsets based on the given split. |
---|
1110 | * |
---|
1111 | * @param data |
---|
1112 | * the data to work with |
---|
1113 | * @return the subsets of instances |
---|
1114 | * @throws Exception |
---|
1115 | * if something goes wrong |
---|
1116 | */ |
---|
1117 | protected Instances[] splitData(Instances data) throws Exception { |
---|
1118 | |
---|
1119 | // Allocate array of Instances objects |
---|
1120 | Instances[] subsets = new Instances[m_Prop.length]; |
---|
1121 | for (int i = 0; i < m_Prop.length; i++) { |
---|
1122 | subsets[i] = new Instances(data, data.numInstances()); |
---|
1123 | } |
---|
1124 | |
---|
1125 | // Go through the data |
---|
1126 | for (int i = 0; i < data.numInstances(); i++) { |
---|
1127 | |
---|
1128 | // Get instance |
---|
1129 | Instance inst = data.instance(i); |
---|
1130 | |
---|
1131 | // Does the instance have a missing value? |
---|
1132 | if (inst.isMissing(m_Attribute)) { |
---|
1133 | |
---|
1134 | // Split instance up |
---|
1135 | for (int k = 0; k < m_Prop.length; k++) { |
---|
1136 | if (m_Prop[k] > 0) { |
---|
1137 | Instance copy = (Instance)inst.copy(); |
---|
1138 | copy.setWeight(m_Prop[k] * inst.weight()); |
---|
1139 | subsets[k].add(copy); |
---|
1140 | } |
---|
1141 | } |
---|
1142 | |
---|
1143 | // Proceed to next instance |
---|
1144 | continue; |
---|
1145 | } |
---|
1146 | |
---|
1147 | // Do we have a nominal attribute? |
---|
1148 | if (data.attribute(m_Attribute).isNominal()) { |
---|
1149 | subsets[(int)inst.value(m_Attribute)].add(inst); |
---|
1150 | |
---|
1151 | // Proceed to next instance |
---|
1152 | continue; |
---|
1153 | } |
---|
1154 | |
---|
1155 | // Do we have a numeric attribute? |
---|
1156 | if (data.attribute(m_Attribute).isNumeric()) { |
---|
1157 | subsets[(inst.value(m_Attribute) < m_SplitPoint) ? 0 : 1].add(inst); |
---|
1158 | |
---|
1159 | // Proceed to next instance |
---|
1160 | continue; |
---|
1161 | } |
---|
1162 | |
---|
1163 | // Else throw an exception |
---|
1164 | throw new IllegalArgumentException("Unknown attribute type"); |
---|
1165 | } |
---|
1166 | |
---|
1167 | // Save memory |
---|
1168 | for (int i = 0; i < m_Prop.length; i++) { |
---|
1169 | subsets[i].compactify(); |
---|
1170 | } |
---|
1171 | |
---|
1172 | // Return the subsets |
---|
1173 | return subsets; |
---|
1174 | } |
---|
1175 | |
---|
1176 | /** |
---|
1177 | * Computes class distribution for an attribute. |
---|
1178 | * |
---|
1179 | * @param props |
---|
1180 | * @param dists |
---|
1181 | * @param att |
---|
1182 | * the attribute index |
---|
1183 | * @param data |
---|
1184 | * the data to work with |
---|
1185 | * @throws Exception |
---|
1186 | * if something goes wrong |
---|
1187 | */ |
---|
1188 | protected double distribution(double[][] props, double[][][] dists, int att, Instances data) |
---|
1189 | throws Exception { |
---|
1190 | |
---|
1191 | double splitPoint = Double.NaN; |
---|
1192 | Attribute attribute = data.attribute(att); |
---|
1193 | double[][] dist = null; |
---|
1194 | int indexOfFirstMissingValue = -1; |
---|
1195 | |
---|
1196 | if (attribute.isNominal()) { |
---|
1197 | |
---|
1198 | // For nominal attributes |
---|
1199 | dist = new double[attribute.numValues()][data.numClasses()]; |
---|
1200 | for (int i = 0; i < data.numInstances(); i++) { |
---|
1201 | Instance inst = data.instance(i); |
---|
1202 | if (inst.isMissing(att)) { |
---|
1203 | |
---|
1204 | // Skip missing values at this stage |
---|
1205 | if (indexOfFirstMissingValue < 0) { |
---|
1206 | indexOfFirstMissingValue = i; |
---|
1207 | } |
---|
1208 | continue; |
---|
1209 | } |
---|
1210 | dist[(int) inst.value(att)][(int) inst.classValue()] += inst.weight(); |
---|
1211 | } |
---|
1212 | } else { |
---|
1213 | |
---|
1214 | // For numeric attributes |
---|
1215 | double[][] currDist = new double[2][data.numClasses()]; |
---|
1216 | dist = new double[2][data.numClasses()]; |
---|
1217 | |
---|
1218 | // Sort data |
---|
1219 | data.sort(att); |
---|
1220 | |
---|
1221 | // Move all instances into second subset |
---|
1222 | for (int j = 0; j < data.numInstances(); j++) { |
---|
1223 | Instance inst = data.instance(j); |
---|
1224 | if (inst.isMissing(att)) { |
---|
1225 | |
---|
1226 | // Can stop as soon as we hit a missing value |
---|
1227 | indexOfFirstMissingValue = j; |
---|
1228 | break; |
---|
1229 | } |
---|
1230 | currDist[1][(int) inst.classValue()] += inst.weight(); |
---|
1231 | } |
---|
1232 | |
---|
1233 | // Value before splitting |
---|
1234 | double priorVal = priorVal(currDist); |
---|
1235 | |
---|
1236 | // Save initial distribution |
---|
1237 | for (int j = 0; j < currDist.length; j++) { |
---|
1238 | System.arraycopy(currDist[j], 0, dist[j], 0, dist[j].length); |
---|
1239 | } |
---|
1240 | |
---|
1241 | // Try all possible split points |
---|
1242 | double currSplit = data.instance(0).value(att); |
---|
1243 | double currVal, bestVal = -Double.MAX_VALUE; |
---|
1244 | for (int i = 0; i < data.numInstances(); i++) { |
---|
1245 | Instance inst = data.instance(i); |
---|
1246 | if (inst.isMissing(att)) { |
---|
1247 | |
---|
1248 | // Can stop as soon as we hit a missing value |
---|
1249 | break; |
---|
1250 | } |
---|
1251 | |
---|
1252 | // Can we place a sensible split point here? |
---|
1253 | if (inst.value(att) > currSplit) { |
---|
1254 | |
---|
1255 | // Compute gain for split point |
---|
1256 | currVal = gain(currDist, priorVal); |
---|
1257 | |
---|
1258 | // Is the current split point the best point so far? |
---|
1259 | if (currVal > bestVal) { |
---|
1260 | |
---|
1261 | // Store value of current point |
---|
1262 | bestVal = currVal; |
---|
1263 | |
---|
1264 | // Save split point |
---|
1265 | splitPoint = (inst.value(att) + currSplit) / 2.0; |
---|
1266 | |
---|
1267 | // Save distribution |
---|
1268 | for (int j = 0; j < currDist.length; j++) { |
---|
1269 | System.arraycopy(currDist[j], 0, dist[j], 0, dist[j].length); |
---|
1270 | } |
---|
1271 | } |
---|
1272 | } |
---|
1273 | currSplit = inst.value(att); |
---|
1274 | |
---|
1275 | // Shift over the weight |
---|
1276 | currDist[0][(int) inst.classValue()] += inst.weight(); |
---|
1277 | currDist[1][(int) inst.classValue()] -= inst.weight(); |
---|
1278 | } |
---|
1279 | } |
---|
1280 | |
---|
1281 | // Compute weights for subsets |
---|
1282 | props[att] = new double[dist.length]; |
---|
1283 | for (int k = 0; k < props[att].length; k++) { |
---|
1284 | props[att][k] = Utils.sum(dist[k]); |
---|
1285 | } |
---|
1286 | if (Utils.eq(Utils.sum(props[att]), 0)) { |
---|
1287 | for (int k = 0; k < props[att].length; k++) { |
---|
1288 | props[att][k] = 1.0 / (double) props[att].length; |
---|
1289 | } |
---|
1290 | } else { |
---|
1291 | Utils.normalize(props[att]); |
---|
1292 | } |
---|
1293 | |
---|
1294 | // Any instances with missing values ? |
---|
1295 | if (indexOfFirstMissingValue > -1) { |
---|
1296 | |
---|
1297 | // Distribute weights for instances with missing values |
---|
1298 | for (int i = indexOfFirstMissingValue; i < data.numInstances(); i++) { |
---|
1299 | Instance inst = data.instance(i); |
---|
1300 | if (attribute.isNominal()) { |
---|
1301 | |
---|
1302 | // Need to check if attribute value is missing |
---|
1303 | if (inst.isMissing(att)) { |
---|
1304 | for (int j = 0; j < dist.length; j++) { |
---|
1305 | dist[j][(int) inst.classValue()] += props[att][j] * inst.weight(); |
---|
1306 | } |
---|
1307 | } |
---|
1308 | } else { |
---|
1309 | |
---|
1310 | // Can be sure that value is missing, so no test required |
---|
1311 | for (int j = 0; j < dist.length; j++) { |
---|
1312 | dist[j][(int) inst.classValue()] += props[att][j] * inst.weight(); |
---|
1313 | } |
---|
1314 | } |
---|
1315 | } |
---|
1316 | } |
---|
1317 | |
---|
1318 | // Return distribution and split point |
---|
1319 | dists[att] = dist; |
---|
1320 | return splitPoint; |
---|
1321 | } |
---|
1322 | |
---|
1323 | /** |
---|
1324 | * Computes value of splitting criterion before split. |
---|
1325 | * |
---|
1326 | * @param dist |
---|
1327 | * the distributions |
---|
1328 | * @return the splitting criterion |
---|
1329 | */ |
---|
1330 | protected double priorVal(double[][] dist) { |
---|
1331 | |
---|
1332 | return ContingencyTables.entropyOverColumns(dist); |
---|
1333 | } |
---|
1334 | |
---|
1335 | /** |
---|
1336 | * Computes value of splitting criterion after split. |
---|
1337 | * |
---|
1338 | * @param dist |
---|
1339 | * the distributions |
---|
1340 | * @param priorVal |
---|
1341 | * the splitting criterion |
---|
1342 | * @return the gain after the split |
---|
1343 | */ |
---|
1344 | protected double gain(double[][] dist, double priorVal) { |
---|
1345 | |
---|
1346 | return priorVal - ContingencyTables.entropyConditionedOnRows(dist); |
---|
1347 | } |
---|
1348 | |
---|
1349 | /** |
---|
1350 | * Returns the revision string. |
---|
1351 | * |
---|
1352 | * @return the revision |
---|
1353 | */ |
---|
1354 | public String getRevision() { |
---|
1355 | return RevisionUtils.extract("$Revision: 5928 $"); |
---|
1356 | } |
---|
1357 | |
---|
1358 | /** |
---|
1359 | * Main method for this class. |
---|
1360 | * |
---|
1361 | * @param argv |
---|
1362 | * the commandline parameters |
---|
1363 | */ |
---|
1364 | public static void main(String[] argv) { |
---|
1365 | runClassifier(new RandomTree(), argv); |
---|
1366 | } |
---|
1367 | |
---|
1368 | /** |
---|
1369 | * Returns graph describing the tree. |
---|
1370 | * |
---|
1371 | * @return the graph describing the tree |
---|
1372 | * @throws Exception |
---|
1373 | * if graph can't be computed |
---|
1374 | */ |
---|
1375 | public String graph() throws Exception { |
---|
1376 | |
---|
1377 | if (m_Successors == null) { |
---|
1378 | throw new Exception("RandomTree: No model built yet."); |
---|
1379 | } |
---|
1380 | StringBuffer resultBuff = new StringBuffer(); |
---|
1381 | toGraph(resultBuff, 0, null); |
---|
1382 | String result = "digraph RandomTree {\n" + "edge [style=bold]\n" |
---|
1383 | + resultBuff.toString() + "\n}\n"; |
---|
1384 | return result; |
---|
1385 | } |
---|
1386 | |
---|
1387 | /** |
---|
1388 | * Returns the type of graph this classifier represents. |
---|
1389 | * |
---|
1390 | * @return Drawable.TREE |
---|
1391 | */ |
---|
1392 | public int graphType() { |
---|
1393 | return Drawable.TREE; |
---|
1394 | } |
---|
1395 | |
---|
1396 | /** |
---|
1397 | * Outputs one node for graph. |
---|
1398 | * |
---|
1399 | * @param text |
---|
1400 | * the buffer to append the output to |
---|
1401 | * @param num |
---|
1402 | * the current node id |
---|
1403 | * @param parent |
---|
1404 | * the parent of the nodes |
---|
1405 | * @return the next node id |
---|
1406 | * @throws Exception |
---|
1407 | * if something goes wrong |
---|
1408 | */ |
---|
1409 | protected int toGraph(StringBuffer text, int num, RandomTree parent) |
---|
1410 | throws Exception { |
---|
1411 | |
---|
1412 | num++; |
---|
1413 | if (m_Attribute == -1) { |
---|
1414 | text.append("N" + Integer.toHexString(RandomTree.this.hashCode()) |
---|
1415 | + " [label=\"" + num + leafString() + "\"" |
---|
1416 | + " shape=box]\n"); |
---|
1417 | |
---|
1418 | } else { |
---|
1419 | text.append("N" + Integer.toHexString(RandomTree.this.hashCode()) |
---|
1420 | + " [label=\"" + num + ": " |
---|
1421 | + m_Info.attribute(m_Attribute).name() + "\"]\n"); |
---|
1422 | for (int i = 0; i < m_Successors.length; i++) { |
---|
1423 | text.append("N" |
---|
1424 | + Integer.toHexString(RandomTree.this.hashCode()) |
---|
1425 | + "->" + "N" |
---|
1426 | + Integer.toHexString(m_Successors[i].hashCode()) |
---|
1427 | + " [label=\""); |
---|
1428 | if (m_Info.attribute(m_Attribute).isNumeric()) { |
---|
1429 | if (i == 0) { |
---|
1430 | text.append(" < " |
---|
1431 | + Utils.doubleToString(m_SplitPoint, 2)); |
---|
1432 | } else { |
---|
1433 | text.append(" >= " |
---|
1434 | + Utils.doubleToString(m_SplitPoint, 2)); |
---|
1435 | } |
---|
1436 | } else { |
---|
1437 | text.append(" = " + m_Info.attribute(m_Attribute).value(i)); |
---|
1438 | } |
---|
1439 | text.append("\"]\n"); |
---|
1440 | num = m_Successors[i].toGraph(text, num, this); |
---|
1441 | } |
---|
1442 | } |
---|
1443 | |
---|
1444 | return num; |
---|
1445 | } |
---|
1446 | } |
---|
1447 | |
---|