source: branches/MetisMQI/src/main/java/weka/classifiers/rules/part/C45PruneableDecList.java

Last change on this file was 29, checked in by gnappo, 15 years ago

Taggata versione per la demo e aggiunto branch.

File size: 5.4 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    C45PruneableDecList.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.rules.part;
24
25import weka.classifiers.trees.j48.Distribution;
26import weka.classifiers.trees.j48.ModelSelection;
27import weka.classifiers.trees.j48.NoSplit;
28import weka.classifiers.trees.j48.Stats;
29import weka.core.Instances;
30import weka.core.RevisionUtils;
31import weka.core.Utils;
32
33/**
34 * Class for handling a partial tree structure pruned using C4.5's
35 * pruning heuristic.
36 *
37 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
38 * @version $Revision: 1.9 $
39 */
40public class C45PruneableDecList
41  extends ClassifierDecList{
42
43  /** for serialization */
44  private static final long serialVersionUID = -2757684345218324559L;
45   
46  /** CF */
47  private double CF = 0.25;
48 
49  /**
50   * Constructor for pruneable tree structure. Stores reference
51   * to associated training data at each node.
52   *
53   * @param toSelectLocModel selection method for local splitting model
54   * @param cf the confidence factor for pruning
55   * @param minNum the minimum number of objects in a leaf
56   * @exception Exception if something goes wrong
57   */
58  public C45PruneableDecList(ModelSelection toSelectLocModel, 
59                             double cf, int minNum) 
60       throws Exception {
61                               
62    super(toSelectLocModel, minNum);
63   
64    CF = cf;
65  }
66 
67  /**
68   * Builds the partial tree without hold out set.
69   *
70   * @exception Exception if something goes wrong
71   */
72  public void buildDecList(Instances data, boolean leaf) throws Exception {
73   
74    Instances [] localInstances,localPruneInstances;
75    int index,ind;
76    int i,j;
77    double sumOfWeights;
78    NoSplit noSplit;
79   
80    m_train = null;
81    m_test = null;
82    m_isLeaf = false;
83    m_isEmpty = false;
84    m_sons = null;
85    indeX = 0;
86    sumOfWeights = data.sumOfWeights();
87    noSplit = new NoSplit (new Distribution((Instances)data));
88    if (leaf)
89      m_localModel = noSplit;
90    else
91      m_localModel = m_toSelectModel.selectModel(data);
92    if (m_localModel.numSubsets() > 1) {
93      localInstances = m_localModel.split(data);
94      data = null;
95      m_sons = new ClassifierDecList [m_localModel.numSubsets()];
96      i = 0;
97      do {
98        i++;
99        ind = chooseIndex();
100        if (ind == -1) {
101          for (j = 0; j < m_sons.length; j++) 
102            if (m_sons[j] == null)
103              m_sons[j] = getNewDecList(localInstances[j],true);
104          if (i < 2) {
105            m_localModel = noSplit;
106            m_isLeaf = true;
107            m_sons = null;
108            if (Utils.eq(sumOfWeights,0))
109              m_isEmpty = true;
110            return;
111          }
112          ind = 0;
113          break;
114        } else 
115          m_sons[ind] = getNewDecList(localInstances[ind],false);
116      } while ((i < m_sons.length) && (m_sons[ind].m_isLeaf));
117     
118      // Check if all successors are leaves
119      for (j = 0; j < m_sons.length; j++) 
120        if ((m_sons[j] == null) || (!m_sons[j].m_isLeaf))
121          break;
122      if (j == m_sons.length) {
123        pruneEnd();
124        if (!m_isLeaf) 
125          indeX = chooseLastIndex();
126      }else 
127        indeX = chooseLastIndex();
128    }else{
129      m_isLeaf = true;
130      if (Utils.eq(sumOfWeights, 0))
131        m_isEmpty = true;
132    }
133  }
134 
135  /**
136   * Returns a newly created tree.
137   *
138   * @exception Exception if something goes wrong
139   */
140  protected ClassifierDecList getNewDecList(Instances data, boolean leaf) 
141       throws Exception {
142         
143    C45PruneableDecList newDecList = 
144      new C45PruneableDecList(m_toSelectModel,CF, m_minNumObj);
145   
146    newDecList.buildDecList((Instances)data, leaf);
147   
148    return newDecList;
149  }
150
151  /**
152   * Prunes the end of the rule.
153   */
154  protected void pruneEnd() {
155   
156    double errorsLeaf, errorsTree;
157   
158    errorsTree = getEstimatedErrorsForTree();
159    errorsLeaf = getEstimatedErrorsForLeaf();
160    if (Utils.smOrEq(errorsLeaf,errorsTree+0.1)) { // +0.1 as in C4.5
161      m_isLeaf = true;
162      m_sons = null;
163      m_localModel = new NoSplit(localModel().distribution());
164    }
165  }
166 
167  /**
168   * Computes estimated errors for tree.
169   */
170  private double getEstimatedErrorsForTree() {
171
172    if (m_isLeaf)
173      return getEstimatedErrorsForLeaf();
174    else {
175      double error = 0;
176      for (int i = 0; i < m_sons.length; i++) 
177        if (!Utils.eq(son(i).localModel().distribution().total(),0))
178          error += ((C45PruneableDecList)son(i)).getEstimatedErrorsForTree();
179      return error;
180    }
181  }
182 
183  /**
184   * Computes estimated errors for leaf.
185   */
186  public double getEstimatedErrorsForLeaf() {
187 
188    double errors = localModel().distribution().numIncorrect();
189
190    return errors+Stats.addErrs(localModel().distribution().total(),
191                                errors,(float)CF);
192  }
193 
194  /**
195   * Returns the revision string.
196   *
197   * @return            the revision
198   */
199  public String getRevision() {
200    return RevisionUtils.extract("$Revision: 1.9 $");
201  }
202}
Note: See TracBrowser for help on using the repository browser.