source: src/main/java/weka/classifiers/trees/lmt/LogisticBase.java @ 7

Last change on this file since 7 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 32.9 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    LogisticBase.java
19 *    Copyright (C) 2003 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.trees.lmt;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.classifiers.Evaluation;
28import weka.classifiers.functions.SimpleLinearRegression;
29import weka.core.Attribute;
30import weka.core.Instance;
31import weka.core.Instances;
32import weka.core.RevisionUtils;
33import weka.core.Utils;
34import weka.core.WeightedInstancesHandler;
35
36/**
37 * Base/helper class for building logistic regression models with the LogitBoost algorithm.
38 * Used for building logistic model trees (weka.classifiers.trees.lmt.LMT)
39 * and standalone logistic regression (weka.classifiers.functions.SimpleLogistic).
40 *
41 <!-- options-start -->
42 * Valid options are: <p/>
43 *
44 * <pre> -D
45 *  If set, classifier is run in debug mode and
46 *  may output additional info to the console</pre>
47 *
48 <!-- options-end -->
49 *
50 * @author Niels Landwehr
51 * @author Marc Sumner
52 * @version $Revision: 5928 $
53 */
54public class LogisticBase 
55    extends AbstractClassifier
56    implements WeightedInstancesHandler {
57
58    /** for serialization */
59    static final long serialVersionUID = 168765678097825064L;
60 
61    /** Header-only version of the numeric version of the training data*/
62    protected Instances m_numericDataHeader;
63    /**
64     * Numeric version of the training data. Original class is replaced by a numeric pseudo-class.
65     */
66    protected Instances m_numericData;
67   
68    /** Training data */
69    protected Instances m_train;
70   
71    /** Use cross-validation to determine best number of LogitBoost iterations ?*/
72    protected boolean m_useCrossValidation;
73
74    /**Use error on probabilities for stopping criterion of LogitBoost? */
75    protected boolean m_errorOnProbabilities;
76
77    /**Use fixed number of iterations for LogitBoost? (if negative, cross-validate number of iterations)*/
78    protected int m_fixedNumIterations;
79   
80    /**Use heuristic to stop performing LogitBoost iterations earlier?
81     * If enabled, LogitBoost is stopped if the current (local) minimum of the error on a test set as
82     * a function of the number of iterations has not changed for m_heuristicStop iterations.
83     */
84    protected int m_heuristicStop = 50;
85 
86    /**The number of LogitBoost iterations performed.*/
87    protected int m_numRegressions = 0;
88   
89    /**The maximum number of LogitBoost iterations*/
90    protected int m_maxIterations;
91   
92    /**The number of different classes*/
93    protected int m_numClasses;
94
95    /**Array holding the simple regression functions fit by LogitBoost*/
96    protected SimpleLinearRegression[][] m_regressions;
97       
98    /**Number of folds for cross-validating number of LogitBoost iterations*/
99    protected static int m_numFoldsBoosting = 5;
100
101    /**Threshold on the Z-value for LogitBoost*/
102    protected static final double Z_MAX = 3;
103   
104    /** If true, the AIC is used to choose the best iteration*/
105    private boolean m_useAIC = false;
106   
107    /** Effective number of parameters used for AIC / BIC automatic stopping */
108    protected double m_numParameters = 0;
109   
110    /**Threshold for trimming weights. Instances with a weight lower than this (as a percentage
111     * of total weights) are not included in the regression fit.
112     **/
113    protected double m_weightTrimBeta = 0;
114
115    /**
116     * Constructor that creates LogisticBase object with standard options.
117     */
118    public LogisticBase(){
119        m_fixedNumIterations = -1;
120        m_useCrossValidation = true;
121        m_errorOnProbabilities = false; 
122        m_maxIterations = 500;
123        m_useAIC = false;
124        m_numParameters = 0;
125    }
126   
127    /**
128     * Constructor to create LogisticBase object.
129     * @param numBoostingIterations fixed number of iterations for LogitBoost (if negative, use cross-validation or
130     * stopping criterion on the training data).
131     * @param useCrossValidation cross-validate number of LogitBoost iterations (if false, use stopping
132     * criterion on the training data).
133     * @param errorOnProbabilities if true, use error on probabilities
134     * instead of misclassification for stopping criterion of LogitBoost
135     */
136    public LogisticBase(int numBoostingIterations, boolean useCrossValidation, boolean errorOnProbabilities){
137        m_fixedNumIterations = numBoostingIterations;
138        m_useCrossValidation = useCrossValidation;
139        m_errorOnProbabilities = errorOnProbabilities; 
140        m_maxIterations = 500;
141        m_useAIC = false;
142        m_numParameters = 0;
143    }   
144
145    /**
146     * Builds the logistic regression model usiing LogitBoost.
147     *
148     * @param data the training data
149     * @throws Exception if something goes wrong
150     */
151    public void buildClassifier(Instances data) throws Exception {                     
152
153        m_train = new Instances(data);
154       
155        m_numClasses = m_train.numClasses();
156       
157        //init the array of simple regression functions
158        m_regressions = initRegressions();
159        m_numRegressions = 0;
160
161        //get numeric version of the training data (class variable replaced  by numeric pseudo-class)
162        m_numericData = getNumericData(m_train);       
163       
164        //save header info
165        m_numericDataHeader = new Instances(m_numericData, 0);
166       
167       
168        if (m_fixedNumIterations > 0) {
169            //run LogitBoost for fixed number of iterations
170            performBoosting(m_fixedNumIterations);
171        } else if (m_useAIC) { // Marc had this after the test for m_useCrossValidation. Changed by Eibe.
172            //run LogitBoost using information criterion for stopping
173            performBoostingInfCriterion();
174        } else if (m_useCrossValidation) {
175            //cross-validate number of LogitBoost iterations
176            performBoostingCV();
177        } else {
178            //run LogitBoost with number of iterations that minimizes error on the training set
179            performBoosting();
180        }       
181       
182        //only keep the simple regression functions that correspond to the selected number of LogitBoost iterations
183        m_regressions = selectRegressions(m_regressions);       
184    }   
185
186    /**
187     * Runs LogitBoost, determining the best number of iterations by cross-validation.
188     *
189     * @throws Exception if something goes wrong
190     */
191    protected void performBoostingCV() throws Exception{                       
192       
193        //completed iteration keeps track of the number of iterations that have been
194        //performed in every fold (some might stop earlier than others).
195        //Best iteration is selected only from these.
196        int completedIterations = m_maxIterations;
197       
198        Instances allData = new Instances(m_train);
199       
200        allData.stratify(m_numFoldsBoosting);         
201
202        double[] error = new double[m_maxIterations + 1];       
203       
204        for (int i = 0; i < m_numFoldsBoosting; i++) {
205            //split into training/test data in fold
206            Instances train = allData.trainCV(m_numFoldsBoosting,i);
207            Instances test = allData.testCV(m_numFoldsBoosting,i);
208
209            //initialize LogitBoost
210            m_numRegressions = 0;
211            m_regressions = initRegressions();
212
213            //run LogitBoost iterations
214            int iterations = performBoosting(train,test,error,completedIterations);         
215            if (iterations < completedIterations) completedIterations = iterations;         
216        }
217       
218        //determine iteration with minimum error over the folds
219        int bestIteration = getBestIteration(error,completedIterations);
220       
221        //rebuild model on all of the training data
222        m_numRegressions = 0;
223        performBoosting(bestIteration);
224    }   
225   
226    /**
227     * Runs LogitBoost, determining the best number of iterations by an information criterion (currently AIC).
228     */
229    protected void performBoostingInfCriterion() throws Exception{
230       
231        double criterion = 0.0;
232        double bestCriterion = Double.MAX_VALUE;
233        int bestIteration = 0;
234        int noMin = 0;
235       
236        // Variable to keep track of criterion values (AIC)
237        double criterionValue = Double.MAX_VALUE;
238       
239        // initialize Ys/Fs/ps
240        double[][] trainYs = getYs(m_train);
241        double[][] trainFs = getFs(m_numericData);
242        double[][] probs = getProbs(trainFs);
243       
244        // Array with true/false if the attribute is included in the model or not
245        boolean[][] attributes = new boolean[m_numClasses][m_numericDataHeader.numAttributes()];
246       
247        int iteration = 0;
248        while (iteration < m_maxIterations) {
249           
250            //perform single LogitBoost iteration
251            boolean foundAttribute = performIteration(iteration, trainYs, trainFs, probs, m_numericData);
252            if (foundAttribute) {
253                iteration++;
254                m_numRegressions = iteration;
255            } else {
256                //could not fit simple linear regression: stop LogitBoost
257                break;
258            }
259           
260            double numberOfAttributes = m_numParameters + iteration;
261           
262            // Fill criterion array values
263            criterionValue = 2.0 * negativeLogLikelihood(trainYs, probs) +
264              2.0 * numberOfAttributes;
265
266            //heuristic: stop LogitBoost if the current minimum has not changed for <m_heuristicStop> iterations
267            if (noMin > m_heuristicStop) break;
268            if (criterionValue < bestCriterion) {
269                bestCriterion = criterionValue;
270                bestIteration = iteration;
271                noMin = 0;
272            } else {
273                noMin++;
274            }
275        }
276
277        m_numRegressions = 0;
278        performBoosting(bestIteration);
279    }
280
281    /**
282     * Runs LogitBoost on a training set and monitors the error on a test set.
283     * Used for running one fold when cross-validating the number of LogitBoost iterations.
284     * @param train the training set
285     * @param test the test set
286     * @param error array to hold the logged error values
287     * @param maxIterations the maximum number of LogitBoost iterations to run
288     * @return the number of completed LogitBoost iterations (can be smaller than maxIterations
289     * if the heuristic for early stopping is active or there is a problem while fitting the regressions
290     * in LogitBoost).
291     * @throws Exception if something goes wrong
292     */
293    protected int performBoosting(Instances train, Instances test, 
294                                  double[] error, int maxIterations) throws Exception{
295       
296        //get numeric version of the (sub)set of training instances
297        Instances numericTrain = getNumericData(train);         
298
299        //initialize Ys/Fs/ps
300        double[][] trainYs = getYs(train);
301        double[][] trainFs = getFs(numericTrain);               
302        double[][] probs = getProbs(trainFs);   
303
304        int iteration = 0;
305
306        int noMin = 0;
307        double lastMin = Double.MAX_VALUE;     
308       
309        if (m_errorOnProbabilities) error[0] += getMeanAbsoluteError(test);
310        else error[0] += getErrorRate(test);
311       
312        while (iteration < maxIterations) {
313         
314            //perform single LogitBoost iteration
315            boolean foundAttribute = performIteration(iteration, trainYs, trainFs, probs, numericTrain);
316            if (foundAttribute) {
317                iteration++;
318                m_numRegressions = iteration;
319            } else {
320                //could not fit simple linear regression: stop LogitBoost
321                break;
322            }
323           
324            if (m_errorOnProbabilities) error[iteration] += getMeanAbsoluteError(test);
325            else error[iteration] += getErrorRate(test);
326         
327            //heuristic: stop LogitBoost if the current minimum has not changed for <m_heuristicStop> iterations
328            if (noMin > m_heuristicStop) break;
329            if (error[iteration] < lastMin) {
330                lastMin = error[iteration];
331                noMin = 0;
332            } else {
333                noMin++;
334            }                       
335        }
336
337        return iteration;
338    }
339
340    /**
341     * Runs LogitBoost with a fixed number of iterations.
342     * @param numIterations the number of iterations to run
343     * @throws Exception if something goes wrong
344     */
345    protected void performBoosting(int numIterations) throws Exception{
346
347        //initialize Ys/Fs/ps
348        double[][] trainYs = getYs(m_train);
349        double[][] trainFs = getFs(m_numericData);             
350        double[][] probs = getProbs(trainFs);
351       
352        int iteration = 0;
353
354        //run iterations
355        while (iteration < numIterations) {
356            boolean foundAttribute = performIteration(iteration, trainYs, trainFs, probs, m_numericData);
357            if (foundAttribute) iteration++;
358            else break;
359        }
360       
361        m_numRegressions = iteration;
362    }
363   
364    /**
365     * Runs LogitBoost using the stopping criterion on the training set.
366     * The number of iterations is used that gives the lowest error on the training set, either misclassification
367     * or error on probabilities (depending on the errorOnProbabilities option).
368     * @throws Exception if something goes wrong
369     */
370    protected void performBoosting() throws Exception{
371       
372        //initialize Ys/Fs/ps
373        double[][] trainYs = getYs(m_train);
374        double[][] trainFs = getFs(m_numericData);             
375        double[][] probs = getProbs(trainFs);   
376
377        int iteration = 0;
378
379        double[] trainErrors = new double[m_maxIterations+1];
380        trainErrors[0] = getErrorRate(m_train);
381       
382        int noMin = 0;
383        double lastMin = Double.MAX_VALUE;
384       
385        while (iteration < m_maxIterations) {
386            boolean foundAttribute = performIteration(iteration, trainYs, trainFs, probs, m_numericData);
387            if (foundAttribute) {
388                iteration++;
389                m_numRegressions = iteration;
390            } else {           
391                //could not fit simple regression
392                break;
393            }
394           
395            trainErrors[iteration] = getErrorRate(m_train);         
396         
397            //heuristic: stop LogitBoost if the current minimum has not changed for <m_heuristicStop> iterations
398            if (noMin > m_heuristicStop) break;     
399            if (trainErrors[iteration] < lastMin) {
400                lastMin = trainErrors[iteration];
401                noMin = 0;
402            } else {
403                noMin++;
404            }
405        }
406       
407        //find iteration with best error
408        m_numRegressions = getBestIteration(trainErrors, iteration);   
409    }
410
411    /**
412     * Returns the misclassification error of the current model on a set of instances.
413     * @param data the set of instances
414     * @return the error rate
415     * @throws Exception if something goes wrong
416     */
417    protected double getErrorRate(Instances data) throws Exception {
418        Evaluation eval = new Evaluation(data);
419        eval.evaluateModel(this,data);
420        return eval.errorRate();
421    }
422
423    /**
424     * Returns the error of the probability estimates for the current model on a set of instances.
425     * @param data the set of instances
426     * @return the error
427     * @throws Exception if something goes wrong
428     */
429    protected double getMeanAbsoluteError(Instances data) throws Exception {
430        Evaluation eval = new Evaluation(data);
431        eval.evaluateModel(this,data);
432        return eval.meanAbsoluteError();
433    }
434
435    /**
436     * Helper function to find the minimum in an array of error values.
437     *
438     * @param errors an array containing errors
439     * @param maxIteration the maximum of iterations
440     * @return the minimum
441     */
442    protected int getBestIteration(double[] errors, int maxIteration) {
443        double bestError = errors[0];
444        int bestIteration = 0;
445        for (int i = 1; i <= maxIteration; i++) {           
446            if (errors[i] < bestError) {
447                bestError = errors[i];
448                bestIteration = i;             
449            }
450        } 
451        return bestIteration;
452    }
453
454    /**
455     * Performs a single iteration of LogitBoost, and updates the model accordingly.
456     * A simple regression function is fit to the response and added to the m_regressions array.
457     * @param iteration the current iteration
458     * @param trainYs the y-values (see description of LogitBoost) for the model trained so far
459     * @param trainFs the F-values (see description of LogitBoost) for the model trained so far
460     * @param probs the p-values (see description of LogitBoost) for the model trained so far
461     * @param trainNumeric numeric version of the training data
462     * @return returns true if iteration performed successfully, false if no simple regression function
463     * could be fitted.
464     * @throws Exception if something goes wrong
465     */
466    protected boolean performIteration(int iteration, 
467                                       double[][] trainYs,
468                                       double[][] trainFs,
469                                       double[][] probs,
470                                       Instances trainNumeric) throws Exception {
471       
472        for (int j = 0; j < m_numClasses; j++) {
473            // Keep track of sum of weights
474            double[] weights = new double[trainNumeric.numInstances()];
475            double weightSum = 0.0;
476           
477            //make copy of data (need to save the weights)
478            Instances boostData = new Instances(trainNumeric);
479           
480            for (int i = 0; i < trainNumeric.numInstances(); i++) {
481               
482                //compute response and weight
483                double p = probs[i][j];
484                double actual = trainYs[i][j];
485                double z = getZ(actual, p);
486                double w = (actual - p) / z;
487               
488                //set values for instance
489                Instance current = boostData.instance(i);
490                current.setValue(boostData.classIndex(), z);
491                current.setWeight(current.weight() * w);                               
492               
493                weights[i] = current.weight();
494                weightSum += current.weight();
495            }
496           
497            Instances instancesCopy = new Instances(boostData);
498           
499            if (weightSum > 0) {
500                // Only the (1-beta)th quantile of instances are sent to the base classifier
501                if (m_weightTrimBeta > 0) {
502                    double weightPercentage = 0.0;
503                    int[] weightsOrder = new int[trainNumeric.numInstances()];
504                    weightsOrder = Utils.sort(weights);
505                    instancesCopy.delete();
506                   
507                   
508                    for (int i = weightsOrder.length-1; (i >= 0) && (weightPercentage < (1-m_weightTrimBeta)); i--) {
509                        instancesCopy.add(boostData.instance(weightsOrder[i]));
510                        weightPercentage += (weights[weightsOrder[i]] / weightSum);
511                       
512                    }
513                }
514               
515                //Scale the weights
516                weightSum = instancesCopy.sumOfWeights();
517                for (int i = 0; i < instancesCopy.numInstances(); i++) {
518                    Instance current = instancesCopy.instance(i);
519                    current.setWeight(current.weight() * (double)instancesCopy.numInstances() / weightSum);
520                }
521            }
522           
523            //fit simple regression function
524            m_regressions[j][iteration].buildClassifier(instancesCopy);
525           
526            boolean foundAttribute = m_regressions[j][iteration].foundUsefulAttribute();
527            if (!foundAttribute) {
528                //could not fit simple regression function
529                return false;
530            }
531           
532        }
533       
534        // Evaluate / increment trainFs from the classifier
535        for (int i = 0; i < trainFs.length; i++) {
536            double [] pred = new double [m_numClasses];
537            double predSum = 0;
538            for (int j = 0; j < m_numClasses; j++) {
539                pred[j] = m_regressions[j][iteration]
540                    .classifyInstance(trainNumeric.instance(i));
541                predSum += pred[j];
542            }
543            predSum /= m_numClasses;
544            for (int j = 0; j < m_numClasses; j++) {
545                trainFs[i][j] += (pred[j] - predSum) * (m_numClasses - 1) 
546                    / m_numClasses;
547            }
548        }
549       
550        // Compute the current probability estimates
551        for (int i = 0; i < trainYs.length; i++) {
552            probs[i] = probs(trainFs[i]);
553        }
554        return true;
555    }   
556
557    /**
558     * Helper function to initialize m_regressions.
559     *
560     * @return the generated classifiers
561     */
562    protected SimpleLinearRegression[][] initRegressions(){
563        SimpleLinearRegression[][] classifiers =   
564            new SimpleLinearRegression[m_numClasses][m_maxIterations];
565        for (int j = 0; j < m_numClasses; j++) {
566            for (int i = 0; i < m_maxIterations; i++) {
567                classifiers[j][i] = new SimpleLinearRegression();
568                classifiers[j][i].setSuppressErrorMessage(true);
569            }
570        }
571        return classifiers;
572    }
573
574    /**
575     * Converts training data to numeric version. The class variable is replaced by a pseudo-class
576     * used by LogitBoost.
577     *
578     * @param data the data to convert
579     * @return the converted data
580     * @throws Exception if something goes wrong
581     */
582    protected Instances getNumericData(Instances data) throws Exception{
583        Instances numericData = new Instances(data);
584       
585        int classIndex = numericData.classIndex();
586        numericData.setClassIndex(-1);
587        numericData.deleteAttributeAt(classIndex);
588        numericData.insertAttributeAt(new Attribute("'pseudo class'"), classIndex);
589        numericData.setClassIndex(classIndex);
590        return numericData;
591    }
592   
593    /**
594     * Helper function for cutting back m_regressions to the set of classifiers
595     * (corresponsing to the number of LogitBoost iterations) that gave the
596     * smallest error.
597     *
598     * @param classifiers the original set of classifiers
599     * @return the cut back set of classifiers
600     */
601    protected SimpleLinearRegression[][] selectRegressions(SimpleLinearRegression[][] classifiers){
602        SimpleLinearRegression[][] goodClassifiers = 
603            new SimpleLinearRegression[m_numClasses][m_numRegressions];
604       
605        for (int j = 0; j < m_numClasses; j++) {
606            for (int i = 0; i < m_numRegressions; i++) {
607                goodClassifiers[j][i] = classifiers[j][i];
608            }
609        }
610        return goodClassifiers;
611    }           
612   
613    /**
614     * Computes the LogitBoost response variable from y/p values
615     * (actual/estimated class probabilities).
616     *
617     * @param actual the actual class probability
618     * @param p the estimated class probability
619     * @return the LogitBoost response
620     */
621    protected double getZ(double actual, double p) {
622        double z;
623        if (actual == 1) {
624            z = 1.0 / p;
625            if (z > Z_MAX) { // threshold
626                z = Z_MAX;
627            }
628        } else {
629            z = -1.0 / (1.0 - p);
630            if (z < -Z_MAX) { // threshold
631                z = -Z_MAX;
632            }
633        }
634        return z;
635    }
636   
637    /**
638     * Computes the LogitBoost response for an array of y/p values
639     * (actual/estimated class probabilities).
640     *
641     * @param dataYs the actual class probabilities
642     * @param probs the estimated class probabilities
643     * @return the LogitBoost response
644     */
645    protected double[][] getZs(double[][] probs, double[][] dataYs) {
646       
647        double[][] dataZs = new double[probs.length][m_numClasses];
648        for (int j = 0; j < m_numClasses; j++) 
649            for (int i = 0; i < probs.length; i++) dataZs[i][j] = getZ(dataYs[i][j], probs[i][j]);
650        return dataZs;
651    }
652   
653    /**
654     * Computes the LogitBoost weights from an array of y/p values
655     * (actual/estimated class probabilities).
656     *
657     * @param dataYs the actual class probabilities
658     * @param probs the estimated class probabilities
659     * @return the LogitBoost weights
660     */
661    protected double[][] getWs(double[][] probs, double[][] dataYs) {
662       
663        double[][] dataWs = new double[probs.length][m_numClasses];
664        for (int j = 0; j < m_numClasses; j++) 
665            for (int i = 0; i < probs.length; i++){
666            double z = getZ(dataYs[i][j], probs[i][j]);
667            dataWs[i][j] = (dataYs[i][j] - probs[i][j]) / z;
668            }
669        return dataWs;
670    }
671
672    /**
673     * Computes the p-values (probabilities for the classes) from the F-values
674     * of the logistic model.
675     *
676     * @param Fs the F-values
677     * @return the p-values
678     */
679    protected double[] probs(double[] Fs) {
680       
681        double maxF = -Double.MAX_VALUE;
682        for (int i = 0; i < Fs.length; i++) {
683            if (Fs[i] > maxF) {
684                maxF = Fs[i];
685            }
686        }   
687        double sum = 0;
688        double[] probs = new double[Fs.length];
689        for (int i = 0; i < Fs.length; i++) {
690            probs[i] = Math.exp(Fs[i] - maxF);   
691            sum += probs[i];
692        }
693       
694        Utils.normalize(probs, sum);
695        return probs;
696    }
697
698    /**
699     * Computes the Y-values (actual class probabilities) for a set of instances.
700     *
701     * @param data the data to compute the Y-values from
702     * @return the Y-values
703     */
704    protected double[][] getYs(Instances data){
705       
706        double [][] dataYs = new double [data.numInstances()][m_numClasses];
707        for (int j = 0; j < m_numClasses; j++) {
708            for (int k = 0; k < data.numInstances(); k++) {
709                dataYs[k][j] = (data.instance(k).classValue() == j) ? 
710                    1.0: 0.0;
711            }
712        }
713        return dataYs;
714    }
715
716    /**
717     * Computes the F-values for a single instance.
718     *
719     * @param instance the instance to compute the F-values for
720     * @return the F-values
721     * @throws Exception if something goes wrong
722     */
723    protected double[] getFs(Instance instance) throws Exception{
724       
725        double [] pred = new double [m_numClasses];
726        double [] instanceFs = new double [m_numClasses]; 
727       
728        //add up the predictions from the simple regression functions
729        for (int i = 0; i < m_numRegressions; i++) {
730            double predSum = 0;
731            for (int j = 0; j < m_numClasses; j++) {
732                pred[j] = m_regressions[j][i].classifyInstance(instance);
733                predSum += pred[j];
734            }
735            predSum /= m_numClasses;
736            for (int j = 0; j < m_numClasses; j++) {
737                instanceFs[j] += (pred[j] - predSum) * (m_numClasses - 1) 
738                    / m_numClasses;
739            }
740        }       
741       
742        return instanceFs; 
743    } 
744   
745    /**
746     * Computes the F-values for a set of instances.
747     *
748     * @param data the data to work on
749     * @return the F-values
750     * @throws Exception if something goes wrong
751     */
752    protected double[][] getFs(Instances data) throws Exception{
753       
754        double[][] dataFs = new double[data.numInstances()][];
755       
756        for (int k = 0; k < data.numInstances(); k++) {
757            dataFs[k] = getFs(data.instance(k));
758        }
759       
760        return dataFs; 
761    }   
762
763    /**
764     * Computes the p-values (probabilities for the different classes) from
765     * the F-values for a set of instances.
766     *
767     * @param dataFs the F-values
768     * @return the p-values
769     */
770    protected double[][] getProbs(double[][] dataFs){
771       
772        int numInstances = dataFs.length;
773        double[][] probs = new double[numInstances][];
774       
775        for (int k = 0; k < numInstances; k++) {       
776            probs[k] = probs(dataFs[k]);
777        }
778        return probs;
779    }
780   
781    /**
782     * Returns the negative loglikelihood of the Y-values (actual class probabilities) given the
783     * p-values (current probability estimates).
784     *
785     * @param dataYs the Y-values
786     * @param probs the p-values
787     * @return the likelihood
788     */
789    protected double negativeLogLikelihood(double[][] dataYs, double[][] probs) {
790       
791        double logLikelihood = 0;
792        for (int i = 0; i < dataYs.length; i++) {
793            for (int j = 0; j < m_numClasses; j++) {
794                if (dataYs[i][j] == 1.0) {
795                    logLikelihood -= Math.log(probs[i][j]);
796                }
797            }
798        }
799        return logLikelihood;// / (double)dataYs.length;
800    }
801
802    /**
803     * Returns an array of the indices of the attributes used in the logistic model.
804     * The first dimension is the class, the second dimension holds a list of attribute indices.
805     * Attribute indices start at zero.
806     * @return the array of attribute indices
807     */
808    public int[][] getUsedAttributes(){
809       
810        int[][] usedAttributes = new int[m_numClasses][];
811       
812        //first extract coefficients
813        double[][] coefficients = getCoefficients();
814       
815        for (int j = 0; j < m_numClasses; j++){
816           
817            //boolean array indicating if attribute used
818            boolean[] attributes = new boolean[m_numericDataHeader.numAttributes()];
819            for (int i = 0; i < attributes.length; i++) {
820                //attribute used if coefficient > 0
821                if (!Utils.eq(coefficients[j][i + 1],0)) attributes[i] = true;
822            }
823                   
824            int numAttributes = 0;
825            for (int i = 0; i < m_numericDataHeader.numAttributes(); i++) if (attributes[i]) numAttributes++;
826           
827            //"collect" all attributes into array of indices
828            int[] usedAttributesClass = new int[numAttributes];
829            int count = 0;
830            for (int i = 0; i < m_numericDataHeader.numAttributes(); i++) {
831                if (attributes[i]) {
832                usedAttributesClass[count] = i;
833                count++;
834                } 
835            }
836           
837            usedAttributes[j] = usedAttributesClass;
838        }
839       
840        return usedAttributes;
841    }
842
843    /**
844     * The number of LogitBoost iterations performed (= the number of simple
845     * regression functions fit).
846     *
847     * @return the number of LogitBoost iterations performed
848     */
849    public int getNumRegressions() {
850        return m_numRegressions;
851    }
852   
853    /**
854     * Get the value of weightTrimBeta.
855     *
856     * @return Value of weightTrimBeta.
857     */
858    public double getWeightTrimBeta(){
859        return m_weightTrimBeta;
860    }
861   
862    /**
863     * Get the value of useAIC.
864     *
865     * @return Value of useAIC.
866     */
867    public boolean getUseAIC(){
868        return m_useAIC;
869    }
870
871    /**
872     * Sets the parameter "maxIterations".
873     *
874     * @param maxIterations the maximum iterations
875     */
876    public void setMaxIterations(int maxIterations) {
877        m_maxIterations = maxIterations;
878    }
879   
880    /**
881     * Sets the option "heuristicStop".
882     *
883     * @param heuristicStop the heuristic stop to use
884     */
885    public void setHeuristicStop(int heuristicStop){
886        m_heuristicStop = heuristicStop;
887    }
888   
889    /**
890     * Sets the option "weightTrimBeta".
891     */
892    public void setWeightTrimBeta(double w){
893        m_weightTrimBeta = w;
894    }
895   
896    /**
897     * Set the value of useAIC.
898     *
899     * @param c Value to assign to useAIC.
900     */
901    public void setUseAIC(boolean c){
902        m_useAIC = c;
903    }
904
905    /**
906     * Returns the maxIterations parameter.
907     *
908     * @return the maximum iteration
909     */
910    public int getMaxIterations(){
911        return m_maxIterations;
912    }
913       
914    /**
915     * Returns an array holding the coefficients of the logistic model.
916     * First dimension is the class, the second one holds a list of coefficients.
917     * At position zero, the constant term of the model is stored, then, the coefficients for
918     * the attributes in ascending order.
919     * @return the array of coefficients
920     */
921    protected double[][] getCoefficients(){
922        double[][] coefficients = new double[m_numClasses][m_numericDataHeader.numAttributes() + 1];
923        for (int j = 0; j < m_numClasses; j++) {
924            //go through simple regression functions and add their coefficient to the coefficient of
925            //the attribute they are built on.
926            for (int i = 0; i < m_numRegressions; i++) {
927               
928                double slope = m_regressions[j][i].getSlope();
929                double intercept = m_regressions[j][i].getIntercept();
930                int attribute = m_regressions[j][i].getAttributeIndex();
931               
932                coefficients[j][0] += intercept;
933                coefficients[j][attribute + 1] += slope;
934            }
935        }
936       
937        // Need to multiply all coefficients by (J-1) / J
938        for (int j = 0; j < coefficients.length; j++) {
939          for (int i = 0; i < coefficients[0].length; i++) {
940            coefficients[j][i] *= (double)(m_numClasses - 1) / (double)m_numClasses;
941          }
942        }
943
944        return coefficients;
945    }
946
947    /**
948     * Returns the fraction of all attributes in the data that are used in the
949     * logistic model (in percent).
950     * An attribute is used in the model if it is used in any of the models for
951     * the different classes.
952     *
953     * @return the fraction of all attributes that are used
954     */
955    public double percentAttributesUsed(){     
956        boolean[] attributes = new boolean[m_numericDataHeader.numAttributes()];
957       
958        double[][] coefficients = getCoefficients();
959        for (int j = 0; j < m_numClasses; j++){
960            for (int i = 1; i < m_numericDataHeader.numAttributes() + 1; i++) {
961                //attribute used if it is used in any class, note coefficients are shifted by one (because
962                //of constant term).
963                if (!Utils.eq(coefficients[j][i],0)) attributes[i - 1] = true;
964            }
965        }
966       
967        //count number of used attributes (without the class attribute)
968        double count = 0;
969        for (int i = 0; i < attributes.length; i++) if (attributes[i]) count++;
970        return count / (double)(m_numericDataHeader.numAttributes() - 1) * 100.0;
971    }
972   
973    /**
974     * Returns a description of the logistic model (i.e., attributes and
975     * coefficients).
976     *
977     * @return the description of the model
978     */
979    public String toString(){
980       
981        StringBuffer s = new StringBuffer();   
982
983        //get used attributes
984        int[][] attributes = getUsedAttributes();
985       
986        //get coefficients
987        double[][] coefficients = getCoefficients();
988       
989        for (int j = 0; j < m_numClasses; j++) {
990            s.append("\nClass "+j+" :\n");
991            //constant term
992            s.append(Utils.doubleToString(coefficients[j][0],4,2)+" + \n");
993            for (int i = 0; i < attributes[j].length; i++) {           
994                //attribute/coefficient pairs
995                s.append("["+m_numericDataHeader.attribute(attributes[j][i]).name()+"]");
996                s.append(" * " + Utils.doubleToString(coefficients[j][attributes[j][i]+1],4,2));
997                if (i != attributes[j].length - 1) s.append(" +");
998                s.append("\n");     
999            }
1000        }       
1001        return new String(s);
1002    }
1003
1004    /**
1005     * Returns class probabilities for an instance.
1006     *
1007     * @param instance the instance to compute the distribution for
1008     * @return the class probabilities
1009     * @throws Exception if distribution can't be computed successfully
1010     */
1011    public double[] distributionForInstance(Instance instance) throws Exception {
1012       
1013        instance = (Instance)instance.copy();   
1014
1015        //set to numeric pseudo-class
1016        instance.setDataset(m_numericDataHeader);               
1017       
1018        //calculate probs via Fs
1019        return probs(getFs(instance));
1020    }
1021
1022    /**
1023     * Cleanup in order to save memory.
1024     */
1025    public void cleanup() {
1026        //save just header info
1027        m_train = new Instances(m_train,0);
1028        m_numericData = null;   
1029    }
1030   
1031    /**
1032     * Returns the revision string.
1033     *
1034     * @return          the revision
1035     */
1036    public String getRevision() {
1037      return RevisionUtils.extract("$Revision: 5928 $");
1038    }
1039}
Note: See TracBrowser for help on using the repository browser.