source: src/main/java/weka/classifiers/trees/ft/FTInnerNode.java @ 9

Last change on this file since 9 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 9.8 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    FTInnerNode.java
19 *    Copyright (C) 2007 University of Porto, Porto, Portugal
20 *
21 */
22
23package weka.classifiers.trees.ft;
24
25import weka.classifiers.functions.SimpleLinearRegression;
26import weka.classifiers.trees.j48.C45ModelSelection;
27import weka.classifiers.trees.j48.C45Split;
28import weka.classifiers.trees.j48.NoSplit;
29import weka.core.Attribute;
30import weka.core.Instance;
31import weka.core.DenseInstance;
32import weka.core.Instances;
33import weka.core.RevisionUtils;
34import weka.core.Utils;
35
36/**
37 * Class for Functional Inner tree structure.
38 *
39 * @author Jo\~{a}o Gama
40 * @author Carlos Ferreira
41 *
42 * @version $Revision: 6088 $
43 */
44public class FTInnerNode 
45  extends FTtree {   
46     
47  /** for serialization. */
48  private static final long serialVersionUID = -1125334488640233181L;
49
50  /**
51   * Constructor for Functional Inner tree node.
52   *
53   * @param errorOnProbabilities Use error on probabilities for stopping criterion of LogitBoost?
54   * @param numBoostingIterations sets the numBoostingIterations parameter
55   * @param minNumInstances minimum number of instances at which a node is considered for splitting
56   */
57  public FTInnerNode(boolean errorOnProbabilities,int numBoostingIterations,
58                     int minNumInstances, double weightTrimBeta, boolean useAIC) {
59    m_errorOnProbabilities = errorOnProbabilities;
60    m_fixedNumIterations = numBoostingIterations;     
61    m_minNumInstances = minNumInstances;
62    m_maxIterations = 200;
63    setWeightTrimBeta(weightTrimBeta);
64    setUseAIC(useAIC);
65  }         
66   
67  /**
68   * Method for building a Functional Inner tree (only called for the root node).
69   * Grows an initial Functional Tree.
70   *
71   * @param data the data to train with
72   * @throws Exception if something goes wrong
73   */
74  public void buildClassifier(Instances data) throws Exception{
75       
76    // add new attributes to original dataset
77    data= insertNewAttr(data); 
78       
79    //build tree using all the data
80    buildTree(data, null, data.numInstances(), 0);
81       
82  }
83
84  /**
85   * Method for building the tree structure.
86   * Builds a logistic model, splits the node and recursively builds tree for child nodes.
87   * @param data the training data passed on to this node
88   * @param higherRegressions An array of regression functions produced by LogitBoost at higher
89   * levels in the tree. They represent a logistic regression model that is refined locally
90   * at this node.
91   * @param totalInstanceWeight the total number of training examples
92   * @param higherNumParameters effective number of parameters in the logistic regression model built
93   * in parent nodes
94   * @throws Exception if something goes wrong
95   */
96  public void buildTree(Instances data, SimpleLinearRegression[][] higherRegressions, 
97                        double totalInstanceWeight, double higherNumParameters) throws Exception{
98
99    //save some stuff
100    m_totalInstanceWeight = totalInstanceWeight;
101    m_train = new Instances(data);
102       
103    m_train= removeExtAttributes( m_train);
104       
105    m_isLeaf = true;
106    m_sons = null;
107       
108    m_numInstances = m_train.numInstances();
109    m_numClasses = m_train.numClasses();                               
110       
111    //init
112    m_numericData = getNumericData(m_train);             
113    m_numericDataHeader = new Instances(m_numericData, 0);
114       
115    m_regressions = initRegressions();
116    m_numRegressions = 0;
117       
118    if (higherRegressions != null) m_higherRegressions = higherRegressions;
119    else m_higherRegressions = new SimpleLinearRegression[m_numClasses][0];     
120
121    m_numHigherRegressions = m_higherRegressions[0].length;     
122       
123    m_numParameters = higherNumParameters;
124       
125    //build logistic model
126    if (m_numInstances >= m_numFoldsBoosting) {
127      if (m_fixedNumIterations > 0){
128        performBoosting(m_fixedNumIterations);
129      } else if (getUseAIC()) {
130        performBoostingInfCriterion();
131      } else {
132        performBoostingCV();
133      }
134    }
135       
136    m_numParameters += m_numRegressions;
137       
138    //only keep the simple regression functions that correspond to the selected number of LogitBoost iterations
139    m_regressions = selectRegressions(m_regressions);
140         
141    boolean grow;
142       
143    //Compute logistic probs
144    double[][] FsConst;
145    double[] probsConst;
146    int j;
147    FsConst = getFs(m_numericData);
148       
149    for (j = 0; j < data.numInstances(); j++)
150      {
151        probsConst=probs(FsConst[j]);
152        // Computes constructor error
153        if (data.instance(j).classValue()!=getConstError(probsConst)) m_constError=m_constError +1;
154        for (int i = 0; i<data.classAttribute().numValues() ; i++)
155          data.instance(j).setValue(i,probsConst[i]);
156      }
157       
158    // needed by dynamic data
159    m_modelSelection=new  C45ModelSelection(m_minNumInstances, data, true);
160    m_localModel = m_modelSelection.selectModel(data);
161    //split node if more than minNumInstances...
162    if (m_numInstances > m_minNumInstances) {
163      grow = (m_localModel.numSubsets() > 1);
164    } else {
165      grow = false;
166    }
167         
168    // logitboost uses distribution for instance
169    m_hasConstr=false;
170    m_train=data;
171    if (grow) { 
172      //create and build children of node
173      m_isLeaf = false;             
174      Instances[] localInstances = m_localModel.split(data);
175      // If split attribute is a extended attribute, the node has a constructor
176      if (((C45Split)m_localModel).attIndex() >=0 && ((C45Split)m_localModel).attIndex()< data.classAttribute().numValues()) 
177        m_hasConstr=true;
178               
179      m_sons = new FTInnerNode[m_localModel.numSubsets()];
180      for (int i = 0; i < m_sons.length; i++) {
181        m_sons[i] = new FTInnerNode (m_errorOnProbabilities, m_fixedNumIterations, 
182                                     m_minNumInstances,getWeightTrimBeta(), getUseAIC());
183        m_sons[i].buildTree(localInstances[i],
184                            mergeArrays(m_regressions, m_higherRegressions), m_totalInstanceWeight, m_numParameters);           
185        localInstances[i] = null;
186      }     
187    } 
188    else{
189      m_leafclass=m_localModel.distribution().maxClass();
190    }
191  }
192
193   
194   
195  /**
196   * Prunes a tree using C4.5 pruning procedure.
197   *
198   * @exception Exception if something goes wrong
199   */
200  public double prune() throws Exception {
201
202    double errorsLeaf;
203    double errorsTree;
204    double errorsConstModel;
205    double treeError=0;
206    int i;
207    double probBranch;
208
209    // Compute error if this Tree would be leaf without contructor
210    errorsLeaf = getEstimatedErrorsForDistribution(m_localModel.distribution());
211    if (m_isLeaf ) { 
212      return  errorsLeaf;
213    } else {
214      //Computes da error of the constructor model
215      errorsConstModel = getEtimateConstModel(m_localModel.distribution());
216      errorsTree=0;
217      for (i = 0; i < m_sons.length; i++) {
218        probBranch = m_localModel.distribution().perBag(i) /
219          m_localModel.distribution().total();
220        errorsTree += probBranch* m_sons[i].prune();
221      }
222         
223      // Decide if leaf is best choice.
224      if (Utils.smOrEq(errorsLeaf, errorsTree)) {
225        // Free son Trees
226        m_sons = null;
227        m_isLeaf = true;
228        m_hasConstr=false;
229        m_leafclass=m_localModel.distribution().maxClass();
230        // Get NoSplit Model for node.
231        m_localModel = new NoSplit(m_localModel.distribution());
232        treeError=errorsLeaf;
233
234      } else{
235        treeError=errorsTree;
236      }
237    }
238    return  treeError;
239  }
240
241  /**
242   * Returns the class probabilities for an instance given by the Functional tree.
243   * @param instance the instance
244   * @return the array of probabilities
245   */
246  public double[] distributionForInstance(Instance instance) throws Exception {
247    double[] probs;
248                                               
249    //also needed for logitboost
250    if (m_isLeaf && m_hasConstr) { //leaf
251      //leaf: use majoraty class or constructor model
252      probs = modelDistributionForInstance(instance);
253    } else {
254      if (m_isLeaf && !m_hasConstr)
255        {
256          probs=new double[instance.numClasses()];
257          probs[m_leafclass]=(double)1; 
258        }else{
259               
260        probs = modelDistributionForInstance(instance);
261        //Built auxiliary split instance   
262        Instance instanceSplit=new DenseInstance(instance.numAttributes()+instance.numClasses());
263           
264        instanceSplit.setDataset(instance.dataset());
265     
266        for(int i=0; i< instance.numClasses();i++)
267          {
268            instanceSplit.dataset().insertAttributeAt( new Attribute("N"+ (instance.numClasses()-i)), 0);
269            instanceSplit.setValue(i,probs[i]);
270          }
271        for(int i=0; i< instance.numAttributes();i++)
272          instanceSplit.setValue(i+instance.numClasses(),instance.value(i));
273         
274           
275           
276        int branch = m_localModel.whichSubset(instanceSplit); //split
277        for(int i=0; i< instance.numClasses();i++)
278          instanceSplit.dataset().deleteAttributeAt(0);
279           
280        //probs = m_sons[branch].distributionForInstance(instance);
281        probs = m_sons[branch].distributionForInstance(instance);
282      }
283    }
284    return probs;       
285  }
286 
287  /**
288   * Returns the revision string.
289   *
290   * @return            the revision
291   */
292  public String getRevision() {
293    return RevisionUtils.extract("$Revision: 6088 $");
294  }
295}
Note: See TracBrowser for help on using the repository browser.