source: src/main/java/weka/classifiers/functions/Logistic.java @ 18

Last change on this file since 18 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 27.4 KB
RevLine 
[4]1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    Logistic.java
19 *    Copyright (C) 2003 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.functions;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.core.Capabilities;
28import weka.core.Instance;
29import weka.core.Instances;
30import weka.core.Optimization;
31import weka.core.Option;
32import weka.core.OptionHandler;
33import weka.core.RevisionUtils;
34import weka.core.TechnicalInformation;
35import weka.core.TechnicalInformationHandler;
36import weka.core.Utils;
37import weka.core.WeightedInstancesHandler;
38import weka.core.Capabilities.Capability;
39import weka.core.TechnicalInformation.Field;
40import weka.core.TechnicalInformation.Type;
41import weka.filters.Filter;
42import weka.filters.unsupervised.attribute.NominalToBinary;
43import weka.filters.unsupervised.attribute.RemoveUseless;
44import weka.filters.unsupervised.attribute.ReplaceMissingValues;
45
46import java.util.Enumeration;
47import java.util.Vector;
48
49/**
50 <!-- globalinfo-start -->
51 * Class for building and using a multinomial logistic regression model with a ridge estimator.<br/>
52 * <br/>
53 * There are some modifications, however, compared to the paper of leCessie and van Houwelingen(1992): <br/>
54 * <br/>
55 * If there are k classes for n instances with m attributes, the parameter matrix B to be calculated will be an m*(k-1) matrix.<br/>
56 * <br/>
57 * The probability for class j with the exception of the last class is<br/>
58 * <br/>
59 * Pj(Xi) = exp(XiBj)/((sum[j=1..(k-1)]exp(Xi*Bj))+1) <br/>
60 * <br/>
61 * The last class has probability<br/>
62 * <br/>
63 * 1-(sum[j=1..(k-1)]Pj(Xi)) <br/>
64 *      = 1/((sum[j=1..(k-1)]exp(Xi*Bj))+1)<br/>
65 * <br/>
66 * The (negative) multinomial log-likelihood is thus: <br/>
67 * <br/>
68 * L = -sum[i=1..n]{<br/>
69 *      sum[j=1..(k-1)](Yij * ln(Pj(Xi)))<br/>
70 *      +(1 - (sum[j=1..(k-1)]Yij)) <br/>
71 *      * ln(1 - sum[j=1..(k-1)]Pj(Xi))<br/>
72 *      } + ridge * (B^2)<br/>
73 * <br/>
74 * In order to find the matrix B for which L is minimised, a Quasi-Newton Method is used to search for the optimized values of the m*(k-1) variables.  Note that before we use the optimization procedure, we 'squeeze' the matrix B into a m*(k-1) vector.  For details of the optimization procedure, please check weka.core.Optimization class.<br/>
75 * <br/>
76 * Although original Logistic Regression does not deal with instance weights, we modify the algorithm a little bit to handle the instance weights.<br/>
77 * <br/>
78 * For more information see:<br/>
79 * <br/>
80 * le Cessie, S., van Houwelingen, J.C. (1992). Ridge Estimators in Logistic Regression. Applied Statistics. 41(1):191-201.<br/>
81 * <br/>
82 * Note: Missing values are replaced using a ReplaceMissingValuesFilter, and nominal attributes are transformed into numeric attributes using a NominalToBinaryFilter.
83 * <p/>
84 <!-- globalinfo-end -->
85 *
86 <!-- technical-bibtex-start -->
87 * BibTeX:
88 * <pre>
89 * &#64;article{leCessie1992,
90 *    author = {le Cessie, S. and van Houwelingen, J.C.},
91 *    journal = {Applied Statistics},
92 *    number = {1},
93 *    pages = {191-201},
94 *    title = {Ridge Estimators in Logistic Regression},
95 *    volume = {41},
96 *    year = {1992}
97 * }
98 * </pre>
99 * <p/>
100 <!-- technical-bibtex-end -->
101 *
102 <!-- options-start -->
103 * Valid options are: <p/>
104 *
105 * <pre> -D
106 *  Turn on debugging output.</pre>
107 *
108 * <pre> -R &lt;ridge&gt;
109 *  Set the ridge in the log-likelihood.</pre>
110 *
111 * <pre> -M &lt;number&gt;
112 *  Set the maximum number of iterations (default -1, until convergence).</pre>
113 *
114 <!-- options-end -->
115 *
116 * @author Xin Xu (xx5@cs.waikato.ac.nz)
117 * @version $Revision: 5928 $
118 */
119public class Logistic extends AbstractClassifier
120  implements OptionHandler, WeightedInstancesHandler, TechnicalInformationHandler {
121 
122  /** for serialization */
123  static final long serialVersionUID = 3932117032546553727L;
124 
125  /** The coefficients (optimized parameters) of the model */
126  protected double [][] m_Par;
127   
128  /** The data saved as a matrix */
129  protected double [][] m_Data;
130   
131  /** The number of attributes in the model */
132  protected int m_NumPredictors;
133   
134  /** The index of the class attribute */
135  protected int m_ClassIndex;
136   
137  /** The number of the class labels */
138  protected int m_NumClasses;
139   
140  /** The ridge parameter. */
141  protected double m_Ridge = 1e-8;
142   
143  /** An attribute filter */
144  private RemoveUseless m_AttFilter;
145   
146  /** The filter used to make attributes numeric. */
147  private NominalToBinary m_NominalToBinary;
148   
149  /** The filter used to get rid of missing values. */
150  private ReplaceMissingValues m_ReplaceMissingValues;
151   
152  /** Debugging output */
153  protected boolean m_Debug;
154
155  /** Log-likelihood of the searched model */
156  protected double m_LL;
157   
158  /** The maximum number of iterations. */
159  private int m_MaxIts = -1;
160
161  private Instances m_structure;
162   
163  /**
164   * Returns a string describing this classifier
165   * @return a description of the classifier suitable for
166   * displaying in the explorer/experimenter gui
167   */
168  public String globalInfo() {
169    return "Class for building and using a multinomial logistic "
170      +"regression model with a ridge estimator.\n\n"
171      +"There are some modifications, however, compared to the paper of "
172      +"leCessie and van Houwelingen(1992): \n\n" 
173      +"If there are k classes for n instances with m attributes, the "
174      +"parameter matrix B to be calculated will be an m*(k-1) matrix.\n\n"
175      +"The probability for class j with the exception of the last class is\n\n"
176      +"Pj(Xi) = exp(XiBj)/((sum[j=1..(k-1)]exp(Xi*Bj))+1) \n\n"
177      +"The last class has probability\n\n"
178      +"1-(sum[j=1..(k-1)]Pj(Xi)) \n\t= 1/((sum[j=1..(k-1)]exp(Xi*Bj))+1)\n\n"
179      +"The (negative) multinomial log-likelihood is thus: \n\n"
180      +"L = -sum[i=1..n]{\n\tsum[j=1..(k-1)](Yij * ln(Pj(Xi)))"
181      +"\n\t+(1 - (sum[j=1..(k-1)]Yij)) \n\t* ln(1 - sum[j=1..(k-1)]Pj(Xi))"
182      +"\n\t} + ridge * (B^2)\n\n"
183      +"In order to find the matrix B for which L is minimised, a "
184      +"Quasi-Newton Method is used to search for the optimized values of "
185      +"the m*(k-1) variables.  Note that before we use the optimization "
186      +"procedure, we 'squeeze' the matrix B into a m*(k-1) vector.  For "
187      +"details of the optimization procedure, please check "
188      +"weka.core.Optimization class.\n\n"
189      +"Although original Logistic Regression does not deal with instance "
190      +"weights, we modify the algorithm a little bit to handle the "
191      +"instance weights.\n\n"
192      +"For more information see:\n\n"
193      + getTechnicalInformation().toString() + "\n\n"
194      +"Note: Missing values are replaced using a ReplaceMissingValuesFilter, and "
195      +"nominal attributes are transformed into numeric attributes using a "
196      +"NominalToBinaryFilter.";
197  }
198
199  /**
200   * Returns an instance of a TechnicalInformation object, containing
201   * detailed information about the technical background of this class,
202   * e.g., paper reference or book this class is based on.
203   *
204   * @return the technical information about this class
205   */
206  public TechnicalInformation getTechnicalInformation() {
207    TechnicalInformation        result;
208   
209    result = new TechnicalInformation(Type.ARTICLE);
210    result.setValue(Field.AUTHOR, "le Cessie, S. and van Houwelingen, J.C.");
211    result.setValue(Field.YEAR, "1992");
212    result.setValue(Field.TITLE, "Ridge Estimators in Logistic Regression");
213    result.setValue(Field.JOURNAL, "Applied Statistics");
214    result.setValue(Field.VOLUME, "41");
215    result.setValue(Field.NUMBER, "1");
216    result.setValue(Field.PAGES, "191-201");
217   
218    return result;
219  }
220
221  /**
222   * Returns an enumeration describing the available options
223   *
224   * @return an enumeration of all the available options
225   */
226  public Enumeration listOptions() {
227    Vector newVector = new Vector(3);
228    newVector.addElement(new Option("\tTurn on debugging output.",
229                                    "D", 0, "-D"));
230    newVector.addElement(new Option("\tSet the ridge in the log-likelihood.",
231                                    "R", 1, "-R <ridge>"));
232    newVector.addElement(new Option("\tSet the maximum number of iterations"+
233                                    " (default -1, until convergence).",
234                                    "M", 1, "-M <number>"));
235    return newVector.elements();
236  }
237   
238  /**
239   * Parses a given list of options. <p/>
240   *
241   <!-- options-start -->
242   * Valid options are: <p/>
243   *
244   * <pre> -D
245   *  Turn on debugging output.</pre>
246   *
247   * <pre> -R &lt;ridge&gt;
248   *  Set the ridge in the log-likelihood.</pre>
249   *
250   * <pre> -M &lt;number&gt;
251   *  Set the maximum number of iterations (default -1, until convergence).</pre>
252   *
253   <!-- options-end -->
254   *
255   * @param options the list of options as an array of strings
256   * @throws Exception if an option is not supported
257   */
258  public void setOptions(String[] options) throws Exception {
259    setDebug(Utils.getFlag('D', options));
260
261    String ridgeString = Utils.getOption('R', options);
262    if (ridgeString.length() != 0) 
263      m_Ridge = Double.parseDouble(ridgeString);
264    else 
265      m_Ridge = 1.0e-8;
266       
267    String maxItsString = Utils.getOption('M', options);
268    if (maxItsString.length() != 0) 
269      m_MaxIts = Integer.parseInt(maxItsString);
270    else 
271      m_MaxIts = -1;
272  }
273   
274  /**
275   * Gets the current settings of the classifier.
276   *
277   * @return an array of strings suitable for passing to setOptions
278   */
279  public String [] getOptions() {
280       
281    String [] options = new String [5];
282    int current = 0;
283       
284    if (getDebug()) 
285      options[current++] = "-D";
286    options[current++] = "-R";
287    options[current++] = ""+m_Ridge;   
288    options[current++] = "-M";
289    options[current++] = ""+m_MaxIts;
290    while (current < options.length) 
291      options[current++] = "";
292    return options;
293  }
294   
295  /**
296   * Returns the tip text for this property
297   * @return tip text for this property suitable for
298   * displaying in the explorer/experimenter gui
299   */
300  public String debugTipText() {
301    return "Output debug information to the console.";
302  }
303
304  /**
305   * Sets whether debugging output will be printed.
306   *
307   * @param debug true if debugging output should be printed
308   */
309  public void setDebug(boolean debug) {
310    m_Debug = debug;
311  }
312   
313  /**
314   * Gets whether debugging output will be printed.
315   *
316   * @return true if debugging output will be printed
317   */
318  public boolean getDebug() {
319    return m_Debug;
320  }     
321
322  /**
323   * Returns the tip text for this property
324   * @return tip text for this property suitable for
325   * displaying in the explorer/experimenter gui
326   */
327  public String ridgeTipText() {
328    return "Set the Ridge value in the log-likelihood.";
329  }
330
331  /**
332   * Sets the ridge in the log-likelihood.
333   *
334   * @param ridge the ridge
335   */
336  public void setRidge(double ridge) {
337    m_Ridge = ridge;
338  }
339   
340  /**
341   * Gets the ridge in the log-likelihood.
342   *
343   * @return the ridge
344   */
345  public double getRidge() {
346    return m_Ridge;
347  }
348   
349  /**
350   * Returns the tip text for this property
351   * @return tip text for this property suitable for
352   * displaying in the explorer/experimenter gui
353   */
354  public String maxItsTipText() {
355    return "Maximum number of iterations to perform.";
356  }
357
358  /**
359   * Get the value of MaxIts.
360   *
361   * @return Value of MaxIts.
362   */
363  public int getMaxIts() {
364       
365    return m_MaxIts;
366  }
367   
368  /**
369   * Set the value of MaxIts.
370   *
371   * @param newMaxIts Value to assign to MaxIts.
372   */
373  public void setMaxIts(int newMaxIts) {
374       
375    m_MaxIts = newMaxIts;
376  }   
377   
378  private class OptEng extends Optimization{
379    /** Weights of instances in the data */
380    private double[] weights;
381
382    /** Class labels of instances */
383    private int[] cls;
384       
385    /**
386     * Set the weights of instances
387     * @param w the weights to be set
388     */ 
389    public void setWeights(double[] w) {
390      weights = w;
391    }
392       
393    /**
394     * Set the class labels of instances
395     * @param c the class labels to be set
396     */ 
397    public void setClassLabels(int[] c) {
398      cls = c;
399    }
400       
401    /**
402     * Evaluate objective function
403     * @param x the current values of variables
404     * @return the value of the objective function
405     */
406    protected double objectiveFunction(double[] x){
407      double nll = 0; // -LogLikelihood
408      int dim = m_NumPredictors+1; // Number of variables per class
409           
410      for(int i=0; i<cls.length; i++){ // ith instance
411
412        double[] exp = new double[m_NumClasses-1];
413        int index;
414        for(int offset=0; offset<m_NumClasses-1; offset++){ 
415          index = offset * dim;
416          for(int j=0; j<dim; j++)
417            exp[offset] += m_Data[i][j]*x[index + j];
418        }
419        double max = exp[Utils.maxIndex(exp)];
420        double denom = Math.exp(-max);
421        double num;
422        if (cls[i] == m_NumClasses - 1) { // Class of this instance
423          num = -max;
424        } else {
425          num = exp[cls[i]] - max;
426        }
427        for(int offset=0; offset<m_NumClasses-1; offset++){
428          denom += Math.exp(exp[offset] - max);
429        }
430               
431        nll -= weights[i]*(num - Math.log(denom)); // Weighted NLL
432      }
433           
434      // Ridge: note that intercepts NOT included
435      for(int offset=0; offset<m_NumClasses-1; offset++){
436        for(int r=1; r<dim; r++)
437          nll += m_Ridge*x[offset*dim+r]*x[offset*dim+r];
438      }
439           
440      return nll;
441    }
442
443    /**
444     * Evaluate Jacobian vector
445     * @param x the current values of variables
446     * @return the gradient vector
447     */
448    protected double[] evaluateGradient(double[] x){
449      double[] grad = new double[x.length];
450      int dim = m_NumPredictors+1; // Number of variables per class
451           
452      for(int i=0; i<cls.length; i++){ // ith instance
453        double[] num=new double[m_NumClasses-1]; // numerator of [-log(1+sum(exp))]'
454        int index;
455        for(int offset=0; offset<m_NumClasses-1; offset++){ // Which part of x
456          double exp=0.0;
457          index = offset * dim;
458          for(int j=0; j<dim; j++)
459            exp += m_Data[i][j]*x[index + j];
460          num[offset] = exp;
461        }
462
463        double max = num[Utils.maxIndex(num)];
464        double denom = Math.exp(-max); // Denominator of [-log(1+sum(exp))]'
465        for(int offset=0; offset<m_NumClasses-1; offset++){
466          num[offset] = Math.exp(num[offset] - max);
467          denom += num[offset];
468        }
469        Utils.normalize(num, denom);
470               
471        // Update denominator of the gradient of -log(Posterior)
472        double firstTerm;
473        for(int offset=0; offset<m_NumClasses-1; offset++){ // Which part of x
474          index = offset * dim;
475          firstTerm = weights[i] * num[offset];
476          for(int q=0; q<dim; q++){
477            grad[index + q] += firstTerm * m_Data[i][q];
478          }
479        }
480               
481        if(cls[i] != m_NumClasses-1){ // Not the last class
482          for(int p=0; p<dim; p++){
483            grad[cls[i]*dim+p] -= weights[i]*m_Data[i][p]; 
484          }
485        }
486      }
487           
488      // Ridge: note that intercepts NOT included
489      for(int offset=0; offset<m_NumClasses-1; offset++){
490        for(int r=1; r<dim; r++)
491          grad[offset*dim+r] += 2*m_Ridge*x[offset*dim+r];
492      }
493           
494      return grad;
495    }
496   
497    /**
498     * Returns the revision string.
499     *
500     * @return          the revision
501     */
502    public String getRevision() {
503      return RevisionUtils.extract("$Revision: 5928 $");
504    }
505  }
506
507  /**
508   * Returns default capabilities of the classifier.
509   *
510   * @return      the capabilities of this classifier
511   */
512  public Capabilities getCapabilities() {
513    Capabilities result = super.getCapabilities();
514    result.disableAll();
515
516    // attributes
517    result.enable(Capability.NOMINAL_ATTRIBUTES);
518    result.enable(Capability.NUMERIC_ATTRIBUTES);
519    result.enable(Capability.DATE_ATTRIBUTES);
520    result.enable(Capability.MISSING_VALUES);
521
522    // class
523    result.enable(Capability.NOMINAL_CLASS);
524    result.enable(Capability.MISSING_CLASS_VALUES);
525   
526    return result;
527  }
528   
529  /**
530   * Builds the classifier
531   *
532   * @param train the training data to be used for generating the
533   * boosted classifier.
534   * @throws Exception if the classifier could not be built successfully
535   */
536  public void buildClassifier(Instances train) throws Exception {
537    // can classifier handle the data?
538    getCapabilities().testWithFail(train);
539
540    // remove instances with missing class
541    train = new Instances(train);
542    train.deleteWithMissingClass();
543   
544    // Replace missing values   
545    m_ReplaceMissingValues = new ReplaceMissingValues();
546    m_ReplaceMissingValues.setInputFormat(train);
547    train = Filter.useFilter(train, m_ReplaceMissingValues);
548
549    // Remove useless attributes
550    m_AttFilter = new RemoveUseless();
551    m_AttFilter.setInputFormat(train);
552    train = Filter.useFilter(train, m_AttFilter);
553       
554    // Transform attributes
555    m_NominalToBinary = new NominalToBinary();
556    m_NominalToBinary.setInputFormat(train);
557    train = Filter.useFilter(train, m_NominalToBinary);
558   
559    // Save the structure for printing the model
560    m_structure = new Instances(train, 0);
561       
562    // Extract data
563    m_ClassIndex = train.classIndex();
564    m_NumClasses = train.numClasses();
565
566    int nK = m_NumClasses - 1;                     // Only K-1 class labels needed
567    int nR = m_NumPredictors = train.numAttributes() - 1;
568    int nC = train.numInstances();
569       
570    m_Data = new double[nC][nR + 1];               // Data values
571    int [] Y  = new int[nC];                       // Class labels
572    double [] xMean= new double[nR + 1];           // Attribute means
573    double [] xSD  = new double[nR + 1];           // Attribute stddev's
574    double [] sY = new double[nK + 1];             // Number of classes
575    double [] weights = new double[nC];            // Weights of instances
576    double totWeights = 0;                         // Total weights of the instances
577    m_Par = new double[nR + 1][nK];                // Optimized parameter values
578       
579    if (m_Debug) {
580      System.out.println("Extracting data...");
581    }
582       
583    for (int i = 0; i < nC; i++) {
584      // initialize X[][]
585      Instance current = train.instance(i);
586      Y[i] = (int)current.classValue();  // Class value starts from 0
587      weights[i] = current.weight();     // Dealing with weights
588      totWeights += weights[i];
589           
590      m_Data[i][0] = 1;
591      int j = 1;
592      for (int k = 0; k <= nR; k++) {
593        if (k != m_ClassIndex) {
594          double x = current.value(k);
595          m_Data[i][j] = x;
596          xMean[j] += weights[i]*x;
597          xSD[j] += weights[i]*x*x;
598          j++;
599        }
600      }
601           
602      // Class count
603      sY[Y[i]]++;       
604    }
605       
606    if((totWeights <= 1) && (nC > 1))
607      throw new Exception("Sum of weights of instances less than 1, please reweight!");
608
609    xMean[0] = 0; xSD[0] = 1;
610    for (int j = 1; j <= nR; j++) {
611      xMean[j] = xMean[j] / totWeights;
612      if(totWeights > 1)
613        xSD[j] = Math.sqrt(Math.abs(xSD[j] - totWeights*xMean[j]*xMean[j])/(totWeights-1));
614      else
615        xSD[j] = 0;
616    }
617
618    if (m_Debug) {         
619      // Output stats about input data
620      System.out.println("Descriptives...");
621      for (int m = 0; m <= nK; m++)
622        System.out.println(sY[m] + " cases have class " + m);
623      System.out.println("\n Variable     Avg       SD    ");
624      for (int j = 1; j <= nR; j++) 
625        System.out.println(Utils.doubleToString(j,8,4) 
626                           + Utils.doubleToString(xMean[j], 10, 4) 
627                           + Utils.doubleToString(xSD[j], 10, 4)
628                           );
629    }
630       
631    // Normalise input data
632    for (int i = 0; i < nC; i++) {
633      for (int j = 0; j <= nR; j++) {
634        if (xSD[j] != 0) {
635          m_Data[i][j] = (m_Data[i][j] - xMean[j]) / xSD[j];
636        }
637      }
638    }
639       
640    if (m_Debug) {
641      System.out.println("\nIteration History..." );
642    }
643       
644    double x[] = new double[(nR+1)*nK];
645    double[][] b = new double[2][x.length]; // Boundary constraints, N/A here
646
647    // Initialize
648    for(int p=0; p<nK; p++){
649      int offset=p*(nR+1);       
650      x[offset] =  Math.log(sY[p]+1.0) - Math.log(sY[nK]+1.0); // Null model
651      b[0][offset] = Double.NaN;
652      b[1][offset] = Double.NaN;   
653      for (int q=1; q <= nR; q++){
654        x[offset+q] = 0.0;             
655        b[0][offset+q] = Double.NaN;
656        b[1][offset+q] = Double.NaN;
657      } 
658    }
659       
660    OptEng opt = new OptEng(); 
661    opt.setDebug(m_Debug);
662    opt.setWeights(weights);
663    opt.setClassLabels(Y);
664
665    if(m_MaxIts == -1){  // Search until convergence
666      x = opt.findArgmin(x, b);
667      while(x==null){
668        x = opt.getVarbValues();
669        if (m_Debug)
670          System.out.println("200 iterations finished, not enough!");
671        x = opt.findArgmin(x, b);
672      }
673      if (m_Debug)
674        System.out.println(" -------------<Converged>--------------");
675    }
676    else{
677      opt.setMaxIteration(m_MaxIts);
678      x = opt.findArgmin(x, b);
679      if(x==null) // Not enough, but use the current value
680        x = opt.getVarbValues();
681    }
682       
683    m_LL = -opt.getMinFunction(); // Log-likelihood
684
685    // Don't need data matrix anymore
686    m_Data = null;
687           
688    // Convert coefficients back to non-normalized attribute units
689    for(int i=0; i < nK; i++){
690      m_Par[0][i] = x[i*(nR+1)];
691      for(int j = 1; j <= nR; j++) {
692        m_Par[j][i] = x[i*(nR+1)+j];
693        if (xSD[j] != 0) {
694          m_Par[j][i] /= xSD[j];
695          m_Par[0][i] -= m_Par[j][i] * xMean[j];
696        }
697      }
698    }
699  }             
700   
701  /**
702   * Computes the distribution for a given instance
703   *
704   * @param instance the instance for which distribution is computed
705   * @return the distribution
706   * @throws Exception if the distribution can't be computed successfully
707   */
708  public double [] distributionForInstance(Instance instance) 
709    throws Exception {
710       
711    m_ReplaceMissingValues.input(instance);
712    instance = m_ReplaceMissingValues.output();
713    m_AttFilter.input(instance);
714    instance = m_AttFilter.output();
715    m_NominalToBinary.input(instance);
716    instance = m_NominalToBinary.output();
717       
718    // Extract the predictor columns into an array
719    double [] instDat = new double [m_NumPredictors + 1];
720    int j = 1;
721    instDat[0] = 1;
722    for (int k = 0; k <= m_NumPredictors; k++) {
723      if (k != m_ClassIndex) {
724        instDat[j++] = instance.value(k);
725      }
726    }
727       
728    double [] distribution = evaluateProbability(instDat);
729    return distribution;
730  }
731
732  /**
733   * Compute the posterior distribution using optimized parameter values
734   * and the testing instance.
735   * @param data the testing instance
736   * @return the posterior probability distribution
737   */ 
738  private double[] evaluateProbability(double[] data){
739    double[] prob = new double[m_NumClasses],
740      v = new double[m_NumClasses];
741
742    // Log-posterior before normalizing
743    for(int j = 0; j < m_NumClasses-1; j++){
744      for(int k = 0; k <= m_NumPredictors; k++){
745        v[j] += m_Par[k][j] * data[k];
746      }
747    }
748    v[m_NumClasses-1] = 0;
749       
750    // Do so to avoid scaling problems
751    for(int m=0; m < m_NumClasses; m++){
752      double sum = 0;
753      for(int n=0; n < m_NumClasses-1; n++)
754        sum += Math.exp(v[n] - v[m]);
755      prob[m] = 1 / (sum + Math.exp(-v[m]));
756    }
757       
758    return prob;
759  } 
760
761  /**
762   * Returns the coefficients for this logistic model.
763   * The first dimension indexes the attributes, and
764   * the second the classes.
765   *
766   * @return the coefficients for this logistic model
767   */
768  public double [][] coefficients() {
769    return m_Par;
770  }
771   
772  /**
773   * Gets a string describing the classifier.
774   *
775   * @return a string describing the classifer built.
776   */
777  public String toString() {
778    StringBuffer temp = new StringBuffer();
779
780    String result = "";
781    temp.append("Logistic Regression with ridge parameter of " + m_Ridge);
782    if (m_Par == null) {
783      return result + ": No model built yet.";
784    }
785
786    // find longest attribute name
787    int attLength = 0;
788    for (int i = 0; i < m_structure.numAttributes(); i++) {
789      if (i != m_structure.classIndex() && 
790          m_structure.attribute(i).name().length() > attLength) {
791        attLength = m_structure.attribute(i).name().length();
792      }
793    }
794
795    if ("Intercept".length() > attLength) {
796      attLength = "Intercept".length();
797    }
798
799    if ("Variable".length() > attLength) {
800      attLength = "Variable".length();
801    }
802    attLength += 2;
803
804    int colWidth = 0;
805    // check length of class names
806    for (int i = 0; i < m_structure.classAttribute().numValues() - 1; i++) {
807      if (m_structure.classAttribute().value(i).length() > colWidth) {
808        colWidth = m_structure.classAttribute().value(i).length();
809      }
810    }
811
812    // check against coefficients and odds ratios
813    for (int j = 1; j <= m_NumPredictors; j++) {
814      for (int k = 0; k < m_NumClasses - 1; k++) {
815        if (Utils.doubleToString(m_Par[j][k], 12, 4).trim().length() > colWidth) {
816          colWidth = Utils.doubleToString(m_Par[j][k], 12, 4).trim().length();
817        }
818        double ORc = Math.exp(m_Par[j][k]);
819        String t = " " + ((ORc > 1e10) ?  "" + ORc : Utils.doubleToString(ORc, 12, 4));
820        if (t.trim().length() > colWidth) {
821          colWidth = t.trim().length();
822        }
823      }
824    }
825
826    if ("Class".length() > colWidth) {
827      colWidth = "Class".length();
828    }
829    colWidth += 2;
830   
831   
832    temp.append("\nCoefficients...\n");
833    temp.append(Utils.padLeft(" ", attLength) + Utils.padLeft("Class", colWidth) + "\n");
834    temp.append(Utils.padRight("Variable", attLength));
835
836    for (int i = 0; i < m_NumClasses - 1; i++) {
837      String className = m_structure.classAttribute().value(i);
838      temp.append(Utils.padLeft(className, colWidth));
839    }
840    temp.append("\n");
841    int separatorL = attLength + ((m_NumClasses - 1) * colWidth);
842    for (int i = 0; i < separatorL; i++) {
843      temp.append("=");
844    }
845    temp.append("\n");
846               
847    int j = 1;
848    for (int i = 0; i < m_structure.numAttributes(); i++) {
849      if (i != m_structure.classIndex()) {
850        temp.append(Utils.padRight(m_structure.attribute(i).name(), attLength));
851        for (int k = 0; k < m_NumClasses-1; k++) {
852          temp.append(Utils.padLeft(Utils.doubleToString(m_Par[j][k], 12, 4).trim(), colWidth));
853        }
854        temp.append("\n");
855        j++;
856      }
857    }
858       
859    temp.append(Utils.padRight("Intercept", attLength));
860    for (int k = 0; k < m_NumClasses-1; k++) {
861      temp.append(Utils.padLeft(Utils.doubleToString(m_Par[0][k], 10, 4).trim(), colWidth)); 
862    }
863    temp.append("\n");
864       
865    temp.append("\n\nOdds Ratios...\n");
866    temp.append(Utils.padLeft(" ", attLength) + Utils.padLeft("Class", colWidth) + "\n");
867    temp.append(Utils.padRight("Variable", attLength));
868
869    for (int i = 0; i < m_NumClasses - 1; i++) {
870      String className = m_structure.classAttribute().value(i);
871      temp.append(Utils.padLeft(className, colWidth));
872    }
873    temp.append("\n");
874    for (int i = 0; i < separatorL; i++) {
875      temp.append("=");
876    }
877    temp.append("\n");
878
879    j = 1;
880    for (int i = 0; i < m_structure.numAttributes(); i++) {
881      if (i != m_structure.classIndex()) {
882        temp.append(Utils.padRight(m_structure.attribute(i).name(), attLength));
883        for (int k = 0; k < m_NumClasses-1; k++) {
884          double ORc = Math.exp(m_Par[j][k]);
885          String ORs = " " + ((ORc > 1e10) ?  "" + ORc : Utils.doubleToString(ORc, 12, 4));
886          temp.append(Utils.padLeft(ORs.trim(), colWidth));
887        }
888        temp.append("\n");
889        j++;
890      }
891    }
892
893    return temp.toString();
894  }
895 
896  /**
897   * Returns the revision string.
898   *
899   * @return            the revision
900   */
901  public String getRevision() {
902    return RevisionUtils.extract("$Revision: 5928 $");
903  }
904   
905  /**
906   * Main method for testing this class.
907   *
908   * @param argv should contain the command line arguments to the
909   * scheme (see Evaluation)
910   */
911  public static void main(String [] argv) {
912    runClassifier(new Logistic(), argv);
913  }
914}
Note: See TracBrowser for help on using the repository browser.