source: src/main/java/weka/classifiers/trees/j48/PruneableClassifierTree.java @ 28

Last change on this file since 28 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 6.1 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    PruneableClassifierTree.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.trees.j48;
24
25import weka.core.Capabilities;
26import weka.core.Instances;
27import weka.core.RevisionUtils;
28import weka.core.Utils;
29import weka.core.Capabilities.Capability;
30
31import java.util.Random;
32
33/**
34 * Class for handling a tree structure that can
35 * be pruned using a pruning set.
36 *
37 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
38 * @version $Revision: 5533 $
39 */
40public class PruneableClassifierTree 
41  extends ClassifierTree {
42 
43  /** for serialization */
44  static final long serialVersionUID = -555775736857600201L;
45
46  /** True if the tree is to be pruned. */
47  private boolean pruneTheTree = false;
48
49  /** How many subsets of equal size? One used for pruning, the rest for training. */
50  private int numSets = 3;
51
52  /** Cleanup after the tree has been built. */
53  private boolean m_cleanup = true;
54
55  /** The random number seed. */
56  private int m_seed = 1;
57
58  /**
59   * Constructor for pruneable tree structure. Stores reference
60   * to associated training data at each node.
61   *
62   * @param toSelectLocModel selection method for local splitting model
63   * @param pruneTree true if the tree is to be pruned
64   * @param num number of subsets of equal size
65   * @param cleanup
66   * @param seed the seed value to use
67   * @throws Exception if something goes wrong
68   */
69  public PruneableClassifierTree(ModelSelection toSelectLocModel,
70                                 boolean pruneTree, int num, boolean cleanup,
71                                 int seed)
72       throws Exception {
73
74    super(toSelectLocModel);
75
76    pruneTheTree = pruneTree;
77    numSets = num;
78    m_cleanup = cleanup;
79    m_seed = seed;
80  }
81
82  /**
83   * Returns default capabilities of the classifier tree.
84   *
85   * @return      the capabilities of this classifier tree
86   */
87  public Capabilities getCapabilities() {
88    Capabilities result = super.getCapabilities();
89    result.disableAll();
90
91    // attributes
92    result.enable(Capability.NOMINAL_ATTRIBUTES);
93    result.enable(Capability.NUMERIC_ATTRIBUTES);
94    result.enable(Capability.DATE_ATTRIBUTES);
95    result.enable(Capability.MISSING_VALUES);
96
97    // class
98    result.enable(Capability.NOMINAL_CLASS);
99    result.enable(Capability.MISSING_CLASS_VALUES);
100
101    // instances
102    result.setMinimumNumberInstances(0);
103   
104    return result;
105  }
106
107  /**
108   * Method for building a pruneable classifier tree.
109   *
110   * @param data the data to build the tree from
111   * @throws Exception if tree can't be built successfully
112   */
113  public void buildClassifier(Instances data) 
114       throws Exception {
115
116    // can classifier tree handle the data?
117    getCapabilities().testWithFail(data);
118
119    // remove instances with missing class
120    data = new Instances(data);
121    data.deleteWithMissingClass();
122   
123   Random random = new Random(m_seed);
124   data.stratify(numSets);
125   buildTree(data.trainCV(numSets, numSets - 1, random),
126             data.testCV(numSets, numSets - 1), false);
127   if (pruneTheTree) {
128     prune();
129   }
130   if (m_cleanup) {
131     cleanup(new Instances(data, 0));
132   }
133  }
134
135  /**
136   * Prunes a tree.
137   *
138   * @throws Exception if tree can't be pruned successfully
139   */
140  public void prune() throws Exception {
141 
142    if (!m_isLeaf) {
143     
144      // Prune all subtrees.
145      for (int i = 0; i < m_sons.length; i++)
146        son(i).prune();
147     
148      // Decide if leaf is best choice.
149      if (Utils.smOrEq(errorsForLeaf(),errorsForTree())) {
150       
151        // Free son Trees
152        m_sons = null;
153        m_isLeaf = true;
154       
155        // Get NoSplit Model for node.
156        m_localModel = new NoSplit(localModel().distribution());
157      }
158    }
159  }
160
161  /**
162   * Returns a newly created tree.
163   *
164   * @param train the training data
165   * @param test the test data
166   * @return the generated tree
167   * @throws Exception if something goes wrong
168   */
169  protected ClassifierTree getNewTree(Instances train, Instances test) 
170       throws Exception {
171
172    PruneableClassifierTree newTree = 
173      new PruneableClassifierTree(m_toSelectModel, pruneTheTree, numSets, m_cleanup,
174                                  m_seed);
175    newTree.buildTree(train, test, false);
176    return newTree;
177  }
178
179  /**
180   * Computes estimated errors for tree.
181   *
182   * @return the estimated errors
183   * @throws Exception if error estimate can't be computed
184   */
185  private double errorsForTree() throws Exception {
186
187    double errors = 0;
188
189    if (m_isLeaf)
190      return errorsForLeaf();
191    else{
192      for (int i = 0; i < m_sons.length; i++)
193        if (Utils.eq(localModel().distribution().perBag(i), 0)) {
194          errors += m_test.perBag(i)-
195            m_test.perClassPerBag(i,localModel().distribution().
196                                maxClass());
197        } else
198          errors += son(i).errorsForTree();
199
200      return errors;
201    }
202  }
203
204  /**
205   * Computes estimated errors for leaf.
206   *
207   * @return the estimated errors
208   * @throws Exception if error estimate can't be computed
209   */
210  private double errorsForLeaf() throws Exception {
211
212    return m_test.total()-
213      m_test.perClass(localModel().distribution().maxClass());
214  }
215
216  /**
217   * Method just exists to make program easier to read.
218   */
219  private ClassifierSplitModel localModel() {
220   
221    return (ClassifierSplitModel)m_localModel;
222  }
223
224  /**
225   * Method just exists to make program easier to read.
226   */
227  private PruneableClassifierTree son(int index) {
228
229    return (PruneableClassifierTree)m_sons[index];
230  }
231 
232  /**
233   * Returns the revision string.
234   *
235   * @return            the revision
236   */
237  public String getRevision() {
238    return RevisionUtils.extract("$Revision: 5533 $");
239  }
240}
Note: See TracBrowser for help on using the repository browser.