source: src/main/java/weka/classifiers/mi/MILR.java @ 8

Last change on this file since 8 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 24.0 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 * MILR.java
19 * Copyright (C) 2005 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.mi;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.core.Capabilities;
28import weka.core.Instance;
29import weka.core.Instances;
30import weka.core.MultiInstanceCapabilitiesHandler;
31import weka.core.Optimization;
32import weka.core.Option;
33import weka.core.OptionHandler;
34import weka.core.RevisionUtils;
35import weka.core.SelectedTag;
36import weka.core.Tag;
37import weka.core.Utils;
38import weka.core.Capabilities.Capability;
39
40import java.util.Enumeration;
41import java.util.Vector;
42
43/**
44 <!-- globalinfo-start -->
45 * Uses either standard or collective multi-instance assumption, but within linear regression. For the collective assumption, it offers arithmetic or geometric mean for the posteriors.
46 * <p/>
47 <!-- globalinfo-end -->
48 *
49 <!-- options-start -->
50 * Valid options are: <p/>
51 *
52 * <pre> -D
53 *  Turn on debugging output.</pre>
54 *
55 * <pre> -R &lt;ridge&gt;
56 *  Set the ridge in the log-likelihood.</pre>
57 *
58 * <pre> -A [0|1|2]
59 *  Defines the type of algorithm:
60 *   0. standard MI assumption
61 *   1. collective MI assumption, arithmetic mean for posteriors
62 *   2. collective MI assumption, geometric mean for posteriors</pre>
63 *
64 <!-- options-end -->
65 *
66 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
67 * @author Xin Xu (xx5@cs.waikato.ac.nz)
68 * @version $Revision: 5928 $
69 */
70public class MILR
71  extends AbstractClassifier
72  implements OptionHandler, MultiInstanceCapabilitiesHandler {
73
74  /** for serialization */
75  static final long serialVersionUID = 1996101190172373826L;
76 
77  protected double[] m_Par;
78
79  /** The number of the class labels */
80  protected int m_NumClasses;
81
82  /** The ridge parameter. */
83  protected double m_Ridge = 1e-6;
84
85  /** Class labels for each bag */
86  protected int[] m_Classes;
87
88  /** MI data */ 
89  protected double[][][] m_Data;
90
91  /** All attribute names */
92  protected Instances m_Attributes;
93
94  protected double[] xMean = null, xSD = null;
95
96  /** the type of processing */
97  protected int m_AlgorithmType = ALGORITHMTYPE_DEFAULT;
98
99  /** standard MI assumption */
100  public static final int ALGORITHMTYPE_DEFAULT = 0;
101  /** collective MI assumption, arithmetic mean for posteriors */
102  public static final int ALGORITHMTYPE_ARITHMETIC = 1;
103  /** collective MI assumption, geometric mean for posteriors */
104  public static final int ALGORITHMTYPE_GEOMETRIC = 2;
105  /** the types of algorithms */
106  public static final Tag [] TAGS_ALGORITHMTYPE = {
107    new Tag(ALGORITHMTYPE_DEFAULT, "standard MI assumption"),
108    new Tag(ALGORITHMTYPE_ARITHMETIC, "collective MI assumption, arithmetic mean for posteriors"),
109    new Tag(ALGORITHMTYPE_GEOMETRIC, "collective MI assumption, geometric mean for posteriors"),
110  };
111
112  /**
113   * Returns the tip text for this property
114   *
115   * @return tip text for this property suitable for
116   * displaying in the explorer/experimenter gui
117   */
118  public String globalInfo() {
119    return 
120        "Uses either standard or collective multi-instance assumption, but "
121      + "within linear regression. For the collective assumption, it offers "
122      + "arithmetic or geometric mean for the posteriors.";
123  }
124
125  /**
126   * Returns an enumeration describing the available options
127   *
128   * @return an enumeration of all the available options
129   */
130  public Enumeration listOptions() {
131    Vector result = new Vector();
132   
133    result.addElement(new Option(
134          "\tTurn on debugging output.",
135          "D", 0, "-D"));
136   
137    result.addElement(new Option(
138        "\tSet the ridge in the log-likelihood.",
139        "R", 1, "-R <ridge>"));
140
141    result.addElement(new Option(
142        "\tDefines the type of algorithm:\n"
143        + "\t 0. standard MI assumption\n"
144        + "\t 1. collective MI assumption, arithmetic mean for posteriors\n"
145        + "\t 2. collective MI assumption, geometric mean for posteriors",
146        "A", 1, "-A [0|1|2]"));
147
148    return result.elements();
149  }
150
151  /**
152   * Parses a given list of options.
153   *
154   * @param options the list of options as an array of strings
155   * @throws Exception if an option is not supported
156   */
157  public void setOptions(String[] options) throws Exception {
158    String      tmpStr;
159
160    setDebug(Utils.getFlag('D', options));
161
162    tmpStr = Utils.getOption('R', options);
163    if (tmpStr.length() != 0) 
164      setRidge(Double.parseDouble(tmpStr));
165    else 
166      setRidge(1.0e-6);
167
168    tmpStr = Utils.getOption('A', options);
169    if (tmpStr.length() != 0) {
170      setAlgorithmType(new SelectedTag(Integer.parseInt(tmpStr), TAGS_ALGORITHMTYPE));
171    } else {
172      setAlgorithmType(new SelectedTag(ALGORITHMTYPE_DEFAULT, TAGS_ALGORITHMTYPE));
173    }     
174  }
175
176  /**
177   * Gets the current settings of the classifier.
178   *
179   * @return an array of strings suitable for passing to setOptions
180   */
181  public String[] getOptions() {
182    Vector        result;
183   
184    result = new Vector();
185
186    if (getDebug())
187      result.add("-D");
188   
189    result.add("-R");
190    result.add("" + getRidge());
191   
192    result.add("-A");
193    result.add("" + m_AlgorithmType);
194
195    return (String[]) result.toArray(new String[result.size()]);
196  }
197
198  /**
199   * Returns the tip text for this property
200   *
201   * @return tip text for this property suitable for
202   * displaying in the explorer/experimenter gui
203   */
204  public String ridgeTipText() {
205    return "The ridge in the log-likelihood.";
206  }
207
208  /**
209   * Sets the ridge in the log-likelihood.
210   *
211   * @param ridge the ridge
212   */
213  public void setRidge(double ridge) {
214    m_Ridge = ridge;
215  }
216
217  /**
218   * Gets the ridge in the log-likelihood.
219   *
220   * @return the ridge
221   */
222  public double getRidge() {
223    return m_Ridge;
224  }
225
226  /**
227   * Returns the tip text for this property
228   *
229   * @return tip text for this property suitable for
230   * displaying in the explorer/experimenter gui
231   */
232  public String algorithmTypeTipText() {
233    return "The mean type for the posteriors.";
234  }
235
236  /**
237   * Gets the type of algorithm.
238   *
239   * @return the algorithm type
240   */
241  public SelectedTag getAlgorithmType() {
242    return new SelectedTag(m_AlgorithmType, TAGS_ALGORITHMTYPE);
243  }
244
245  /**
246   * Sets the algorithm type.
247   *
248   * @param newType the new algorithm type
249   */
250  public void setAlgorithmType(SelectedTag newType) {
251    if (newType.getTags() == TAGS_ALGORITHMTYPE) {
252      m_AlgorithmType = newType.getSelectedTag().getID();
253    }
254  }
255
256  private class OptEng 
257    extends Optimization {
258   
259    /** the type to use
260     * @see MILR#TAGS_ALGORITHMTYPE */
261    private int m_Type;
262   
263    /**
264     * initializes the object
265     *
266     * @param type      the type top use
267     * @see MILR#TAGS_ALGORITHMTYPE
268     */
269    public OptEng(int type) {
270      super();
271     
272      m_Type = type;
273    }
274   
275    /**
276     * Evaluate objective function
277     * @param x the current values of variables
278     * @return the value of the objective function
279     */
280    protected double objectiveFunction(double[] x){
281      double nll = 0; // -LogLikelihood
282     
283      switch (m_Type) {
284        case ALGORITHMTYPE_DEFAULT:
285          for(int i=0; i<m_Classes.length; i++){ // ith bag
286            int nI = m_Data[i][0].length; // numInstances in ith bag
287            double bag = 0.0, // NLL of each bag
288                   prod = 0.0;   // Log-prob.
289
290            for(int j=0; j<nI; j++){
291              double exp=0.0;
292              for(int k=m_Data[i].length-1; k>=0; k--)
293                exp += m_Data[i][k][j]*x[k+1];
294              exp += x[0];
295              exp = Math.exp(exp);
296
297              if(m_Classes[i]==1)
298                prod -= Math.log(1.0+exp);
299              else
300                bag += Math.log(1.0+exp);
301            }
302
303            if(m_Classes[i]==1)
304              bag = -Math.log(1.0-Math.exp(prod));
305
306            nll += bag;
307          }   
308          break;
309       
310        case ALGORITHMTYPE_ARITHMETIC:
311          for(int i=0; i<m_Classes.length; i++){ // ith bag
312            int nI = m_Data[i][0].length; // numInstances in ith bag
313            double bag = 0;  // NLL of each bag
314
315            for(int j=0; j<nI; j++){
316              double exp=0.0;
317              for(int k=m_Data[i].length-1; k>=0; k--)
318                exp += m_Data[i][k][j]*x[k+1];
319              exp += x[0];
320              exp = Math.exp(exp);
321
322              if(m_Classes[i] == 1)
323                bag += 1.0-1.0/(1.0+exp); // To avoid exp infinite
324              else
325                bag += 1.0/(1.0+exp);                 
326            }   
327            bag /= (double)nI;
328
329            nll -= Math.log(bag);
330          }   
331          break;
332         
333        case ALGORITHMTYPE_GEOMETRIC:
334          for(int i=0; i<m_Classes.length; i++){ // ith bag
335            int nI = m_Data[i][0].length; // numInstances in ith bag
336            double bag = 0;   // Log-prob.
337
338            for(int j=0; j<nI; j++){
339              double exp=0.0;
340              for(int k=m_Data[i].length-1; k>=0; k--)
341                exp += m_Data[i][k][j]*x[k+1];
342              exp += x[0];
343
344              if(m_Classes[i]==1)
345                bag -= exp/(double)nI;
346              else
347                bag += exp/(double)nI;
348            }
349
350            nll += Math.log(1.0+Math.exp(bag));
351          }   
352          break;
353      }
354
355      // ridge: note that intercepts NOT included
356      for(int r=1; r<x.length; r++)
357        nll += m_Ridge*x[r]*x[r];
358
359      return nll;
360    }
361
362    /**
363     * Evaluate Jacobian vector
364     * @param x the current values of variables
365     * @return the gradient vector
366     */
367    protected double[] evaluateGradient(double[] x){
368      double[] grad = new double[x.length];
369     
370      switch (m_Type) {
371        case ALGORITHMTYPE_DEFAULT:
372          for(int i=0; i<m_Classes.length; i++){ // ith bag
373            int nI = m_Data[i][0].length; // numInstances in ith bag
374
375            double denom = 0.0; // denominator, in log-scale       
376            double[] bag = new double[grad.length]; //gradient update with ith bag
377
378            for(int j=0; j<nI; j++){
379              // Compute exp(b0+b1*Xi1j+...)/[1+exp(b0+b1*Xi1j+...)]
380              double exp=0.0;
381              for(int k=m_Data[i].length-1; k>=0; k--)
382                exp += m_Data[i][k][j]*x[k+1];
383              exp += x[0];
384              exp = Math.exp(exp)/(1.0+Math.exp(exp));
385
386              if(m_Classes[i]==1)
387                // Bug fix: it used to be denom += Math.log(1.0+exp);
388                // Fixed 21 Jan 2005 (Eibe)
389                denom -= Math.log(1.0-exp);
390
391              // Instance-wise update of dNLL/dBk
392              for(int p=0; p<x.length; p++){  // pth variable
393                double m = 1.0;
394                if(p>0) m=m_Data[i][p-1][j];
395                bag[p] += m*exp;
396              }     
397            }
398
399            denom = Math.exp(denom);
400
401            // Bag-wise update of dNLL/dBk
402            for(int q=0; q<grad.length; q++){
403              if(m_Classes[i]==1)
404                grad[q] -= bag[q]/(denom-1.0);
405              else
406                grad[q] += bag[q];
407            }   
408          }
409          break;
410       
411        case ALGORITHMTYPE_ARITHMETIC:
412          for(int i=0; i<m_Classes.length; i++){ // ith bag
413            int nI = m_Data[i][0].length; // numInstances in ith bag
414
415            double denom=0.0;
416            double[] numrt = new double[x.length];
417
418            for(int j=0; j<nI; j++){
419              // Compute exp(b0+b1*Xi1j+...)/[1+exp(b0+b1*Xi1j+...)]
420              double exp=0.0;
421              for(int k=m_Data[i].length-1; k>=0; k--)
422                exp += m_Data[i][k][j]*x[k+1];
423              exp += x[0];
424              exp = Math.exp(exp);
425              if(m_Classes[i]==1)
426                denom += exp/(1.0+exp);
427              else
428                denom += 1.0/(1.0+exp);     
429
430              // Instance-wise update of dNLL/dBk
431              for(int p=0; p<x.length; p++){  // pth variable
432                double m = 1.0;
433                if(p>0) m=m_Data[i][p-1][j];
434                numrt[p] += m*exp/((1.0+exp)*(1.0+exp));   
435              }     
436            }
437
438            // Bag-wise update of dNLL/dBk
439            for(int q=0; q<grad.length; q++){
440              if(m_Classes[i]==1)
441                grad[q] -= numrt[q]/denom;
442              else
443                grad[q] += numrt[q]/denom;         
444            }
445          }
446          break;
447
448        case ALGORITHMTYPE_GEOMETRIC:
449          for(int i=0; i<m_Classes.length; i++){ // ith bag
450            int nI = m_Data[i][0].length; // numInstances in ith bag   
451            double bag = 0;
452            double[] sumX = new double[x.length];
453            for(int j=0; j<nI; j++){
454              // Compute exp(b0+b1*Xi1j+...)/[1+exp(b0+b1*Xi1j+...)]
455              double exp=0.0;
456              for(int k=m_Data[i].length-1; k>=0; k--)
457                exp += m_Data[i][k][j]*x[k+1];
458              exp += x[0];
459
460              if(m_Classes[i]==1){
461                bag -= exp/(double)nI;
462                for(int q=0; q<grad.length; q++){
463                  double m = 1.0;
464                  if(q>0) m=m_Data[i][q-1][j];
465                  sumX[q] -= m/(double)nI;
466                }
467              }
468              else{
469                bag += exp/(double)nI;
470                for(int q=0; q<grad.length; q++){
471                  double m = 1.0;
472                  if(q>0) m=m_Data[i][q-1][j];
473                  sumX[q] += m/(double)nI;
474                }     
475              }
476            }
477
478            for(int p=0; p<x.length; p++)
479              grad[p] += Math.exp(bag)*sumX[p]/(1.0+Math.exp(bag));
480          }
481          break;
482      }
483
484      // ridge: note that intercepts NOT included
485      for(int r=1; r<x.length; r++){
486        grad[r] += 2.0*m_Ridge*x[r];
487      }
488
489      return grad;
490    }
491   
492    /**
493     * Returns the revision string.
494     *
495     * @return          the revision
496     */
497    public String getRevision() {
498      return RevisionUtils.extract("$Revision: 5928 $");
499    }
500  }
501
502  /**
503   * Returns default capabilities of the classifier.
504   *
505   * @return      the capabilities of this classifier
506   */
507  public Capabilities getCapabilities() {
508    Capabilities result = super.getCapabilities();
509    result.disableAll();
510
511    // attributes
512    result.enable(Capability.NOMINAL_ATTRIBUTES);
513    result.enable(Capability.RELATIONAL_ATTRIBUTES);
514    result.enable(Capability.MISSING_VALUES);
515
516    // class
517    result.enable(Capability.BINARY_CLASS);
518    result.enable(Capability.MISSING_CLASS_VALUES);
519   
520    // other
521    result.enable(Capability.ONLY_MULTIINSTANCE);
522   
523    return result;
524  }
525
526  /**
527   * Returns the capabilities of this multi-instance classifier for the
528   * relational data.
529   *
530   * @return            the capabilities of this object
531   * @see               Capabilities
532   */
533  public Capabilities getMultiInstanceCapabilities() {
534    Capabilities result = super.getCapabilities();
535    result.disableAll();
536   
537    // attributes
538    result.enable(Capability.NOMINAL_ATTRIBUTES);
539    result.enable(Capability.NUMERIC_ATTRIBUTES);
540    result.enable(Capability.DATE_ATTRIBUTES);
541    result.enable(Capability.MISSING_VALUES);
542
543    // class
544    result.disableAllClasses();
545    result.enable(Capability.NO_CLASS);
546   
547    return result;
548  }
549
550  /**
551   * Builds the classifier
552   *
553   * @param train the training data to be used for generating the
554   * boosted classifier.
555   * @throws Exception if the classifier could not be built successfully
556   */
557  public void buildClassifier(Instances train) throws Exception {
558    // can classifier handle the data?
559    getCapabilities().testWithFail(train);
560
561    // remove instances with missing class
562    train = new Instances(train);
563    train.deleteWithMissingClass();
564
565    m_NumClasses = train.numClasses();
566
567    int nR = train.attribute(1).relation().numAttributes();
568    int nC = train.numInstances();
569
570    m_Data  = new double [nC][nR][];              // Data values
571    m_Classes  = new int [nC];                    // Class values
572    m_Attributes = train.attribute(1).relation();
573
574    xMean = new double [nR];             // Mean of mean
575    xSD   = new double [nR];             // Mode of stddev
576
577    double sY1=0, sY0=0, totIns=0;                          // Number of classes
578    int[] missingbags = new int[nR];
579
580    if (m_Debug) {
581      System.out.println("Extracting data...");
582    }
583
584    for(int h=0; h<m_Data.length; h++){
585      Instance current = train.instance(h);
586      m_Classes[h] = (int)current.classValue();  // Class value starts from 0
587      Instances currInsts = current.relationalValue(1);
588      int nI = currInsts.numInstances();
589      totIns += (double)nI;
590
591      for (int i = 0; i < nR; i++) {           
592        // initialize m_data[][][]             
593        m_Data[h][i] = new double[nI];
594        double avg=0, std=0, num=0;
595        for (int k=0; k<nI; k++){
596          if(!currInsts.instance(k).isMissing(i)){
597            m_Data[h][i][k] = currInsts.instance(k).value(i);
598            avg += m_Data[h][i][k];
599            std += m_Data[h][i][k]*m_Data[h][i][k];
600            num++;
601          }
602          else
603            m_Data[h][i][k] = Double.NaN;
604        }
605       
606        if(num > 0){
607          xMean[i] += avg/num;
608          xSD[i] += std/num;
609        }
610        else
611          missingbags[i]++;
612      }     
613
614      // Class count   
615      if (m_Classes[h] == 1)
616        sY1++;
617      else
618        sY0++;
619    }
620
621    for (int j = 0; j < nR; j++) {
622      xMean[j] = xMean[j]/(double)(nC-missingbags[j]);
623      xSD[j] = Math.sqrt(Math.abs(xSD[j]/((double)(nC-missingbags[j])-1.0)
624            -xMean[j]*xMean[j]*(double)(nC-missingbags[j])/
625            ((double)(nC-missingbags[j])-1.0)));
626    }
627
628    if (m_Debug) {         
629      // Output stats about input data
630      System.out.println("Descriptives...");
631      System.out.println(sY0 + " bags have class 0 and " +
632          sY1 + " bags have class 1");
633      System.out.println("\n Variable     Avg       SD    ");
634      for (int j = 0; j < nR; j++) 
635        System.out.println(Utils.doubleToString(j,8,4) 
636            + Utils.doubleToString(xMean[j], 10, 4) 
637            + Utils.doubleToString(xSD[j], 10,4));
638    }
639
640    // Normalise input data and remove ignored attributes
641    for (int i = 0; i < nC; i++) {
642      for (int j = 0; j < nR; j++) {
643        for(int k=0; k < m_Data[i][j].length; k++){
644          if(xSD[j] != 0){
645            if(!Double.isNaN(m_Data[i][j][k]))
646              m_Data[i][j][k] = (m_Data[i][j][k] - xMean[j]) / xSD[j];
647            else
648              m_Data[i][j][k] = 0;
649          }
650        }
651      }
652    }
653
654    if (m_Debug) {
655      System.out.println("\nIteration History..." );
656    }
657
658    double x[] = new double[nR + 1];
659    x[0] =  Math.log((sY1+1.0) / (sY0+1.0));
660    double[][] b = new double[2][x.length];
661    b[0][0] = Double.NaN;
662    b[1][0] = Double.NaN;
663    for (int q=1; q < x.length;q++){
664      x[q] = 0.0;               
665      b[0][q] = Double.NaN;
666      b[1][q] = Double.NaN;
667    }
668
669    OptEng opt = new OptEng(m_AlgorithmType);   
670    opt.setDebug(m_Debug);
671    m_Par = opt.findArgmin(x, b);
672    while(m_Par==null){
673      m_Par = opt.getVarbValues();
674      if (m_Debug)
675        System.out.println("200 iterations finished, not enough!");
676      m_Par = opt.findArgmin(m_Par, b);
677    }
678    if (m_Debug)
679      System.out.println(" -------------<Converged>--------------");
680
681    // feature selection use
682    if (m_AlgorithmType == ALGORITHMTYPE_ARITHMETIC) {
683      double[] fs = new double[nR];
684      for(int k=1; k<nR+1; k++)
685        fs[k-1] = Math.abs(m_Par[k]);
686      int[] idx = Utils.sort(fs);
687      double max = fs[idx[idx.length-1]];
688      for(int k=idx.length-1; k>=0; k--)
689        System.out.println(m_Attributes.attribute(idx[k]).name()+"\t"+(fs[idx[k]]*100/max));
690    }
691
692    // Convert coefficients back to non-normalized attribute units
693    for(int j = 1; j < nR+1; j++) {
694      if (xSD[j-1] != 0) {
695        m_Par[j] /= xSD[j-1];
696        m_Par[0] -= m_Par[j] * xMean[j-1];
697      }
698    }
699  }             
700
701  /**
702   * Computes the distribution for a given exemplar
703   *
704   * @param exmp the exemplar for which distribution is computed
705   * @return the distribution
706   * @throws Exception if the distribution can't be computed successfully
707   */
708  public double[] distributionForInstance(Instance exmp) 
709    throws Exception {
710
711    // Extract the data
712    Instances ins = exmp.relationalValue(1);
713    int nI = ins.numInstances(), nA = ins.numAttributes();
714    double[][] dat = new double [nI][nA+1];
715    for(int j=0; j<nI; j++){
716      dat[j][0]=1.0;
717      int idx=1;
718      for(int k=0; k<nA; k++){ 
719        if(!ins.instance(j).isMissing(k))
720          dat[j][idx] = ins.instance(j).value(k);
721        else
722          dat[j][idx] = xMean[idx-1];
723        idx++;
724      }
725    }
726
727    // Compute the probability of the bag
728    double [] distribution = new double[2];
729    switch (m_AlgorithmType) {
730      case ALGORITHMTYPE_DEFAULT:
731        distribution[0]=0.0;  // Log-Prob. for class 0
732
733        for(int i=0; i<nI; i++){
734          double exp = 0.0; 
735          for(int r=0; r<m_Par.length; r++)
736            exp += m_Par[r]*dat[i][r];
737          exp = Math.exp(exp);
738
739          // Prob. updated for one instance
740          distribution[0] -= Math.log(1.0+exp);
741        }
742
743        // Prob. for class 0
744        distribution[0] = Math.exp(distribution[0]);
745        // Prob. for class 1
746        distribution[1] = 1.0 - distribution[0];
747        break;
748     
749      case ALGORITHMTYPE_ARITHMETIC:
750        distribution[0]=0.0;  // Prob. for class 0
751
752        for(int i=0; i<nI; i++){
753          double exp = 0.0;
754          for(int r=0; r<m_Par.length; r++)
755            exp += m_Par[r]*dat[i][r];
756          exp = Math.exp(exp);
757
758          // Prob. updated for one instance
759          distribution[0] += 1.0/(1.0+exp);
760        }
761
762        // Prob. for class 0
763        distribution[0] /= (double)nI;
764        // Prob. for class 1
765        distribution[1] = 1.0 - distribution[0];
766        break;
767
768      case ALGORITHMTYPE_GEOMETRIC:
769        for(int i=0; i<nI; i++){
770          double exp = 0.0;
771          for(int r=0; r<m_Par.length; r++)
772            exp += m_Par[r]*dat[i][r];
773          distribution[1] += exp/(double)nI; 
774        }
775
776        // Prob. for class 1
777        distribution[1] = 1.0/(1.0+Math.exp(-distribution[1]));
778        // Prob. for class 0
779        distribution[0] = 1-distribution[1];
780        break;
781    }
782
783    return distribution;
784  }
785
786  /**
787   * Gets a string describing the classifier.
788   *
789   * @return a string describing the classifer built.
790   */
791  public String toString() {
792
793    String result = "Modified Logistic Regression";
794    if (m_Par == null) {
795      return result + ": No model built yet.";
796    }
797
798    result += "\nMean type: " + getAlgorithmType().getSelectedTag().getReadable() + "\n";
799    result += "\nCoefficients...\n"
800      + "Variable      Coeff.\n";
801    for (int j = 1, idx=0; j < m_Par.length; j++, idx++) {
802      result += m_Attributes.attribute(idx).name();
803      result += " "+Utils.doubleToString(m_Par[j], 12, 4); 
804      result += "\n";
805    }
806
807    result += "Intercept:";
808    result += " "+Utils.doubleToString(m_Par[0], 10, 4); 
809    result += "\n";
810
811    result += "\nOdds Ratios...\n"
812      + "Variable         O.R.\n";
813    for (int j = 1, idx=0; j < m_Par.length; j++, idx++) {
814      result += " " + m_Attributes.attribute(idx).name(); 
815      double ORc = Math.exp(m_Par[j]);
816      result += " " + ((ORc > 1e10) ?  "" + ORc : Utils.doubleToString(ORc, 12, 4));
817    }
818    result += "\n";
819    return result;
820  }
821 
822  /**
823   * Returns the revision string.
824   *
825   * @return            the revision
826   */
827  public String getRevision() {
828    return RevisionUtils.extract("$Revision: 5928 $");
829  }
830
831  /**
832   * Main method for testing this class.
833   *
834   * @param argv should contain the command line arguments to the
835   * scheme (see Evaluation)
836   */
837  public static void main(String[] argv) {
838    runClassifier(new MILR(), argv);
839  }
840}
Note: See TracBrowser for help on using the repository browser.