source: src/main/java/weka/classifiers/pmml/consumer/GeneralRegression.java @ 26

Last change on this file since 26 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 51.7 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    GeneralRegression.java
19 *    Copyright (C) 2008 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.pmml.consumer;
24
25import java.io.Serializable;
26import java.util.ArrayList;
27import org.w3c.dom.Element;
28import org.w3c.dom.Node;
29import org.w3c.dom.NodeList;
30
31import weka.core.Attribute;
32import weka.core.Instance;
33import weka.core.Instances;
34import weka.core.RevisionUtils;
35import weka.core.Utils;
36import weka.core.pmml.*;
37
38/**
39 * Class implementing import of PMML General Regression model. Can be
40 * used as a Weka classifier for prediction (buildClassifier()
41 * raises an Exception).
42 *
43 * @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
44 * @version $Revision: 5987 $
45 */
46public class GeneralRegression extends PMMLClassifier
47  implements Serializable {
48
49  /**
50   * For serialization
51   */
52  private static final long serialVersionUID = 2583880411828388959L;
53
54  /**
55   * Enumerated type for the model type.
56   */
57  enum ModelType {
58
59    // same type of model
60    REGRESSION ("regression"), 
61      GENERALLINEAR ("generalLinear"), 
62      MULTINOMIALLOGISTIC ("multinomialLogistic"),
63      ORDINALMULTINOMIAL ("ordinalMultinomial"), 
64      GENERALIZEDLINEAR ("generalizedLinear");
65
66    private final String m_stringVal;
67    ModelType(String name) {
68      m_stringVal = name;
69    }
70   
71    public String toString() {
72      return m_stringVal;
73    }
74  }
75 
76  // the model type
77  protected ModelType m_modelType = ModelType.REGRESSION;
78
79  // the model name (if defined)
80  protected String m_modelName;
81   
82  // the algorithm name (if defined)
83  protected String m_algorithmName;
84
85  // the function type (regression or classification)
86  protected int m_functionType = Regression.RegressionTable.REGRESSION;
87
88  /**
89   * Enumerated type for the cumulative link function
90   * (ordinal multinomial model type only).
91   */
92  enum CumulativeLinkFunction {
93    NONE ("none") {
94      double eval(double value, double offset) {
95        return Double.NaN; // no evaluation defined in this case!
96      }
97    },
98    LOGIT ("logit") {
99      double eval(double value, double offset) {
100        return 1.0 / (1.0 + Math.exp(-(value + offset)));
101      }
102    },
103    PROBIT ("probit") {
104      double eval(double value, double offset) {
105        return weka.core.matrix.Maths.pnorm(value + offset); 
106      }
107    },
108    CLOGLOG ("cloglog") {
109      double eval(double value, double offset) {
110        return 1.0 - Math.exp(-Math.exp(value + offset));
111      }
112    },
113    LOGLOG ("loglog") {
114      double eval(double value, double offset) {
115        return Math.exp(-Math.exp(-(value + offset))); 
116      }
117    },
118    CAUCHIT ("cauchit") {
119      double eval(double value, double offset) {
120        return 0.5 + (1.0 / Math.PI) * Math.atan(value + offset);
121      }
122    };
123
124    /**
125     * Evaluation function.
126     *
127     * @param value the raw response value
128     * @param offset the offset to add to the raw value
129     * @return the result of the link function
130     */
131    abstract double eval(double value, double offset);
132   
133    private final String m_stringVal;
134   
135    /**
136     * Constructor
137     *
138     * @param name textual name for this enum
139     */
140    CumulativeLinkFunction(String name) {
141      m_stringVal = name;
142    }
143   
144    /* (non-Javadoc)
145     * @see java.lang.Enum#toString()
146     */
147    public String toString() {
148      return m_stringVal;
149    }
150  }
151 
152  // cumulative link function (ordinal multinomial only)
153  protected CumulativeLinkFunction m_cumulativeLinkFunction
154    = CumulativeLinkFunction.NONE;
155
156
157  /**
158   * Enumerated type for the link function (general linear and
159   * generalized linear model types only).
160   */
161  enum LinkFunction {
162    NONE ("none") {
163      double eval(double value, double offset, double trials,
164                  double distParam, double linkParam) {
165        return Double.NaN; // no evaluation defined in this case!
166      }
167    },
168    CLOGLOG ("cloglog") {
169      double eval(double value, double offset, double trials,
170                  double distParam, double linkParam) {
171        return (1.0 - Math.exp(-Math.exp(value + offset))) * trials;
172      }
173    },
174    IDENTITY ("identity") {
175      double eval(double value, double offset, double trials,
176                  double distParam, double linkParam) {
177        return (value + offset) * trials;
178      }
179    },
180    LOG ("log") {
181      double eval(double value, double offset, double trials,
182                  double distParam, double linkParam) {
183        return Math.exp(value + offset) * trials;
184      }
185    },
186    LOGC ("logc") {
187      double eval(double value, double offset, double trials,
188                  double distParam, double linkParam) {
189        return (1.0 - Math.exp(value + offset)) * trials;
190      }
191    },
192    LOGIT ("logit") {
193      double eval(double value, double offset, double trials,
194                  double distParam, double linkParam) {
195        return (1.0 / (1.0 + Math.exp(-(value + offset)))) * trials;
196      }
197    },
198    LOGLOG ("loglog") {
199      double eval(double value, double offset, double trials,
200                  double distParam, double linkParam) {
201        return Math.exp(-Math.exp(-(value + offset))) * trials;
202      }
203    },
204    NEGBIN ("negbin") {
205      double eval(double value, double offset, double trials,
206                  double distParam, double linkParam) {
207        return (1.0 / (distParam * (Math.exp(-(value + offset)) - 1.0))) * trials;
208      }
209    },
210    ODDSPOWER ("oddspower") {
211      double eval(double value, double offset, double trials,
212                  double distParam, double linkParam) {
213        return (linkParam < 0.0 || linkParam > 0.0)
214        ? (1.0 / (1.0 + Math.pow(1.0 + linkParam * (value + offset), (-1.0 / linkParam)))) * trials
215        : (1.0 / (1.0 + Math.exp(-(value + offset)))) * trials;
216      }
217    },
218    POWER ("power") {
219      double eval(double value, double offset, double trials,
220                  double distParam, double linkParam) {
221        return (linkParam < 0.0 || linkParam > 0.0)
222        ? Math.pow(value + offset, (1.0 / linkParam)) * trials
223            : Math.exp(value + offset) * trials;
224      }
225    },
226    PROBIT ("probit") {
227      double eval(double value, double offset, double trials,
228                  double distParam, double linkParam) {
229        return weka.core.matrix.Maths.pnorm(value + offset) * trials;
230      }
231    };
232
233    /**
234     * Evaluation function.
235     *
236     * @param value the raw response value
237     * @param offset the offset to add to the raw value
238     * @param trials the trials value to multiply the result by
239     * @param distParam the distribution parameter (negbin only)
240     * @param linkParam the link parameter (power and oddspower only)
241     * @return the result of the link function
242     */
243    abstract double eval(double value, double offset, double trials, 
244                         double distParam, double linkParam);
245   
246    private final String m_stringVal;
247   
248    /**
249     * Constructor.
250     *
251     * @param name the textual name of this link function
252     */
253    LinkFunction(String name) {
254      m_stringVal = name;
255    }
256
257    /* (non-Javadoc)
258     * @see java.lang.Enum#toString()
259     */
260    public String toString() {
261      return m_stringVal;
262    }
263  }
264 
265  // link function (generalLinear model type only)
266  protected LinkFunction m_linkFunction = LinkFunction.NONE;
267  protected double m_linkParameter = Double.NaN;
268  protected String m_trialsVariable;
269  protected double m_trialsValue = Double.NaN;
270
271  /**
272   * Enumerated type for the distribution (general linear
273   * and generalized linear model types only).
274   */
275  enum Distribution {
276    NONE ("none"),
277    NORMAL ("normal"),
278    BINOMIAL ("binomial"),
279    GAMMA ("gamma"),
280    INVGAUSSIAN ("igauss"),
281    NEGBINOMIAL ("negbin"),
282    POISSON ("poisson");
283
284    private final String m_stringVal;
285    Distribution(String name) {
286      m_stringVal = name;
287    }
288
289    /* (non-Javadoc)
290     * @see java.lang.Enum#toString()
291     */
292    public String toString() {
293      return m_stringVal;
294    }
295  }
296 
297  // generalLinear and generalizedLinear model type only
298  protected Distribution m_distribution = Distribution.NORMAL;
299
300  // ancillary parameter value for the negative binomial distribution
301  protected double m_distParameter = Double.NaN;
302
303  // if present, this variable is used during scoring generalizedLinear/generalLinear or
304  // ordinalMultinomial models
305  protected String m_offsetVariable;
306
307  // if present, this variable is used during scoring generalizedLinear/generalLinear or
308  // ordinalMultinomial models. It works like a user-specified intercept.
309  // At most, only one of offsetVariable or offsetValue may be specified.
310  protected double m_offsetValue = Double.NaN;
311
312  /**
313   * Small inner class to hold the name of a parameter plus
314   * its optional descriptive label
315   */
316  static class Parameter implements Serializable {
317    // ESCA-JAVA0096:
318    /** For serialization */
319    // CHECK ME WITH serialver
320    private static final long serialVersionUID = 6502780192411755341L;
321
322    protected String m_name = null;
323    protected String m_label = null;
324  }
325
326  // List of model parameters
327  protected ArrayList<Parameter> m_parameterList = new ArrayList<Parameter>();
328
329  /**
330   * Small inner class to hold the name of a factor or covariate,
331   * plus the index of the attribute it corresponds to in the
332   * mining schema.
333   */
334  static class Predictor implements Serializable {
335    /** For serialization */
336    // CHECK ME WITH serialver
337    private static final long serialVersionUID = 6502780192411755341L;
338
339    protected String m_name = null;
340    protected int m_miningSchemaIndex = -1;
341   
342    public String toString() {
343      return m_name;
344    }
345  }
346 
347  // FactorList
348  protected ArrayList<Predictor> m_factorList = new ArrayList<Predictor>();
349
350  // CovariateList
351  protected ArrayList<Predictor> m_covariateList = new ArrayList<Predictor>();
352
353  /**
354   * Small inner class to hold details on a predictor-to-parameter
355   * correlation.
356   */
357  static class PPCell implements Serializable {
358    /** For serialization */
359    // CHECK ME WITH serialver
360    private static final long serialVersionUID = 6502780192411755341L;
361   
362    protected String m_predictorName = null;
363    protected String m_parameterName = null;
364
365    // either the exponent of a numeric attribute or the index of
366    // a discrete value
367    protected double m_value = 0;
368
369    // optional. The default is for all target categories to
370    // share the same PPMatrix.
371    // TO-DO: implement multiple PPMatrixes
372    protected String m_targetCategory = null;
373   
374  }
375 
376  // PPMatrix (predictor-to-parameter matrix)
377  // rows = parameters, columns = predictors (attributes)
378  protected PPCell[][] m_ppMatrix;
379
380  /**
381   * Small inner class to hold a single entry in the
382   * ParamMatrix (parameter matrix).
383   */
384  static class PCell implements Serializable {
385   
386    /** For serialization */
387    // CHECK ME WITH serialver
388    private static final long serialVersionUID = 6502780192411755341L;
389
390    // may be null for numeric target. May also be null if this coefficent
391    // applies to all target categories.
392    protected String m_targetCategory = null;
393    protected String m_parameterName = null;
394    // coefficient
395    protected double m_beta = 0.0;
396    // optional degrees of freedom
397    protected int m_df = -1;
398  }
399 
400  // ParamMatrix. rows = target categories (only one if target is numeric),
401  // columns = parameters (in order that they occur in the parameter list).
402  protected PCell[][] m_paramMatrix;
403
404  /**
405   * Constructs a GeneralRegression classifier.
406   *
407   * @param model the Element that holds the model definition
408   * @param dataDictionary the data dictionary as a set of Instances
409   * @param miningSchema the mining schema
410   * @throws Exception if there is a problem constructing the general regression
411   * object from the PMML.
412   */
413  public GeneralRegression(Element model, Instances dataDictionary,
414                           MiningSchema miningSchema) throws Exception {
415
416    super(dataDictionary, miningSchema);
417 
418    // get the model type
419    String mType = model.getAttribute("modelType");
420    boolean found = false;
421    for (ModelType m : ModelType.values()) {
422      if (m.toString().equals(mType)) {
423        m_modelType = m;
424        found = true;
425        break;
426      }     
427    }
428    if (!found) {
429      throw new Exception("[GeneralRegression] unknown model type: " + mType);
430    }
431
432    if (m_modelType == ModelType.ORDINALMULTINOMIAL) {
433      // get the cumulative link function
434      String cLink = model.getAttribute("cumulativeLink");
435      found = false;
436      for (CumulativeLinkFunction c : CumulativeLinkFunction.values()) {
437        if (c.toString().equals(cLink)) {
438          m_cumulativeLinkFunction = c;
439          found = true;
440          break;
441        }
442      }
443      if (!found) {
444        throw new Exception("[GeneralRegression] cumulative link function " + cLink);
445      }
446    } else if (m_modelType == ModelType.GENERALIZEDLINEAR || 
447                m_modelType == ModelType.GENERALLINEAR) {
448      // get the link function
449      String link = model.getAttribute("linkFunction");
450      found = false;
451      for (LinkFunction l : LinkFunction.values()) {
452        if (l.toString().equals(link)) {
453          m_linkFunction = l;
454          found = true;
455          break;
456        }
457      }
458      if (!found) {
459        throw new Exception("[GeneralRegression] unknown link function " + link);
460      }
461
462      // get the link parameter
463      String linkP = model.getAttribute("linkParameter");
464      if (linkP != null && linkP.length() > 0) {
465        try {
466          m_linkParameter = Double.parseDouble(linkP);
467        } catch (IllegalArgumentException ex) {
468          throw new Exception("[GeneralRegression] unable to parse the link parameter");
469        }
470      }
471
472      // get the trials variable
473      String trials = model.getAttribute("trialsVariable");
474      if (trials != null && trials.length() > 0) {
475        m_trialsVariable = trials;
476      }
477
478      // get the trials value
479      String trialsV = model.getAttribute("trialsValue");
480      if (trialsV != null && trialsV.length() > 0) {
481        try {
482          m_trialsValue = Double.parseDouble(trialsV);
483        } catch (IllegalArgumentException ex) {
484          throw new Exception("[GeneralRegression] unable to parse the trials value"); 
485        }
486      }
487    }
488 
489    String mName = model.getAttribute("modelName");
490    if (mName != null && mName.length() > 0) {
491      m_modelName = mName;
492    }
493
494    String fName = model.getAttribute("functionName");
495    if (fName.equals("classification")) {
496      m_functionType = Regression.RegressionTable.CLASSIFICATION;
497    }
498
499    String algName = model.getAttribute("algorithmName");
500    if (algName != null && algName.length() > 0) {
501      m_algorithmName = algName;
502    }
503
504    String distribution = model.getAttribute("distribution");
505    if (distribution != null && distribution.length() > 0) {
506      found = false;
507      for (Distribution d : Distribution.values()) {
508        if (d.toString().equals(distribution)) {
509          m_distribution = d;
510          found = true;
511          break;
512        }
513      }
514      if (!found) {
515        throw new Exception("[GeneralRegression] unknown distribution type " + distribution);
516      }
517    }
518
519    String distP = model.getAttribute("distParameter");
520    if (distP != null && distP.length() > 0) {
521      try {
522        m_distParameter = Double.parseDouble(distP);
523      } catch (IllegalArgumentException ex) {
524        throw new Exception("[GeneralRegression] unable to parse the distribution parameter");
525      }
526    }
527
528    String offsetV = model.getAttribute("offsetVariable");
529    if (offsetV != null && offsetV.length() > 0) {
530       m_offsetVariable = offsetV;
531    }
532
533    String offsetVal = model.getAttribute("offsetValue");
534    if (offsetVal != null && offsetVal.length() > 0) {
535      try {
536        m_offsetValue = Double.parseDouble(offsetVal);
537      } catch (IllegalArgumentException ex) {
538        throw new Exception("[GeneralRegression] unable to parse the offset value");
539      }
540    }
541
542    // get the parameter list
543    readParameterList(model);
544   
545    // get the factors and covariates
546    readFactorsAndCovariates(model, "FactorList");
547    readFactorsAndCovariates(model, "CovariateList");
548
549    // read the PPMatrix
550    readPPMatrix(model);
551
552    // read the parameter estimates
553    readParamMatrix(model);
554  }
555
556  /**
557   * Read the list of parameters.
558   *
559   * @param model the Element that contains the model
560   * @throws Exception if there is some problem with extracting the
561   * parameters.
562   */
563  protected void readParameterList(Element model) throws Exception {
564    NodeList paramL = model.getElementsByTagName("ParameterList");
565
566    // should be just one parameter list
567    if (paramL.getLength() == 1) {
568      Node paramN = paramL.item(0);
569      if (paramN.getNodeType() == Node.ELEMENT_NODE) {
570        NodeList parameterList = ((Element)paramN).getElementsByTagName("Parameter");
571        for (int i = 0; i < parameterList.getLength(); i++) {
572          Node parameter = parameterList.item(i);
573          if (parameter.getNodeType() == Node.ELEMENT_NODE) {
574            Parameter p = new Parameter();
575            p.m_name = ((Element)parameter).getAttribute("name");
576            String label = ((Element)parameter).getAttribute("label");
577            if (label != null && label.length() > 0) {
578              p.m_label = label;
579            }
580            m_parameterList.add(p);
581          }
582        }
583      }
584    } else {
585      throw new Exception("[GeneralRegression] more than one parameter list!");
586    }
587  }
588
589  /**
590   * Read the lists of factors and covariates.
591   *
592   * @param model the Element that contains the model
593   * @param factorOrCovariate holds the String "FactorList" or
594   * "CovariateList"
595   * @throws Exception if there is a factor or covariate listed
596   * that isn't in the mining schema
597   */
598  protected void readFactorsAndCovariates(Element model, 
599                                          String factorOrCovariate) 
600    throws Exception {
601    Instances miningSchemaI = m_miningSchema.getFieldsAsInstances();
602
603    NodeList factorL = model.getElementsByTagName(factorOrCovariate);
604    if (factorL.getLength() == 1) { // should be 0 or 1 FactorList element
605      Node factor = factorL.item(0);
606      if (factor.getNodeType() == Node.ELEMENT_NODE) {
607        NodeList predL = ((Element)factor).getElementsByTagName("Predictor");
608        for (int i = 0; i < predL.getLength(); i++) {
609          Node pred = predL.item(i);
610          if (pred.getNodeType() == Node.ELEMENT_NODE) {
611            Predictor p = new Predictor();
612            p.m_name = ((Element)pred).getAttribute("name");
613            // find the index of this predictor in the mining schema
614            boolean found = false;
615            for (int j = 0; j < miningSchemaI.numAttributes(); j++) {
616              if (miningSchemaI.attribute(j).name().equals(p.m_name)) {
617                found = true;
618                p.m_miningSchemaIndex = j;
619                break;
620              }
621            }
622            if (found) {
623              if (factorOrCovariate.equals("FactorList")) {
624                m_factorList.add(p);
625              } else {
626                m_covariateList.add(p);
627              }
628            } else {
629              throw new Exception("[GeneralRegression] reading factors and covariates - "
630                                  + "unable to find predictor " +
631                                  p.m_name + " in the mining schema");
632            }
633          }
634        }
635      }
636    } else if (factorL.getLength() > 1){
637      throw new Exception("[GeneralRegression] more than one " + factorOrCovariate
638                          + "! ");
639    }
640  }
641
642  /**
643   * Read the PPMatrix from the xml. Does not handle multiple PPMatrixes yet.
644   *
645   * @param model the Element that contains the model
646   * @throws Exception if there is a problem parsing cell values.
647   */
648  protected void readPPMatrix(Element model) throws Exception {
649    Instances miningSchemaI = m_miningSchema.getFieldsAsInstances();
650   
651    NodeList matrixL = model.getElementsByTagName("PPMatrix");
652
653    // should be exactly one PPMatrix
654    if (matrixL.getLength() == 1) {
655      // allocate space for the matrix
656      // column that corresponds to the class will be empty (and will be missed out
657      // when printing the model).
658      m_ppMatrix = new PPCell[m_parameterList.size()][miningSchemaI.numAttributes()];
659
660      Node ppM = matrixL.item(0);
661      if (ppM.getNodeType() == Node.ELEMENT_NODE) {
662        NodeList cellL = ((Element)ppM).getElementsByTagName("PPCell");
663        for (int i = 0; i < cellL.getLength(); i++) {
664          Node cell = cellL.item(i);
665          if (cell.getNodeType() == Node.ELEMENT_NODE) {
666            String predictorName = ((Element)cell).getAttribute("predictorName");
667            String parameterName = ((Element)cell).getAttribute("parameterName");
668            String value = ((Element)cell).getAttribute("value");
669            double expOrIndex = -1;
670            int predictorIndex = -1;
671            int parameterIndex = -1;
672            for (int j = 0; j < m_parameterList.size(); j++) {
673              if (m_parameterList.get(j).m_name.equals(parameterName)) {
674                parameterIndex = j;
675                break;
676              }
677            }
678            if (parameterIndex == -1) {
679              throw new Exception("[GeneralRegression] unable to find parameter name "
680                                  + parameterName + " in parameter list");
681            }
682
683            Predictor p = getCovariate(predictorName);
684            if (p != null) {
685              try {
686                expOrIndex = Double.parseDouble(value);
687                predictorIndex = p.m_miningSchemaIndex;
688              } catch (IllegalArgumentException ex) {
689                throw new Exception("[GeneralRegression] unable to parse PPCell value: "
690                                    + value);
691              }
692            } else {
693              // try as a factor
694              p = getFactor(predictorName);
695              if (p != null) {
696                // An example pmml file from DMG seems to suggest that it
697                // is possible for a continuous variable in the mining schema
698                // to be treated as a factor, so we have to check for this
699                if (miningSchemaI.attribute(p.m_miningSchemaIndex).isNumeric()) {
700                  // parse this value as a double. It will be treated as a value
701                  // to match rather than an exponent since we are dealing with
702                  // a factor here
703                  try {
704                    expOrIndex = Double.parseDouble(value);
705                  } catch (IllegalArgumentException ex) {
706                    throw new Exception("[GeneralRegresion] unable to parse PPCell value: "
707                                        + value);
708                  }
709                } else {
710                  // it is a nominal attribute in the mining schema so find
711                  // the index that correponds to this value
712                  Attribute att = miningSchemaI.attribute(p.m_miningSchemaIndex); 
713                  expOrIndex = att.indexOfValue(value);
714                  if (expOrIndex == -1) {
715                    throw new Exception("[GeneralRegression] unable to find PPCell value "
716                                        + value + " in mining schema attribute "
717                                        + att.name());
718                  }
719                }
720              } else {
721                throw new Exception("[GeneralRegression] cant find predictor "
722                                    + predictorName + "in either the factors list "
723                                    + "or the covariates list");
724              }
725              predictorIndex = p.m_miningSchemaIndex;
726            }
727
728            // fill in cell value
729            PPCell ppc = new PPCell();
730            ppc.m_predictorName = predictorName; ppc.m_parameterName = parameterName;
731            ppc.m_value = expOrIndex;
732
733            // TO-DO: ppc.m_targetCategory (when handling for multiple PPMatrixes is implemented)
734            m_ppMatrix[parameterIndex][predictorIndex] = ppc;
735          }
736        }
737      }
738    } else {
739      throw new Exception("[GeneralRegression] more than one PPMatrix!");
740    }
741  }
742
743  private Predictor getCovariate(String predictorName) {
744    for (int i = 0; i < m_covariateList.size(); i++) {
745      if (predictorName.equals(m_covariateList.get(i).m_name)) {
746        return m_covariateList.get(i);
747      }
748    }
749    return null;
750  }
751
752  private Predictor getFactor(String predictorName) {
753    for (int i = 0; i < m_factorList.size(); i++) {
754      if (predictorName.equals(m_factorList.get(i).m_name)) {
755        return m_factorList.get(i);
756      }
757    }
758    return null;
759  }
760
761  /**
762   * Read the parameter matrix from the xml.
763   *
764   * @param model Element that holds the model
765   * @throws Exception if a problem is encountered during extraction of
766   * the parameter matrix
767   */
768  private void readParamMatrix(Element model) throws Exception {
769
770    Instances miningSchemaI = m_miningSchema.getFieldsAsInstances();
771    Attribute classAtt = miningSchemaI.classAttribute();
772    // used when function type is classification but class attribute is numeric
773    // in the mining schema. We will assume that there is a Target specified in
774    // the pmml that defines the legal values for this class.
775    ArrayList<String> targetVals = null;
776
777    NodeList matrixL = model.getElementsByTagName("ParamMatrix");
778    if (matrixL.getLength() != 1) {
779      throw new Exception("[GeneralRegression] more than one ParamMatrix!");
780    }
781    Element matrix = (Element)matrixL.item(0);
782
783
784    // check for the case where the class in the mining schema is numeric,
785    // but this attribute is treated as discrete
786    if (m_functionType == Regression.RegressionTable.CLASSIFICATION &&
787        classAtt.isNumeric()) {
788      // try and convert the class attribute to nominal. For this to succeed
789      // there has to be a Target element defined in the PMML.
790      if (!m_miningSchema.hasTargetMetaData()) {
791        throw new Exception("[GeneralRegression] function type is classification and "
792                            + "class attribute in mining schema is numeric, however, "
793                            + "there is no Target element "
794                            + "specifying legal discrete values for the target!");
795
796      }
797
798      if (m_miningSchema.getTargetMetaData().getOptype() 
799          != TargetMetaInfo.Optype.CATEGORICAL) {
800        throw new Exception("[GeneralRegression] function type is classification and "
801                            + "class attribute in mining schema is numeric, however "
802                            + "Target element in PMML does not have optype categorical!");
803      }
804
805      // OK now get legal values
806      targetVals = m_miningSchema.getTargetMetaData().getValues();
807      if (targetVals.size() == 0) {
808        throw new Exception("[GeneralRegression] function type is classification and "
809                            + "class attribute in mining schema is numeric, however "
810                            + "Target element in PMML does not have any discrete values "
811                            + "defined!");
812      }
813
814      // Finally, convert the class in the mining schema to nominal
815      m_miningSchema.convertNumericAttToNominal(miningSchemaI.classIndex(), targetVals);
816    }
817   
818    // allocate space for the matrix
819    m_paramMatrix = 
820        new PCell[(classAtt.isNumeric())
821                  ? 1
822                  : classAtt.numValues()][m_parameterList.size()];
823
824    NodeList pcellL = matrix.getElementsByTagName("PCell");
825    for (int i = 0; i < pcellL.getLength(); i++) {
826      // indicates that that this beta applies to all target categories
827      // or target is numeric
828      int targetCategoryIndex = -1;
829      int parameterIndex = -1;
830      Node pcell = pcellL.item(i);
831      if (pcell.getNodeType() == Node.ELEMENT_NODE) {
832        String paramName = ((Element)pcell).getAttribute("parameterName");
833        String targetCatName = ((Element)pcell).getAttribute("targetCategory");
834        String coefficient = ((Element)pcell).getAttribute("beta");
835        String df = ((Element)pcell).getAttribute("df");
836
837        for (int j = 0; j < m_parameterList.size(); j++) {
838          if (m_parameterList.get(j).m_name.equals(paramName)) {
839            parameterIndex = j;
840            // use the label if defined
841            if (m_parameterList.get(j).m_label != null) {
842              paramName = m_parameterList.get(j).m_label;
843            }
844            break;
845          }
846        }
847        if (parameterIndex == -1) {
848          throw new Exception("[GeneralRegression] unable to find parameter name "
849                              + paramName + " in parameter list");
850        }
851
852        if (targetCatName != null && targetCatName.length() > 0) {
853          if (classAtt.isNominal() || classAtt.isString()) {
854            targetCategoryIndex = classAtt.indexOfValue(targetCatName);
855          } else {
856            throw new Exception("[GeneralRegression] found a PCell with a named "
857                                + "target category: " + targetCatName
858                                + " but class attribute is numeric in "
859                                + "mining schema");
860          }
861        }
862
863        PCell p = new PCell();
864        if (targetCategoryIndex != -1) {
865          p.m_targetCategory = targetCatName;
866        }
867        p.m_parameterName = paramName;
868        try {
869          p.m_beta = Double.parseDouble(coefficient);
870        } catch (IllegalArgumentException ex) {
871          throw new Exception("[GeneralRegression] unable to parse beta value "
872                              + coefficient + " as a double from PCell");
873        }
874        if (df != null && df.length() > 0) {
875          try {
876            p.m_df = Integer.parseInt(df);
877          } catch (IllegalArgumentException ex) {
878            throw new Exception("[GeneralRegression] unable to parse df value "
879                              + df + " as an int from PCell");
880          }
881        }
882       
883        if (targetCategoryIndex != -1) {
884          m_paramMatrix[targetCategoryIndex][parameterIndex] = p;
885        } else {
886          // this PCell to all target categories (covers numeric class, in
887          // which case there will be only one row in the matrix anyway)
888          for (int j = 0; j < m_paramMatrix.length; j++) {
889            m_paramMatrix[j][parameterIndex] = p;
890          }
891        }
892      }
893    }
894  }
895
896  /**
897   * Return a textual description of this general regression.
898   *
899   * @return a description of this general regression
900   */
901  public String toString() {
902    StringBuffer temp = new StringBuffer();
903    temp.append("PMML version " + getPMMLVersion());
904    if (!getCreatorApplication().equals("?")) {
905      temp.append("\nApplication: " + getCreatorApplication());
906    }
907    temp.append("\nPMML Model: " + m_modelType);
908    temp.append("\n\n");
909    temp.append(m_miningSchema);
910
911    if (m_factorList.size() > 0) {
912      temp.append("Factors:\n");
913      for (Predictor p : m_factorList) {
914        temp.append("\t" + p + "\n");
915      }
916    }
917    temp.append("\n");
918    if (m_covariateList.size() > 0) {
919      temp.append("Covariates:\n");
920      for (Predictor p : m_covariateList) {
921        temp.append("\t" + p + "\n");
922      }
923    }
924    temp.append("\n");
925   
926    printPPMatrix(temp);
927    temp.append("\n");
928    printParameterMatrix(temp);
929   
930    // do the link function stuff
931    temp.append("\n");
932   
933    if (m_linkFunction != LinkFunction.NONE) {
934      temp.append("Link function: " + m_linkFunction);
935      if (m_offsetVariable != null) {
936        temp.append("\n\tOffset variable " + m_offsetVariable);
937      } else if (!Double.isNaN(m_offsetValue)) {
938        temp.append("\n\tOffset value " + m_offsetValue);
939      }
940     
941      if (m_trialsVariable != null) {
942        temp.append("\n\tTrials variable " + m_trialsVariable);
943      } else if (!Double.isNaN(m_trialsValue)) {
944        temp.append("\n\tTrials value " + m_trialsValue);
945      }
946     
947      if (m_distribution != Distribution.NONE) {
948        temp.append("\nDistribution: " + m_distribution);
949      }
950     
951      if (m_linkFunction == LinkFunction.NEGBIN &&
952          m_distribution == Distribution.NEGBINOMIAL &&
953          !Double.isNaN(m_distParameter)) {
954        temp.append("\n\tDistribution parameter " + m_distParameter);
955      }
956     
957      if (m_linkFunction == LinkFunction.POWER ||
958          m_linkFunction == LinkFunction.ODDSPOWER) {
959        if (!Double.isNaN(m_linkParameter)) {
960          temp.append("\n\nLink parameter " + m_linkParameter);
961        }
962      }
963    }
964   
965    if (m_cumulativeLinkFunction != CumulativeLinkFunction.NONE) {
966      temp.append("Cumulative link function: " + m_cumulativeLinkFunction);
967     
968      if (m_offsetVariable != null) {
969        temp.append("\n\tOffset variable " + m_offsetVariable);
970      } else if (!Double.isNaN(m_offsetValue)) {
971        temp.append("\n\tOffset value " + m_offsetValue);
972      }
973    }
974    temp.append("\n");
975   
976    return temp.toString();
977  }
978 
979  /**
980   * Format and print the PPMatrix to the supplied StringBuffer.
981   *
982   * @param buff the StringBuffer to append to
983   */
984  protected void printPPMatrix(StringBuffer buff) {
985    Instances miningSchemaI = m_miningSchema.getFieldsAsInstances();
986    int maxAttWidth = 0;
987    for (int i = 0; i < miningSchemaI.numAttributes(); i++) {
988      Attribute a = miningSchemaI.attribute(i);
989      if (a.name().length() > maxAttWidth) {
990        maxAttWidth = a.name().length();
991      }
992    }
993
994    // check the width of the values
995    for (int i = 0; i < m_parameterList.size(); i++) {
996      for (int j = 0; j < miningSchemaI.numAttributes(); j++) {
997        if (m_ppMatrix[i][j] != null) {
998          double width = Math.log(Math.abs(m_ppMatrix[i][j].m_value)) /
999            Math.log(10.0);
1000          if (width < 0) {
1001            width = 1;
1002          }
1003          // decimal + # decimal places + 1
1004          width += 2.0;
1005          if ((int)width > maxAttWidth) {
1006            maxAttWidth = (int)width;
1007          }
1008          if (miningSchemaI.attribute(j).isNominal() || 
1009              miningSchemaI.attribute(j).isString()) {
1010            // check the width of this value
1011            String val = miningSchemaI.attribute(j).value((int)m_ppMatrix[i][j].m_value) + " ";
1012            if (val.length() > maxAttWidth) {
1013              maxAttWidth = val.length();
1014            }
1015          }
1016        }
1017      }
1018    }
1019
1020    // get the max parameter width
1021    int maxParamWidth = "Parameter  ".length();
1022    for (Parameter p : m_parameterList) {
1023      String temp = (p.m_label != null)
1024        ? p.m_label + " "
1025        : p.m_name + " ";
1026
1027      if (temp.length() > maxParamWidth) {
1028        maxParamWidth = temp.length();
1029      }
1030    }
1031
1032    buff.append("Predictor-to-Parameter matrix:\n");
1033    buff.append(PMMLUtils.pad("Predictor", " ", (maxParamWidth + (maxAttWidth * 2 + 2))
1034                              - "Predictor".length(), true));
1035    buff.append("\n" + PMMLUtils.pad("Parameter", " ", maxParamWidth - "Parameter".length(), false));
1036    // attribute names
1037    for (int i = 0; i < miningSchemaI.numAttributes(); i++) {
1038      if (i != miningSchemaI.classIndex()) {
1039        String attName = miningSchemaI.attribute(i).name();
1040        buff.append(PMMLUtils.pad(attName, " ", maxAttWidth + 1 - attName.length(), true));
1041      }
1042    }
1043    buff.append("\n");
1044
1045    for (int i = 0; i < m_parameterList.size(); i++) {
1046      Parameter param = m_parameterList.get(i);
1047      String paramS = (param.m_label != null)
1048        ? param.m_label
1049        : param.m_name;
1050      buff.append(PMMLUtils.pad(paramS, " ", 
1051                                maxParamWidth - paramS.length(), false));
1052      for (int j = 0; j < miningSchemaI.numAttributes(); j++) {
1053        if (j != miningSchemaI.classIndex()) {
1054          PPCell p = m_ppMatrix[i][j];
1055          String val = " ";
1056          if (p != null) {
1057            if (miningSchemaI.attribute(j).isNominal() ||
1058                miningSchemaI.attribute(j).isString()) {
1059              val = miningSchemaI.attribute(j).value((int)p.m_value);
1060            } else {
1061              val = "" + Utils.doubleToString(p.m_value, maxAttWidth, 4).trim();
1062            }
1063          }
1064          buff.append(PMMLUtils.pad(val, " ", maxAttWidth + 1 - val.length(), true));
1065        }
1066      }
1067      buff.append("\n");
1068    }
1069  }
1070
1071  /**
1072   * Format and print the parameter matrix to the supplied StringBuffer.
1073   *
1074   * @param buff the StringBuffer to append to
1075   */
1076  protected void printParameterMatrix(StringBuffer buff) {
1077    Instances miningSchemaI = m_miningSchema.getFieldsAsInstances();
1078
1079    // get the maximum class value width (nominal)
1080    int maxClassWidth = miningSchemaI.classAttribute().name().length();
1081    if (miningSchemaI.classAttribute().isNominal()
1082        || miningSchemaI.classAttribute().isString()) {
1083      for (int i = 0; i < miningSchemaI.classAttribute().numValues(); i++) {
1084        if (miningSchemaI.classAttribute().value(i).length() > maxClassWidth) {
1085          maxClassWidth = miningSchemaI.classAttribute().value(i).length();
1086        }
1087      }
1088    }
1089
1090    // get the maximum parameter name/label width
1091    int maxParamWidth = 0;
1092    for (int i = 0; i < m_parameterList.size(); i++) {
1093      Parameter p = m_parameterList.get(i);
1094      String val = (p.m_label != null)
1095        ? p.m_label + " "
1096        : p.m_name + " ";
1097      if (val.length() > maxParamWidth) {
1098        maxParamWidth = val.length();
1099      }
1100    }
1101
1102    // get the max beta value width
1103    int maxBetaWidth = "Coeff.".length();
1104    for (int i = 0; i < m_paramMatrix.length; i++) {
1105      for (int j = 0; j < m_parameterList.size(); j++) {
1106        PCell p = m_paramMatrix[i][j];
1107        if (p != null) {
1108          double width = Math.log(Math.abs(p.m_beta)) / Math.log(10);
1109          if (width < 0) {
1110            width = 1;
1111          }
1112          // decimal + # decimal places + 1
1113          width += 7.0;
1114          if ((int)width > maxBetaWidth) {
1115            maxBetaWidth = (int)width;
1116          }
1117        }
1118      }
1119    }
1120
1121    buff.append("Parameter estimates:\n");
1122    buff.append(PMMLUtils.pad(miningSchemaI.classAttribute().name(), " ", 
1123                              maxClassWidth + maxParamWidth + 2 - 
1124                              miningSchemaI.classAttribute().name().length(), false));
1125    buff.append(PMMLUtils.pad("Coeff.", " ", maxBetaWidth + 1 - "Coeff.".length(), true));
1126    buff.append(PMMLUtils.pad("df", " ", maxBetaWidth - "df".length(), true));
1127    buff.append("\n");
1128    for (int i = 0; i < m_paramMatrix.length; i++) {
1129      // scan for non-null entry for this class value
1130      boolean ok = false;
1131      for (int j = 0; j < m_parameterList.size(); j++) {
1132        if (m_paramMatrix[i][j] != null) {
1133          ok = true;
1134        }
1135      }
1136      if (!ok) {
1137        continue;
1138      }
1139      // first the class value (if nominal)
1140      String cVal = (miningSchemaI.classAttribute().isNominal() || 
1141          miningSchemaI.classAttribute().isString())
1142        ? miningSchemaI.classAttribute().value(i)
1143        : " ";
1144      buff.append(PMMLUtils.pad(cVal, " ", maxClassWidth - cVal.length(), false));     
1145      buff.append("\n");
1146      for (int j = 0; j < m_parameterList.size(); j++) {
1147        PCell p = m_paramMatrix[i][j];
1148        if (p != null) {
1149          String label = p.m_parameterName;
1150          buff.append(PMMLUtils.pad(label, " ", maxClassWidth + maxParamWidth + 2 -
1151                                    label.length(), true));
1152          String betaS = Utils.doubleToString(p.m_beta, maxBetaWidth, 4).trim();
1153          buff.append(PMMLUtils.pad(betaS, " ", maxBetaWidth + 1 - betaS.length(), true));
1154          String dfS = Utils.doubleToString(p.m_df, maxBetaWidth, 4).trim();
1155          buff.append(PMMLUtils.pad(dfS, " ", maxBetaWidth - dfS.length(), true));
1156          buff.append("\n");
1157        }
1158      }
1159    }
1160  }
1161 
1162  /**
1163   * Construct the incoming parameter vector based on the values
1164   * in the incoming test instance.
1165   *
1166   * @param incomingInst the values of the incoming test instance
1167   * @return the populated parameter vector ready to be multiplied against
1168   * the vector of coefficients.
1169   * @throws Exception if there is some problem whilst constructing the
1170   * parameter vector
1171   */
1172  private double[] incomingParamVector(double[] incomingInst) throws Exception {
1173    Instances miningSchemaI = m_miningSchema.getFieldsAsInstances();
1174    double[] incomingPV = new double[m_parameterList.size()];
1175   
1176    for (int i = 0; i < m_parameterList.size(); i++) {
1177      //
1178      // default is that this row represents the intercept.
1179      // this will be the case if there are all null entries in this row
1180      incomingPV[i] = 1.0;
1181
1182      // loop over the attributes (predictors)
1183      for (int j = 0; j < miningSchemaI.numAttributes(); j++) {       
1184        PPCell cellEntry = m_ppMatrix[i][j];
1185        Predictor p = null;
1186        if (cellEntry != null) {
1187          if ((p = getFactor(cellEntry.m_predictorName)) != null) {
1188            if ((int)incomingInst[p.m_miningSchemaIndex] == (int)cellEntry.m_value) {
1189              incomingPV[i] *= 1.0; // we have a match
1190            } else {
1191              incomingPV[i] *= 0.0;
1192            }
1193          } else if ((p = getCovariate(cellEntry.m_predictorName)) != null) {
1194              incomingPV[i] *= Math.pow(incomingInst[p.m_miningSchemaIndex], cellEntry.m_value);
1195          } else {
1196            throw new Exception("[GeneralRegression] can't find predictor "
1197                + cellEntry.m_predictorName + " in either the list of factors or covariates");
1198          }
1199        }
1200      }
1201    }
1202   
1203    return incomingPV;
1204  }
1205
1206  /**                                                                                                             
1207   * Classifies the given test instance. The instance has to belong to a                                         
1208   * dataset when it's being classified.                                                         
1209   *                                                                                                             
1210   * @param inst the instance to be classified                                                               
1211   * @return the predicted most likely class for the instance or                                                 
1212   * Utils.missingValue() if no prediction is made                                                             
1213   * @exception Exception if an error occurred during the prediction                                             
1214   */
1215  public double[] distributionForInstance(Instance inst) throws Exception {
1216    if (!m_initialized) {
1217      mapToMiningSchema(inst.dataset());
1218    }
1219    double[] preds = null;
1220    if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
1221      preds = new double[1];
1222    } else {
1223      preds = new double[m_miningSchema.getFieldsAsInstances().classAttribute().numValues()];
1224    }
1225   
1226    // create an array of doubles that holds values from the incoming
1227    // instance; in order of the fields in the mining schema. We will
1228    // also handle missing values and outliers here.
1229    double[] incoming = m_fieldsMap.instanceToSchema(inst, m_miningSchema);
1230   
1231    // In this implementation we will default to information in the Target element (default
1232    // value for numeric prediction and prior probabilities for classification). If there is
1233    // no Target element defined, then an Exception is thrown.
1234
1235    boolean hasMissing = false;
1236    for (int i = 0; i < incoming.length; i++) {
1237      if (i != m_miningSchema.getFieldsAsInstances().classIndex() && 
1238          Double.isNaN(incoming[i])) {
1239        hasMissing = true;
1240        break;
1241      }
1242    }
1243   
1244    if (hasMissing) {
1245      if (!m_miningSchema.hasTargetMetaData()) {
1246        String message = "[GeneralRegression] WARNING: Instance to predict has missing value(s) but "
1247          + "there is no missing value handling meta data and no "
1248          + "prior probabilities/default value to fall back to. No "
1249          + "prediction will be made (" 
1250          + ((m_miningSchema.getFieldsAsInstances().classAttribute().isNominal()
1251              || m_miningSchema.getFieldsAsInstances().classAttribute().isString())
1252              ? "zero probabilities output)."
1253              : "NaN output).");
1254        if (m_log == null) {
1255          System.err.println(message);
1256        } else {
1257          m_log.logMessage(message);
1258        }
1259       
1260        if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
1261          preds[0] = Utils.missingValue();
1262        }
1263        return preds;
1264      } else {
1265        // use prior probablilities/default value
1266        TargetMetaInfo targetData = m_miningSchema.getTargetMetaData();
1267        if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
1268          preds[0] = targetData.getDefaultValue();
1269        } else {
1270          Instances miningSchemaI = m_miningSchema.getFieldsAsInstances();
1271          for (int i = 0; i < miningSchemaI.classAttribute().numValues(); i++) {
1272            preds[i] = targetData.getPriorProbability(miningSchemaI.classAttribute().value(i));
1273          }
1274        }
1275        return preds;
1276      }
1277    } else {
1278      // construct input parameter vector here
1279      double[] inputParamVector = incomingParamVector(incoming);
1280      computeResponses(incoming, inputParamVector, preds);
1281    }
1282   
1283    return preds;
1284  }
1285 
1286  /**
1287   * Compute the responses for the function given the parameter values corresponding
1288   * to the current incoming instance.
1289   *
1290   * @param incomingInst raw incoming instance values (after missing value
1291   * replacement and outlier treatment)
1292   * @param incomingParamVector incoming instance values mapped to parameters
1293   * @param responses will contain the responses computed by the function
1294   * @throws Exception if something goes wrong
1295   */
1296  private void computeResponses(double[] incomingInst, 
1297                                double[] incomingParamVector,
1298                                double[] responses) throws Exception {
1299    for (int i = 0; i < responses.length; i++) {
1300      for (int j = 0; j < m_parameterList.size(); j++) {
1301        // a row of the parameter matrix should have all non-null entries
1302        // except for the last class (in the case of classification) which
1303        // should have just an intercept of 0. Need to handle the case where
1304        // no intercept has been defined in the pmml file for the last class
1305        PCell p = m_paramMatrix[i][j];
1306        if (p == null) {
1307          responses[i] += 0.0 * incomingParamVector[j];
1308        } else {
1309          responses[i] += incomingParamVector[j] * p.m_beta;
1310        }
1311      }
1312    }
1313   
1314    switch(m_modelType) {
1315    case MULTINOMIALLOGISTIC:
1316      computeProbabilitiesMultinomialLogistic(responses);
1317      break;
1318    case REGRESSION:
1319      // nothing to be done
1320      break;
1321    case GENERALLINEAR:
1322    case GENERALIZEDLINEAR:
1323      if (m_linkFunction != LinkFunction.NONE) {
1324        computeResponseGeneralizedLinear(incomingInst, responses);
1325      } else {
1326        throw new Exception("[GeneralRegression] no link function specified!");
1327      }
1328      break;
1329    case ORDINALMULTINOMIAL:
1330      if (m_cumulativeLinkFunction != CumulativeLinkFunction.NONE) {
1331        computeResponseOrdinalMultinomial(incomingInst, responses);
1332      } else {
1333        throw new Exception("[GeneralRegression] no cumulative link function specified!");
1334      }
1335      break;
1336      default:
1337        throw new Exception("[GeneralRegression] unknown model type");
1338    }
1339  }
1340 
1341  /**
1342   * Computes probabilities for the multinomial logistic model type.
1343   *
1344   * @param responses will hold the responses computed by the function.
1345   */
1346  private static void computeProbabilitiesMultinomialLogistic(double[] responses) {
1347    double[] r = responses.clone();
1348    for (int j = 0; j < r.length; j++) {
1349      double sum = 0;
1350      boolean overflow = false;
1351      for (int k = 0; k < r.length; k++) {
1352        if (r[k] - r[j] > 700) {
1353          overflow = true;
1354          break;
1355        }
1356        sum += Math.exp(r[k] - r[j]);
1357      }
1358      if (overflow) {
1359        responses[j] = 0.0;
1360      } else {
1361        responses[j] = 1.0 / sum;
1362      }
1363    }
1364  }
1365 
1366  /**
1367   * Computes responses for the general linear and generalized linear model
1368   * types.
1369   *
1370   * @param incomingInst the raw incoming instance values (after missing value
1371   * replacement and outlier treatment etc).
1372   * @param responses will hold the responses computed by the function
1373   * @throws Exception if a problem occurs.
1374   */
1375  private void computeResponseGeneralizedLinear(double[] incomingInst, 
1376                                                double[] responses) 
1377    throws Exception {
1378    double[] r = responses.clone();
1379   
1380    double offset = 0;
1381    if (m_offsetVariable != null) {
1382      Attribute offsetAtt = 
1383        m_miningSchema.getFieldsAsInstances().attribute(m_offsetVariable);
1384      if (offsetAtt == null) {
1385        throw new Exception("[GeneralRegression] unable to find offset variable "
1386            + m_offsetVariable + " in the mining schema!");
1387      }
1388      offset = incomingInst[offsetAtt.index()];
1389    } else if (!Double.isNaN(m_offsetValue)) {
1390      offset = m_offsetValue;
1391    }
1392   
1393    double trials = 1;
1394    if (m_trialsVariable != null) {
1395      Attribute trialsAtt = m_miningSchema.getFieldsAsInstances().attribute(m_trialsVariable);
1396      if (trialsAtt == null) {
1397        throw new Exception("[GeneralRegression] unable to find trials variable "
1398            + m_trialsVariable + " in the mining schema!");
1399      }
1400      trials = incomingInst[trialsAtt.index()];
1401    } else if (!Double.isNaN(m_trialsValue)) {
1402      trials = m_trialsValue;
1403    }
1404   
1405    double distParam = 0;
1406    if (m_linkFunction == LinkFunction.NEGBIN && 
1407        m_distribution == Distribution.NEGBINOMIAL) {
1408      if (Double.isNaN(m_distParameter)) {
1409        throw new Exception("[GeneralRegression] no distribution parameter defined!");
1410      }
1411      distParam = m_distParameter;
1412    }
1413   
1414    double linkParam = 0;
1415    if (m_linkFunction == LinkFunction.POWER || 
1416        m_linkFunction == LinkFunction.ODDSPOWER) {
1417      if (Double.isNaN(m_linkParameter)) {
1418        throw new Exception("[GeneralRegression] no link parameter defined!");
1419      }
1420      linkParam = m_linkParameter;
1421    }
1422   
1423    for (int i = 0; i < r.length; i++) {
1424      responses[i] = m_linkFunction.eval(r[i], offset, trials, distParam, linkParam);
1425    }
1426  }
1427   
1428  /**
1429   * Computes responses for the ordinal multinomial model type.
1430   *
1431   * @param incomingInst the raw incoming instance values (after missing value
1432   * replacement and outlier treatment etc).
1433   * @param responses will hold the responses computed by the function
1434   * @throws Exception if a problem occurs.
1435   */
1436  private void computeResponseOrdinalMultinomial(double[] incomingInst, 
1437                                                  double[] responses) throws Exception {
1438   
1439    double[] r = responses.clone();
1440   
1441    double offset = 0;
1442    if (m_offsetVariable != null) {
1443      Attribute offsetAtt = 
1444        m_miningSchema.getFieldsAsInstances().attribute(m_offsetVariable);
1445      if (offsetAtt == null) {
1446        throw new Exception("[GeneralRegression] unable to find offset variable "
1447            + m_offsetVariable + " in the mining schema!");
1448      }
1449      offset = incomingInst[offsetAtt.index()];
1450    } else if (!Double.isNaN(m_offsetValue)) {
1451      offset = m_offsetValue;
1452    }
1453   
1454    for (int i = 0; i < r.length; i++) {
1455      if (i == 0) {
1456        responses[i] = m_cumulativeLinkFunction.eval(r[i], offset);
1457   
1458      } else if (i == (r.length - 1)) {
1459        responses[i] = 1.0 - responses[i - 1];
1460      } else {
1461        responses[i] = m_cumulativeLinkFunction.eval(r[i], offset) - responses[i - 1];
1462      }
1463    }
1464  }
1465
1466  /* (non-Javadoc)
1467   * @see weka.core.RevisionHandler#getRevision()
1468   */
1469  public String getRevision() {
1470    return RevisionUtils.extract("$Revision: 5987 $");
1471  }
1472}
Note: See TracBrowser for help on using the repository browser.