source: branches/MetisMQI/src/main/java/weka/classifiers/rules/part/PruneableDecList.java

Last change on this file was 29, checked in by gnappo, 15 years ago

Taggata versione per la demo e aggiunto branch.

File size: 5.6 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    PruneableDecList.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.rules.part;
24
25import weka.classifiers.trees.j48.Distribution;
26import weka.classifiers.trees.j48.ModelSelection;
27import weka.classifiers.trees.j48.NoSplit;
28import weka.core.Instances;
29import weka.core.RevisionUtils;
30import weka.core.Utils;
31
32/**
33 * Class for handling a partial tree structure that
34 * can be pruned using a pruning set.
35 *
36 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
37 * @version $Revision: 1.10 $
38 */
39public class PruneableDecList
40  extends ClassifierDecList {
41
42  /** for serialization */
43  private static final long serialVersionUID = -7228103346297172921L;
44 
45  /**
46   * Constructor for pruneable partial tree structure.
47   *
48   * @param toSelectLocModel selection method for local splitting model
49   * @param minNum minimum number of objects in leaf
50   */
51  public PruneableDecList(ModelSelection toSelectLocModel,
52                          int minNum) {
53                               
54    super(toSelectLocModel, minNum);
55  }
56 
57  /**
58   * Method for building a pruned partial tree.
59   *
60   * @throws Exception if tree can't be built successfully
61   */
62  public void buildRule(Instances train,
63                        Instances test) throws Exception { 
64   
65    buildDecList(train, test, false);
66
67    cleanup(new Instances(train, 0));
68  }
69
70  /**
71   * Builds the partial tree with hold out set
72   *
73   * @throws Exception if something goes wrong
74   */
75  public void buildDecList(Instances train, Instances test, 
76                           boolean leaf) throws Exception {
77   
78    Instances [] localTrain,localTest;
79    int index,ind;
80    int i,j;
81    double sumOfWeights;
82    NoSplit noSplit;
83   
84    m_train = null;
85    m_isLeaf = false;
86    m_isEmpty = false;
87    m_sons = null;
88    indeX = 0;
89    sumOfWeights = train.sumOfWeights();
90    noSplit = new NoSplit (new Distribution((Instances)train));
91    if (leaf)
92      m_localModel = noSplit;
93    else
94      m_localModel = m_toSelectModel.selectModel(train, test);
95    m_test = new Distribution(test, m_localModel);
96    if (m_localModel.numSubsets() > 1) {
97      localTrain = m_localModel.split(train);
98      localTest = m_localModel.split(test);
99      train = null;
100      test = null;
101      m_sons = new ClassifierDecList [m_localModel.numSubsets()];
102      i = 0;
103      do {
104        i++;
105        ind = chooseIndex();
106        if (ind == -1) {
107          for (j = 0; j < m_sons.length; j++) 
108            if (m_sons[j] == null)
109              m_sons[j] = getNewDecList(localTrain[j],localTest[j],true);
110          if (i < 2) {
111            m_localModel = noSplit;
112            m_isLeaf = true;
113            m_sons = null;
114            if (Utils.eq(sumOfWeights,0))
115              m_isEmpty = true;
116            return;
117          }
118          ind = 0;
119          break;
120        } else 
121          m_sons[ind] = getNewDecList(localTrain[ind],localTest[ind],false);
122      } while ((i < m_sons.length) && (m_sons[ind].m_isLeaf));
123     
124      // Check if all successors are leaves
125      for (j = 0; j < m_sons.length; j++) 
126        if ((m_sons[j] == null) || (!m_sons[j].m_isLeaf))
127          break;
128      if (j == m_sons.length) {
129        pruneEnd();
130        if (!m_isLeaf) 
131          indeX = chooseLastIndex();
132      }else 
133        indeX = chooseLastIndex();
134    }else{
135      m_isLeaf = true;
136      if (Utils.eq(sumOfWeights, 0))
137        m_isEmpty = true;
138    }
139  }
140 
141  /**
142   * Returns a newly created tree.
143   *
144   * @param train train data
145   * @param test test data
146   * @param leaf
147   * @throws Exception if something goes wrong
148   */
149  protected ClassifierDecList getNewDecList(Instances train, Instances test, 
150                                            boolean leaf) throws Exception {
151         
152    PruneableDecList newDecList = 
153      new PruneableDecList(m_toSelectModel, m_minNumObj);
154   
155    newDecList.buildDecList((Instances)train, test, leaf);
156   
157    return newDecList;
158  }
159
160  /**
161   * Prunes the end of the rule.
162   */
163  protected void pruneEnd() throws Exception {
164   
165    double errorsLeaf, errorsTree;
166   
167    errorsTree = errorsForTree();
168    errorsLeaf = errorsForLeaf();
169    if (Utils.smOrEq(errorsLeaf,errorsTree)){ 
170      m_isLeaf = true;
171      m_sons = null;
172      m_localModel = new NoSplit(localModel().distribution());
173    }
174  }
175
176  /**
177   * Computes error estimate for tree.
178   */
179  private double errorsForTree() throws Exception {
180
181    Distribution test;
182
183    if (m_isLeaf)
184      return errorsForLeaf();
185    else {
186      double error = 0;
187      for (int i = 0; i < m_sons.length; i++) 
188        if (Utils.eq(son(i).localModel().distribution().total(),0)) {
189          error += m_test.perBag(i)-
190            m_test.perClassPerBag(i,localModel().distribution().
191                                maxClass());
192        } else
193          error += ((PruneableDecList)son(i)).errorsForTree();
194
195      return error;
196    }
197  }
198
199  /**
200   * Computes estimated errors for leaf.
201   */
202  private double errorsForLeaf() throws Exception {
203
204    return m_test.total()-
205            m_test.perClass(localModel().distribution().maxClass());
206  }
207 
208  /**
209   * Returns the revision string.
210   *
211   * @return            the revision
212   */
213  public String getRevision() {
214    return RevisionUtils.extract("$Revision: 1.10 $");
215  }
216}
Note: See TracBrowser for help on using the repository browser.