1 | /* |
---|
2 | * This program is free software; you can redistribute it and/or modify |
---|
3 | * it under the terms of the GNU General Public License as published by |
---|
4 | * the Free Software Foundation; either version 2 of the License, or |
---|
5 | * (at your option) any later version. |
---|
6 | * |
---|
7 | * This program is distributed in the hope that it will be useful, |
---|
8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
---|
10 | * GNU General Public License for more details. |
---|
11 | * |
---|
12 | * You should have received a copy of the GNU General Public License |
---|
13 | * along with this program; if not, write to the Free Software |
---|
14 | * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. |
---|
15 | */ |
---|
16 | |
---|
17 | /* |
---|
18 | * IsotonicRegression.java |
---|
19 | * Copyright (C) 2006 University of Waikato, Hamilton, New Zealand |
---|
20 | * |
---|
21 | */ |
---|
22 | |
---|
23 | package weka.classifiers.functions; |
---|
24 | |
---|
25 | import weka.classifiers.Classifier; |
---|
26 | import weka.classifiers.AbstractClassifier; |
---|
27 | import weka.classifiers.Evaluation; |
---|
28 | import weka.core.Attribute; |
---|
29 | import weka.core.Capabilities; |
---|
30 | import weka.core.Instance; |
---|
31 | import weka.core.Instances; |
---|
32 | import weka.core.RevisionUtils; |
---|
33 | import weka.core.Utils; |
---|
34 | import weka.core.WeightedInstancesHandler; |
---|
35 | import weka.core.Capabilities.Capability; |
---|
36 | |
---|
37 | import java.util.Arrays; |
---|
38 | |
---|
39 | /** |
---|
40 | <!-- globalinfo-start --> |
---|
41 | * Learns an isotonic regression model. Picks the attribute that results in the lowest squared error. Missing values are not allowed. Can only deal with numeric attributes.Considers the monotonically increasing case as well as the monotonicallydecreasing case |
---|
42 | * <p/> |
---|
43 | <!-- globalinfo-end --> |
---|
44 | * |
---|
45 | <!-- options-start --> |
---|
46 | * Valid options are: <p/> |
---|
47 | * |
---|
48 | * <pre> -D |
---|
49 | * If set, classifier is run in debug mode and |
---|
50 | * may output additional info to the console</pre> |
---|
51 | * |
---|
52 | <!-- options-end --> |
---|
53 | * |
---|
54 | * @author Eibe Frank (eibe@cs.waikato.ac.nz) |
---|
55 | * @version $Revision: 5928 $ |
---|
56 | */ |
---|
57 | public class IsotonicRegression extends AbstractClassifier implements WeightedInstancesHandler { |
---|
58 | |
---|
59 | /** for serialization */ |
---|
60 | static final long serialVersionUID = 1679336022835454137L; |
---|
61 | |
---|
62 | /** The chosen attribute */ |
---|
63 | private Attribute m_attribute; |
---|
64 | |
---|
65 | /** The array of cut points */ |
---|
66 | private double[] m_cuts; |
---|
67 | |
---|
68 | /** The predicted value in each interval. */ |
---|
69 | private double[] m_values; |
---|
70 | |
---|
71 | /** The minimum mean squared error that has been achieved. */ |
---|
72 | private double m_minMsq; |
---|
73 | |
---|
74 | /** a ZeroR model in case no model can be built from the data */ |
---|
75 | private Classifier m_ZeroR; |
---|
76 | |
---|
77 | /** |
---|
78 | * Returns a string describing this classifier |
---|
79 | * @return a description of the classifier suitable for |
---|
80 | * displaying in the explorer/experimenter gui |
---|
81 | */ |
---|
82 | public String globalInfo() { |
---|
83 | return "Learns an isotonic regression model. " |
---|
84 | +"Picks the attribute that results in the lowest squared error. " |
---|
85 | +"Missing values are not allowed. Can only deal with numeric attributes." |
---|
86 | +"Considers the monotonically increasing case as well as the monotonically" |
---|
87 | +"decreasing case"; |
---|
88 | } |
---|
89 | |
---|
90 | /** |
---|
91 | * Generate a prediction for the supplied instance. |
---|
92 | * |
---|
93 | * @param inst the instance to predict. |
---|
94 | * @return the prediction |
---|
95 | * @throws Exception if an error occurs |
---|
96 | */ |
---|
97 | public double classifyInstance(Instance inst) throws Exception { |
---|
98 | |
---|
99 | // default model? |
---|
100 | if (m_ZeroR != null) { |
---|
101 | return m_ZeroR.classifyInstance(inst); |
---|
102 | } |
---|
103 | |
---|
104 | if (inst.isMissing(m_attribute.index())) { |
---|
105 | throw new Exception("IsotonicRegression: No missing values!"); |
---|
106 | } |
---|
107 | int index = Arrays.binarySearch(m_cuts, inst.value(m_attribute)); |
---|
108 | if (index < 0) { |
---|
109 | return m_values[-index - 1]; |
---|
110 | } else { |
---|
111 | return m_values[index + 1]; |
---|
112 | } |
---|
113 | } |
---|
114 | |
---|
115 | /** |
---|
116 | * Returns default capabilities of the classifier. |
---|
117 | * |
---|
118 | * @return the capabilities of this classifier |
---|
119 | */ |
---|
120 | public Capabilities getCapabilities() { |
---|
121 | Capabilities result = super.getCapabilities(); |
---|
122 | result.disableAll(); |
---|
123 | |
---|
124 | // attributes |
---|
125 | result.enable(Capability.NUMERIC_ATTRIBUTES); |
---|
126 | result.enable(Capability.DATE_ATTRIBUTES); |
---|
127 | |
---|
128 | // class |
---|
129 | result.enable(Capability.NUMERIC_CLASS); |
---|
130 | result.enable(Capability.DATE_CLASS); |
---|
131 | result.enable(Capability.MISSING_CLASS_VALUES); |
---|
132 | |
---|
133 | return result; |
---|
134 | } |
---|
135 | |
---|
136 | /** |
---|
137 | * Does the actual regression. |
---|
138 | */ |
---|
139 | protected void regress(Attribute attribute, Instances insts, boolean ascending) |
---|
140 | throws Exception { |
---|
141 | |
---|
142 | // Sort values according to current attribute |
---|
143 | insts.sort(attribute); |
---|
144 | |
---|
145 | // Initialize arrays |
---|
146 | double[] values = new double[insts.numInstances()]; |
---|
147 | double[] weights = new double[insts.numInstances()]; |
---|
148 | double[] cuts = new double[insts.numInstances() - 1]; |
---|
149 | int size = 0; |
---|
150 | values[0] = insts.instance(0).classValue(); |
---|
151 | weights[0] = insts.instance(0).weight(); |
---|
152 | for (int i = 1; i < insts.numInstances(); i++) { |
---|
153 | if (insts.instance(i).value(attribute) > |
---|
154 | insts.instance(i - 1).value(attribute)) { |
---|
155 | cuts[size] = (insts.instance(i).value(attribute) + |
---|
156 | insts.instance(i - 1).value(attribute)) / 2; |
---|
157 | size++; |
---|
158 | } |
---|
159 | values[size] += insts.instance(i).classValue(); |
---|
160 | weights[size] += insts.instance(i).weight(); |
---|
161 | } |
---|
162 | size++; |
---|
163 | |
---|
164 | // While there is a pair of adjacent violators |
---|
165 | boolean violators; |
---|
166 | do { |
---|
167 | violators = false; |
---|
168 | |
---|
169 | // Initialize arrays |
---|
170 | double[] tempValues = new double[size]; |
---|
171 | double[] tempWeights = new double[size]; |
---|
172 | double[] tempCuts = new double[size - 1]; |
---|
173 | |
---|
174 | // Merge adjacent violators |
---|
175 | int newSize = 0; |
---|
176 | tempValues[0] = values[0]; |
---|
177 | tempWeights[0] = weights[0]; |
---|
178 | for (int j = 1; j < size; j++) { |
---|
179 | if ((ascending && (values[j] / weights[j] > |
---|
180 | tempValues[newSize] / tempWeights[newSize])) || |
---|
181 | (!ascending && (values[j] / weights[j] < |
---|
182 | tempValues[newSize] / tempWeights[newSize]))) { |
---|
183 | tempCuts[newSize] = cuts[j - 1]; |
---|
184 | newSize++; |
---|
185 | tempValues[newSize] = values[j]; |
---|
186 | tempWeights[newSize] = weights[j]; |
---|
187 | } else { |
---|
188 | tempWeights[newSize] += weights[j]; |
---|
189 | tempValues[newSize] += values[j]; |
---|
190 | violators = true; |
---|
191 | } |
---|
192 | } |
---|
193 | newSize++; |
---|
194 | |
---|
195 | // Copy references |
---|
196 | values = tempValues; |
---|
197 | weights = tempWeights; |
---|
198 | cuts = tempCuts; |
---|
199 | size = newSize; |
---|
200 | } while (violators); |
---|
201 | |
---|
202 | // Compute actual predictions |
---|
203 | for (int i = 0; i < size; i++) { |
---|
204 | values[i] /= weights[i]; |
---|
205 | } |
---|
206 | |
---|
207 | // Backup best instance variables |
---|
208 | Attribute attributeBackedup = m_attribute; |
---|
209 | double[] cutsBackedup = m_cuts; |
---|
210 | double[] valuesBackedup = m_values; |
---|
211 | |
---|
212 | // Set instance variables to values computed for this attribute |
---|
213 | m_attribute = attribute; |
---|
214 | m_cuts = cuts; |
---|
215 | m_values = values; |
---|
216 | |
---|
217 | // Compute sum of squared errors |
---|
218 | Evaluation eval = new Evaluation(insts); |
---|
219 | eval.evaluateModel(this, insts); |
---|
220 | double msq = eval.rootMeanSquaredError(); |
---|
221 | |
---|
222 | // Check whether this is the best attribute |
---|
223 | if (msq < m_minMsq) { |
---|
224 | m_minMsq = msq; |
---|
225 | } else { |
---|
226 | m_attribute = attributeBackedup; |
---|
227 | m_cuts = cutsBackedup; |
---|
228 | m_values = valuesBackedup; |
---|
229 | } |
---|
230 | } |
---|
231 | |
---|
232 | /** |
---|
233 | * Builds an isotonic regression model given the supplied training data. |
---|
234 | * |
---|
235 | * @param insts the training data. |
---|
236 | * @throws Exception if an error occurs |
---|
237 | */ |
---|
238 | public void buildClassifier(Instances insts) throws Exception { |
---|
239 | |
---|
240 | // can classifier handle the data? |
---|
241 | getCapabilities().testWithFail(insts); |
---|
242 | |
---|
243 | // remove instances with missing class |
---|
244 | insts = new Instances(insts); |
---|
245 | insts.deleteWithMissingClass(); |
---|
246 | |
---|
247 | // only class? -> build ZeroR model |
---|
248 | if (insts.numAttributes() == 1) { |
---|
249 | System.err.println( |
---|
250 | "Cannot build model (only class attribute present in data!), " |
---|
251 | + "using ZeroR model instead!"); |
---|
252 | m_ZeroR = new weka.classifiers.rules.ZeroR(); |
---|
253 | m_ZeroR.buildClassifier(insts); |
---|
254 | return; |
---|
255 | } |
---|
256 | else { |
---|
257 | m_ZeroR = null; |
---|
258 | } |
---|
259 | |
---|
260 | // Choose best attribute and mode |
---|
261 | m_minMsq = Double.MAX_VALUE; |
---|
262 | m_attribute = null; |
---|
263 | for (int a = 0; a < insts.numAttributes(); a++) { |
---|
264 | if (a != insts.classIndex()) { |
---|
265 | regress(insts.attribute(a), insts, true); |
---|
266 | regress(insts.attribute(a), insts, false); |
---|
267 | } |
---|
268 | } |
---|
269 | } |
---|
270 | |
---|
271 | /** |
---|
272 | * Returns a description of this classifier as a string |
---|
273 | * |
---|
274 | * @return a description of the classifier. |
---|
275 | */ |
---|
276 | public String toString() { |
---|
277 | |
---|
278 | // only ZeroR model? |
---|
279 | if (m_ZeroR != null) { |
---|
280 | StringBuffer buf = new StringBuffer(); |
---|
281 | buf.append(this.getClass().getName().replaceAll(".*\\.", "") + "\n"); |
---|
282 | buf.append(this.getClass().getName().replaceAll(".*\\.", "").replaceAll(".", "=") + "\n\n"); |
---|
283 | buf.append("Warning: No model could be built, hence ZeroR model is used:\n\n"); |
---|
284 | buf.append(m_ZeroR.toString()); |
---|
285 | return buf.toString(); |
---|
286 | } |
---|
287 | |
---|
288 | StringBuffer text = new StringBuffer(); |
---|
289 | text.append("Isotonic regression\n\n"); |
---|
290 | if (m_attribute == null) { |
---|
291 | text.append("No model built yet!"); |
---|
292 | } |
---|
293 | else { |
---|
294 | text.append("Based on attribute: " + m_attribute.name() + "\n\n"); |
---|
295 | for (int i = 0; i < m_values.length; i++) { |
---|
296 | text.append("prediction: " + Utils.doubleToString(m_values[i], 10, 2)); |
---|
297 | if (i < m_cuts.length) { |
---|
298 | text.append("\t\tcut point: " + Utils.doubleToString(m_cuts[i], 10, 2) + "\n"); |
---|
299 | } |
---|
300 | } |
---|
301 | } |
---|
302 | return text.toString(); |
---|
303 | } |
---|
304 | |
---|
305 | /** |
---|
306 | * Returns the revision string. |
---|
307 | * |
---|
308 | * @return the revision |
---|
309 | */ |
---|
310 | public String getRevision() { |
---|
311 | return RevisionUtils.extract("$Revision: 5928 $"); |
---|
312 | } |
---|
313 | |
---|
314 | /** |
---|
315 | * Main method for testing this class |
---|
316 | * |
---|
317 | * @param argv options |
---|
318 | */ |
---|
319 | public static void main(String [] argv){ |
---|
320 | runClassifier(new IsotonicRegression(), argv); |
---|
321 | } |
---|
322 | } |
---|