1 | /* |
---|
2 | * This program is free software; you can redistribute it and/or modify |
---|
3 | * it under the terms of the GNU General Public License as published by |
---|
4 | * the Free Software Foundation; either version 2 of the License, or |
---|
5 | * (at your option) any later version. |
---|
6 | * |
---|
7 | * This program is distributed in the hope that it will be useful, |
---|
8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
---|
10 | * GNU General Public License for more details. |
---|
11 | * |
---|
12 | * You should have received a copy of the GNU General Public License |
---|
13 | * along with this program; if not, write to the Free Software |
---|
14 | * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. |
---|
15 | */ |
---|
16 | |
---|
17 | /* |
---|
18 | * SimpleCart.java |
---|
19 | * Copyright (C) 2007 Haijian Shi |
---|
20 | * |
---|
21 | */ |
---|
22 | |
---|
23 | package weka.classifiers.trees; |
---|
24 | |
---|
25 | import weka.classifiers.Evaluation; |
---|
26 | import weka.classifiers.RandomizableClassifier; |
---|
27 | import weka.core.AdditionalMeasureProducer; |
---|
28 | import weka.core.Attribute; |
---|
29 | import weka.core.Capabilities; |
---|
30 | import weka.core.Instance; |
---|
31 | import weka.core.Instances; |
---|
32 | import weka.core.Option; |
---|
33 | import weka.core.RevisionUtils; |
---|
34 | import weka.core.TechnicalInformation; |
---|
35 | import weka.core.TechnicalInformationHandler; |
---|
36 | import weka.core.Utils; |
---|
37 | import weka.core.Capabilities.Capability; |
---|
38 | import weka.core.TechnicalInformation.Field; |
---|
39 | import weka.core.TechnicalInformation.Type; |
---|
40 | import weka.core.matrix.Matrix; |
---|
41 | |
---|
42 | import java.util.Arrays; |
---|
43 | import java.util.Enumeration; |
---|
44 | import java.util.Random; |
---|
45 | import java.util.Vector; |
---|
46 | |
---|
47 | /** |
---|
48 | <!-- globalinfo-start --> |
---|
49 | * Class implementing minimal cost-complexity pruning.<br/> |
---|
50 | * Note when dealing with missing values, use "fractional instances" method instead of surrogate split method.<br/> |
---|
51 | * <br/> |
---|
52 | * For more information, see:<br/> |
---|
53 | * <br/> |
---|
54 | * Leo Breiman, Jerome H. Friedman, Richard A. Olshen, Charles J. Stone (1984). Classification and Regression Trees. Wadsworth International Group, Belmont, California. |
---|
55 | * <p/> |
---|
56 | <!-- globalinfo-end --> |
---|
57 | * |
---|
58 | <!-- technical-bibtex-start --> |
---|
59 | * BibTeX: |
---|
60 | * <pre> |
---|
61 | * @book{Breiman1984, |
---|
62 | * address = {Belmont, California}, |
---|
63 | * author = {Leo Breiman and Jerome H. Friedman and Richard A. Olshen and Charles J. Stone}, |
---|
64 | * publisher = {Wadsworth International Group}, |
---|
65 | * title = {Classification and Regression Trees}, |
---|
66 | * year = {1984} |
---|
67 | * } |
---|
68 | * </pre> |
---|
69 | * <p/> |
---|
70 | <!-- technical-bibtex-end --> |
---|
71 | * |
---|
72 | <!-- options-start --> |
---|
73 | * Valid options are: <p/> |
---|
74 | * |
---|
75 | * <pre> -S <num> |
---|
76 | * Random number seed. |
---|
77 | * (default 1)</pre> |
---|
78 | * |
---|
79 | * <pre> -D |
---|
80 | * If set, classifier is run in debug mode and |
---|
81 | * may output additional info to the console</pre> |
---|
82 | * |
---|
83 | * <pre> -M <min no> |
---|
84 | * The minimal number of instances at the terminal nodes. |
---|
85 | * (default 2)</pre> |
---|
86 | * |
---|
87 | * <pre> -N <num folds> |
---|
88 | * The number of folds used in the minimal cost-complexity pruning. |
---|
89 | * (default 5)</pre> |
---|
90 | * |
---|
91 | * <pre> -U |
---|
92 | * Don't use the minimal cost-complexity pruning. |
---|
93 | * (default yes).</pre> |
---|
94 | * |
---|
95 | * <pre> -H |
---|
96 | * Don't use the heuristic method for binary split. |
---|
97 | * (default true).</pre> |
---|
98 | * |
---|
99 | * <pre> -A |
---|
100 | * Use 1 SE rule to make pruning decision. |
---|
101 | * (default no).</pre> |
---|
102 | * |
---|
103 | * <pre> -C |
---|
104 | * Percentage of training data size (0-1]. |
---|
105 | * (default 1).</pre> |
---|
106 | * |
---|
107 | <!-- options-end --> |
---|
108 | * |
---|
109 | * @author Haijian Shi (hs69@cs.waikato.ac.nz) |
---|
110 | * @version $Revision: 5987 $ |
---|
111 | */ |
---|
112 | public class SimpleCart |
---|
113 | extends RandomizableClassifier |
---|
114 | implements AdditionalMeasureProducer, TechnicalInformationHandler { |
---|
115 | |
---|
116 | /** For serialization. */ |
---|
117 | private static final long serialVersionUID = 4154189200352566053L; |
---|
118 | |
---|
119 | /** Training data. */ |
---|
120 | protected Instances m_train; |
---|
121 | |
---|
122 | /** Successor nodes. */ |
---|
123 | protected SimpleCart[] m_Successors; |
---|
124 | |
---|
125 | /** Attribute used to split data. */ |
---|
126 | protected Attribute m_Attribute; |
---|
127 | |
---|
128 | /** Split point for a numeric attribute. */ |
---|
129 | protected double m_SplitValue; |
---|
130 | |
---|
131 | /** Split subset used to split data for nominal attributes. */ |
---|
132 | protected String m_SplitString; |
---|
133 | |
---|
134 | /** Class value if the node is leaf. */ |
---|
135 | protected double m_ClassValue; |
---|
136 | |
---|
137 | /** Class attriubte of data. */ |
---|
138 | protected Attribute m_ClassAttribute; |
---|
139 | |
---|
140 | /** Minimum number of instances in at the terminal nodes. */ |
---|
141 | protected double m_minNumObj = 2; |
---|
142 | |
---|
143 | /** Number of folds for minimal cost-complexity pruning. */ |
---|
144 | protected int m_numFoldsPruning = 5; |
---|
145 | |
---|
146 | /** Alpha-value (for pruning) at the node. */ |
---|
147 | protected double m_Alpha; |
---|
148 | |
---|
149 | /** Number of training examples misclassified by the model (subtree rooted). */ |
---|
150 | protected double m_numIncorrectModel; |
---|
151 | |
---|
152 | /** Number of training examples misclassified by the model (subtree not rooted). */ |
---|
153 | protected double m_numIncorrectTree; |
---|
154 | |
---|
155 | /** Indicate if the node is a leaf node. */ |
---|
156 | protected boolean m_isLeaf; |
---|
157 | |
---|
158 | /** If use minimal cost-compexity pruning. */ |
---|
159 | protected boolean m_Prune = true; |
---|
160 | |
---|
161 | /** Total number of instances used to build the classifier. */ |
---|
162 | protected int m_totalTrainInstances; |
---|
163 | |
---|
164 | /** Proportion for each branch. */ |
---|
165 | protected double[] m_Props; |
---|
166 | |
---|
167 | /** Class probabilities. */ |
---|
168 | protected double[] m_ClassProbs = null; |
---|
169 | |
---|
170 | /** Distributions of leaf node (or temporary leaf node in minimal cost-complexity pruning) */ |
---|
171 | protected double[] m_Distribution; |
---|
172 | |
---|
173 | /** If use huristic search for nominal attributes in multi-class problems (default true). */ |
---|
174 | protected boolean m_Heuristic = true; |
---|
175 | |
---|
176 | /** If use the 1SE rule to make final decision tree. */ |
---|
177 | protected boolean m_UseOneSE = false; |
---|
178 | |
---|
179 | /** Training data size. */ |
---|
180 | protected double m_SizePer = 1; |
---|
181 | |
---|
182 | /** |
---|
183 | * Return a description suitable for displaying in the explorer/experimenter. |
---|
184 | * |
---|
185 | * @return a description suitable for displaying in the |
---|
186 | * explorer/experimenter |
---|
187 | */ |
---|
188 | public String globalInfo() { |
---|
189 | return |
---|
190 | "Class implementing minimal cost-complexity pruning.\n" |
---|
191 | + "Note when dealing with missing values, use \"fractional " |
---|
192 | + "instances\" method instead of surrogate split method.\n\n" |
---|
193 | + "For more information, see:\n\n" |
---|
194 | + getTechnicalInformation().toString(); |
---|
195 | } |
---|
196 | |
---|
197 | /** |
---|
198 | * Returns an instance of a TechnicalInformation object, containing |
---|
199 | * detailed information about the technical background of this class, |
---|
200 | * e.g., paper reference or book this class is based on. |
---|
201 | * |
---|
202 | * @return the technical information about this class |
---|
203 | */ |
---|
204 | public TechnicalInformation getTechnicalInformation() { |
---|
205 | TechnicalInformation result; |
---|
206 | |
---|
207 | result = new TechnicalInformation(Type.BOOK); |
---|
208 | result.setValue(Field.AUTHOR, "Leo Breiman and Jerome H. Friedman and Richard A. Olshen and Charles J. Stone"); |
---|
209 | result.setValue(Field.YEAR, "1984"); |
---|
210 | result.setValue(Field.TITLE, "Classification and Regression Trees"); |
---|
211 | result.setValue(Field.PUBLISHER, "Wadsworth International Group"); |
---|
212 | result.setValue(Field.ADDRESS, "Belmont, California"); |
---|
213 | |
---|
214 | return result; |
---|
215 | } |
---|
216 | |
---|
217 | /** |
---|
218 | * Returns default capabilities of the classifier. |
---|
219 | * |
---|
220 | * @return the capabilities of this classifier |
---|
221 | */ |
---|
222 | public Capabilities getCapabilities() { |
---|
223 | Capabilities result = super.getCapabilities(); |
---|
224 | result.disableAll(); |
---|
225 | |
---|
226 | // attributes |
---|
227 | result.enable(Capability.NOMINAL_ATTRIBUTES); |
---|
228 | result.enable(Capability.NUMERIC_ATTRIBUTES); |
---|
229 | result.enable(Capability.MISSING_VALUES); |
---|
230 | |
---|
231 | // class |
---|
232 | result.enable(Capability.NOMINAL_CLASS); |
---|
233 | |
---|
234 | return result; |
---|
235 | } |
---|
236 | |
---|
237 | /** |
---|
238 | * Build the classifier. |
---|
239 | * |
---|
240 | * @param data the training instances |
---|
241 | * @throws Exception if something goes wrong |
---|
242 | */ |
---|
243 | public void buildClassifier(Instances data) throws Exception { |
---|
244 | |
---|
245 | getCapabilities().testWithFail(data); |
---|
246 | data = new Instances(data); |
---|
247 | data.deleteWithMissingClass(); |
---|
248 | |
---|
249 | // unpruned CART decision tree |
---|
250 | if (!m_Prune) { |
---|
251 | |
---|
252 | // calculate sorted indices and weights, and compute initial class counts. |
---|
253 | int[][] sortedIndices = new int[data.numAttributes()][0]; |
---|
254 | double[][] weights = new double[data.numAttributes()][0]; |
---|
255 | double[] classProbs = new double[data.numClasses()]; |
---|
256 | double totalWeight = computeSortedInfo(data,sortedIndices, weights,classProbs); |
---|
257 | |
---|
258 | makeTree(data, data.numInstances(),sortedIndices,weights,classProbs, |
---|
259 | totalWeight,m_minNumObj, m_Heuristic); |
---|
260 | return; |
---|
261 | } |
---|
262 | |
---|
263 | Random random = new Random(m_Seed); |
---|
264 | Instances cvData = new Instances(data); |
---|
265 | cvData.randomize(random); |
---|
266 | cvData = new Instances(cvData,0,(int)(cvData.numInstances()*m_SizePer)-1); |
---|
267 | cvData.stratify(m_numFoldsPruning); |
---|
268 | |
---|
269 | double[][] alphas = new double[m_numFoldsPruning][]; |
---|
270 | double[][] errors = new double[m_numFoldsPruning][]; |
---|
271 | |
---|
272 | // calculate errors and alphas for each fold |
---|
273 | for (int i = 0; i < m_numFoldsPruning; i++) { |
---|
274 | |
---|
275 | //for every fold, grow tree on training set and fix error on test set. |
---|
276 | Instances train = cvData.trainCV(m_numFoldsPruning, i); |
---|
277 | Instances test = cvData.testCV(m_numFoldsPruning, i); |
---|
278 | |
---|
279 | // calculate sorted indices and weights, and compute initial class counts for each fold |
---|
280 | int[][] sortedIndices = new int[train.numAttributes()][0]; |
---|
281 | double[][] weights = new double[train.numAttributes()][0]; |
---|
282 | double[] classProbs = new double[train.numClasses()]; |
---|
283 | double totalWeight = computeSortedInfo(train,sortedIndices, weights,classProbs); |
---|
284 | |
---|
285 | makeTree(train, train.numInstances(),sortedIndices,weights,classProbs, |
---|
286 | totalWeight,m_minNumObj, m_Heuristic); |
---|
287 | |
---|
288 | int numNodes = numInnerNodes(); |
---|
289 | alphas[i] = new double[numNodes + 2]; |
---|
290 | errors[i] = new double[numNodes + 2]; |
---|
291 | |
---|
292 | // prune back and log alpha-values and errors on test set |
---|
293 | prune(alphas[i], errors[i], test); |
---|
294 | } |
---|
295 | |
---|
296 | // calculate sorted indices and weights, and compute initial class counts on all training instances |
---|
297 | int[][] sortedIndices = new int[data.numAttributes()][0]; |
---|
298 | double[][] weights = new double[data.numAttributes()][0]; |
---|
299 | double[] classProbs = new double[data.numClasses()]; |
---|
300 | double totalWeight = computeSortedInfo(data,sortedIndices, weights,classProbs); |
---|
301 | |
---|
302 | //build tree using all the data |
---|
303 | makeTree(data, data.numInstances(),sortedIndices,weights,classProbs, |
---|
304 | totalWeight,m_minNumObj, m_Heuristic); |
---|
305 | |
---|
306 | int numNodes = numInnerNodes(); |
---|
307 | |
---|
308 | double[] treeAlphas = new double[numNodes + 2]; |
---|
309 | |
---|
310 | // prune back and log alpha-values |
---|
311 | int iterations = prune(treeAlphas, null, null); |
---|
312 | |
---|
313 | double[] treeErrors = new double[numNodes + 2]; |
---|
314 | |
---|
315 | // for each pruned subtree, find the cross-validated error |
---|
316 | for (int i = 0; i <= iterations; i++){ |
---|
317 | //compute midpoint alphas |
---|
318 | double alpha = Math.sqrt(treeAlphas[i] * treeAlphas[i+1]); |
---|
319 | double error = 0; |
---|
320 | for (int k = 0; k < m_numFoldsPruning; k++) { |
---|
321 | int l = 0; |
---|
322 | while (alphas[k][l] <= alpha) l++; |
---|
323 | error += errors[k][l - 1]; |
---|
324 | } |
---|
325 | treeErrors[i] = error/m_numFoldsPruning; |
---|
326 | } |
---|
327 | |
---|
328 | // find best alpha |
---|
329 | int best = -1; |
---|
330 | double bestError = Double.MAX_VALUE; |
---|
331 | for (int i = iterations; i >= 0; i--) { |
---|
332 | if (treeErrors[i] < bestError) { |
---|
333 | bestError = treeErrors[i]; |
---|
334 | best = i; |
---|
335 | } |
---|
336 | } |
---|
337 | |
---|
338 | // 1 SE rule to choose expansion |
---|
339 | if (m_UseOneSE) { |
---|
340 | double oneSE = Math.sqrt(bestError*(1-bestError)/(data.numInstances())); |
---|
341 | for (int i = iterations; i >= 0; i--) { |
---|
342 | if (treeErrors[i] <= bestError+oneSE) { |
---|
343 | best = i; |
---|
344 | break; |
---|
345 | } |
---|
346 | } |
---|
347 | } |
---|
348 | |
---|
349 | double bestAlpha = Math.sqrt(treeAlphas[best] * treeAlphas[best + 1]); |
---|
350 | |
---|
351 | //"unprune" final tree (faster than regrowing it) |
---|
352 | unprune(); |
---|
353 | prune(bestAlpha); |
---|
354 | } |
---|
355 | |
---|
356 | /** |
---|
357 | * Make binary decision tree recursively. |
---|
358 | * |
---|
359 | * @param data the training instances |
---|
360 | * @param totalInstances total number of instances |
---|
361 | * @param sortedIndices sorted indices of the instances |
---|
362 | * @param weights weights of the instances |
---|
363 | * @param classProbs class probabilities |
---|
364 | * @param totalWeight total weight of instances |
---|
365 | * @param minNumObj minimal number of instances at leaf nodes |
---|
366 | * @param useHeuristic if use heuristic search for nominal attributes in multi-class problem |
---|
367 | * @throws Exception if something goes wrong |
---|
368 | */ |
---|
369 | protected void makeTree(Instances data, int totalInstances, int[][] sortedIndices, |
---|
370 | double[][] weights, double[] classProbs, double totalWeight, double minNumObj, |
---|
371 | boolean useHeuristic) throws Exception{ |
---|
372 | |
---|
373 | // if no instances have reached this node (normally won't happen) |
---|
374 | if (totalWeight == 0){ |
---|
375 | m_Attribute = null; |
---|
376 | m_ClassValue = Utils.missingValue(); |
---|
377 | m_Distribution = new double[data.numClasses()]; |
---|
378 | return; |
---|
379 | } |
---|
380 | |
---|
381 | m_totalTrainInstances = totalInstances; |
---|
382 | m_isLeaf = true; |
---|
383 | |
---|
384 | m_ClassProbs = new double[classProbs.length]; |
---|
385 | m_Distribution = new double[classProbs.length]; |
---|
386 | System.arraycopy(classProbs, 0, m_ClassProbs, 0, classProbs.length); |
---|
387 | System.arraycopy(classProbs, 0, m_Distribution, 0, classProbs.length); |
---|
388 | if (Utils.sum(m_ClassProbs)!=0) Utils.normalize(m_ClassProbs); |
---|
389 | |
---|
390 | // Compute class distributions and value of splitting |
---|
391 | // criterion for each attribute |
---|
392 | double[][][] dists = new double[data.numAttributes()][0][0]; |
---|
393 | double[][] props = new double[data.numAttributes()][0]; |
---|
394 | double[][] totalSubsetWeights = new double[data.numAttributes()][2]; |
---|
395 | double[] splits = new double[data.numAttributes()]; |
---|
396 | String[] splitString = new String[data.numAttributes()]; |
---|
397 | double[] giniGains = new double[data.numAttributes()]; |
---|
398 | |
---|
399 | // for each attribute find split information |
---|
400 | for (int i = 0; i < data.numAttributes(); i++) { |
---|
401 | Attribute att = data.attribute(i); |
---|
402 | if (i==data.classIndex()) continue; |
---|
403 | if (att.isNumeric()) { |
---|
404 | // numeric attribute |
---|
405 | splits[i] = numericDistribution(props, dists, att, sortedIndices[i], |
---|
406 | weights[i], totalSubsetWeights, giniGains, data); |
---|
407 | } else { |
---|
408 | // nominal attribute |
---|
409 | splitString[i] = nominalDistribution(props, dists, att, sortedIndices[i], |
---|
410 | weights[i], totalSubsetWeights, giniGains, data, useHeuristic); |
---|
411 | } |
---|
412 | } |
---|
413 | |
---|
414 | // Find best attribute (split with maximum Gini gain) |
---|
415 | int attIndex = Utils.maxIndex(giniGains); |
---|
416 | m_Attribute = data.attribute(attIndex); |
---|
417 | |
---|
418 | m_train = new Instances(data, sortedIndices[attIndex].length); |
---|
419 | for (int i=0; i<sortedIndices[attIndex].length; i++) { |
---|
420 | Instance inst = data.instance(sortedIndices[attIndex][i]); |
---|
421 | Instance instCopy = (Instance)inst.copy(); |
---|
422 | instCopy.setWeight(weights[attIndex][i]); |
---|
423 | m_train.add(instCopy); |
---|
424 | } |
---|
425 | |
---|
426 | // Check if node does not contain enough instances, or if it can not be split, |
---|
427 | // or if it is pure. If does, make leaf. |
---|
428 | if (totalWeight < 2 * minNumObj || giniGains[attIndex]==0 || |
---|
429 | props[attIndex][0]==0 || props[attIndex][1]==0) { |
---|
430 | makeLeaf(data); |
---|
431 | } |
---|
432 | |
---|
433 | else { |
---|
434 | m_Props = props[attIndex]; |
---|
435 | int[][][] subsetIndices = new int[2][data.numAttributes()][0]; |
---|
436 | double[][][] subsetWeights = new double[2][data.numAttributes()][0]; |
---|
437 | |
---|
438 | // numeric split |
---|
439 | if (m_Attribute.isNumeric()) m_SplitValue = splits[attIndex]; |
---|
440 | |
---|
441 | // nominal split |
---|
442 | else m_SplitString = splitString[attIndex]; |
---|
443 | |
---|
444 | splitData(subsetIndices, subsetWeights, m_Attribute, m_SplitValue, |
---|
445 | m_SplitString, sortedIndices, weights, data); |
---|
446 | |
---|
447 | // If split of the node results in a node with less than minimal number of isntances, |
---|
448 | // make the node leaf node. |
---|
449 | if (subsetIndices[0][attIndex].length<minNumObj || |
---|
450 | subsetIndices[1][attIndex].length<minNumObj) { |
---|
451 | makeLeaf(data); |
---|
452 | return; |
---|
453 | } |
---|
454 | |
---|
455 | // Otherwise, split the node. |
---|
456 | m_isLeaf = false; |
---|
457 | m_Successors = new SimpleCart[2]; |
---|
458 | for (int i = 0; i < 2; i++) { |
---|
459 | m_Successors[i] = new SimpleCart(); |
---|
460 | m_Successors[i].makeTree(data, m_totalTrainInstances, subsetIndices[i], |
---|
461 | subsetWeights[i],dists[attIndex][i], totalSubsetWeights[attIndex][i], |
---|
462 | minNumObj, useHeuristic); |
---|
463 | } |
---|
464 | } |
---|
465 | } |
---|
466 | |
---|
467 | /** |
---|
468 | * Prunes the original tree using the CART pruning scheme, given a |
---|
469 | * cost-complexity parameter alpha. |
---|
470 | * |
---|
471 | * @param alpha the cost-complexity parameter |
---|
472 | * @throws Exception if something goes wrong |
---|
473 | */ |
---|
474 | public void prune(double alpha) throws Exception { |
---|
475 | |
---|
476 | Vector nodeList; |
---|
477 | |
---|
478 | // determine training error of pruned subtrees (both with and without replacing a subtree), |
---|
479 | // and calculate alpha-values from them |
---|
480 | modelErrors(); |
---|
481 | treeErrors(); |
---|
482 | calculateAlphas(); |
---|
483 | |
---|
484 | // get list of all inner nodes in the tree |
---|
485 | nodeList = getInnerNodes(); |
---|
486 | |
---|
487 | boolean prune = (nodeList.size() > 0); |
---|
488 | double preAlpha = Double.MAX_VALUE; |
---|
489 | while (prune) { |
---|
490 | |
---|
491 | // select node with minimum alpha |
---|
492 | SimpleCart nodeToPrune = nodeToPrune(nodeList); |
---|
493 | |
---|
494 | // want to prune if its alpha is smaller than alpha |
---|
495 | if (nodeToPrune.m_Alpha > alpha) { |
---|
496 | break; |
---|
497 | } |
---|
498 | |
---|
499 | nodeToPrune.makeLeaf(nodeToPrune.m_train); |
---|
500 | |
---|
501 | // normally would not happen |
---|
502 | if (nodeToPrune.m_Alpha==preAlpha) { |
---|
503 | nodeToPrune.makeLeaf(nodeToPrune.m_train); |
---|
504 | treeErrors(); |
---|
505 | calculateAlphas(); |
---|
506 | nodeList = getInnerNodes(); |
---|
507 | prune = (nodeList.size() > 0); |
---|
508 | continue; |
---|
509 | } |
---|
510 | preAlpha = nodeToPrune.m_Alpha; |
---|
511 | |
---|
512 | //update tree errors and alphas |
---|
513 | treeErrors(); |
---|
514 | calculateAlphas(); |
---|
515 | |
---|
516 | nodeList = getInnerNodes(); |
---|
517 | prune = (nodeList.size() > 0); |
---|
518 | } |
---|
519 | } |
---|
520 | |
---|
521 | /** |
---|
522 | * Method for performing one fold in the cross-validation of minimal |
---|
523 | * cost-complexity pruning. Generates a sequence of alpha-values with error |
---|
524 | * estimates for the corresponding (partially pruned) trees, given the test |
---|
525 | * set of that fold. |
---|
526 | * |
---|
527 | * @param alphas array to hold the generated alpha-values |
---|
528 | * @param errors array to hold the corresponding error estimates |
---|
529 | * @param test test set of that fold (to obtain error estimates) |
---|
530 | * @return the iteration of the pruning |
---|
531 | * @throws Exception if something goes wrong |
---|
532 | */ |
---|
533 | public int prune(double[] alphas, double[] errors, Instances test) |
---|
534 | throws Exception { |
---|
535 | |
---|
536 | Vector nodeList; |
---|
537 | |
---|
538 | // determine training error of subtrees (both with and without replacing a subtree), |
---|
539 | // and calculate alpha-values from them |
---|
540 | modelErrors(); |
---|
541 | treeErrors(); |
---|
542 | calculateAlphas(); |
---|
543 | |
---|
544 | // get list of all inner nodes in the tree |
---|
545 | nodeList = getInnerNodes(); |
---|
546 | |
---|
547 | boolean prune = (nodeList.size() > 0); |
---|
548 | |
---|
549 | //alpha_0 is always zero (unpruned tree) |
---|
550 | alphas[0] = 0; |
---|
551 | |
---|
552 | Evaluation eval; |
---|
553 | |
---|
554 | // error of unpruned tree |
---|
555 | if (errors != null) { |
---|
556 | eval = new Evaluation(test); |
---|
557 | eval.evaluateModel(this, test); |
---|
558 | errors[0] = eval.errorRate(); |
---|
559 | } |
---|
560 | |
---|
561 | int iteration = 0; |
---|
562 | double preAlpha = Double.MAX_VALUE; |
---|
563 | while (prune) { |
---|
564 | |
---|
565 | iteration++; |
---|
566 | |
---|
567 | // get node with minimum alpha |
---|
568 | SimpleCart nodeToPrune = nodeToPrune(nodeList); |
---|
569 | |
---|
570 | // do not set m_sons null, want to unprune |
---|
571 | nodeToPrune.m_isLeaf = true; |
---|
572 | |
---|
573 | // normally would not happen |
---|
574 | if (nodeToPrune.m_Alpha==preAlpha) { |
---|
575 | iteration--; |
---|
576 | treeErrors(); |
---|
577 | calculateAlphas(); |
---|
578 | nodeList = getInnerNodes(); |
---|
579 | prune = (nodeList.size() > 0); |
---|
580 | continue; |
---|
581 | } |
---|
582 | |
---|
583 | // get alpha-value of node |
---|
584 | alphas[iteration] = nodeToPrune.m_Alpha; |
---|
585 | |
---|
586 | // log error |
---|
587 | if (errors != null) { |
---|
588 | eval = new Evaluation(test); |
---|
589 | eval.evaluateModel(this, test); |
---|
590 | errors[iteration] = eval.errorRate(); |
---|
591 | } |
---|
592 | preAlpha = nodeToPrune.m_Alpha; |
---|
593 | |
---|
594 | //update errors/alphas |
---|
595 | treeErrors(); |
---|
596 | calculateAlphas(); |
---|
597 | |
---|
598 | nodeList = getInnerNodes(); |
---|
599 | prune = (nodeList.size() > 0); |
---|
600 | } |
---|
601 | |
---|
602 | //set last alpha 1 to indicate end |
---|
603 | alphas[iteration + 1] = 1.0; |
---|
604 | return iteration; |
---|
605 | } |
---|
606 | |
---|
607 | /** |
---|
608 | * Method to "unprune" the CART tree. Sets all leaf-fields to false. |
---|
609 | * Faster than re-growing the tree because CART do not have to be fit again. |
---|
610 | */ |
---|
611 | protected void unprune() { |
---|
612 | if (m_Successors != null) { |
---|
613 | m_isLeaf = false; |
---|
614 | for (int i = 0; i < m_Successors.length; i++) m_Successors[i].unprune(); |
---|
615 | } |
---|
616 | } |
---|
617 | |
---|
618 | /** |
---|
619 | * Compute distributions, proportions and total weights of two successor |
---|
620 | * nodes for a given numeric attribute. |
---|
621 | * |
---|
622 | * @param props proportions of each two branches for each attribute |
---|
623 | * @param dists class distributions of two branches for each attribute |
---|
624 | * @param att numeric att split on |
---|
625 | * @param sortedIndices sorted indices of instances for the attirubte |
---|
626 | * @param weights weights of instances for the attirbute |
---|
627 | * @param subsetWeights total weight of two branches split based on the attribute |
---|
628 | * @param giniGains Gini gains for each attribute |
---|
629 | * @param data training instances |
---|
630 | * @return Gini gain the given numeric attribute |
---|
631 | * @throws Exception if something goes wrong |
---|
632 | */ |
---|
633 | protected double numericDistribution(double[][] props, double[][][] dists, |
---|
634 | Attribute att, int[] sortedIndices, double[] weights, double[][] subsetWeights, |
---|
635 | double[] giniGains, Instances data) |
---|
636 | throws Exception { |
---|
637 | |
---|
638 | double splitPoint = Double.NaN; |
---|
639 | double[][] dist = null; |
---|
640 | int numClasses = data.numClasses(); |
---|
641 | int i; // differ instances with or without missing values |
---|
642 | |
---|
643 | double[][] currDist = new double[2][numClasses]; |
---|
644 | dist = new double[2][numClasses]; |
---|
645 | |
---|
646 | // Move all instances without missing values into second subset |
---|
647 | double[] parentDist = new double[numClasses]; |
---|
648 | int missingStart = 0; |
---|
649 | for (int j = 0; j < sortedIndices.length; j++) { |
---|
650 | Instance inst = data.instance(sortedIndices[j]); |
---|
651 | if (!inst.isMissing(att)) { |
---|
652 | missingStart ++; |
---|
653 | currDist[1][(int)inst.classValue()] += weights[j]; |
---|
654 | } |
---|
655 | parentDist[(int)inst.classValue()] += weights[j]; |
---|
656 | } |
---|
657 | System.arraycopy(currDist[1], 0, dist[1], 0, dist[1].length); |
---|
658 | |
---|
659 | // Try all possible split points |
---|
660 | double currSplit = data.instance(sortedIndices[0]).value(att); |
---|
661 | double currGiniGain; |
---|
662 | double bestGiniGain = -Double.MAX_VALUE; |
---|
663 | |
---|
664 | for (i = 0; i < sortedIndices.length; i++) { |
---|
665 | Instance inst = data.instance(sortedIndices[i]); |
---|
666 | if (inst.isMissing(att)) { |
---|
667 | break; |
---|
668 | } |
---|
669 | if (inst.value(att) > currSplit) { |
---|
670 | |
---|
671 | double[][] tempDist = new double[2][numClasses]; |
---|
672 | for (int k=0; k<2; k++) { |
---|
673 | //tempDist[k] = currDist[k]; |
---|
674 | System.arraycopy(currDist[k], 0, tempDist[k], 0, tempDist[k].length); |
---|
675 | } |
---|
676 | |
---|
677 | double[] tempProps = new double[2]; |
---|
678 | for (int k=0; k<2; k++) { |
---|
679 | tempProps[k] = Utils.sum(tempDist[k]); |
---|
680 | } |
---|
681 | |
---|
682 | if (Utils.sum(tempProps) !=0) Utils.normalize(tempProps); |
---|
683 | |
---|
684 | // split missing values |
---|
685 | int index = missingStart; |
---|
686 | while (index < sortedIndices.length) { |
---|
687 | Instance insta = data.instance(sortedIndices[index]); |
---|
688 | for (int j = 0; j < 2; j++) { |
---|
689 | tempDist[j][(int)insta.classValue()] += tempProps[j] * weights[index]; |
---|
690 | } |
---|
691 | index++; |
---|
692 | } |
---|
693 | |
---|
694 | currGiniGain = computeGiniGain(parentDist,tempDist); |
---|
695 | |
---|
696 | if (currGiniGain > bestGiniGain) { |
---|
697 | bestGiniGain = currGiniGain; |
---|
698 | |
---|
699 | // clean split point |
---|
700 | // splitPoint = Math.rint((inst.value(att) + currSplit)/2.0*100000)/100000.0; |
---|
701 | splitPoint = (inst.value(att) + currSplit) / 2.0; |
---|
702 | |
---|
703 | for (int j = 0; j < currDist.length; j++) { |
---|
704 | System.arraycopy(tempDist[j], 0, dist[j], 0, |
---|
705 | dist[j].length); |
---|
706 | } |
---|
707 | } |
---|
708 | } |
---|
709 | currSplit = inst.value(att); |
---|
710 | currDist[0][(int)inst.classValue()] += weights[i]; |
---|
711 | currDist[1][(int)inst.classValue()] -= weights[i]; |
---|
712 | } |
---|
713 | |
---|
714 | // Compute weights |
---|
715 | int attIndex = att.index(); |
---|
716 | props[attIndex] = new double[2]; |
---|
717 | for (int k = 0; k < 2; k++) { |
---|
718 | props[attIndex][k] = Utils.sum(dist[k]); |
---|
719 | } |
---|
720 | if (Utils.sum(props[attIndex]) != 0) Utils.normalize(props[attIndex]); |
---|
721 | |
---|
722 | // Compute subset weights |
---|
723 | subsetWeights[attIndex] = new double[2]; |
---|
724 | for (int j = 0; j < 2; j++) { |
---|
725 | subsetWeights[attIndex][j] += Utils.sum(dist[j]); |
---|
726 | } |
---|
727 | |
---|
728 | // clean Gini gain |
---|
729 | //giniGains[attIndex] = Math.rint(bestGiniGain*10000000)/10000000.0; |
---|
730 | giniGains[attIndex] = bestGiniGain; |
---|
731 | dists[attIndex] = dist; |
---|
732 | |
---|
733 | return splitPoint; |
---|
734 | } |
---|
735 | |
---|
736 | /** |
---|
737 | * Compute distributions, proportions and total weights of two successor |
---|
738 | * nodes for a given nominal attribute. |
---|
739 | * |
---|
740 | * @param props proportions of each two branches for each attribute |
---|
741 | * @param dists class distributions of two branches for each attribute |
---|
742 | * @param att numeric att split on |
---|
743 | * @param sortedIndices sorted indices of instances for the attirubte |
---|
744 | * @param weights weights of instances for the attirbute |
---|
745 | * @param subsetWeights total weight of two branches split based on the attribute |
---|
746 | * @param giniGains Gini gains for each attribute |
---|
747 | * @param data training instances |
---|
748 | * @param useHeuristic if use heuristic search |
---|
749 | * @return Gini gain for the given nominal attribute |
---|
750 | * @throws Exception if something goes wrong |
---|
751 | */ |
---|
752 | protected String nominalDistribution(double[][] props, double[][][] dists, |
---|
753 | Attribute att, int[] sortedIndices, double[] weights, double[][] subsetWeights, |
---|
754 | double[] giniGains, Instances data, boolean useHeuristic) |
---|
755 | throws Exception { |
---|
756 | |
---|
757 | String[] values = new String[att.numValues()]; |
---|
758 | int numCat = values.length; // number of values of the attribute |
---|
759 | int numClasses = data.numClasses(); |
---|
760 | |
---|
761 | String bestSplitString = ""; |
---|
762 | double bestGiniGain = -Double.MAX_VALUE; |
---|
763 | |
---|
764 | // class frequency for each value |
---|
765 | int[] classFreq = new int[numCat]; |
---|
766 | for (int j=0; j<numCat; j++) classFreq[j] = 0; |
---|
767 | |
---|
768 | double[] parentDist = new double[numClasses]; |
---|
769 | double[][] currDist = new double[2][numClasses]; |
---|
770 | double[][] dist = new double[2][numClasses]; |
---|
771 | int missingStart = 0; |
---|
772 | |
---|
773 | for (int i = 0; i < sortedIndices.length; i++) { |
---|
774 | Instance inst = data.instance(sortedIndices[i]); |
---|
775 | if (!inst.isMissing(att)) { |
---|
776 | missingStart++; |
---|
777 | classFreq[(int)inst.value(att)] ++; |
---|
778 | } |
---|
779 | parentDist[(int)inst.classValue()] += weights[i]; |
---|
780 | } |
---|
781 | |
---|
782 | // count the number of values that class frequency is not 0 |
---|
783 | int nonEmpty = 0; |
---|
784 | for (int j=0; j<numCat; j++) { |
---|
785 | if (classFreq[j]!=0) nonEmpty ++; |
---|
786 | } |
---|
787 | |
---|
788 | // attribute values that class frequency is not 0 |
---|
789 | String[] nonEmptyValues = new String[nonEmpty]; |
---|
790 | int nonEmptyIndex = 0; |
---|
791 | for (int j=0; j<numCat; j++) { |
---|
792 | if (classFreq[j]!=0) { |
---|
793 | nonEmptyValues[nonEmptyIndex] = att.value(j); |
---|
794 | nonEmptyIndex ++; |
---|
795 | } |
---|
796 | } |
---|
797 | |
---|
798 | // attribute values that class frequency is 0 |
---|
799 | int empty = numCat - nonEmpty; |
---|
800 | String[] emptyValues = new String[empty]; |
---|
801 | int emptyIndex = 0; |
---|
802 | for (int j=0; j<numCat; j++) { |
---|
803 | if (classFreq[j]==0) { |
---|
804 | emptyValues[emptyIndex] = att.value(j); |
---|
805 | emptyIndex ++; |
---|
806 | } |
---|
807 | } |
---|
808 | |
---|
809 | if (nonEmpty<=1) { |
---|
810 | giniGains[att.index()] = 0; |
---|
811 | return ""; |
---|
812 | } |
---|
813 | |
---|
814 | // for tow-class probloms |
---|
815 | if (data.numClasses()==2) { |
---|
816 | |
---|
817 | //// Firstly, for attribute values which class frequency is not zero |
---|
818 | |
---|
819 | // probability of class 0 for each attribute value |
---|
820 | double[] pClass0 = new double[nonEmpty]; |
---|
821 | // class distribution for each attribute value |
---|
822 | double[][] valDist = new double[nonEmpty][2]; |
---|
823 | |
---|
824 | for (int j=0; j<nonEmpty; j++) { |
---|
825 | for (int k=0; k<2; k++) { |
---|
826 | valDist[j][k] = 0; |
---|
827 | } |
---|
828 | } |
---|
829 | |
---|
830 | for (int i = 0; i < sortedIndices.length; i++) { |
---|
831 | Instance inst = data.instance(sortedIndices[i]); |
---|
832 | if (inst.isMissing(att)) { |
---|
833 | break; |
---|
834 | } |
---|
835 | |
---|
836 | for (int j=0; j<nonEmpty; j++) { |
---|
837 | if (att.value((int)inst.value(att)).compareTo(nonEmptyValues[j])==0) { |
---|
838 | valDist[j][(int)inst.classValue()] += inst.weight(); |
---|
839 | break; |
---|
840 | } |
---|
841 | } |
---|
842 | } |
---|
843 | |
---|
844 | for (int j=0; j<nonEmpty; j++) { |
---|
845 | double distSum = Utils.sum(valDist[j]); |
---|
846 | if (distSum==0) pClass0[j]=0; |
---|
847 | else pClass0[j] = valDist[j][0]/distSum; |
---|
848 | } |
---|
849 | |
---|
850 | // sort category according to the probability of the first class |
---|
851 | String[] sortedValues = new String[nonEmpty]; |
---|
852 | for (int j=0; j<nonEmpty; j++) { |
---|
853 | sortedValues[j] = nonEmptyValues[Utils.minIndex(pClass0)]; |
---|
854 | pClass0[Utils.minIndex(pClass0)] = Double.MAX_VALUE; |
---|
855 | } |
---|
856 | |
---|
857 | // Find a subset of attribute values that maximize Gini decrease |
---|
858 | |
---|
859 | // for the attribute values that class frequency is not 0 |
---|
860 | String tempStr = ""; |
---|
861 | |
---|
862 | for (int j=0; j<nonEmpty-1; j++) { |
---|
863 | currDist = new double[2][numClasses]; |
---|
864 | if (tempStr=="") tempStr="(" + sortedValues[j] + ")"; |
---|
865 | else tempStr += "|"+ "(" + sortedValues[j] + ")"; |
---|
866 | for (int i=0; i<sortedIndices.length;i++) { |
---|
867 | Instance inst = data.instance(sortedIndices[i]); |
---|
868 | if (inst.isMissing(att)) { |
---|
869 | break; |
---|
870 | } |
---|
871 | |
---|
872 | if (tempStr.indexOf |
---|
873 | ("(" + att.value((int)inst.value(att)) + ")")!=-1) { |
---|
874 | currDist[0][(int)inst.classValue()] += weights[i]; |
---|
875 | } else currDist[1][(int)inst.classValue()] += weights[i]; |
---|
876 | } |
---|
877 | |
---|
878 | double[][] tempDist = new double[2][numClasses]; |
---|
879 | for (int kk=0; kk<2; kk++) { |
---|
880 | tempDist[kk] = currDist[kk]; |
---|
881 | } |
---|
882 | |
---|
883 | double[] tempProps = new double[2]; |
---|
884 | for (int kk=0; kk<2; kk++) { |
---|
885 | tempProps[kk] = Utils.sum(tempDist[kk]); |
---|
886 | } |
---|
887 | |
---|
888 | if (Utils.sum(tempProps)!=0) Utils.normalize(tempProps); |
---|
889 | |
---|
890 | // split missing values |
---|
891 | int mstart = missingStart; |
---|
892 | while (mstart < sortedIndices.length) { |
---|
893 | Instance insta = data.instance(sortedIndices[mstart]); |
---|
894 | for (int jj = 0; jj < 2; jj++) { |
---|
895 | tempDist[jj][(int)insta.classValue()] += tempProps[jj] * weights[mstart]; |
---|
896 | } |
---|
897 | mstart++; |
---|
898 | } |
---|
899 | |
---|
900 | double currGiniGain = computeGiniGain(parentDist,tempDist); |
---|
901 | |
---|
902 | if (currGiniGain>bestGiniGain) { |
---|
903 | bestGiniGain = currGiniGain; |
---|
904 | bestSplitString = tempStr; |
---|
905 | for (int jj = 0; jj < 2; jj++) { |
---|
906 | //dist[jj] = new double[currDist[jj].length]; |
---|
907 | System.arraycopy(tempDist[jj], 0, dist[jj], 0, |
---|
908 | dist[jj].length); |
---|
909 | } |
---|
910 | } |
---|
911 | } |
---|
912 | } |
---|
913 | |
---|
914 | // multi-class problems - exhaustive search |
---|
915 | else if (!useHeuristic || nonEmpty<=4) { |
---|
916 | |
---|
917 | // Firstly, for attribute values which class frequency is not zero |
---|
918 | for (int i=0; i<(int)Math.pow(2,nonEmpty-1); i++) { |
---|
919 | String tempStr=""; |
---|
920 | currDist = new double[2][numClasses]; |
---|
921 | int mod; |
---|
922 | int bit10 = i; |
---|
923 | for (int j=nonEmpty-1; j>=0; j--) { |
---|
924 | mod = bit10%2; // convert from 10bit to 2bit |
---|
925 | if (mod==1) { |
---|
926 | if (tempStr=="") tempStr = "("+nonEmptyValues[j]+")"; |
---|
927 | else tempStr += "|" + "("+nonEmptyValues[j]+")"; |
---|
928 | } |
---|
929 | bit10 = bit10/2; |
---|
930 | } |
---|
931 | for (int j=0; j<sortedIndices.length;j++) { |
---|
932 | Instance inst = data.instance(sortedIndices[j]); |
---|
933 | if (inst.isMissing(att)) { |
---|
934 | break; |
---|
935 | } |
---|
936 | |
---|
937 | if (tempStr.indexOf("("+att.value((int)inst.value(att))+")")!=-1) { |
---|
938 | currDist[0][(int)inst.classValue()] += weights[j]; |
---|
939 | } else currDist[1][(int)inst.classValue()] += weights[j]; |
---|
940 | } |
---|
941 | |
---|
942 | double[][] tempDist = new double[2][numClasses]; |
---|
943 | for (int k=0; k<2; k++) { |
---|
944 | tempDist[k] = currDist[k]; |
---|
945 | } |
---|
946 | |
---|
947 | double[] tempProps = new double[2]; |
---|
948 | for (int k=0; k<2; k++) { |
---|
949 | tempProps[k] = Utils.sum(tempDist[k]); |
---|
950 | } |
---|
951 | |
---|
952 | if (Utils.sum(tempProps)!=0) Utils.normalize(tempProps); |
---|
953 | |
---|
954 | // split missing values |
---|
955 | int index = missingStart; |
---|
956 | while (index < sortedIndices.length) { |
---|
957 | Instance insta = data.instance(sortedIndices[index]); |
---|
958 | for (int j = 0; j < 2; j++) { |
---|
959 | tempDist[j][(int)insta.classValue()] += tempProps[j] * weights[index]; |
---|
960 | } |
---|
961 | index++; |
---|
962 | } |
---|
963 | |
---|
964 | double currGiniGain = computeGiniGain(parentDist,tempDist); |
---|
965 | |
---|
966 | if (currGiniGain>bestGiniGain) { |
---|
967 | bestGiniGain = currGiniGain; |
---|
968 | bestSplitString = tempStr; |
---|
969 | for (int j = 0; j < 2; j++) { |
---|
970 | //dist[jj] = new double[currDist[jj].length]; |
---|
971 | System.arraycopy(tempDist[j], 0, dist[j], 0, |
---|
972 | dist[j].length); |
---|
973 | } |
---|
974 | } |
---|
975 | } |
---|
976 | } |
---|
977 | |
---|
978 | // huristic search to solve multi-classes problems |
---|
979 | else { |
---|
980 | // Firstly, for attribute values which class frequency is not zero |
---|
981 | int n = nonEmpty; |
---|
982 | int k = data.numClasses(); // number of classes of the data |
---|
983 | double[][] P = new double[n][k]; // class probability matrix |
---|
984 | int[] numInstancesValue = new int[n]; // number of instances for an attribute value |
---|
985 | double[] meanClass = new double[k]; // vector of mean class probability |
---|
986 | int numInstances = data.numInstances(); // total number of instances |
---|
987 | |
---|
988 | // initialize the vector of mean class probability |
---|
989 | for (int j=0; j<meanClass.length; j++) meanClass[j]=0; |
---|
990 | |
---|
991 | for (int j=0; j<numInstances; j++) { |
---|
992 | Instance inst = (Instance)data.instance(j); |
---|
993 | int valueIndex = 0; // attribute value index in nonEmptyValues |
---|
994 | for (int i=0; i<nonEmpty; i++) { |
---|
995 | if (att.value((int)inst.value(att)).compareToIgnoreCase(nonEmptyValues[i])==0){ |
---|
996 | valueIndex = i; |
---|
997 | break; |
---|
998 | } |
---|
999 | } |
---|
1000 | P[valueIndex][(int)inst.classValue()]++; |
---|
1001 | numInstancesValue[valueIndex]++; |
---|
1002 | meanClass[(int)inst.classValue()]++; |
---|
1003 | } |
---|
1004 | |
---|
1005 | // calculate the class probability matrix |
---|
1006 | for (int i=0; i<P.length; i++) { |
---|
1007 | for (int j=0; j<P[0].length; j++) { |
---|
1008 | if (numInstancesValue[i]==0) P[i][j]=0; |
---|
1009 | else P[i][j]/=numInstancesValue[i]; |
---|
1010 | } |
---|
1011 | } |
---|
1012 | |
---|
1013 | //calculate the vector of mean class probability |
---|
1014 | for (int i=0; i<meanClass.length; i++) { |
---|
1015 | meanClass[i]/=numInstances; |
---|
1016 | } |
---|
1017 | |
---|
1018 | // calculate the covariance matrix |
---|
1019 | double[][] covariance = new double[k][k]; |
---|
1020 | for (int i1=0; i1<k; i1++) { |
---|
1021 | for (int i2=0; i2<k; i2++) { |
---|
1022 | double element = 0; |
---|
1023 | for (int j=0; j<n; j++) { |
---|
1024 | element += (P[j][i2]-meanClass[i2])*(P[j][i1]-meanClass[i1]) |
---|
1025 | *numInstancesValue[j]; |
---|
1026 | } |
---|
1027 | covariance[i1][i2] = element; |
---|
1028 | } |
---|
1029 | } |
---|
1030 | |
---|
1031 | Matrix matrix = new Matrix(covariance); |
---|
1032 | weka.core.matrix.EigenvalueDecomposition eigen = |
---|
1033 | new weka.core.matrix.EigenvalueDecomposition(matrix); |
---|
1034 | double[] eigenValues = eigen.getRealEigenvalues(); |
---|
1035 | |
---|
1036 | // find index of the largest eigenvalue |
---|
1037 | int index=0; |
---|
1038 | double largest = eigenValues[0]; |
---|
1039 | for (int i=1; i<eigenValues.length; i++) { |
---|
1040 | if (eigenValues[i]>largest) { |
---|
1041 | index=i; |
---|
1042 | largest = eigenValues[i]; |
---|
1043 | } |
---|
1044 | } |
---|
1045 | |
---|
1046 | // calculate the first principle component |
---|
1047 | double[] FPC = new double[k]; |
---|
1048 | Matrix eigenVector = eigen.getV(); |
---|
1049 | double[][] vectorArray = eigenVector.getArray(); |
---|
1050 | for (int i=0; i<FPC.length; i++) { |
---|
1051 | FPC[i] = vectorArray[i][index]; |
---|
1052 | } |
---|
1053 | |
---|
1054 | // calculate the first principle component scores |
---|
1055 | //System.out.println("the first principle component scores: "); |
---|
1056 | double[] Sa = new double[n]; |
---|
1057 | for (int i=0; i<Sa.length; i++) { |
---|
1058 | Sa[i]=0; |
---|
1059 | for (int j=0; j<k; j++) { |
---|
1060 | Sa[i] += FPC[j]*P[i][j]; |
---|
1061 | } |
---|
1062 | } |
---|
1063 | |
---|
1064 | // sort category according to Sa(s) |
---|
1065 | double[] pCopy = new double[n]; |
---|
1066 | System.arraycopy(Sa,0,pCopy,0,n); |
---|
1067 | String[] sortedValues = new String[n]; |
---|
1068 | Arrays.sort(Sa); |
---|
1069 | |
---|
1070 | for (int j=0; j<n; j++) { |
---|
1071 | sortedValues[j] = nonEmptyValues[Utils.minIndex(pCopy)]; |
---|
1072 | pCopy[Utils.minIndex(pCopy)] = Double.MAX_VALUE; |
---|
1073 | } |
---|
1074 | |
---|
1075 | // for the attribute values that class frequency is not 0 |
---|
1076 | String tempStr = ""; |
---|
1077 | |
---|
1078 | for (int j=0; j<nonEmpty-1; j++) { |
---|
1079 | currDist = new double[2][numClasses]; |
---|
1080 | if (tempStr=="") tempStr="(" + sortedValues[j] + ")"; |
---|
1081 | else tempStr += "|"+ "(" + sortedValues[j] + ")"; |
---|
1082 | for (int i=0; i<sortedIndices.length;i++) { |
---|
1083 | Instance inst = data.instance(sortedIndices[i]); |
---|
1084 | if (inst.isMissing(att)) { |
---|
1085 | break; |
---|
1086 | } |
---|
1087 | |
---|
1088 | if (tempStr.indexOf |
---|
1089 | ("(" + att.value((int)inst.value(att)) + ")")!=-1) { |
---|
1090 | currDist[0][(int)inst.classValue()] += weights[i]; |
---|
1091 | } else currDist[1][(int)inst.classValue()] += weights[i]; |
---|
1092 | } |
---|
1093 | |
---|
1094 | double[][] tempDist = new double[2][numClasses]; |
---|
1095 | for (int kk=0; kk<2; kk++) { |
---|
1096 | tempDist[kk] = currDist[kk]; |
---|
1097 | } |
---|
1098 | |
---|
1099 | double[] tempProps = new double[2]; |
---|
1100 | for (int kk=0; kk<2; kk++) { |
---|
1101 | tempProps[kk] = Utils.sum(tempDist[kk]); |
---|
1102 | } |
---|
1103 | |
---|
1104 | if (Utils.sum(tempProps)!=0) Utils.normalize(tempProps); |
---|
1105 | |
---|
1106 | // split missing values |
---|
1107 | int mstart = missingStart; |
---|
1108 | while (mstart < sortedIndices.length) { |
---|
1109 | Instance insta = data.instance(sortedIndices[mstart]); |
---|
1110 | for (int jj = 0; jj < 2; jj++) { |
---|
1111 | tempDist[jj][(int)insta.classValue()] += tempProps[jj] * weights[mstart]; |
---|
1112 | } |
---|
1113 | mstart++; |
---|
1114 | } |
---|
1115 | |
---|
1116 | double currGiniGain = computeGiniGain(parentDist,tempDist); |
---|
1117 | |
---|
1118 | if (currGiniGain>bestGiniGain) { |
---|
1119 | bestGiniGain = currGiniGain; |
---|
1120 | bestSplitString = tempStr; |
---|
1121 | for (int jj = 0; jj < 2; jj++) { |
---|
1122 | //dist[jj] = new double[currDist[jj].length]; |
---|
1123 | System.arraycopy(tempDist[jj], 0, dist[jj], 0, |
---|
1124 | dist[jj].length); |
---|
1125 | } |
---|
1126 | } |
---|
1127 | } |
---|
1128 | } |
---|
1129 | |
---|
1130 | // Compute weights |
---|
1131 | int attIndex = att.index(); |
---|
1132 | props[attIndex] = new double[2]; |
---|
1133 | for (int k = 0; k < 2; k++) { |
---|
1134 | props[attIndex][k] = Utils.sum(dist[k]); |
---|
1135 | } |
---|
1136 | |
---|
1137 | if (!(Utils.sum(props[attIndex]) > 0)) { |
---|
1138 | for (int k = 0; k < props[attIndex].length; k++) { |
---|
1139 | props[attIndex][k] = 1.0 / (double)props[attIndex].length; |
---|
1140 | } |
---|
1141 | } else { |
---|
1142 | Utils.normalize(props[attIndex]); |
---|
1143 | } |
---|
1144 | |
---|
1145 | |
---|
1146 | // Compute subset weights |
---|
1147 | subsetWeights[attIndex] = new double[2]; |
---|
1148 | for (int j = 0; j < 2; j++) { |
---|
1149 | subsetWeights[attIndex][j] += Utils.sum(dist[j]); |
---|
1150 | } |
---|
1151 | |
---|
1152 | // Then, for the attribute values that class frequency is 0, split it into the |
---|
1153 | // most frequent branch |
---|
1154 | for (int j=0; j<empty; j++) { |
---|
1155 | if (props[attIndex][0]>=props[attIndex][1]) { |
---|
1156 | if (bestSplitString=="") bestSplitString = "(" + emptyValues[j] + ")"; |
---|
1157 | else bestSplitString += "|" + "(" + emptyValues[j] + ")"; |
---|
1158 | } |
---|
1159 | } |
---|
1160 | |
---|
1161 | // clean Gini gain for the attribute |
---|
1162 | //giniGains[attIndex] = Math.rint(bestGiniGain*10000000)/10000000.0; |
---|
1163 | giniGains[attIndex] = bestGiniGain; |
---|
1164 | |
---|
1165 | dists[attIndex] = dist; |
---|
1166 | return bestSplitString; |
---|
1167 | } |
---|
1168 | |
---|
1169 | |
---|
1170 | /** |
---|
1171 | * Split data into two subsets and store sorted indices and weights for two |
---|
1172 | * successor nodes. |
---|
1173 | * |
---|
1174 | * @param subsetIndices sorted indecis of instances for each attribute |
---|
1175 | * for two successor node |
---|
1176 | * @param subsetWeights weights of instances for each attribute for |
---|
1177 | * two successor node |
---|
1178 | * @param att attribute the split based on |
---|
1179 | * @param splitPoint split point the split based on if att is numeric |
---|
1180 | * @param splitStr split subset the split based on if att is nominal |
---|
1181 | * @param sortedIndices sorted indices of the instances to be split |
---|
1182 | * @param weights weights of the instances to bes split |
---|
1183 | * @param data training data |
---|
1184 | * @throws Exception if something goes wrong |
---|
1185 | */ |
---|
1186 | protected void splitData(int[][][] subsetIndices, double[][][] subsetWeights, |
---|
1187 | Attribute att, double splitPoint, String splitStr, int[][] sortedIndices, |
---|
1188 | double[][] weights, Instances data) throws Exception { |
---|
1189 | |
---|
1190 | int j; |
---|
1191 | // For each attribute |
---|
1192 | for (int i = 0; i < data.numAttributes(); i++) { |
---|
1193 | if (i==data.classIndex()) continue; |
---|
1194 | int[] num = new int[2]; |
---|
1195 | for (int k = 0; k < 2; k++) { |
---|
1196 | subsetIndices[k][i] = new int[sortedIndices[i].length]; |
---|
1197 | subsetWeights[k][i] = new double[weights[i].length]; |
---|
1198 | } |
---|
1199 | |
---|
1200 | for (j = 0; j < sortedIndices[i].length; j++) { |
---|
1201 | Instance inst = data.instance(sortedIndices[i][j]); |
---|
1202 | if (inst.isMissing(att)) { |
---|
1203 | // Split instance up |
---|
1204 | for (int k = 0; k < 2; k++) { |
---|
1205 | if (m_Props[k] > 0) { |
---|
1206 | subsetIndices[k][i][num[k]] = sortedIndices[i][j]; |
---|
1207 | subsetWeights[k][i][num[k]] = m_Props[k] * weights[i][j]; |
---|
1208 | num[k]++; |
---|
1209 | } |
---|
1210 | } |
---|
1211 | } else { |
---|
1212 | int subset; |
---|
1213 | if (att.isNumeric()) { |
---|
1214 | subset = (inst.value(att) < splitPoint) ? 0 : 1; |
---|
1215 | } else { // nominal attribute |
---|
1216 | if (splitStr.indexOf |
---|
1217 | ("(" + att.value((int)inst.value(att.index()))+")")!=-1) { |
---|
1218 | subset = 0; |
---|
1219 | } else subset = 1; |
---|
1220 | } |
---|
1221 | subsetIndices[subset][i][num[subset]] = sortedIndices[i][j]; |
---|
1222 | subsetWeights[subset][i][num[subset]] = weights[i][j]; |
---|
1223 | num[subset]++; |
---|
1224 | } |
---|
1225 | } |
---|
1226 | |
---|
1227 | // Trim arrays |
---|
1228 | for (int k = 0; k < 2; k++) { |
---|
1229 | int[] copy = new int[num[k]]; |
---|
1230 | System.arraycopy(subsetIndices[k][i], 0, copy, 0, num[k]); |
---|
1231 | subsetIndices[k][i] = copy; |
---|
1232 | double[] copyWeights = new double[num[k]]; |
---|
1233 | System.arraycopy(subsetWeights[k][i], 0 ,copyWeights, 0, num[k]); |
---|
1234 | subsetWeights[k][i] = copyWeights; |
---|
1235 | } |
---|
1236 | } |
---|
1237 | } |
---|
1238 | |
---|
1239 | /** |
---|
1240 | * Updates the numIncorrectModel field for all nodes when subtree (to be |
---|
1241 | * pruned) is rooted. This is needed for calculating the alpha-values. |
---|
1242 | * |
---|
1243 | * @throws Exception if something goes wrong |
---|
1244 | */ |
---|
1245 | public void modelErrors() throws Exception{ |
---|
1246 | Evaluation eval = new Evaluation(m_train); |
---|
1247 | |
---|
1248 | if (!m_isLeaf) { |
---|
1249 | m_isLeaf = true; //temporarily make leaf |
---|
1250 | |
---|
1251 | // calculate distribution for evaluation |
---|
1252 | eval.evaluateModel(this, m_train); |
---|
1253 | m_numIncorrectModel = eval.incorrect(); |
---|
1254 | |
---|
1255 | m_isLeaf = false; |
---|
1256 | |
---|
1257 | for (int i = 0; i < m_Successors.length; i++) |
---|
1258 | m_Successors[i].modelErrors(); |
---|
1259 | |
---|
1260 | } else { |
---|
1261 | eval.evaluateModel(this, m_train); |
---|
1262 | m_numIncorrectModel = eval.incorrect(); |
---|
1263 | } |
---|
1264 | } |
---|
1265 | |
---|
1266 | /** |
---|
1267 | * Updates the numIncorrectTree field for all nodes. This is needed for |
---|
1268 | * calculating the alpha-values. |
---|
1269 | * |
---|
1270 | * @throws Exception if something goes wrong |
---|
1271 | */ |
---|
1272 | public void treeErrors() throws Exception { |
---|
1273 | if (m_isLeaf) { |
---|
1274 | m_numIncorrectTree = m_numIncorrectModel; |
---|
1275 | } else { |
---|
1276 | m_numIncorrectTree = 0; |
---|
1277 | for (int i = 0; i < m_Successors.length; i++) { |
---|
1278 | m_Successors[i].treeErrors(); |
---|
1279 | m_numIncorrectTree += m_Successors[i].m_numIncorrectTree; |
---|
1280 | } |
---|
1281 | } |
---|
1282 | } |
---|
1283 | |
---|
1284 | /** |
---|
1285 | * Updates the alpha field for all nodes. |
---|
1286 | * |
---|
1287 | * @throws Exception if something goes wrong |
---|
1288 | */ |
---|
1289 | public void calculateAlphas() throws Exception { |
---|
1290 | |
---|
1291 | if (!m_isLeaf) { |
---|
1292 | double errorDiff = m_numIncorrectModel - m_numIncorrectTree; |
---|
1293 | if (errorDiff <=0) { |
---|
1294 | //split increases training error (should not normally happen). |
---|
1295 | //prune it instantly. |
---|
1296 | makeLeaf(m_train); |
---|
1297 | m_Alpha = Double.MAX_VALUE; |
---|
1298 | } else { |
---|
1299 | //compute alpha |
---|
1300 | errorDiff /= m_totalTrainInstances; |
---|
1301 | m_Alpha = errorDiff / (double)(numLeaves() - 1); |
---|
1302 | long alphaLong = Math.round(m_Alpha*Math.pow(10,10)); |
---|
1303 | m_Alpha = (double)alphaLong/Math.pow(10,10); |
---|
1304 | for (int i = 0; i < m_Successors.length; i++) { |
---|
1305 | m_Successors[i].calculateAlphas(); |
---|
1306 | } |
---|
1307 | } |
---|
1308 | } else { |
---|
1309 | //alpha = infinite for leaves (do not want to prune) |
---|
1310 | m_Alpha = Double.MAX_VALUE; |
---|
1311 | } |
---|
1312 | } |
---|
1313 | |
---|
1314 | /** |
---|
1315 | * Find the node with minimal alpha value. If two nodes have the same alpha, |
---|
1316 | * choose the one with more leave nodes. |
---|
1317 | * |
---|
1318 | * @param nodeList list of inner nodes |
---|
1319 | * @return the node to be pruned |
---|
1320 | */ |
---|
1321 | protected SimpleCart nodeToPrune(Vector nodeList) { |
---|
1322 | if (nodeList.size()==0) return null; |
---|
1323 | if (nodeList.size()==1) return (SimpleCart)nodeList.elementAt(0); |
---|
1324 | SimpleCart returnNode = (SimpleCart)nodeList.elementAt(0); |
---|
1325 | double baseAlpha = returnNode.m_Alpha; |
---|
1326 | for (int i=1; i<nodeList.size(); i++) { |
---|
1327 | SimpleCart node = (SimpleCart)nodeList.elementAt(i); |
---|
1328 | if (node.m_Alpha < baseAlpha) { |
---|
1329 | baseAlpha = node.m_Alpha; |
---|
1330 | returnNode = node; |
---|
1331 | } else if (node.m_Alpha == baseAlpha) { // break tie |
---|
1332 | if (node.numLeaves()>returnNode.numLeaves()) { |
---|
1333 | returnNode = node; |
---|
1334 | } |
---|
1335 | } |
---|
1336 | } |
---|
1337 | return returnNode; |
---|
1338 | } |
---|
1339 | |
---|
1340 | /** |
---|
1341 | * Compute sorted indices, weights and class probabilities for a given |
---|
1342 | * dataset. Return total weights of the data at the node. |
---|
1343 | * |
---|
1344 | * @param data training data |
---|
1345 | * @param sortedIndices sorted indices of instances at the node |
---|
1346 | * @param weights weights of instances at the node |
---|
1347 | * @param classProbs class probabilities at the node |
---|
1348 | * @return total weights of instances at the node |
---|
1349 | * @throws Exception if something goes wrong |
---|
1350 | */ |
---|
1351 | protected double computeSortedInfo(Instances data, int[][] sortedIndices, double[][] weights, |
---|
1352 | double[] classProbs) throws Exception { |
---|
1353 | |
---|
1354 | // Create array of sorted indices and weights |
---|
1355 | double[] vals = new double[data.numInstances()]; |
---|
1356 | for (int j = 0; j < data.numAttributes(); j++) { |
---|
1357 | if (j==data.classIndex()) continue; |
---|
1358 | weights[j] = new double[data.numInstances()]; |
---|
1359 | |
---|
1360 | if (data.attribute(j).isNominal()) { |
---|
1361 | |
---|
1362 | // Handling nominal attributes. Putting indices of |
---|
1363 | // instances with missing values at the end. |
---|
1364 | sortedIndices[j] = new int[data.numInstances()]; |
---|
1365 | int count = 0; |
---|
1366 | for (int i = 0; i < data.numInstances(); i++) { |
---|
1367 | Instance inst = data.instance(i); |
---|
1368 | if (!inst.isMissing(j)) { |
---|
1369 | sortedIndices[j][count] = i; |
---|
1370 | weights[j][count] = inst.weight(); |
---|
1371 | count++; |
---|
1372 | } |
---|
1373 | } |
---|
1374 | for (int i = 0; i < data.numInstances(); i++) { |
---|
1375 | Instance inst = data.instance(i); |
---|
1376 | if (inst.isMissing(j)) { |
---|
1377 | sortedIndices[j][count] = i; |
---|
1378 | weights[j][count] = inst.weight(); |
---|
1379 | count++; |
---|
1380 | } |
---|
1381 | } |
---|
1382 | } else { |
---|
1383 | |
---|
1384 | // Sorted indices are computed for numeric attributes |
---|
1385 | // missing values instances are put to end |
---|
1386 | for (int i = 0; i < data.numInstances(); i++) { |
---|
1387 | Instance inst = data.instance(i); |
---|
1388 | vals[i] = inst.value(j); |
---|
1389 | } |
---|
1390 | sortedIndices[j] = Utils.sort(vals); |
---|
1391 | for (int i = 0; i < data.numInstances(); i++) { |
---|
1392 | weights[j][i] = data.instance(sortedIndices[j][i]).weight(); |
---|
1393 | } |
---|
1394 | } |
---|
1395 | } |
---|
1396 | |
---|
1397 | // Compute initial class counts |
---|
1398 | double totalWeight = 0; |
---|
1399 | for (int i = 0; i < data.numInstances(); i++) { |
---|
1400 | Instance inst = data.instance(i); |
---|
1401 | classProbs[(int)inst.classValue()] += inst.weight(); |
---|
1402 | totalWeight += inst.weight(); |
---|
1403 | } |
---|
1404 | |
---|
1405 | return totalWeight; |
---|
1406 | } |
---|
1407 | |
---|
1408 | /** |
---|
1409 | * Compute and return gini gain for given distributions of a node and its |
---|
1410 | * successor nodes. |
---|
1411 | * |
---|
1412 | * @param parentDist class distributions of parent node |
---|
1413 | * @param childDist class distributions of successor nodes |
---|
1414 | * @return Gini gain computed |
---|
1415 | */ |
---|
1416 | protected double computeGiniGain(double[] parentDist, double[][] childDist) { |
---|
1417 | double totalWeight = Utils.sum(parentDist); |
---|
1418 | if (totalWeight==0) return 0; |
---|
1419 | |
---|
1420 | double leftWeight = Utils.sum(childDist[0]); |
---|
1421 | double rightWeight = Utils.sum(childDist[1]); |
---|
1422 | |
---|
1423 | double parentGini = computeGini(parentDist, totalWeight); |
---|
1424 | double leftGini = computeGini(childDist[0],leftWeight); |
---|
1425 | double rightGini = computeGini(childDist[1], rightWeight); |
---|
1426 | |
---|
1427 | return parentGini - leftWeight/totalWeight*leftGini - |
---|
1428 | rightWeight/totalWeight*rightGini; |
---|
1429 | } |
---|
1430 | |
---|
1431 | /** |
---|
1432 | * Compute and return gini index for a given distribution of a node. |
---|
1433 | * |
---|
1434 | * @param dist class distributions |
---|
1435 | * @param total class distributions |
---|
1436 | * @return Gini index of the class distributions |
---|
1437 | */ |
---|
1438 | protected double computeGini(double[] dist, double total) { |
---|
1439 | if (total==0) return 0; |
---|
1440 | double val = 0; |
---|
1441 | for (int i=0; i<dist.length; i++) { |
---|
1442 | val += (dist[i]/total)*(dist[i]/total); |
---|
1443 | } |
---|
1444 | return 1- val; |
---|
1445 | } |
---|
1446 | |
---|
1447 | /** |
---|
1448 | * Computes class probabilities for instance using the decision tree. |
---|
1449 | * |
---|
1450 | * @param instance the instance for which class probabilities is to be computed |
---|
1451 | * @return the class probabilities for the given instance |
---|
1452 | * @throws Exception if something goes wrong |
---|
1453 | */ |
---|
1454 | public double[] distributionForInstance(Instance instance) |
---|
1455 | throws Exception { |
---|
1456 | if (!m_isLeaf) { |
---|
1457 | // value of split attribute is missing |
---|
1458 | if (instance.isMissing(m_Attribute)) { |
---|
1459 | double[] returnedDist = new double[m_ClassProbs.length]; |
---|
1460 | |
---|
1461 | for (int i = 0; i < m_Successors.length; i++) { |
---|
1462 | double[] help = |
---|
1463 | m_Successors[i].distributionForInstance(instance); |
---|
1464 | if (help != null) { |
---|
1465 | for (int j = 0; j < help.length; j++) { |
---|
1466 | returnedDist[j] += m_Props[i] * help[j]; |
---|
1467 | } |
---|
1468 | } |
---|
1469 | } |
---|
1470 | return returnedDist; |
---|
1471 | } |
---|
1472 | |
---|
1473 | // split attribute is nonimal |
---|
1474 | else if (m_Attribute.isNominal()) { |
---|
1475 | if (m_SplitString.indexOf("(" + |
---|
1476 | m_Attribute.value((int)instance.value(m_Attribute)) + ")")!=-1) |
---|
1477 | return m_Successors[0].distributionForInstance(instance); |
---|
1478 | else return m_Successors[1].distributionForInstance(instance); |
---|
1479 | } |
---|
1480 | |
---|
1481 | // split attribute is numeric |
---|
1482 | else { |
---|
1483 | if (instance.value(m_Attribute) < m_SplitValue) |
---|
1484 | return m_Successors[0].distributionForInstance(instance); |
---|
1485 | else |
---|
1486 | return m_Successors[1].distributionForInstance(instance); |
---|
1487 | } |
---|
1488 | } |
---|
1489 | |
---|
1490 | // leaf node |
---|
1491 | else return m_ClassProbs; |
---|
1492 | } |
---|
1493 | |
---|
1494 | /** |
---|
1495 | * Make the node leaf node. |
---|
1496 | * |
---|
1497 | * @param data trainging data |
---|
1498 | */ |
---|
1499 | protected void makeLeaf(Instances data) { |
---|
1500 | m_Attribute = null; |
---|
1501 | m_isLeaf = true; |
---|
1502 | m_ClassValue=Utils.maxIndex(m_ClassProbs); |
---|
1503 | m_ClassAttribute = data.classAttribute(); |
---|
1504 | } |
---|
1505 | |
---|
1506 | /** |
---|
1507 | * Prints the decision tree using the protected toString method from below. |
---|
1508 | * |
---|
1509 | * @return a textual description of the classifier |
---|
1510 | */ |
---|
1511 | public String toString() { |
---|
1512 | if ((m_ClassProbs == null) && (m_Successors == null)) { |
---|
1513 | return "CART Tree: No model built yet."; |
---|
1514 | } |
---|
1515 | |
---|
1516 | return "CART Decision Tree\n" + toString(0)+"\n\n" |
---|
1517 | +"Number of Leaf Nodes: "+numLeaves()+"\n\n" + |
---|
1518 | "Size of the Tree: "+numNodes(); |
---|
1519 | } |
---|
1520 | |
---|
1521 | /** |
---|
1522 | * Outputs a tree at a certain level. |
---|
1523 | * |
---|
1524 | * @param level the level at which the tree is to be printed |
---|
1525 | * @return a tree at a certain level |
---|
1526 | */ |
---|
1527 | protected String toString(int level) { |
---|
1528 | |
---|
1529 | StringBuffer text = new StringBuffer(); |
---|
1530 | // if leaf nodes |
---|
1531 | if (m_Attribute == null) { |
---|
1532 | if (Utils.isMissingValue(m_ClassValue)) { |
---|
1533 | text.append(": null"); |
---|
1534 | } else { |
---|
1535 | double correctNum = (int)(m_Distribution[Utils.maxIndex(m_Distribution)]*100)/ |
---|
1536 | 100.0; |
---|
1537 | double wrongNum = (int)((Utils.sum(m_Distribution) - |
---|
1538 | m_Distribution[Utils.maxIndex(m_Distribution)])*100)/100.0; |
---|
1539 | String str = "(" + correctNum + "/" + wrongNum + ")"; |
---|
1540 | text.append(": " + m_ClassAttribute.value((int) m_ClassValue)+ str); |
---|
1541 | } |
---|
1542 | } else { |
---|
1543 | for (int j = 0; j < 2; j++) { |
---|
1544 | text.append("\n"); |
---|
1545 | for (int i = 0; i < level; i++) { |
---|
1546 | text.append("| "); |
---|
1547 | } |
---|
1548 | if (j==0) { |
---|
1549 | if (m_Attribute.isNumeric()) |
---|
1550 | text.append(m_Attribute.name() + " < " + m_SplitValue); |
---|
1551 | else |
---|
1552 | text.append(m_Attribute.name() + "=" + m_SplitString); |
---|
1553 | } else { |
---|
1554 | if (m_Attribute.isNumeric()) |
---|
1555 | text.append(m_Attribute.name() + " >= " + m_SplitValue); |
---|
1556 | else |
---|
1557 | text.append(m_Attribute.name() + "!=" + m_SplitString); |
---|
1558 | } |
---|
1559 | text.append(m_Successors[j].toString(level + 1)); |
---|
1560 | } |
---|
1561 | } |
---|
1562 | return text.toString(); |
---|
1563 | } |
---|
1564 | |
---|
1565 | /** |
---|
1566 | * Compute size of the tree. |
---|
1567 | * |
---|
1568 | * @return size of the tree |
---|
1569 | */ |
---|
1570 | public int numNodes() { |
---|
1571 | if (m_isLeaf) { |
---|
1572 | return 1; |
---|
1573 | } else { |
---|
1574 | int size =1; |
---|
1575 | for (int i=0;i<m_Successors.length;i++) { |
---|
1576 | size+=m_Successors[i].numNodes(); |
---|
1577 | } |
---|
1578 | return size; |
---|
1579 | } |
---|
1580 | } |
---|
1581 | |
---|
1582 | /** |
---|
1583 | * Method to count the number of inner nodes in the tree. |
---|
1584 | * |
---|
1585 | * @return the number of inner nodes |
---|
1586 | */ |
---|
1587 | public int numInnerNodes(){ |
---|
1588 | if (m_Attribute==null) return 0; |
---|
1589 | int numNodes = 1; |
---|
1590 | for (int i = 0; i < m_Successors.length; i++) |
---|
1591 | numNodes += m_Successors[i].numInnerNodes(); |
---|
1592 | return numNodes; |
---|
1593 | } |
---|
1594 | |
---|
1595 | /** |
---|
1596 | * Return a list of all inner nodes in the tree. |
---|
1597 | * |
---|
1598 | * @return the list of all inner nodes |
---|
1599 | */ |
---|
1600 | protected Vector getInnerNodes(){ |
---|
1601 | Vector nodeList = new Vector(); |
---|
1602 | fillInnerNodes(nodeList); |
---|
1603 | return nodeList; |
---|
1604 | } |
---|
1605 | |
---|
1606 | /** |
---|
1607 | * Fills a list with all inner nodes in the tree. |
---|
1608 | * |
---|
1609 | * @param nodeList the list to be filled |
---|
1610 | */ |
---|
1611 | protected void fillInnerNodes(Vector nodeList) { |
---|
1612 | if (!m_isLeaf) { |
---|
1613 | nodeList.add(this); |
---|
1614 | for (int i = 0; i < m_Successors.length; i++) |
---|
1615 | m_Successors[i].fillInnerNodes(nodeList); |
---|
1616 | } |
---|
1617 | } |
---|
1618 | |
---|
1619 | /** |
---|
1620 | * Compute number of leaf nodes. |
---|
1621 | * |
---|
1622 | * @return number of leaf nodes |
---|
1623 | */ |
---|
1624 | public int numLeaves() { |
---|
1625 | if (m_isLeaf) return 1; |
---|
1626 | else { |
---|
1627 | int size=0; |
---|
1628 | for (int i=0;i<m_Successors.length;i++) { |
---|
1629 | size+=m_Successors[i].numLeaves(); |
---|
1630 | } |
---|
1631 | return size; |
---|
1632 | } |
---|
1633 | } |
---|
1634 | |
---|
1635 | /** |
---|
1636 | * Returns an enumeration describing the available options. |
---|
1637 | * |
---|
1638 | * @return an enumeration of all the available options. |
---|
1639 | */ |
---|
1640 | public Enumeration listOptions() { |
---|
1641 | Vector result; |
---|
1642 | Enumeration en; |
---|
1643 | |
---|
1644 | result = new Vector(); |
---|
1645 | |
---|
1646 | en = super.listOptions(); |
---|
1647 | while (en.hasMoreElements()) |
---|
1648 | result.addElement(en.nextElement()); |
---|
1649 | |
---|
1650 | result.addElement(new Option( |
---|
1651 | "\tThe minimal number of instances at the terminal nodes.\n" |
---|
1652 | + "\t(default 2)", |
---|
1653 | "M", 1, "-M <min no>")); |
---|
1654 | |
---|
1655 | result.addElement(new Option( |
---|
1656 | "\tThe number of folds used in the minimal cost-complexity pruning.\n" |
---|
1657 | + "\t(default 5)", |
---|
1658 | "N", 1, "-N <num folds>")); |
---|
1659 | |
---|
1660 | result.addElement(new Option( |
---|
1661 | "\tDon't use the minimal cost-complexity pruning.\n" |
---|
1662 | + "\t(default yes).", |
---|
1663 | "U", 0, "-U")); |
---|
1664 | |
---|
1665 | result.addElement(new Option( |
---|
1666 | "\tDon't use the heuristic method for binary split.\n" |
---|
1667 | + "\t(default true).", |
---|
1668 | "H", 0, "-H")); |
---|
1669 | |
---|
1670 | result.addElement(new Option( |
---|
1671 | "\tUse 1 SE rule to make pruning decision.\n" |
---|
1672 | + "\t(default no).", |
---|
1673 | "A", 0, "-A")); |
---|
1674 | |
---|
1675 | result.addElement(new Option( |
---|
1676 | "\tPercentage of training data size (0-1].\n" |
---|
1677 | + "\t(default 1).", |
---|
1678 | "C", 1, "-C")); |
---|
1679 | |
---|
1680 | return result.elements(); |
---|
1681 | } |
---|
1682 | |
---|
1683 | /** |
---|
1684 | * Parses a given list of options. <p/> |
---|
1685 | * |
---|
1686 | <!-- options-start --> |
---|
1687 | * Valid options are: <p/> |
---|
1688 | * |
---|
1689 | * <pre> -S <num> |
---|
1690 | * Random number seed. |
---|
1691 | * (default 1)</pre> |
---|
1692 | * |
---|
1693 | * <pre> -D |
---|
1694 | * If set, classifier is run in debug mode and |
---|
1695 | * may output additional info to the console</pre> |
---|
1696 | * |
---|
1697 | * <pre> -M <min no> |
---|
1698 | * The minimal number of instances at the terminal nodes. |
---|
1699 | * (default 2)</pre> |
---|
1700 | * |
---|
1701 | * <pre> -N <num folds> |
---|
1702 | * The number of folds used in the minimal cost-complexity pruning. |
---|
1703 | * (default 5)</pre> |
---|
1704 | * |
---|
1705 | * <pre> -U |
---|
1706 | * Don't use the minimal cost-complexity pruning. |
---|
1707 | * (default yes).</pre> |
---|
1708 | * |
---|
1709 | * <pre> -H |
---|
1710 | * Don't use the heuristic method for binary split. |
---|
1711 | * (default true).</pre> |
---|
1712 | * |
---|
1713 | * <pre> -A |
---|
1714 | * Use 1 SE rule to make pruning decision. |
---|
1715 | * (default no).</pre> |
---|
1716 | * |
---|
1717 | * <pre> -C |
---|
1718 | * Percentage of training data size (0-1]. |
---|
1719 | * (default 1).</pre> |
---|
1720 | * |
---|
1721 | <!-- options-end --> |
---|
1722 | * |
---|
1723 | * @param options the list of options as an array of strings |
---|
1724 | * @throws Exception if an options is not supported |
---|
1725 | */ |
---|
1726 | public void setOptions(String[] options) throws Exception { |
---|
1727 | String tmpStr; |
---|
1728 | |
---|
1729 | super.setOptions(options); |
---|
1730 | |
---|
1731 | tmpStr = Utils.getOption('M', options); |
---|
1732 | if (tmpStr.length() != 0) |
---|
1733 | setMinNumObj(Double.parseDouble(tmpStr)); |
---|
1734 | else |
---|
1735 | setMinNumObj(2); |
---|
1736 | |
---|
1737 | tmpStr = Utils.getOption('N', options); |
---|
1738 | if (tmpStr.length()!=0) |
---|
1739 | setNumFoldsPruning(Integer.parseInt(tmpStr)); |
---|
1740 | else |
---|
1741 | setNumFoldsPruning(5); |
---|
1742 | |
---|
1743 | setUsePrune(!Utils.getFlag('U',options)); |
---|
1744 | setHeuristic(!Utils.getFlag('H',options)); |
---|
1745 | setUseOneSE(Utils.getFlag('A',options)); |
---|
1746 | |
---|
1747 | tmpStr = Utils.getOption('C', options); |
---|
1748 | if (tmpStr.length()!=0) |
---|
1749 | setSizePer(Double.parseDouble(tmpStr)); |
---|
1750 | else |
---|
1751 | setSizePer(1); |
---|
1752 | |
---|
1753 | Utils.checkForRemainingOptions(options); |
---|
1754 | } |
---|
1755 | |
---|
1756 | /** |
---|
1757 | * Gets the current settings of the classifier. |
---|
1758 | * |
---|
1759 | * @return the current setting of the classifier |
---|
1760 | */ |
---|
1761 | public String[] getOptions() { |
---|
1762 | int i; |
---|
1763 | Vector result; |
---|
1764 | String[] options; |
---|
1765 | |
---|
1766 | result = new Vector(); |
---|
1767 | |
---|
1768 | options = super.getOptions(); |
---|
1769 | for (i = 0; i < options.length; i++) |
---|
1770 | result.add(options[i]); |
---|
1771 | |
---|
1772 | result.add("-M"); |
---|
1773 | result.add("" + getMinNumObj()); |
---|
1774 | |
---|
1775 | result.add("-N"); |
---|
1776 | result.add("" + getNumFoldsPruning()); |
---|
1777 | |
---|
1778 | if (!getUsePrune()) |
---|
1779 | result.add("-U"); |
---|
1780 | |
---|
1781 | if (!getHeuristic()) |
---|
1782 | result.add("-H"); |
---|
1783 | |
---|
1784 | if (getUseOneSE()) |
---|
1785 | result.add("-A"); |
---|
1786 | |
---|
1787 | result.add("-C"); |
---|
1788 | result.add("" + getSizePer()); |
---|
1789 | |
---|
1790 | return (String[]) result.toArray(new String[result.size()]); |
---|
1791 | } |
---|
1792 | |
---|
1793 | /** |
---|
1794 | * Return an enumeration of the measure names. |
---|
1795 | * |
---|
1796 | * @return an enumeration of the measure names |
---|
1797 | */ |
---|
1798 | public Enumeration enumerateMeasures() { |
---|
1799 | Vector result = new Vector(); |
---|
1800 | |
---|
1801 | result.addElement("measureTreeSize"); |
---|
1802 | |
---|
1803 | return result.elements(); |
---|
1804 | } |
---|
1805 | |
---|
1806 | /** |
---|
1807 | * Return number of tree size. |
---|
1808 | * |
---|
1809 | * @return number of tree size |
---|
1810 | */ |
---|
1811 | public double measureTreeSize() { |
---|
1812 | return numNodes(); |
---|
1813 | } |
---|
1814 | |
---|
1815 | /** |
---|
1816 | * Returns the value of the named measure. |
---|
1817 | * |
---|
1818 | * @param additionalMeasureName the name of the measure to query for its value |
---|
1819 | * @return the value of the named measure |
---|
1820 | * @throws IllegalArgumentException if the named measure is not supported |
---|
1821 | */ |
---|
1822 | public double getMeasure(String additionalMeasureName) { |
---|
1823 | if (additionalMeasureName.compareToIgnoreCase("measureTreeSize") == 0) { |
---|
1824 | return measureTreeSize(); |
---|
1825 | } else { |
---|
1826 | throw new IllegalArgumentException(additionalMeasureName |
---|
1827 | + " not supported (Cart pruning)"); |
---|
1828 | } |
---|
1829 | } |
---|
1830 | |
---|
1831 | /** |
---|
1832 | * Returns the tip text for this property |
---|
1833 | * |
---|
1834 | * @return tip text for this property suitable for |
---|
1835 | * displaying in the explorer/experimenter gui |
---|
1836 | */ |
---|
1837 | public String minNumObjTipText() { |
---|
1838 | return "The minimal number of observations at the terminal nodes (default 2)."; |
---|
1839 | } |
---|
1840 | |
---|
1841 | /** |
---|
1842 | * Set minimal number of instances at the terminal nodes. |
---|
1843 | * |
---|
1844 | * @param value minimal number of instances at the terminal nodes |
---|
1845 | */ |
---|
1846 | public void setMinNumObj(double value) { |
---|
1847 | m_minNumObj = value; |
---|
1848 | } |
---|
1849 | |
---|
1850 | /** |
---|
1851 | * Get minimal number of instances at the terminal nodes. |
---|
1852 | * |
---|
1853 | * @return minimal number of instances at the terminal nodes |
---|
1854 | */ |
---|
1855 | public double getMinNumObj() { |
---|
1856 | return m_minNumObj; |
---|
1857 | } |
---|
1858 | |
---|
1859 | /** |
---|
1860 | * Returns the tip text for this property |
---|
1861 | * |
---|
1862 | * @return tip text for this property suitable for |
---|
1863 | * displaying in the explorer/experimenter gui |
---|
1864 | */ |
---|
1865 | public String numFoldsPruningTipText() { |
---|
1866 | return "The number of folds in the internal cross-validation (default 5)."; |
---|
1867 | } |
---|
1868 | |
---|
1869 | /** |
---|
1870 | * Set number of folds in internal cross-validation. |
---|
1871 | * |
---|
1872 | * @param value number of folds in internal cross-validation. |
---|
1873 | */ |
---|
1874 | public void setNumFoldsPruning(int value) { |
---|
1875 | m_numFoldsPruning = value; |
---|
1876 | } |
---|
1877 | |
---|
1878 | /** |
---|
1879 | * Set number of folds in internal cross-validation. |
---|
1880 | * |
---|
1881 | * @return number of folds in internal cross-validation. |
---|
1882 | */ |
---|
1883 | public int getNumFoldsPruning() { |
---|
1884 | return m_numFoldsPruning; |
---|
1885 | } |
---|
1886 | |
---|
1887 | /** |
---|
1888 | * Return the tip text for this property |
---|
1889 | * |
---|
1890 | * @return tip text for this property suitable for displaying in |
---|
1891 | * the explorer/experimenter gui. |
---|
1892 | */ |
---|
1893 | public String usePruneTipText() { |
---|
1894 | return "Use minimal cost-complexity pruning (default yes)."; |
---|
1895 | } |
---|
1896 | |
---|
1897 | /** |
---|
1898 | * Set if use minimal cost-complexity pruning. |
---|
1899 | * |
---|
1900 | * @param value if use minimal cost-complexity pruning |
---|
1901 | */ |
---|
1902 | public void setUsePrune(boolean value) { |
---|
1903 | m_Prune = value; |
---|
1904 | } |
---|
1905 | |
---|
1906 | /** |
---|
1907 | * Get if use minimal cost-complexity pruning. |
---|
1908 | * |
---|
1909 | * @return if use minimal cost-complexity pruning |
---|
1910 | */ |
---|
1911 | public boolean getUsePrune() { |
---|
1912 | return m_Prune; |
---|
1913 | } |
---|
1914 | |
---|
1915 | /** |
---|
1916 | * Returns the tip text for this property |
---|
1917 | * |
---|
1918 | * @return tip text for this property suitable for |
---|
1919 | * displaying in the explorer/experimenter gui. |
---|
1920 | */ |
---|
1921 | public String heuristicTipText() { |
---|
1922 | return |
---|
1923 | "If heuristic search is used for binary split for nominal attributes " |
---|
1924 | + "in multi-class problems (default yes)."; |
---|
1925 | } |
---|
1926 | |
---|
1927 | /** |
---|
1928 | * Set if use heuristic search for nominal attributes in multi-class problems. |
---|
1929 | * |
---|
1930 | * @param value if use heuristic search for nominal attributes in |
---|
1931 | * multi-class problems |
---|
1932 | */ |
---|
1933 | public void setHeuristic(boolean value) { |
---|
1934 | m_Heuristic = value; |
---|
1935 | } |
---|
1936 | |
---|
1937 | /** |
---|
1938 | * Get if use heuristic search for nominal attributes in multi-class problems. |
---|
1939 | * |
---|
1940 | * @return if use heuristic search for nominal attributes in |
---|
1941 | * multi-class problems |
---|
1942 | */ |
---|
1943 | public boolean getHeuristic() {return m_Heuristic;} |
---|
1944 | |
---|
1945 | /** |
---|
1946 | * Returns the tip text for this property |
---|
1947 | * |
---|
1948 | * @return tip text for this property suitable for |
---|
1949 | * displaying in the explorer/experimenter gui. |
---|
1950 | */ |
---|
1951 | public String useOneSETipText() { |
---|
1952 | return "Use the 1SE rule to make pruning decisoin."; |
---|
1953 | } |
---|
1954 | |
---|
1955 | /** |
---|
1956 | * Set if use the 1SE rule to choose final model. |
---|
1957 | * |
---|
1958 | * @param value if use the 1SE rule to choose final model |
---|
1959 | */ |
---|
1960 | public void setUseOneSE(boolean value) { |
---|
1961 | m_UseOneSE = value; |
---|
1962 | } |
---|
1963 | |
---|
1964 | /** |
---|
1965 | * Get if use the 1SE rule to choose final model. |
---|
1966 | * |
---|
1967 | * @return if use the 1SE rule to choose final model |
---|
1968 | */ |
---|
1969 | public boolean getUseOneSE() { |
---|
1970 | return m_UseOneSE; |
---|
1971 | } |
---|
1972 | |
---|
1973 | /** |
---|
1974 | * Returns the tip text for this property |
---|
1975 | * |
---|
1976 | * @return tip text for this property suitable for |
---|
1977 | * displaying in the explorer/experimenter gui. |
---|
1978 | */ |
---|
1979 | public String sizePerTipText() { |
---|
1980 | return "The percentage of the training set size (0-1, 0 not included)."; |
---|
1981 | } |
---|
1982 | |
---|
1983 | /** |
---|
1984 | * Set training set size. |
---|
1985 | * |
---|
1986 | * @param value training set size |
---|
1987 | */ |
---|
1988 | public void setSizePer(double value) { |
---|
1989 | if ((value <= 0) || (value > 1)) |
---|
1990 | System.err.println( |
---|
1991 | "The percentage of the training set size must be in range 0 to 1 " |
---|
1992 | + "(0 not included) - ignored!"); |
---|
1993 | else |
---|
1994 | m_SizePer = value; |
---|
1995 | } |
---|
1996 | |
---|
1997 | /** |
---|
1998 | * Get training set size. |
---|
1999 | * |
---|
2000 | * @return training set size |
---|
2001 | */ |
---|
2002 | public double getSizePer() { |
---|
2003 | return m_SizePer; |
---|
2004 | } |
---|
2005 | |
---|
2006 | /** |
---|
2007 | * Returns the revision string. |
---|
2008 | * |
---|
2009 | * @return the revision |
---|
2010 | */ |
---|
2011 | public String getRevision() { |
---|
2012 | return RevisionUtils.extract("$Revision: 5987 $"); |
---|
2013 | } |
---|
2014 | |
---|
2015 | /** |
---|
2016 | * Main method. |
---|
2017 | * @param args the options for the classifier |
---|
2018 | */ |
---|
2019 | public static void main(String[] args) { |
---|
2020 | runClassifier(new SimpleCart(), args); |
---|
2021 | } |
---|
2022 | } |
---|