source: src/main/java/weka/classifiers/functions/SimpleLinearRegression.java @ 15

Last change on this file since 15 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 8.0 KB
RevLine 
[4]1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    SimpleLinearRegression.java
19 *    Copyright (C) 2002 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.functions;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.core.Attribute;
28import weka.core.Capabilities;
29import weka.core.Instance;
30import weka.core.Instances;
31import weka.core.RevisionUtils;
32import weka.core.Utils;
33import weka.core.WeightedInstancesHandler;
34import weka.core.Capabilities.Capability;
35
36/**
37 <!-- globalinfo-start -->
38 * Learns a simple linear regression model. Picks the attribute that results in the lowest squared error. Missing values are not allowed. Can only deal with numeric attributes.
39 * <p/>
40 <!-- globalinfo-end -->
41 *
42 <!-- options-start -->
43 * Valid options are: <p/>
44 *
45 * <pre> -D
46 *  If set, classifier is run in debug mode and
47 *  may output additional info to the console</pre>
48 *
49 <!-- options-end -->
50 *
51 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
52 * @version $Revision: 5928 $
53 */
54public class SimpleLinearRegression extends AbstractClassifier
55  implements WeightedInstancesHandler {
56
57  /** for serialization */
58  static final long serialVersionUID = 1679336022895414137L;
59 
60  /** The chosen attribute */
61  private Attribute m_attribute;
62
63  /** The index of the chosen attribute */
64  private int m_attributeIndex;
65
66  /** The slope */
67  private double m_slope;
68 
69  /** The intercept */
70  private double m_intercept;
71
72  /** If true, suppress error message if no useful attribute was found*/   
73  private boolean m_suppressErrorMessage = false; 
74
75  /**
76   * Returns a string describing this classifier
77   * @return a description of the classifier suitable for
78   * displaying in the explorer/experimenter gui
79   */
80  public String globalInfo() {
81    return "Learns a simple linear regression model. "
82      +"Picks the attribute that results in the lowest squared error. "
83      +"Missing values are not allowed. Can only deal with numeric attributes.";
84  }
85
86  /**
87   * Generate a prediction for the supplied instance.
88   *
89   * @param inst the instance to predict.
90   * @return the prediction
91   * @throws Exception if an error occurs
92   */
93  public double classifyInstance(Instance inst) throws Exception {
94   
95    if (m_attribute == null) {
96      return m_intercept;
97    } else {
98      if (inst.isMissing(m_attribute.index())) {
99        throw new Exception("SimpleLinearRegression: No missing values!");
100      }
101      return m_intercept + m_slope * inst.value(m_attribute.index());
102    }
103  }
104
105  /**
106   * Returns default capabilities of the classifier.
107   *
108   * @return      the capabilities of this classifier
109   */
110  public Capabilities getCapabilities() {
111    Capabilities result = super.getCapabilities();
112    result.disableAll();
113
114    // attributes
115    result.enable(Capability.NUMERIC_ATTRIBUTES);
116    result.enable(Capability.DATE_ATTRIBUTES);
117
118    // class
119    result.enable(Capability.NUMERIC_CLASS);
120    result.enable(Capability.DATE_CLASS);
121    result.enable(Capability.MISSING_CLASS_VALUES);
122   
123    return result;
124  }
125 
126  /**
127   * Builds a simple linear regression model given the supplied training data.
128   *
129   * @param insts the training data.
130   * @throws Exception if an error occurs
131   */
132  public void buildClassifier(Instances insts) throws Exception {
133
134    // can classifier handle the data?
135    getCapabilities().testWithFail(insts);
136
137    // remove instances with missing class
138    insts = new Instances(insts);
139    insts.deleteWithMissingClass();
140   
141    // Compute mean of target value
142    double yMean = insts.meanOrMode(insts.classIndex());
143
144    // Choose best attribute
145    double minMsq = Double.MAX_VALUE;
146    m_attribute = null;
147    int chosen = -1;
148    double chosenSlope = Double.NaN;
149    double chosenIntercept = Double.NaN;
150    for (int i = 0; i < insts.numAttributes(); i++) {
151      if (i != insts.classIndex()) {
152        m_attribute = insts.attribute(i);
153       
154        // Compute slope and intercept
155        double xMean = insts.meanOrMode(i);
156        double sumWeightedXDiffSquared = 0;
157        double sumWeightedYDiffSquared = 0;
158        m_slope = 0;
159        for (int j = 0; j < insts.numInstances(); j++) {
160          Instance inst = insts.instance(j);
161          if (!inst.isMissing(i) && !inst.classIsMissing()) {
162            double xDiff = inst.value(i) - xMean;
163            double yDiff = inst.classValue() - yMean;
164            double weightedXDiff = inst.weight() * xDiff;
165            double weightedYDiff = inst.weight() * yDiff;
166            m_slope += weightedXDiff * yDiff;
167            sumWeightedXDiffSquared += weightedXDiff * xDiff;
168            sumWeightedYDiffSquared += weightedYDiff * yDiff;
169          }
170        }
171
172        // Skip attribute if not useful
173        if (sumWeightedXDiffSquared == 0) {
174          continue;
175        }
176        double numerator = m_slope;
177        m_slope /= sumWeightedXDiffSquared;
178        m_intercept = yMean - m_slope * xMean;
179
180        // Compute sum of squared errors
181        double msq = sumWeightedYDiffSquared - m_slope * numerator;
182
183        // Check whether this is the best attribute
184        if (msq < minMsq) {
185          minMsq = msq;
186          chosen = i;
187          chosenSlope = m_slope;
188          chosenIntercept = m_intercept;
189        }
190      }
191    }
192
193    // Set parameters
194    if (chosen == -1) {
195      if (!m_suppressErrorMessage) System.err.println("----- no useful attribute found");
196      m_attribute = null;
197      m_attributeIndex = 0;
198      m_slope = 0;
199      m_intercept = yMean;
200    } else {
201      m_attribute = insts.attribute(chosen);
202      m_attributeIndex = chosen;
203      m_slope = chosenSlope;
204      m_intercept = chosenIntercept;
205    }
206  }
207
208  /**
209   * Returns true if a usable attribute was found.
210   *
211   * @return true if a usable attribute was found.
212   */
213  public boolean foundUsefulAttribute(){
214      return (m_attribute != null); 
215  } 
216
217  /**
218   * Returns the index of the attribute used in the regression.
219   *
220   * @return the index of the attribute.
221   */
222  public int getAttributeIndex(){
223      return m_attributeIndex;
224  }
225
226  /**
227   * Returns the slope of the function.
228   *
229   * @return the slope.
230   */
231  public double getSlope(){
232      return m_slope;
233  }
234   
235  /**
236   * Returns the intercept of the function.
237   *
238   * @return the intercept.
239   */
240  public double getIntercept(){
241      return m_intercept;
242  } 
243
244  /**
245   * Turn off the error message that is reported when no useful attribute is found.
246   *
247   * @param s if set to true turns off the error message
248   */
249  public void setSuppressErrorMessage(boolean s){
250      m_suppressErrorMessage = s;
251  }   
252
253  /**
254   * Returns a description of this classifier as a string
255   *
256   * @return a description of the classifier.
257   */
258  public String toString() {
259
260    StringBuffer text = new StringBuffer();
261    if (m_attribute == null) {
262      text.append("Predicting constant " + m_intercept);
263    } else {
264      text.append("Linear regression on " + m_attribute.name() + "\n\n");
265      text.append(Utils.doubleToString(m_slope,2) + " * " + 
266                m_attribute.name());
267      if (m_intercept > 0) {
268        text.append(" + " + Utils.doubleToString(m_intercept, 2));
269      } else {
270      text.append(" - " + Utils.doubleToString((-m_intercept), 2)); 
271      }
272    }
273    text.append("\n");
274    return text.toString();
275  }
276 
277  /**
278   * Returns the revision string.
279   *
280   * @return            the revision
281   */
282  public String getRevision() {
283    return RevisionUtils.extract("$Revision: 5928 $");
284  }
285
286  /**
287   * Main method for testing this class
288   *
289   * @param argv options
290   */
291  public static void main(String [] argv){
292    runClassifier(new SimpleLinearRegression(), argv);
293  } 
294}
Note: See TracBrowser for help on using the repository browser.