source: src/main/java/weka/classifiers/trees/j48/C45PruneableClassifierTree.java @ 11

Last change on this file since 11 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 9.6 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    C45PruneableClassifierTree.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.trees.j48;
24
25import weka.core.Capabilities;
26import weka.core.Instances;
27import weka.core.RevisionUtils;
28import weka.core.Utils;
29import weka.core.Capabilities.Capability;
30
31/**
32 * Class for handling a tree structure that can
33 * be pruned using C4.5 procedures.
34 *
35 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
36 * @version $Revision: 6073 $
37 */
38
39public class C45PruneableClassifierTree 
40  extends ClassifierTree {
41
42  /** for serialization */
43  static final long serialVersionUID = -4813820170260388194L;
44 
45  /** True if the tree is to be pruned. */
46  boolean m_pruneTheTree = false;
47 
48  /** True if the tree is to be collapsed. */
49  boolean m_collapseTheTree = false;
50
51  /** The confidence factor for pruning. */
52  float m_CF = 0.25f;
53
54  /** Is subtree raising to be performed? */
55  boolean m_subtreeRaising = true;
56
57  /** Cleanup after the tree has been built. */
58  boolean m_cleanup = true;
59
60  /**
61   * Constructor for pruneable tree structure. Stores reference
62   * to associated training data at each node.
63   *
64   * @param toSelectLocModel selection method for local splitting model
65   * @param pruneTree true if the tree is to be pruned
66   * @param cf the confidence factor for pruning
67   * @param raiseTree
68   * @param cleanup
69   * @throws Exception if something goes wrong
70   */
71  public C45PruneableClassifierTree(ModelSelection toSelectLocModel,
72                                    boolean pruneTree,float cf,
73                                    boolean raiseTree,
74                                    boolean cleanup,
75                                    boolean collapseTree)
76       throws Exception {
77
78    super(toSelectLocModel);
79
80    m_pruneTheTree = pruneTree;
81    m_CF = cf;
82    m_subtreeRaising = raiseTree;
83    m_cleanup = cleanup;
84    m_collapseTheTree = collapseTree;
85  }
86
87  /**
88   * Returns default capabilities of the classifier tree.
89   *
90   * @return      the capabilities of this classifier tree
91   */
92  public Capabilities getCapabilities() {
93    Capabilities result = super.getCapabilities();
94    result.disableAll();
95
96    // attributes
97    result.enable(Capability.NOMINAL_ATTRIBUTES);
98    result.enable(Capability.NUMERIC_ATTRIBUTES);
99    result.enable(Capability.DATE_ATTRIBUTES);
100    result.enable(Capability.MISSING_VALUES);
101
102    // class
103    result.enable(Capability.NOMINAL_CLASS);
104    result.enable(Capability.MISSING_CLASS_VALUES);
105
106    // instances
107    result.setMinimumNumberInstances(0);
108   
109    return result;
110  }
111
112  /**
113   * Method for building a pruneable classifier tree.
114   *
115   * @param data the data for building the tree
116   * @throws Exception if something goes wrong
117   */
118  public void buildClassifier(Instances data) throws Exception {
119
120    // can classifier tree handle the data?
121    getCapabilities().testWithFail(data);
122
123    // remove instances with missing class
124    data = new Instances(data);
125    data.deleteWithMissingClass();
126   
127   buildTree(data, m_subtreeRaising);
128   if (m_collapseTheTree) {
129     collapse();
130   }
131   if (m_pruneTheTree) {
132     prune();
133   }
134   if (m_cleanup) {
135     cleanup(new Instances(data, 0));
136   }
137  }
138
139  /**
140   * Collapses a tree to a node if training error doesn't increase.
141   */
142  public final void collapse(){
143
144    double errorsOfSubtree;
145    double errorsOfTree;
146    int i;
147
148    if (!m_isLeaf){
149      errorsOfSubtree = getTrainingErrors();
150      errorsOfTree = localModel().distribution().numIncorrect();
151      if (errorsOfSubtree >= errorsOfTree-1E-3){
152
153        // Free adjacent trees
154        m_sons = null;
155        m_isLeaf = true;
156                       
157        // Get NoSplit Model for tree.
158        m_localModel = new NoSplit(localModel().distribution());
159      }else
160        for (i=0;i<m_sons.length;i++)
161          son(i).collapse();
162    }
163  }
164
165  /**
166   * Prunes a tree using C4.5's pruning procedure.
167   *
168   * @throws Exception if something goes wrong
169   */
170  public void prune() throws Exception {
171
172    double errorsLargestBranch;
173    double errorsLeaf;
174    double errorsTree;
175    int indexOfLargestBranch;
176    C45PruneableClassifierTree largestBranch;
177    int i;
178
179    if (!m_isLeaf){
180
181      // Prune all subtrees.
182      for (i=0;i<m_sons.length;i++)
183        son(i).prune();
184
185      // Compute error for largest branch
186      indexOfLargestBranch = localModel().distribution().maxBag();
187      if (m_subtreeRaising) {
188        errorsLargestBranch = son(indexOfLargestBranch).
189          getEstimatedErrorsForBranch((Instances)m_train);
190      } else {
191        errorsLargestBranch = Double.MAX_VALUE;
192      }
193
194      // Compute error if this Tree would be leaf
195      errorsLeaf = 
196        getEstimatedErrorsForDistribution(localModel().distribution());
197
198      // Compute error for the whole subtree
199      errorsTree = getEstimatedErrors();
200
201      // Decide if leaf is best choice.
202      if (Utils.smOrEq(errorsLeaf,errorsTree+0.1) &&
203          Utils.smOrEq(errorsLeaf,errorsLargestBranch+0.1)){
204
205        // Free son Trees
206        m_sons = null;
207        m_isLeaf = true;
208               
209        // Get NoSplit Model for node.
210        m_localModel = new NoSplit(localModel().distribution());
211        return;
212      }
213
214      // Decide if largest branch is better choice
215      // than whole subtree.
216      if (Utils.smOrEq(errorsLargestBranch,errorsTree+0.1)){
217        largestBranch = son(indexOfLargestBranch);
218        m_sons = largestBranch.m_sons;
219        m_localModel = largestBranch.localModel();
220        m_isLeaf = largestBranch.m_isLeaf;
221        newDistribution(m_train);
222        prune();
223      }
224    }
225  }
226
227  /**
228   * Returns a newly created tree.
229   *
230   * @param data the data to work with
231   * @return the new tree
232   * @throws Exception if something goes wrong
233   */
234  protected ClassifierTree getNewTree(Instances data) throws Exception {
235   
236    C45PruneableClassifierTree newTree = 
237      new C45PruneableClassifierTree(m_toSelectModel, m_pruneTheTree, m_CF,
238                                     m_subtreeRaising, m_cleanup, m_collapseTheTree);
239    newTree.buildTree((Instances)data, m_subtreeRaising);
240
241    return newTree;
242  }
243
244  /**
245   * Computes estimated errors for tree.
246   *
247   * @return the estimated errors
248   */
249  private double getEstimatedErrors(){
250
251    double errors = 0;
252    int i;
253
254    if (m_isLeaf)
255      return getEstimatedErrorsForDistribution(localModel().distribution());
256    else{
257      for (i=0;i<m_sons.length;i++)
258        errors = errors+son(i).getEstimatedErrors();
259      return errors;
260    }
261  }
262 
263  /**
264   * Computes estimated errors for one branch.
265   *
266   * @param data the data to work with
267   * @return the estimated errors
268   * @throws Exception if something goes wrong
269   */
270  private double getEstimatedErrorsForBranch(Instances data) 
271       throws Exception {
272
273    Instances [] localInstances;
274    double errors = 0;
275    int i;
276
277    if (m_isLeaf)
278      return getEstimatedErrorsForDistribution(new Distribution(data));
279    else{
280      Distribution savedDist = localModel().m_distribution;
281      localModel().resetDistribution(data);
282      localInstances = (Instances[])localModel().split(data);
283      localModel().m_distribution = savedDist;
284      for (i=0;i<m_sons.length;i++)
285        errors = errors+
286          son(i).getEstimatedErrorsForBranch(localInstances[i]);
287      return errors;
288    }
289  }
290
291  /**
292   * Computes estimated errors for leaf.
293   *
294   * @param theDistribution the distribution to use
295   * @return the estimated errors
296   */
297  private double getEstimatedErrorsForDistribution(Distribution
298                                                   theDistribution){
299
300    if (Utils.eq(theDistribution.total(),0))
301      return 0;
302    else
303      return theDistribution.numIncorrect()+
304        Stats.addErrs(theDistribution.total(),
305                      theDistribution.numIncorrect(),m_CF);
306  }
307
308  /**
309   * Computes errors of tree on training data.
310   *
311   * @return the training errors
312   */
313  private double getTrainingErrors(){
314
315    double errors = 0;
316    int i;
317
318    if (m_isLeaf)
319      return localModel().distribution().numIncorrect();
320    else{
321      for (i=0;i<m_sons.length;i++)
322        errors = errors+son(i).getTrainingErrors();
323      return errors;
324    }
325  }
326
327  /**
328   * Method just exists to make program easier to read.
329   *
330   * @return the local split model
331   */
332  private ClassifierSplitModel localModel(){
333   
334    return (ClassifierSplitModel)m_localModel;
335  }
336
337  /**
338   * Computes new distributions of instances for nodes
339   * in tree.
340   *
341   * @param data the data to compute the distributions for
342   * @throws Exception if something goes wrong
343   */
344  private void newDistribution(Instances data) throws Exception {
345
346    Instances [] localInstances;
347
348    localModel().resetDistribution(data);
349    m_train = data;
350    if (!m_isLeaf){
351      localInstances = 
352        (Instances [])localModel().split(data);
353      for (int i = 0; i < m_sons.length; i++)
354        son(i).newDistribution(localInstances[i]);
355    } else {
356
357      // Check whether there are some instances at the leaf now!
358      if (!Utils.eq(data.sumOfWeights(), 0)) {
359        m_isEmpty = false;
360      }
361    }
362  }
363
364  /**
365   * Method just exists to make program easier to read.
366   */
367  private C45PruneableClassifierTree son(int index){
368
369    return (C45PruneableClassifierTree)m_sons[index];
370  }
371 
372  /**
373   * Returns the revision string.
374   *
375   * @return            the revision
376   */
377  public String getRevision() {
378    return RevisionUtils.extract("$Revision: 6073 $");
379  }
380}
Note: See TracBrowser for help on using the repository browser.