| 1 | /* |
|---|
| 2 | * This program is free software; you can redistribute it and/or modify |
|---|
| 3 | * it under the terms of the GNU General Public License as published by |
|---|
| 4 | * the Free Software Foundation; either version 2 of the License, or |
|---|
| 5 | * (at your option) any later version. |
|---|
| 6 | * |
|---|
| 7 | * This program is distributed in the hope that it will be useful, |
|---|
| 8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
|---|
| 9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
|---|
| 10 | * GNU General Public License for more details. |
|---|
| 11 | * |
|---|
| 12 | * You should have received a copy of the GNU General Public License |
|---|
| 13 | * along with this program; if not, write to the Free Software |
|---|
| 14 | * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. |
|---|
| 15 | */ |
|---|
| 16 | |
|---|
| 17 | /* |
|---|
| 18 | * Regression.java |
|---|
| 19 | * Copyright (C) 2008 University of Waikato, Hamilton, New Zealand |
|---|
| 20 | * |
|---|
| 21 | */ |
|---|
| 22 | |
|---|
| 23 | package weka.classifiers.pmml.consumer; |
|---|
| 24 | |
|---|
| 25 | import java.io.Serializable; |
|---|
| 26 | import java.util.ArrayList; |
|---|
| 27 | import org.w3c.dom.Element; |
|---|
| 28 | import org.w3c.dom.Node; |
|---|
| 29 | import org.w3c.dom.NodeList; |
|---|
| 30 | |
|---|
| 31 | import weka.core.Attribute; |
|---|
| 32 | import weka.core.Instance; |
|---|
| 33 | import weka.core.Instances; |
|---|
| 34 | import weka.core.RevisionUtils; |
|---|
| 35 | import weka.core.Utils; |
|---|
| 36 | import weka.core.pmml.*; |
|---|
| 37 | |
|---|
| 38 | /** |
|---|
| 39 | * Class implementing import of PMML Regression model. Can be |
|---|
| 40 | * used as a Weka classifier for prediction (buildClassifier() |
|---|
| 41 | * raises an Exception). |
|---|
| 42 | * |
|---|
| 43 | * @author Mark Hall (mhall{[at]}pentaho{[dot]}com |
|---|
| 44 | * @version $Revision: 6018 $ |
|---|
| 45 | */ |
|---|
| 46 | public class Regression extends PMMLClassifier |
|---|
| 47 | implements Serializable { |
|---|
| 48 | |
|---|
| 49 | /** For serialization */ |
|---|
| 50 | private static final long serialVersionUID = -5551125528409488634L; |
|---|
| 51 | |
|---|
| 52 | /** |
|---|
| 53 | * Inner class for encapsulating a regression table |
|---|
| 54 | */ |
|---|
| 55 | static class RegressionTable implements Serializable { |
|---|
| 56 | |
|---|
| 57 | /** For serialization */ |
|---|
| 58 | private static final long serialVersionUID = -5259866093996338995L; |
|---|
| 59 | |
|---|
| 60 | /** |
|---|
| 61 | * Abstract inner base class for different predictor types. |
|---|
| 62 | */ |
|---|
| 63 | abstract static class Predictor implements Serializable { |
|---|
| 64 | |
|---|
| 65 | /** For serialization */ |
|---|
| 66 | private static final long serialVersionUID = 7043831847273383618L; |
|---|
| 67 | |
|---|
| 68 | /** Name of this predictor */ |
|---|
| 69 | protected String m_name; |
|---|
| 70 | |
|---|
| 71 | /** |
|---|
| 72 | * Index of the attribute in the mining schema that corresponds to this |
|---|
| 73 | * predictor |
|---|
| 74 | */ |
|---|
| 75 | protected int m_miningSchemaAttIndex = -1; |
|---|
| 76 | |
|---|
| 77 | /** Coefficient for this predictor */ |
|---|
| 78 | protected double m_coefficient = 1.0; |
|---|
| 79 | |
|---|
| 80 | /** |
|---|
| 81 | * Constructs a new Predictor. |
|---|
| 82 | * |
|---|
| 83 | * @param predictor the <code>Element</code> encapsulating this predictor |
|---|
| 84 | * @param miningSchema the mining schema as an Instances object |
|---|
| 85 | * @throws Exception if there is a problem constructing this Predictor |
|---|
| 86 | */ |
|---|
| 87 | protected Predictor(Element predictor, Instances miningSchema) throws Exception { |
|---|
| 88 | m_name = predictor.getAttribute("name"); |
|---|
| 89 | for (int i = 0; i < miningSchema.numAttributes(); i++) { |
|---|
| 90 | Attribute temp = miningSchema.attribute(i); |
|---|
| 91 | if (temp.name().equals(m_name)) { |
|---|
| 92 | m_miningSchemaAttIndex = i; |
|---|
| 93 | } |
|---|
| 94 | } |
|---|
| 95 | |
|---|
| 96 | if (m_miningSchemaAttIndex == -1) { |
|---|
| 97 | throw new Exception("[Predictor] unable to find matching attribute for " |
|---|
| 98 | + "predictor " + m_name); |
|---|
| 99 | } |
|---|
| 100 | |
|---|
| 101 | String coeff = predictor.getAttribute("coefficient"); |
|---|
| 102 | if (coeff.length() > 0) { |
|---|
| 103 | m_coefficient = Double.parseDouble(coeff); |
|---|
| 104 | } |
|---|
| 105 | } |
|---|
| 106 | |
|---|
| 107 | /** |
|---|
| 108 | * Returns a textual description of this predictor applicable |
|---|
| 109 | * to all sub classes. |
|---|
| 110 | */ |
|---|
| 111 | public String toString() { |
|---|
| 112 | return Utils.doubleToString(m_coefficient, 12, 4) + " * "; |
|---|
| 113 | } |
|---|
| 114 | |
|---|
| 115 | /** |
|---|
| 116 | * Abstract add method. Adds this predictor into the sum for the |
|---|
| 117 | * current prediction. |
|---|
| 118 | * |
|---|
| 119 | * @param preds the prediction computed so far. For regression, it is a |
|---|
| 120 | * single element array; for classification it is a multi-element array |
|---|
| 121 | * @param input the input instance's values |
|---|
| 122 | */ |
|---|
| 123 | public abstract void add(double[] preds, double[] input); |
|---|
| 124 | } |
|---|
| 125 | |
|---|
| 126 | /** |
|---|
| 127 | * Inner class for a numeric predictor |
|---|
| 128 | */ |
|---|
| 129 | protected class NumericPredictor extends Predictor { |
|---|
| 130 | /** |
|---|
| 131 | * For serialization |
|---|
| 132 | */ |
|---|
| 133 | private static final long serialVersionUID = -4335075205696648273L; |
|---|
| 134 | |
|---|
| 135 | /** The exponent*/ |
|---|
| 136 | protected double m_exponent = 1.0; |
|---|
| 137 | |
|---|
| 138 | /** |
|---|
| 139 | * Constructs a NumericPredictor. |
|---|
| 140 | * |
|---|
| 141 | * @param predictor the <code>Element</code> holding the predictor |
|---|
| 142 | * @param miningSchema the mining schema as an Instances object |
|---|
| 143 | * @throws Exception if something goes wrong while constructing this |
|---|
| 144 | * predictor |
|---|
| 145 | */ |
|---|
| 146 | protected NumericPredictor(Element predictor, |
|---|
| 147 | Instances miningSchema) throws Exception { |
|---|
| 148 | super(predictor, miningSchema); |
|---|
| 149 | |
|---|
| 150 | String exponent = predictor.getAttribute("exponent"); |
|---|
| 151 | if (exponent.length() > 0) { |
|---|
| 152 | m_exponent = Double.parseDouble(exponent); |
|---|
| 153 | } |
|---|
| 154 | } |
|---|
| 155 | |
|---|
| 156 | /** |
|---|
| 157 | * Return a textual description of this predictor. |
|---|
| 158 | */ |
|---|
| 159 | public String toString() { |
|---|
| 160 | String output = super.toString(); |
|---|
| 161 | output += m_name; |
|---|
| 162 | if (m_exponent > 1.0 || m_exponent < 1.0) { |
|---|
| 163 | output += "^" + Utils.doubleToString(m_exponent, 4); |
|---|
| 164 | } |
|---|
| 165 | return output; |
|---|
| 166 | } |
|---|
| 167 | |
|---|
| 168 | /** |
|---|
| 169 | * Adds this predictor into the sum for the |
|---|
| 170 | * current prediction. |
|---|
| 171 | * |
|---|
| 172 | * @param preds the prediction computed so far. For regression, it is a |
|---|
| 173 | * single element array; for classification it is a multi-element array |
|---|
| 174 | * @param input the input instance's values |
|---|
| 175 | */ |
|---|
| 176 | public void add(double[] preds, double[] input) { |
|---|
| 177 | if (m_targetCategory == -1) { |
|---|
| 178 | preds[0] += m_coefficient * Math.pow(input[m_miningSchemaAttIndex], m_exponent); |
|---|
| 179 | } else { |
|---|
| 180 | preds[m_targetCategory] += |
|---|
| 181 | m_coefficient * Math.pow(input[m_miningSchemaAttIndex], m_exponent); |
|---|
| 182 | } |
|---|
| 183 | } |
|---|
| 184 | } |
|---|
| 185 | |
|---|
| 186 | /** |
|---|
| 187 | * Inner class encapsulating a categorical predictor. |
|---|
| 188 | */ |
|---|
| 189 | protected class CategoricalPredictor extends Predictor { |
|---|
| 190 | |
|---|
| 191 | /**For serialization */ |
|---|
| 192 | private static final long serialVersionUID = 3077920125549906819L; |
|---|
| 193 | |
|---|
| 194 | /** The attribute value for this predictor */ |
|---|
| 195 | protected String m_valueName; |
|---|
| 196 | |
|---|
| 197 | /** The index of the attribute value for this predictor */ |
|---|
| 198 | protected int m_valueIndex = -1; |
|---|
| 199 | |
|---|
| 200 | /** |
|---|
| 201 | * Constructs a CategoricalPredictor. |
|---|
| 202 | * |
|---|
| 203 | * @param predictor the <code>Element</code> containing the predictor |
|---|
| 204 | * @param miningSchema the mining schema as an Instances object |
|---|
| 205 | * @throws Exception if something goes wrong while constructing |
|---|
| 206 | * this predictor |
|---|
| 207 | */ |
|---|
| 208 | protected CategoricalPredictor(Element predictor, |
|---|
| 209 | Instances miningSchema) throws Exception { |
|---|
| 210 | super(predictor, miningSchema); |
|---|
| 211 | |
|---|
| 212 | String valName = predictor.getAttribute("value"); |
|---|
| 213 | if (valName.length() == 0) { |
|---|
| 214 | throw new Exception("[CategoricalPredictor] attribute value not specified!"); |
|---|
| 215 | } |
|---|
| 216 | |
|---|
| 217 | m_valueName = valName; |
|---|
| 218 | |
|---|
| 219 | Attribute att = miningSchema.attribute(m_miningSchemaAttIndex); |
|---|
| 220 | if (att.isString()) { |
|---|
| 221 | // means that there were no Value elements defined in the |
|---|
| 222 | // data dictionary (and hence the mining schema). |
|---|
| 223 | // We add our value here. |
|---|
| 224 | att.addStringValue(m_valueName); |
|---|
| 225 | } |
|---|
| 226 | m_valueIndex = att.indexOfValue(m_valueName); |
|---|
| 227 | /* for (int i = 0; i < att.numValues(); i++) { |
|---|
| 228 | if (att.value(i).equals(m_valueName)) { |
|---|
| 229 | m_valueIndex = i; |
|---|
| 230 | } |
|---|
| 231 | }*/ |
|---|
| 232 | |
|---|
| 233 | if (m_valueIndex == -1) { |
|---|
| 234 | throw new Exception("[CategoricalPredictor] unable to find value " |
|---|
| 235 | + m_valueName + " in mining schema attribute " |
|---|
| 236 | + att.name()); |
|---|
| 237 | } |
|---|
| 238 | } |
|---|
| 239 | |
|---|
| 240 | /** |
|---|
| 241 | * Return a textual description of this predictor. |
|---|
| 242 | */ |
|---|
| 243 | public String toString() { |
|---|
| 244 | String output = super.toString(); |
|---|
| 245 | output += m_name + "=" + m_valueName; |
|---|
| 246 | return output; |
|---|
| 247 | } |
|---|
| 248 | |
|---|
| 249 | /** |
|---|
| 250 | * Adds this predictor into the sum for the |
|---|
| 251 | * current prediction. |
|---|
| 252 | * |
|---|
| 253 | * @param preds the prediction computed so far. For regression, it is a |
|---|
| 254 | * single element array; for classification it is a multi-element array |
|---|
| 255 | * @param input the input instance's values |
|---|
| 256 | */ |
|---|
| 257 | public void add(double[] preds, double[] input) { |
|---|
| 258 | |
|---|
| 259 | // if the value is equal to the one in the input then add the coefficient |
|---|
| 260 | if (m_valueIndex == (int)input[m_miningSchemaAttIndex]) { |
|---|
| 261 | if (m_targetCategory == -1) { |
|---|
| 262 | preds[0] += m_coefficient; |
|---|
| 263 | } else { |
|---|
| 264 | preds[m_targetCategory] += m_coefficient; |
|---|
| 265 | } |
|---|
| 266 | } |
|---|
| 267 | } |
|---|
| 268 | } |
|---|
| 269 | |
|---|
| 270 | /** |
|---|
| 271 | * Inner class to handle PredictorTerms. |
|---|
| 272 | */ |
|---|
| 273 | protected class PredictorTerm implements Serializable { |
|---|
| 274 | |
|---|
| 275 | /** For serialization */ |
|---|
| 276 | private static final long serialVersionUID = 5493100145890252757L; |
|---|
| 277 | |
|---|
| 278 | /** The coefficient for this predictor term */ |
|---|
| 279 | protected double m_coefficient = 1.0; |
|---|
| 280 | |
|---|
| 281 | /** the indexes of the terms to be multiplied */ |
|---|
| 282 | protected int[] m_indexes; |
|---|
| 283 | |
|---|
| 284 | /** The names of the terms (attributes) to be multiplied */ |
|---|
| 285 | protected String[] m_fieldNames; |
|---|
| 286 | |
|---|
| 287 | /** |
|---|
| 288 | * Construct a new PredictorTerm. |
|---|
| 289 | * |
|---|
| 290 | * @param predictorTerm the <code>Element</code> describing the predictor term |
|---|
| 291 | * @param miningSchema the mining schema as an Instances object |
|---|
| 292 | * @throws Exception if something goes wrong while constructing this |
|---|
| 293 | * predictor term |
|---|
| 294 | */ |
|---|
| 295 | protected PredictorTerm(Element predictorTerm, |
|---|
| 296 | Instances miningSchema) throws Exception { |
|---|
| 297 | |
|---|
| 298 | String coeff = predictorTerm.getAttribute("coefficient"); |
|---|
| 299 | if (coeff != null && coeff.length() > 0) { |
|---|
| 300 | try { |
|---|
| 301 | m_coefficient = Double.parseDouble(coeff); |
|---|
| 302 | } catch (IllegalArgumentException ex) { |
|---|
| 303 | throw new Exception("[PredictorTerm] unable to parse coefficient"); |
|---|
| 304 | } |
|---|
| 305 | } |
|---|
| 306 | |
|---|
| 307 | NodeList fields = predictorTerm.getElementsByTagName("FieldRef"); |
|---|
| 308 | if (fields.getLength() > 0) { |
|---|
| 309 | m_indexes = new int[fields.getLength()]; |
|---|
| 310 | m_fieldNames = new String[fields.getLength()]; |
|---|
| 311 | |
|---|
| 312 | for (int i = 0; i < fields.getLength(); i++) { |
|---|
| 313 | Node fieldRef = fields.item(i); |
|---|
| 314 | if (fieldRef.getNodeType() == Node.ELEMENT_NODE) { |
|---|
| 315 | String fieldName = ((Element)fieldRef).getAttribute("field"); |
|---|
| 316 | if (fieldName != null && fieldName.length() > 0) { |
|---|
| 317 | boolean found = false; |
|---|
| 318 | // look for this field in the mining schema |
|---|
| 319 | for (int j = 0; j < miningSchema.numAttributes(); j++) { |
|---|
| 320 | if (miningSchema.attribute(j).name().equals(fieldName)) { |
|---|
| 321 | |
|---|
| 322 | // all referenced fields MUST be numeric |
|---|
| 323 | if (!miningSchema.attribute(j).isNumeric()) { |
|---|
| 324 | throw new Exception("[PredictorTerm] field is not continuous: " |
|---|
| 325 | + fieldName); |
|---|
| 326 | } |
|---|
| 327 | found = true; |
|---|
| 328 | m_indexes[i] = j; |
|---|
| 329 | m_fieldNames[i] = fieldName; |
|---|
| 330 | break; |
|---|
| 331 | } |
|---|
| 332 | } |
|---|
| 333 | if (!found) { |
|---|
| 334 | throw new Exception("[PredictorTerm] Unable to find field " |
|---|
| 335 | + fieldName + " in mining schema!"); |
|---|
| 336 | } |
|---|
| 337 | } |
|---|
| 338 | } |
|---|
| 339 | } |
|---|
| 340 | } |
|---|
| 341 | } |
|---|
| 342 | |
|---|
| 343 | /** |
|---|
| 344 | * Return a textual description of this predictor term. |
|---|
| 345 | */ |
|---|
| 346 | public String toString() { |
|---|
| 347 | StringBuffer result = new StringBuffer(); |
|---|
| 348 | result.append("(" + Utils.doubleToString(m_coefficient, 12, 4)); |
|---|
| 349 | for (int i = 0; i < m_fieldNames.length; i++) { |
|---|
| 350 | result.append(" * " + m_fieldNames[i]); |
|---|
| 351 | } |
|---|
| 352 | result.append(")"); |
|---|
| 353 | return result.toString(); |
|---|
| 354 | } |
|---|
| 355 | |
|---|
| 356 | /** |
|---|
| 357 | * Adds this predictor term into the sum for the |
|---|
| 358 | * current prediction. |
|---|
| 359 | * |
|---|
| 360 | * @param preds the prediction computed so far. For regression, it is a |
|---|
| 361 | * single element array; for classification it is a multi-element array |
|---|
| 362 | * @param input the input instance's values |
|---|
| 363 | */ |
|---|
| 364 | public void add(double[] preds, double[] input) { |
|---|
| 365 | int indx = 0; |
|---|
| 366 | if (m_targetCategory != -1) { |
|---|
| 367 | indx = m_targetCategory; |
|---|
| 368 | } |
|---|
| 369 | |
|---|
| 370 | double result = m_coefficient; |
|---|
| 371 | for (int i = 0; i < m_indexes.length; i++) { |
|---|
| 372 | result *= input[m_indexes[i]]; |
|---|
| 373 | } |
|---|
| 374 | preds[indx] += result; |
|---|
| 375 | } |
|---|
| 376 | } |
|---|
| 377 | |
|---|
| 378 | /** Constant for regression model type */ |
|---|
| 379 | public static final int REGRESSION = 0; |
|---|
| 380 | |
|---|
| 381 | /** Constant for classification model type */ |
|---|
| 382 | public static final int CLASSIFICATION = 1; |
|---|
| 383 | |
|---|
| 384 | /** The type of function - regression or classification */ |
|---|
| 385 | protected int m_functionType = REGRESSION; |
|---|
| 386 | |
|---|
| 387 | /** The mining schema */ |
|---|
| 388 | protected MiningSchema m_miningSchema; |
|---|
| 389 | |
|---|
| 390 | /** The intercept */ |
|---|
| 391 | protected double m_intercept = 0.0; |
|---|
| 392 | |
|---|
| 393 | /** classification only */ |
|---|
| 394 | protected int m_targetCategory = -1; |
|---|
| 395 | |
|---|
| 396 | /** Numeric and categorical predictors */ |
|---|
| 397 | protected ArrayList<Predictor> m_predictors = |
|---|
| 398 | new ArrayList<Predictor>(); |
|---|
| 399 | |
|---|
| 400 | /** Interaction terms */ |
|---|
| 401 | protected ArrayList<PredictorTerm> m_predictorTerms = |
|---|
| 402 | new ArrayList<PredictorTerm>(); |
|---|
| 403 | |
|---|
| 404 | /** |
|---|
| 405 | * Return a textual description of this RegressionTable. |
|---|
| 406 | */ |
|---|
| 407 | public String toString() { |
|---|
| 408 | Instances miningSchema = m_miningSchema.getFieldsAsInstances(); |
|---|
| 409 | StringBuffer temp = new StringBuffer(); |
|---|
| 410 | temp.append("Regression table:\n"); |
|---|
| 411 | temp.append(miningSchema.classAttribute().name()); |
|---|
| 412 | if (m_functionType == CLASSIFICATION) { |
|---|
| 413 | temp.append("=" + miningSchema. |
|---|
| 414 | classAttribute().value(m_targetCategory)); |
|---|
| 415 | } |
|---|
| 416 | |
|---|
| 417 | temp.append(" =\n\n"); |
|---|
| 418 | |
|---|
| 419 | // do the predictors |
|---|
| 420 | for (int i = 0; i < m_predictors.size(); i++) { |
|---|
| 421 | temp.append(m_predictors.get(i).toString() + " +\n"); |
|---|
| 422 | } |
|---|
| 423 | |
|---|
| 424 | // do the predictor terms |
|---|
| 425 | for (int i = 0; i < m_predictorTerms.size(); i++) { |
|---|
| 426 | temp.append(m_predictorTerms.get(i).toString() + " +\n"); |
|---|
| 427 | } |
|---|
| 428 | |
|---|
| 429 | temp.append(Utils.doubleToString(m_intercept, 12, 4)); |
|---|
| 430 | temp.append("\n\n"); |
|---|
| 431 | |
|---|
| 432 | return temp.toString(); |
|---|
| 433 | } |
|---|
| 434 | |
|---|
| 435 | /** |
|---|
| 436 | * Construct a regression table from an <code>Element</code> |
|---|
| 437 | * |
|---|
| 438 | * @param table the table to encapsulate |
|---|
| 439 | * @param functionType the type of function |
|---|
| 440 | * (regression or classification) |
|---|
| 441 | * to use |
|---|
| 442 | * @param mSchema the mining schema |
|---|
| 443 | * @throws Exception if there is a problem while constructing |
|---|
| 444 | * this regression table |
|---|
| 445 | */ |
|---|
| 446 | protected RegressionTable(Element table, |
|---|
| 447 | int functionType, |
|---|
| 448 | MiningSchema mSchema) throws Exception { |
|---|
| 449 | |
|---|
| 450 | m_miningSchema = mSchema; |
|---|
| 451 | m_functionType = functionType; |
|---|
| 452 | |
|---|
| 453 | Instances miningSchema = m_miningSchema.getFieldsAsInstances(); |
|---|
| 454 | |
|---|
| 455 | // get the intercept |
|---|
| 456 | String intercept = table.getAttribute("intercept"); |
|---|
| 457 | if (intercept.length() > 0) { |
|---|
| 458 | m_intercept = Double.parseDouble(intercept); |
|---|
| 459 | } |
|---|
| 460 | |
|---|
| 461 | // get the target category (if classification) |
|---|
| 462 | if (m_functionType == CLASSIFICATION) { |
|---|
| 463 | // target category MUST be defined |
|---|
| 464 | String targetCat = table.getAttribute("targetCategory"); |
|---|
| 465 | if (targetCat.length() > 0) { |
|---|
| 466 | Attribute classA = miningSchema.classAttribute(); |
|---|
| 467 | for (int i = 0; i < classA.numValues(); i++) { |
|---|
| 468 | if (classA.value(i).equals(targetCat)) { |
|---|
| 469 | m_targetCategory = i; |
|---|
| 470 | } |
|---|
| 471 | } |
|---|
| 472 | } |
|---|
| 473 | if (m_targetCategory == -1) { |
|---|
| 474 | throw new Exception("[RegressionTable] No target categories defined for classification"); |
|---|
| 475 | } |
|---|
| 476 | } |
|---|
| 477 | |
|---|
| 478 | // read all the numeric predictors |
|---|
| 479 | NodeList numericPs = table.getElementsByTagName("NumericPredictor"); |
|---|
| 480 | for (int i = 0; i < numericPs.getLength(); i++) { |
|---|
| 481 | Node nP = numericPs.item(i); |
|---|
| 482 | if (nP.getNodeType() == Node.ELEMENT_NODE) { |
|---|
| 483 | NumericPredictor numP = new NumericPredictor((Element)nP, miningSchema); |
|---|
| 484 | m_predictors.add(numP); |
|---|
| 485 | } |
|---|
| 486 | } |
|---|
| 487 | |
|---|
| 488 | // read all the categorical predictors |
|---|
| 489 | NodeList categoricalPs = table.getElementsByTagName("CategoricalPredictor"); |
|---|
| 490 | for (int i = 0; i < categoricalPs.getLength(); i++) { |
|---|
| 491 | Node cP = categoricalPs.item(i); |
|---|
| 492 | if (cP.getNodeType() == Node.ELEMENT_NODE) { |
|---|
| 493 | CategoricalPredictor catP = new CategoricalPredictor((Element)cP, miningSchema); |
|---|
| 494 | m_predictors.add(catP); |
|---|
| 495 | } |
|---|
| 496 | } |
|---|
| 497 | |
|---|
| 498 | // read all the PredictorTerms |
|---|
| 499 | NodeList predictorTerms = table.getElementsByTagName("PredictorTerm"); |
|---|
| 500 | for (int i = 0; i < predictorTerms.getLength(); i++) { |
|---|
| 501 | Node pT = predictorTerms.item(i); |
|---|
| 502 | PredictorTerm predT = new PredictorTerm((Element)pT, miningSchema); |
|---|
| 503 | m_predictorTerms.add(predT); |
|---|
| 504 | } |
|---|
| 505 | } |
|---|
| 506 | |
|---|
| 507 | public void predict(double[] preds, double[] input) { |
|---|
| 508 | if (m_targetCategory == -1) { |
|---|
| 509 | preds[0] = m_intercept; |
|---|
| 510 | } else { |
|---|
| 511 | preds[m_targetCategory] = m_intercept; |
|---|
| 512 | } |
|---|
| 513 | |
|---|
| 514 | // add the predictors |
|---|
| 515 | for (int i = 0; i < m_predictors.size(); i++) { |
|---|
| 516 | Predictor p = m_predictors.get(i); |
|---|
| 517 | p.add(preds, input); |
|---|
| 518 | } |
|---|
| 519 | |
|---|
| 520 | // add the PredictorTerms |
|---|
| 521 | for (int i = 0; i < m_predictorTerms.size(); i++) { |
|---|
| 522 | PredictorTerm pt = m_predictorTerms.get(i); |
|---|
| 523 | pt.add(preds, input); |
|---|
| 524 | } |
|---|
| 525 | } |
|---|
| 526 | } |
|---|
| 527 | |
|---|
| 528 | /** Description of the algorithm */ |
|---|
| 529 | protected String m_algorithmName; |
|---|
| 530 | |
|---|
| 531 | /** The regression tables for this regression */ |
|---|
| 532 | protected RegressionTable[] m_regressionTables; |
|---|
| 533 | |
|---|
| 534 | /** |
|---|
| 535 | * Enum for the normalization methods. |
|---|
| 536 | */ |
|---|
| 537 | enum Normalization { |
|---|
| 538 | NONE, SIMPLEMAX, SOFTMAX, LOGIT, PROBIT, CLOGLOG, |
|---|
| 539 | EXP, LOGLOG, CAUCHIT} |
|---|
| 540 | |
|---|
| 541 | /** The normalization to use */ |
|---|
| 542 | protected Normalization m_normalizationMethod = Normalization.NONE; |
|---|
| 543 | |
|---|
| 544 | /** |
|---|
| 545 | * Constructs a new PMML Regression. |
|---|
| 546 | * |
|---|
| 547 | * @param model the <code>Element</code> containing the regression model |
|---|
| 548 | * @param dataDictionary the data dictionary as an Instances object |
|---|
| 549 | * @param miningSchema the mining schema |
|---|
| 550 | * @throws Exception if there is a problem constructing this Regression |
|---|
| 551 | */ |
|---|
| 552 | public Regression(Element model, Instances dataDictionary, |
|---|
| 553 | MiningSchema miningSchema) throws Exception { |
|---|
| 554 | super(dataDictionary, miningSchema); |
|---|
| 555 | |
|---|
| 556 | int functionType = RegressionTable.REGRESSION; |
|---|
| 557 | |
|---|
| 558 | // determine function name first |
|---|
| 559 | String fName = model.getAttribute("functionName"); |
|---|
| 560 | |
|---|
| 561 | if (fName.equals("regression")) { |
|---|
| 562 | functionType = RegressionTable.REGRESSION; |
|---|
| 563 | } else if (fName.equals("classification")) { |
|---|
| 564 | functionType = RegressionTable.CLASSIFICATION; |
|---|
| 565 | } else { |
|---|
| 566 | throw new Exception("[PMML Regression] Function name not defined in pmml!"); |
|---|
| 567 | } |
|---|
| 568 | |
|---|
| 569 | // do we have an algorithm name? |
|---|
| 570 | String algName = model.getAttribute("algorithmName"); |
|---|
| 571 | if (algName != null && algName.length() > 0) { |
|---|
| 572 | m_algorithmName = algName; |
|---|
| 573 | } |
|---|
| 574 | |
|---|
| 575 | // determine normalization method (if any) |
|---|
| 576 | m_normalizationMethod = determineNormalization(model); |
|---|
| 577 | |
|---|
| 578 | setUpRegressionTables(model, functionType); |
|---|
| 579 | |
|---|
| 580 | // convert any string attributes in the mining schema |
|---|
| 581 | //miningSchema.convertStringAttsToNominal(); |
|---|
| 582 | } |
|---|
| 583 | |
|---|
| 584 | /** |
|---|
| 585 | * Create all the RegressionTables for this model. |
|---|
| 586 | * |
|---|
| 587 | * @param model the <code>Element</code> holding this regression model |
|---|
| 588 | * @param functionType the type of function (regression or |
|---|
| 589 | * classification) |
|---|
| 590 | * @throws Exception if there is a problem setting up the regression |
|---|
| 591 | * tables |
|---|
| 592 | */ |
|---|
| 593 | private void setUpRegressionTables(Element model, |
|---|
| 594 | int functionType) throws Exception { |
|---|
| 595 | NodeList tableList = model.getElementsByTagName("RegressionTable"); |
|---|
| 596 | |
|---|
| 597 | if (tableList.getLength() == 0) { |
|---|
| 598 | throw new Exception("[Regression] no regression tables defined!"); |
|---|
| 599 | } |
|---|
| 600 | |
|---|
| 601 | m_regressionTables = new RegressionTable[tableList.getLength()]; |
|---|
| 602 | |
|---|
| 603 | for (int i = 0; i < tableList.getLength(); i++) { |
|---|
| 604 | Node table = tableList.item(i); |
|---|
| 605 | if (table.getNodeType() == Node.ELEMENT_NODE) { |
|---|
| 606 | RegressionTable tempRTable = |
|---|
| 607 | new RegressionTable((Element)table, |
|---|
| 608 | functionType, |
|---|
| 609 | m_miningSchema); |
|---|
| 610 | m_regressionTables[i] = tempRTable; |
|---|
| 611 | } |
|---|
| 612 | } |
|---|
| 613 | } |
|---|
| 614 | |
|---|
| 615 | /** |
|---|
| 616 | * Return the type of normalization used for this regression |
|---|
| 617 | * |
|---|
| 618 | * @param model the <code>Element</code> holding the model |
|---|
| 619 | * @return the normalization used in this regression |
|---|
| 620 | */ |
|---|
| 621 | private static Normalization determineNormalization(Element model) { |
|---|
| 622 | |
|---|
| 623 | Normalization normMethod = Normalization.NONE; |
|---|
| 624 | |
|---|
| 625 | String normName = model.getAttribute("normalizationMethod"); |
|---|
| 626 | if (normName.equals("simplemax")) { |
|---|
| 627 | normMethod = Normalization.SIMPLEMAX; |
|---|
| 628 | } else if (normName.equals("softmax")) { |
|---|
| 629 | normMethod = Normalization.SOFTMAX; |
|---|
| 630 | } else if (normName.equals("logit")) { |
|---|
| 631 | normMethod = Normalization.LOGIT; |
|---|
| 632 | } else if (normName.equals("probit")) { |
|---|
| 633 | normMethod = Normalization.PROBIT; |
|---|
| 634 | } else if (normName.equals("cloglog")) { |
|---|
| 635 | normMethod = Normalization.CLOGLOG; |
|---|
| 636 | } else if (normName.equals("exp")) { |
|---|
| 637 | normMethod = Normalization.EXP; |
|---|
| 638 | } else if (normName.equals("loglog")) { |
|---|
| 639 | normMethod = Normalization.LOGLOG; |
|---|
| 640 | } else if (normName.equals("cauchit")) { |
|---|
| 641 | normMethod = Normalization.CAUCHIT; |
|---|
| 642 | } |
|---|
| 643 | return normMethod; |
|---|
| 644 | } |
|---|
| 645 | |
|---|
| 646 | /** |
|---|
| 647 | * Return a textual description of this Regression model. |
|---|
| 648 | */ |
|---|
| 649 | public String toString() { |
|---|
| 650 | StringBuffer temp = new StringBuffer(); |
|---|
| 651 | temp.append("PMML version " + getPMMLVersion()); |
|---|
| 652 | if (!getCreatorApplication().equals("?")) { |
|---|
| 653 | temp.append("\nApplication: " + getCreatorApplication()); |
|---|
| 654 | } |
|---|
| 655 | if (m_algorithmName != null) { |
|---|
| 656 | temp.append("\nPMML Model: " + m_algorithmName); |
|---|
| 657 | } |
|---|
| 658 | temp.append("\n\n"); |
|---|
| 659 | temp.append(m_miningSchema); |
|---|
| 660 | |
|---|
| 661 | for (RegressionTable table : m_regressionTables) { |
|---|
| 662 | temp.append(table); |
|---|
| 663 | } |
|---|
| 664 | |
|---|
| 665 | if (m_normalizationMethod != Normalization.NONE) { |
|---|
| 666 | temp.append("Normalization: " + m_normalizationMethod); |
|---|
| 667 | } |
|---|
| 668 | temp.append("\n"); |
|---|
| 669 | |
|---|
| 670 | return temp.toString(); |
|---|
| 671 | } |
|---|
| 672 | |
|---|
| 673 | /** |
|---|
| 674 | * Classifies the given test instance. The instance has to belong to a |
|---|
| 675 | * dataset when it's being classified. |
|---|
| 676 | * |
|---|
| 677 | * @param inst the instance to be classified |
|---|
| 678 | * @return the predicted most likely class for the instance or |
|---|
| 679 | * Utils.missingValue() if no prediction is made |
|---|
| 680 | * @exception Exception if an error occurred during the prediction |
|---|
| 681 | */ |
|---|
| 682 | public double[] distributionForInstance(Instance inst) throws Exception { |
|---|
| 683 | if (!m_initialized) { |
|---|
| 684 | mapToMiningSchema(inst.dataset()); |
|---|
| 685 | } |
|---|
| 686 | double[] preds = null; |
|---|
| 687 | if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) { |
|---|
| 688 | preds = new double[1]; |
|---|
| 689 | } else { |
|---|
| 690 | preds = new double[m_miningSchema.getFieldsAsInstances().classAttribute().numValues()]; |
|---|
| 691 | } |
|---|
| 692 | |
|---|
| 693 | // create an array of doubles that holds values from the incoming |
|---|
| 694 | // instance; in order of the fields in the mining schema. We will |
|---|
| 695 | // also handle missing values and outliers here. |
|---|
| 696 | // System.err.println(inst); |
|---|
| 697 | double[] incoming = m_fieldsMap.instanceToSchema(inst, m_miningSchema); |
|---|
| 698 | |
|---|
| 699 | // scan for missing values. If there are still missing values after instanceToSchema(), |
|---|
| 700 | // then missing value handling has been deferred to the PMML scheme. The specification |
|---|
| 701 | // (Regression PMML 3.2) seems to contradict itself with regards to classification and categorical |
|---|
| 702 | // variables. In one place it states that if a categorical variable is missing then |
|---|
| 703 | // variable_name=value is 0 for any value. Further down in the document it states: "if |
|---|
| 704 | // one or more of the y_j cannot be evaluated because the value in one of the referenced |
|---|
| 705 | // fields is missing, then the following formulas (for computing p_j) do not apply. In |
|---|
| 706 | // that case the predictions are defined by the priorProbability values in the Target |
|---|
| 707 | // element". |
|---|
| 708 | |
|---|
| 709 | // In this implementation we will default to information in the Target element (default |
|---|
| 710 | // value for numeric prediction and prior probabilities for classification). If there is |
|---|
| 711 | // no Target element defined, then an Exception is thrown. |
|---|
| 712 | |
|---|
| 713 | boolean hasMissing = false; |
|---|
| 714 | for (int i = 0; i < incoming.length; i++) { |
|---|
| 715 | if (i != m_miningSchema.getFieldsAsInstances().classIndex() && |
|---|
| 716 | Utils.isMissingValue(incoming[i])) { |
|---|
| 717 | hasMissing = true; |
|---|
| 718 | break; |
|---|
| 719 | } |
|---|
| 720 | } |
|---|
| 721 | |
|---|
| 722 | if (hasMissing) { |
|---|
| 723 | if (!m_miningSchema.hasTargetMetaData()) { |
|---|
| 724 | String message = "[Regression] WARNING: Instance to predict has missing value(s) but " |
|---|
| 725 | + "there is no missing value handling meta data and no " |
|---|
| 726 | + "prior probabilities/default value to fall back to. No " |
|---|
| 727 | + "prediction will be made (" |
|---|
| 728 | + ((m_miningSchema.getFieldsAsInstances().classAttribute().isNominal() || |
|---|
| 729 | m_miningSchema.getFieldsAsInstances().classAttribute().isString()) |
|---|
| 730 | ? "zero probabilities output)." |
|---|
| 731 | : "NaN output)."); |
|---|
| 732 | if (m_log == null) { |
|---|
| 733 | System.err.println(message); |
|---|
| 734 | } else { |
|---|
| 735 | m_log.logMessage(message); |
|---|
| 736 | } |
|---|
| 737 | if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) { |
|---|
| 738 | preds[0] = Utils.missingValue(); |
|---|
| 739 | } |
|---|
| 740 | return preds; |
|---|
| 741 | } else { |
|---|
| 742 | // use prior probablilities/default value |
|---|
| 743 | TargetMetaInfo targetData = m_miningSchema.getTargetMetaData(); |
|---|
| 744 | if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) { |
|---|
| 745 | preds[0] = targetData.getDefaultValue(); |
|---|
| 746 | } else { |
|---|
| 747 | Instances miningSchemaI = m_miningSchema.getFieldsAsInstances(); |
|---|
| 748 | for (int i = 0; i < miningSchemaI.classAttribute().numValues(); i++) { |
|---|
| 749 | preds[i] = targetData.getPriorProbability(miningSchemaI.classAttribute().value(i)); |
|---|
| 750 | } |
|---|
| 751 | } |
|---|
| 752 | return preds; |
|---|
| 753 | } |
|---|
| 754 | } else { |
|---|
| 755 | // loop through the RegressionTables |
|---|
| 756 | for (int i = 0; i < m_regressionTables.length; i++) { |
|---|
| 757 | m_regressionTables[i].predict(preds, incoming); |
|---|
| 758 | } |
|---|
| 759 | |
|---|
| 760 | // Now apply the normalization |
|---|
| 761 | switch (m_normalizationMethod) { |
|---|
| 762 | case NONE: |
|---|
| 763 | // nothing to be done |
|---|
| 764 | break; |
|---|
| 765 | case SIMPLEMAX: |
|---|
| 766 | Utils.normalize(preds); |
|---|
| 767 | break; |
|---|
| 768 | case SOFTMAX: |
|---|
| 769 | for (int i = 0; i < preds.length; i++) { |
|---|
| 770 | preds[i] = Math.exp(preds[i]); |
|---|
| 771 | } |
|---|
| 772 | if (preds.length == 1) { |
|---|
| 773 | // hack for those models that do binary logistic regression as |
|---|
| 774 | // a numeric prediction model |
|---|
| 775 | preds[0] = preds[0] / (preds[0] + 1.0); |
|---|
| 776 | } else { |
|---|
| 777 | Utils.normalize(preds); |
|---|
| 778 | } |
|---|
| 779 | break; |
|---|
| 780 | case LOGIT: |
|---|
| 781 | for (int i = 0; i < preds.length; i++) { |
|---|
| 782 | preds[i] = 1.0 / (1.0 + Math.exp(-preds[i])); |
|---|
| 783 | } |
|---|
| 784 | Utils.normalize(preds); |
|---|
| 785 | break; |
|---|
| 786 | case PROBIT: |
|---|
| 787 | for (int i = 0; i < preds.length; i++) { |
|---|
| 788 | preds[i] = weka.core.matrix.Maths.pnorm(preds[i]); |
|---|
| 789 | } |
|---|
| 790 | Utils.normalize(preds); |
|---|
| 791 | break; |
|---|
| 792 | case CLOGLOG: |
|---|
| 793 | // note this is supposed to be illegal for regression |
|---|
| 794 | for (int i = 0; i < preds.length; i++) { |
|---|
| 795 | preds[i] = 1.0 - Math.exp(-Math.exp(-preds[i])); |
|---|
| 796 | } |
|---|
| 797 | Utils.normalize(preds); |
|---|
| 798 | break; |
|---|
| 799 | case EXP: |
|---|
| 800 | for (int i = 0; i < preds.length; i++) { |
|---|
| 801 | preds[i] = Math.exp(preds[i]); |
|---|
| 802 | } |
|---|
| 803 | Utils.normalize(preds); |
|---|
| 804 | break; |
|---|
| 805 | case LOGLOG: |
|---|
| 806 | // note this is supposed to be illegal for regression |
|---|
| 807 | for (int i = 0; i < preds.length; i++) { |
|---|
| 808 | preds[i] = Math.exp(-Math.exp(-preds[i])); |
|---|
| 809 | } |
|---|
| 810 | Utils.normalize(preds); |
|---|
| 811 | break; |
|---|
| 812 | case CAUCHIT: |
|---|
| 813 | for (int i = 0; i < preds.length; i++) { |
|---|
| 814 | preds[i] = 0.5 + (1.0 / Math.PI) * Math.atan(preds[i]); |
|---|
| 815 | } |
|---|
| 816 | Utils.normalize(preds); |
|---|
| 817 | break; |
|---|
| 818 | default: |
|---|
| 819 | throw new Exception("[Regression] unknown normalization method"); |
|---|
| 820 | } |
|---|
| 821 | |
|---|
| 822 | // If there is a Target defined, and this is a numeric prediction problem, |
|---|
| 823 | // then apply any min, max, rescaling etc. |
|---|
| 824 | if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric() |
|---|
| 825 | && m_miningSchema.hasTargetMetaData()) { |
|---|
| 826 | TargetMetaInfo targetData = m_miningSchema.getTargetMetaData(); |
|---|
| 827 | preds[0] = targetData.applyMinMaxRescaleCast(preds[0]); |
|---|
| 828 | } |
|---|
| 829 | } |
|---|
| 830 | |
|---|
| 831 | return preds; |
|---|
| 832 | } |
|---|
| 833 | |
|---|
| 834 | /* (non-Javadoc) |
|---|
| 835 | * @see weka.core.RevisionHandler#getRevision() |
|---|
| 836 | */ |
|---|
| 837 | public String getRevision() { |
|---|
| 838 | return RevisionUtils.extract("$Revision: 6018 $"); |
|---|
| 839 | } |
|---|
| 840 | } |
|---|