source: src/main/java/weka/classifiers/trees/ft/FTNode.java @ 21

Last change on this file since 21 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 10.1 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    FTNode.java
19 *    Copyright (C) 2007 University of Porto, Porto, Portugal
20 *
21 */
22
23package weka.classifiers.trees.ft;
24
25import weka.classifiers.functions.SimpleLinearRegression;
26import weka.classifiers.trees.j48.C45ModelSelection;
27import weka.classifiers.trees.j48.C45Split;
28import weka.classifiers.trees.j48.NoSplit;
29import weka.core.Attribute;
30import weka.core.Instance;
31import weka.core.DenseInstance;
32import weka.core.Instances;
33import weka.core.RevisionUtils;
34import weka.core.Utils;
35
36/**
37 * Class for Functional tree structure.
38 *
39 * @author Jo\~{a}o Gama
40 * @author Carlos Ferreira
41 *
42 * @version $Revision: 6088 $
43 */
44public class FTNode 
45  extends FTtree {   
46 
47  /** for serialization. */
48  private static final long serialVersionUID = 2317688685139295063L;
49
50  /**
51   * Constructor for Functional tree node.
52   *
53   * @param errorOnProbabilities Use error on probabilities for stopping criterion of LogitBoost?
54   * @param numBoostingIterations sets the numBoostingIterations parameter
55   * @param minNumInstances minimum number of instances at which a node is considered for splitting
56   *
57   */
58  public FTNode( boolean errorOnProbabilities, int numBoostingIterations, 
59                 int minNumInstances, double weightTrimBeta, boolean useAIC) {
60    m_errorOnProbabilities = errorOnProbabilities;
61    m_fixedNumIterations = numBoostingIterations;     
62    m_minNumInstances = minNumInstances;
63    m_maxIterations = 200;
64    setWeightTrimBeta(weightTrimBeta);
65    setUseAIC(useAIC);
66  }         
67   
68  /**
69   * Method for building a Functional tree (only called for the root node).
70   * Grows an initial Functional Tree.
71   *
72   * @param data the data to train with
73   * @throws Exception if something goes wrong
74   */
75  public void buildClassifier(Instances data) throws Exception{
76       
77    // Insert new attributes
78    data= insertNewAttr(data); 
79         
80    //build tree using all the data
81    buildTree(data, null, data.numInstances(), 0);
82       
83  }
84
85  /**
86   * Method for building the tree structure.
87   * Builds a logistic model, splits the node and recursively builds tree for child nodes.
88   * @param data the training data passed on to this node
89   * @param higherRegressions An array of regression functions produced by LogitBoost at higher
90   * levels in the tree. They represent a logistic regression model that is refined locally
91   * at this node.
92   * @param totalInstanceWeight the total number of training examples
93   * @param higherNumParameters effective number of parameters in the logistic regression model built
94   * in parent nodes
95   * @throws Exception if something goes wrong
96   */
97  public void buildTree(Instances data, SimpleLinearRegression[][] higherRegressions, 
98                        double totalInstanceWeight, double higherNumParameters) throws Exception{
99
100    //save some stuff
101    m_totalInstanceWeight = totalInstanceWeight;
102    m_train = new Instances(data);
103    m_train= removeExtAttributes( m_train);
104       
105    m_isLeaf = true;
106    m_sons = null;
107       
108    m_numInstances = m_train.numInstances();
109    m_numClasses = m_train.numClasses();                               
110       
111    //init
112    m_numericData = getNumericData(m_train);             
113    m_numericDataHeader = new Instances(m_numericData, 0);
114       
115    m_regressions = initRegressions();
116    m_numRegressions = 0;
117       
118    if (higherRegressions != null) m_higherRegressions = higherRegressions;
119    else m_higherRegressions = new SimpleLinearRegression[m_numClasses][0];     
120
121    m_numHigherRegressions = m_higherRegressions[0].length;     
122       
123    m_numParameters = higherNumParameters;
124       
125    //build logistic model
126    if (m_numInstances >= m_numFoldsBoosting) {
127      if (m_fixedNumIterations > 0){
128        performBoosting(m_fixedNumIterations);
129      } else if (getUseAIC()) {
130        performBoostingInfCriterion();
131      } else {
132        performBoostingCV();
133      }
134    }
135       
136    m_numParameters += m_numRegressions;
137       
138    //only keep the simple regression functions that correspond to the selected number of LogitBoost iterations
139    m_regressions = selectRegressions(m_regressions);
140         
141    boolean grow;
142       
143    //Compute logistic probs
144    double[][] FsConst;
145    double[] probsConst;
146    int j;
147    FsConst = getFs(m_numericData);
148       
149    for (j = 0; j < data.numInstances(); j++)
150      {
151        probsConst=probs(FsConst[j]);
152        // auxiliary to compute constructor error
153        if (data.instance(j).classValue()!=getConstError(probsConst)) m_constError=m_constError +1;
154        for (int i = 0; i<data.classAttribute().numValues() ; i++)
155          data.instance(j).setValue(i,probsConst[i]);
156      } 
157 
158       
159    //needed by dynamic data
160    m_modelSelection=new  C45ModelSelection(m_minNumInstances, data, true);
161     
162    m_localModel = m_modelSelection.selectModel(data);
163       
164    //split node if more than minNumInstances...
165    if (m_numInstances > m_minNumInstances) {
166      grow = (m_localModel.numSubsets() > 1);
167    } else {
168      grow = false;
169    }
170       
171    // logitboost uses distribution for instance
172    m_hasConstr=false;
173    m_train=data;
174    if (grow) { 
175      //create and build children of node
176      m_isLeaf = false;
177      Instances[] localInstances = m_localModel.split(data);
178      // deletes extended attributes
179      if (((C45Split)m_localModel).attIndex() >=0 && ((C45Split)m_localModel).attIndex()< data.classAttribute().numValues()) 
180        m_hasConstr=true;                         
181               
182      m_sons = new FTNode[m_localModel.numSubsets()];
183      for (int i = 0; i < m_sons.length; i++) {
184        m_sons[i] = new FTNode(m_errorOnProbabilities,m_fixedNumIterations, 
185                               m_minNumInstances,getWeightTrimBeta(), getUseAIC());
186        m_sons[i].buildTree(localInstances[i],
187                            mergeArrays(m_regressions, m_higherRegressions), m_totalInstanceWeight, m_numParameters);           
188        localInstances[i] = null;
189      }     
190    } 
191    else{
192      m_leafclass=m_localModel.distribution().maxClass();
193    }
194  }
195
196  /**
197   * Method for prunning a tree using C4.5 pruning procedure.
198   *
199   * @exception Exception if something goes wrong
200   */
201  public double prune() throws Exception {
202
203    double errorsLeaf;
204    double errorsTree;
205    double errorsConstModel;
206    double treeError=0;
207    int i;
208    double probBranch;
209
210    // Compute error if this Tree would be leaf without contructor
211    errorsLeaf = getEstimatedErrorsForDistribution(m_localModel.distribution());
212    if (m_isLeaf ) { 
213      return  errorsLeaf;
214    } else {
215      //Computes da error of the constructor model
216      errorsConstModel = getEtimateConstModel(m_localModel.distribution());
217      errorsTree=0;
218      for (i = 0; i < m_sons.length; i++) {
219        probBranch = m_localModel.distribution().perBag(i) /
220          m_localModel.distribution().total();
221        errorsTree += probBranch* m_sons[i].prune();
222      }
223      // Decide if leaf is best choice.
224
225      if (Utils.smOrEq(errorsLeaf, errorsTree) && Utils.smOrEq(errorsLeaf, errorsConstModel)) {
226        // Free son Trees
227        m_sons = null;
228        m_isLeaf = true;
229        m_hasConstr=false;
230        m_leafclass=m_localModel.distribution().maxClass();
231        // Get NoSplit Model for node.
232        m_localModel = new NoSplit(m_localModel.distribution());
233        treeError=errorsLeaf;
234
235      }else{
236        // Decide if Constructor is best choice.
237        if (Utils.smOrEq(errorsConstModel, errorsTree)) {
238          // Free son Trees
239          m_sons = null;
240          m_isLeaf = true;
241          m_hasConstr =true;
242          // Get NoSplit Model for node.
243          m_localModel = new NoSplit(m_localModel.distribution());
244          treeError=errorsConstModel;
245        } else
246          treeError=errorsTree;
247      }
248    }
249    return  treeError;
250  }
251 
252  /**
253   * Returns the class probabilities for an instance given by the Functional Tree.
254   * @param instance the instance
255   * @return the array of probabilities
256   */
257  public double[] distributionForInstance(Instance instance) throws Exception {
258    double[] probs;
259
260    if (m_isLeaf && m_hasConstr) { //leaf
261      //leaf: use majoraty class or constructor model
262      probs = modelDistributionForInstance(instance);
263    } else { 
264      if (m_isLeaf && !m_hasConstr)
265        {
266          probs=new double[instance.numClasses()];
267          probs[m_leafclass]=(double)1;
268        }else{
269               
270        probs = modelDistributionForInstance(instance);
271        //Built auxiliary split instance   
272        Instance instanceSplit=new DenseInstance(instance.numAttributes()+instance.numClasses());
273        instanceSplit.setDataset(instance.dataset());
274           
275        // Inserts attribute and their value
276        for(int i=0; i< instance.numClasses();i++)
277          {
278            instanceSplit.dataset().insertAttributeAt( new Attribute("N"+ (instance.numClasses()-i)), 0);
279            instanceSplit.setValue(i,probs[i]);
280          }
281        for(int i=0; i< instance.numAttributes();i++)
282          instanceSplit.setValue(i+instance.numClasses(),instance.value(i));
283           
284        //chooses best branch           
285        int branch = m_localModel.whichSubset(instanceSplit); //split
286           
287        //delete added attributes
288        for(int i=0; i< instance.numClasses();i++)
289          instanceSplit.dataset().deleteAttributeAt(0);
290           
291        probs = m_sons[branch].distributionForInstance(instance);
292      }
293    }
294    return probs;
295       
296  }
297 
298  /**
299   * Returns the revision string.
300   *
301   * @return            the revision
302   */
303  public String getRevision() {
304    return RevisionUtils.extract("$Revision: 6088 $");
305  }
306}
Note: See TracBrowser for help on using the repository browser.