source: src/main/java/weka/classifiers/trees/ft/FTtree.java @ 11

Last change on this file since 11 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 21.1 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    FTNode.java
19 *    Copyright (C) 2007 University of Porto, Porto, Portugal
20 *
21 */
22
23package weka.classifiers.trees.ft;
24
25import weka.classifiers.functions.SimpleLinearRegression;
26import weka.classifiers.trees.j48.BinC45ModelSelection;
27import weka.classifiers.trees.j48.BinC45Split;
28import weka.classifiers.trees.j48.C45Split;
29import weka.classifiers.trees.j48.ClassifierSplitModel;
30import weka.classifiers.trees.j48.Distribution;
31import weka.classifiers.trees.j48.ModelSelection;
32import weka.classifiers.trees.j48.Stats;
33import weka.classifiers.trees.lmt.LogisticBase;
34import weka.core.Attribute;
35import weka.core.Instance;
36import weka.core.Instances;
37import weka.core.RevisionUtils;
38import weka.core.Utils;
39import weka.filters.Filter;
40import weka.filters.supervised.attribute.NominalToBinary;
41
42import java.util.Vector;
43
44/**
45 * Abstract class for Functional tree structure.
46 *
47 * @author Jo\~{a}o Gama
48 * @author Carlos Ferreira
49 *
50 * @version $Revision: 4899 $
51 */
52public abstract class FTtree 
53  extends LogisticBase {   
54 
55  /** for serialization */
56  static final long serialVersionUID = 1862737145870398755L;
57   
58  /** Total number of training instances. */
59  protected double m_totalInstanceWeight;
60   
61  /** Node id*/
62  protected int m_id;
63   
64  /** ID of logistic model at leaf*/
65  protected int m_leafModelNum;
66 
67  /**minimum number of instances at which a node is considered for splitting*/
68  protected int m_minNumInstances;
69
70  /**ModelSelection object (for splitting)*/
71  protected ModelSelection m_modelSelection;     
72
73  /**Filter to convert nominal attributes to binary*/
74  protected NominalToBinary m_nominalToBinary; 
75   
76  /**Simple regression functions fit by LogitBoost at higher levels in the tree*/
77  protected SimpleLinearRegression[][] m_higherRegressions;
78   
79  /**Number of simple regression functions fit by LogitBoost at higher levels in the tree*/
80  protected int m_numHigherRegressions = 0;
81   
82  /**Number of instances at the node*/
83  protected int m_numInstances;   
84
85  /**The ClassifierSplitModel (for splitting)*/
86  protected ClassifierSplitModel m_localModel; 
87   
88  /**Auxiliary copy ClassifierSplitModel (for splitting)*/
89  protected ClassifierSplitModel m_auxLocalModel; 
90 
91  /**Array of children of the node*/
92  protected FTtree[] m_sons; 
93   
94  /** Stores leaf class value */ 
95  protected int m_leafclass;
96   
97  /**True if node is leaf*/
98  protected boolean m_isLeaf;
99   
100  /**True if node has or splits on constructor */
101  protected boolean m_hasConstr=true;
102   
103  /** Constructor error */
104  protected double  m_constError=0;
105   
106  /** Confidence level */
107  protected float m_CF = 0.10f; 
108                       
109  /**
110   * Method for building a Functional Tree (only called for the root node).
111   * Grows an initial Functional Tree.
112   *
113   * @param data the data to train with
114   * @throws Exception if something goes wrong
115   */
116  public abstract void buildClassifier(Instances data) throws Exception;
117
118  /**
119   * Abstract method for building the tree structure.
120   * Builds a logistic model, splits the node and recursively builds tree for child nodes.
121   * @param data the training data passed on to this node
122   * @param higherRegressions An array of regression functions produced by LogitBoost at higher
123   * levels in the tree. They represent a logistic regression model that is refined locally
124   * at this node.
125   * @param totalInstanceWeight the total number of training examples
126   * @param higherNumParameters effective number of parameters in the logistic regression model built
127   * in parent nodes
128   * @throws Exception if something goes wrong
129   */
130  public abstract void buildTree(Instances data, SimpleLinearRegression[][] higherRegressions, 
131                                 double totalInstanceWeight, double higherNumParameters) throws Exception;
132   
133  /**
134   * Abstract Method that prunes a tree using C4.5 pruning procedure.
135   *
136   * @exception Exception if something goes wrong
137   */
138  public abstract double prune() throws Exception; 
139 
140  /** Inserts new attributes in current dataset or instance
141   *
142   * @exception Exception if something goes wrong
143   */
144  protected Instances insertNewAttr(Instances data) throws Exception{
145   
146    int i;
147    for (i=0; i<data.classAttribute().numValues(); i++)
148      {
149        data.insertAttributeAt( new Attribute("N"+ i), i); 
150      }
151    return data;
152  }
153
154  /** Removes extended attributes in current dataset or instance
155   *
156   * @exception Exception if something goes wrong
157   */
158  protected Instances removeExtAttributes(Instances  data) throws Exception{
159   
160    for (int i=0; i< data.classAttribute().numValues(); i++)
161      {
162        data.deleteAttributeAt(0);
163      }
164    return data;
165  }
166
167  /**
168   * Computes estimated errors for tree.
169   */
170  protected double getEstimatedErrors(){
171
172    double errors = 0;
173    int i;
174
175    if (m_isLeaf)
176      return getEstimatedErrorsForDistribution(m_localModel.distribution());
177    else{
178      for (i=0;i<m_sons.length;i++)
179        errors = errors+ m_sons[i].getEstimatedErrors();
180
181      return errors;
182    }
183  }
184
185  /**
186   * Computes estimated errors for one branch.
187   *
188   * @exception Exception if something goes wrong
189   */
190  protected double getEstimatedErrorsForBranch(Instances data)
191    throws Exception {
192
193    Instances [] localInstances;
194    double errors = 0;
195    int i;
196
197    if (m_isLeaf)
198      return getEstimatedErrorsForDistribution(new Distribution(data));
199    else{
200      Distribution savedDist = m_localModel.distribution();
201      m_localModel.resetDistribution(data);
202      localInstances = (Instances[])m_localModel.split(data);
203      //m_localModel.m_distribution=savedDist;
204      for (i=0;i<m_sons.length;i++)
205        errors = errors+
206          m_sons[i].getEstimatedErrorsForBranch(localInstances[i]);
207      return errors;
208    }
209  }
210
211  /**
212   * Computes estimated errors for leaf.
213   */
214  protected double getEstimatedErrorsForDistribution(Distribution
215                                                     theDistribution){
216    double numInc;
217    double numTotal;
218    if (Utils.eq(theDistribution.total(),0))
219      return 0;
220    else// stats.addErrs returns p - numberofincorrect.=p
221      {
222        numInc=theDistribution.numIncorrect();
223        numTotal=theDistribution.total();
224        return ((Stats.addErrs(numTotal, numInc,m_CF)) + numInc)/numTotal;
225      }
226
227  }
228
229  /**
230   * Computes estimated errors for Constructor Model.
231   */
232  protected double getEtimateConstModel(Distribution theDistribution){
233    double numInc;
234    double numTotal;
235    if (Utils.eq(theDistribution.total(),0))
236      return 0;
237    else// stats.addErrs returns p - numberofincorrect.=p
238      {
239        numTotal=theDistribution.total();
240        return ((Stats.addErrs(numTotal,m_constError,m_CF)) + m_constError)/numTotal;
241      }
242  }
243   
244
245  /**
246   * Method to count the number of inner nodes in the tree
247   * @return the number of inner nodes
248   */
249  public int getNumInnerNodes(){
250    if (m_isLeaf) return 0;
251    int numNodes = 1;
252    for (int i = 0; i < m_sons.length; i++) numNodes += m_sons[i].getNumInnerNodes();
253    return numNodes;
254  }
255
256  /**
257   * Returns the number of leaves in the tree.
258   * Leaves are only counted if their logistic model has changed compared to the one of the parent node.
259   * @return the number of leaves
260   */
261  public int getNumLeaves(){
262    int numLeaves;
263    if (!m_isLeaf) {
264      numLeaves = 0;
265      int numEmptyLeaves = 0;
266      for (int i = 0; i < m_sons.length; i++) {
267        numLeaves += m_sons[i].getNumLeaves();
268        if (m_sons[i].m_isLeaf && !m_sons[i].hasModels()) numEmptyLeaves++;
269      }
270      if (numEmptyLeaves > 1) {
271        numLeaves -= (numEmptyLeaves - 1);
272      }
273    } else {
274      numLeaves = 1;
275    }     
276    return numLeaves;   
277  }
278
279
280     
281  /**
282   * Merges two arrays of regression functions into one
283   * @param a1 one array
284   * @param a2 the other array
285   *
286   * @return an array that contains all entries from both input arrays
287   */
288  protected SimpleLinearRegression[][] mergeArrays(SimpleLinearRegression[][] a1,       
289                                                   SimpleLinearRegression[][] a2){
290    int numModels1 = a1[0].length;
291    int numModels2 = a2[0].length;             
292       
293    SimpleLinearRegression[][] result =
294      new SimpleLinearRegression[m_numClasses][numModels1 + numModels2];
295       
296    for (int i = 0; i < m_numClasses; i++)
297      for (int j = 0; j < numModels1; j++) {
298        result[i][j]  = a1[i][j];
299      }
300    for (int i = 0; i < m_numClasses; i++)
301      for (int j = 0; j < numModels2; j++) result[i][j+numModels1] = a2[i][j];
302    return result;
303  }
304
305  /**
306   * Return a list of all inner nodes in the tree
307   * @return the list of nodes
308   */
309  public Vector getNodes(){
310    Vector nodeList = new Vector();
311    getNodes(nodeList);
312    return nodeList;
313  }
314
315  /**
316   * Fills a list with all inner nodes in the tree
317   *
318   * @param nodeList the list to be filled
319   */
320  public void getNodes(Vector nodeList) {
321    if (!m_isLeaf) {
322      nodeList.add(this);
323      for (int i = 0; i < m_sons.length; i++) m_sons[i].getNodes(nodeList);
324    }   
325  }
326   
327  /**
328   * Returns a numeric version of a set of instances.
329   * All nominal attributes are replaced by binary ones, and the class variable is replaced
330   * by a pseudo-class variable that is used by LogitBoost.
331   */
332  protected Instances getNumericData(Instances train) throws Exception{
333       
334    Instances filteredData = new Instances(train);     
335    m_nominalToBinary = new NominalToBinary();                 
336    m_nominalToBinary.setInputFormat(filteredData);
337    filteredData = Filter.useFilter(filteredData, m_nominalToBinary);   
338
339    return super.getNumericData(filteredData);
340  }
341
342  /**
343   * Computes the F-values of LogitBoost for an instance from the current logistic model at the node
344   * Note that this also takes into account the (partial) logistic model fit at higher levels in
345   * the tree.
346   * @param instance the instance
347   * @return the array of F-values
348   */
349  protected double[] getFs(Instance instance) throws Exception{
350       
351    double [] pred = new double [m_numClasses];
352       
353    //Need to take into account partial model fit at higher levels in the tree (m_higherRegressions)
354    //and the part of the model fit at this node (m_regressions).
355
356    //Fs from m_regressions (use method of LogisticBase)
357    double [] instanceFs = super.getFs(instance);               
358
359    //Fs from m_higherRegressions
360    for (int i = 0; i < m_numHigherRegressions; i++) {
361      double predSum = 0;
362      for (int j = 0; j < m_numClasses; j++) {
363        pred[j] = m_higherRegressions[j][i].classifyInstance(instance);
364        predSum += pred[j];
365      }
366      predSum /= m_numClasses;
367      for (int j = 0; j < m_numClasses; j++) {
368        instanceFs[j] += (pred[j] - predSum) * (m_numClasses - 1) 
369          / m_numClasses;
370      }
371    }
372    return instanceFs; 
373  }
374     
375  /**
376   *
377   * @param probsConst
378   */
379  public int getConstError(double[] probsConst)
380  {
381    return Utils.maxIndex(probsConst);
382  }
383   
384  /**
385   *Returns true if the logistic regression model at this node has changed compared to the
386   *one at the parent node.
387   *@return whether it has changed
388   */
389  public boolean hasModels() {
390    return (m_numRegressions > 0);
391  }
392
393  /**
394   * Returns the class probabilities for an instance according to the logistic model at the node.
395   * @param instance the instance
396   * @return the array of probabilities
397   */
398  public double[] modelDistributionForInstance(Instance instance) throws Exception {
399       
400    //make copy and convert nominal attributes
401    instance = (Instance)instance.copy();               
402    m_nominalToBinary.input(instance);
403    instance = m_nominalToBinary.output();     
404       
405    //set numeric pseudo-class
406    instance.setDataset(m_numericDataHeader);           
407       
408    return probs(getFs(instance));
409  }
410
411  /**
412   * Returns the class probabilities for an instance given by the Functional tree.
413   * @param instance the instance
414   * @return the array of probabilities
415   */
416  public abstract double[] distributionForInstance(Instance instance) throws Exception;
417 
418 
419   
420  /**
421   * Returns a description of the Functional tree (tree structure and logistic models)
422   * @return describing string
423   */
424  public String toString(){     
425    //assign numbers to logistic regression functions at leaves
426    assignLeafModelNumbers(0); 
427    try{
428      StringBuffer text = new StringBuffer();
429           
430      if (m_isLeaf && !m_hasConstr) {
431        text.append(": ");
432        text.append("Class"+"="+ m_leafclass);
433        //text.append("FT_"+m_leafModelNum+":"+getModelParameters());
434      } else {
435               
436        if (m_isLeaf && m_hasConstr) {
437          text.append(": ");
438          text.append("FT_"+m_leafModelNum+":"+getModelParameters());
439                   
440        } else {
441          dumpTree(0,text); 
442        }                   
443      }
444      text.append("\n\nNumber of Leaves  : \t"+numLeaves()+"\n");
445      text.append("\nSize of the Tree : \t"+numNodes()+"\n");   
446               
447      //This prints logistic models after the tree, comment out if only tree should be printed
448      text.append(modelsToString());
449      return text.toString();
450    } catch (Exception e){
451      return "Can't print logistic model tree";
452    }
453  }
454   
455  /**
456   * Returns the number of leaves (normal count).
457   * @return the number of leaves
458   */
459  public int numLeaves() {     
460    if (m_isLeaf) return 1;     
461    int numLeaves = 0;
462    for (int i = 0; i < m_sons.length; i++) numLeaves += m_sons[i].numLeaves();
463    return numLeaves;
464  }
465   
466  /**
467   * Returns the number of nodes.
468   * @return the number of nodes
469   */
470  public int numNodes() {
471    if (m_isLeaf) return 1;     
472    int numNodes = 1;
473    for (int i = 0; i < m_sons.length; i++) numNodes += m_sons[i].numNodes();
474    return numNodes;
475  }
476   
477  /**
478   * Returns a string describing the number of LogitBoost iterations performed at this node, the total number
479   * of LogitBoost iterations performed (including iterations at higher levels in the tree), and the number
480   * of training examples at this node.
481   * @return the describing string
482   */
483  public String getModelParameters(){
484       
485    StringBuffer text = new StringBuffer();
486    int numModels = m_numRegressions+m_numHigherRegressions;
487    text.append(m_numRegressions+"/"+numModels+" ("+m_numInstances+")");
488    return text.toString();
489  }
490       
491  /**
492   * Help method for printing tree structure.
493   *
494   * @throws Exception if something goes wrong
495   */
496  protected void dumpTree(int depth,StringBuffer text) 
497    throws Exception {
498       
499    for (int i = 0; i < m_sons.length; i++) {
500      text.append("\n");
501      for (int j = 0; j < depth; j++)
502        text.append("|   ");
503      if(m_hasConstr)
504        text.append(m_localModel.leftSide(m_train)+ "#" + m_id);
505      else 
506        text.append(m_localModel.leftSide(m_train)); 
507      text.append(m_localModel.rightSide(i, m_train) );
508      if (m_sons[i].m_isLeaf && m_sons[i].m_hasConstr ) {
509        text.append(": ");
510        text.append("FT_"+m_sons[i].m_leafModelNum+":"+m_sons[i].getModelParameters());
511      }else {               
512        if(m_sons[i].m_isLeaf && !m_sons[i].m_hasConstr)
513          {
514            text.append(": ");
515            text.append("Class"+"="+ m_sons[i].m_leafclass); 
516          }
517        else{
518           
519          m_sons[i].dumpTree(depth+1,text);
520        }
521      }
522    }
523  }
524
525  /**
526   * Assigns unique IDs to all nodes in the tree
527   */
528  public int assignIDs(int lastID) {
529       
530    int currLastID = lastID + 1;
531       
532    m_id = currLastID;
533    if (m_sons != null) {
534      for (int i = 0; i < m_sons.length; i++) {
535        currLastID = m_sons[i].assignIDs(currLastID);
536      }
537    }
538    return currLastID;
539  }
540   
541  /**
542   * Assigns numbers to the logistic regression models at the leaves of the tree
543   */
544  public int assignLeafModelNumbers(int leafCounter) {
545    if (!m_isLeaf) {
546      m_leafModelNum = 0;
547      for (int i = 0; i < m_sons.length; i++){
548        leafCounter = m_sons[i].assignLeafModelNumbers(leafCounter);
549      }
550    } else {
551      leafCounter++;
552      m_leafModelNum = leafCounter;
553    } 
554    return leafCounter;
555  }
556
557  /**
558   * Returns an array containing the coefficients of the logistic regression function at this node.
559   * @return the array of coefficients, first dimension is the class, second the attribute.
560   */
561  protected double[][] getCoefficients(){
562       
563    //Need to take into account partial model fit at higher levels in the tree (m_higherRegressions)
564    //and the part of the model fit at this node (m_regressions).
565       
566    //get coefficients from m_regressions: use method of LogisticBase
567    double[][] coefficients = super.getCoefficients();
568    //get coefficients from m_higherRegressions:
569    double constFactor = (double)(m_numClasses - 1) / (double)m_numClasses; // (J - 1)/J
570    for (int j = 0; j < m_numClasses; j++) {
571      for (int i = 0; i < m_numHigherRegressions; i++) {               
572        double slope = m_higherRegressions[j][i].getSlope();
573        double intercept = m_higherRegressions[j][i].getIntercept();
574        int attribute = m_higherRegressions[j][i].getAttributeIndex();
575        coefficients[j][0] += constFactor * intercept;
576        coefficients[j][attribute + 1] += constFactor * slope;
577      }
578    }
579
580    return coefficients;
581  }
582   
583  /**
584   * Returns a string describing the logistic regression function at the node.
585   */
586  public String modelsToString(){
587       
588    StringBuffer text = new StringBuffer();
589    if (m_isLeaf && m_hasConstr) {
590      text.append("FT_"+m_leafModelNum+":"+super.toString());
591           
592    }else{
593      if (!m_isLeaf && m_hasConstr) {
594        if (m_modelSelection instanceof BinC45ModelSelection){
595          text.append("FT_N"+((BinC45Split)m_localModel).attIndex()+"#"+m_id +":"+super.toString()); 
596        }else{
597          text.append("FT_N"+((C45Split)m_localModel).attIndex()+"#"+m_id +":"+super.toString());
598        }
599        for (int i = 0; i < m_sons.length; i++) { 
600          text.append("\n"+ m_sons[i].modelsToString());
601        }
602      }else{
603        if (!m_isLeaf && !m_hasConstr) 
604          {
605            for (int i = 0; i < m_sons.length; i++) { 
606              text.append("\n"+ m_sons[i].modelsToString());
607            }
608          }else{
609          if (m_isLeaf && !m_hasConstr)
610            {
611              text.append("");
612            }
613        }
614               
615      }
616    }
617       
618    return text.toString();
619  }
620
621  /**
622   * Returns graph describing the tree.
623   *
624   * @throws Exception if something goes wrong
625   */
626  public String graph() throws Exception {
627       
628    StringBuffer text = new StringBuffer();
629       
630    assignIDs(-1);
631    assignLeafModelNumbers(0);
632    text.append("digraph FTree {\n");
633    if (m_isLeaf && m_hasConstr) {
634      text.append("N" + m_id + " [label=\"FT_"+m_leafModelNum+":"+getModelParameters()+"\" " + 
635                  "shape=box style=filled");
636      text.append("]\n");
637    }else{
638      if (m_isLeaf && !m_hasConstr){
639        text.append("N" + m_id + " [label=\"Class="+m_leafclass+ "\" " + 
640                    "shape=box style=filled");
641        text.append("]\n");
642             
643      }else {
644        text.append("N" + m_id
645                    + " [label=\"" + 
646                    m_localModel.leftSide(m_train) + "\" ");
647        text.append("]\n");
648        graphTree(text);
649      }
650    }
651    return text.toString() +"}\n";
652  }
653
654  /**
655   * Helper function for graph description of tree
656   *
657   * @throws Exception if something goes wrong
658   */
659  protected void graphTree(StringBuffer text) throws Exception {
660       
661    for (int i = 0; i < m_sons.length; i++) {
662      text.append("N" + m_id 
663                  + "->" + 
664                  "N" + m_sons[i].m_id +
665                  " [label=\"" + m_localModel.rightSide(i,m_train).trim() + 
666                  "\"]\n");
667      if (m_sons[i].m_isLeaf && m_sons[i].m_hasConstr) {
668        text.append("N" +m_sons[i].m_id + " [label=\"FT_"+m_sons[i].m_leafModelNum+":"+
669                    m_sons[i].getModelParameters()+"\" " + "shape=box style=filled");
670        text.append("]\n");
671      } else { 
672        if (m_sons[i].m_isLeaf && !m_sons[i].m_hasConstr) {
673          text.append("N" +m_sons[i].m_id + " [label=\"Class="+m_sons[i].m_leafclass+"\" " + "shape=box style=filled");
674          text.append("]\n");
675        }else{
676          text.append("N" + m_sons[i].m_id +
677                      " [label=\""+m_sons[i].m_localModel.leftSide(m_train) + 
678                      "\" ");
679          text.append("]\n");
680          m_sons[i].graphTree(text);
681        }
682      }
683    } 
684  } 
685
686  /**
687   * Cleanup in order to save memory.
688   */
689  public void cleanup() {
690    super.cleanup();
691    if (!m_isLeaf) {
692      for (int i = 0; i < m_sons.length; i++) m_sons[i].cleanup();
693    }
694  }
695 
696  /**
697   * Returns the revision string.
698   *
699   * @return            the revision
700   */
701  public String getRevision() {
702    return RevisionUtils.extract("$Revision: 4899 $");
703  }
704}
Note: See TracBrowser for help on using the repository browser.