source: src/main/java/weka/classifiers/functions/IsotonicRegression.java @ 18

Last change on this file since 18 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 9.5 KB
RevLine 
[4]1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    IsotonicRegression.java
19 *    Copyright (C) 2006 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.functions;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.classifiers.Evaluation;
28import weka.core.Attribute;
29import weka.core.Capabilities;
30import weka.core.Instance;
31import weka.core.Instances;
32import weka.core.RevisionUtils;
33import weka.core.Utils;
34import weka.core.WeightedInstancesHandler;
35import weka.core.Capabilities.Capability;
36
37import java.util.Arrays;
38
39/**
40 <!-- globalinfo-start -->
41 * Learns an isotonic regression model. Picks the attribute that results in the lowest squared error. Missing values are not allowed. Can only deal with numeric attributes.Considers the monotonically increasing case as well as the monotonicallydecreasing case
42 * <p/>
43 <!-- globalinfo-end -->
44 *
45 <!-- options-start -->
46 * Valid options are: <p/>
47 *
48 * <pre> -D
49 *  If set, classifier is run in debug mode and
50 *  may output additional info to the console</pre>
51 *
52 <!-- options-end -->
53 *
54 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
55 * @version $Revision: 5928 $
56 */
57public class IsotonicRegression extends AbstractClassifier implements WeightedInstancesHandler {
58
59  /** for serialization */
60  static final long serialVersionUID = 1679336022835454137L;
61 
62  /** The chosen attribute */
63  private Attribute m_attribute;
64
65  /** The array of cut points */
66  private double[] m_cuts;
67 
68  /** The predicted value in each interval. */
69  private double[] m_values;
70
71  /** The minimum mean squared error that has been achieved. */
72  private double m_minMsq;
73
74  /** a ZeroR model in case no model can be built from the data */
75  private Classifier m_ZeroR;
76
77  /**
78   * Returns a string describing this classifier
79   * @return a description of the classifier suitable for
80   * displaying in the explorer/experimenter gui
81   */
82  public String globalInfo() {
83    return "Learns an isotonic regression model. "
84      +"Picks the attribute that results in the lowest squared error. "
85      +"Missing values are not allowed. Can only deal with numeric attributes."
86      +"Considers the monotonically increasing case as well as the monotonically"
87      +"decreasing case";
88  }
89
90  /**
91   * Generate a prediction for the supplied instance.
92   *
93   * @param inst the instance to predict.
94   * @return the prediction
95   * @throws Exception if an error occurs
96   */
97  public double classifyInstance(Instance inst) throws Exception {
98   
99    // default model?
100    if (m_ZeroR != null) {
101      return m_ZeroR.classifyInstance(inst);
102    }
103   
104    if (inst.isMissing(m_attribute.index())) {
105      throw new Exception("IsotonicRegression: No missing values!");
106    }
107    int index = Arrays.binarySearch(m_cuts, inst.value(m_attribute));
108    if (index < 0) {
109      return m_values[-index - 1];
110    } else { 
111      return m_values[index + 1];
112    }
113  }
114
115  /**
116   * Returns default capabilities of the classifier.
117   *
118   * @return      the capabilities of this classifier
119   */
120  public Capabilities getCapabilities() {
121    Capabilities result = super.getCapabilities();
122    result.disableAll();
123
124    // attributes
125    result.enable(Capability.NUMERIC_ATTRIBUTES);
126    result.enable(Capability.DATE_ATTRIBUTES);
127
128    // class
129    result.enable(Capability.NUMERIC_CLASS);
130    result.enable(Capability.DATE_CLASS);
131    result.enable(Capability.MISSING_CLASS_VALUES);
132   
133    return result;
134  }
135
136  /**
137   * Does the actual regression.
138   */
139  protected void regress(Attribute attribute, Instances insts, boolean ascending) 
140    throws Exception {
141
142    // Sort values according to current attribute
143    insts.sort(attribute);
144   
145    // Initialize arrays
146    double[] values = new double[insts.numInstances()];
147    double[] weights = new double[insts.numInstances()];
148    double[] cuts = new double[insts.numInstances() - 1];
149    int size = 0;
150    values[0] = insts.instance(0).classValue();
151    weights[0] = insts.instance(0).weight();
152    for (int i = 1; i < insts.numInstances(); i++) {
153      if (insts.instance(i).value(attribute) >
154          insts.instance(i - 1).value(attribute)) {
155        cuts[size] = (insts.instance(i).value(attribute) +
156                      insts.instance(i - 1).value(attribute)) / 2;
157        size++;
158      }
159      values[size] += insts.instance(i).classValue();
160      weights[size] += insts.instance(i).weight();
161    }
162    size++;
163   
164    // While there is a pair of adjacent violators
165    boolean violators;
166    do {
167      violators = false;
168     
169      // Initialize arrays
170      double[] tempValues = new double[size];
171      double[] tempWeights = new double[size];
172      double[] tempCuts = new double[size - 1];
173     
174      // Merge adjacent violators
175      int newSize = 0;
176      tempValues[0] = values[0];
177      tempWeights[0] = weights[0];
178      for (int j = 1; j < size; j++) {
179        if ((ascending && (values[j] / weights[j] > 
180                           tempValues[newSize] / tempWeights[newSize])) ||
181            (!ascending && (values[j] / weights[j] < 
182                            tempValues[newSize] / tempWeights[newSize]))) {
183          tempCuts[newSize] = cuts[j - 1];
184          newSize++;
185          tempValues[newSize] = values[j];
186          tempWeights[newSize] = weights[j];
187        } else {
188          tempWeights[newSize] += weights[j];
189          tempValues[newSize] += values[j];
190          violators = true;
191        }
192      }
193      newSize++;
194     
195      // Copy references
196      values = tempValues;
197      weights = tempWeights;
198      cuts = tempCuts;
199      size = newSize;
200    } while (violators);
201   
202    // Compute actual predictions
203    for (int i = 0; i < size; i++) {
204      values[i] /= weights[i];
205    }
206   
207    // Backup best instance variables
208    Attribute attributeBackedup = m_attribute;
209    double[] cutsBackedup = m_cuts;
210    double[] valuesBackedup = m_values;
211   
212    // Set instance variables to values computed for this attribute
213    m_attribute = attribute;
214    m_cuts = cuts;
215    m_values = values;
216   
217    // Compute sum of squared errors
218    Evaluation eval = new Evaluation(insts);
219    eval.evaluateModel(this, insts);
220    double msq = eval.rootMeanSquaredError();
221   
222    // Check whether this is the best attribute
223    if (msq < m_minMsq) {
224      m_minMsq = msq;
225    } else {
226      m_attribute = attributeBackedup;
227      m_cuts = cutsBackedup;
228      m_values = valuesBackedup;
229    }
230  }
231 
232  /**
233   * Builds an isotonic regression model given the supplied training data.
234   *
235   * @param insts the training data.
236   * @throws Exception if an error occurs
237   */
238  public void buildClassifier(Instances insts) throws Exception {
239
240    // can classifier handle the data?
241    getCapabilities().testWithFail(insts);
242
243    // remove instances with missing class
244    insts = new Instances(insts);
245    insts.deleteWithMissingClass();
246
247    // only class? -> build ZeroR model
248    if (insts.numAttributes() == 1) {
249      System.err.println(
250          "Cannot build model (only class attribute present in data!), "
251          + "using ZeroR model instead!");
252      m_ZeroR = new weka.classifiers.rules.ZeroR();
253      m_ZeroR.buildClassifier(insts);
254      return;
255    }
256    else {
257      m_ZeroR = null;
258    }
259
260    // Choose best attribute and mode
261    m_minMsq = Double.MAX_VALUE;
262    m_attribute = null;
263    for (int a = 0; a < insts.numAttributes(); a++) {
264      if (a != insts.classIndex()) {
265        regress(insts.attribute(a), insts, true);
266        regress(insts.attribute(a), insts, false);
267      }
268    }
269  }
270
271  /**
272   * Returns a description of this classifier as a string
273   *
274   * @return a description of the classifier.
275   */
276  public String toString() {
277
278    // only ZeroR model?
279    if (m_ZeroR != null) {
280      StringBuffer buf = new StringBuffer();
281      buf.append(this.getClass().getName().replaceAll(".*\\.", "") + "\n");
282      buf.append(this.getClass().getName().replaceAll(".*\\.", "").replaceAll(".", "=") + "\n\n");
283      buf.append("Warning: No model could be built, hence ZeroR model is used:\n\n");
284      buf.append(m_ZeroR.toString());
285      return buf.toString();
286    }
287   
288    StringBuffer text = new StringBuffer();
289    text.append("Isotonic regression\n\n");
290    if (m_attribute == null) {
291      text.append("No model built yet!");
292    }
293    else {
294      text.append("Based on attribute: " + m_attribute.name() + "\n\n");
295      for (int i = 0; i < m_values.length; i++) {
296        text.append("prediction: " + Utils.doubleToString(m_values[i], 10, 2));
297        if (i < m_cuts.length) {
298          text.append("\t\tcut point: " + Utils.doubleToString(m_cuts[i], 10, 2) + "\n");
299        }
300      }
301    }
302    return text.toString();
303  }
304 
305  /**
306   * Returns the revision string.
307   *
308   * @return            the revision
309   */
310  public String getRevision() {
311    return RevisionUtils.extract("$Revision: 5928 $");
312  }
313
314  /**
315   * Main method for testing this class
316   *
317   * @param argv options
318   */
319  public static void main(String [] argv){
320    runClassifier(new IsotonicRegression(), argv);
321  } 
322}
Note: See TracBrowser for help on using the repository browser.