1 | /* |
---|
2 | * This program is free software; you can redistribute it and/or modify |
---|
3 | * it under the terms of the GNU General Public License as published by |
---|
4 | * the Free Software Foundation; either version 2 of the License, or |
---|
5 | * (at your option) any later version. |
---|
6 | * |
---|
7 | * This program is distributed in the hope that it will be useful, |
---|
8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
---|
10 | * GNU General Public License for more details. |
---|
11 | * |
---|
12 | * You should have received a copy of the GNU General Public License |
---|
13 | * along with this program; if not, write to the Free Software |
---|
14 | * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. |
---|
15 | */ |
---|
16 | |
---|
17 | /* |
---|
18 | * LinearRegression.java |
---|
19 | * Copyright (C) 1999 University of Waikato, Hamilton, New Zealand |
---|
20 | * |
---|
21 | */ |
---|
22 | |
---|
23 | package weka.classifiers.functions; |
---|
24 | |
---|
25 | import weka.classifiers.Classifier; |
---|
26 | import weka.classifiers.AbstractClassifier; |
---|
27 | import weka.core.Capabilities; |
---|
28 | import weka.core.Instance; |
---|
29 | import weka.core.Instances; |
---|
30 | import weka.core.Matrix; |
---|
31 | import weka.core.Option; |
---|
32 | import weka.core.OptionHandler; |
---|
33 | import weka.core.RevisionUtils; |
---|
34 | import weka.core.SelectedTag; |
---|
35 | import weka.core.Tag; |
---|
36 | import weka.core.Utils; |
---|
37 | import weka.core.WeightedInstancesHandler; |
---|
38 | import weka.core.Capabilities.Capability; |
---|
39 | import weka.filters.Filter; |
---|
40 | import weka.filters.supervised.attribute.NominalToBinary; |
---|
41 | import weka.filters.unsupervised.attribute.ReplaceMissingValues; |
---|
42 | |
---|
43 | import java.util.Enumeration; |
---|
44 | import java.util.Vector; |
---|
45 | |
---|
46 | /** |
---|
47 | <!-- globalinfo-start --> |
---|
48 | * Class for using linear regression for prediction. Uses the Akaike criterion for model selection, and is able to deal with weighted instances. |
---|
49 | * <p/> |
---|
50 | <!-- globalinfo-end --> |
---|
51 | * |
---|
52 | <!-- options-start --> |
---|
53 | * Valid options are: <p/> |
---|
54 | * |
---|
55 | * <pre> -D |
---|
56 | * Produce debugging output. |
---|
57 | * (default no debugging output)</pre> |
---|
58 | * |
---|
59 | * <pre> -S <number of selection method> |
---|
60 | * Set the attribute selection method to use. 1 = None, 2 = Greedy. |
---|
61 | * (default 0 = M5' method)</pre> |
---|
62 | * |
---|
63 | * <pre> -C |
---|
64 | * Do not try to eliminate colinear attributes. |
---|
65 | * </pre> |
---|
66 | * |
---|
67 | * <pre> -R <double> |
---|
68 | * Set ridge parameter (default 1.0e-8). |
---|
69 | * </pre> |
---|
70 | * |
---|
71 | <!-- options-end --> |
---|
72 | * |
---|
73 | * @author Eibe Frank (eibe@cs.waikato.ac.nz) |
---|
74 | * @author Len Trigg (trigg@cs.waikato.ac.nz) |
---|
75 | * @version $Revision: 5928 $ |
---|
76 | */ |
---|
77 | public class LinearRegression extends AbstractClassifier implements OptionHandler, |
---|
78 | WeightedInstancesHandler { |
---|
79 | |
---|
80 | /** for serialization */ |
---|
81 | static final long serialVersionUID = -3364580862046573747L; |
---|
82 | |
---|
83 | /** Array for storing coefficients of linear regression. */ |
---|
84 | private double[] m_Coefficients; |
---|
85 | |
---|
86 | /** Which attributes are relevant? */ |
---|
87 | private boolean[] m_SelectedAttributes; |
---|
88 | |
---|
89 | /** Variable for storing transformed training data. */ |
---|
90 | private Instances m_TransformedData; |
---|
91 | |
---|
92 | /** The filter for removing missing values. */ |
---|
93 | private ReplaceMissingValues m_MissingFilter; |
---|
94 | |
---|
95 | /** The filter storing the transformation from nominal to |
---|
96 | binary attributes. */ |
---|
97 | private NominalToBinary m_TransformFilter; |
---|
98 | |
---|
99 | /** The standard deviations of the class attribute */ |
---|
100 | private double m_ClassStdDev; |
---|
101 | |
---|
102 | /** The mean of the class attribute */ |
---|
103 | private double m_ClassMean; |
---|
104 | |
---|
105 | /** The index of the class attribute */ |
---|
106 | private int m_ClassIndex; |
---|
107 | |
---|
108 | /** The attributes means */ |
---|
109 | private double[] m_Means; |
---|
110 | |
---|
111 | /** The attribute standard deviations */ |
---|
112 | private double[] m_StdDevs; |
---|
113 | |
---|
114 | /** True if debug output will be printed */ |
---|
115 | private boolean b_Debug; |
---|
116 | |
---|
117 | /** The current attribute selection method */ |
---|
118 | private int m_AttributeSelection; |
---|
119 | |
---|
120 | /** Attribute selection method: M5 method */ |
---|
121 | public static final int SELECTION_M5 = 0; |
---|
122 | /** Attribute selection method: No attribute selection */ |
---|
123 | public static final int SELECTION_NONE = 1; |
---|
124 | /** Attribute selection method: Greedy method */ |
---|
125 | public static final int SELECTION_GREEDY = 2; |
---|
126 | /** Attribute selection methods */ |
---|
127 | public static final Tag [] TAGS_SELECTION = { |
---|
128 | new Tag(SELECTION_NONE, "No attribute selection"), |
---|
129 | new Tag(SELECTION_M5, "M5 method"), |
---|
130 | new Tag(SELECTION_GREEDY, "Greedy method") |
---|
131 | }; |
---|
132 | |
---|
133 | /** Try to eliminate correlated attributes? */ |
---|
134 | private boolean m_EliminateColinearAttributes = true; |
---|
135 | |
---|
136 | /** Turn off all checks and conversions? */ |
---|
137 | private boolean m_checksTurnedOff = false; |
---|
138 | |
---|
139 | /** The ridge parameter */ |
---|
140 | private double m_Ridge = 1.0e-8; |
---|
141 | |
---|
142 | /** |
---|
143 | * Turns off checks for missing values, etc. Use with caution. |
---|
144 | * Also turns off scaling. |
---|
145 | */ |
---|
146 | public void turnChecksOff() { |
---|
147 | |
---|
148 | m_checksTurnedOff = true; |
---|
149 | } |
---|
150 | |
---|
151 | /** |
---|
152 | * Turns on checks for missing values, etc. Also turns |
---|
153 | * on scaling. |
---|
154 | */ |
---|
155 | public void turnChecksOn() { |
---|
156 | |
---|
157 | m_checksTurnedOff = false; |
---|
158 | } |
---|
159 | |
---|
160 | /** |
---|
161 | * Returns a string describing this classifier |
---|
162 | * @return a description of the classifier suitable for |
---|
163 | * displaying in the explorer/experimenter gui |
---|
164 | */ |
---|
165 | public String globalInfo() { |
---|
166 | return "Class for using linear regression for prediction. Uses the Akaike " |
---|
167 | +"criterion for model selection, and is able to deal with weighted " |
---|
168 | +"instances."; |
---|
169 | } |
---|
170 | |
---|
171 | /** |
---|
172 | * Returns default capabilities of the classifier. |
---|
173 | * |
---|
174 | * @return the capabilities of this classifier |
---|
175 | */ |
---|
176 | public Capabilities getCapabilities() { |
---|
177 | Capabilities result = super.getCapabilities(); |
---|
178 | result.disableAll(); |
---|
179 | |
---|
180 | // attributes |
---|
181 | result.enable(Capability.NOMINAL_ATTRIBUTES); |
---|
182 | result.enable(Capability.NUMERIC_ATTRIBUTES); |
---|
183 | result.enable(Capability.DATE_ATTRIBUTES); |
---|
184 | result.enable(Capability.MISSING_VALUES); |
---|
185 | |
---|
186 | // class |
---|
187 | result.enable(Capability.NUMERIC_CLASS); |
---|
188 | result.enable(Capability.DATE_CLASS); |
---|
189 | result.enable(Capability.MISSING_CLASS_VALUES); |
---|
190 | |
---|
191 | return result; |
---|
192 | } |
---|
193 | |
---|
194 | /** |
---|
195 | * Builds a regression model for the given data. |
---|
196 | * |
---|
197 | * @param data the training data to be used for generating the |
---|
198 | * linear regression function |
---|
199 | * @throws Exception if the classifier could not be built successfully |
---|
200 | */ |
---|
201 | public void buildClassifier(Instances data) throws Exception { |
---|
202 | |
---|
203 | if (!m_checksTurnedOff) { |
---|
204 | // can classifier handle the data? |
---|
205 | getCapabilities().testWithFail(data); |
---|
206 | |
---|
207 | // remove instances with missing class |
---|
208 | data = new Instances(data); |
---|
209 | data.deleteWithMissingClass(); |
---|
210 | } |
---|
211 | |
---|
212 | // Preprocess instances |
---|
213 | if (!m_checksTurnedOff) { |
---|
214 | m_TransformFilter = new NominalToBinary(); |
---|
215 | m_TransformFilter.setInputFormat(data); |
---|
216 | data = Filter.useFilter(data, m_TransformFilter); |
---|
217 | m_MissingFilter = new ReplaceMissingValues(); |
---|
218 | m_MissingFilter.setInputFormat(data); |
---|
219 | data = Filter.useFilter(data, m_MissingFilter); |
---|
220 | data.deleteWithMissingClass(); |
---|
221 | } else { |
---|
222 | m_TransformFilter = null; |
---|
223 | m_MissingFilter = null; |
---|
224 | } |
---|
225 | |
---|
226 | m_ClassIndex = data.classIndex(); |
---|
227 | m_TransformedData = data; |
---|
228 | |
---|
229 | // Turn all attributes on for a start |
---|
230 | m_SelectedAttributes = new boolean[data.numAttributes()]; |
---|
231 | for (int i = 0; i < data.numAttributes(); i++) { |
---|
232 | if (i != m_ClassIndex) { |
---|
233 | m_SelectedAttributes[i] = true; |
---|
234 | } |
---|
235 | } |
---|
236 | m_Coefficients = null; |
---|
237 | |
---|
238 | // Compute means and standard deviations |
---|
239 | m_Means = new double[data.numAttributes()]; |
---|
240 | m_StdDevs = new double[data.numAttributes()]; |
---|
241 | for (int j = 0; j < data.numAttributes(); j++) { |
---|
242 | if (j != data.classIndex()) { |
---|
243 | m_Means[j] = data.meanOrMode(j); |
---|
244 | m_StdDevs[j] = Math.sqrt(data.variance(j)); |
---|
245 | if (m_StdDevs[j] == 0) { |
---|
246 | m_SelectedAttributes[j] = false; |
---|
247 | } |
---|
248 | } |
---|
249 | } |
---|
250 | |
---|
251 | m_ClassStdDev = Math.sqrt(data.variance(m_TransformedData.classIndex())); |
---|
252 | m_ClassMean = data.meanOrMode(m_TransformedData.classIndex()); |
---|
253 | |
---|
254 | // Perform the regression |
---|
255 | findBestModel(); |
---|
256 | |
---|
257 | // Save memory |
---|
258 | m_TransformedData = new Instances(data, 0); |
---|
259 | } |
---|
260 | |
---|
261 | /** |
---|
262 | * Classifies the given instance using the linear regression function. |
---|
263 | * |
---|
264 | * @param instance the test instance |
---|
265 | * @return the classification |
---|
266 | * @throws Exception if classification can't be done successfully |
---|
267 | */ |
---|
268 | public double classifyInstance(Instance instance) throws Exception { |
---|
269 | |
---|
270 | // Transform the input instance |
---|
271 | Instance transformedInstance = instance; |
---|
272 | if (!m_checksTurnedOff) { |
---|
273 | m_TransformFilter.input(transformedInstance); |
---|
274 | m_TransformFilter.batchFinished(); |
---|
275 | transformedInstance = m_TransformFilter.output(); |
---|
276 | m_MissingFilter.input(transformedInstance); |
---|
277 | m_MissingFilter.batchFinished(); |
---|
278 | transformedInstance = m_MissingFilter.output(); |
---|
279 | } |
---|
280 | |
---|
281 | // Calculate the dependent variable from the regression model |
---|
282 | return regressionPrediction(transformedInstance, |
---|
283 | m_SelectedAttributes, |
---|
284 | m_Coefficients); |
---|
285 | } |
---|
286 | |
---|
287 | /** |
---|
288 | * Outputs the linear regression model as a string. |
---|
289 | * |
---|
290 | * @return the model as string |
---|
291 | */ |
---|
292 | public String toString() { |
---|
293 | |
---|
294 | if (m_TransformedData == null) { |
---|
295 | return "Linear Regression: No model built yet."; |
---|
296 | } |
---|
297 | try { |
---|
298 | StringBuffer text = new StringBuffer(); |
---|
299 | int column = 0; |
---|
300 | boolean first = true; |
---|
301 | |
---|
302 | text.append("\nLinear Regression Model\n\n"); |
---|
303 | |
---|
304 | text.append(m_TransformedData.classAttribute().name()+" =\n\n"); |
---|
305 | for (int i = 0; i < m_TransformedData.numAttributes(); i++) { |
---|
306 | if ((i != m_ClassIndex) |
---|
307 | && (m_SelectedAttributes[i])) { |
---|
308 | if (!first) |
---|
309 | text.append(" +\n"); |
---|
310 | else |
---|
311 | first = false; |
---|
312 | text.append(Utils.doubleToString(m_Coefficients[column], 12, 4) |
---|
313 | + " * "); |
---|
314 | text.append(m_TransformedData.attribute(i).name()); |
---|
315 | column++; |
---|
316 | } |
---|
317 | } |
---|
318 | text.append(" +\n" + |
---|
319 | Utils.doubleToString(m_Coefficients[column], 12, 4)); |
---|
320 | return text.toString(); |
---|
321 | } catch (Exception e) { |
---|
322 | return "Can't print Linear Regression!"; |
---|
323 | } |
---|
324 | } |
---|
325 | |
---|
326 | /** |
---|
327 | * Returns an enumeration describing the available options. |
---|
328 | * |
---|
329 | * @return an enumeration of all the available options. |
---|
330 | */ |
---|
331 | public Enumeration listOptions() { |
---|
332 | |
---|
333 | Vector newVector = new Vector(4); |
---|
334 | newVector.addElement(new Option("\tProduce debugging output.\n" |
---|
335 | + "\t(default no debugging output)", |
---|
336 | "D", 0, "-D")); |
---|
337 | newVector.addElement(new Option("\tSet the attribute selection method" |
---|
338 | + " to use. 1 = None, 2 = Greedy.\n" |
---|
339 | + "\t(default 0 = M5' method)", |
---|
340 | "S", 1, "-S <number of selection method>")); |
---|
341 | newVector.addElement(new Option("\tDo not try to eliminate colinear" |
---|
342 | + " attributes.\n", |
---|
343 | "C", 0, "-C")); |
---|
344 | newVector.addElement(new Option("\tSet ridge parameter (default 1.0e-8).\n", |
---|
345 | "R", 1, "-R <double>")); |
---|
346 | return newVector.elements(); |
---|
347 | } |
---|
348 | |
---|
349 | /** |
---|
350 | * Parses a given list of options. <p/> |
---|
351 | * |
---|
352 | <!-- options-start --> |
---|
353 | * Valid options are: <p/> |
---|
354 | * |
---|
355 | * <pre> -D |
---|
356 | * Produce debugging output. |
---|
357 | * (default no debugging output)</pre> |
---|
358 | * |
---|
359 | * <pre> -S <number of selection method> |
---|
360 | * Set the attribute selection method to use. 1 = None, 2 = Greedy. |
---|
361 | * (default 0 = M5' method)</pre> |
---|
362 | * |
---|
363 | * <pre> -C |
---|
364 | * Do not try to eliminate colinear attributes. |
---|
365 | * </pre> |
---|
366 | * |
---|
367 | * <pre> -R <double> |
---|
368 | * Set ridge parameter (default 1.0e-8). |
---|
369 | * </pre> |
---|
370 | * |
---|
371 | <!-- options-end --> |
---|
372 | * |
---|
373 | * @param options the list of options as an array of strings |
---|
374 | * @throws Exception if an option is not supported |
---|
375 | */ |
---|
376 | public void setOptions(String[] options) throws Exception { |
---|
377 | |
---|
378 | String selectionString = Utils.getOption('S', options); |
---|
379 | if (selectionString.length() != 0) { |
---|
380 | setAttributeSelectionMethod(new SelectedTag(Integer |
---|
381 | .parseInt(selectionString), |
---|
382 | TAGS_SELECTION)); |
---|
383 | } else { |
---|
384 | setAttributeSelectionMethod(new SelectedTag(SELECTION_M5, |
---|
385 | TAGS_SELECTION)); |
---|
386 | } |
---|
387 | String ridgeString = Utils.getOption('R', options); |
---|
388 | if (ridgeString.length() != 0) { |
---|
389 | setRidge(new Double(ridgeString).doubleValue()); |
---|
390 | } else { |
---|
391 | setRidge(1.0e-8); |
---|
392 | } |
---|
393 | setDebug(Utils.getFlag('D', options)); |
---|
394 | setEliminateColinearAttributes(!Utils.getFlag('C', options)); |
---|
395 | } |
---|
396 | |
---|
397 | /** |
---|
398 | * Returns the coefficients for this linear model. |
---|
399 | * |
---|
400 | * @return the coefficients for this linear model |
---|
401 | */ |
---|
402 | public double[] coefficients() { |
---|
403 | |
---|
404 | double[] coefficients = new double[m_SelectedAttributes.length + 1]; |
---|
405 | int counter = 0; |
---|
406 | for (int i = 0; i < m_SelectedAttributes.length; i++) { |
---|
407 | if ((m_SelectedAttributes[i]) && ((i != m_ClassIndex))) { |
---|
408 | coefficients[i] = m_Coefficients[counter++]; |
---|
409 | } |
---|
410 | } |
---|
411 | coefficients[m_SelectedAttributes.length] = m_Coefficients[counter]; |
---|
412 | return coefficients; |
---|
413 | } |
---|
414 | |
---|
415 | /** |
---|
416 | * Gets the current settings of the classifier. |
---|
417 | * |
---|
418 | * @return an array of strings suitable for passing to setOptions |
---|
419 | */ |
---|
420 | public String [] getOptions() { |
---|
421 | |
---|
422 | String [] options = new String [6]; |
---|
423 | int current = 0; |
---|
424 | |
---|
425 | options[current++] = "-S"; |
---|
426 | options[current++] = "" + getAttributeSelectionMethod() |
---|
427 | .getSelectedTag().getID(); |
---|
428 | if (getDebug()) { |
---|
429 | options[current++] = "-D"; |
---|
430 | } |
---|
431 | if (!getEliminateColinearAttributes()) { |
---|
432 | options[current++] = "-C"; |
---|
433 | } |
---|
434 | options[current++] = "-R"; |
---|
435 | options[current++] = "" + getRidge(); |
---|
436 | |
---|
437 | while (current < options.length) { |
---|
438 | options[current++] = ""; |
---|
439 | } |
---|
440 | return options; |
---|
441 | } |
---|
442 | |
---|
443 | /** |
---|
444 | * Returns the tip text for this property |
---|
445 | * @return tip text for this property suitable for |
---|
446 | * displaying in the explorer/experimenter gui |
---|
447 | */ |
---|
448 | public String ridgeTipText() { |
---|
449 | return "The value of the Ridge parameter."; |
---|
450 | } |
---|
451 | |
---|
452 | /** |
---|
453 | * Get the value of Ridge. |
---|
454 | * |
---|
455 | * @return Value of Ridge. |
---|
456 | */ |
---|
457 | public double getRidge() { |
---|
458 | |
---|
459 | return m_Ridge; |
---|
460 | } |
---|
461 | |
---|
462 | /** |
---|
463 | * Set the value of Ridge. |
---|
464 | * |
---|
465 | * @param newRidge Value to assign to Ridge. |
---|
466 | */ |
---|
467 | public void setRidge(double newRidge) { |
---|
468 | |
---|
469 | m_Ridge = newRidge; |
---|
470 | } |
---|
471 | |
---|
472 | /** |
---|
473 | * Returns the tip text for this property |
---|
474 | * @return tip text for this property suitable for |
---|
475 | * displaying in the explorer/experimenter gui |
---|
476 | */ |
---|
477 | public String eliminateColinearAttributesTipText() { |
---|
478 | return "Eliminate colinear attributes."; |
---|
479 | } |
---|
480 | |
---|
481 | /** |
---|
482 | * Get the value of EliminateColinearAttributes. |
---|
483 | * |
---|
484 | * @return Value of EliminateColinearAttributes. |
---|
485 | */ |
---|
486 | public boolean getEliminateColinearAttributes() { |
---|
487 | |
---|
488 | return m_EliminateColinearAttributes; |
---|
489 | } |
---|
490 | |
---|
491 | /** |
---|
492 | * Set the value of EliminateColinearAttributes. |
---|
493 | * |
---|
494 | * @param newEliminateColinearAttributes Value to assign to EliminateColinearAttributes. |
---|
495 | */ |
---|
496 | public void setEliminateColinearAttributes(boolean newEliminateColinearAttributes) { |
---|
497 | |
---|
498 | m_EliminateColinearAttributes = newEliminateColinearAttributes; |
---|
499 | } |
---|
500 | |
---|
501 | /** |
---|
502 | * Get the number of coefficients used in the model |
---|
503 | * |
---|
504 | * @return the number of coefficients |
---|
505 | */ |
---|
506 | public int numParameters() |
---|
507 | { |
---|
508 | return m_Coefficients.length-1; |
---|
509 | } |
---|
510 | |
---|
511 | /** |
---|
512 | * Returns the tip text for this property |
---|
513 | * @return tip text for this property suitable for |
---|
514 | * displaying in the explorer/experimenter gui |
---|
515 | */ |
---|
516 | public String attributeSelectionMethodTipText() { |
---|
517 | return "Set the method used to select attributes for use in the linear " |
---|
518 | +"regression. Available methods are: no attribute selection, attribute " |
---|
519 | +"selection using M5's method (step through the attributes removing the one " |
---|
520 | +"with the smallest standardised coefficient until no improvement is observed " |
---|
521 | +"in the estimate of the error given by the Akaike " |
---|
522 | +"information criterion), and a greedy selection using the Akaike information " |
---|
523 | +"metric."; |
---|
524 | } |
---|
525 | |
---|
526 | /** |
---|
527 | * Sets the method used to select attributes for use in the |
---|
528 | * linear regression. |
---|
529 | * |
---|
530 | * @param method the attribute selection method to use. |
---|
531 | */ |
---|
532 | public void setAttributeSelectionMethod(SelectedTag method) { |
---|
533 | |
---|
534 | if (method.getTags() == TAGS_SELECTION) { |
---|
535 | m_AttributeSelection = method.getSelectedTag().getID(); |
---|
536 | } |
---|
537 | } |
---|
538 | |
---|
539 | /** |
---|
540 | * Gets the method used to select attributes for use in the |
---|
541 | * linear regression. |
---|
542 | * |
---|
543 | * @return the method to use. |
---|
544 | */ |
---|
545 | public SelectedTag getAttributeSelectionMethod() { |
---|
546 | |
---|
547 | return new SelectedTag(m_AttributeSelection, TAGS_SELECTION); |
---|
548 | } |
---|
549 | |
---|
550 | /** |
---|
551 | * Returns the tip text for this property |
---|
552 | * @return tip text for this property suitable for |
---|
553 | * displaying in the explorer/experimenter gui |
---|
554 | */ |
---|
555 | public String debugTipText() { |
---|
556 | return "Outputs debug information to the console."; |
---|
557 | } |
---|
558 | |
---|
559 | /** |
---|
560 | * Controls whether debugging output will be printed |
---|
561 | * |
---|
562 | * @param debug true if debugging output should be printed |
---|
563 | */ |
---|
564 | public void setDebug(boolean debug) { |
---|
565 | |
---|
566 | b_Debug = debug; |
---|
567 | } |
---|
568 | |
---|
569 | /** |
---|
570 | * Controls whether debugging output will be printed |
---|
571 | * |
---|
572 | * @return true if debugging output is printed |
---|
573 | */ |
---|
574 | public boolean getDebug() { |
---|
575 | |
---|
576 | return b_Debug; |
---|
577 | } |
---|
578 | |
---|
579 | /** |
---|
580 | * Removes the attribute with the highest standardised coefficient |
---|
581 | * greater than 1.5 from the selected attributes. |
---|
582 | * |
---|
583 | * @param selectedAttributes an array of flags indicating which |
---|
584 | * attributes are included in the regression model |
---|
585 | * @param coefficients an array of coefficients for the regression |
---|
586 | * model |
---|
587 | * @return true if an attribute was removed |
---|
588 | */ |
---|
589 | private boolean deselectColinearAttributes(boolean [] selectedAttributes, |
---|
590 | double [] coefficients) { |
---|
591 | |
---|
592 | double maxSC = 1.5; |
---|
593 | int maxAttr = -1, coeff = 0; |
---|
594 | for (int i = 0; i < selectedAttributes.length; i++) { |
---|
595 | if (selectedAttributes[i]) { |
---|
596 | double SC = Math.abs(coefficients[coeff] * m_StdDevs[i] |
---|
597 | / m_ClassStdDev); |
---|
598 | if (SC > maxSC) { |
---|
599 | maxSC = SC; |
---|
600 | maxAttr = i; |
---|
601 | } |
---|
602 | coeff++; |
---|
603 | } |
---|
604 | } |
---|
605 | if (maxAttr >= 0) { |
---|
606 | selectedAttributes[maxAttr] = false; |
---|
607 | if (b_Debug) { |
---|
608 | System.out.println("Deselected colinear attribute:" + (maxAttr + 1) |
---|
609 | + " with standardised coefficient: " + maxSC); |
---|
610 | } |
---|
611 | return true; |
---|
612 | } |
---|
613 | return false; |
---|
614 | } |
---|
615 | |
---|
616 | /** |
---|
617 | * Performs a greedy search for the best regression model using |
---|
618 | * Akaike's criterion. |
---|
619 | * |
---|
620 | * @throws Exception if regression can't be done |
---|
621 | */ |
---|
622 | private void findBestModel() throws Exception { |
---|
623 | |
---|
624 | // For the weighted case we still use numInstances in |
---|
625 | // the calculation of the Akaike criterion. |
---|
626 | int numInstances = m_TransformedData.numInstances(); |
---|
627 | |
---|
628 | if (b_Debug) { |
---|
629 | System.out.println((new Instances(m_TransformedData, 0)).toString()); |
---|
630 | } |
---|
631 | |
---|
632 | // Perform a regression for the full model, and remove colinear attributes |
---|
633 | do { |
---|
634 | m_Coefficients = doRegression(m_SelectedAttributes); |
---|
635 | } while (m_EliminateColinearAttributes && |
---|
636 | deselectColinearAttributes(m_SelectedAttributes, m_Coefficients)); |
---|
637 | |
---|
638 | // Figure out current number of attributes + 1. (We treat this model |
---|
639 | // as the full model for the Akaike-based methods.) |
---|
640 | int numAttributes = 1; |
---|
641 | for (int i = 0; i < m_SelectedAttributes.length; i++) { |
---|
642 | if (m_SelectedAttributes[i]) { |
---|
643 | numAttributes++; |
---|
644 | } |
---|
645 | } |
---|
646 | |
---|
647 | double fullMSE = calculateSE(m_SelectedAttributes, m_Coefficients); |
---|
648 | double akaike = (numInstances - numAttributes) + 2 * numAttributes; |
---|
649 | if (b_Debug) { |
---|
650 | System.out.println("Initial Akaike value: " + akaike); |
---|
651 | } |
---|
652 | |
---|
653 | boolean improved; |
---|
654 | int currentNumAttributes = numAttributes; |
---|
655 | switch (m_AttributeSelection) { |
---|
656 | |
---|
657 | case SELECTION_GREEDY: |
---|
658 | |
---|
659 | // Greedy attribute removal |
---|
660 | do { |
---|
661 | boolean [] currentSelected = (boolean []) m_SelectedAttributes.clone(); |
---|
662 | improved = false; |
---|
663 | currentNumAttributes--; |
---|
664 | |
---|
665 | for (int i = 0; i < m_SelectedAttributes.length; i++) { |
---|
666 | if (currentSelected[i]) { |
---|
667 | |
---|
668 | // Calculate the akaike rating without this attribute |
---|
669 | currentSelected[i] = false; |
---|
670 | double [] currentCoeffs = doRegression(currentSelected); |
---|
671 | double currentMSE = calculateSE(currentSelected, currentCoeffs); |
---|
672 | double currentAkaike = currentMSE / fullMSE |
---|
673 | * (numInstances - numAttributes) |
---|
674 | + 2 * currentNumAttributes; |
---|
675 | if (b_Debug) { |
---|
676 | System.out.println("(akaike: " + currentAkaike); |
---|
677 | } |
---|
678 | |
---|
679 | // If it is better than the current best |
---|
680 | if (currentAkaike < akaike) { |
---|
681 | if (b_Debug) { |
---|
682 | System.err.println("Removing attribute " + (i + 1) |
---|
683 | + " improved Akaike: " + currentAkaike); |
---|
684 | } |
---|
685 | improved = true; |
---|
686 | akaike = currentAkaike; |
---|
687 | System.arraycopy(currentSelected, 0, |
---|
688 | m_SelectedAttributes, 0, |
---|
689 | m_SelectedAttributes.length); |
---|
690 | m_Coefficients = currentCoeffs; |
---|
691 | } |
---|
692 | currentSelected[i] = true; |
---|
693 | } |
---|
694 | } |
---|
695 | } while (improved); |
---|
696 | break; |
---|
697 | |
---|
698 | case SELECTION_M5: |
---|
699 | |
---|
700 | // Step through the attributes removing the one with the smallest |
---|
701 | // standardised coefficient until no improvement in Akaike |
---|
702 | do { |
---|
703 | improved = false; |
---|
704 | currentNumAttributes--; |
---|
705 | |
---|
706 | // Find attribute with smallest SC |
---|
707 | double minSC = 0; |
---|
708 | int minAttr = -1, coeff = 0; |
---|
709 | for (int i = 0; i < m_SelectedAttributes.length; i++) { |
---|
710 | if (m_SelectedAttributes[i]) { |
---|
711 | double SC = Math.abs(m_Coefficients[coeff] * m_StdDevs[i] |
---|
712 | / m_ClassStdDev); |
---|
713 | if ((coeff == 0) || (SC < minSC)) { |
---|
714 | minSC = SC; |
---|
715 | minAttr = i; |
---|
716 | } |
---|
717 | coeff++; |
---|
718 | } |
---|
719 | } |
---|
720 | |
---|
721 | // See whether removing it improves the Akaike score |
---|
722 | if (minAttr >= 0) { |
---|
723 | m_SelectedAttributes[minAttr] = false; |
---|
724 | double [] currentCoeffs = doRegression(m_SelectedAttributes); |
---|
725 | double currentMSE = calculateSE(m_SelectedAttributes, currentCoeffs); |
---|
726 | double currentAkaike = currentMSE / fullMSE |
---|
727 | * (numInstances - numAttributes) |
---|
728 | + 2 * currentNumAttributes; |
---|
729 | if (b_Debug) { |
---|
730 | System.out.println("(akaike: " + currentAkaike); |
---|
731 | } |
---|
732 | |
---|
733 | // If it is better than the current best |
---|
734 | if (currentAkaike < akaike) { |
---|
735 | if (b_Debug) { |
---|
736 | System.err.println("Removing attribute " + (minAttr + 1) |
---|
737 | + " improved Akaike: " + currentAkaike); |
---|
738 | } |
---|
739 | improved = true; |
---|
740 | akaike = currentAkaike; |
---|
741 | m_Coefficients = currentCoeffs; |
---|
742 | } else { |
---|
743 | m_SelectedAttributes[minAttr] = true; |
---|
744 | } |
---|
745 | } |
---|
746 | } while (improved); |
---|
747 | break; |
---|
748 | |
---|
749 | case SELECTION_NONE: |
---|
750 | break; |
---|
751 | } |
---|
752 | } |
---|
753 | |
---|
754 | /** |
---|
755 | * Calculate the squared error of a regression model on the |
---|
756 | * training data |
---|
757 | * |
---|
758 | * @param selectedAttributes an array of flags indicating which |
---|
759 | * attributes are included in the regression model |
---|
760 | * @param coefficients an array of coefficients for the regression |
---|
761 | * model |
---|
762 | * @return the mean squared error on the training data |
---|
763 | * @throws Exception if there is a missing class value in the training |
---|
764 | * data |
---|
765 | */ |
---|
766 | private double calculateSE(boolean [] selectedAttributes, |
---|
767 | double [] coefficients) throws Exception { |
---|
768 | |
---|
769 | double mse = 0; |
---|
770 | for (int i = 0; i < m_TransformedData.numInstances(); i++) { |
---|
771 | double prediction = regressionPrediction(m_TransformedData.instance(i), |
---|
772 | selectedAttributes, |
---|
773 | coefficients); |
---|
774 | double error = prediction - m_TransformedData.instance(i).classValue(); |
---|
775 | mse += error * error; |
---|
776 | } |
---|
777 | return mse; |
---|
778 | } |
---|
779 | |
---|
780 | /** |
---|
781 | * Calculate the dependent value for a given instance for a |
---|
782 | * given regression model. |
---|
783 | * |
---|
784 | * @param transformedInstance the input instance |
---|
785 | * @param selectedAttributes an array of flags indicating which |
---|
786 | * attributes are included in the regression model |
---|
787 | * @param coefficients an array of coefficients for the regression |
---|
788 | * model |
---|
789 | * @return the regression value for the instance. |
---|
790 | * @throws Exception if the class attribute of the input instance |
---|
791 | * is not assigned |
---|
792 | */ |
---|
793 | private double regressionPrediction(Instance transformedInstance, |
---|
794 | boolean [] selectedAttributes, |
---|
795 | double [] coefficients) |
---|
796 | throws Exception { |
---|
797 | |
---|
798 | double result = 0; |
---|
799 | int column = 0; |
---|
800 | for (int j = 0; j < transformedInstance.numAttributes(); j++) { |
---|
801 | if ((m_ClassIndex != j) |
---|
802 | && (selectedAttributes[j])) { |
---|
803 | result += coefficients[column] * transformedInstance.value(j); |
---|
804 | column++; |
---|
805 | } |
---|
806 | } |
---|
807 | result += coefficients[column]; |
---|
808 | |
---|
809 | return result; |
---|
810 | } |
---|
811 | |
---|
812 | /** |
---|
813 | * Calculate a linear regression using the selected attributes |
---|
814 | * |
---|
815 | * @param selectedAttributes an array of booleans where each element |
---|
816 | * is true if the corresponding attribute should be included in the |
---|
817 | * regression. |
---|
818 | * @return an array of coefficients for the linear regression model. |
---|
819 | * @throws Exception if an error occurred during the regression. |
---|
820 | */ |
---|
821 | private double [] doRegression(boolean [] selectedAttributes) |
---|
822 | throws Exception { |
---|
823 | |
---|
824 | if (b_Debug) { |
---|
825 | System.out.print("doRegression("); |
---|
826 | for (int i = 0; i < selectedAttributes.length; i++) { |
---|
827 | System.out.print(" " + selectedAttributes[i]); |
---|
828 | } |
---|
829 | System.out.println(" )"); |
---|
830 | } |
---|
831 | int numAttributes = 0; |
---|
832 | for (int i = 0; i < selectedAttributes.length; i++) { |
---|
833 | if (selectedAttributes[i]) { |
---|
834 | numAttributes++; |
---|
835 | } |
---|
836 | } |
---|
837 | |
---|
838 | // Check whether there are still attributes left |
---|
839 | Matrix independent = null, dependent = null; |
---|
840 | double[] weights = null; |
---|
841 | if (numAttributes > 0) { |
---|
842 | independent = new Matrix(m_TransformedData.numInstances(), |
---|
843 | numAttributes); |
---|
844 | dependent = new Matrix(m_TransformedData.numInstances(), 1); |
---|
845 | for (int i = 0; i < m_TransformedData.numInstances(); i ++) { |
---|
846 | Instance inst = m_TransformedData.instance(i); |
---|
847 | int column = 0; |
---|
848 | for (int j = 0; j < m_TransformedData.numAttributes(); j++) { |
---|
849 | if (j == m_ClassIndex) { |
---|
850 | dependent.setElement(i, 0, inst.classValue()); |
---|
851 | } else { |
---|
852 | if (selectedAttributes[j]) { |
---|
853 | double value = inst.value(j) - m_Means[j]; |
---|
854 | |
---|
855 | // We only need to do this if we want to |
---|
856 | // scale the input |
---|
857 | if (!m_checksTurnedOff) { |
---|
858 | value /= m_StdDevs[j]; |
---|
859 | } |
---|
860 | independent.setElement(i, column, value); |
---|
861 | column++; |
---|
862 | } |
---|
863 | } |
---|
864 | } |
---|
865 | } |
---|
866 | |
---|
867 | // Grab instance weights |
---|
868 | weights = new double [m_TransformedData.numInstances()]; |
---|
869 | for (int i = 0; i < weights.length; i++) { |
---|
870 | weights[i] = m_TransformedData.instance(i).weight(); |
---|
871 | } |
---|
872 | } |
---|
873 | |
---|
874 | // Compute coefficients (note that we have to treat the |
---|
875 | // intercept separately so that it doesn't get affected |
---|
876 | // by the ridge constant.) |
---|
877 | double[] coefficients = new double[numAttributes + 1]; |
---|
878 | if (numAttributes > 0) { |
---|
879 | double[] coeffsWithoutIntercept = |
---|
880 | independent.regression(dependent, weights, m_Ridge); |
---|
881 | System.arraycopy(coeffsWithoutIntercept, 0, coefficients, 0, |
---|
882 | numAttributes); |
---|
883 | } |
---|
884 | coefficients[numAttributes] = m_ClassMean; |
---|
885 | |
---|
886 | // Convert coefficients into original scale |
---|
887 | int column = 0; |
---|
888 | for(int i = 0; i < m_TransformedData.numAttributes(); i++) { |
---|
889 | if ((i != m_TransformedData.classIndex()) && |
---|
890 | (selectedAttributes[i])) { |
---|
891 | |
---|
892 | // We only need to do this if we have scaled the |
---|
893 | // input. |
---|
894 | if (!m_checksTurnedOff) { |
---|
895 | coefficients[column] /= m_StdDevs[i]; |
---|
896 | } |
---|
897 | |
---|
898 | // We have centred the input |
---|
899 | coefficients[coefficients.length - 1] -= |
---|
900 | coefficients[column] * m_Means[i]; |
---|
901 | column++; |
---|
902 | } |
---|
903 | } |
---|
904 | |
---|
905 | return coefficients; |
---|
906 | } |
---|
907 | |
---|
908 | /** |
---|
909 | * Returns the revision string. |
---|
910 | * |
---|
911 | * @return the revision |
---|
912 | */ |
---|
913 | public String getRevision() { |
---|
914 | return RevisionUtils.extract("$Revision: 5928 $"); |
---|
915 | } |
---|
916 | |
---|
917 | /** |
---|
918 | * Generates a linear regression function predictor. |
---|
919 | * |
---|
920 | * @param argv the options |
---|
921 | */ |
---|
922 | public static void main(String argv[]) { |
---|
923 | runClassifier(new LinearRegression(), argv); |
---|
924 | } |
---|
925 | } |
---|