source: src/main/java/weka/classifiers/trees/ft/FTLeavesNode.java @ 22

Last change on this file since 22 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 8.7 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    FTLeavesNode.java
19 *    Copyright (C) 2007 University of Porto, Porto, Portugal
20 *
21 */
22
23package weka.classifiers.trees.ft;
24
25import weka.classifiers.functions.SimpleLinearRegression;
26import weka.classifiers.trees.j48.C45ModelSelection;
27import weka.classifiers.trees.j48.NoSplit;
28import weka.core.Instance;
29import weka.core.Instances;
30import weka.core.RevisionUtils;
31import weka.core.Utils;
32
33/**
34 * Class for Functional Leaves tree version.
35 *
36 * @author Jo\~{a}o Gama
37 * @author Carlos Ferreira
38 *
39 * @version $Revision: 6088 $
40 */
41public class FTLeavesNode 
42  extends FTtree {   
43
44  /** for serialization. */
45  private static final long serialVersionUID = 950601378326259315L;
46
47  /**
48   * Constructor for Functional Leaves tree node.
49   *
50   * @param errorOnProbabilities Use error on probabilities for stopping criterion of LogitBoost?
51   * @param numBoostingIterations sets the numBoostingIterations parameter
52   * @param minNumInstances minimum number of instances at which a node is considered for splitting
53   */
54  public FTLeavesNode( boolean errorOnProbabilities, int numBoostingIterations, int minNumInstances,
55                       double weightTrimBeta, boolean useAIC) {
56    m_errorOnProbabilities = errorOnProbabilities;
57    m_fixedNumIterations = numBoostingIterations;     
58    m_minNumInstances = minNumInstances;
59    m_maxIterations = 200;
60    setWeightTrimBeta(weightTrimBeta);
61    setUseAIC(useAIC);
62  }         
63   
64  /**
65   * Method for building a Functional Leaves tree (only called for the root node).
66   * Grows an initial Functional Tree.
67   *
68   * @param data the data to train with
69   * @throws Exception if something goes wrong
70   */
71  public void buildClassifier(Instances data) throws Exception{
72       
73    buildTree(data, null, data.numInstances(), 0);
74       
75  }
76
77  /**
78   * Method for building the tree structure.
79   * Builds a logistic model, splits the node and recursively builds tree for child nodes.
80   * @param data the training data passed on to this node
81   * @param higherRegressions An array of regression functions produced by LogitBoost at higher
82   * levels in the tree. They represent a logistic regression model that is refined locally
83   * at this node.
84   * @param totalInstanceWeight the total number of training examples
85   * @param higherNumParameters effective number of parameters in the logistic regression model built
86   * in parent nodes
87   * @throws Exception if something goes wrong
88   */
89  public void buildTree(Instances data, SimpleLinearRegression[][] higherRegressions, 
90                        double totalInstanceWeight, double higherNumParameters) throws Exception{
91
92    //save some stuff
93    m_totalInstanceWeight = totalInstanceWeight;
94    m_train = new Instances(data);
95       
96       
97    m_isLeaf = true;
98    m_sons = null;
99       
100    m_numInstances = m_train.numInstances();
101    m_numClasses = m_train.numClasses();                               
102       
103    //init
104    m_numericData = getNumericData(m_train);             
105    m_numericDataHeader = new Instances(m_numericData, 0);
106       
107    m_regressions = initRegressions();
108    m_numRegressions = 0;
109       
110    if (higherRegressions != null) m_higherRegressions = higherRegressions;
111    else m_higherRegressions = new SimpleLinearRegression[m_numClasses][0];     
112
113    m_numHigherRegressions = m_higherRegressions[0].length;     
114       
115    m_numParameters = higherNumParameters;
116       
117    //build logistic model
118    if (m_numInstances >= m_numFoldsBoosting) {
119      if (m_fixedNumIterations > 0){
120        performBoosting(m_fixedNumIterations);
121      } else if (getUseAIC()) {
122        performBoostingInfCriterion();
123      } else {
124        performBoostingCV();
125      }
126    }
127       
128    m_numParameters += m_numRegressions;
129       
130    //only keep the simple regression functions that correspond to the selected number of LogitBoost iterations
131    m_regressions = selectRegressions(m_regressions);
132       
133    boolean grow;
134       
135    //Compute logistic probs
136    double[][] FsConst;
137    double[] probsConst;
138    int j;
139    FsConst = getFs(m_numericData);
140       
141    for (j = 0; j < data.numInstances(); j++)
142      {
143        probsConst=probs(FsConst[j]);
144        // Computes constructor error
145        if (data.instance(j).classValue()!=getConstError(probsConst)) m_constError=m_constError +1;
146      }
147       
148    //to choose split point on the node data
149    m_modelSelection=new  C45ModelSelection(m_minNumInstances, data, true);
150    m_localModel = m_modelSelection.selectModel(data);
151       
152    //split node if more than minNumInstances...
153    if (m_numInstances > m_minNumInstances) {
154      grow = (m_localModel.numSubsets() > 1);
155    } else {
156      grow = false;
157    }
158       
159    // logitboost uses distribution for instance
160    m_hasConstr=false;
161    if (grow) { 
162      //create and build children of node
163      m_isLeaf = false;
164      Instances[] localInstances = m_localModel.split(data);
165      m_sons = new FTLeavesNode[m_localModel.numSubsets()];
166     
167      for (int i = 0; i < m_sons.length; i++) {
168        m_sons[i] = new FTLeavesNode(m_errorOnProbabilities, m_fixedNumIterations, 
169                                     m_minNumInstances,getWeightTrimBeta(), getUseAIC());
170        m_sons[i].buildTree(localInstances[i],
171                            mergeArrays(m_regressions, m_higherRegressions), m_totalInstanceWeight, m_numParameters);           
172        localInstances[i] = null;
173      }     
174    } 
175    else{
176      m_leafclass=m_localModel.distribution().maxClass();
177
178    }
179  }
180   
181  /**
182   * Prunes a tree using C4.5 pruning procedure.
183   *
184   * @exception Exception if something goes wrong
185   */
186  public double prune() throws Exception {
187
188    double errorsLeaf;
189    double errorsTree;
190    double errorsConstModel;
191    double treeError=0;
192    int i;
193    double probBranch;
194
195    // Compute error if this Tree would be leaf without contructor
196    errorsLeaf = getEstimatedErrorsForDistribution(m_localModel.distribution());
197    if (m_isLeaf ) { 
198      return  errorsLeaf;
199    } else {
200      //Computes da error of the constructor model
201      errorsConstModel = getEtimateConstModel(m_localModel.distribution());
202      errorsTree=0;
203      for (i = 0; i < m_sons.length; i++) {
204        probBranch = m_localModel.distribution().perBag(i) /
205          m_localModel.distribution().total();
206        errorsTree += probBranch* m_sons[i].prune();
207      }
208      // Decide if leaf is best choice.
209
210      if (Utils.smOrEq(errorsLeaf, errorsTree) && Utils.smOrEq(errorsLeaf, errorsConstModel)) {
211        // Free son Trees
212        m_sons = null;
213        m_isLeaf = true;
214        m_hasConstr=false;
215        m_leafclass=m_localModel.distribution().maxClass();
216        // Get NoSplit Model for node.
217        m_localModel = new NoSplit(m_localModel.distribution());
218        treeError=errorsLeaf;
219
220      }else{
221        // Decide if Constructor is best choice.
222        if (Utils.smOrEq(errorsConstModel, errorsTree)) {
223          // Free son Trees
224          m_sons = null;
225          m_isLeaf = true;
226          m_hasConstr =true;
227          // Get NoSplit Model for node.
228          m_localModel = new NoSplit(m_localModel.distribution());
229          treeError=errorsConstModel;
230        } else
231          treeError=errorsTree;
232      }
233    }
234    return  treeError;
235  }
236
237  /**
238   * Returns the class probabilities for an instance given by the Functional Leaves tree.
239   * @param instance the instance
240   * @return the array of probabilities
241   */
242  public double[] distributionForInstance(Instance instance) throws Exception {
243    double[] probs;
244       
245    if (m_isLeaf && m_hasConstr) { //leaf
246      //leaf: use majoraty class or constructor model
247      probs = modelDistributionForInstance(instance);
248    } else {
249      if (m_isLeaf && !m_hasConstr)
250        {
251          probs=new double[instance.numClasses()];
252          probs[m_leafclass]=(double)1;   
253        }else{
254                             
255        int branch = m_localModel.whichSubset(instance); //split
256        probs = m_sons[branch].distributionForInstance(instance);
257      }
258    }
259    return probs;
260       
261  }
262 
263  /**
264   * Returns the revision string.
265   *
266   * @return            the revision
267   */
268  public String getRevision() {
269    return RevisionUtils.extract("$Revision: 6088 $");
270  }
271}
Note: See TracBrowser for help on using the repository browser.