1 | /* |
---|
2 | * This program is free software; you can redistribute it and/or modify |
---|
3 | * it under the terms of the GNU General Public License as published by |
---|
4 | * the Free Software Foundation; either version 2 of the License, or |
---|
5 | * (at your option) any later version. |
---|
6 | * |
---|
7 | * This program is distributed in the hope that it will be useful, |
---|
8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
---|
10 | * GNU General Public License for more details. |
---|
11 | * |
---|
12 | * You should have received a copy of the GNU General Public License |
---|
13 | * along with this program; if not, write to the Free Software |
---|
14 | * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. |
---|
15 | */ |
---|
16 | |
---|
17 | /* |
---|
18 | * LogisticBase.java |
---|
19 | * Copyright (C) 2003 University of Waikato, Hamilton, New Zealand |
---|
20 | * |
---|
21 | */ |
---|
22 | |
---|
23 | package weka.classifiers.trees.lmt; |
---|
24 | |
---|
25 | import weka.classifiers.Classifier; |
---|
26 | import weka.classifiers.AbstractClassifier; |
---|
27 | import weka.classifiers.Evaluation; |
---|
28 | import weka.classifiers.functions.SimpleLinearRegression; |
---|
29 | import weka.core.Attribute; |
---|
30 | import weka.core.Instance; |
---|
31 | import weka.core.Instances; |
---|
32 | import weka.core.RevisionUtils; |
---|
33 | import weka.core.Utils; |
---|
34 | import weka.core.WeightedInstancesHandler; |
---|
35 | |
---|
36 | /** |
---|
37 | * Base/helper class for building logistic regression models with the LogitBoost algorithm. |
---|
38 | * Used for building logistic model trees (weka.classifiers.trees.lmt.LMT) |
---|
39 | * and standalone logistic regression (weka.classifiers.functions.SimpleLogistic). |
---|
40 | * |
---|
41 | <!-- options-start --> |
---|
42 | * Valid options are: <p/> |
---|
43 | * |
---|
44 | * <pre> -D |
---|
45 | * If set, classifier is run in debug mode and |
---|
46 | * may output additional info to the console</pre> |
---|
47 | * |
---|
48 | <!-- options-end --> |
---|
49 | * |
---|
50 | * @author Niels Landwehr |
---|
51 | * @author Marc Sumner |
---|
52 | * @version $Revision: 5928 $ |
---|
53 | */ |
---|
54 | public class LogisticBase |
---|
55 | extends AbstractClassifier |
---|
56 | implements WeightedInstancesHandler { |
---|
57 | |
---|
58 | /** for serialization */ |
---|
59 | static final long serialVersionUID = 168765678097825064L; |
---|
60 | |
---|
61 | /** Header-only version of the numeric version of the training data*/ |
---|
62 | protected Instances m_numericDataHeader; |
---|
63 | /** |
---|
64 | * Numeric version of the training data. Original class is replaced by a numeric pseudo-class. |
---|
65 | */ |
---|
66 | protected Instances m_numericData; |
---|
67 | |
---|
68 | /** Training data */ |
---|
69 | protected Instances m_train; |
---|
70 | |
---|
71 | /** Use cross-validation to determine best number of LogitBoost iterations ?*/ |
---|
72 | protected boolean m_useCrossValidation; |
---|
73 | |
---|
74 | /**Use error on probabilities for stopping criterion of LogitBoost? */ |
---|
75 | protected boolean m_errorOnProbabilities; |
---|
76 | |
---|
77 | /**Use fixed number of iterations for LogitBoost? (if negative, cross-validate number of iterations)*/ |
---|
78 | protected int m_fixedNumIterations; |
---|
79 | |
---|
80 | /**Use heuristic to stop performing LogitBoost iterations earlier? |
---|
81 | * If enabled, LogitBoost is stopped if the current (local) minimum of the error on a test set as |
---|
82 | * a function of the number of iterations has not changed for m_heuristicStop iterations. |
---|
83 | */ |
---|
84 | protected int m_heuristicStop = 50; |
---|
85 | |
---|
86 | /**The number of LogitBoost iterations performed.*/ |
---|
87 | protected int m_numRegressions = 0; |
---|
88 | |
---|
89 | /**The maximum number of LogitBoost iterations*/ |
---|
90 | protected int m_maxIterations; |
---|
91 | |
---|
92 | /**The number of different classes*/ |
---|
93 | protected int m_numClasses; |
---|
94 | |
---|
95 | /**Array holding the simple regression functions fit by LogitBoost*/ |
---|
96 | protected SimpleLinearRegression[][] m_regressions; |
---|
97 | |
---|
98 | /**Number of folds for cross-validating number of LogitBoost iterations*/ |
---|
99 | protected static int m_numFoldsBoosting = 5; |
---|
100 | |
---|
101 | /**Threshold on the Z-value for LogitBoost*/ |
---|
102 | protected static final double Z_MAX = 3; |
---|
103 | |
---|
104 | /** If true, the AIC is used to choose the best iteration*/ |
---|
105 | private boolean m_useAIC = false; |
---|
106 | |
---|
107 | /** Effective number of parameters used for AIC / BIC automatic stopping */ |
---|
108 | protected double m_numParameters = 0; |
---|
109 | |
---|
110 | /**Threshold for trimming weights. Instances with a weight lower than this (as a percentage |
---|
111 | * of total weights) are not included in the regression fit. |
---|
112 | **/ |
---|
113 | protected double m_weightTrimBeta = 0; |
---|
114 | |
---|
115 | /** |
---|
116 | * Constructor that creates LogisticBase object with standard options. |
---|
117 | */ |
---|
118 | public LogisticBase(){ |
---|
119 | m_fixedNumIterations = -1; |
---|
120 | m_useCrossValidation = true; |
---|
121 | m_errorOnProbabilities = false; |
---|
122 | m_maxIterations = 500; |
---|
123 | m_useAIC = false; |
---|
124 | m_numParameters = 0; |
---|
125 | } |
---|
126 | |
---|
127 | /** |
---|
128 | * Constructor to create LogisticBase object. |
---|
129 | * @param numBoostingIterations fixed number of iterations for LogitBoost (if negative, use cross-validation or |
---|
130 | * stopping criterion on the training data). |
---|
131 | * @param useCrossValidation cross-validate number of LogitBoost iterations (if false, use stopping |
---|
132 | * criterion on the training data). |
---|
133 | * @param errorOnProbabilities if true, use error on probabilities |
---|
134 | * instead of misclassification for stopping criterion of LogitBoost |
---|
135 | */ |
---|
136 | public LogisticBase(int numBoostingIterations, boolean useCrossValidation, boolean errorOnProbabilities){ |
---|
137 | m_fixedNumIterations = numBoostingIterations; |
---|
138 | m_useCrossValidation = useCrossValidation; |
---|
139 | m_errorOnProbabilities = errorOnProbabilities; |
---|
140 | m_maxIterations = 500; |
---|
141 | m_useAIC = false; |
---|
142 | m_numParameters = 0; |
---|
143 | } |
---|
144 | |
---|
145 | /** |
---|
146 | * Builds the logistic regression model usiing LogitBoost. |
---|
147 | * |
---|
148 | * @param data the training data |
---|
149 | * @throws Exception if something goes wrong |
---|
150 | */ |
---|
151 | public void buildClassifier(Instances data) throws Exception { |
---|
152 | |
---|
153 | m_train = new Instances(data); |
---|
154 | |
---|
155 | m_numClasses = m_train.numClasses(); |
---|
156 | |
---|
157 | //init the array of simple regression functions |
---|
158 | m_regressions = initRegressions(); |
---|
159 | m_numRegressions = 0; |
---|
160 | |
---|
161 | //get numeric version of the training data (class variable replaced by numeric pseudo-class) |
---|
162 | m_numericData = getNumericData(m_train); |
---|
163 | |
---|
164 | //save header info |
---|
165 | m_numericDataHeader = new Instances(m_numericData, 0); |
---|
166 | |
---|
167 | |
---|
168 | if (m_fixedNumIterations > 0) { |
---|
169 | //run LogitBoost for fixed number of iterations |
---|
170 | performBoosting(m_fixedNumIterations); |
---|
171 | } else if (m_useAIC) { // Marc had this after the test for m_useCrossValidation. Changed by Eibe. |
---|
172 | //run LogitBoost using information criterion for stopping |
---|
173 | performBoostingInfCriterion(); |
---|
174 | } else if (m_useCrossValidation) { |
---|
175 | //cross-validate number of LogitBoost iterations |
---|
176 | performBoostingCV(); |
---|
177 | } else { |
---|
178 | //run LogitBoost with number of iterations that minimizes error on the training set |
---|
179 | performBoosting(); |
---|
180 | } |
---|
181 | |
---|
182 | //only keep the simple regression functions that correspond to the selected number of LogitBoost iterations |
---|
183 | m_regressions = selectRegressions(m_regressions); |
---|
184 | } |
---|
185 | |
---|
186 | /** |
---|
187 | * Runs LogitBoost, determining the best number of iterations by cross-validation. |
---|
188 | * |
---|
189 | * @throws Exception if something goes wrong |
---|
190 | */ |
---|
191 | protected void performBoostingCV() throws Exception{ |
---|
192 | |
---|
193 | //completed iteration keeps track of the number of iterations that have been |
---|
194 | //performed in every fold (some might stop earlier than others). |
---|
195 | //Best iteration is selected only from these. |
---|
196 | int completedIterations = m_maxIterations; |
---|
197 | |
---|
198 | Instances allData = new Instances(m_train); |
---|
199 | |
---|
200 | allData.stratify(m_numFoldsBoosting); |
---|
201 | |
---|
202 | double[] error = new double[m_maxIterations + 1]; |
---|
203 | |
---|
204 | for (int i = 0; i < m_numFoldsBoosting; i++) { |
---|
205 | //split into training/test data in fold |
---|
206 | Instances train = allData.trainCV(m_numFoldsBoosting,i); |
---|
207 | Instances test = allData.testCV(m_numFoldsBoosting,i); |
---|
208 | |
---|
209 | //initialize LogitBoost |
---|
210 | m_numRegressions = 0; |
---|
211 | m_regressions = initRegressions(); |
---|
212 | |
---|
213 | //run LogitBoost iterations |
---|
214 | int iterations = performBoosting(train,test,error,completedIterations); |
---|
215 | if (iterations < completedIterations) completedIterations = iterations; |
---|
216 | } |
---|
217 | |
---|
218 | //determine iteration with minimum error over the folds |
---|
219 | int bestIteration = getBestIteration(error,completedIterations); |
---|
220 | |
---|
221 | //rebuild model on all of the training data |
---|
222 | m_numRegressions = 0; |
---|
223 | performBoosting(bestIteration); |
---|
224 | } |
---|
225 | |
---|
226 | /** |
---|
227 | * Runs LogitBoost, determining the best number of iterations by an information criterion (currently AIC). |
---|
228 | */ |
---|
229 | protected void performBoostingInfCriterion() throws Exception{ |
---|
230 | |
---|
231 | double criterion = 0.0; |
---|
232 | double bestCriterion = Double.MAX_VALUE; |
---|
233 | int bestIteration = 0; |
---|
234 | int noMin = 0; |
---|
235 | |
---|
236 | // Variable to keep track of criterion values (AIC) |
---|
237 | double criterionValue = Double.MAX_VALUE; |
---|
238 | |
---|
239 | // initialize Ys/Fs/ps |
---|
240 | double[][] trainYs = getYs(m_train); |
---|
241 | double[][] trainFs = getFs(m_numericData); |
---|
242 | double[][] probs = getProbs(trainFs); |
---|
243 | |
---|
244 | // Array with true/false if the attribute is included in the model or not |
---|
245 | boolean[][] attributes = new boolean[m_numClasses][m_numericDataHeader.numAttributes()]; |
---|
246 | |
---|
247 | int iteration = 0; |
---|
248 | while (iteration < m_maxIterations) { |
---|
249 | |
---|
250 | //perform single LogitBoost iteration |
---|
251 | boolean foundAttribute = performIteration(iteration, trainYs, trainFs, probs, m_numericData); |
---|
252 | if (foundAttribute) { |
---|
253 | iteration++; |
---|
254 | m_numRegressions = iteration; |
---|
255 | } else { |
---|
256 | //could not fit simple linear regression: stop LogitBoost |
---|
257 | break; |
---|
258 | } |
---|
259 | |
---|
260 | double numberOfAttributes = m_numParameters + iteration; |
---|
261 | |
---|
262 | // Fill criterion array values |
---|
263 | criterionValue = 2.0 * negativeLogLikelihood(trainYs, probs) + |
---|
264 | 2.0 * numberOfAttributes; |
---|
265 | |
---|
266 | //heuristic: stop LogitBoost if the current minimum has not changed for <m_heuristicStop> iterations |
---|
267 | if (noMin > m_heuristicStop) break; |
---|
268 | if (criterionValue < bestCriterion) { |
---|
269 | bestCriterion = criterionValue; |
---|
270 | bestIteration = iteration; |
---|
271 | noMin = 0; |
---|
272 | } else { |
---|
273 | noMin++; |
---|
274 | } |
---|
275 | } |
---|
276 | |
---|
277 | m_numRegressions = 0; |
---|
278 | performBoosting(bestIteration); |
---|
279 | } |
---|
280 | |
---|
281 | /** |
---|
282 | * Runs LogitBoost on a training set and monitors the error on a test set. |
---|
283 | * Used for running one fold when cross-validating the number of LogitBoost iterations. |
---|
284 | * @param train the training set |
---|
285 | * @param test the test set |
---|
286 | * @param error array to hold the logged error values |
---|
287 | * @param maxIterations the maximum number of LogitBoost iterations to run |
---|
288 | * @return the number of completed LogitBoost iterations (can be smaller than maxIterations |
---|
289 | * if the heuristic for early stopping is active or there is a problem while fitting the regressions |
---|
290 | * in LogitBoost). |
---|
291 | * @throws Exception if something goes wrong |
---|
292 | */ |
---|
293 | protected int performBoosting(Instances train, Instances test, |
---|
294 | double[] error, int maxIterations) throws Exception{ |
---|
295 | |
---|
296 | //get numeric version of the (sub)set of training instances |
---|
297 | Instances numericTrain = getNumericData(train); |
---|
298 | |
---|
299 | //initialize Ys/Fs/ps |
---|
300 | double[][] trainYs = getYs(train); |
---|
301 | double[][] trainFs = getFs(numericTrain); |
---|
302 | double[][] probs = getProbs(trainFs); |
---|
303 | |
---|
304 | int iteration = 0; |
---|
305 | |
---|
306 | int noMin = 0; |
---|
307 | double lastMin = Double.MAX_VALUE; |
---|
308 | |
---|
309 | if (m_errorOnProbabilities) error[0] += getMeanAbsoluteError(test); |
---|
310 | else error[0] += getErrorRate(test); |
---|
311 | |
---|
312 | while (iteration < maxIterations) { |
---|
313 | |
---|
314 | //perform single LogitBoost iteration |
---|
315 | boolean foundAttribute = performIteration(iteration, trainYs, trainFs, probs, numericTrain); |
---|
316 | if (foundAttribute) { |
---|
317 | iteration++; |
---|
318 | m_numRegressions = iteration; |
---|
319 | } else { |
---|
320 | //could not fit simple linear regression: stop LogitBoost |
---|
321 | break; |
---|
322 | } |
---|
323 | |
---|
324 | if (m_errorOnProbabilities) error[iteration] += getMeanAbsoluteError(test); |
---|
325 | else error[iteration] += getErrorRate(test); |
---|
326 | |
---|
327 | //heuristic: stop LogitBoost if the current minimum has not changed for <m_heuristicStop> iterations |
---|
328 | if (noMin > m_heuristicStop) break; |
---|
329 | if (error[iteration] < lastMin) { |
---|
330 | lastMin = error[iteration]; |
---|
331 | noMin = 0; |
---|
332 | } else { |
---|
333 | noMin++; |
---|
334 | } |
---|
335 | } |
---|
336 | |
---|
337 | return iteration; |
---|
338 | } |
---|
339 | |
---|
340 | /** |
---|
341 | * Runs LogitBoost with a fixed number of iterations. |
---|
342 | * @param numIterations the number of iterations to run |
---|
343 | * @throws Exception if something goes wrong |
---|
344 | */ |
---|
345 | protected void performBoosting(int numIterations) throws Exception{ |
---|
346 | |
---|
347 | //initialize Ys/Fs/ps |
---|
348 | double[][] trainYs = getYs(m_train); |
---|
349 | double[][] trainFs = getFs(m_numericData); |
---|
350 | double[][] probs = getProbs(trainFs); |
---|
351 | |
---|
352 | int iteration = 0; |
---|
353 | |
---|
354 | //run iterations |
---|
355 | while (iteration < numIterations) { |
---|
356 | boolean foundAttribute = performIteration(iteration, trainYs, trainFs, probs, m_numericData); |
---|
357 | if (foundAttribute) iteration++; |
---|
358 | else break; |
---|
359 | } |
---|
360 | |
---|
361 | m_numRegressions = iteration; |
---|
362 | } |
---|
363 | |
---|
364 | /** |
---|
365 | * Runs LogitBoost using the stopping criterion on the training set. |
---|
366 | * The number of iterations is used that gives the lowest error on the training set, either misclassification |
---|
367 | * or error on probabilities (depending on the errorOnProbabilities option). |
---|
368 | * @throws Exception if something goes wrong |
---|
369 | */ |
---|
370 | protected void performBoosting() throws Exception{ |
---|
371 | |
---|
372 | //initialize Ys/Fs/ps |
---|
373 | double[][] trainYs = getYs(m_train); |
---|
374 | double[][] trainFs = getFs(m_numericData); |
---|
375 | double[][] probs = getProbs(trainFs); |
---|
376 | |
---|
377 | int iteration = 0; |
---|
378 | |
---|
379 | double[] trainErrors = new double[m_maxIterations+1]; |
---|
380 | trainErrors[0] = getErrorRate(m_train); |
---|
381 | |
---|
382 | int noMin = 0; |
---|
383 | double lastMin = Double.MAX_VALUE; |
---|
384 | |
---|
385 | while (iteration < m_maxIterations) { |
---|
386 | boolean foundAttribute = performIteration(iteration, trainYs, trainFs, probs, m_numericData); |
---|
387 | if (foundAttribute) { |
---|
388 | iteration++; |
---|
389 | m_numRegressions = iteration; |
---|
390 | } else { |
---|
391 | //could not fit simple regression |
---|
392 | break; |
---|
393 | } |
---|
394 | |
---|
395 | trainErrors[iteration] = getErrorRate(m_train); |
---|
396 | |
---|
397 | //heuristic: stop LogitBoost if the current minimum has not changed for <m_heuristicStop> iterations |
---|
398 | if (noMin > m_heuristicStop) break; |
---|
399 | if (trainErrors[iteration] < lastMin) { |
---|
400 | lastMin = trainErrors[iteration]; |
---|
401 | noMin = 0; |
---|
402 | } else { |
---|
403 | noMin++; |
---|
404 | } |
---|
405 | } |
---|
406 | |
---|
407 | //find iteration with best error |
---|
408 | m_numRegressions = getBestIteration(trainErrors, iteration); |
---|
409 | } |
---|
410 | |
---|
411 | /** |
---|
412 | * Returns the misclassification error of the current model on a set of instances. |
---|
413 | * @param data the set of instances |
---|
414 | * @return the error rate |
---|
415 | * @throws Exception if something goes wrong |
---|
416 | */ |
---|
417 | protected double getErrorRate(Instances data) throws Exception { |
---|
418 | Evaluation eval = new Evaluation(data); |
---|
419 | eval.evaluateModel(this,data); |
---|
420 | return eval.errorRate(); |
---|
421 | } |
---|
422 | |
---|
423 | /** |
---|
424 | * Returns the error of the probability estimates for the current model on a set of instances. |
---|
425 | * @param data the set of instances |
---|
426 | * @return the error |
---|
427 | * @throws Exception if something goes wrong |
---|
428 | */ |
---|
429 | protected double getMeanAbsoluteError(Instances data) throws Exception { |
---|
430 | Evaluation eval = new Evaluation(data); |
---|
431 | eval.evaluateModel(this,data); |
---|
432 | return eval.meanAbsoluteError(); |
---|
433 | } |
---|
434 | |
---|
435 | /** |
---|
436 | * Helper function to find the minimum in an array of error values. |
---|
437 | * |
---|
438 | * @param errors an array containing errors |
---|
439 | * @param maxIteration the maximum of iterations |
---|
440 | * @return the minimum |
---|
441 | */ |
---|
442 | protected int getBestIteration(double[] errors, int maxIteration) { |
---|
443 | double bestError = errors[0]; |
---|
444 | int bestIteration = 0; |
---|
445 | for (int i = 1; i <= maxIteration; i++) { |
---|
446 | if (errors[i] < bestError) { |
---|
447 | bestError = errors[i]; |
---|
448 | bestIteration = i; |
---|
449 | } |
---|
450 | } |
---|
451 | return bestIteration; |
---|
452 | } |
---|
453 | |
---|
454 | /** |
---|
455 | * Performs a single iteration of LogitBoost, and updates the model accordingly. |
---|
456 | * A simple regression function is fit to the response and added to the m_regressions array. |
---|
457 | * @param iteration the current iteration |
---|
458 | * @param trainYs the y-values (see description of LogitBoost) for the model trained so far |
---|
459 | * @param trainFs the F-values (see description of LogitBoost) for the model trained so far |
---|
460 | * @param probs the p-values (see description of LogitBoost) for the model trained so far |
---|
461 | * @param trainNumeric numeric version of the training data |
---|
462 | * @return returns true if iteration performed successfully, false if no simple regression function |
---|
463 | * could be fitted. |
---|
464 | * @throws Exception if something goes wrong |
---|
465 | */ |
---|
466 | protected boolean performIteration(int iteration, |
---|
467 | double[][] trainYs, |
---|
468 | double[][] trainFs, |
---|
469 | double[][] probs, |
---|
470 | Instances trainNumeric) throws Exception { |
---|
471 | |
---|
472 | for (int j = 0; j < m_numClasses; j++) { |
---|
473 | // Keep track of sum of weights |
---|
474 | double[] weights = new double[trainNumeric.numInstances()]; |
---|
475 | double weightSum = 0.0; |
---|
476 | |
---|
477 | //make copy of data (need to save the weights) |
---|
478 | Instances boostData = new Instances(trainNumeric); |
---|
479 | |
---|
480 | for (int i = 0; i < trainNumeric.numInstances(); i++) { |
---|
481 | |
---|
482 | //compute response and weight |
---|
483 | double p = probs[i][j]; |
---|
484 | double actual = trainYs[i][j]; |
---|
485 | double z = getZ(actual, p); |
---|
486 | double w = (actual - p) / z; |
---|
487 | |
---|
488 | //set values for instance |
---|
489 | Instance current = boostData.instance(i); |
---|
490 | current.setValue(boostData.classIndex(), z); |
---|
491 | current.setWeight(current.weight() * w); |
---|
492 | |
---|
493 | weights[i] = current.weight(); |
---|
494 | weightSum += current.weight(); |
---|
495 | } |
---|
496 | |
---|
497 | Instances instancesCopy = new Instances(boostData); |
---|
498 | |
---|
499 | if (weightSum > 0) { |
---|
500 | // Only the (1-beta)th quantile of instances are sent to the base classifier |
---|
501 | if (m_weightTrimBeta > 0) { |
---|
502 | double weightPercentage = 0.0; |
---|
503 | int[] weightsOrder = new int[trainNumeric.numInstances()]; |
---|
504 | weightsOrder = Utils.sort(weights); |
---|
505 | instancesCopy.delete(); |
---|
506 | |
---|
507 | |
---|
508 | for (int i = weightsOrder.length-1; (i >= 0) && (weightPercentage < (1-m_weightTrimBeta)); i--) { |
---|
509 | instancesCopy.add(boostData.instance(weightsOrder[i])); |
---|
510 | weightPercentage += (weights[weightsOrder[i]] / weightSum); |
---|
511 | |
---|
512 | } |
---|
513 | } |
---|
514 | |
---|
515 | //Scale the weights |
---|
516 | weightSum = instancesCopy.sumOfWeights(); |
---|
517 | for (int i = 0; i < instancesCopy.numInstances(); i++) { |
---|
518 | Instance current = instancesCopy.instance(i); |
---|
519 | current.setWeight(current.weight() * (double)instancesCopy.numInstances() / weightSum); |
---|
520 | } |
---|
521 | } |
---|
522 | |
---|
523 | //fit simple regression function |
---|
524 | m_regressions[j][iteration].buildClassifier(instancesCopy); |
---|
525 | |
---|
526 | boolean foundAttribute = m_regressions[j][iteration].foundUsefulAttribute(); |
---|
527 | if (!foundAttribute) { |
---|
528 | //could not fit simple regression function |
---|
529 | return false; |
---|
530 | } |
---|
531 | |
---|
532 | } |
---|
533 | |
---|
534 | // Evaluate / increment trainFs from the classifier |
---|
535 | for (int i = 0; i < trainFs.length; i++) { |
---|
536 | double [] pred = new double [m_numClasses]; |
---|
537 | double predSum = 0; |
---|
538 | for (int j = 0; j < m_numClasses; j++) { |
---|
539 | pred[j] = m_regressions[j][iteration] |
---|
540 | .classifyInstance(trainNumeric.instance(i)); |
---|
541 | predSum += pred[j]; |
---|
542 | } |
---|
543 | predSum /= m_numClasses; |
---|
544 | for (int j = 0; j < m_numClasses; j++) { |
---|
545 | trainFs[i][j] += (pred[j] - predSum) * (m_numClasses - 1) |
---|
546 | / m_numClasses; |
---|
547 | } |
---|
548 | } |
---|
549 | |
---|
550 | // Compute the current probability estimates |
---|
551 | for (int i = 0; i < trainYs.length; i++) { |
---|
552 | probs[i] = probs(trainFs[i]); |
---|
553 | } |
---|
554 | return true; |
---|
555 | } |
---|
556 | |
---|
557 | /** |
---|
558 | * Helper function to initialize m_regressions. |
---|
559 | * |
---|
560 | * @return the generated classifiers |
---|
561 | */ |
---|
562 | protected SimpleLinearRegression[][] initRegressions(){ |
---|
563 | SimpleLinearRegression[][] classifiers = |
---|
564 | new SimpleLinearRegression[m_numClasses][m_maxIterations]; |
---|
565 | for (int j = 0; j < m_numClasses; j++) { |
---|
566 | for (int i = 0; i < m_maxIterations; i++) { |
---|
567 | classifiers[j][i] = new SimpleLinearRegression(); |
---|
568 | classifiers[j][i].setSuppressErrorMessage(true); |
---|
569 | } |
---|
570 | } |
---|
571 | return classifiers; |
---|
572 | } |
---|
573 | |
---|
574 | /** |
---|
575 | * Converts training data to numeric version. The class variable is replaced by a pseudo-class |
---|
576 | * used by LogitBoost. |
---|
577 | * |
---|
578 | * @param data the data to convert |
---|
579 | * @return the converted data |
---|
580 | * @throws Exception if something goes wrong |
---|
581 | */ |
---|
582 | protected Instances getNumericData(Instances data) throws Exception{ |
---|
583 | Instances numericData = new Instances(data); |
---|
584 | |
---|
585 | int classIndex = numericData.classIndex(); |
---|
586 | numericData.setClassIndex(-1); |
---|
587 | numericData.deleteAttributeAt(classIndex); |
---|
588 | numericData.insertAttributeAt(new Attribute("'pseudo class'"), classIndex); |
---|
589 | numericData.setClassIndex(classIndex); |
---|
590 | return numericData; |
---|
591 | } |
---|
592 | |
---|
593 | /** |
---|
594 | * Helper function for cutting back m_regressions to the set of classifiers |
---|
595 | * (corresponsing to the number of LogitBoost iterations) that gave the |
---|
596 | * smallest error. |
---|
597 | * |
---|
598 | * @param classifiers the original set of classifiers |
---|
599 | * @return the cut back set of classifiers |
---|
600 | */ |
---|
601 | protected SimpleLinearRegression[][] selectRegressions(SimpleLinearRegression[][] classifiers){ |
---|
602 | SimpleLinearRegression[][] goodClassifiers = |
---|
603 | new SimpleLinearRegression[m_numClasses][m_numRegressions]; |
---|
604 | |
---|
605 | for (int j = 0; j < m_numClasses; j++) { |
---|
606 | for (int i = 0; i < m_numRegressions; i++) { |
---|
607 | goodClassifiers[j][i] = classifiers[j][i]; |
---|
608 | } |
---|
609 | } |
---|
610 | return goodClassifiers; |
---|
611 | } |
---|
612 | |
---|
613 | /** |
---|
614 | * Computes the LogitBoost response variable from y/p values |
---|
615 | * (actual/estimated class probabilities). |
---|
616 | * |
---|
617 | * @param actual the actual class probability |
---|
618 | * @param p the estimated class probability |
---|
619 | * @return the LogitBoost response |
---|
620 | */ |
---|
621 | protected double getZ(double actual, double p) { |
---|
622 | double z; |
---|
623 | if (actual == 1) { |
---|
624 | z = 1.0 / p; |
---|
625 | if (z > Z_MAX) { // threshold |
---|
626 | z = Z_MAX; |
---|
627 | } |
---|
628 | } else { |
---|
629 | z = -1.0 / (1.0 - p); |
---|
630 | if (z < -Z_MAX) { // threshold |
---|
631 | z = -Z_MAX; |
---|
632 | } |
---|
633 | } |
---|
634 | return z; |
---|
635 | } |
---|
636 | |
---|
637 | /** |
---|
638 | * Computes the LogitBoost response for an array of y/p values |
---|
639 | * (actual/estimated class probabilities). |
---|
640 | * |
---|
641 | * @param dataYs the actual class probabilities |
---|
642 | * @param probs the estimated class probabilities |
---|
643 | * @return the LogitBoost response |
---|
644 | */ |
---|
645 | protected double[][] getZs(double[][] probs, double[][] dataYs) { |
---|
646 | |
---|
647 | double[][] dataZs = new double[probs.length][m_numClasses]; |
---|
648 | for (int j = 0; j < m_numClasses; j++) |
---|
649 | for (int i = 0; i < probs.length; i++) dataZs[i][j] = getZ(dataYs[i][j], probs[i][j]); |
---|
650 | return dataZs; |
---|
651 | } |
---|
652 | |
---|
653 | /** |
---|
654 | * Computes the LogitBoost weights from an array of y/p values |
---|
655 | * (actual/estimated class probabilities). |
---|
656 | * |
---|
657 | * @param dataYs the actual class probabilities |
---|
658 | * @param probs the estimated class probabilities |
---|
659 | * @return the LogitBoost weights |
---|
660 | */ |
---|
661 | protected double[][] getWs(double[][] probs, double[][] dataYs) { |
---|
662 | |
---|
663 | double[][] dataWs = new double[probs.length][m_numClasses]; |
---|
664 | for (int j = 0; j < m_numClasses; j++) |
---|
665 | for (int i = 0; i < probs.length; i++){ |
---|
666 | double z = getZ(dataYs[i][j], probs[i][j]); |
---|
667 | dataWs[i][j] = (dataYs[i][j] - probs[i][j]) / z; |
---|
668 | } |
---|
669 | return dataWs; |
---|
670 | } |
---|
671 | |
---|
672 | /** |
---|
673 | * Computes the p-values (probabilities for the classes) from the F-values |
---|
674 | * of the logistic model. |
---|
675 | * |
---|
676 | * @param Fs the F-values |
---|
677 | * @return the p-values |
---|
678 | */ |
---|
679 | protected double[] probs(double[] Fs) { |
---|
680 | |
---|
681 | double maxF = -Double.MAX_VALUE; |
---|
682 | for (int i = 0; i < Fs.length; i++) { |
---|
683 | if (Fs[i] > maxF) { |
---|
684 | maxF = Fs[i]; |
---|
685 | } |
---|
686 | } |
---|
687 | double sum = 0; |
---|
688 | double[] probs = new double[Fs.length]; |
---|
689 | for (int i = 0; i < Fs.length; i++) { |
---|
690 | probs[i] = Math.exp(Fs[i] - maxF); |
---|
691 | sum += probs[i]; |
---|
692 | } |
---|
693 | |
---|
694 | Utils.normalize(probs, sum); |
---|
695 | return probs; |
---|
696 | } |
---|
697 | |
---|
698 | /** |
---|
699 | * Computes the Y-values (actual class probabilities) for a set of instances. |
---|
700 | * |
---|
701 | * @param data the data to compute the Y-values from |
---|
702 | * @return the Y-values |
---|
703 | */ |
---|
704 | protected double[][] getYs(Instances data){ |
---|
705 | |
---|
706 | double [][] dataYs = new double [data.numInstances()][m_numClasses]; |
---|
707 | for (int j = 0; j < m_numClasses; j++) { |
---|
708 | for (int k = 0; k < data.numInstances(); k++) { |
---|
709 | dataYs[k][j] = (data.instance(k).classValue() == j) ? |
---|
710 | 1.0: 0.0; |
---|
711 | } |
---|
712 | } |
---|
713 | return dataYs; |
---|
714 | } |
---|
715 | |
---|
716 | /** |
---|
717 | * Computes the F-values for a single instance. |
---|
718 | * |
---|
719 | * @param instance the instance to compute the F-values for |
---|
720 | * @return the F-values |
---|
721 | * @throws Exception if something goes wrong |
---|
722 | */ |
---|
723 | protected double[] getFs(Instance instance) throws Exception{ |
---|
724 | |
---|
725 | double [] pred = new double [m_numClasses]; |
---|
726 | double [] instanceFs = new double [m_numClasses]; |
---|
727 | |
---|
728 | //add up the predictions from the simple regression functions |
---|
729 | for (int i = 0; i < m_numRegressions; i++) { |
---|
730 | double predSum = 0; |
---|
731 | for (int j = 0; j < m_numClasses; j++) { |
---|
732 | pred[j] = m_regressions[j][i].classifyInstance(instance); |
---|
733 | predSum += pred[j]; |
---|
734 | } |
---|
735 | predSum /= m_numClasses; |
---|
736 | for (int j = 0; j < m_numClasses; j++) { |
---|
737 | instanceFs[j] += (pred[j] - predSum) * (m_numClasses - 1) |
---|
738 | / m_numClasses; |
---|
739 | } |
---|
740 | } |
---|
741 | |
---|
742 | return instanceFs; |
---|
743 | } |
---|
744 | |
---|
745 | /** |
---|
746 | * Computes the F-values for a set of instances. |
---|
747 | * |
---|
748 | * @param data the data to work on |
---|
749 | * @return the F-values |
---|
750 | * @throws Exception if something goes wrong |
---|
751 | */ |
---|
752 | protected double[][] getFs(Instances data) throws Exception{ |
---|
753 | |
---|
754 | double[][] dataFs = new double[data.numInstances()][]; |
---|
755 | |
---|
756 | for (int k = 0; k < data.numInstances(); k++) { |
---|
757 | dataFs[k] = getFs(data.instance(k)); |
---|
758 | } |
---|
759 | |
---|
760 | return dataFs; |
---|
761 | } |
---|
762 | |
---|
763 | /** |
---|
764 | * Computes the p-values (probabilities for the different classes) from |
---|
765 | * the F-values for a set of instances. |
---|
766 | * |
---|
767 | * @param dataFs the F-values |
---|
768 | * @return the p-values |
---|
769 | */ |
---|
770 | protected double[][] getProbs(double[][] dataFs){ |
---|
771 | |
---|
772 | int numInstances = dataFs.length; |
---|
773 | double[][] probs = new double[numInstances][]; |
---|
774 | |
---|
775 | for (int k = 0; k < numInstances; k++) { |
---|
776 | probs[k] = probs(dataFs[k]); |
---|
777 | } |
---|
778 | return probs; |
---|
779 | } |
---|
780 | |
---|
781 | /** |
---|
782 | * Returns the negative loglikelihood of the Y-values (actual class probabilities) given the |
---|
783 | * p-values (current probability estimates). |
---|
784 | * |
---|
785 | * @param dataYs the Y-values |
---|
786 | * @param probs the p-values |
---|
787 | * @return the likelihood |
---|
788 | */ |
---|
789 | protected double negativeLogLikelihood(double[][] dataYs, double[][] probs) { |
---|
790 | |
---|
791 | double logLikelihood = 0; |
---|
792 | for (int i = 0; i < dataYs.length; i++) { |
---|
793 | for (int j = 0; j < m_numClasses; j++) { |
---|
794 | if (dataYs[i][j] == 1.0) { |
---|
795 | logLikelihood -= Math.log(probs[i][j]); |
---|
796 | } |
---|
797 | } |
---|
798 | } |
---|
799 | return logLikelihood;// / (double)dataYs.length; |
---|
800 | } |
---|
801 | |
---|
802 | /** |
---|
803 | * Returns an array of the indices of the attributes used in the logistic model. |
---|
804 | * The first dimension is the class, the second dimension holds a list of attribute indices. |
---|
805 | * Attribute indices start at zero. |
---|
806 | * @return the array of attribute indices |
---|
807 | */ |
---|
808 | public int[][] getUsedAttributes(){ |
---|
809 | |
---|
810 | int[][] usedAttributes = new int[m_numClasses][]; |
---|
811 | |
---|
812 | //first extract coefficients |
---|
813 | double[][] coefficients = getCoefficients(); |
---|
814 | |
---|
815 | for (int j = 0; j < m_numClasses; j++){ |
---|
816 | |
---|
817 | //boolean array indicating if attribute used |
---|
818 | boolean[] attributes = new boolean[m_numericDataHeader.numAttributes()]; |
---|
819 | for (int i = 0; i < attributes.length; i++) { |
---|
820 | //attribute used if coefficient > 0 |
---|
821 | if (!Utils.eq(coefficients[j][i + 1],0)) attributes[i] = true; |
---|
822 | } |
---|
823 | |
---|
824 | int numAttributes = 0; |
---|
825 | for (int i = 0; i < m_numericDataHeader.numAttributes(); i++) if (attributes[i]) numAttributes++; |
---|
826 | |
---|
827 | //"collect" all attributes into array of indices |
---|
828 | int[] usedAttributesClass = new int[numAttributes]; |
---|
829 | int count = 0; |
---|
830 | for (int i = 0; i < m_numericDataHeader.numAttributes(); i++) { |
---|
831 | if (attributes[i]) { |
---|
832 | usedAttributesClass[count] = i; |
---|
833 | count++; |
---|
834 | } |
---|
835 | } |
---|
836 | |
---|
837 | usedAttributes[j] = usedAttributesClass; |
---|
838 | } |
---|
839 | |
---|
840 | return usedAttributes; |
---|
841 | } |
---|
842 | |
---|
843 | /** |
---|
844 | * The number of LogitBoost iterations performed (= the number of simple |
---|
845 | * regression functions fit). |
---|
846 | * |
---|
847 | * @return the number of LogitBoost iterations performed |
---|
848 | */ |
---|
849 | public int getNumRegressions() { |
---|
850 | return m_numRegressions; |
---|
851 | } |
---|
852 | |
---|
853 | /** |
---|
854 | * Get the value of weightTrimBeta. |
---|
855 | * |
---|
856 | * @return Value of weightTrimBeta. |
---|
857 | */ |
---|
858 | public double getWeightTrimBeta(){ |
---|
859 | return m_weightTrimBeta; |
---|
860 | } |
---|
861 | |
---|
862 | /** |
---|
863 | * Get the value of useAIC. |
---|
864 | * |
---|
865 | * @return Value of useAIC. |
---|
866 | */ |
---|
867 | public boolean getUseAIC(){ |
---|
868 | return m_useAIC; |
---|
869 | } |
---|
870 | |
---|
871 | /** |
---|
872 | * Sets the parameter "maxIterations". |
---|
873 | * |
---|
874 | * @param maxIterations the maximum iterations |
---|
875 | */ |
---|
876 | public void setMaxIterations(int maxIterations) { |
---|
877 | m_maxIterations = maxIterations; |
---|
878 | } |
---|
879 | |
---|
880 | /** |
---|
881 | * Sets the option "heuristicStop". |
---|
882 | * |
---|
883 | * @param heuristicStop the heuristic stop to use |
---|
884 | */ |
---|
885 | public void setHeuristicStop(int heuristicStop){ |
---|
886 | m_heuristicStop = heuristicStop; |
---|
887 | } |
---|
888 | |
---|
889 | /** |
---|
890 | * Sets the option "weightTrimBeta". |
---|
891 | */ |
---|
892 | public void setWeightTrimBeta(double w){ |
---|
893 | m_weightTrimBeta = w; |
---|
894 | } |
---|
895 | |
---|
896 | /** |
---|
897 | * Set the value of useAIC. |
---|
898 | * |
---|
899 | * @param c Value to assign to useAIC. |
---|
900 | */ |
---|
901 | public void setUseAIC(boolean c){ |
---|
902 | m_useAIC = c; |
---|
903 | } |
---|
904 | |
---|
905 | /** |
---|
906 | * Returns the maxIterations parameter. |
---|
907 | * |
---|
908 | * @return the maximum iteration |
---|
909 | */ |
---|
910 | public int getMaxIterations(){ |
---|
911 | return m_maxIterations; |
---|
912 | } |
---|
913 | |
---|
914 | /** |
---|
915 | * Returns an array holding the coefficients of the logistic model. |
---|
916 | * First dimension is the class, the second one holds a list of coefficients. |
---|
917 | * At position zero, the constant term of the model is stored, then, the coefficients for |
---|
918 | * the attributes in ascending order. |
---|
919 | * @return the array of coefficients |
---|
920 | */ |
---|
921 | protected double[][] getCoefficients(){ |
---|
922 | double[][] coefficients = new double[m_numClasses][m_numericDataHeader.numAttributes() + 1]; |
---|
923 | for (int j = 0; j < m_numClasses; j++) { |
---|
924 | //go through simple regression functions and add their coefficient to the coefficient of |
---|
925 | //the attribute they are built on. |
---|
926 | for (int i = 0; i < m_numRegressions; i++) { |
---|
927 | |
---|
928 | double slope = m_regressions[j][i].getSlope(); |
---|
929 | double intercept = m_regressions[j][i].getIntercept(); |
---|
930 | int attribute = m_regressions[j][i].getAttributeIndex(); |
---|
931 | |
---|
932 | coefficients[j][0] += intercept; |
---|
933 | coefficients[j][attribute + 1] += slope; |
---|
934 | } |
---|
935 | } |
---|
936 | |
---|
937 | // Need to multiply all coefficients by (J-1) / J |
---|
938 | for (int j = 0; j < coefficients.length; j++) { |
---|
939 | for (int i = 0; i < coefficients[0].length; i++) { |
---|
940 | coefficients[j][i] *= (double)(m_numClasses - 1) / (double)m_numClasses; |
---|
941 | } |
---|
942 | } |
---|
943 | |
---|
944 | return coefficients; |
---|
945 | } |
---|
946 | |
---|
947 | /** |
---|
948 | * Returns the fraction of all attributes in the data that are used in the |
---|
949 | * logistic model (in percent). |
---|
950 | * An attribute is used in the model if it is used in any of the models for |
---|
951 | * the different classes. |
---|
952 | * |
---|
953 | * @return the fraction of all attributes that are used |
---|
954 | */ |
---|
955 | public double percentAttributesUsed(){ |
---|
956 | boolean[] attributes = new boolean[m_numericDataHeader.numAttributes()]; |
---|
957 | |
---|
958 | double[][] coefficients = getCoefficients(); |
---|
959 | for (int j = 0; j < m_numClasses; j++){ |
---|
960 | for (int i = 1; i < m_numericDataHeader.numAttributes() + 1; i++) { |
---|
961 | //attribute used if it is used in any class, note coefficients are shifted by one (because |
---|
962 | //of constant term). |
---|
963 | if (!Utils.eq(coefficients[j][i],0)) attributes[i - 1] = true; |
---|
964 | } |
---|
965 | } |
---|
966 | |
---|
967 | //count number of used attributes (without the class attribute) |
---|
968 | double count = 0; |
---|
969 | for (int i = 0; i < attributes.length; i++) if (attributes[i]) count++; |
---|
970 | return count / (double)(m_numericDataHeader.numAttributes() - 1) * 100.0; |
---|
971 | } |
---|
972 | |
---|
973 | /** |
---|
974 | * Returns a description of the logistic model (i.e., attributes and |
---|
975 | * coefficients). |
---|
976 | * |
---|
977 | * @return the description of the model |
---|
978 | */ |
---|
979 | public String toString(){ |
---|
980 | |
---|
981 | StringBuffer s = new StringBuffer(); |
---|
982 | |
---|
983 | //get used attributes |
---|
984 | int[][] attributes = getUsedAttributes(); |
---|
985 | |
---|
986 | //get coefficients |
---|
987 | double[][] coefficients = getCoefficients(); |
---|
988 | |
---|
989 | for (int j = 0; j < m_numClasses; j++) { |
---|
990 | s.append("\nClass "+j+" :\n"); |
---|
991 | //constant term |
---|
992 | s.append(Utils.doubleToString(coefficients[j][0],4,2)+" + \n"); |
---|
993 | for (int i = 0; i < attributes[j].length; i++) { |
---|
994 | //attribute/coefficient pairs |
---|
995 | s.append("["+m_numericDataHeader.attribute(attributes[j][i]).name()+"]"); |
---|
996 | s.append(" * " + Utils.doubleToString(coefficients[j][attributes[j][i]+1],4,2)); |
---|
997 | if (i != attributes[j].length - 1) s.append(" +"); |
---|
998 | s.append("\n"); |
---|
999 | } |
---|
1000 | } |
---|
1001 | return new String(s); |
---|
1002 | } |
---|
1003 | |
---|
1004 | /** |
---|
1005 | * Returns class probabilities for an instance. |
---|
1006 | * |
---|
1007 | * @param instance the instance to compute the distribution for |
---|
1008 | * @return the class probabilities |
---|
1009 | * @throws Exception if distribution can't be computed successfully |
---|
1010 | */ |
---|
1011 | public double[] distributionForInstance(Instance instance) throws Exception { |
---|
1012 | |
---|
1013 | instance = (Instance)instance.copy(); |
---|
1014 | |
---|
1015 | //set to numeric pseudo-class |
---|
1016 | instance.setDataset(m_numericDataHeader); |
---|
1017 | |
---|
1018 | //calculate probs via Fs |
---|
1019 | return probs(getFs(instance)); |
---|
1020 | } |
---|
1021 | |
---|
1022 | /** |
---|
1023 | * Cleanup in order to save memory. |
---|
1024 | */ |
---|
1025 | public void cleanup() { |
---|
1026 | //save just header info |
---|
1027 | m_train = new Instances(m_train,0); |
---|
1028 | m_numericData = null; |
---|
1029 | } |
---|
1030 | |
---|
1031 | /** |
---|
1032 | * Returns the revision string. |
---|
1033 | * |
---|
1034 | * @return the revision |
---|
1035 | */ |
---|
1036 | public String getRevision() { |
---|
1037 | return RevisionUtils.extract("$Revision: 5928 $"); |
---|
1038 | } |
---|
1039 | } |
---|