| 1 | /* | 
|---|
| 2 | *    This program is free software; you can redistribute it and/or modify | 
|---|
| 3 | *    it under the terms of the GNU General Public License as published by | 
|---|
| 4 | *    the Free Software Foundation; either version 2 of the License, or | 
|---|
| 5 | *    (at your option) any later version. | 
|---|
| 6 | * | 
|---|
| 7 | *    This program is distributed in the hope that it will be useful, | 
|---|
| 8 | *    but WITHOUT ANY WARRANTY; without even the implied warranty of | 
|---|
| 9 | *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | 
|---|
| 10 | *    GNU General Public License for more details. | 
|---|
| 11 | * | 
|---|
| 12 | *    You should have received a copy of the GNU General Public License | 
|---|
| 13 | *    along with this program; if not, write to the Free Software | 
|---|
| 14 | *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. | 
|---|
| 15 | */ | 
|---|
| 16 |  | 
|---|
| 17 | /* | 
|---|
| 18 | *    Regression.java | 
|---|
| 19 | *    Copyright (C) 2008 University of Waikato, Hamilton, New Zealand | 
|---|
| 20 | * | 
|---|
| 21 | */ | 
|---|
| 22 |  | 
|---|
| 23 | package weka.classifiers.pmml.consumer; | 
|---|
| 24 |  | 
|---|
| 25 | import java.io.Serializable; | 
|---|
| 26 | import java.util.ArrayList; | 
|---|
| 27 | import org.w3c.dom.Element; | 
|---|
| 28 | import org.w3c.dom.Node; | 
|---|
| 29 | import org.w3c.dom.NodeList; | 
|---|
| 30 |  | 
|---|
| 31 | import weka.core.Attribute; | 
|---|
| 32 | import weka.core.Instance; | 
|---|
| 33 | import weka.core.Instances; | 
|---|
| 34 | import weka.core.RevisionUtils; | 
|---|
| 35 | import weka.core.Utils; | 
|---|
| 36 | import weka.core.pmml.*; | 
|---|
| 37 |  | 
|---|
| 38 | /** | 
|---|
| 39 | * Class implementing import of PMML Regression model. Can be | 
|---|
| 40 | * used as a Weka classifier for prediction (buildClassifier() | 
|---|
| 41 | * raises an Exception). | 
|---|
| 42 | * | 
|---|
| 43 | * @author Mark Hall (mhall{[at]}pentaho{[dot]}com | 
|---|
| 44 | * @version $Revision: 6018 $ | 
|---|
| 45 | */ | 
|---|
| 46 | public class Regression extends PMMLClassifier | 
|---|
| 47 | implements Serializable { | 
|---|
| 48 |  | 
|---|
| 49 | /** For serialization */ | 
|---|
| 50 | private static final long serialVersionUID = -5551125528409488634L; | 
|---|
| 51 |  | 
|---|
| 52 | /** | 
|---|
| 53 | * Inner class for encapsulating a regression table | 
|---|
| 54 | */ | 
|---|
| 55 | static class RegressionTable implements Serializable { | 
|---|
| 56 |  | 
|---|
| 57 | /** For serialization */ | 
|---|
| 58 | private static final long serialVersionUID = -5259866093996338995L; | 
|---|
| 59 |  | 
|---|
| 60 | /** | 
|---|
| 61 | * Abstract inner base class for different predictor types. | 
|---|
| 62 | */ | 
|---|
| 63 | abstract static class Predictor implements Serializable { | 
|---|
| 64 |  | 
|---|
| 65 | /** For serialization */ | 
|---|
| 66 | private static final long serialVersionUID = 7043831847273383618L; | 
|---|
| 67 |  | 
|---|
| 68 | /** Name of this predictor */ | 
|---|
| 69 | protected String m_name; | 
|---|
| 70 |  | 
|---|
| 71 | /** | 
|---|
| 72 | * Index of the attribute in the mining schema that corresponds to this | 
|---|
| 73 | * predictor | 
|---|
| 74 | */ | 
|---|
| 75 | protected int m_miningSchemaAttIndex = -1; | 
|---|
| 76 |  | 
|---|
| 77 | /** Coefficient for this predictor */ | 
|---|
| 78 | protected double m_coefficient = 1.0; | 
|---|
| 79 |  | 
|---|
| 80 | /** | 
|---|
| 81 | * Constructs a new Predictor. | 
|---|
| 82 | * | 
|---|
| 83 | * @param predictor the <code>Element</code> encapsulating this predictor | 
|---|
| 84 | * @param miningSchema the mining schema as an Instances object | 
|---|
| 85 | * @throws Exception if there is a problem constructing this Predictor | 
|---|
| 86 | */ | 
|---|
| 87 | protected Predictor(Element predictor, Instances miningSchema) throws Exception { | 
|---|
| 88 | m_name = predictor.getAttribute("name"); | 
|---|
| 89 | for (int i = 0; i < miningSchema.numAttributes(); i++) { | 
|---|
| 90 | Attribute temp = miningSchema.attribute(i); | 
|---|
| 91 | if (temp.name().equals(m_name)) { | 
|---|
| 92 | m_miningSchemaAttIndex = i; | 
|---|
| 93 | } | 
|---|
| 94 | } | 
|---|
| 95 |  | 
|---|
| 96 | if (m_miningSchemaAttIndex == -1) { | 
|---|
| 97 | throw new Exception("[Predictor] unable to find matching attribute for " | 
|---|
| 98 | + "predictor " + m_name); | 
|---|
| 99 | } | 
|---|
| 100 |  | 
|---|
| 101 | String coeff = predictor.getAttribute("coefficient"); | 
|---|
| 102 | if (coeff.length() > 0) { | 
|---|
| 103 | m_coefficient = Double.parseDouble(coeff); | 
|---|
| 104 | } | 
|---|
| 105 | } | 
|---|
| 106 |  | 
|---|
| 107 | /** | 
|---|
| 108 | * Returns a textual description of this predictor applicable | 
|---|
| 109 | * to all sub classes. | 
|---|
| 110 | */ | 
|---|
| 111 | public String toString() { | 
|---|
| 112 | return Utils.doubleToString(m_coefficient, 12, 4) + " * "; | 
|---|
| 113 | } | 
|---|
| 114 |  | 
|---|
| 115 | /** | 
|---|
| 116 | * Abstract add method. Adds this predictor into the sum for the | 
|---|
| 117 | * current prediction. | 
|---|
| 118 | * | 
|---|
| 119 | * @param preds the prediction computed so far. For regression, it is a | 
|---|
| 120 | * single element array; for classification it is a multi-element array | 
|---|
| 121 | * @param input the input instance's values | 
|---|
| 122 | */ | 
|---|
| 123 | public abstract void add(double[] preds, double[] input); | 
|---|
| 124 | } | 
|---|
| 125 |  | 
|---|
| 126 | /** | 
|---|
| 127 | * Inner class for a numeric predictor | 
|---|
| 128 | */ | 
|---|
| 129 | protected class NumericPredictor extends Predictor { | 
|---|
| 130 | /** | 
|---|
| 131 | * For serialization | 
|---|
| 132 | */ | 
|---|
| 133 | private static final long serialVersionUID = -4335075205696648273L; | 
|---|
| 134 |  | 
|---|
| 135 | /** The exponent*/ | 
|---|
| 136 | protected double m_exponent = 1.0; | 
|---|
| 137 |  | 
|---|
| 138 | /** | 
|---|
| 139 | * Constructs a NumericPredictor. | 
|---|
| 140 | * | 
|---|
| 141 | * @param predictor the <code>Element</code> holding the predictor | 
|---|
| 142 | * @param miningSchema the mining schema as an Instances object | 
|---|
| 143 | * @throws Exception if something goes wrong while constructing this | 
|---|
| 144 | * predictor | 
|---|
| 145 | */ | 
|---|
| 146 | protected NumericPredictor(Element predictor, | 
|---|
| 147 | Instances miningSchema) throws Exception { | 
|---|
| 148 | super(predictor, miningSchema); | 
|---|
| 149 |  | 
|---|
| 150 | String exponent = predictor.getAttribute("exponent"); | 
|---|
| 151 | if (exponent.length() > 0) { | 
|---|
| 152 | m_exponent = Double.parseDouble(exponent); | 
|---|
| 153 | } | 
|---|
| 154 | } | 
|---|
| 155 |  | 
|---|
| 156 | /** | 
|---|
| 157 | * Return a textual description of this predictor. | 
|---|
| 158 | */ | 
|---|
| 159 | public String toString() { | 
|---|
| 160 | String output = super.toString(); | 
|---|
| 161 | output += m_name; | 
|---|
| 162 | if (m_exponent > 1.0 || m_exponent < 1.0) { | 
|---|
| 163 | output += "^" + Utils.doubleToString(m_exponent, 4); | 
|---|
| 164 | } | 
|---|
| 165 | return output; | 
|---|
| 166 | } | 
|---|
| 167 |  | 
|---|
| 168 | /** | 
|---|
| 169 | * Adds this predictor into the sum for the | 
|---|
| 170 | * current prediction. | 
|---|
| 171 | * | 
|---|
| 172 | * @param preds the prediction computed so far. For regression, it is a | 
|---|
| 173 | * single element array; for classification it is a multi-element array | 
|---|
| 174 | * @param input the input instance's values | 
|---|
| 175 | */ | 
|---|
| 176 | public void add(double[] preds, double[] input) { | 
|---|
| 177 | if (m_targetCategory == -1) { | 
|---|
| 178 | preds[0] += m_coefficient * Math.pow(input[m_miningSchemaAttIndex], m_exponent); | 
|---|
| 179 | } else { | 
|---|
| 180 | preds[m_targetCategory] += | 
|---|
| 181 | m_coefficient * Math.pow(input[m_miningSchemaAttIndex], m_exponent); | 
|---|
| 182 | } | 
|---|
| 183 | } | 
|---|
| 184 | } | 
|---|
| 185 |  | 
|---|
| 186 | /** | 
|---|
| 187 | * Inner class encapsulating a categorical predictor. | 
|---|
| 188 | */ | 
|---|
| 189 | protected class CategoricalPredictor extends Predictor { | 
|---|
| 190 |  | 
|---|
| 191 | /**For serialization */ | 
|---|
| 192 | private static final long serialVersionUID = 3077920125549906819L; | 
|---|
| 193 |  | 
|---|
| 194 | /** The attribute value for this predictor */ | 
|---|
| 195 | protected String m_valueName; | 
|---|
| 196 |  | 
|---|
| 197 | /** The index of the attribute value for this predictor */ | 
|---|
| 198 | protected int m_valueIndex = -1; | 
|---|
| 199 |  | 
|---|
| 200 | /** | 
|---|
| 201 | * Constructs a CategoricalPredictor. | 
|---|
| 202 | * | 
|---|
| 203 | * @param predictor the <code>Element</code> containing the predictor | 
|---|
| 204 | * @param miningSchema the mining schema as an Instances object | 
|---|
| 205 | * @throws Exception if something goes wrong while constructing | 
|---|
| 206 | * this predictor | 
|---|
| 207 | */ | 
|---|
| 208 | protected CategoricalPredictor(Element predictor, | 
|---|
| 209 | Instances miningSchema) throws Exception { | 
|---|
| 210 | super(predictor, miningSchema); | 
|---|
| 211 |  | 
|---|
| 212 | String valName = predictor.getAttribute("value"); | 
|---|
| 213 | if (valName.length() == 0) { | 
|---|
| 214 | throw new Exception("[CategoricalPredictor] attribute value not specified!"); | 
|---|
| 215 | } | 
|---|
| 216 |  | 
|---|
| 217 | m_valueName = valName; | 
|---|
| 218 |  | 
|---|
| 219 | Attribute att = miningSchema.attribute(m_miningSchemaAttIndex); | 
|---|
| 220 | if (att.isString()) { | 
|---|
| 221 | // means that there were no Value elements defined in the | 
|---|
| 222 | // data dictionary (and hence the mining schema). | 
|---|
| 223 | // We add our value here. | 
|---|
| 224 | att.addStringValue(m_valueName); | 
|---|
| 225 | } | 
|---|
| 226 | m_valueIndex = att.indexOfValue(m_valueName); | 
|---|
| 227 | /*        for (int i = 0; i < att.numValues(); i++) { | 
|---|
| 228 | if (att.value(i).equals(m_valueName)) { | 
|---|
| 229 | m_valueIndex = i; | 
|---|
| 230 | } | 
|---|
| 231 | }*/ | 
|---|
| 232 |  | 
|---|
| 233 | if (m_valueIndex == -1) { | 
|---|
| 234 | throw new Exception("[CategoricalPredictor] unable to find value " | 
|---|
| 235 | + m_valueName + " in mining schema attribute " | 
|---|
| 236 | + att.name()); | 
|---|
| 237 | } | 
|---|
| 238 | } | 
|---|
| 239 |  | 
|---|
| 240 | /** | 
|---|
| 241 | * Return a textual description of this predictor. | 
|---|
| 242 | */ | 
|---|
| 243 | public String toString() { | 
|---|
| 244 | String output = super.toString(); | 
|---|
| 245 | output += m_name + "=" + m_valueName; | 
|---|
| 246 | return output; | 
|---|
| 247 | } | 
|---|
| 248 |  | 
|---|
| 249 | /** | 
|---|
| 250 | * Adds this predictor into the sum for the | 
|---|
| 251 | * current prediction. | 
|---|
| 252 | * | 
|---|
| 253 | * @param preds the prediction computed so far. For regression, it is a | 
|---|
| 254 | * single element array; for classification it is a multi-element array | 
|---|
| 255 | * @param input the input instance's values | 
|---|
| 256 | */ | 
|---|
| 257 | public void add(double[] preds, double[] input) { | 
|---|
| 258 |  | 
|---|
| 259 | // if the value is equal to the one in the input then add the coefficient | 
|---|
| 260 | if (m_valueIndex == (int)input[m_miningSchemaAttIndex]) { | 
|---|
| 261 | if (m_targetCategory == -1) { | 
|---|
| 262 | preds[0] += m_coefficient; | 
|---|
| 263 | } else { | 
|---|
| 264 | preds[m_targetCategory] += m_coefficient; | 
|---|
| 265 | } | 
|---|
| 266 | } | 
|---|
| 267 | } | 
|---|
| 268 | } | 
|---|
| 269 |  | 
|---|
| 270 | /** | 
|---|
| 271 | * Inner class to handle PredictorTerms. | 
|---|
| 272 | */ | 
|---|
| 273 | protected class PredictorTerm implements Serializable { | 
|---|
| 274 |  | 
|---|
| 275 | /** For serialization */ | 
|---|
| 276 | private static final long serialVersionUID = 5493100145890252757L; | 
|---|
| 277 |  | 
|---|
| 278 | /** The coefficient for this predictor term */ | 
|---|
| 279 | protected double m_coefficient = 1.0; | 
|---|
| 280 |  | 
|---|
| 281 | /** the indexes of the terms to be multiplied */ | 
|---|
| 282 | protected int[] m_indexes; | 
|---|
| 283 |  | 
|---|
| 284 | /** The names of the terms (attributes) to be multiplied */ | 
|---|
| 285 | protected String[] m_fieldNames; | 
|---|
| 286 |  | 
|---|
| 287 | /** | 
|---|
| 288 | * Construct a new PredictorTerm. | 
|---|
| 289 | * | 
|---|
| 290 | * @param predictorTerm the <code>Element</code> describing the predictor term | 
|---|
| 291 | * @param miningSchema the mining schema as an Instances object | 
|---|
| 292 | * @throws Exception if something goes wrong while constructing this | 
|---|
| 293 | * predictor term | 
|---|
| 294 | */ | 
|---|
| 295 | protected PredictorTerm(Element predictorTerm, | 
|---|
| 296 | Instances miningSchema) throws Exception { | 
|---|
| 297 |  | 
|---|
| 298 | String coeff = predictorTerm.getAttribute("coefficient"); | 
|---|
| 299 | if (coeff != null && coeff.length() > 0) { | 
|---|
| 300 | try { | 
|---|
| 301 | m_coefficient = Double.parseDouble(coeff); | 
|---|
| 302 | } catch (IllegalArgumentException ex) { | 
|---|
| 303 | throw new Exception("[PredictorTerm] unable to parse coefficient"); | 
|---|
| 304 | } | 
|---|
| 305 | } | 
|---|
| 306 |  | 
|---|
| 307 | NodeList fields = predictorTerm.getElementsByTagName("FieldRef"); | 
|---|
| 308 | if (fields.getLength() > 0) { | 
|---|
| 309 | m_indexes = new int[fields.getLength()]; | 
|---|
| 310 | m_fieldNames = new String[fields.getLength()]; | 
|---|
| 311 |  | 
|---|
| 312 | for (int i = 0; i < fields.getLength(); i++) { | 
|---|
| 313 | Node fieldRef = fields.item(i); | 
|---|
| 314 | if (fieldRef.getNodeType() == Node.ELEMENT_NODE) { | 
|---|
| 315 | String fieldName = ((Element)fieldRef).getAttribute("field"); | 
|---|
| 316 | if (fieldName != null && fieldName.length() > 0) { | 
|---|
| 317 | boolean found = false; | 
|---|
| 318 | // look for this field in the mining schema | 
|---|
| 319 | for (int j = 0; j < miningSchema.numAttributes(); j++) { | 
|---|
| 320 | if (miningSchema.attribute(j).name().equals(fieldName)) { | 
|---|
| 321 |  | 
|---|
| 322 | // all referenced fields MUST be numeric | 
|---|
| 323 | if (!miningSchema.attribute(j).isNumeric()) { | 
|---|
| 324 | throw new Exception("[PredictorTerm] field is not continuous: " | 
|---|
| 325 | + fieldName); | 
|---|
| 326 | } | 
|---|
| 327 | found = true; | 
|---|
| 328 | m_indexes[i] = j; | 
|---|
| 329 | m_fieldNames[i] = fieldName; | 
|---|
| 330 | break; | 
|---|
| 331 | } | 
|---|
| 332 | } | 
|---|
| 333 | if (!found) { | 
|---|
| 334 | throw new Exception("[PredictorTerm] Unable to find field " | 
|---|
| 335 | + fieldName + " in mining schema!"); | 
|---|
| 336 | } | 
|---|
| 337 | } | 
|---|
| 338 | } | 
|---|
| 339 | } | 
|---|
| 340 | } | 
|---|
| 341 | } | 
|---|
| 342 |  | 
|---|
| 343 | /** | 
|---|
| 344 | * Return a textual description of this predictor term. | 
|---|
| 345 | */ | 
|---|
| 346 | public String toString() { | 
|---|
| 347 | StringBuffer result = new StringBuffer(); | 
|---|
| 348 | result.append("(" + Utils.doubleToString(m_coefficient, 12, 4)); | 
|---|
| 349 | for (int i = 0; i < m_fieldNames.length; i++) { | 
|---|
| 350 | result.append(" * " + m_fieldNames[i]); | 
|---|
| 351 | } | 
|---|
| 352 | result.append(")"); | 
|---|
| 353 | return result.toString(); | 
|---|
| 354 | } | 
|---|
| 355 |  | 
|---|
| 356 | /** | 
|---|
| 357 | * Adds this predictor term into the sum for the | 
|---|
| 358 | * current prediction. | 
|---|
| 359 | * | 
|---|
| 360 | * @param preds the prediction computed so far. For regression, it is a | 
|---|
| 361 | * single element array; for classification it is a multi-element array | 
|---|
| 362 | * @param input the input instance's values | 
|---|
| 363 | */ | 
|---|
| 364 | public void add(double[] preds, double[] input) { | 
|---|
| 365 | int indx = 0; | 
|---|
| 366 | if (m_targetCategory != -1) { | 
|---|
| 367 | indx = m_targetCategory; | 
|---|
| 368 | } | 
|---|
| 369 |  | 
|---|
| 370 | double result = m_coefficient; | 
|---|
| 371 | for (int i = 0; i < m_indexes.length; i++) { | 
|---|
| 372 | result *= input[m_indexes[i]]; | 
|---|
| 373 | } | 
|---|
| 374 | preds[indx] += result; | 
|---|
| 375 | } | 
|---|
| 376 | } | 
|---|
| 377 |  | 
|---|
| 378 | /** Constant for regression model type */ | 
|---|
| 379 | public static final int REGRESSION = 0; | 
|---|
| 380 |  | 
|---|
| 381 | /** Constant for classification model type */ | 
|---|
| 382 | public static final int CLASSIFICATION = 1; | 
|---|
| 383 |  | 
|---|
| 384 | /** The type of function - regression or classification */ | 
|---|
| 385 | protected int m_functionType = REGRESSION; | 
|---|
| 386 |  | 
|---|
| 387 | /** The mining schema */ | 
|---|
| 388 | protected MiningSchema m_miningSchema; | 
|---|
| 389 |  | 
|---|
| 390 | /** The intercept */ | 
|---|
| 391 | protected double m_intercept = 0.0; | 
|---|
| 392 |  | 
|---|
| 393 | /** classification only */ | 
|---|
| 394 | protected int m_targetCategory = -1; | 
|---|
| 395 |  | 
|---|
| 396 | /** Numeric and categorical predictors */ | 
|---|
| 397 | protected ArrayList<Predictor> m_predictors = | 
|---|
| 398 | new ArrayList<Predictor>(); | 
|---|
| 399 |  | 
|---|
| 400 | /** Interaction terms */ | 
|---|
| 401 | protected ArrayList<PredictorTerm> m_predictorTerms = | 
|---|
| 402 | new ArrayList<PredictorTerm>(); | 
|---|
| 403 |  | 
|---|
| 404 | /** | 
|---|
| 405 | * Return a textual description of this RegressionTable. | 
|---|
| 406 | */ | 
|---|
| 407 | public String toString() { | 
|---|
| 408 | Instances miningSchema = m_miningSchema.getFieldsAsInstances(); | 
|---|
| 409 | StringBuffer temp = new StringBuffer(); | 
|---|
| 410 | temp.append("Regression table:\n"); | 
|---|
| 411 | temp.append(miningSchema.classAttribute().name()); | 
|---|
| 412 | if (m_functionType == CLASSIFICATION) { | 
|---|
| 413 | temp.append("=" + miningSchema. | 
|---|
| 414 | classAttribute().value(m_targetCategory)); | 
|---|
| 415 | } | 
|---|
| 416 |  | 
|---|
| 417 | temp.append(" =\n\n"); | 
|---|
| 418 |  | 
|---|
| 419 | // do the predictors | 
|---|
| 420 | for (int i = 0; i < m_predictors.size(); i++) { | 
|---|
| 421 | temp.append(m_predictors.get(i).toString() + " +\n"); | 
|---|
| 422 | } | 
|---|
| 423 |  | 
|---|
| 424 | // do the predictor terms | 
|---|
| 425 | for (int i = 0; i < m_predictorTerms.size(); i++) { | 
|---|
| 426 | temp.append(m_predictorTerms.get(i).toString() + " +\n"); | 
|---|
| 427 | } | 
|---|
| 428 |  | 
|---|
| 429 | temp.append(Utils.doubleToString(m_intercept, 12, 4)); | 
|---|
| 430 | temp.append("\n\n"); | 
|---|
| 431 |  | 
|---|
| 432 | return temp.toString(); | 
|---|
| 433 | } | 
|---|
| 434 |  | 
|---|
| 435 | /** | 
|---|
| 436 | * Construct a regression table from an <code>Element</code> | 
|---|
| 437 | * | 
|---|
| 438 | * @param table the table to encapsulate | 
|---|
| 439 | * @param functionType the type of function | 
|---|
| 440 | * (regression or classification) | 
|---|
| 441 | * to use | 
|---|
| 442 | * @param mSchema the mining schema | 
|---|
| 443 | * @throws Exception if there is a problem while constructing | 
|---|
| 444 | * this regression table | 
|---|
| 445 | */ | 
|---|
| 446 | protected RegressionTable(Element table, | 
|---|
| 447 | int functionType, | 
|---|
| 448 | MiningSchema mSchema) throws Exception { | 
|---|
| 449 |  | 
|---|
| 450 | m_miningSchema = mSchema; | 
|---|
| 451 | m_functionType = functionType; | 
|---|
| 452 |  | 
|---|
| 453 | Instances miningSchema = m_miningSchema.getFieldsAsInstances(); | 
|---|
| 454 |  | 
|---|
| 455 | // get the intercept | 
|---|
| 456 | String intercept = table.getAttribute("intercept"); | 
|---|
| 457 | if (intercept.length() > 0) { | 
|---|
| 458 | m_intercept = Double.parseDouble(intercept); | 
|---|
| 459 | } | 
|---|
| 460 |  | 
|---|
| 461 | // get the target category (if classification) | 
|---|
| 462 | if (m_functionType == CLASSIFICATION) { | 
|---|
| 463 | // target category MUST be defined | 
|---|
| 464 | String targetCat = table.getAttribute("targetCategory"); | 
|---|
| 465 | if (targetCat.length() > 0) { | 
|---|
| 466 | Attribute classA = miningSchema.classAttribute(); | 
|---|
| 467 | for (int i = 0; i < classA.numValues(); i++) { | 
|---|
| 468 | if (classA.value(i).equals(targetCat)) { | 
|---|
| 469 | m_targetCategory = i; | 
|---|
| 470 | } | 
|---|
| 471 | } | 
|---|
| 472 | } | 
|---|
| 473 | if (m_targetCategory == -1) { | 
|---|
| 474 | throw new Exception("[RegressionTable] No target categories defined for classification"); | 
|---|
| 475 | } | 
|---|
| 476 | } | 
|---|
| 477 |  | 
|---|
| 478 | // read all the numeric predictors | 
|---|
| 479 | NodeList numericPs = table.getElementsByTagName("NumericPredictor"); | 
|---|
| 480 | for (int i = 0; i < numericPs.getLength(); i++) { | 
|---|
| 481 | Node nP = numericPs.item(i); | 
|---|
| 482 | if (nP.getNodeType() == Node.ELEMENT_NODE) { | 
|---|
| 483 | NumericPredictor numP = new NumericPredictor((Element)nP, miningSchema); | 
|---|
| 484 | m_predictors.add(numP); | 
|---|
| 485 | } | 
|---|
| 486 | } | 
|---|
| 487 |  | 
|---|
| 488 | // read all the categorical predictors | 
|---|
| 489 | NodeList categoricalPs = table.getElementsByTagName("CategoricalPredictor"); | 
|---|
| 490 | for (int i = 0; i < categoricalPs.getLength(); i++) { | 
|---|
| 491 | Node cP = categoricalPs.item(i); | 
|---|
| 492 | if (cP.getNodeType() == Node.ELEMENT_NODE) { | 
|---|
| 493 | CategoricalPredictor catP = new CategoricalPredictor((Element)cP, miningSchema); | 
|---|
| 494 | m_predictors.add(catP); | 
|---|
| 495 | } | 
|---|
| 496 | } | 
|---|
| 497 |  | 
|---|
| 498 | // read all the PredictorTerms | 
|---|
| 499 | NodeList predictorTerms = table.getElementsByTagName("PredictorTerm"); | 
|---|
| 500 | for (int i = 0; i < predictorTerms.getLength(); i++) { | 
|---|
| 501 | Node pT = predictorTerms.item(i); | 
|---|
| 502 | PredictorTerm predT = new PredictorTerm((Element)pT, miningSchema); | 
|---|
| 503 | m_predictorTerms.add(predT); | 
|---|
| 504 | } | 
|---|
| 505 | } | 
|---|
| 506 |  | 
|---|
| 507 | public void predict(double[] preds, double[] input) { | 
|---|
| 508 | if (m_targetCategory == -1) { | 
|---|
| 509 | preds[0] = m_intercept; | 
|---|
| 510 | } else { | 
|---|
| 511 | preds[m_targetCategory] = m_intercept; | 
|---|
| 512 | } | 
|---|
| 513 |  | 
|---|
| 514 | // add the predictors | 
|---|
| 515 | for (int i = 0; i < m_predictors.size(); i++) { | 
|---|
| 516 | Predictor p = m_predictors.get(i); | 
|---|
| 517 | p.add(preds, input); | 
|---|
| 518 | } | 
|---|
| 519 |  | 
|---|
| 520 | // add the PredictorTerms | 
|---|
| 521 | for (int i = 0; i < m_predictorTerms.size(); i++) { | 
|---|
| 522 | PredictorTerm pt = m_predictorTerms.get(i); | 
|---|
| 523 | pt.add(preds, input); | 
|---|
| 524 | } | 
|---|
| 525 | } | 
|---|
| 526 | } | 
|---|
| 527 |  | 
|---|
| 528 | /** Description of the algorithm */ | 
|---|
| 529 | protected String m_algorithmName; | 
|---|
| 530 |  | 
|---|
| 531 | /** The regression tables for this regression */ | 
|---|
| 532 | protected RegressionTable[] m_regressionTables; | 
|---|
| 533 |  | 
|---|
| 534 | /** | 
|---|
| 535 | * Enum for the normalization methods. | 
|---|
| 536 | */ | 
|---|
| 537 | enum Normalization { | 
|---|
| 538 | NONE, SIMPLEMAX, SOFTMAX, LOGIT, PROBIT, CLOGLOG, | 
|---|
| 539 | EXP, LOGLOG, CAUCHIT} | 
|---|
| 540 |  | 
|---|
| 541 | /** The normalization to use */ | 
|---|
| 542 | protected Normalization m_normalizationMethod = Normalization.NONE; | 
|---|
| 543 |  | 
|---|
| 544 | /** | 
|---|
| 545 | * Constructs a new PMML Regression. | 
|---|
| 546 | * | 
|---|
| 547 | * @param model the <code>Element</code> containing the regression model | 
|---|
| 548 | * @param dataDictionary the data dictionary as an Instances object | 
|---|
| 549 | * @param miningSchema the mining schema | 
|---|
| 550 | * @throws Exception if there is a problem constructing this Regression | 
|---|
| 551 | */ | 
|---|
| 552 | public Regression(Element model, Instances dataDictionary, | 
|---|
| 553 | MiningSchema miningSchema) throws Exception { | 
|---|
| 554 | super(dataDictionary, miningSchema); | 
|---|
| 555 |  | 
|---|
| 556 | int functionType = RegressionTable.REGRESSION; | 
|---|
| 557 |  | 
|---|
| 558 | // determine function name first | 
|---|
| 559 | String fName = model.getAttribute("functionName"); | 
|---|
| 560 |  | 
|---|
| 561 | if (fName.equals("regression")) { | 
|---|
| 562 | functionType = RegressionTable.REGRESSION; | 
|---|
| 563 | } else if (fName.equals("classification")) { | 
|---|
| 564 | functionType = RegressionTable.CLASSIFICATION; | 
|---|
| 565 | } else { | 
|---|
| 566 | throw new Exception("[PMML Regression] Function name not defined in pmml!"); | 
|---|
| 567 | } | 
|---|
| 568 |  | 
|---|
| 569 | // do we have an algorithm name? | 
|---|
| 570 | String algName = model.getAttribute("algorithmName"); | 
|---|
| 571 | if (algName != null && algName.length() > 0) { | 
|---|
| 572 | m_algorithmName = algName; | 
|---|
| 573 | } | 
|---|
| 574 |  | 
|---|
| 575 | // determine normalization method (if any) | 
|---|
| 576 | m_normalizationMethod = determineNormalization(model); | 
|---|
| 577 |  | 
|---|
| 578 | setUpRegressionTables(model, functionType); | 
|---|
| 579 |  | 
|---|
| 580 | // convert any string attributes in the mining schema | 
|---|
| 581 | //miningSchema.convertStringAttsToNominal(); | 
|---|
| 582 | } | 
|---|
| 583 |  | 
|---|
| 584 | /** | 
|---|
| 585 | * Create all the RegressionTables for this model. | 
|---|
| 586 | * | 
|---|
| 587 | * @param model the <code>Element</code> holding this regression model | 
|---|
| 588 | * @param functionType the type of function (regression or | 
|---|
| 589 | * classification) | 
|---|
| 590 | * @throws Exception if there is a problem setting up the regression | 
|---|
| 591 | * tables | 
|---|
| 592 | */ | 
|---|
| 593 | private void setUpRegressionTables(Element model, | 
|---|
| 594 | int functionType) throws Exception { | 
|---|
| 595 | NodeList tableList = model.getElementsByTagName("RegressionTable"); | 
|---|
| 596 |  | 
|---|
| 597 | if (tableList.getLength() == 0) { | 
|---|
| 598 | throw new Exception("[Regression] no regression tables defined!"); | 
|---|
| 599 | } | 
|---|
| 600 |  | 
|---|
| 601 | m_regressionTables = new RegressionTable[tableList.getLength()]; | 
|---|
| 602 |  | 
|---|
| 603 | for (int i = 0; i < tableList.getLength(); i++) { | 
|---|
| 604 | Node table = tableList.item(i); | 
|---|
| 605 | if (table.getNodeType() == Node.ELEMENT_NODE) { | 
|---|
| 606 | RegressionTable tempRTable = | 
|---|
| 607 | new RegressionTable((Element)table, | 
|---|
| 608 | functionType, | 
|---|
| 609 | m_miningSchema); | 
|---|
| 610 | m_regressionTables[i] = tempRTable; | 
|---|
| 611 | } | 
|---|
| 612 | } | 
|---|
| 613 | } | 
|---|
| 614 |  | 
|---|
| 615 | /** | 
|---|
| 616 | * Return the type of normalization used for this regression | 
|---|
| 617 | * | 
|---|
| 618 | * @param model the <code>Element</code> holding the model | 
|---|
| 619 | * @return the normalization used in this regression | 
|---|
| 620 | */ | 
|---|
| 621 | private static Normalization determineNormalization(Element model) { | 
|---|
| 622 |  | 
|---|
| 623 | Normalization normMethod = Normalization.NONE; | 
|---|
| 624 |  | 
|---|
| 625 | String normName = model.getAttribute("normalizationMethod"); | 
|---|
| 626 | if (normName.equals("simplemax")) { | 
|---|
| 627 | normMethod = Normalization.SIMPLEMAX; | 
|---|
| 628 | } else if (normName.equals("softmax")) { | 
|---|
| 629 | normMethod = Normalization.SOFTMAX; | 
|---|
| 630 | } else if (normName.equals("logit")) { | 
|---|
| 631 | normMethod = Normalization.LOGIT; | 
|---|
| 632 | } else if (normName.equals("probit")) { | 
|---|
| 633 | normMethod = Normalization.PROBIT; | 
|---|
| 634 | } else if (normName.equals("cloglog")) { | 
|---|
| 635 | normMethod = Normalization.CLOGLOG; | 
|---|
| 636 | } else if (normName.equals("exp")) { | 
|---|
| 637 | normMethod = Normalization.EXP; | 
|---|
| 638 | } else if (normName.equals("loglog")) { | 
|---|
| 639 | normMethod = Normalization.LOGLOG; | 
|---|
| 640 | } else if (normName.equals("cauchit")) { | 
|---|
| 641 | normMethod = Normalization.CAUCHIT; | 
|---|
| 642 | } | 
|---|
| 643 | return normMethod; | 
|---|
| 644 | } | 
|---|
| 645 |  | 
|---|
| 646 | /** | 
|---|
| 647 | * Return a textual description of this Regression model. | 
|---|
| 648 | */ | 
|---|
| 649 | public String toString() { | 
|---|
| 650 | StringBuffer temp = new StringBuffer(); | 
|---|
| 651 | temp.append("PMML version " + getPMMLVersion()); | 
|---|
| 652 | if (!getCreatorApplication().equals("?")) { | 
|---|
| 653 | temp.append("\nApplication: " + getCreatorApplication()); | 
|---|
| 654 | } | 
|---|
| 655 | if (m_algorithmName != null) { | 
|---|
| 656 | temp.append("\nPMML Model: " + m_algorithmName); | 
|---|
| 657 | } | 
|---|
| 658 | temp.append("\n\n"); | 
|---|
| 659 | temp.append(m_miningSchema); | 
|---|
| 660 |  | 
|---|
| 661 | for (RegressionTable table : m_regressionTables) { | 
|---|
| 662 | temp.append(table); | 
|---|
| 663 | } | 
|---|
| 664 |  | 
|---|
| 665 | if (m_normalizationMethod != Normalization.NONE) { | 
|---|
| 666 | temp.append("Normalization: " + m_normalizationMethod); | 
|---|
| 667 | } | 
|---|
| 668 | temp.append("\n"); | 
|---|
| 669 |  | 
|---|
| 670 | return temp.toString(); | 
|---|
| 671 | } | 
|---|
| 672 |  | 
|---|
| 673 | /** | 
|---|
| 674 | * Classifies the given test instance. The instance has to belong to a | 
|---|
| 675 | * dataset when it's being classified. | 
|---|
| 676 | * | 
|---|
| 677 | * @param inst the instance to be classified | 
|---|
| 678 | * @return the predicted most likely class for the instance or | 
|---|
| 679 | * Utils.missingValue() if no prediction is made | 
|---|
| 680 | * @exception Exception if an error occurred during the prediction | 
|---|
| 681 | */ | 
|---|
| 682 | public double[] distributionForInstance(Instance inst) throws Exception { | 
|---|
| 683 | if (!m_initialized) { | 
|---|
| 684 | mapToMiningSchema(inst.dataset()); | 
|---|
| 685 | } | 
|---|
| 686 | double[] preds = null; | 
|---|
| 687 | if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) { | 
|---|
| 688 | preds = new double[1]; | 
|---|
| 689 | } else { | 
|---|
| 690 | preds = new double[m_miningSchema.getFieldsAsInstances().classAttribute().numValues()]; | 
|---|
| 691 | } | 
|---|
| 692 |  | 
|---|
| 693 | // create an array of doubles that holds values from the incoming | 
|---|
| 694 | // instance; in order of the fields in the mining schema. We will | 
|---|
| 695 | // also handle missing values and outliers here. | 
|---|
| 696 | //    System.err.println(inst); | 
|---|
| 697 | double[] incoming = m_fieldsMap.instanceToSchema(inst, m_miningSchema); | 
|---|
| 698 |  | 
|---|
| 699 | // scan for missing values. If there are still missing values after instanceToSchema(), | 
|---|
| 700 | // then missing value handling has been deferred to the PMML scheme. The specification | 
|---|
| 701 | // (Regression PMML 3.2) seems to contradict itself with regards to classification and categorical | 
|---|
| 702 | // variables. In one place it states that if a categorical variable is missing then | 
|---|
| 703 | // variable_name=value is 0 for any value. Further down in the document it states: "if | 
|---|
| 704 | // one or more of the y_j cannot be evaluated because the value in one of the referenced | 
|---|
| 705 | // fields is missing, then the following formulas (for computing p_j) do not apply. In | 
|---|
| 706 | // that case the predictions are defined by the priorProbability values in the Target | 
|---|
| 707 | // element". | 
|---|
| 708 |  | 
|---|
| 709 | // In this implementation we will default to information in the Target element (default | 
|---|
| 710 | // value for numeric prediction and prior probabilities for classification). If there is | 
|---|
| 711 | // no Target element defined, then an Exception is thrown. | 
|---|
| 712 |  | 
|---|
| 713 | boolean hasMissing = false; | 
|---|
| 714 | for (int i = 0; i < incoming.length; i++) { | 
|---|
| 715 | if (i != m_miningSchema.getFieldsAsInstances().classIndex() && | 
|---|
| 716 | Utils.isMissingValue(incoming[i])) { | 
|---|
| 717 | hasMissing = true; | 
|---|
| 718 | break; | 
|---|
| 719 | } | 
|---|
| 720 | } | 
|---|
| 721 |  | 
|---|
| 722 | if (hasMissing) { | 
|---|
| 723 | if (!m_miningSchema.hasTargetMetaData()) { | 
|---|
| 724 | String message = "[Regression] WARNING: Instance to predict has missing value(s) but " | 
|---|
| 725 | + "there is no missing value handling meta data and no " | 
|---|
| 726 | + "prior probabilities/default value to fall back to. No " | 
|---|
| 727 | + "prediction will be made (" | 
|---|
| 728 | + ((m_miningSchema.getFieldsAsInstances().classAttribute().isNominal() || | 
|---|
| 729 | m_miningSchema.getFieldsAsInstances().classAttribute().isString()) | 
|---|
| 730 | ? "zero probabilities output)." | 
|---|
| 731 | : "NaN output)."); | 
|---|
| 732 | if (m_log == null) { | 
|---|
| 733 | System.err.println(message); | 
|---|
| 734 | } else { | 
|---|
| 735 | m_log.logMessage(message); | 
|---|
| 736 | } | 
|---|
| 737 | if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) { | 
|---|
| 738 | preds[0] = Utils.missingValue(); | 
|---|
| 739 | } | 
|---|
| 740 | return preds; | 
|---|
| 741 | } else { | 
|---|
| 742 | // use prior probablilities/default value | 
|---|
| 743 | TargetMetaInfo targetData = m_miningSchema.getTargetMetaData(); | 
|---|
| 744 | if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) { | 
|---|
| 745 | preds[0] = targetData.getDefaultValue(); | 
|---|
| 746 | } else { | 
|---|
| 747 | Instances miningSchemaI = m_miningSchema.getFieldsAsInstances(); | 
|---|
| 748 | for (int i = 0; i < miningSchemaI.classAttribute().numValues(); i++) { | 
|---|
| 749 | preds[i] = targetData.getPriorProbability(miningSchemaI.classAttribute().value(i)); | 
|---|
| 750 | } | 
|---|
| 751 | } | 
|---|
| 752 | return preds; | 
|---|
| 753 | } | 
|---|
| 754 | } else { | 
|---|
| 755 | // loop through the RegressionTables | 
|---|
| 756 | for (int i = 0; i < m_regressionTables.length; i++) { | 
|---|
| 757 | m_regressionTables[i].predict(preds, incoming); | 
|---|
| 758 | } | 
|---|
| 759 |  | 
|---|
| 760 | // Now apply the normalization | 
|---|
| 761 | switch (m_normalizationMethod) { | 
|---|
| 762 | case NONE: | 
|---|
| 763 | // nothing to be done | 
|---|
| 764 | break; | 
|---|
| 765 | case SIMPLEMAX: | 
|---|
| 766 | Utils.normalize(preds); | 
|---|
| 767 | break; | 
|---|
| 768 | case SOFTMAX: | 
|---|
| 769 | for (int i = 0; i < preds.length; i++) { | 
|---|
| 770 | preds[i] = Math.exp(preds[i]); | 
|---|
| 771 | } | 
|---|
| 772 | if (preds.length == 1) { | 
|---|
| 773 | // hack for those models that do binary logistic regression as | 
|---|
| 774 | // a numeric prediction model | 
|---|
| 775 | preds[0] = preds[0] / (preds[0] + 1.0); | 
|---|
| 776 | } else { | 
|---|
| 777 | Utils.normalize(preds); | 
|---|
| 778 | } | 
|---|
| 779 | break; | 
|---|
| 780 | case LOGIT: | 
|---|
| 781 | for (int i = 0; i < preds.length; i++) { | 
|---|
| 782 | preds[i] = 1.0 / (1.0 + Math.exp(-preds[i])); | 
|---|
| 783 | } | 
|---|
| 784 | Utils.normalize(preds); | 
|---|
| 785 | break; | 
|---|
| 786 | case PROBIT: | 
|---|
| 787 | for (int i = 0; i < preds.length; i++) { | 
|---|
| 788 | preds[i] = weka.core.matrix.Maths.pnorm(preds[i]); | 
|---|
| 789 | } | 
|---|
| 790 | Utils.normalize(preds); | 
|---|
| 791 | break; | 
|---|
| 792 | case CLOGLOG: | 
|---|
| 793 | // note this is supposed to be illegal for regression | 
|---|
| 794 | for (int i = 0; i < preds.length; i++) { | 
|---|
| 795 | preds[i] = 1.0 - Math.exp(-Math.exp(-preds[i])); | 
|---|
| 796 | } | 
|---|
| 797 | Utils.normalize(preds); | 
|---|
| 798 | break; | 
|---|
| 799 | case EXP: | 
|---|
| 800 | for (int i = 0; i < preds.length; i++) { | 
|---|
| 801 | preds[i] = Math.exp(preds[i]); | 
|---|
| 802 | } | 
|---|
| 803 | Utils.normalize(preds); | 
|---|
| 804 | break; | 
|---|
| 805 | case LOGLOG: | 
|---|
| 806 | // note this is supposed to be illegal for regression | 
|---|
| 807 | for (int i = 0; i < preds.length; i++) { | 
|---|
| 808 | preds[i] = Math.exp(-Math.exp(-preds[i])); | 
|---|
| 809 | } | 
|---|
| 810 | Utils.normalize(preds); | 
|---|
| 811 | break; | 
|---|
| 812 | case CAUCHIT: | 
|---|
| 813 | for (int i = 0; i < preds.length; i++) { | 
|---|
| 814 | preds[i] = 0.5 + (1.0 / Math.PI) * Math.atan(preds[i]); | 
|---|
| 815 | } | 
|---|
| 816 | Utils.normalize(preds); | 
|---|
| 817 | break; | 
|---|
| 818 | default: | 
|---|
| 819 | throw new Exception("[Regression] unknown normalization method"); | 
|---|
| 820 | } | 
|---|
| 821 |  | 
|---|
| 822 | // If there is a Target defined, and this is a numeric prediction problem, | 
|---|
| 823 | // then apply any min, max, rescaling etc. | 
|---|
| 824 | if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric() | 
|---|
| 825 | && m_miningSchema.hasTargetMetaData()) { | 
|---|
| 826 | TargetMetaInfo targetData = m_miningSchema.getTargetMetaData(); | 
|---|
| 827 | preds[0] = targetData.applyMinMaxRescaleCast(preds[0]); | 
|---|
| 828 | } | 
|---|
| 829 | } | 
|---|
| 830 |  | 
|---|
| 831 | return preds; | 
|---|
| 832 | } | 
|---|
| 833 |  | 
|---|
| 834 | /* (non-Javadoc) | 
|---|
| 835 | * @see weka.core.RevisionHandler#getRevision() | 
|---|
| 836 | */ | 
|---|
| 837 | public String getRevision() { | 
|---|
| 838 | return RevisionUtils.extract("$Revision: 6018 $"); | 
|---|
| 839 | } | 
|---|
| 840 | } | 
|---|