source: src/main/java/weka/classifiers/pmml/consumer/Regression.java @ 11

Last change on this file since 11 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 29.0 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    Regression.java
19 *    Copyright (C) 2008 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.pmml.consumer;
24
25import java.io.Serializable;
26import java.util.ArrayList;
27import org.w3c.dom.Element;
28import org.w3c.dom.Node;
29import org.w3c.dom.NodeList;
30
31import weka.core.Attribute;
32import weka.core.Instance;
33import weka.core.Instances;
34import weka.core.RevisionUtils;
35import weka.core.Utils;
36import weka.core.pmml.*;
37
38/**
39 * Class implementing import of PMML Regression model. Can be
40 * used as a Weka classifier for prediction (buildClassifier()
41 * raises an Exception).
42 *
43 * @author Mark Hall (mhall{[at]}pentaho{[dot]}com
44 * @version $Revision: 6018 $
45 */
46public class Regression extends PMMLClassifier
47  implements Serializable {
48
49  /** For serialization */
50  private static final long serialVersionUID = -5551125528409488634L;
51
52  /**
53   * Inner class for encapsulating a regression table
54   */
55  static class RegressionTable implements Serializable {
56
57    /** For serialization */
58    private static final long serialVersionUID = -5259866093996338995L;
59
60    /**
61     * Abstract inner base class for different predictor types.
62     */
63    abstract static class Predictor implements Serializable {
64
65      /** For serialization */
66      private static final long serialVersionUID = 7043831847273383618L;
67     
68      /** Name of this predictor */
69      protected String m_name;
70     
71      /**
72       * Index of the attribute in the mining schema that corresponds to this
73       * predictor
74       */
75      protected int m_miningSchemaAttIndex = -1;
76     
77      /** Coefficient for this predictor */
78      protected double m_coefficient = 1.0;
79     
80      /**
81       * Constructs a new Predictor.
82       *
83       * @param predictor the <code>Element</code> encapsulating this predictor
84       * @param miningSchema the mining schema as an Instances object
85       * @throws Exception if there is a problem constructing this Predictor
86       */
87      protected Predictor(Element predictor, Instances miningSchema) throws Exception {
88        m_name = predictor.getAttribute("name");
89        for (int i = 0; i < miningSchema.numAttributes(); i++) {
90          Attribute temp = miningSchema.attribute(i);
91          if (temp.name().equals(m_name)) {
92            m_miningSchemaAttIndex = i;
93          }
94        }
95       
96        if (m_miningSchemaAttIndex == -1) {
97          throw new Exception("[Predictor] unable to find matching attribute for "
98                              + "predictor " + m_name);
99        }
100
101        String coeff = predictor.getAttribute("coefficient");
102        if (coeff.length() > 0) {
103          m_coefficient = Double.parseDouble(coeff);
104        }
105      }
106
107      /**
108       * Returns a textual description of this predictor applicable
109       * to all sub classes.
110       */
111      public String toString() {
112        return Utils.doubleToString(m_coefficient, 12, 4) + " * ";
113      }
114
115      /**
116       * Abstract add method. Adds this predictor into the sum for the
117       * current prediction.
118       *
119       * @param preds the prediction computed so far. For regression, it is a
120       * single element array; for classification it is a multi-element array
121       * @param input the input instance's values
122       */
123      public abstract void add(double[] preds, double[] input);
124    }
125
126    /**
127     * Inner class for a numeric predictor
128     */
129    protected class NumericPredictor extends Predictor {
130      /**
131       * For serialization
132       */
133      private static final long serialVersionUID = -4335075205696648273L;
134     
135      /** The exponent*/
136      protected double m_exponent = 1.0;
137
138      /**
139       * Constructs a NumericPredictor.
140       *
141       * @param predictor the <code>Element</code> holding the predictor
142       * @param miningSchema the mining schema as an Instances object
143       * @throws Exception if something goes wrong while constructing this
144       * predictor
145       */
146      protected NumericPredictor(Element predictor, 
147                              Instances miningSchema) throws Exception {
148        super(predictor, miningSchema);
149       
150        String exponent = predictor.getAttribute("exponent");
151        if (exponent.length() > 0) {
152          m_exponent = Double.parseDouble(exponent);
153        }
154      }
155
156      /**
157       * Return a textual description of this predictor.
158       */
159      public String toString() {
160        String output = super.toString();
161        output += m_name;
162        if (m_exponent > 1.0 || m_exponent < 1.0) {
163          output += "^" + Utils.doubleToString(m_exponent, 4);
164        }
165        return output;
166      }
167
168      /**
169       * Adds this predictor into the sum for the
170       * current prediction.
171       *
172       * @param preds the prediction computed so far. For regression, it is a
173       * single element array; for classification it is a multi-element array
174       * @param input the input instance's values
175       */
176      public void add(double[] preds, double[] input) {
177        if (m_targetCategory == -1) {
178          preds[0] += m_coefficient * Math.pow(input[m_miningSchemaAttIndex], m_exponent);
179        } else {
180          preds[m_targetCategory] += 
181            m_coefficient * Math.pow(input[m_miningSchemaAttIndex], m_exponent);
182        }
183      }
184    }
185
186    /**
187     * Inner class encapsulating a categorical predictor.
188     */
189    protected class CategoricalPredictor extends Predictor {
190     
191      /**For serialization */
192      private static final long serialVersionUID = 3077920125549906819L;
193     
194      /** The attribute value for this predictor */
195      protected String m_valueName;
196     
197      /** The index of the attribute value for this predictor */
198      protected int m_valueIndex = -1;
199
200      /**
201       * Constructs a CategoricalPredictor.
202       *
203       * @param predictor the <code>Element</code> containing the predictor
204       * @param miningSchema the mining schema as an Instances object
205       * @throws Exception if something goes wrong while constructing
206       * this predictor
207       */
208      protected CategoricalPredictor(Element predictor,
209                                  Instances miningSchema) throws Exception {
210        super(predictor, miningSchema);
211       
212        String valName = predictor.getAttribute("value");
213        if (valName.length() == 0) {
214          throw new Exception("[CategoricalPredictor] attribute value not specified!");
215        }
216       
217        m_valueName = valName;
218
219        Attribute att = miningSchema.attribute(m_miningSchemaAttIndex);
220        if (att.isString()) {
221          // means that there were no Value elements defined in the
222          // data dictionary (and hence the mining schema).
223          // We add our value here.
224          att.addStringValue(m_valueName);
225        }
226        m_valueIndex = att.indexOfValue(m_valueName);
227        /*        for (int i = 0; i < att.numValues(); i++) {
228          if (att.value(i).equals(m_valueName)) {
229            m_valueIndex = i;
230          }
231          }*/
232
233        if (m_valueIndex == -1) {
234          throw new Exception("[CategoricalPredictor] unable to find value "
235                              + m_valueName + " in mining schema attribute "
236                              + att.name());
237        }
238      }
239
240      /**
241       * Return a textual description of this predictor.
242       */
243      public String toString() {
244        String output = super.toString();
245        output += m_name + "=" + m_valueName;
246        return output;
247      }
248
249      /**
250       * Adds this predictor into the sum for the
251       * current prediction.
252       *
253       * @param preds the prediction computed so far. For regression, it is a
254       * single element array; for classification it is a multi-element array
255       * @param input the input instance's values
256       */
257      public void add(double[] preds, double[] input) {
258       
259        // if the value is equal to the one in the input then add the coefficient
260        if (m_valueIndex == (int)input[m_miningSchemaAttIndex]) {
261          if (m_targetCategory == -1) {
262            preds[0] += m_coefficient;
263          } else {
264            preds[m_targetCategory] += m_coefficient;
265          }
266        }
267      }
268    }
269
270    /**
271     * Inner class to handle PredictorTerms.
272     */
273    protected class PredictorTerm implements Serializable {
274
275      /** For serialization */
276      private static final long serialVersionUID = 5493100145890252757L;
277
278      /** The coefficient for this predictor term */
279      protected double m_coefficient = 1.0;
280
281      /** the indexes of the terms to be multiplied */
282      protected int[] m_indexes;
283
284      /** The names of the terms (attributes) to be multiplied */
285      protected String[] m_fieldNames;
286
287      /**
288       * Construct a new PredictorTerm.
289       *
290       * @param predictorTerm the <code>Element</code> describing the predictor term
291       * @param miningSchema the mining schema as an Instances object
292       * @throws Exception if something goes wrong while constructing this
293       * predictor term
294       */
295      protected PredictorTerm(Element predictorTerm, 
296                              Instances miningSchema) throws Exception {
297
298        String coeff = predictorTerm.getAttribute("coefficient");
299        if (coeff != null && coeff.length() > 0) {
300          try {
301            m_coefficient = Double.parseDouble(coeff);
302          } catch (IllegalArgumentException ex) {
303            throw new Exception("[PredictorTerm] unable to parse coefficient");
304          }
305        }
306       
307        NodeList fields = predictorTerm.getElementsByTagName("FieldRef");
308        if (fields.getLength() > 0) {
309          m_indexes = new int[fields.getLength()];
310          m_fieldNames = new String[fields.getLength()];
311
312          for (int i = 0; i < fields.getLength(); i++) {
313            Node fieldRef = fields.item(i);
314            if (fieldRef.getNodeType() == Node.ELEMENT_NODE) {
315              String fieldName = ((Element)fieldRef).getAttribute("field");
316              if (fieldName != null && fieldName.length() > 0) {
317                boolean found = false;
318                // look for this field in the mining schema
319                for (int j = 0; j < miningSchema.numAttributes(); j++) {
320                  if (miningSchema.attribute(j).name().equals(fieldName)) {
321                   
322                    // all referenced fields MUST be numeric
323                    if (!miningSchema.attribute(j).isNumeric()) {
324                      throw new Exception("[PredictorTerm] field is not continuous: "
325                                          + fieldName);
326                    }
327                    found = true;
328                    m_indexes[i] = j;
329                    m_fieldNames[i] = fieldName;
330                    break;
331                  }
332                }
333                if (!found) {
334                  throw new Exception("[PredictorTerm] Unable to find field "
335                                      + fieldName + " in mining schema!");
336                }
337              }
338            }
339          }
340        }
341      }
342
343      /**
344       * Return a textual description of this predictor term.
345       */
346      public String toString() {
347        StringBuffer result = new StringBuffer();
348        result.append("(" + Utils.doubleToString(m_coefficient, 12, 4));
349        for (int i = 0; i < m_fieldNames.length; i++) {
350          result.append(" * " + m_fieldNames[i]);
351        }
352        result.append(")");
353        return result.toString();
354      }
355
356      /**
357       * Adds this predictor term into the sum for the
358       * current prediction.
359       *
360       * @param preds the prediction computed so far. For regression, it is a
361       * single element array; for classification it is a multi-element array
362       * @param input the input instance's values
363       */
364      public void add(double[] preds, double[] input) {
365        int indx = 0;
366        if (m_targetCategory != -1) {
367          indx = m_targetCategory;
368        }
369
370        double result = m_coefficient;
371        for (int i = 0; i < m_indexes.length; i++) {
372          result *= input[m_indexes[i]];
373        }
374        preds[indx] += result;
375      }
376    }
377   
378    /** Constant for regression model type */
379    public static final int REGRESSION = 0;
380   
381    /** Constant for classification model type */
382    public static final int CLASSIFICATION = 1;
383
384    /** The type of function - regression or classification */
385    protected int m_functionType = REGRESSION;
386   
387    /** The mining schema */
388    protected MiningSchema m_miningSchema;
389       
390    /** The intercept */
391    protected double m_intercept = 0.0;
392   
393    /** classification only */
394    protected int m_targetCategory = -1;
395
396    /** Numeric and categorical predictors */
397    protected ArrayList<Predictor> m_predictors = 
398      new ArrayList<Predictor>();
399
400    /** Interaction terms */
401    protected ArrayList<PredictorTerm> m_predictorTerms =
402      new ArrayList<PredictorTerm>();
403
404    /**
405     * Return a textual description of this RegressionTable.
406     */
407    public String toString() {
408      Instances miningSchema = m_miningSchema.getFieldsAsInstances();
409      StringBuffer temp = new StringBuffer();
410      temp.append("Regression table:\n");
411      temp.append(miningSchema.classAttribute().name());
412      if (m_functionType == CLASSIFICATION) {
413        temp.append("=" + miningSchema.
414                    classAttribute().value(m_targetCategory));
415      }
416
417      temp.append(" =\n\n");
418     
419      // do the predictors
420      for (int i = 0; i < m_predictors.size(); i++) {
421        temp.append(m_predictors.get(i).toString() + " +\n");
422      }
423     
424      // do the predictor terms
425      for (int i = 0; i < m_predictorTerms.size(); i++) {
426        temp.append(m_predictorTerms.get(i).toString() + " +\n");
427      }
428
429      temp.append(Utils.doubleToString(m_intercept, 12, 4));
430      temp.append("\n\n");
431
432      return temp.toString();
433    }
434
435    /**
436     * Construct a regression table from an <code>Element</code>
437     *
438     * @param table the table to encapsulate
439     * @param functionType the type of function
440     * (regression or classification)
441     * to use
442     * @param mSchema the mining schema
443     * @throws Exception if there is a problem while constructing
444     * this regression table
445     */
446    protected RegressionTable(Element table, 
447                           int functionType,
448                           MiningSchema mSchema) throws Exception {
449
450      m_miningSchema = mSchema;
451      m_functionType = functionType;
452
453      Instances miningSchema = m_miningSchema.getFieldsAsInstances();
454
455      // get the intercept
456      String intercept = table.getAttribute("intercept");
457      if (intercept.length() > 0) {
458        m_intercept = Double.parseDouble(intercept);
459      }
460
461      // get the target category (if classification)
462      if (m_functionType == CLASSIFICATION) {
463        // target category MUST be defined
464        String targetCat = table.getAttribute("targetCategory");
465        if (targetCat.length() > 0) {
466          Attribute classA = miningSchema.classAttribute();
467          for (int i = 0; i < classA.numValues(); i++) {
468            if (classA.value(i).equals(targetCat)) {
469              m_targetCategory = i;
470            }
471          }
472        } 
473        if (m_targetCategory == -1) {
474          throw new Exception("[RegressionTable] No target categories defined for classification");
475        }
476      }
477
478      // read all the numeric predictors
479      NodeList numericPs = table.getElementsByTagName("NumericPredictor");
480      for (int i = 0; i < numericPs.getLength(); i++) {
481        Node nP = numericPs.item(i);
482        if (nP.getNodeType() == Node.ELEMENT_NODE) {
483          NumericPredictor numP = new NumericPredictor((Element)nP, miningSchema);
484          m_predictors.add(numP);
485        }
486      }
487
488      // read all the categorical predictors
489      NodeList categoricalPs = table.getElementsByTagName("CategoricalPredictor");
490      for (int i = 0; i < categoricalPs.getLength(); i++) {
491        Node cP = categoricalPs.item(i);
492        if (cP.getNodeType() == Node.ELEMENT_NODE) {
493          CategoricalPredictor catP = new CategoricalPredictor((Element)cP, miningSchema);
494          m_predictors.add(catP);
495        }
496      }
497
498      // read all the PredictorTerms
499      NodeList predictorTerms = table.getElementsByTagName("PredictorTerm");
500      for (int i = 0; i < predictorTerms.getLength(); i++) {
501        Node pT = predictorTerms.item(i);
502        PredictorTerm predT = new PredictorTerm((Element)pT, miningSchema);
503        m_predictorTerms.add(predT);
504      }
505    }
506
507    public void predict(double[] preds, double[] input) {
508      if (m_targetCategory == -1) {
509        preds[0] = m_intercept;
510      } else {
511        preds[m_targetCategory] = m_intercept;
512      }
513     
514      // add the predictors
515      for (int i = 0; i < m_predictors.size(); i++) {
516        Predictor p = m_predictors.get(i);
517        p.add(preds, input);
518      }
519
520      // add the PredictorTerms
521      for (int i = 0; i < m_predictorTerms.size(); i++) {
522        PredictorTerm pt = m_predictorTerms.get(i);
523        pt.add(preds, input);
524      }
525    }
526  }
527
528  /** Description of the algorithm */
529  protected String m_algorithmName;
530
531  /** The regression tables for this regression */
532  protected RegressionTable[] m_regressionTables;
533
534  /**
535   * Enum for the normalization methods.
536   */
537  enum Normalization {
538    NONE, SIMPLEMAX, SOFTMAX, LOGIT, PROBIT, CLOGLOG,
539      EXP, LOGLOG, CAUCHIT}
540
541  /** The normalization to use */
542  protected Normalization m_normalizationMethod = Normalization.NONE;
543
544  /**
545   * Constructs a new PMML Regression.
546   *
547   * @param model the <code>Element</code> containing the regression model
548   * @param dataDictionary the data dictionary as an Instances object
549   * @param miningSchema the mining schema
550   * @throws Exception if there is a problem constructing this Regression
551   */
552  public Regression(Element model, Instances dataDictionary,
553                    MiningSchema miningSchema) throws Exception {
554    super(dataDictionary, miningSchema);
555   
556    int functionType = RegressionTable.REGRESSION;
557
558    // determine function name first
559    String fName = model.getAttribute("functionName");
560   
561    if (fName.equals("regression")) {
562      functionType = RegressionTable.REGRESSION;
563    } else if (fName.equals("classification")) {
564      functionType = RegressionTable.CLASSIFICATION;
565    } else {
566      throw new Exception("[PMML Regression] Function name not defined in pmml!");
567    }
568
569    // do we have an algorithm name?
570    String algName = model.getAttribute("algorithmName");
571    if (algName != null && algName.length() > 0) {
572      m_algorithmName = algName;
573    }
574
575    // determine normalization method (if any)
576    m_normalizationMethod = determineNormalization(model);
577
578    setUpRegressionTables(model, functionType);
579
580    // convert any string attributes in the mining schema
581    //miningSchema.convertStringAttsToNominal();
582  }
583
584  /**
585   * Create all the RegressionTables for this model.
586   *
587   * @param model the <code>Element</code> holding this regression model
588   * @param functionType the type of function (regression or
589   * classification)
590   * @throws Exception if there is a problem setting up the regression
591   * tables
592   */
593  private void setUpRegressionTables(Element model,
594                                     int functionType) throws Exception {
595    NodeList tableList = model.getElementsByTagName("RegressionTable");
596   
597    if (tableList.getLength() == 0) {
598      throw new Exception("[Regression] no regression tables defined!");
599    }
600
601    m_regressionTables = new RegressionTable[tableList.getLength()];
602   
603    for (int i = 0; i < tableList.getLength(); i++) {
604      Node table = tableList.item(i);
605      if (table.getNodeType() == Node.ELEMENT_NODE) {
606        RegressionTable tempRTable = 
607          new RegressionTable((Element)table, 
608                              functionType, 
609                              m_miningSchema);
610        m_regressionTables[i] = tempRTable;
611      }
612    }
613  }
614
615  /**
616   * Return the type of normalization used for this regression
617   *
618   * @param model the <code>Element</code> holding the model
619   * @return the normalization used in this regression
620   */
621  private static Normalization determineNormalization(Element model) {
622   
623    Normalization normMethod = Normalization.NONE;
624
625    String normName = model.getAttribute("normalizationMethod");
626    if (normName.equals("simplemax")) {
627      normMethod = Normalization.SIMPLEMAX;
628    } else if (normName.equals("softmax")) {
629      normMethod = Normalization.SOFTMAX;
630    } else if (normName.equals("logit")) {
631      normMethod = Normalization.LOGIT;
632    } else if (normName.equals("probit")) {
633      normMethod = Normalization.PROBIT;
634    } else if (normName.equals("cloglog")) {
635      normMethod = Normalization.CLOGLOG;
636    } else if (normName.equals("exp")) {
637      normMethod = Normalization.EXP;
638    } else if (normName.equals("loglog")) {
639      normMethod = Normalization.LOGLOG;
640    } else if (normName.equals("cauchit")) {
641      normMethod = Normalization.CAUCHIT;
642    } 
643    return normMethod;
644  }
645
646  /**
647   * Return a textual description of this Regression model.
648   */
649  public String toString() {
650    StringBuffer temp = new StringBuffer();
651    temp.append("PMML version " + getPMMLVersion());
652    if (!getCreatorApplication().equals("?")) {
653      temp.append("\nApplication: " + getCreatorApplication());
654    }
655    if (m_algorithmName != null) {
656      temp.append("\nPMML Model: " + m_algorithmName);
657    }
658    temp.append("\n\n");
659    temp.append(m_miningSchema);
660
661    for (RegressionTable table : m_regressionTables) {
662      temp.append(table);
663    }
664   
665    if (m_normalizationMethod != Normalization.NONE) {
666      temp.append("Normalization: " + m_normalizationMethod);
667    }
668    temp.append("\n");
669
670    return temp.toString();
671  }
672
673  /**                                                                                                             
674   * Classifies the given test instance. The instance has to belong to a                                         
675   * dataset when it's being classified.                                                         
676   *                                                                                                             
677   * @param inst the instance to be classified                                                               
678   * @return the predicted most likely class for the instance or                                                 
679   * Utils.missingValue() if no prediction is made                                                             
680   * @exception Exception if an error occurred during the prediction                                             
681   */
682  public double[] distributionForInstance(Instance inst) throws Exception {
683    if (!m_initialized) {
684      mapToMiningSchema(inst.dataset());
685    }
686    double[] preds = null;
687    if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
688      preds = new double[1];
689    } else {
690      preds = new double[m_miningSchema.getFieldsAsInstances().classAttribute().numValues()];
691    }
692
693    // create an array of doubles that holds values from the incoming
694    // instance; in order of the fields in the mining schema. We will
695    // also handle missing values and outliers here.
696    //    System.err.println(inst);
697    double[] incoming = m_fieldsMap.instanceToSchema(inst, m_miningSchema);
698
699    // scan for missing values. If there are still missing values after instanceToSchema(),
700    // then missing value handling has been deferred to the PMML scheme. The specification
701    // (Regression PMML 3.2) seems to contradict itself with regards to classification and categorical
702    // variables. In one place it states that if a categorical variable is missing then
703    // variable_name=value is 0 for any value. Further down in the document it states: "if
704    // one or more of the y_j cannot be evaluated because the value in one of the referenced
705    // fields is missing, then the following formulas (for computing p_j) do not apply. In
706    // that case the predictions are defined by the priorProbability values in the Target
707    // element".
708
709    // In this implementation we will default to information in the Target element (default
710    // value for numeric prediction and prior probabilities for classification). If there is
711    // no Target element defined, then an Exception is thrown.
712
713    boolean hasMissing = false;
714    for (int i = 0; i < incoming.length; i++) {
715      if (i != m_miningSchema.getFieldsAsInstances().classIndex() && 
716          Utils.isMissingValue(incoming[i])) {
717        hasMissing = true;
718        break;
719      }
720    }
721
722    if (hasMissing) {
723      if (!m_miningSchema.hasTargetMetaData()) {
724        String message = "[Regression] WARNING: Instance to predict has missing value(s) but "
725          + "there is no missing value handling meta data and no "
726          + "prior probabilities/default value to fall back to. No "
727          + "prediction will be made (" 
728          + ((m_miningSchema.getFieldsAsInstances().classAttribute().isNominal() ||
729              m_miningSchema.getFieldsAsInstances().classAttribute().isString())
730              ? "zero probabilities output)."
731              : "NaN output).");
732        if (m_log == null) {
733          System.err.println(message);
734        } else {
735          m_log.logMessage(message);
736        }
737        if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
738          preds[0] = Utils.missingValue();
739        }
740        return preds;
741      } else {
742        // use prior probablilities/default value
743        TargetMetaInfo targetData = m_miningSchema.getTargetMetaData();
744        if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
745          preds[0] = targetData.getDefaultValue();
746        } else {
747          Instances miningSchemaI = m_miningSchema.getFieldsAsInstances();
748          for (int i = 0; i < miningSchemaI.classAttribute().numValues(); i++) {
749            preds[i] = targetData.getPriorProbability(miningSchemaI.classAttribute().value(i));
750          }
751        }
752        return preds;
753      }
754    } else {
755      // loop through the RegressionTables
756      for (int i = 0; i < m_regressionTables.length; i++) {
757        m_regressionTables[i].predict(preds, incoming);
758      }
759 
760      // Now apply the normalization
761      switch (m_normalizationMethod) {
762      case NONE:
763        // nothing to be done
764        break;
765      case SIMPLEMAX:
766        Utils.normalize(preds);
767        break;
768      case SOFTMAX:
769        for (int i = 0; i < preds.length; i++) {
770          preds[i] = Math.exp(preds[i]);
771        }
772        if (preds.length == 1) {
773          // hack for those models that do binary logistic regression as
774          // a numeric prediction model
775          preds[0] = preds[0] / (preds[0] + 1.0);
776        } else {
777          Utils.normalize(preds);
778        }
779        break;
780      case LOGIT:
781        for (int i = 0; i < preds.length; i++) {
782          preds[i] = 1.0 / (1.0 + Math.exp(-preds[i]));
783        }
784        Utils.normalize(preds);
785        break;
786      case PROBIT:
787        for (int i = 0; i < preds.length; i++) {
788          preds[i] = weka.core.matrix.Maths.pnorm(preds[i]);
789        }
790        Utils.normalize(preds);
791        break;
792      case CLOGLOG:
793        // note this is supposed to be illegal for regression
794        for (int i = 0; i < preds.length; i++) {
795          preds[i] = 1.0 - Math.exp(-Math.exp(-preds[i]));
796        }
797        Utils.normalize(preds);
798        break;
799      case EXP:
800        for (int i = 0; i < preds.length; i++) {
801          preds[i] = Math.exp(preds[i]);
802        }
803        Utils.normalize(preds);
804        break;
805      case LOGLOG:
806        // note this is supposed to be illegal for regression
807        for (int i = 0; i < preds.length; i++) {
808          preds[i] = Math.exp(-Math.exp(-preds[i]));
809        }
810        Utils.normalize(preds);
811        break;
812      case CAUCHIT:
813        for (int i = 0; i < preds.length; i++) {
814          preds[i] = 0.5 + (1.0 / Math.PI) * Math.atan(preds[i]);
815        }
816        Utils.normalize(preds);
817        break;
818      default:
819          throw new Exception("[Regression] unknown normalization method");
820      }
821
822      // If there is a Target defined, and this is a numeric prediction problem,
823      // then apply any min, max, rescaling etc.
824      if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()
825          && m_miningSchema.hasTargetMetaData()) {
826        TargetMetaInfo targetData = m_miningSchema.getTargetMetaData();
827        preds[0] = targetData.applyMinMaxRescaleCast(preds[0]);
828      }
829    }
830   
831    return preds;
832  }
833
834  /* (non-Javadoc)
835   * @see weka.core.RevisionHandler#getRevision()
836   */
837  public String getRevision() {
838    return RevisionUtils.extract("$Revision: 6018 $");
839  }
840}
Note: See TracBrowser for help on using the repository browser.