1 | /* |
---|
2 | * This program is free software; you can redistribute it and/or modify |
---|
3 | * it under the terms of the GNU General Public License as published by |
---|
4 | * the Free Software Foundation; either version 2 of the License, or |
---|
5 | * (at your option) any later version. |
---|
6 | * |
---|
7 | * This program is distributed in the hope that it will be useful, |
---|
8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
---|
10 | * GNU General Public License for more details. |
---|
11 | * |
---|
12 | * You should have received a copy of the GNU General Public License |
---|
13 | * along with this program; if not, write to the Free Software |
---|
14 | * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. |
---|
15 | */ |
---|
16 | |
---|
17 | /* |
---|
18 | * C45PruneableClassifierTreeG.java |
---|
19 | * Copyright (C) 1999 University of Waikato, Hamilton, New Zealand |
---|
20 | * Copyright (C) 2007 Geoff Webb & Janice Boughton |
---|
21 | * |
---|
22 | */ |
---|
23 | |
---|
24 | package weka.classifiers.trees.j48; |
---|
25 | |
---|
26 | import weka.core.Capabilities; |
---|
27 | import weka.core.Instances; |
---|
28 | import weka.core.Instance; |
---|
29 | import weka.core.RevisionUtils; |
---|
30 | import weka.core.Utils; |
---|
31 | import weka.core.Capabilities.Capability; |
---|
32 | import java.util.ArrayList; |
---|
33 | import java.util.Collections; |
---|
34 | |
---|
35 | /** |
---|
36 | * Class for handling a tree structure that can |
---|
37 | * be pruned using C4.5 procedures and have nodes grafted on. |
---|
38 | * |
---|
39 | * @author Janice Boughton (based on code by Eibe Frank) |
---|
40 | * @version $Revision: 5532 $ |
---|
41 | */ |
---|
42 | |
---|
43 | public class C45PruneableClassifierTreeG extends ClassifierTree{ |
---|
44 | |
---|
45 | /** for serialization */ |
---|
46 | static final long serialVersionUID = 66981207374331964L; |
---|
47 | |
---|
48 | /** True if the tree is to be pruned. */ |
---|
49 | boolean m_pruneTheTree = false; |
---|
50 | |
---|
51 | /** The confidence factor for pruning. */ |
---|
52 | float m_CF = 0.25f; |
---|
53 | |
---|
54 | /** Is subtree raising to be performed? */ |
---|
55 | boolean m_subtreeRaising = true; |
---|
56 | |
---|
57 | /** Cleanup after the tree has been built. */ |
---|
58 | boolean m_cleanup = true; |
---|
59 | |
---|
60 | /** flag for using relabelling when grafting */ |
---|
61 | boolean m_relabel = false; |
---|
62 | |
---|
63 | /** binomial probability critical value */ |
---|
64 | double m_BiProbCrit = 1.64; |
---|
65 | |
---|
66 | boolean m_Debug = false; |
---|
67 | |
---|
68 | /** |
---|
69 | * Constructor for pruneable tree structure. Stores reference |
---|
70 | * to associated training data at each node. |
---|
71 | * |
---|
72 | * @param toSelectLocModel selection method for local splitting model |
---|
73 | * @param pruneTree true if the tree is to be pruned |
---|
74 | * @param cf the confidence factor for pruning |
---|
75 | * @param raiseTree |
---|
76 | * @param cleanup |
---|
77 | * @throws Exception if something goes wrong |
---|
78 | */ |
---|
79 | public C45PruneableClassifierTreeG(ModelSelection toSelectLocModel, |
---|
80 | boolean pruneTree,float cf, |
---|
81 | boolean raiseTree, |
---|
82 | boolean relabel, boolean cleanup) |
---|
83 | throws Exception { |
---|
84 | |
---|
85 | super(toSelectLocModel); |
---|
86 | |
---|
87 | m_pruneTheTree = pruneTree; |
---|
88 | m_CF = cf; |
---|
89 | m_subtreeRaising = raiseTree; |
---|
90 | m_cleanup = cleanup; |
---|
91 | m_relabel = relabel; |
---|
92 | } |
---|
93 | |
---|
94 | |
---|
95 | /** |
---|
96 | * Returns default capabilities of the classifier tree. |
---|
97 | * |
---|
98 | * @return the capabilities of this classifier tree |
---|
99 | */ |
---|
100 | public Capabilities getCapabilities() { |
---|
101 | Capabilities result = super.getCapabilities(); |
---|
102 | result.disableAll(); |
---|
103 | |
---|
104 | // attributes |
---|
105 | result.enable(Capability.NOMINAL_ATTRIBUTES); |
---|
106 | result.enable(Capability.NUMERIC_ATTRIBUTES); |
---|
107 | result.enable(Capability.MISSING_VALUES); |
---|
108 | |
---|
109 | // class |
---|
110 | result.enable(Capability.NOMINAL_CLASS); |
---|
111 | result.enable(Capability.MISSING_CLASS_VALUES); |
---|
112 | |
---|
113 | // instances |
---|
114 | result.setMinimumNumberInstances(0); |
---|
115 | |
---|
116 | return result; |
---|
117 | } |
---|
118 | |
---|
119 | /** |
---|
120 | * Constructor for pruneable tree structure. Used to create new nodes |
---|
121 | * in the tree during grafting. |
---|
122 | * |
---|
123 | * @param toSelectLocModel selection method for local splitting model |
---|
124 | * @param data the dta used to produce split model |
---|
125 | * @param gs the split model |
---|
126 | * @param prune true if the tree is to be pruned |
---|
127 | * @param cf the confidence factor for pruning |
---|
128 | * @param raise |
---|
129 | * @param isLeaf if this node is a leaf or not |
---|
130 | * @param relabel whether relabeling occured |
---|
131 | * @param cleanup |
---|
132 | * @throws Exception if something goes wrong |
---|
133 | */ |
---|
134 | public C45PruneableClassifierTreeG(ModelSelection toSelectLocModel, |
---|
135 | Instances data, ClassifierSplitModel gs, |
---|
136 | boolean prune, float cf, boolean raise, |
---|
137 | boolean isLeaf, boolean relabel, |
---|
138 | boolean cleanup) { |
---|
139 | |
---|
140 | super(toSelectLocModel); |
---|
141 | m_relabel = relabel; |
---|
142 | m_cleanup = cleanup; |
---|
143 | m_localModel = gs; |
---|
144 | m_train = data; |
---|
145 | m_test = null; |
---|
146 | m_isLeaf = isLeaf; |
---|
147 | if(gs.distribution().total() > 0) |
---|
148 | m_isEmpty = false; |
---|
149 | else |
---|
150 | m_isEmpty = true; |
---|
151 | |
---|
152 | m_pruneTheTree = prune; |
---|
153 | m_CF = cf; |
---|
154 | m_subtreeRaising = raise; |
---|
155 | } |
---|
156 | |
---|
157 | /** |
---|
158 | * Method for building a pruneable classifier tree. |
---|
159 | * |
---|
160 | * @param data the data for building the tree |
---|
161 | * @throws Exception if something goes wrong |
---|
162 | */ |
---|
163 | public void buildClassifier(Instances data) throws Exception { |
---|
164 | |
---|
165 | // can classifier tree handle the data? |
---|
166 | getCapabilities().testWithFail(data); |
---|
167 | |
---|
168 | // remove instances with missing class |
---|
169 | data = new Instances(data); |
---|
170 | data.deleteWithMissingClass(); |
---|
171 | |
---|
172 | buildTree(data, m_subtreeRaising); |
---|
173 | collapse(); |
---|
174 | if (m_pruneTheTree) { |
---|
175 | prune(); |
---|
176 | } |
---|
177 | doGrafting(data); |
---|
178 | if (m_cleanup) { |
---|
179 | cleanup(new Instances(data, 0)); |
---|
180 | } |
---|
181 | } |
---|
182 | |
---|
183 | |
---|
184 | /** |
---|
185 | * Collapses a tree to a node if training error doesn't increase. |
---|
186 | */ |
---|
187 | public final void collapse(){ |
---|
188 | |
---|
189 | double errorsOfSubtree; |
---|
190 | double errorsOfTree; |
---|
191 | int i; |
---|
192 | |
---|
193 | if (!m_isLeaf){ |
---|
194 | errorsOfSubtree = getTrainingErrors(); |
---|
195 | errorsOfTree = localModel().distribution().numIncorrect(); |
---|
196 | if (errorsOfSubtree >= errorsOfTree-1E-3){ |
---|
197 | |
---|
198 | // Free adjacent trees |
---|
199 | m_sons = null; |
---|
200 | m_isLeaf = true; |
---|
201 | |
---|
202 | // Get NoSplit Model for tree. |
---|
203 | m_localModel = new NoSplit(localModel().distribution()); |
---|
204 | }else |
---|
205 | for (i=0;i<m_sons.length;i++) |
---|
206 | son(i).collapse(); |
---|
207 | } |
---|
208 | } |
---|
209 | |
---|
210 | /** |
---|
211 | * Prunes a tree using C4.5's pruning procedure. |
---|
212 | * |
---|
213 | * @throws Exception if something goes wrong |
---|
214 | */ |
---|
215 | public void prune() throws Exception { |
---|
216 | |
---|
217 | double errorsLargestBranch; |
---|
218 | double errorsLeaf; |
---|
219 | double errorsTree; |
---|
220 | int indexOfLargestBranch; |
---|
221 | C45PruneableClassifierTreeG largestBranch; |
---|
222 | int i; |
---|
223 | |
---|
224 | if (!m_isLeaf){ |
---|
225 | |
---|
226 | // Prune all subtrees. |
---|
227 | for (i=0;i<m_sons.length;i++) |
---|
228 | son(i).prune(); |
---|
229 | |
---|
230 | // Compute error for largest branch |
---|
231 | indexOfLargestBranch = localModel().distribution().maxBag(); |
---|
232 | if (m_subtreeRaising) { |
---|
233 | errorsLargestBranch = son(indexOfLargestBranch). |
---|
234 | getEstimatedErrorsForBranch((Instances)m_train); |
---|
235 | } else { |
---|
236 | errorsLargestBranch = Double.MAX_VALUE; |
---|
237 | } |
---|
238 | |
---|
239 | // Compute error if this Tree would be leaf |
---|
240 | errorsLeaf = |
---|
241 | getEstimatedErrorsForDistribution(localModel().distribution()); |
---|
242 | |
---|
243 | // Compute error for the whole subtree |
---|
244 | errorsTree = getEstimatedErrors(); |
---|
245 | |
---|
246 | // Decide if leaf is best choice. |
---|
247 | if (Utils.smOrEq(errorsLeaf,errorsTree+0.1) && |
---|
248 | Utils.smOrEq(errorsLeaf,errorsLargestBranch+0.1)){ |
---|
249 | |
---|
250 | // Free son Trees |
---|
251 | m_sons = null; |
---|
252 | m_isLeaf = true; |
---|
253 | |
---|
254 | // Get NoSplit Model for node. |
---|
255 | m_localModel = new NoSplit(localModel().distribution()); |
---|
256 | return; |
---|
257 | } |
---|
258 | |
---|
259 | // Decide if largest branch is better choice |
---|
260 | // than whole subtree. |
---|
261 | if (Utils.smOrEq(errorsLargestBranch,errorsTree+0.1)){ |
---|
262 | largestBranch = son(indexOfLargestBranch); |
---|
263 | m_sons = largestBranch.m_sons; |
---|
264 | m_localModel = largestBranch.localModel(); |
---|
265 | m_isLeaf = largestBranch.m_isLeaf; |
---|
266 | newDistribution(m_train); |
---|
267 | prune(); |
---|
268 | } |
---|
269 | } |
---|
270 | } |
---|
271 | |
---|
272 | /** |
---|
273 | * Returns a newly created tree. |
---|
274 | * |
---|
275 | * @param data the data to work with |
---|
276 | * @return the new tree |
---|
277 | * @throws Exception if something goes wrong |
---|
278 | */ |
---|
279 | protected ClassifierTree getNewTree(Instances data) throws Exception { |
---|
280 | |
---|
281 | C45PruneableClassifierTreeG newTree = |
---|
282 | new C45PruneableClassifierTreeG(m_toSelectModel, m_pruneTheTree, m_CF, |
---|
283 | m_subtreeRaising, m_relabel, m_cleanup); |
---|
284 | // ATBOP Modification // m_subtreeRaising, m_cleanup); |
---|
285 | |
---|
286 | newTree.buildTree((Instances)data, m_subtreeRaising); |
---|
287 | |
---|
288 | return newTree; |
---|
289 | } |
---|
290 | |
---|
291 | /** |
---|
292 | * Computes estimated errors for tree. |
---|
293 | * |
---|
294 | * @return the estimated errors |
---|
295 | */ |
---|
296 | private double getEstimatedErrors(){ |
---|
297 | |
---|
298 | double errors = 0; |
---|
299 | int i; |
---|
300 | |
---|
301 | if (m_isLeaf) |
---|
302 | return getEstimatedErrorsForDistribution(localModel().distribution()); |
---|
303 | else{ |
---|
304 | for (i=0;i<m_sons.length;i++) |
---|
305 | errors = errors+son(i).getEstimatedErrors(); |
---|
306 | return errors; |
---|
307 | } |
---|
308 | } |
---|
309 | |
---|
310 | /** |
---|
311 | * Computes estimated errors for one branch. |
---|
312 | * |
---|
313 | * @param data the data to work with |
---|
314 | * @return the estimated errors |
---|
315 | * @throws Exception if something goes wrong |
---|
316 | */ |
---|
317 | private double getEstimatedErrorsForBranch(Instances data) |
---|
318 | throws Exception { |
---|
319 | |
---|
320 | Instances [] localInstances; |
---|
321 | double errors = 0; |
---|
322 | int i; |
---|
323 | |
---|
324 | if (m_isLeaf) |
---|
325 | return getEstimatedErrorsForDistribution(new Distribution(data)); |
---|
326 | else{ |
---|
327 | Distribution savedDist = localModel().m_distribution; |
---|
328 | localModel().resetDistribution(data); |
---|
329 | localInstances = (Instances[])localModel().split(data); |
---|
330 | localModel().m_distribution = savedDist; |
---|
331 | for (i=0;i<m_sons.length;i++) |
---|
332 | errors = errors+ |
---|
333 | son(i).getEstimatedErrorsForBranch(localInstances[i]); |
---|
334 | return errors; |
---|
335 | } |
---|
336 | } |
---|
337 | |
---|
338 | /** |
---|
339 | * Computes estimated errors for leaf. |
---|
340 | * |
---|
341 | * @param theDistribution the distribution to use |
---|
342 | * @return the estimated errors |
---|
343 | */ |
---|
344 | private double getEstimatedErrorsForDistribution(Distribution |
---|
345 | theDistribution){ |
---|
346 | |
---|
347 | if (Utils.eq(theDistribution.total(),0)) |
---|
348 | return 0; |
---|
349 | else |
---|
350 | return theDistribution.numIncorrect()+ |
---|
351 | Stats.addErrs(theDistribution.total(), |
---|
352 | theDistribution.numIncorrect(),m_CF); |
---|
353 | } |
---|
354 | |
---|
355 | /** |
---|
356 | * Computes errors of tree on training data. |
---|
357 | * |
---|
358 | * @return the training errors |
---|
359 | */ |
---|
360 | private double getTrainingErrors(){ |
---|
361 | |
---|
362 | double errors = 0; |
---|
363 | int i; |
---|
364 | |
---|
365 | if (m_isLeaf) |
---|
366 | return localModel().distribution().numIncorrect(); |
---|
367 | else{ |
---|
368 | for (i=0;i<m_sons.length;i++) |
---|
369 | errors = errors+son(i).getTrainingErrors(); |
---|
370 | return errors; |
---|
371 | } |
---|
372 | } |
---|
373 | |
---|
374 | /** |
---|
375 | * Method just exists to make program easier to read. |
---|
376 | * |
---|
377 | * @return the local split model |
---|
378 | */ |
---|
379 | private ClassifierSplitModel localModel(){ |
---|
380 | |
---|
381 | return (ClassifierSplitModel)m_localModel; |
---|
382 | } |
---|
383 | |
---|
384 | /** |
---|
385 | * Computes new distributions of instances for nodes |
---|
386 | * in tree. |
---|
387 | * |
---|
388 | * @param data the data to compute the distributions for |
---|
389 | * @throws Exception if something goes wrong |
---|
390 | */ |
---|
391 | private void newDistribution(Instances data) throws Exception { |
---|
392 | |
---|
393 | Instances [] localInstances; |
---|
394 | |
---|
395 | localModel().resetDistribution(data); |
---|
396 | m_train = data; |
---|
397 | if (!m_isLeaf){ |
---|
398 | localInstances = |
---|
399 | (Instances [])localModel().split(data); |
---|
400 | for (int i = 0; i < m_sons.length; i++) |
---|
401 | son(i).newDistribution(localInstances[i]); |
---|
402 | } else { |
---|
403 | |
---|
404 | // Check whether there are some instances at the leaf now! |
---|
405 | if (!Utils.eq(data.sumOfWeights(), 0)) { |
---|
406 | m_isEmpty = false; |
---|
407 | } |
---|
408 | } |
---|
409 | } |
---|
410 | |
---|
411 | /** |
---|
412 | * Method just exists to make program easier to read. |
---|
413 | */ |
---|
414 | private C45PruneableClassifierTreeG son(int index){ |
---|
415 | return (C45PruneableClassifierTreeG)m_sons[index]; |
---|
416 | } |
---|
417 | |
---|
418 | |
---|
419 | /** |
---|
420 | * Initializes variables for grafting. |
---|
421 | * sets up limits array (for numeric attributes) and calls |
---|
422 | * the recursive function traverseTree. |
---|
423 | * |
---|
424 | * @param data the data for the tree |
---|
425 | * @throws Exception if anything goes wrong |
---|
426 | */ |
---|
427 | public void doGrafting(Instances data) throws Exception { |
---|
428 | |
---|
429 | // 2d array for the limits |
---|
430 | double [][] limits = new double[data.numAttributes()][2]; |
---|
431 | // 2nd dimension: index 0 == lower limit, index 1 == upper limit |
---|
432 | // initialise to no limit |
---|
433 | for(int i = 0; i < data.numAttributes(); i++) { |
---|
434 | limits[i][0] = Double.NEGATIVE_INFINITY; |
---|
435 | limits[i][1] = Double.POSITIVE_INFINITY; |
---|
436 | } |
---|
437 | |
---|
438 | // use an index instead of creating new Insances objects all the time |
---|
439 | // instanceIndex[0] == array for weights at leaf |
---|
440 | // instanceIndex[1] == array for weights in atbop |
---|
441 | double [][] instanceIndex = new double[2][data.numInstances()]; |
---|
442 | // initialize the weight for each instance |
---|
443 | for(int x = 0; x < data.numInstances(); x++) { |
---|
444 | instanceIndex[0][x] = 1; |
---|
445 | instanceIndex[1][x] = 1; // leaf instances are in atbop |
---|
446 | } |
---|
447 | |
---|
448 | // first call to graft |
---|
449 | traverseTree(data, instanceIndex, limits, this, 0, -1); |
---|
450 | } |
---|
451 | |
---|
452 | |
---|
453 | /** |
---|
454 | * recursive function. |
---|
455 | * if this node is a leaf then calls findGraft, otherwise sorts |
---|
456 | * the two sets of instances (tracked in iindex array) and calls |
---|
457 | * sortInstances for each of the child nodes (which then calls |
---|
458 | * this method). |
---|
459 | * |
---|
460 | * @param fulldata all instances |
---|
461 | * @param iindex array the tracks the weight of each instance in |
---|
462 | * the atbop and at the leaf (0.0 if not present) |
---|
463 | * @param limits array specifying current upper/lower limits for numeric atts |
---|
464 | * @param parent the node immediately before the current one |
---|
465 | * @param pL laplace for node, as calculated by parent (in case leaf is empty) |
---|
466 | * @param nodeClass class of node, determined by parent (in case leaf empty) |
---|
467 | */ |
---|
468 | private void traverseTree(Instances fulldata, double [][] iindex, |
---|
469 | double[][] limits, C45PruneableClassifierTreeG parent, |
---|
470 | double pL, int nodeClass) throws Exception { |
---|
471 | |
---|
472 | if(m_isLeaf) { |
---|
473 | |
---|
474 | findGraft(fulldata, iindex, limits, |
---|
475 | (ClassifierTree)parent, pL, nodeClass); |
---|
476 | |
---|
477 | } else { |
---|
478 | |
---|
479 | // traverse each branch |
---|
480 | for(int i = 0; i < localModel().numSubsets(); i++) { |
---|
481 | |
---|
482 | double [][] newiindex = new double[2][fulldata.numInstances()]; |
---|
483 | for(int x = 0; x < 2; x++) |
---|
484 | System.arraycopy(iindex[x], 0, newiindex[x], 0, iindex[x].length); |
---|
485 | sortInstances(fulldata, newiindex, limits, i); |
---|
486 | } |
---|
487 | } |
---|
488 | } |
---|
489 | |
---|
490 | /** |
---|
491 | * sorts/deletes instances into/from node and atbop according to |
---|
492 | * the test for subset, then calls traverseTree for subset's node. |
---|
493 | * |
---|
494 | * @param fulldata all instances |
---|
495 | * @param iindex array the tracks the weight of each instance in |
---|
496 | * the atbop and at the leaf (0.0 if not present) |
---|
497 | * @param limits array specifying current upper/lower limits for numeric atts |
---|
498 | * @param subset the subset for which to sort instances into inode & iatbop |
---|
499 | */ |
---|
500 | private void sortInstances(Instances fulldata, double [][] iindex, |
---|
501 | double [][] limits, int subset) throws Exception { |
---|
502 | |
---|
503 | C45Split test = (C45Split)localModel(); |
---|
504 | |
---|
505 | // update the instances index for subset |
---|
506 | double knownCases = 0; |
---|
507 | double thisSubsetCount = 0; |
---|
508 | for(int x = 0; x < iindex[0].length; x++) { |
---|
509 | if(iindex[0][x] == 0 && iindex[1][x] == 0) // skip "discarded" instances |
---|
510 | continue; |
---|
511 | if(!fulldata.instance(x).isMissing(test.attIndex())) { |
---|
512 | knownCases += iindex[0][x]; |
---|
513 | if(test.whichSubset(fulldata.instance(x)) != subset) { |
---|
514 | if(iindex[0][x] > 0) { |
---|
515 | // move to atbop, delete from leaf |
---|
516 | iindex[1][x] = iindex[0][x]; |
---|
517 | iindex[0][x] = 0; |
---|
518 | } else { |
---|
519 | if(iindex[1][x] > 0) { |
---|
520 | // instance is now "discarded" |
---|
521 | iindex[1][x] = 0; |
---|
522 | } |
---|
523 | } |
---|
524 | } else { |
---|
525 | thisSubsetCount += iindex[0][x]; |
---|
526 | } |
---|
527 | } |
---|
528 | } |
---|
529 | |
---|
530 | // work out proportions of weight for missing values for leaf and atbop |
---|
531 | double lprop = (knownCases == 0) ? (1.0 / (double)test.numSubsets()) |
---|
532 | : (thisSubsetCount / (double)knownCases); |
---|
533 | |
---|
534 | // add in the instances that have missing value for attIndex |
---|
535 | for(int x = 0; x < iindex[0].length; x++) { |
---|
536 | if(iindex[0][x] == 0 && iindex[1][x] == 0) |
---|
537 | continue; // skip "discarded" instances |
---|
538 | if(fulldata.instance(x).isMissing(test.attIndex())) { |
---|
539 | iindex[1][x] -= (iindex[1][x] - iindex[0][x]) * (1-lprop); |
---|
540 | iindex[0][x] *= lprop; |
---|
541 | } |
---|
542 | } |
---|
543 | |
---|
544 | int nodeClass = localModel().distribution().maxClass(subset); |
---|
545 | double pL = (localModel().distribution().perClass(nodeClass) + 1.0) |
---|
546 | / (localModel().distribution().total() + 2.0); |
---|
547 | |
---|
548 | // call traerseTree method for the child node |
---|
549 | son(subset).traverseTree(fulldata, iindex, |
---|
550 | test.minsAndMaxs(fulldata, limits, subset), this, pL, nodeClass); |
---|
551 | } |
---|
552 | |
---|
553 | /** |
---|
554 | * finds new nodes that improve accuracy and grafts them onto the tree |
---|
555 | * |
---|
556 | * @param fulldata the instances in whole trainset |
---|
557 | * @param iindex records num tests each instance has failed up to this node |
---|
558 | * @param limits the upper/lower limits for numeric attributes |
---|
559 | * @param parent the node immediately before the current one |
---|
560 | * @param pLaplace laplace for leaf, calculated by parent (in case leaf empty) |
---|
561 | * @param pLeafClass class of leaf, determined by parent (in case leaf empty) |
---|
562 | */ |
---|
563 | private void findGraft(Instances fulldata, double [][] iindex, |
---|
564 | double [][] limits, ClassifierTree parent, double pLaplace, |
---|
565 | int pLeafClass) throws Exception { |
---|
566 | |
---|
567 | // get the class for this leaf |
---|
568 | int leafClass = (m_isEmpty) |
---|
569 | ? pLeafClass |
---|
570 | : localModel().distribution().maxClass(); |
---|
571 | |
---|
572 | // get the laplace value for this leaf |
---|
573 | double leafLaplace = (m_isEmpty) |
---|
574 | ? pLaplace |
---|
575 | : laplaceLeaf(leafClass); |
---|
576 | |
---|
577 | // sort the instances into those at the leaf, those in atbop, and discarded |
---|
578 | Instances l = new Instances(fulldata, fulldata.numInstances()); |
---|
579 | Instances n = new Instances(fulldata, fulldata.numInstances()); |
---|
580 | int lcount = 0; |
---|
581 | int acount = 0; |
---|
582 | for(int x = 0; x < fulldata.numInstances(); x++) { |
---|
583 | if(iindex[0][x] <= 0 && iindex[1][x] <= 0) |
---|
584 | continue; |
---|
585 | if(iindex[0][x] != 0) { |
---|
586 | l.add(fulldata.instance(x)); |
---|
587 | l.instance(lcount).setWeight(iindex[0][x]); |
---|
588 | // move instance's weight in iindex to same index as in l |
---|
589 | iindex[0][lcount++] = iindex[0][x]; |
---|
590 | } |
---|
591 | if(iindex[1][x] > 0) { |
---|
592 | n.add(fulldata.instance(x)); |
---|
593 | n.instance(acount).setWeight(iindex[1][x]); |
---|
594 | // move instance's weight in iindex to same index as in n |
---|
595 | iindex[1][acount++] = iindex[1][x]; |
---|
596 | } |
---|
597 | } |
---|
598 | |
---|
599 | boolean graftPossible = false; |
---|
600 | double [] classDist = new double[n.numClasses()]; |
---|
601 | for(int x = 0; x < n.numInstances(); x++) { |
---|
602 | if(iindex[1][x] > 0 && !n.instance(x).classIsMissing()) |
---|
603 | classDist[(int)n.instance(x).classValue()] += iindex[1][x]; |
---|
604 | } |
---|
605 | |
---|
606 | for(int cVal = 0; cVal < n.numClasses(); cVal++) { |
---|
607 | double theLaplace = (classDist[cVal] + 1.0) / (classDist[cVal] + 2.0); |
---|
608 | if(cVal != leafClass && (theLaplace > leafLaplace) && |
---|
609 | (biprob(classDist[cVal], classDist[cVal], leafLaplace) |
---|
610 | > m_BiProbCrit)) { |
---|
611 | graftPossible = true; |
---|
612 | break; |
---|
613 | } |
---|
614 | } |
---|
615 | |
---|
616 | if(!graftPossible) { |
---|
617 | return; |
---|
618 | } |
---|
619 | |
---|
620 | // 1. Initialize to {} a set of tuples t containing potential tests |
---|
621 | ArrayList t = new ArrayList(); |
---|
622 | |
---|
623 | // go through each attribute |
---|
624 | for(int a = 0; a < n.numAttributes(); a++) { |
---|
625 | if(a == n.classIndex()) |
---|
626 | continue; // skip the class |
---|
627 | |
---|
628 | // sort instances in atbop by $a |
---|
629 | int [] sorted = sortByAttribute(n, a); |
---|
630 | |
---|
631 | // 2. For each continuous attribute $a: |
---|
632 | if(n.attribute(a).isNumeric()) { |
---|
633 | |
---|
634 | // find min and max values for this attribute at the leaf |
---|
635 | boolean prohibited = false; |
---|
636 | double minLeaf = Double.POSITIVE_INFINITY; |
---|
637 | double maxLeaf = Double.NEGATIVE_INFINITY; |
---|
638 | for(int i = 0; i < l.numInstances(); i++) { |
---|
639 | if(l.instance(i).isMissing(a)) { |
---|
640 | if(l.instance(i).classValue() == leafClass) { |
---|
641 | prohibited = true; |
---|
642 | break; |
---|
643 | } |
---|
644 | } |
---|
645 | double value = l.instance(i).value(a); |
---|
646 | if(!m_relabel || l.instance(i).classValue() == leafClass) { |
---|
647 | if(value < minLeaf) |
---|
648 | minLeaf = value; |
---|
649 | if(value > maxLeaf) |
---|
650 | maxLeaf = value; |
---|
651 | } |
---|
652 | } |
---|
653 | if(prohibited) { |
---|
654 | continue; |
---|
655 | } |
---|
656 | |
---|
657 | // (a) find values of |
---|
658 | // $n: instances in atbop (already have that, actually) |
---|
659 | // $v: a value for $a that exists for a case in the atbop, where |
---|
660 | // $v is < the min value for $a for a case at the leaf which |
---|
661 | // has the class $c, and $v is > the lowerlimit of $a at |
---|
662 | // the leaf. |
---|
663 | // (note: error in original paper stated that $v must be |
---|
664 | // smaller OR EQUAL TO the min value). |
---|
665 | // $k: $k is a class |
---|
666 | // that maximize L' = Laplace({$x: $x contained in cases($n) |
---|
667 | // & value($a,$x) <= $v & value($a,$x) > lowerlim($l,$a)}, $k). |
---|
668 | double minBestClass = Double.NaN; |
---|
669 | double minBestLaplace = leafLaplace; |
---|
670 | double minBestVal = Double.NaN; |
---|
671 | double minBestPos = Double.NaN; |
---|
672 | double minBestTotal = Double.NaN; |
---|
673 | double [][] minBestCounts = null; |
---|
674 | double [][] counts = new double[2][n.numClasses()]; |
---|
675 | for(int x = 0; x < n.numInstances(); x++) { |
---|
676 | if(n.instance(sorted[x]).isMissing(a)) |
---|
677 | break; // missing are sorted to end: no more valid vals |
---|
678 | |
---|
679 | double theval = n.instance(sorted[x]).value(a); |
---|
680 | if(m_Debug) |
---|
681 | System.out.println("\t " + theval); |
---|
682 | |
---|
683 | if(theval <= limits[a][0]) { |
---|
684 | if(m_Debug) |
---|
685 | System.out.println("\t <= lowerlim: continuing..."); |
---|
686 | continue; |
---|
687 | } |
---|
688 | // note: error in paper would have this read "theVal > minLeaf) |
---|
689 | if(theval >= minLeaf) { |
---|
690 | if(m_Debug) |
---|
691 | System.out.println("\t >= minLeaf; breaking..."); |
---|
692 | break; |
---|
693 | } |
---|
694 | counts[0][(int)n.instance(sorted[x]).classValue()] |
---|
695 | += iindex[1][sorted[x]]; |
---|
696 | |
---|
697 | if(x != n.numInstances() - 1) { |
---|
698 | int z = x + 1; |
---|
699 | while(z < n.numInstances() |
---|
700 | && n.instance(sorted[z]).value(a) == theval) { |
---|
701 | z++; x++; |
---|
702 | counts[0][(int)n.instance(sorted[x]).classValue()] |
---|
703 | += iindex[1][sorted[x]]; |
---|
704 | } |
---|
705 | } |
---|
706 | |
---|
707 | // work out the best laplace/class (for <= theval) |
---|
708 | double total = Utils.sum(counts[0]); |
---|
709 | for(int c = 0; c < n.numClasses(); c++) { |
---|
710 | double temp = (counts[0][c]+1.0)/(total+2.0); |
---|
711 | if(temp > minBestLaplace) { |
---|
712 | minBestPos = counts[0][c]; |
---|
713 | minBestTotal = total; |
---|
714 | minBestLaplace = temp; |
---|
715 | minBestClass = c; |
---|
716 | minBestCounts = copyCounts(counts); |
---|
717 | |
---|
718 | minBestVal = (x == n.numInstances()-1) |
---|
719 | ? theval |
---|
720 | : ((theval + n.instance(sorted[x+1]).value(a)) / 2.0); |
---|
721 | } |
---|
722 | } |
---|
723 | } |
---|
724 | |
---|
725 | // (b) add to t tuple <n,a,v,k,L',"<="> |
---|
726 | if(!Double.isNaN(minBestVal) |
---|
727 | && biprob(minBestPos, minBestTotal, leafLaplace) > m_BiProbCrit) { |
---|
728 | GraftSplit gsplit = null; |
---|
729 | try { |
---|
730 | gsplit = new GraftSplit(a, minBestVal, 0, |
---|
731 | leafClass, minBestCounts); |
---|
732 | } catch (Exception e) { |
---|
733 | System.err.println("graftsplit error: "+e.getMessage()); |
---|
734 | System.exit(1); |
---|
735 | } |
---|
736 | t.add(gsplit); |
---|
737 | } |
---|
738 | // free space |
---|
739 | minBestCounts = null; |
---|
740 | |
---|
741 | // (c) find values of |
---|
742 | // n: instances in atbop (already have that, actually) |
---|
743 | // $v: a value for $a that exists for a case in the atbop, where |
---|
744 | // $v is > the max value for $a for a case at the leaf which |
---|
745 | // has the class $c, and $v is <= the upperlimit of $a at |
---|
746 | // the leaf. |
---|
747 | // k: k is a class |
---|
748 | // that maximize L' = Laplace({x: x contained in cases(n) |
---|
749 | // & value(a,x) > v & value(a,x) <= upperlim(l,a)}, k). |
---|
750 | double maxBestClass = -1; |
---|
751 | double maxBestLaplace = leafLaplace; |
---|
752 | double maxBestVal = Double.NaN; |
---|
753 | double maxBestPos = Double.NaN; |
---|
754 | double maxBestTotal = Double.NaN; |
---|
755 | double [][] maxBestCounts = null; |
---|
756 | for(int c = 0; c < n.numClasses(); c++) { // zero the counts |
---|
757 | counts[0][c] = 0; |
---|
758 | counts[1][c] = 0; // shouldn't need to do this ... |
---|
759 | } |
---|
760 | |
---|
761 | // check smallest val for a in atbop is < upper limit |
---|
762 | if(n.numInstances() >= 1 |
---|
763 | && n.instance(sorted[0]).value(a) < limits[a][1]) { |
---|
764 | for(int x = n.numInstances() - 1; x >= 0; x--) { |
---|
765 | if(n.instance(sorted[x]).isMissing(a)) |
---|
766 | continue; |
---|
767 | |
---|
768 | double theval = n.instance(sorted[x]).value(a); |
---|
769 | if(m_Debug) |
---|
770 | System.out.println("\t " + theval); |
---|
771 | |
---|
772 | if(theval > limits[a][1]) { |
---|
773 | if(m_Debug) |
---|
774 | System.out.println("\t >= upperlim; continuing..."); |
---|
775 | continue; |
---|
776 | } |
---|
777 | if(theval <= maxLeaf) { |
---|
778 | if(m_Debug) |
---|
779 | System.out.println("\t < maxLeaf; breaking..."); |
---|
780 | break; |
---|
781 | } |
---|
782 | |
---|
783 | // increment counts |
---|
784 | counts[1][(int)n.instance(sorted[x]).classValue()] |
---|
785 | += iindex[1][sorted[x]]; |
---|
786 | |
---|
787 | if(x != 0 && !n.instance(sorted[x-1]).isMissing(a)) { |
---|
788 | int z = x - 1; |
---|
789 | while(z >= 0 && n.instance(sorted[z]).value(a) == theval) { |
---|
790 | z--; x--; |
---|
791 | counts[1][(int)n.instance(sorted[x]).classValue()] |
---|
792 | += iindex[1][sorted[x]]; |
---|
793 | } |
---|
794 | } |
---|
795 | |
---|
796 | // work out best laplace for > theval |
---|
797 | double total = Utils.sum(counts[1]); |
---|
798 | for(int c = 0; c < n.numClasses(); c++) { |
---|
799 | double temp = (counts[1][c]+1.0)/(total+2.0); |
---|
800 | if(temp > maxBestLaplace ) { |
---|
801 | maxBestPos = counts[1][c]; |
---|
802 | maxBestTotal = total; |
---|
803 | maxBestLaplace = temp; |
---|
804 | maxBestClass = c; |
---|
805 | maxBestCounts = copyCounts(counts); |
---|
806 | maxBestVal = (x == 0) |
---|
807 | ? theval |
---|
808 | : ((theval + n.instance(sorted[x-1]).value(a)) / 2.0); |
---|
809 | } |
---|
810 | } |
---|
811 | } |
---|
812 | |
---|
813 | // (d) add to t tuple <n,a,v,k,L',">"> |
---|
814 | if(!Double.isNaN(maxBestVal) |
---|
815 | && biprob(maxBestPos,maxBestTotal,leafLaplace) > m_BiProbCrit) { |
---|
816 | GraftSplit gsplit = null; |
---|
817 | try { |
---|
818 | gsplit = new GraftSplit(a, maxBestVal, 1, |
---|
819 | leafClass, maxBestCounts); |
---|
820 | } catch (Exception e) { |
---|
821 | System.err.println("graftsplit error:" + e.getMessage()); |
---|
822 | System.exit(1); |
---|
823 | } |
---|
824 | t.add(gsplit); |
---|
825 | } |
---|
826 | } |
---|
827 | } else { // must be a nominal attribute |
---|
828 | |
---|
829 | // 3. for each discrete attribute a for which there is no |
---|
830 | // test at an ancestor of l |
---|
831 | |
---|
832 | // skip if this attribute has already been used |
---|
833 | if(limits[a][1] == 1) { |
---|
834 | continue; |
---|
835 | } |
---|
836 | |
---|
837 | boolean [] prohibit = new boolean[l.attribute(a).numValues()]; |
---|
838 | for(int aval = 0; aval < n.attribute(a).numValues(); aval++) { |
---|
839 | for(int x = 0; x < l.numInstances(); x++) { |
---|
840 | if((l.instance(x).isMissing(a) |
---|
841 | || l.instance(x).value(a) == aval) |
---|
842 | && (!m_relabel || (l.instance(x).classValue() == leafClass))) { |
---|
843 | prohibit[aval] = true; |
---|
844 | break; |
---|
845 | } |
---|
846 | } |
---|
847 | } |
---|
848 | |
---|
849 | // (a) find values of |
---|
850 | // $n: instances in atbop (already have that, actually) |
---|
851 | // $v: $v is a value for $a |
---|
852 | // $k: $k is a class |
---|
853 | // that maximize L' = Laplace({$x: $x contained in cases($n) |
---|
854 | // & value($a,$x) = $v}, $k). |
---|
855 | double bestVal = Double.NaN; |
---|
856 | double bestClass = Double.NaN; |
---|
857 | double bestLaplace = leafLaplace; |
---|
858 | double [][] bestCounts = null; |
---|
859 | double [][] counts = new double[2][n.numClasses()]; |
---|
860 | |
---|
861 | for(int x = 0; x < n.numInstances(); x++) { |
---|
862 | if(n.instance(sorted[x]).isMissing(a)) |
---|
863 | continue; |
---|
864 | |
---|
865 | // zero the counts |
---|
866 | for(int c = 0; c < n.numClasses(); c++) |
---|
867 | counts[0][c] = 0; |
---|
868 | |
---|
869 | double theval = n.instance(sorted[x]).value(a); |
---|
870 | counts[0][(int)n.instance(sorted[x]).classValue()] |
---|
871 | += iindex[1][sorted[x]]; |
---|
872 | |
---|
873 | if(x != n.numInstances() - 1) { |
---|
874 | int z = x + 1; |
---|
875 | while(z < n.numInstances() |
---|
876 | && n.instance(sorted[z]).value(a) == theval) { |
---|
877 | z++; x++; |
---|
878 | counts[0][(int)n.instance(sorted[x]).classValue()] |
---|
879 | += iindex[1][sorted[x]]; |
---|
880 | } |
---|
881 | } |
---|
882 | |
---|
883 | if(!prohibit[(int)theval]) { |
---|
884 | // work out best laplace for > theval |
---|
885 | double total = Utils.sum(counts[0]); |
---|
886 | bestLaplace = leafLaplace; |
---|
887 | bestClass = Double.NaN; |
---|
888 | for(int c = 0; c < n.numClasses(); c++) { |
---|
889 | double temp = (counts[0][c]+1.0)/(total+2.0); |
---|
890 | if(temp > bestLaplace |
---|
891 | && biprob(counts[0][c],total,leafLaplace) > m_BiProbCrit) { |
---|
892 | bestLaplace = temp; |
---|
893 | bestClass = c; |
---|
894 | bestVal = theval; |
---|
895 | bestCounts = copyCounts(counts); |
---|
896 | } |
---|
897 | } |
---|
898 | // add to graft list |
---|
899 | if(!Double.isNaN(bestClass)) { |
---|
900 | GraftSplit gsplit = null; |
---|
901 | try { |
---|
902 | gsplit = new GraftSplit(a, bestVal, 2, |
---|
903 | leafClass, bestCounts); |
---|
904 | } catch (Exception e) { |
---|
905 | System.err.println("graftsplit error: "+e.getMessage()); |
---|
906 | System.exit(1); |
---|
907 | } |
---|
908 | t.add(gsplit); |
---|
909 | } |
---|
910 | } |
---|
911 | } |
---|
912 | // (b) add to t tuple <n,a,v,k,L',"="> |
---|
913 | // done this already |
---|
914 | } |
---|
915 | } |
---|
916 | |
---|
917 | // 4. remove from t all tuples <n,a,v,c,L,x> such that L <= |
---|
918 | // Laplace(cases(l),c) or prob(x,n,Laplace(cases(l),c) <= 0.05 |
---|
919 | // -- checked this constraint prior to adding a tuple -- |
---|
920 | |
---|
921 | // *** step six done before step five for efficiency *** |
---|
922 | // 6. for each <n,a,v,k,L,x> in t ordered on L from highest to lowest |
---|
923 | // order the tuples from highest to lowest laplace |
---|
924 | // (this actually orders lowest to highest) |
---|
925 | Collections.sort(t); |
---|
926 | |
---|
927 | // 5. remove from t all tuples <n,a,v,c,L,x> such that there is |
---|
928 | // no tuple <n',a',v',k',L',x'> such that k' != c & L' < L. |
---|
929 | for(int x = 0; x < t.size(); x++) { |
---|
930 | GraftSplit gs = (GraftSplit)t.get(x); |
---|
931 | if(gs.maxClassForSubsetOfInterest() != leafClass) { |
---|
932 | break; // reached a graft with class != leafClass, so stop deleting |
---|
933 | } else { |
---|
934 | t.remove(x); |
---|
935 | x--; |
---|
936 | } |
---|
937 | } |
---|
938 | |
---|
939 | // if no potential grafts were found, do nothing and return |
---|
940 | if(t.size() < 1) { |
---|
941 | return; |
---|
942 | } |
---|
943 | |
---|
944 | // create the distributions for each graft |
---|
945 | for(int x = t.size()-1; x >= 0; x--) { |
---|
946 | GraftSplit gs = (GraftSplit)t.get(x); |
---|
947 | try { |
---|
948 | gs.buildClassifier(l); |
---|
949 | gs.deleteGraftedCases(l); // so they don't go down the other branch |
---|
950 | } catch (Exception e) { |
---|
951 | System.err.println("graftsplit build error: " + e.getMessage()); |
---|
952 | } |
---|
953 | } |
---|
954 | |
---|
955 | // add this stuff to the tree |
---|
956 | ((C45PruneableClassifierTreeG)parent).setDescendents(t, this); |
---|
957 | } |
---|
958 | |
---|
959 | /** |
---|
960 | * sorts the int array in ascending order by attribute indexed |
---|
961 | * by a in dataset data. |
---|
962 | * @param the data the indices represent |
---|
963 | * @param the index of the attribute to sort by |
---|
964 | * @return array of sorted indicies |
---|
965 | */ |
---|
966 | private int [] sortByAttribute(Instances data, int a) { |
---|
967 | |
---|
968 | double [] attList = data.attributeToDoubleArray(a); |
---|
969 | int [] temp = Utils.sort(attList); |
---|
970 | return temp; |
---|
971 | } |
---|
972 | |
---|
973 | /** |
---|
974 | * deep copy the 2d array of counts |
---|
975 | * |
---|
976 | * @param src the array to copy |
---|
977 | * @return a copy of src |
---|
978 | */ |
---|
979 | private double [][] copyCounts(double [][] src) { |
---|
980 | |
---|
981 | double [][] newArr = new double[src.length][0]; |
---|
982 | for(int x = 0; x < src.length; x++) { |
---|
983 | newArr[x] = new double[src[x].length]; |
---|
984 | for(int y = 0; y < src[x].length; y++) { |
---|
985 | newArr[x][y] = src[x][y]; |
---|
986 | } |
---|
987 | } |
---|
988 | return newArr; |
---|
989 | } |
---|
990 | |
---|
991 | |
---|
992 | /** |
---|
993 | * Help method for computing class probabilities of |
---|
994 | * a given instance. |
---|
995 | * |
---|
996 | * @throws Exception if something goes wrong |
---|
997 | */ |
---|
998 | private double getProbsLaplace(int classIndex, Instance instance, double weight) |
---|
999 | throws Exception { |
---|
1000 | |
---|
1001 | double [] weights; |
---|
1002 | double prob = 0; |
---|
1003 | int treeIndex; |
---|
1004 | int i,j; |
---|
1005 | |
---|
1006 | if (m_isLeaf) { |
---|
1007 | return weight * localModel().classProbLaplace(classIndex, instance, -1); |
---|
1008 | } else { |
---|
1009 | treeIndex = localModel().whichSubset(instance); |
---|
1010 | |
---|
1011 | if (treeIndex == -1) { |
---|
1012 | weights = localModel().weights(instance); |
---|
1013 | for (i = 0; i < m_sons.length; i++) { |
---|
1014 | if (!son(i).m_isEmpty) { |
---|
1015 | if (!son(i).m_isLeaf) { |
---|
1016 | prob += son(i).getProbsLaplace(classIndex, instance, |
---|
1017 | weights[i] * weight); |
---|
1018 | } else { |
---|
1019 | prob += weight * weights[i] * |
---|
1020 | localModel().classProbLaplace(classIndex, instance, i); |
---|
1021 | } |
---|
1022 | } |
---|
1023 | } |
---|
1024 | return prob; |
---|
1025 | } else { |
---|
1026 | |
---|
1027 | if (son(treeIndex).m_isLeaf) { |
---|
1028 | return weight * localModel().classProbLaplace(classIndex, instance, |
---|
1029 | treeIndex); |
---|
1030 | } else { |
---|
1031 | return son(treeIndex).getProbsLaplace(classIndex,instance,weight); |
---|
1032 | } |
---|
1033 | } |
---|
1034 | } |
---|
1035 | } |
---|
1036 | |
---|
1037 | |
---|
1038 | /** |
---|
1039 | * Help method for computing class probabilities of |
---|
1040 | * a given instance. |
---|
1041 | * |
---|
1042 | * @throws Exception if something goes wrong |
---|
1043 | */ |
---|
1044 | private double getProbs(int classIndex, Instance instance, double weight) |
---|
1045 | throws Exception { |
---|
1046 | |
---|
1047 | double [] weights; |
---|
1048 | double prob = 0; |
---|
1049 | int treeIndex; |
---|
1050 | int i,j; |
---|
1051 | |
---|
1052 | if (m_isLeaf) { |
---|
1053 | return weight * localModel().classProb(classIndex, instance, -1); |
---|
1054 | } else { |
---|
1055 | treeIndex = localModel().whichSubset(instance); |
---|
1056 | if (treeIndex == -1) { |
---|
1057 | weights = localModel().weights(instance); |
---|
1058 | for (i = 0; i < m_sons.length; i++) { |
---|
1059 | if (!son(i).m_isEmpty) { |
---|
1060 | prob += son(i).getProbs(classIndex, instance, |
---|
1061 | weights[i] * weight); |
---|
1062 | } |
---|
1063 | } |
---|
1064 | return prob; |
---|
1065 | } else { |
---|
1066 | |
---|
1067 | if (son(treeIndex).m_isEmpty) { |
---|
1068 | return weight * localModel().classProb(classIndex, instance, |
---|
1069 | treeIndex); |
---|
1070 | } else { |
---|
1071 | return son(treeIndex).getProbs(classIndex, instance, weight); |
---|
1072 | } |
---|
1073 | } |
---|
1074 | } |
---|
1075 | } |
---|
1076 | |
---|
1077 | |
---|
1078 | |
---|
1079 | /** |
---|
1080 | * add the grafted nodes at originalLeaf's position in tree. |
---|
1081 | * a recursive function that terminates when t is empty. |
---|
1082 | * |
---|
1083 | * @param t the list of nodes to graft |
---|
1084 | * @param originalLeaf the leaf that the grafts are replacing |
---|
1085 | */ |
---|
1086 | public void setDescendents(ArrayList t, |
---|
1087 | C45PruneableClassifierTreeG originalLeaf) { |
---|
1088 | |
---|
1089 | Instances headerInfo = new Instances(m_train, 0); |
---|
1090 | |
---|
1091 | boolean end = false; |
---|
1092 | ClassifierSplitModel splitmod = null; |
---|
1093 | C45PruneableClassifierTreeG newNode; |
---|
1094 | if(t.size() > 0) { |
---|
1095 | splitmod = (ClassifierSplitModel)t.remove(t.size() - 1); |
---|
1096 | newNode = new C45PruneableClassifierTreeG(m_toSelectModel, headerInfo, |
---|
1097 | splitmod, m_pruneTheTree, m_CF, m_subtreeRaising, |
---|
1098 | false, m_relabel, m_cleanup); |
---|
1099 | } else { |
---|
1100 | // get the leaf for one of newNode's children |
---|
1101 | NoSplit kLeaf = ((GraftSplit)localModel()).getOtherLeaf(); |
---|
1102 | newNode = |
---|
1103 | new C45PruneableClassifierTreeG(m_toSelectModel, headerInfo, |
---|
1104 | kLeaf, m_pruneTheTree, m_CF, m_subtreeRaising, |
---|
1105 | true, m_relabel, m_cleanup); |
---|
1106 | end = true; |
---|
1107 | } |
---|
1108 | |
---|
1109 | // behave differently for parent of original leaf, since we don't |
---|
1110 | // want to destroy any of its other branches |
---|
1111 | if(m_sons != null) { |
---|
1112 | for(int x = 0; x < m_sons.length; x++) { |
---|
1113 | if(son(x).equals(originalLeaf)) { |
---|
1114 | m_sons[x] = newNode; // replace originalLeaf with newNode |
---|
1115 | } |
---|
1116 | } |
---|
1117 | } else { |
---|
1118 | |
---|
1119 | // allocate space for the children |
---|
1120 | m_sons = new C45PruneableClassifierTreeG[localModel().numSubsets()]; |
---|
1121 | |
---|
1122 | // get the leaf for one of newNode's children |
---|
1123 | NoSplit kLeaf = ((GraftSplit)localModel()).getLeaf(); |
---|
1124 | C45PruneableClassifierTreeG kNode = |
---|
1125 | new C45PruneableClassifierTreeG(m_toSelectModel, headerInfo, |
---|
1126 | kLeaf, m_pruneTheTree, m_CF, m_subtreeRaising, |
---|
1127 | true, m_relabel, m_cleanup); |
---|
1128 | |
---|
1129 | // figure where to put the new node |
---|
1130 | if(((GraftSplit)localModel()).subsetOfInterest() == 0) { |
---|
1131 | m_sons[0] = kNode; |
---|
1132 | m_sons[1] = newNode; |
---|
1133 | } else { |
---|
1134 | m_sons[0] = newNode; |
---|
1135 | m_sons[1] = kNode; |
---|
1136 | } |
---|
1137 | } |
---|
1138 | if(!end) |
---|
1139 | ((C45PruneableClassifierTreeG)newNode).setDescendents |
---|
1140 | (t, (C45PruneableClassifierTreeG)originalLeaf); |
---|
1141 | } |
---|
1142 | |
---|
1143 | |
---|
1144 | /** |
---|
1145 | * class prob with laplace correction (assumes binary class) |
---|
1146 | */ |
---|
1147 | private double laplaceLeaf(double classIndex) { |
---|
1148 | double l = (localModel().distribution().perClass((int)classIndex) + 1.0) |
---|
1149 | / (localModel().distribution().total() + 2.0); |
---|
1150 | return l; |
---|
1151 | } |
---|
1152 | |
---|
1153 | |
---|
1154 | /** |
---|
1155 | * Significance test |
---|
1156 | * @param x |
---|
1157 | * @param n |
---|
1158 | * @param r |
---|
1159 | * @return returns the probability of obtaining x or MORE out of n |
---|
1160 | * if r proportion of n are positive. |
---|
1161 | * |
---|
1162 | * z for normal estimation of binomial probability of obtaining x |
---|
1163 | * or more out of n, if r proportion of n are positive |
---|
1164 | */ |
---|
1165 | public double biprob(double x, double n, double r) throws Exception { |
---|
1166 | |
---|
1167 | return ((((x) - 0.5) - (n) * (r)) / Math.sqrt((n) * (r) * (1.0 - (r)))); |
---|
1168 | } |
---|
1169 | |
---|
1170 | /** |
---|
1171 | * Prints tree structure. |
---|
1172 | */ |
---|
1173 | public String toString() { |
---|
1174 | |
---|
1175 | try { |
---|
1176 | StringBuffer text = new StringBuffer(); |
---|
1177 | |
---|
1178 | if(m_isLeaf) { |
---|
1179 | text.append(": "); |
---|
1180 | if(m_localModel instanceof GraftSplit) |
---|
1181 | text.append(((GraftSplit)m_localModel).dumpLabelG(0,m_train)); |
---|
1182 | else |
---|
1183 | text.append(m_localModel.dumpLabel(0,m_train)); |
---|
1184 | } else |
---|
1185 | dumpTree(0,text); |
---|
1186 | text.append("\n\nNumber of Leaves : \t"+numLeaves()+"\n"); |
---|
1187 | text.append("\nSize of the tree : \t"+numNodes()+"\n"); |
---|
1188 | |
---|
1189 | return text.toString(); |
---|
1190 | } catch (Exception e) { |
---|
1191 | return "Can't print classification tree."; |
---|
1192 | } |
---|
1193 | } |
---|
1194 | |
---|
1195 | /** |
---|
1196 | * Help method for printing tree structure. |
---|
1197 | * |
---|
1198 | * @throws Exception if something goes wrong |
---|
1199 | */ |
---|
1200 | protected void dumpTree(int depth,StringBuffer text) throws Exception { |
---|
1201 | |
---|
1202 | int i,j; |
---|
1203 | |
---|
1204 | for(i=0;i<m_sons.length;i++) { |
---|
1205 | text.append("\n");; |
---|
1206 | for(j=0;j<depth;j++) |
---|
1207 | text.append("| "); |
---|
1208 | text.append(m_localModel.leftSide(m_train)); |
---|
1209 | text.append(m_localModel.rightSide(i, m_train)); |
---|
1210 | if(m_sons[i].m_isLeaf) { |
---|
1211 | text.append(": "); |
---|
1212 | if(m_localModel instanceof GraftSplit) |
---|
1213 | text.append(((GraftSplit)m_localModel).dumpLabelG(i,m_train)); |
---|
1214 | else |
---|
1215 | text.append(m_localModel.dumpLabel(i,m_train)); |
---|
1216 | } else |
---|
1217 | ((C45PruneableClassifierTreeG)m_sons[i]).dumpTree(depth+1,text); |
---|
1218 | } |
---|
1219 | } |
---|
1220 | |
---|
1221 | /** |
---|
1222 | * Returns the revision string. |
---|
1223 | * |
---|
1224 | * @return the revision |
---|
1225 | */ |
---|
1226 | public String getRevision() { |
---|
1227 | return RevisionUtils.extract("$Revision: 5532 $"); |
---|
1228 | } |
---|
1229 | } |
---|