source: src/main/java/weka/classifiers/trees/lmt/LMTNode.java @ 28

Last change on this file since 28 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 29.2 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    LMTNode.java
19 *    Copyright (C) 2003 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.trees.lmt;
24
25import weka.classifiers.Evaluation;
26import weka.classifiers.functions.SimpleLinearRegression;
27import weka.classifiers.trees.j48.ClassifierSplitModel;
28import weka.classifiers.trees.j48.ModelSelection;
29import weka.core.Instance;
30import weka.core.Instances;
31import weka.core.RevisionHandler;
32import weka.core.RevisionUtils;
33import weka.filters.Filter;
34import weka.filters.supervised.attribute.NominalToBinary;
35
36import java.util.Collections;
37import java.util.Comparator;
38import java.util.Vector;
39
40/**
41 * Auxiliary class for list of LMTNodes
42 */
43class CompareNode 
44    implements Comparator, RevisionHandler {
45
46    /**
47     * Compares its two arguments for order.
48     *
49     * @param o1 first object
50     * @param o2 second object
51     * @return a negative integer, zero, or a positive integer as the first
52     *         argument is less than, equal to, or greater than the second.
53     */
54    public int compare(Object o1, Object o2) {         
55        if ( ((LMTNode)o1).m_alpha < ((LMTNode)o2).m_alpha) return -1;
56        if ( ((LMTNode)o1).m_alpha > ((LMTNode)o2).m_alpha) return 1;
57        return 0;       
58    }       
59   
60    /**
61     * Returns the revision string.
62     *
63     * @return          the revision
64     */
65    public String getRevision() {
66      return RevisionUtils.extract("$Revision: 1.8 $");
67    }
68}
69
70/**
71 * Class for logistic model tree structure.
72 *
73 *
74 * @author Niels Landwehr
75 * @author Marc Sumner
76 * @version $Revision: 1.8 $
77 */
78public class LMTNode 
79    extends LogisticBase {
80 
81    /** for serialization */
82    static final long serialVersionUID = 1862737145870398755L;
83   
84    /** Total number of training instances. */
85    protected double m_totalInstanceWeight;
86   
87    /** Node id*/
88    protected int m_id;
89   
90    /** ID of logistic model at leaf*/
91    protected int m_leafModelNum;
92 
93    /** Alpha-value (for pruning) at the node*/
94    public double m_alpha;
95   
96    /** Weighted number of training examples currently misclassified by the logistic model at the node*/ 
97    public double m_numIncorrectModel;
98
99    /** Weighted number of training examples currently misclassified by the subtree rooted at the node*/
100    public double m_numIncorrectTree;
101
102    /**minimum number of instances at which a node is considered for splitting*/
103    protected int m_minNumInstances;
104   
105    /**ModelSelection object (for splitting)*/
106    protected ModelSelection m_modelSelection;     
107
108    /**Filter to convert nominal attributes to binary*/
109    protected NominalToBinary m_nominalToBinary; 
110   
111    /**Simple regression functions fit by LogitBoost at higher levels in the tree*/
112    protected SimpleLinearRegression[][] m_higherRegressions;
113   
114    /**Number of simple regression functions fit by LogitBoost at higher levels in the tree*/
115    protected int m_numHigherRegressions = 0;
116   
117    /**Number of folds for CART pruning*/
118    protected static int m_numFoldsPruning = 5;
119
120    /**Use heuristic that determines the number of LogitBoost iterations only once in the beginning? */
121    protected boolean m_fastRegression;
122   
123    /**Number of instances at the node*/
124    protected int m_numInstances;   
125
126    /**The ClassifierSplitModel (for splitting)*/
127    protected ClassifierSplitModel m_localModel; 
128 
129    /**Array of children of the node*/
130    protected LMTNode[] m_sons;           
131
132    /**True if node is leaf*/
133    protected boolean m_isLeaf;                   
134
135    /**
136     * Constructor for logistic model tree node.
137     *
138     * @param modelSelection selection method for local splitting model
139     * @param numBoostingIterations sets the numBoostingIterations parameter
140     * @param fastRegression sets the fastRegression parameter
141     * @param errorOnProbabilities Use error on probabilities for stopping criterion of LogitBoost?
142     * @param minNumInstances minimum number of instances at which a node is considered for splitting
143     */
144    public LMTNode(ModelSelection modelSelection, int numBoostingIterations, 
145                   boolean fastRegression, 
146                   boolean errorOnProbabilities, int minNumInstances,
147                   double weightTrimBeta, boolean useAIC) {
148        m_modelSelection = modelSelection;
149        m_fixedNumIterations = numBoostingIterations;     
150        m_fastRegression = fastRegression;
151        m_errorOnProbabilities = errorOnProbabilities;
152        m_minNumInstances = minNumInstances;
153        m_maxIterations = 200;
154        setWeightTrimBeta(weightTrimBeta);
155        setUseAIC(useAIC);
156    }         
157   
158    /**
159     * Method for building a logistic model tree (only called for the root node).
160     * Grows an initial logistic model tree and prunes it back using the CART pruning scheme.
161     *
162     * @param data the data to train with
163     * @throws Exception if something goes wrong
164     */
165    public void buildClassifier(Instances data) throws Exception{
166       
167        //heuristic to avoid cross-validating the number of LogitBoost iterations
168        //at every node: build standalone logistic model and take its optimum number
169        //of iteration everywhere in the tree.
170        if (m_fastRegression && (m_fixedNumIterations < 0)) m_fixedNumIterations = tryLogistic(data);
171       
172        //Need to cross-validate alpha-parameter for CART-pruning
173        Instances cvData = new Instances(data);
174        cvData.stratify(m_numFoldsPruning);
175       
176        double[][] alphas = new double[m_numFoldsPruning][];
177        double[][] errors = new double[m_numFoldsPruning][];
178       
179        for (int i = 0; i < m_numFoldsPruning; i++) {
180            //for every fold, grow tree on training set...
181            Instances train = cvData.trainCV(m_numFoldsPruning, i);
182            Instances test = cvData.testCV(m_numFoldsPruning, i);
183           
184            buildTree(train, null, train.numInstances() , 0);   
185           
186            int numNodes = getNumInnerNodes();     
187            alphas[i] = new double[numNodes + 2];
188            errors[i] = new double[numNodes + 2];
189           
190            //... then prune back and log alpha-values and errors on test set
191            prune(alphas[i], errors[i], test);             
192        }
193       
194        //build tree using all the data
195        buildTree(data, null, data.numInstances(), 0);
196        int numNodes = getNumInnerNodes();
197
198        double[] treeAlphas = new double[numNodes + 2]; 
199       
200        //prune back and log alpha-values     
201        int iterations = prune(treeAlphas, null, null);
202       
203        double[] treeErrors = new double[numNodes + 2];
204       
205        for (int i = 0; i <= iterations; i++){
206            //compute midpoint alphas
207            double alpha = Math.sqrt(treeAlphas[i] * treeAlphas[i+1]);
208            double error = 0;
209           
210            //compute error estimate for final trees from the midpoint-alphas and the error estimates gotten in
211            //the cross-validation
212            for (int k = 0; k < m_numFoldsPruning; k++) {
213                int l = 0;
214                while (alphas[k][l] <= alpha) l++;
215                error += errors[k][l - 1];
216            }
217
218            treeErrors[i] = error;                         
219        }
220       
221        //find best alpha
222        int best = -1;
223        double bestError = Double.MAX_VALUE;
224        for (int i = iterations; i >= 0; i--) {
225            if (treeErrors[i] < bestError) {
226                bestError = treeErrors[i];
227                best = i;
228            }       
229        }
230
231        double bestAlpha = Math.sqrt(treeAlphas[best] * treeAlphas[best + 1]);         
232       
233        //"unprune" final tree (faster than regrowing it)
234        unprune();
235
236        //CART-prune it with best alpha
237        prune(bestAlpha);                       
238        cleanup();     
239    }
240
241    /**
242     * Method for building the tree structure.
243     * Builds a logistic model, splits the node and recursively builds tree for child nodes.
244     * @param data the training data passed on to this node
245     * @param higherRegressions An array of regression functions produced by LogitBoost at higher
246     * levels in the tree. They represent a logistic regression model that is refined locally
247     * at this node.
248     * @param totalInstanceWeight the total number of training examples
249     * @param higherNumParameters effective number of parameters in the logistic regression model built
250     * in parent nodes
251     * @throws Exception if something goes wrong
252     */
253    public void buildTree(Instances data, SimpleLinearRegression[][] higherRegressions, 
254                          double totalInstanceWeight, double higherNumParameters) throws Exception{
255
256        //save some stuff
257        m_totalInstanceWeight = totalInstanceWeight;
258        m_train = new Instances(data);
259       
260        m_isLeaf = true;
261        m_sons = null;
262       
263        m_numInstances = m_train.numInstances();
264        m_numClasses = m_train.numClasses();                           
265       
266        //init
267        m_numericData = getNumericData(m_train);                 
268        m_numericDataHeader = new Instances(m_numericData, 0);
269       
270        m_regressions = initRegressions();
271        m_numRegressions = 0;
272       
273        if (higherRegressions != null) m_higherRegressions = higherRegressions;
274        else m_higherRegressions = new SimpleLinearRegression[m_numClasses][0]; 
275
276        m_numHigherRegressions = m_higherRegressions[0].length; 
277       
278        m_numParameters = higherNumParameters;
279       
280        //build logistic model
281        if (m_numInstances >= m_numFoldsBoosting) {
282            if (m_fixedNumIterations > 0){
283                performBoosting(m_fixedNumIterations);
284            } else if (getUseAIC()) {
285                performBoostingInfCriterion();
286            } else {
287                performBoostingCV();
288            }
289        }
290       
291        m_numParameters += m_numRegressions;
292       
293        //only keep the simple regression functions that correspond to the selected number of LogitBoost iterations
294        m_regressions = selectRegressions(m_regressions);
295
296        boolean grow;
297        //split node if more than minNumInstances...
298        if (m_numInstances > m_minNumInstances) {
299            //split node: either splitting on class value (a la C4.5) or splitting on residuals
300            if (m_modelSelection instanceof ResidualModelSelection) {   
301                //need ps/Ys/Zs/weights
302                double[][] probs = getProbs(getFs(m_numericData));
303                double[][] trainYs = getYs(m_train);
304                double[][] dataZs = getZs(probs, trainYs);
305                double[][] dataWs = getWs(probs, trainYs);
306                m_localModel = ((ResidualModelSelection)m_modelSelection).selectModel(m_train, dataZs, dataWs); 
307            } else {
308                m_localModel = m_modelSelection.selectModel(m_train);   
309            }
310            //... and valid split found
311            grow = (m_localModel.numSubsets() > 1);
312        } else {
313            grow = false;
314        }
315       
316        if (grow) {     
317            //create and build children of node
318            m_isLeaf = false;               
319            Instances[] localInstances = m_localModel.split(m_train);       
320            m_sons = new LMTNode[m_localModel.numSubsets()];
321            for (int i = 0; i < m_sons.length; i++) {
322                m_sons[i] = new LMTNode(m_modelSelection, m_fixedNumIterations, 
323                                         m_fastRegression, 
324                                         m_errorOnProbabilities,m_minNumInstances,
325                                        getWeightTrimBeta(), getUseAIC());
326                //the "higherRegressions" (partial logistic model fit at higher levels in the tree) passed
327                //on to the children are the "higherRegressions" at this node plus the regressions added
328                //at this node (m_regressions).
329                m_sons[i].buildTree(localInstances[i],
330                                  mergeArrays(m_regressions, m_higherRegressions), m_totalInstanceWeight, m_numParameters);             
331                localInstances[i] = null;
332            }       
333        } 
334    }
335
336    /**
337     * Prunes a logistic model tree using the CART pruning scheme, given a
338     * cost-complexity parameter alpha.
339     *
340     * @param alpha the cost-complexity measure 
341     * @throws Exception if something goes wrong
342     */
343    public void prune(double alpha) throws Exception {
344       
345        Vector nodeList;       
346        CompareNode comparator = new CompareNode();     
347       
348        //determine training error of logistic models and subtrees, and calculate alpha-values from them
349        modelErrors();
350        treeErrors();
351        calculateAlphas();
352       
353        //get list of all inner nodes in the tree
354        nodeList = getNodes();
355               
356        boolean prune = (nodeList.size() > 0);
357       
358        while (prune) {
359           
360            //select node with minimum alpha
361            LMTNode nodeToPrune = (LMTNode)Collections.min(nodeList,comparator);
362           
363            //want to prune if its alpha is smaller than alpha
364            if (nodeToPrune.m_alpha > alpha) break; 
365           
366            nodeToPrune.m_isLeaf = true;
367            nodeToPrune.m_sons = null;
368           
369            //update tree errors and alphas
370            treeErrors();
371            calculateAlphas();
372
373            nodeList = getNodes();
374            prune = (nodeList.size() > 0);       
375        } 
376    }
377
378    /**
379     * Method for performing one fold in the cross-validation of the cost-complexity parameter.
380     * Generates a sequence of alpha-values with error estimates for the corresponding (partially pruned)
381     * trees, given the test set of that fold.
382     * @param alphas array to hold the generated alpha-values
383     * @param errors array to hold the corresponding error estimates
384     * @param test test set of that fold (to obtain error estimates)
385     * @throws Exception if something goes wrong
386     */
387    public int prune(double[] alphas, double[] errors, Instances test) throws Exception {
388       
389        Vector nodeList; 
390       
391        CompareNode comparator = new CompareNode();     
392
393        //determine training error of logistic models and subtrees, and calculate alpha-values from them
394        modelErrors();
395        treeErrors();
396        calculateAlphas();
397
398        //get list of all inner nodes in the tree
399        nodeList = getNodes();
400       
401        boolean prune = (nodeList.size() > 0);                         
402
403        //alpha_0 is always zero (unpruned tree)
404        alphas[0] = 0;
405
406        Evaluation eval;
407
408        //error of unpruned tree
409        if (errors != null) {
410            eval = new Evaluation(test);
411            eval.evaluateModel(this, test);
412            errors[0] = eval.errorRate(); 
413        }       
414       
415        int iteration = 0;
416        while (prune) {
417
418            iteration++;
419           
420            //get node with minimum alpha
421            LMTNode nodeToPrune = (LMTNode)Collections.min(nodeList,comparator);
422
423            nodeToPrune.m_isLeaf = true;
424            //Do not set m_sons null, want to unprune
425           
426            //get alpha-value of node
427            alphas[iteration] = nodeToPrune.m_alpha;
428           
429            //log error
430            if (errors != null) {
431                eval = new Evaluation(test);
432                eval.evaluateModel(this, test);
433                errors[iteration] = eval.errorRate(); 
434            }
435
436            //update errors/alphas
437            treeErrors();
438            calculateAlphas();
439
440            nodeList = getNodes();         
441            prune = (nodeList.size() > 0);         
442        } 
443       
444        //set last alpha 1 to indicate end
445        alphas[iteration + 1] = 1.0;   
446        return iteration;
447    }
448
449
450    /**
451     *Method to "unprune" a logistic model tree.
452     *Sets all leaf-fields to false.
453     *Faster than re-growing the tree because the logistic models do not have to be fit again.
454     */
455    protected void unprune() {
456        if (m_sons != null) {
457            m_isLeaf = false;
458            for (int i = 0; i < m_sons.length; i++) m_sons[i].unprune();
459        }
460    }
461
462    /**
463     *Determines the optimum number of LogitBoost iterations to perform by building a standalone logistic
464     *regression function on the training data. Used for the heuristic that avoids cross-validating this
465     *number again at every node.
466     *@param data training instances for the logistic model
467     *@throws Exception if something goes wrong
468     */
469    protected int tryLogistic(Instances data) throws Exception{
470       
471        //convert nominal attributes
472        Instances filteredData = new Instances(data);   
473        NominalToBinary nominalToBinary = new NominalToBinary();                       
474        nominalToBinary.setInputFormat(filteredData);
475        filteredData = Filter.useFilter(filteredData, nominalToBinary); 
476       
477        LogisticBase logistic = new LogisticBase(0,true,m_errorOnProbabilities);
478       
479        //limit LogitBoost to 200 iterations (speed)
480        logistic.setMaxIterations(200);
481        logistic.setWeightTrimBeta(getWeightTrimBeta()); // Not in Marc's code. Added by Eibe.
482        logistic.setUseAIC(getUseAIC());
483        logistic.buildClassifier(filteredData);
484       
485        //return best number of iterations
486        return logistic.getNumRegressions(); 
487    }
488
489    /**
490     * Method to count the number of inner nodes in the tree
491     * @return the number of inner nodes
492     */
493    public int getNumInnerNodes(){
494        if (m_isLeaf) return 0;
495        int numNodes = 1;
496        for (int i = 0; i < m_sons.length; i++) numNodes += m_sons[i].getNumInnerNodes();
497        return numNodes;
498    }
499
500    /**
501     * Returns the number of leaves in the tree.
502     * Leaves are only counted if their logistic model has changed compared to the one of the parent node.
503     * @return the number of leaves
504     */
505     public int getNumLeaves(){
506        int numLeaves;
507        if (!m_isLeaf) {
508            numLeaves = 0;
509            int numEmptyLeaves = 0;
510            for (int i = 0; i < m_sons.length; i++) {
511                numLeaves += m_sons[i].getNumLeaves();
512                if (m_sons[i].m_isLeaf && !m_sons[i].hasModels()) numEmptyLeaves++;
513            }
514            if (numEmptyLeaves > 1) {
515                numLeaves -= (numEmptyLeaves - 1);
516            }
517        } else {
518            numLeaves = 1;
519        }         
520        return numLeaves;       
521    }
522
523    /**
524     *Updates the numIncorrectModel field for all nodes. This is needed for calculating the alpha-values.
525     */
526    public void modelErrors() throws Exception{
527               
528        Evaluation eval = new Evaluation(m_train);
529               
530        if (!m_isLeaf) {
531            m_isLeaf = true;
532            eval.evaluateModel(this, m_train);
533            m_isLeaf = false;
534            m_numIncorrectModel = eval.incorrect();
535            for (int i = 0; i < m_sons.length; i++) m_sons[i].modelErrors();
536        } else {
537            eval.evaluateModel(this, m_train);
538            m_numIncorrectModel = eval.incorrect();
539        }
540    }
541   
542    /**
543     *Updates the numIncorrectTree field for all nodes. This is needed for calculating the alpha-values.
544     */
545    public void treeErrors(){
546        if (m_isLeaf) {
547            m_numIncorrectTree = m_numIncorrectModel;
548        } else {
549            m_numIncorrectTree = 0;
550            for (int i = 0; i < m_sons.length; i++) {
551                m_sons[i].treeErrors();
552                m_numIncorrectTree += m_sons[i].m_numIncorrectTree;
553            }   
554        }       
555    }
556
557    /**
558     *Updates the alpha field for all nodes.
559     */
560    public void calculateAlphas() throws Exception {           
561               
562        if (!m_isLeaf) {       
563            double errorDiff = m_numIncorrectModel - m_numIncorrectTree;                   
564           
565            if (errorDiff <= 0) {
566                //split increases training error (should not normally happen).
567                //prune it instantly.
568                m_isLeaf = true;
569                m_sons = null;
570                m_alpha = Double.MAX_VALUE;             
571            } else {
572                //compute alpha
573                errorDiff /= m_totalInstanceWeight;             
574                m_alpha = errorDiff / (double)(getNumLeaves() - 1);
575               
576                for (int i = 0; i < m_sons.length; i++) m_sons[i].calculateAlphas();
577            }
578        } else {           
579            //alpha = infinite for leaves (do not want to prune)
580            m_alpha = Double.MAX_VALUE;
581        }
582    }
583   
584    /**
585     * Merges two arrays of regression functions into one
586     * @param a1 one array
587     * @param a2 the other array
588     *
589     * @return an array that contains all entries from both input arrays
590     */
591    protected SimpleLinearRegression[][] mergeArrays(SimpleLinearRegression[][] a1,     
592                                                           SimpleLinearRegression[][] a2){
593        int numModels1 = a1[0].length;
594        int numModels2 = a2[0].length;         
595       
596        SimpleLinearRegression[][] result =
597            new SimpleLinearRegression[m_numClasses][numModels1 + numModels2];
598       
599        for (int i = 0; i < m_numClasses; i++)
600            for (int j = 0; j < numModels1; j++) {
601                result[i][j]  = a1[i][j];
602            }
603        for (int i = 0; i < m_numClasses; i++)
604            for (int j = 0; j < numModels2; j++) result[i][j+numModels1] = a2[i][j];
605        return result;
606    }
607
608    /**
609     * Return a list of all inner nodes in the tree
610     * @return the list of nodes
611     */
612    public Vector getNodes(){
613        Vector nodeList = new Vector();
614        getNodes(nodeList);
615        return nodeList;
616    }
617
618    /**
619     * Fills a list with all inner nodes in the tree
620     *
621     * @param nodeList the list to be filled
622     */
623    public void getNodes(Vector nodeList) {
624        if (!m_isLeaf) {
625            nodeList.add(this);
626            for (int i = 0; i < m_sons.length; i++) m_sons[i].getNodes(nodeList);
627        }       
628    }
629   
630    /**
631     * Returns a numeric version of a set of instances.
632     * All nominal attributes are replaced by binary ones, and the class variable is replaced
633     * by a pseudo-class variable that is used by LogitBoost.
634     */
635    protected Instances getNumericData(Instances train) throws Exception{
636       
637        Instances filteredData = new Instances(train); 
638        m_nominalToBinary = new NominalToBinary();                     
639        m_nominalToBinary.setInputFormat(filteredData);
640        filteredData = Filter.useFilter(filteredData, m_nominalToBinary);       
641
642        return super.getNumericData(filteredData);
643    }
644
645    /**
646     * Computes the F-values of LogitBoost for an instance from the current logistic model at the node
647     * Note that this also takes into account the (partial) logistic model fit at higher levels in
648     * the tree.
649     * @param instance the instance
650     * @return the array of F-values
651     */
652    protected double[] getFs(Instance instance) throws Exception{
653       
654        double [] pred = new double [m_numClasses];
655       
656        //Need to take into account partial model fit at higher levels in the tree (m_higherRegressions)
657        //and the part of the model fit at this node (m_regressions).
658
659        //Fs from m_regressions (use method of LogisticBase)
660        double [] instanceFs = super.getFs(instance);           
661
662        //Fs from m_higherRegressions
663        for (int i = 0; i < m_numHigherRegressions; i++) {
664            double predSum = 0;
665            for (int j = 0; j < m_numClasses; j++) {
666                pred[j] = m_higherRegressions[j][i].classifyInstance(instance);
667                predSum += pred[j];
668            }
669            predSum /= m_numClasses;
670            for (int j = 0; j < m_numClasses; j++) {
671                instanceFs[j] += (pred[j] - predSum) * (m_numClasses - 1) 
672                    / m_numClasses;
673            }
674        }
675        return instanceFs; 
676    }
677   
678    /**
679     *Returns true if the logistic regression model at this node has changed compared to the
680     *one at the parent node.
681     *@return whether it has changed
682     */
683    public boolean hasModels() {
684        return (m_numRegressions > 0);
685    }
686
687    /**
688     * Returns the class probabilities for an instance according to the logistic model at the node.
689     * @param instance the instance
690     * @return the array of probabilities
691     */
692    public double[] modelDistributionForInstance(Instance instance) throws Exception {
693       
694        //make copy and convert nominal attributes
695        instance = (Instance)instance.copy();           
696        m_nominalToBinary.input(instance);
697        instance = m_nominalToBinary.output(); 
698       
699        //saet numeric pseudo-class
700        instance.setDataset(m_numericDataHeader);               
701       
702        return probs(getFs(instance));
703    }
704
705    /**
706     * Returns the class probabilities for an instance given by the logistic model tree.
707     * @param instance the instance
708     * @return the array of probabilities
709     */
710    public double[] distributionForInstance(Instance instance) throws Exception {
711       
712        double[] probs;
713       
714        if (m_isLeaf) {     
715            //leaf: use logistic model
716            probs = modelDistributionForInstance(instance);
717        } else {
718            //sort into appropiate child node
719            int branch = m_localModel.whichSubset(instance);
720            probs = m_sons[branch].distributionForInstance(instance);
721        }                       
722        return probs;
723    }
724
725    /**
726     * Returns the number of leaves (normal count).
727     * @return the number of leaves
728     */
729    public int numLeaves() {   
730        if (m_isLeaf) return 1; 
731        int numLeaves = 0;
732        for (int i = 0; i < m_sons.length; i++) numLeaves += m_sons[i].numLeaves();
733        return numLeaves;
734    }
735   
736    /**
737     * Returns the number of nodes.
738     * @return the number of nodes
739     */
740    public int numNodes() {
741        if (m_isLeaf) return 1; 
742        int numNodes = 1;
743        for (int i = 0; i < m_sons.length; i++) numNodes += m_sons[i].numNodes();
744        return numNodes;
745    }
746
747    /**
748     * Returns a description of the logistic model tree (tree structure and logistic models)
749     * @return describing string
750     */
751    public String toString(){   
752        //assign numbers to logistic regression functions at leaves
753        assignLeafModelNumbers(0);     
754        try{
755            StringBuffer text = new StringBuffer();
756           
757            if (m_isLeaf) {
758                text.append(": ");
759                text.append("LM_"+m_leafModelNum+":"+getModelParameters());
760            } else {
761                dumpTree(0,text);                   
762            }
763            text.append("\n\nNumber of Leaves  : \t"+numLeaves()+"\n");
764            text.append("\nSize of the Tree : \t"+numNodes()+"\n");     
765               
766            //This prints logistic models after the tree, comment out if only tree should be printed
767            text.append(modelsToString());
768            return text.toString();
769        } catch (Exception e){
770            return "Can't print logistic model tree";
771        }
772       
773       
774    }
775
776    /**
777     * Returns a string describing the number of LogitBoost iterations performed at this node, the total number
778     * of LogitBoost iterations performed (including iterations at higher levels in the tree), and the number
779     * of training examples at this node.
780     * @return the describing string
781     */
782    public String getModelParameters(){
783       
784        StringBuffer text = new StringBuffer();
785        int numModels = m_numRegressions+m_numHigherRegressions;
786        text.append(m_numRegressions+"/"+numModels+" ("+m_numInstances+")");
787        return text.toString();
788    }
789   
790   
791    /**
792     * Help method for printing tree structure.
793     *
794     * @throws Exception if something goes wrong
795     */
796    protected void dumpTree(int depth,StringBuffer text) 
797        throws Exception {
798       
799        for (int i = 0; i < m_sons.length; i++) {
800            text.append("\n");
801            for (int j = 0; j < depth; j++)
802                text.append("|   ");
803            text.append(m_localModel.leftSide(m_train));
804            text.append(m_localModel.rightSide(i, m_train));
805            if (m_sons[i].m_isLeaf) {
806                text.append(": ");
807                text.append("LM_"+m_sons[i].m_leafModelNum+":"+m_sons[i].getModelParameters());
808            }else
809                m_sons[i].dumpTree(depth+1,text);
810        }
811    }
812
813    /**
814     * Assigns unique IDs to all nodes in the tree
815     */
816    public int assignIDs(int lastID) {
817       
818        int currLastID = lastID + 1;
819       
820        m_id = currLastID;
821        if (m_sons != null) {
822            for (int i = 0; i < m_sons.length; i++) {
823                currLastID = m_sons[i].assignIDs(currLastID);
824            }
825        }
826        return currLastID;
827    }
828   
829    /**
830     * Assigns numbers to the logistic regression models at the leaves of the tree
831     */
832    public int assignLeafModelNumbers(int leafCounter) {
833        if (!m_isLeaf) {
834            m_leafModelNum = 0;
835            for (int i = 0; i < m_sons.length; i++){
836                leafCounter = m_sons[i].assignLeafModelNumbers(leafCounter);
837            }
838        } else {
839            leafCounter++;
840            m_leafModelNum = leafCounter;
841        } 
842        return leafCounter;
843    }
844
845    /**
846     * Returns an array containing the coefficients of the logistic regression function at this node.
847     * @return the array of coefficients, first dimension is the class, second the attribute.
848     */
849    protected double[][] getCoefficients(){
850       
851        //Need to take into account partial model fit at higher levels in the tree (m_higherRegressions)
852        //and the part of the model fit at this node (m_regressions).
853       
854        //get coefficients from m_regressions: use method of LogisticBase
855        double[][] coefficients = super.getCoefficients();
856        //get coefficients from m_higherRegressions:
857        double constFactor = (double)(m_numClasses - 1) / (double)m_numClasses; // (J - 1)/J
858        for (int j = 0; j < m_numClasses; j++) {
859            for (int i = 0; i < m_numHigherRegressions; i++) {         
860                double slope = m_higherRegressions[j][i].getSlope();
861                double intercept = m_higherRegressions[j][i].getIntercept();
862                int attribute = m_higherRegressions[j][i].getAttributeIndex();
863                coefficients[j][0] += constFactor * intercept;
864                coefficients[j][attribute + 1] += constFactor * slope;
865            }
866        }
867
868        return coefficients;
869    }
870   
871    /**
872     * Returns a string describing the logistic regression function at the node.
873     */
874    public String modelsToString(){
875       
876        StringBuffer text = new StringBuffer();
877        if (m_isLeaf) {
878            text.append("LM_"+m_leafModelNum+":"+super.toString());
879        } else {
880            for (int i = 0; i < m_sons.length; i++) {
881                text.append("\n"+m_sons[i].modelsToString());
882            }
883        }
884        return text.toString();     
885    }
886
887    /**
888     * Returns graph describing the tree.
889     *
890     * @throws Exception if something goes wrong
891     */
892    public String graph() throws Exception {
893       
894        StringBuffer text = new StringBuffer();
895       
896        assignIDs(-1);
897        assignLeafModelNumbers(0);
898        text.append("digraph LMTree {\n");
899        if (m_isLeaf) {
900            text.append("N" + m_id + " [label=\"LM_"+m_leafModelNum+":"+getModelParameters()+"\" " + 
901                        "shape=box style=filled");
902            text.append("]\n");
903        }else {
904            text.append("N" + m_id
905                        + " [label=\"" + 
906                        m_localModel.leftSide(m_train) + "\" ");
907            text.append("]\n");
908            graphTree(text);
909        }
910   
911        return text.toString() +"}\n";
912    }
913
914    /**
915     * Helper function for graph description of tree
916     *
917     * @throws Exception if something goes wrong
918     */
919    private void graphTree(StringBuffer text) throws Exception {
920       
921        for (int i = 0; i < m_sons.length; i++) {
922            text.append("N" + m_id 
923                        + "->" + 
924                        "N" + m_sons[i].m_id +
925                        " [label=\"" + m_localModel.rightSide(i,m_train).trim() + 
926                        "\"]\n");
927            if (m_sons[i].m_isLeaf) {
928                text.append("N" +m_sons[i].m_id + " [label=\"LM_"+m_sons[i].m_leafModelNum+":"+
929                            m_sons[i].getModelParameters()+"\" " + "shape=box style=filled");
930                text.append("]\n");
931            } else {
932                text.append("N" + m_sons[i].m_id +
933                            " [label=\""+m_sons[i].m_localModel.leftSide(m_train) + 
934                            "\" ");
935                text.append("]\n");
936                m_sons[i].graphTree(text);
937            }
938        }
939    } 
940   
941    /**
942     * Cleanup in order to save memory.
943     */
944    public void cleanup() {
945        super.cleanup();
946        if (!m_isLeaf) {
947            for (int i = 0; i < m_sons.length; i++) m_sons[i].cleanup();
948        }
949    }
950   
951    /**
952     * Returns the revision string.
953     *
954     * @return          the revision
955     */
956    public String getRevision() {
957      return RevisionUtils.extract("$Revision: 1.8 $");
958    }
959}
Note: See TracBrowser for help on using the repository browser.