source: src/main/java/weka/classifiers/trees/j48/C45PruneableClassifierTreeG.java @ 9

Last change on this file since 9 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 39.5 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    C45PruneableClassifierTreeG.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *    Copyright (C) 2007 Geoff Webb & Janice Boughton
21 *
22 */
23
24package weka.classifiers.trees.j48;
25
26import weka.core.Capabilities;
27import weka.core.Instances;
28import weka.core.Instance;
29import weka.core.RevisionUtils;
30import weka.core.Utils;
31import weka.core.Capabilities.Capability;
32import java.util.ArrayList;
33import java.util.Collections;
34
35/**
36 * Class for handling a tree structure that can
37 * be pruned using C4.5 procedures and have nodes grafted on.
38 *
39 * @author Janice Boughton (based on code by Eibe Frank)
40 * @version $Revision: 5532 $
41 */
42
43public class C45PruneableClassifierTreeG extends ClassifierTree{
44
45  /** for serialization */
46  static final long serialVersionUID = 66981207374331964L;
47
48  /** True if the tree is to be pruned. */
49  boolean m_pruneTheTree = false;
50
51  /** The confidence factor for pruning. */
52  float m_CF = 0.25f;
53
54  /** Is subtree raising to be performed? */
55  boolean m_subtreeRaising = true;
56
57  /** Cleanup after the tree has been built. */
58  boolean m_cleanup = true;
59
60  /** flag for using relabelling when grafting */
61  boolean m_relabel = false;
62
63  /** binomial probability critical value */
64  double m_BiProbCrit = 1.64;
65
66  boolean m_Debug = false;
67
68  /**
69   * Constructor for pruneable tree structure. Stores reference
70   * to associated training data at each node.
71   *
72   * @param toSelectLocModel selection method for local splitting model
73   * @param pruneTree true if the tree is to be pruned
74   * @param cf the confidence factor for pruning
75   * @param raiseTree
76   * @param cleanup
77   * @throws Exception if something goes wrong
78   */
79  public C45PruneableClassifierTreeG(ModelSelection toSelectLocModel,
80                                    boolean pruneTree,float cf,
81                                    boolean raiseTree,
82                                    boolean relabel, boolean cleanup)
83       throws Exception {
84
85    super(toSelectLocModel);
86
87    m_pruneTheTree = pruneTree;
88    m_CF = cf;
89    m_subtreeRaising = raiseTree;
90    m_cleanup = cleanup;
91    m_relabel = relabel;
92  }
93
94
95  /**
96   * Returns default capabilities of the classifier tree.
97   *
98   * @return      the capabilities of this classifier tree
99   */
100  public Capabilities getCapabilities() {
101    Capabilities result = super.getCapabilities();
102    result.disableAll();
103
104    // attributes
105    result.enable(Capability.NOMINAL_ATTRIBUTES);
106    result.enable(Capability.NUMERIC_ATTRIBUTES);
107    result.enable(Capability.MISSING_VALUES);
108
109    // class
110    result.enable(Capability.NOMINAL_CLASS);
111    result.enable(Capability.MISSING_CLASS_VALUES);
112
113    // instances
114    result.setMinimumNumberInstances(0);
115
116    return result;
117  }
118
119  /**
120   * Constructor for pruneable tree structure. Used to create new nodes
121   * in the tree during grafting.
122   *
123   * @param toSelectLocModel selection method for local splitting model
124   * @param data the dta used to produce split model
125   * @param gs the split model
126   * @param prune true if the tree is to be pruned
127   * @param cf the confidence factor for pruning
128   * @param raise
129   * @param isLeaf if this node is a leaf or not
130   * @param relabel whether relabeling occured
131   * @param cleanup
132   * @throws Exception if something goes wrong
133   */
134  public C45PruneableClassifierTreeG(ModelSelection toSelectLocModel, 
135                                    Instances data, ClassifierSplitModel gs, 
136                                    boolean prune, float cf, boolean raise,
137                                    boolean isLeaf, boolean relabel, 
138                                    boolean cleanup) {
139
140    super(toSelectLocModel);
141    m_relabel = relabel;
142    m_cleanup = cleanup;
143    m_localModel = gs;
144    m_train = data;
145    m_test = null;
146    m_isLeaf = isLeaf;
147    if(gs.distribution().total() > 0)
148       m_isEmpty = false;
149    else
150       m_isEmpty = true;
151
152    m_pruneTheTree = prune;
153    m_CF = cf;
154    m_subtreeRaising = raise;
155  }
156
157  /**
158   * Method for building a pruneable classifier tree.
159   *
160   * @param data the data for building the tree
161   * @throws Exception if something goes wrong
162   */
163  public void buildClassifier(Instances data) throws Exception {
164
165    // can classifier tree handle the data?
166    getCapabilities().testWithFail(data);
167
168    // remove instances with missing class
169    data = new Instances(data);
170    data.deleteWithMissingClass();
171
172    buildTree(data, m_subtreeRaising);
173    collapse();
174    if (m_pruneTheTree) {
175      prune();
176    }
177    doGrafting(data);
178    if (m_cleanup) {
179      cleanup(new Instances(data, 0));
180    }
181  }
182
183
184  /**
185   * Collapses a tree to a node if training error doesn't increase.
186   */
187  public final void collapse(){
188
189    double errorsOfSubtree;
190    double errorsOfTree;
191    int i;
192
193    if (!m_isLeaf){
194      errorsOfSubtree = getTrainingErrors();
195      errorsOfTree = localModel().distribution().numIncorrect();
196      if (errorsOfSubtree >= errorsOfTree-1E-3){
197
198        // Free adjacent trees
199        m_sons = null;
200        m_isLeaf = true;
201                       
202        // Get NoSplit Model for tree.
203        m_localModel = new NoSplit(localModel().distribution());
204      }else
205        for (i=0;i<m_sons.length;i++)
206          son(i).collapse();
207    }
208  }
209
210  /**
211   * Prunes a tree using C4.5's pruning procedure.
212   *
213   * @throws Exception if something goes wrong
214   */
215  public void prune() throws Exception {
216
217    double errorsLargestBranch;
218    double errorsLeaf;
219    double errorsTree;
220    int indexOfLargestBranch;
221    C45PruneableClassifierTreeG largestBranch;
222    int i;
223
224    if (!m_isLeaf){
225
226      // Prune all subtrees.
227      for (i=0;i<m_sons.length;i++)
228        son(i).prune();
229
230      // Compute error for largest branch
231      indexOfLargestBranch = localModel().distribution().maxBag();
232      if (m_subtreeRaising) {
233        errorsLargestBranch = son(indexOfLargestBranch).
234          getEstimatedErrorsForBranch((Instances)m_train);
235      } else {
236        errorsLargestBranch = Double.MAX_VALUE;
237      }
238
239      // Compute error if this Tree would be leaf
240      errorsLeaf = 
241        getEstimatedErrorsForDistribution(localModel().distribution());
242
243      // Compute error for the whole subtree
244      errorsTree = getEstimatedErrors();
245
246      // Decide if leaf is best choice.
247      if (Utils.smOrEq(errorsLeaf,errorsTree+0.1) &&
248          Utils.smOrEq(errorsLeaf,errorsLargestBranch+0.1)){
249
250        // Free son Trees
251        m_sons = null;
252        m_isLeaf = true;
253               
254        // Get NoSplit Model for node.
255        m_localModel = new NoSplit(localModel().distribution());
256        return;
257      }
258
259      // Decide if largest branch is better choice
260      // than whole subtree.
261      if (Utils.smOrEq(errorsLargestBranch,errorsTree+0.1)){
262        largestBranch = son(indexOfLargestBranch);
263        m_sons = largestBranch.m_sons;
264        m_localModel = largestBranch.localModel();
265        m_isLeaf = largestBranch.m_isLeaf;
266        newDistribution(m_train);
267        prune();
268      }
269    }
270  }
271
272  /**
273   * Returns a newly created tree.
274   *
275   * @param data the data to work with
276   * @return the new tree
277   * @throws Exception if something goes wrong
278   */
279  protected ClassifierTree getNewTree(Instances data) throws Exception {
280   
281    C45PruneableClassifierTreeG newTree = 
282      new C45PruneableClassifierTreeG(m_toSelectModel, m_pruneTheTree, m_CF,
283             m_subtreeRaising, m_relabel, m_cleanup);
284        // ATBOP Modification     // m_subtreeRaising, m_cleanup);
285
286    newTree.buildTree((Instances)data, m_subtreeRaising);
287
288    return newTree;
289  }
290
291  /**
292   * Computes estimated errors for tree.
293   *
294   * @return the estimated errors
295   */
296  private double getEstimatedErrors(){
297
298    double errors = 0;
299    int i;
300
301    if (m_isLeaf)
302      return getEstimatedErrorsForDistribution(localModel().distribution());
303    else{
304      for (i=0;i<m_sons.length;i++)
305        errors = errors+son(i).getEstimatedErrors();
306      return errors;
307    }
308  }
309 
310  /**
311   * Computes estimated errors for one branch.
312   *
313   * @param data the data to work with
314   * @return the estimated errors
315   * @throws Exception if something goes wrong
316   */
317  private double getEstimatedErrorsForBranch(Instances data) 
318       throws Exception {
319
320    Instances [] localInstances;
321    double errors = 0;
322    int i;
323
324    if (m_isLeaf)
325      return getEstimatedErrorsForDistribution(new Distribution(data));
326    else{
327      Distribution savedDist = localModel().m_distribution;
328      localModel().resetDistribution(data);
329      localInstances = (Instances[])localModel().split(data);
330      localModel().m_distribution = savedDist;
331      for (i=0;i<m_sons.length;i++)
332        errors = errors+
333          son(i).getEstimatedErrorsForBranch(localInstances[i]);
334      return errors;
335    }
336  }
337
338  /**
339   * Computes estimated errors for leaf.
340   *
341   * @param theDistribution the distribution to use
342   * @return the estimated errors
343   */
344  private double getEstimatedErrorsForDistribution(Distribution
345                                                   theDistribution){
346
347    if (Utils.eq(theDistribution.total(),0))
348      return 0;
349    else
350      return theDistribution.numIncorrect()+
351        Stats.addErrs(theDistribution.total(),
352                      theDistribution.numIncorrect(),m_CF);
353  }
354
355  /**
356   * Computes errors of tree on training data.
357   *
358   * @return the training errors
359   */
360  private double getTrainingErrors(){
361
362    double errors = 0;
363    int i;
364
365    if (m_isLeaf)
366      return localModel().distribution().numIncorrect();
367    else{
368      for (i=0;i<m_sons.length;i++)
369        errors = errors+son(i).getTrainingErrors();
370      return errors;
371    }
372  }
373
374  /**
375   * Method just exists to make program easier to read.
376   *
377   * @return the local split model
378   */
379  private ClassifierSplitModel localModel(){
380   
381    return (ClassifierSplitModel)m_localModel;
382  }
383
384  /**
385   * Computes new distributions of instances for nodes
386   * in tree.
387   *
388   * @param data the data to compute the distributions for
389   * @throws Exception if something goes wrong
390   */
391  private void newDistribution(Instances data) throws Exception {
392
393    Instances [] localInstances;
394
395    localModel().resetDistribution(data);
396    m_train = data;
397    if (!m_isLeaf){
398      localInstances = 
399        (Instances [])localModel().split(data);
400      for (int i = 0; i < m_sons.length; i++)
401        son(i).newDistribution(localInstances[i]);
402    } else {
403
404      // Check whether there are some instances at the leaf now!
405      if (!Utils.eq(data.sumOfWeights(), 0)) {
406        m_isEmpty = false;
407      }
408    }
409  }
410
411  /**
412   * Method just exists to make program easier to read.
413   */
414  private C45PruneableClassifierTreeG son(int index){
415    return (C45PruneableClassifierTreeG)m_sons[index];
416  }
417
418
419  /**
420   * Initializes variables for grafting.
421   * sets up limits array (for numeric attributes) and calls
422   * the recursive function traverseTree.
423   *
424   * @param data the data for the tree
425   * @throws Exception if anything goes wrong
426   */
427  public void doGrafting(Instances data) throws Exception {
428
429    // 2d array for the limits
430    double [][] limits = new double[data.numAttributes()][2];
431    // 2nd dimension: index 0 == lower limit, index 1 == upper limit
432    // initialise to no limit
433    for(int i = 0; i < data.numAttributes(); i++) {
434       limits[i][0] = Double.NEGATIVE_INFINITY;
435       limits[i][1] = Double.POSITIVE_INFINITY;
436    }
437
438    // use an index instead of creating new Insances objects all the time
439    // instanceIndex[0] == array for weights at leaf
440    // instanceIndex[1] == array for weights in atbop
441    double [][] instanceIndex = new double[2][data.numInstances()];
442    // initialize the weight for each instance
443    for(int x = 0; x < data.numInstances(); x++) {
444        instanceIndex[0][x] = 1;
445        instanceIndex[1][x] = 1;  // leaf instances are in atbop
446    }
447
448    // first call to graft
449    traverseTree(data, instanceIndex, limits, this, 0, -1);
450  }
451
452
453  /**
454   * recursive function.
455   * if this node is a leaf then calls findGraft, otherwise sorts
456   * the two sets of instances (tracked in iindex array) and calls
457   * sortInstances for each of the child nodes (which then calls
458   * this method).
459   *
460   * @param fulldata all instances
461   * @param iindex array the tracks the weight of each instance in
462   *        the atbop and at the leaf (0.0 if not present)
463   * @param limits array specifying current upper/lower limits for numeric atts
464   * @param parent the node immediately before the current one
465   * @param pL laplace for node, as calculated by parent (in case leaf is empty)
466   * @param nodeClass class of node, determined by parent (in case leaf empty)
467   */
468  private void traverseTree(Instances fulldata, double [][] iindex, 
469     double[][] limits, C45PruneableClassifierTreeG parent, 
470     double pL, int nodeClass) throws Exception {
471   
472    if(m_isLeaf) {
473
474       findGraft(fulldata, iindex, limits, 
475                 (ClassifierTree)parent, pL, nodeClass);
476
477    } else {
478
479       // traverse each branch
480       for(int i = 0; i < localModel().numSubsets(); i++) {
481
482          double [][] newiindex = new double[2][fulldata.numInstances()];
483          for(int x = 0; x < 2; x++)
484             System.arraycopy(iindex[x], 0, newiindex[x], 0, iindex[x].length);
485          sortInstances(fulldata, newiindex, limits, i);
486       }
487    }
488  }
489
490  /**
491   * sorts/deletes instances into/from node and atbop according to
492   * the test for subset, then calls traverseTree for subset's node.
493   *
494   * @param fulldata all instances
495   * @param iindex array the tracks the weight of each instance in
496   *        the atbop and at the leaf (0.0 if not present)
497   * @param limits array specifying current upper/lower limits for numeric atts
498   * @param subset the subset for which to sort instances into inode & iatbop
499   */
500  private void sortInstances(Instances fulldata, double [][] iindex, 
501                   double [][] limits, int subset) throws Exception {
502
503    C45Split test = (C45Split)localModel();
504
505    // update the instances index for subset
506    double knownCases = 0;
507    double thisSubsetCount = 0;
508    for(int x = 0; x < iindex[0].length; x++) {
509       if(iindex[0][x] == 0 && iindex[1][x] == 0) // skip "discarded" instances
510          continue;
511       if(!fulldata.instance(x).isMissing(test.attIndex())) {
512          knownCases += iindex[0][x];
513          if(test.whichSubset(fulldata.instance(x)) != subset) {
514             if(iindex[0][x] > 0) {
515                // move to atbop, delete from leaf
516                iindex[1][x] = iindex[0][x];
517                iindex[0][x] = 0;
518             } else {
519                if(iindex[1][x] > 0) {
520                   // instance is now "discarded"
521                   iindex[1][x] = 0;
522                }
523             }
524          } else {
525             thisSubsetCount += iindex[0][x];
526          }
527       }
528    }
529
530    // work out proportions of weight for missing values for leaf and atbop
531    double lprop = (knownCases == 0) ? (1.0 / (double)test.numSubsets()) 
532                                : (thisSubsetCount / (double)knownCases);
533
534    // add in the instances that have missing value for attIndex
535    for(int x = 0; x < iindex[0].length; x++) {
536       if(iindex[0][x] == 0 && iindex[1][x] == 0)
537          continue;     // skip "discarded" instances
538       if(fulldata.instance(x).isMissing(test.attIndex())) {
539          iindex[1][x] -= (iindex[1][x] - iindex[0][x]) * (1-lprop);
540          iindex[0][x] *= lprop;
541       }
542    }
543
544    int nodeClass = localModel().distribution().maxClass(subset);
545    double pL = (localModel().distribution().perClass(nodeClass) + 1.0)
546               / (localModel().distribution().total() + 2.0);
547
548    // call traerseTree method for the child node
549    son(subset).traverseTree(fulldata, iindex,
550          test.minsAndMaxs(fulldata, limits, subset), this, pL, nodeClass);
551  }
552
553  /**
554   * finds new nodes that improve accuracy and grafts them onto the tree
555   *
556   * @param fulldata the instances in whole trainset
557   * @param iindex records num tests each instance has failed up to this node
558   * @param limits the upper/lower limits for numeric attributes
559   * @param parent the node immediately before the current one
560   * @param pLaplace laplace for leaf, calculated by parent (in case leaf empty)
561   * @param pLeafClass class of leaf, determined by parent (in case leaf empty)
562   */
563  private void findGraft(Instances fulldata, double [][] iindex, 
564   double [][] limits, ClassifierTree parent, double pLaplace, 
565   int pLeafClass) throws Exception {
566
567    // get the class for this leaf
568    int leafClass = (m_isEmpty)
569                       ? pLeafClass
570                       :  localModel().distribution().maxClass();
571
572    // get the laplace value for this leaf
573    double leafLaplace = (m_isEmpty)
574                            ? pLaplace
575                            : laplaceLeaf(leafClass);
576
577    // sort the instances into those at the leaf, those in atbop, and discarded
578    Instances l = new Instances(fulldata, fulldata.numInstances());
579    Instances n = new Instances(fulldata, fulldata.numInstances());
580    int lcount = 0;
581    int acount = 0;
582    for(int x = 0; x < fulldata.numInstances(); x++) {
583       if(iindex[0][x] <= 0 && iindex[1][x] <= 0)
584          continue;
585       if(iindex[0][x] != 0) {
586          l.add(fulldata.instance(x));
587          l.instance(lcount).setWeight(iindex[0][x]);
588          // move instance's weight in iindex to same index as in l
589          iindex[0][lcount++] = iindex[0][x];
590       }
591       if(iindex[1][x] > 0) {
592          n.add(fulldata.instance(x));
593          n.instance(acount).setWeight(iindex[1][x]);
594          // move instance's weight in iindex to same index as in n
595          iindex[1][acount++] = iindex[1][x];
596       }
597    }
598
599    boolean graftPossible = false;
600    double [] classDist = new double[n.numClasses()];
601    for(int x = 0; x < n.numInstances(); x++) {
602       if(iindex[1][x] > 0 && !n.instance(x).classIsMissing())
603          classDist[(int)n.instance(x).classValue()] += iindex[1][x];
604    }
605
606    for(int cVal = 0; cVal < n.numClasses(); cVal++) {
607       double theLaplace = (classDist[cVal] + 1.0) / (classDist[cVal] + 2.0);
608       if(cVal != leafClass && (theLaplace > leafLaplace) && 
609        (biprob(classDist[cVal], classDist[cVal], leafLaplace)
610         > m_BiProbCrit)) {
611          graftPossible = true;
612          break;
613       }
614    }
615
616    if(!graftPossible) {
617       return;
618    }
619
620    // 1. Initialize to {} a set of tuples t containing potential tests
621    ArrayList t = new ArrayList();
622
623    // go through each attribute
624    for(int a = 0; a < n.numAttributes(); a++) {
625       if(a == n.classIndex())
626          continue;   // skip the class
627
628       // sort instances in atbop by $a
629       int [] sorted = sortByAttribute(n, a);
630
631       // 2. For each continuous attribute $a:
632       if(n.attribute(a).isNumeric()) {
633
634          // find min and max values for this attribute at the leaf
635          boolean prohibited = false;
636          double minLeaf = Double.POSITIVE_INFINITY;
637          double maxLeaf = Double.NEGATIVE_INFINITY;
638          for(int i = 0; i < l.numInstances(); i++) {
639             if(l.instance(i).isMissing(a)) {
640                if(l.instance(i).classValue() == leafClass) {
641                   prohibited = true;
642                   break;
643                }
644             }
645             double value = l.instance(i).value(a);
646             if(!m_relabel || l.instance(i).classValue() == leafClass) {
647                if(value < minLeaf)
648                   minLeaf = value;
649                if(value > maxLeaf)
650                   maxLeaf = value;
651             }
652          }
653          if(prohibited) {
654             continue;
655          }
656
657          // (a) find values of
658          //    $n: instances in atbop (already have that, actually)
659          //    $v: a value for $a that exists for a case in the atbop, where
660          //       $v is < the min value for $a for a case at the leaf which
661          //       has the class $c, and $v is > the lowerlimit of $a at
662          //       the leaf.
663          //       (note: error in original paper stated that $v must be
664          //       smaller OR EQUAL TO the min value).
665          //    $k: $k is a class
666          //  that maximize L' = Laplace({$x: $x contained in cases($n)
667          //    & value($a,$x) <= $v & value($a,$x) > lowerlim($l,$a)}, $k).
668          double minBestClass = Double.NaN;
669          double minBestLaplace = leafLaplace;
670          double minBestVal = Double.NaN;
671          double minBestPos = Double.NaN;
672          double minBestTotal = Double.NaN;
673          double [][] minBestCounts = null;
674          double [][] counts = new double[2][n.numClasses()];
675          for(int x = 0; x < n.numInstances(); x++) {
676             if(n.instance(sorted[x]).isMissing(a))
677                break;   // missing are sorted to end: no more valid vals
678
679             double theval = n.instance(sorted[x]).value(a);
680             if(m_Debug)
681                System.out.println("\t " + theval);
682
683             if(theval <= limits[a][0]) {
684                if(m_Debug)
685                   System.out.println("\t  <= lowerlim: continuing...");
686                continue;
687             }
688             // note: error in paper would have this read "theVal > minLeaf)
689             if(theval >= minLeaf) {
690                if(m_Debug)
691                   System.out.println("\t  >= minLeaf; breaking...");
692                break;
693             }
694             counts[0][(int)n.instance(sorted[x]).classValue()]
695                += iindex[1][sorted[x]];
696
697             if(x != n.numInstances() - 1) {
698                int z = x + 1;
699                while(z < n.numInstances()
700                 && n.instance(sorted[z]).value(a) == theval) {
701                   z++; x++;
702                   counts[0][(int)n.instance(sorted[x]).classValue()] 
703                    += iindex[1][sorted[x]];
704                }
705             }
706
707             // work out the best laplace/class (for <= theval)
708             double total = Utils.sum(counts[0]);
709             for(int c = 0; c < n.numClasses(); c++) {
710                double temp = (counts[0][c]+1.0)/(total+2.0);
711                if(temp > minBestLaplace) {
712                   minBestPos = counts[0][c];
713                   minBestTotal = total;
714                   minBestLaplace = temp;
715                   minBestClass = c;
716                   minBestCounts = copyCounts(counts);
717
718                   minBestVal = (x == n.numInstances()-1) 
719                      ? theval
720                      : ((theval + n.instance(sorted[x+1]).value(a)) / 2.0);
721                }
722             }
723          }
724
725          // (b) add to t tuple <n,a,v,k,L',"<=">
726          if(!Double.isNaN(minBestVal)
727             && biprob(minBestPos, minBestTotal, leafLaplace) > m_BiProbCrit) {
728             GraftSplit gsplit = null;
729             try {
730                gsplit = new GraftSplit(a, minBestVal, 0,
731                                        leafClass, minBestCounts);
732             } catch (Exception e) {
733                System.err.println("graftsplit error: "+e.getMessage());
734                System.exit(1);
735             }
736             t.add(gsplit);
737          }
738          // free space
739          minBestCounts = null;
740
741          // (c) find values of
742          //    n: instances in atbop (already have that, actually)
743          //    $v: a value for $a that exists for a case in the atbop, where
744          //       $v is > the max value for $a for a case at the leaf which
745          //       has the class $c, and $v is <= the upperlimit of $a at
746          //       the leaf.
747          //    k: k is a class
748          //   that maximize L' = Laplace({x: x contained in cases(n)
749          //       & value(a,x) > v & value(a,x) <= upperlim(l,a)}, k).
750          double maxBestClass = -1;
751          double maxBestLaplace = leafLaplace;
752          double maxBestVal = Double.NaN;
753          double maxBestPos = Double.NaN;
754          double maxBestTotal = Double.NaN;
755          double [][] maxBestCounts = null;
756          for(int c = 0; c < n.numClasses(); c++) {  // zero the counts
757             counts[0][c] = 0;
758             counts[1][c] = 0;  // shouldn't need to do this ...
759          }
760
761          // check smallest val for a in atbop is < upper limit
762          if(n.numInstances() >= 1
763           && n.instance(sorted[0]).value(a) < limits[a][1]) {
764             for(int x = n.numInstances() - 1; x >= 0; x--) {
765                if(n.instance(sorted[x]).isMissing(a))
766                   continue;
767
768                double theval = n.instance(sorted[x]).value(a);
769                if(m_Debug)
770                   System.out.println("\t " + theval);
771
772                if(theval > limits[a][1]) {
773                   if(m_Debug)
774                      System.out.println("\t  >= upperlim; continuing...");
775                   continue;
776                }
777                if(theval <= maxLeaf) {
778                   if(m_Debug)
779                      System.out.println("\t  < maxLeaf; breaking...");
780                   break;
781                }
782
783                // increment counts
784                counts[1][(int)n.instance(sorted[x]).classValue()] 
785                   += iindex[1][sorted[x]];
786
787                if(x != 0 && !n.instance(sorted[x-1]).isMissing(a)) {
788                   int z = x - 1;
789                   while(z >= 0 && n.instance(sorted[z]).value(a) == theval) {
790                      z--; x--;
791                      counts[1][(int)n.instance(sorted[x]).classValue()]
792                         += iindex[1][sorted[x]];
793                   }
794                }
795
796                // work out best laplace for > theval
797                double total = Utils.sum(counts[1]);
798                for(int c = 0; c < n.numClasses(); c++) {
799                   double temp = (counts[1][c]+1.0)/(total+2.0);
800                   if(temp > maxBestLaplace ) {
801                      maxBestPos = counts[1][c];
802                      maxBestTotal = total;
803                      maxBestLaplace = temp;
804                      maxBestClass = c;
805                      maxBestCounts = copyCounts(counts);
806                      maxBestVal = (x == 0) 
807                        ? theval
808                        : ((theval + n.instance(sorted[x-1]).value(a)) / 2.0);
809                   }
810                }
811             }
812
813             // (d) add to t tuple <n,a,v,k,L',">">
814             if(!Double.isNaN(maxBestVal)
815               && biprob(maxBestPos,maxBestTotal,leafLaplace) > m_BiProbCrit) {
816                GraftSplit gsplit = null;
817                try {
818                   gsplit = new GraftSplit(a, maxBestVal, 1,
819                      leafClass, maxBestCounts);
820                } catch (Exception e) {
821                   System.err.println("graftsplit error:" + e.getMessage());
822                   System.exit(1);
823                }
824                t.add(gsplit);
825             }
826          }
827       } else {    // must be a nominal attribute
828
829          // 3. for each discrete attribute a for which there is no
830          //    test at an ancestor of l
831
832          // skip if this attribute has already been used
833          if(limits[a][1] == 1) {
834             continue;
835          }
836
837          boolean [] prohibit = new boolean[l.attribute(a).numValues()];
838          for(int aval = 0; aval < n.attribute(a).numValues(); aval++) {
839             for(int x = 0; x < l.numInstances(); x++) {
840                if((l.instance(x).isMissing(a)
841                    || l.instance(x).value(a) == aval) 
842                 && (!m_relabel || (l.instance(x).classValue() == leafClass))) {
843                   prohibit[aval] = true;
844                   break;
845                }
846             }
847          }
848
849          // (a) find values of
850          //       $n: instances in atbop (already have that, actually)
851          //       $v: $v is a value for $a
852          //       $k: $k is a class
853          //     that maximize L' = Laplace({$x: $x contained in cases($n)
854          //           & value($a,$x) = $v}, $k).
855          double bestVal = Double.NaN;
856          double bestClass = Double.NaN;
857          double bestLaplace = leafLaplace;
858          double [][] bestCounts = null;
859          double [][] counts = new double[2][n.numClasses()];
860
861          for(int x = 0; x < n.numInstances(); x++) {
862             if(n.instance(sorted[x]).isMissing(a))
863                continue;
864
865             // zero the counts
866             for(int c = 0; c < n.numClasses(); c++)
867                counts[0][c] = 0;
868
869             double theval = n.instance(sorted[x]).value(a);
870             counts[0][(int)n.instance(sorted[x]).classValue()] 
871               += iindex[1][sorted[x]];
872
873             if(x != n.numInstances() - 1) {
874                int z = x + 1;
875                while(z < n.numInstances() 
876                 && n.instance(sorted[z]).value(a) == theval) {
877                   z++; x++;
878                   counts[0][(int)n.instance(sorted[x]).classValue()]
879                      += iindex[1][sorted[x]];
880                }
881             }
882
883             if(!prohibit[(int)theval]) {
884                // work out best laplace for > theval
885                double total = Utils.sum(counts[0]);
886                bestLaplace = leafLaplace;
887                bestClass = Double.NaN;
888                for(int c = 0; c < n.numClasses(); c++) {
889                   double temp = (counts[0][c]+1.0)/(total+2.0);
890                   if(temp > bestLaplace
891                    && biprob(counts[0][c],total,leafLaplace) > m_BiProbCrit) {
892                      bestLaplace = temp;
893                      bestClass = c;
894                      bestVal = theval;
895                      bestCounts = copyCounts(counts);
896                   }
897                }
898                // add to graft list
899                if(!Double.isNaN(bestClass)) {
900                   GraftSplit gsplit = null;
901                   try {
902                      gsplit = new GraftSplit(a, bestVal, 2,
903                         leafClass, bestCounts);
904                   } catch (Exception e) {
905                     System.err.println("graftsplit error: "+e.getMessage());
906                     System.exit(1);
907                   }
908                   t.add(gsplit);
909                }
910             }
911          }
912          // (b) add to t tuple <n,a,v,k,L',"=">
913          // done this already
914       }
915    }
916
917    // 4. remove from t all tuples <n,a,v,c,L,x> such that L <=
918    //    Laplace(cases(l),c) or prob(x,n,Laplace(cases(l),c) <= 0.05
919    //      -- checked this constraint prior to adding a tuple --
920
921    // *** step six done before step five for efficiency ***
922    // 6. for each <n,a,v,k,L,x> in t ordered on L from highest to lowest
923    // order the tuples from highest to lowest laplace
924    // (this actually orders lowest to highest)
925    Collections.sort(t);
926
927    // 5. remove from t all tuples <n,a,v,c,L,x> such that there is
928    //    no tuple <n',a',v',k',L',x'> such that k' != c & L' < L.
929    for(int x = 0; x < t.size(); x++) {
930       GraftSplit gs = (GraftSplit)t.get(x);
931       if(gs.maxClassForSubsetOfInterest() != leafClass) {
932          break; // reached a graft with class != leafClass, so stop deleting
933       } else {
934          t.remove(x);
935          x--;
936       }
937    }
938
939    // if no potential grafts were found, do nothing and return
940    if(t.size() < 1) {
941       return;
942    }
943
944    // create the distributions for each graft
945    for(int x = t.size()-1; x >= 0; x--) {
946       GraftSplit gs = (GraftSplit)t.get(x);
947       try {
948          gs.buildClassifier(l);
949          gs.deleteGraftedCases(l); // so they don't go down the other branch
950       } catch (Exception e) {
951          System.err.println("graftsplit build error: " + e.getMessage());
952       }
953    }
954
955    // add this stuff to the tree
956    ((C45PruneableClassifierTreeG)parent).setDescendents(t, this);
957  }
958
959  /**
960   * sorts the int array in ascending order by attribute indexed
961   * by a in dataset data. 
962   * @param the data the indices represent
963   * @param the index of the attribute to sort by
964   * @return array of sorted indicies
965   */
966  private int [] sortByAttribute(Instances data, int a) {
967
968    double [] attList = data.attributeToDoubleArray(a);
969    int [] temp = Utils.sort(attList);
970    return temp;
971  }
972
973  /**
974   * deep copy the 2d array of counts
975   *
976   * @param src the array to copy
977   * @return a copy of src
978   */
979  private double [][] copyCounts(double [][] src) {
980
981    double [][] newArr = new double[src.length][0];
982    for(int x = 0; x < src.length; x++) {
983       newArr[x] = new double[src[x].length];
984       for(int y = 0; y < src[x].length; y++) {
985          newArr[x][y] = src[x][y];
986       }
987    }
988    return newArr;
989  }
990 
991
992  /**
993   * Help method for computing class probabilities of
994   * a given instance.
995   *
996   * @throws Exception if something goes wrong
997   */
998  private double getProbsLaplace(int classIndex, Instance instance, double weight)
999       throws Exception {
1000
1001    double [] weights;
1002    double prob = 0;
1003    int treeIndex;
1004    int i,j;
1005
1006    if (m_isLeaf) {
1007       return weight * localModel().classProbLaplace(classIndex, instance, -1);
1008    } else {
1009       treeIndex = localModel().whichSubset(instance);
1010
1011       if (treeIndex == -1) {
1012          weights = localModel().weights(instance);
1013          for (i = 0; i < m_sons.length; i++) {
1014             if (!son(i).m_isEmpty) {
1015                if (!son(i).m_isLeaf) {
1016                   prob += son(i).getProbsLaplace(classIndex, instance,
1017                                                  weights[i] * weight);
1018                } else {
1019                   prob += weight * weights[i] *
1020                     localModel().classProbLaplace(classIndex, instance, i);
1021                }
1022             }
1023          }
1024          return prob;
1025       } else {
1026
1027          if (son(treeIndex).m_isLeaf) {
1028             return weight * localModel().classProbLaplace(classIndex, instance,
1029                                                           treeIndex);
1030          } else {
1031             return son(treeIndex).getProbsLaplace(classIndex,instance,weight);
1032          }
1033       }
1034    }
1035  }
1036
1037
1038  /**
1039   * Help method for computing class probabilities of
1040   * a given instance.
1041   *
1042   * @throws Exception if something goes wrong
1043   */
1044  private double getProbs(int classIndex, Instance instance, double weight)
1045      throws Exception {
1046
1047    double [] weights;
1048    double prob = 0;
1049    int treeIndex;
1050    int i,j;
1051
1052    if (m_isLeaf) {
1053       return weight * localModel().classProb(classIndex, instance, -1);
1054    } else {
1055       treeIndex = localModel().whichSubset(instance);
1056       if (treeIndex == -1) {
1057          weights = localModel().weights(instance);
1058          for (i = 0; i < m_sons.length; i++) {
1059             if (!son(i).m_isEmpty) {
1060                prob += son(i).getProbs(classIndex, instance,
1061                                 weights[i] * weight);
1062             }
1063          }
1064          return prob;
1065       } else {
1066
1067          if (son(treeIndex).m_isEmpty) {
1068             return weight * localModel().classProb(classIndex, instance,
1069                                                    treeIndex);
1070          } else {
1071             return son(treeIndex).getProbs(classIndex, instance, weight);
1072          }
1073       }
1074    }
1075  }
1076
1077
1078
1079  /**
1080   * add the grafted nodes at originalLeaf's position in tree.
1081   * a recursive function that terminates when t is empty.
1082   *
1083   * @param t the list of nodes to graft
1084   * @param originalLeaf the leaf that the grafts are replacing
1085   */
1086  public void setDescendents(ArrayList t, 
1087                             C45PruneableClassifierTreeG originalLeaf) {
1088
1089    Instances headerInfo = new Instances(m_train, 0);
1090
1091    boolean end = false;
1092    ClassifierSplitModel splitmod = null;
1093    C45PruneableClassifierTreeG newNode;
1094    if(t.size() > 0) {
1095       splitmod = (ClassifierSplitModel)t.remove(t.size() - 1);
1096       newNode = new C45PruneableClassifierTreeG(m_toSelectModel, headerInfo,
1097                           splitmod, m_pruneTheTree, m_CF, m_subtreeRaising,
1098                           false, m_relabel, m_cleanup);
1099    } else {
1100       // get the leaf for one of newNode's children
1101       NoSplit kLeaf = ((GraftSplit)localModel()).getOtherLeaf();
1102       newNode = 
1103             new C45PruneableClassifierTreeG(m_toSelectModel, headerInfo,
1104                           kLeaf, m_pruneTheTree, m_CF, m_subtreeRaising,
1105                           true, m_relabel, m_cleanup);
1106       end = true;
1107    }
1108
1109    // behave differently for parent of original leaf, since we don't
1110    // want to destroy any of its other branches
1111    if(m_sons != null) {
1112       for(int x = 0; x < m_sons.length; x++) {
1113          if(son(x).equals(originalLeaf)) {
1114             m_sons[x] = newNode;  // replace originalLeaf with newNode
1115          }
1116       }
1117    } else {
1118
1119       // allocate space for the children
1120       m_sons = new C45PruneableClassifierTreeG[localModel().numSubsets()];
1121 
1122       // get the leaf for one of newNode's children
1123       NoSplit kLeaf = ((GraftSplit)localModel()).getLeaf();
1124       C45PruneableClassifierTreeG kNode = 
1125                 new C45PruneableClassifierTreeG(m_toSelectModel, headerInfo,
1126                               kLeaf, m_pruneTheTree, m_CF, m_subtreeRaising,
1127                               true, m_relabel, m_cleanup);
1128 
1129       // figure where to put the new node
1130       if(((GraftSplit)localModel()).subsetOfInterest() == 0) {
1131          m_sons[0] = kNode;
1132          m_sons[1] = newNode;
1133       } else {
1134          m_sons[0] = newNode;
1135          m_sons[1] = kNode;
1136       }
1137    }
1138    if(!end)
1139       ((C45PruneableClassifierTreeG)newNode).setDescendents
1140                  (t, (C45PruneableClassifierTreeG)originalLeaf);
1141  }
1142
1143
1144  /**
1145   *  class prob with laplace correction (assumes binary class)
1146   */
1147  private double laplaceLeaf(double classIndex) {
1148    double l =  (localModel().distribution().perClass((int)classIndex) + 1.0)
1149               / (localModel().distribution().total() + 2.0);
1150    return l;
1151  }
1152
1153
1154  /**
1155   * Significance test
1156   * @param x
1157   * @param n
1158   * @param r
1159   * @return returns the probability of obtaining x or MORE out of n
1160   * if r proportion of n are positive.
1161   *
1162   * z for normal estimation of binomial probability of obtaining x
1163   * or more out of n, if r proportion of n are positive
1164   */
1165  public double biprob(double x, double n, double r) throws Exception {
1166
1167    return ((((x) - 0.5) - (n) * (r)) / Math.sqrt((n) * (r) * (1.0 - (r))));
1168  }
1169
1170  /**
1171   * Prints tree structure.
1172   */
1173  public String toString() {
1174
1175    try {
1176       StringBuffer text = new StringBuffer();
1177
1178       if(m_isLeaf) {
1179          text.append(": ");
1180          if(m_localModel instanceof GraftSplit)
1181             text.append(((GraftSplit)m_localModel).dumpLabelG(0,m_train));
1182          else
1183             text.append(m_localModel.dumpLabel(0,m_train));
1184       } else
1185          dumpTree(0,text);
1186       text.append("\n\nNumber of Leaves  : \t"+numLeaves()+"\n");
1187       text.append("\nSize of the tree : \t"+numNodes()+"\n");
1188
1189       return text.toString();
1190    } catch (Exception e) {
1191       return "Can't print classification tree.";
1192    }
1193  }
1194
1195  /**
1196   * Help method for printing tree structure.
1197   *
1198   * @throws Exception if something goes wrong
1199   */
1200  protected void dumpTree(int depth,StringBuffer text) throws Exception {
1201
1202    int i,j;
1203
1204    for(i=0;i<m_sons.length;i++) {
1205       text.append("\n");;
1206       for(j=0;j<depth;j++)
1207          text.append("|   ");
1208       text.append(m_localModel.leftSide(m_train));
1209       text.append(m_localModel.rightSide(i, m_train));
1210       if(m_sons[i].m_isLeaf) {
1211          text.append(": ");
1212          if(m_localModel instanceof GraftSplit)
1213             text.append(((GraftSplit)m_localModel).dumpLabelG(i,m_train));
1214          else
1215             text.append(m_localModel.dumpLabel(i,m_train));
1216       } else
1217          ((C45PruneableClassifierTreeG)m_sons[i]).dumpTree(depth+1,text);
1218     }
1219  }
1220 
1221  /**
1222   * Returns the revision string.
1223   *
1224   * @return            the revision
1225   */
1226  public String getRevision() {
1227    return RevisionUtils.extract("$Revision: 5532 $");
1228  }
1229}
Note: See TracBrowser for help on using the repository browser.