source: src/main/java/weka/classifiers/mi/MDD.java @ 10

Last change on this file since 10 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 18.3 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 * MDD.java
19 * Copyright (C) 2005 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.mi;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.core.Capabilities;
28import weka.core.FastVector;
29import weka.core.Instance;
30import weka.core.Instances;
31import weka.core.MultiInstanceCapabilitiesHandler;
32import weka.core.Optimization;
33import weka.core.Option;
34import weka.core.OptionHandler;
35import weka.core.RevisionUtils;
36import weka.core.SelectedTag;
37import weka.core.Tag;
38import weka.core.TechnicalInformation;
39import weka.core.TechnicalInformationHandler;
40import weka.core.Utils;
41import weka.core.Capabilities.Capability;
42import weka.core.TechnicalInformation.Field;
43import weka.core.TechnicalInformation.Type;
44import weka.filters.Filter;
45import weka.filters.unsupervised.attribute.Normalize;
46import weka.filters.unsupervised.attribute.ReplaceMissingValues;
47import weka.filters.unsupervised.attribute.Standardize;
48
49import java.util.Enumeration;
50import java.util.Vector;
51
52/**
53 <!-- globalinfo-start -->
54 * Modified Diverse Density algorithm, with collective assumption.<br/>
55 * <br/>
56 * More information about DD:<br/>
57 * <br/>
58 * Oded Maron (1998). Learning from ambiguity.<br/>
59 * <br/>
60 * O. Maron, T. Lozano-Perez (1998). A Framework for Multiple Instance Learning. Neural Information Processing Systems. 10.
61 * <p/>
62 <!-- globalinfo-end -->
63 *
64 <!-- technical-bibtex-start -->
65 * BibTeX:
66 * <pre>
67 * &#64;phdthesis{Maron1998,
68 *    author = {Oded Maron},
69 *    school = {Massachusetts Institute of Technology},
70 *    title = {Learning from ambiguity},
71 *    year = {1998}
72 * }
73 *
74 * &#64;article{Maron1998,
75 *    author = {O. Maron and T. Lozano-Perez},
76 *    journal = {Neural Information Processing Systems},
77 *    title = {A Framework for Multiple Instance Learning},
78 *    volume = {10},
79 *    year = {1998}
80 * }
81 * </pre>
82 * <p/>
83 <!-- technical-bibtex-end -->
84 *
85 <!-- options-start -->
86 * Valid options are: <p/>
87 *
88 * <pre> -D
89 *  Turn on debugging output.</pre>
90 *
91 * <pre> -N &lt;num&gt;
92 *  Whether to 0=normalize/1=standardize/2=neither.
93 *  (default 1=standardize)</pre>
94 *
95 <!-- options-end -->
96 *   
97 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
98 * @author Xin Xu (xx5@cs.waikato.ac.nz)
99 * @version $Revision: 5928 $
100 */
101public class MDD 
102  extends AbstractClassifier
103  implements OptionHandler, MultiInstanceCapabilitiesHandler,
104             TechnicalInformationHandler {
105 
106  /** for serialization */
107  static final long serialVersionUID = -7273119490545290581L;
108
109  /** The index of the class attribute */
110  protected int m_ClassIndex;
111
112  protected double[] m_Par;
113
114  /** The number of the class labels */
115  protected int m_NumClasses;
116
117  /** Class labels for each bag */
118  protected int[] m_Classes;
119
120  /** MI data */ 
121  protected double[][][] m_Data;
122
123  /** All attribute names */
124  protected Instances m_Attributes;
125
126  /** The filter used to standardize/normalize all values. */
127  protected Filter m_Filter =null;
128
129  /** Whether to normalize/standardize/neither, default:standardize */
130  protected int m_filterType = FILTER_STANDARDIZE;
131
132  /** Normalize training data */
133  public static final int FILTER_NORMALIZE = 0;
134  /** Standardize training data */
135  public static final int FILTER_STANDARDIZE = 1;
136  /** No normalization/standardization */
137  public static final int FILTER_NONE = 2;
138  /** The filter to apply to the training data */
139  public static final Tag [] TAGS_FILTER = {
140    new Tag(FILTER_NORMALIZE, "Normalize training data"),
141    new Tag(FILTER_STANDARDIZE, "Standardize training data"),
142    new Tag(FILTER_NONE, "No normalization/standardization"),
143  };
144
145  /** The filter used to get rid of missing values. */
146  protected ReplaceMissingValues m_Missing = new ReplaceMissingValues();
147
148  /**
149   * Returns a string describing this filter
150   *
151   * @return a description of the filter suitable for
152   * displaying in the explorer/experimenter gui
153   */
154  public String globalInfo() {
155    return 
156        "Modified Diverse Density algorithm, with collective assumption.\n\n"
157      + "More information about DD:\n\n"
158      + getTechnicalInformation().toString();
159  }
160
161  /**
162   * Returns an instance of a TechnicalInformation object, containing
163   * detailed information about the technical background of this class,
164   * e.g., paper reference or book this class is based on.
165   *
166   * @return the technical information about this class
167   */
168  public TechnicalInformation getTechnicalInformation() {
169    TechnicalInformation        result;
170    TechnicalInformation        additional;
171   
172    result = new TechnicalInformation(Type.PHDTHESIS);
173    result.setValue(Field.AUTHOR, "Oded Maron");
174    result.setValue(Field.YEAR, "1998");
175    result.setValue(Field.TITLE, "Learning from ambiguity");
176    result.setValue(Field.SCHOOL, "Massachusetts Institute of Technology");
177   
178    additional = result.add(Type.ARTICLE);
179    additional.setValue(Field.AUTHOR, "O. Maron and T. Lozano-Perez");
180    additional.setValue(Field.YEAR, "1998");
181    additional.setValue(Field.TITLE, "A Framework for Multiple Instance Learning");
182    additional.setValue(Field.JOURNAL, "Neural Information Processing Systems");
183    additional.setValue(Field.VOLUME, "10");
184   
185    return result;
186  }
187
188  /**
189   * Returns an enumeration describing the available options
190   *
191   * @return an enumeration of all the available options
192   */
193  public Enumeration listOptions() {
194    Vector result = new Vector();
195   
196    result.addElement(new Option(
197          "\tTurn on debugging output.",
198          "D", 0, "-D"));
199   
200    result.addElement(new Option(
201          "\tWhether to 0=normalize/1=standardize/2=neither.\n"
202          + "\t(default 1=standardize)",
203          "N", 1, "-N <num>"));
204
205    return result.elements();
206  }
207
208  /**
209   * Parses a given list of options.
210   *
211   * @param options the list of options as an array of strings
212   * @throws Exception if an option is not supported
213   */
214  public void setOptions(String[] options) throws Exception {
215    setDebug(Utils.getFlag('D', options));
216
217    String nString = Utils.getOption('N', options);
218    if (nString.length() != 0) {
219      setFilterType(new SelectedTag(Integer.parseInt(nString), TAGS_FILTER));
220    } else {
221      setFilterType(new SelectedTag(FILTER_STANDARDIZE, TAGS_FILTER));
222    }     
223  }
224
225  /**
226   * Gets the current settings of the classifier.
227   *
228   * @return an array of strings suitable for passing to setOptions
229   */
230  public String[] getOptions() {
231    Vector        result;
232   
233    result = new Vector();
234
235    if (getDebug())
236      result.add("-D");
237   
238    result.add("-N");
239    result.add("" + m_filterType);
240
241    return (String[]) result.toArray(new String[result.size()]);
242  }
243
244  /**
245   * Returns the tip text for this property
246   *
247   * @return tip text for this property suitable for
248   * displaying in the explorer/experimenter gui
249   */
250  public String filterTypeTipText() {
251    return "The filter type for transforming the training data.";
252  }
253
254  /**
255   * Gets how the training data will be transformed. Will be one of
256   * FILTER_NORMALIZE, FILTER_STANDARDIZE, FILTER_NONE.
257   *
258   * @return the filtering mode
259   */
260  public SelectedTag getFilterType() {
261    return new SelectedTag(m_filterType, TAGS_FILTER);
262  }
263
264  /**
265   * Sets how the training data will be transformed. Should be one of
266   * FILTER_NORMALIZE, FILTER_STANDARDIZE, FILTER_NONE.
267   *
268   * @param newType the new filtering mode
269   */
270  public void setFilterType(SelectedTag newType) {
271
272    if (newType.getTags() == TAGS_FILTER) {
273      m_filterType = newType.getSelectedTag().getID();
274    }
275  }
276
277
278  private class OptEng 
279    extends Optimization {
280   
281    /**
282     * Evaluate objective function
283     * @param x the current values of variables
284     * @return the value of the objective function
285     */
286    protected double objectiveFunction(double[] x){
287      double nll = 0; // -LogLikelihood
288      for(int i=0; i<m_Classes.length; i++){ // ith bag
289        int nI = m_Data[i][0].length; // numInstances in ith bag
290        double bag = 0;  // NLL of each bag
291
292        for(int j=0; j<nI; j++){
293          double ins=0.0; 
294          for(int k=0; k<m_Data[i].length; k++) {
295            ins += (m_Data[i][k][j]-x[k*2])*(m_Data[i][k][j]-x[k*2])/
296              (x[k*2+1]*x[k*2+1]);
297          }
298          ins = Math.exp(-ins); 
299
300          if(m_Classes[i] == 1)
301            bag += ins/(double)nI;
302          else
303            bag += (1.0-ins)/(double)nI;   
304        }               
305        if(bag<=m_Zero) bag=m_Zero; 
306        nll -= Math.log(bag);
307      }         
308
309      return nll;
310    }
311
312    /**
313     * Evaluate Jacobian vector
314     * @param x the current values of variables
315     * @return the gradient vector
316     */
317    protected double[] evaluateGradient(double[] x){
318      double[] grad = new double[x.length];
319      for(int i=0; i<m_Classes.length; i++){ // ith bag
320        int nI = m_Data[i][0].length; // numInstances in ith bag
321
322        double denom=0.0;
323        double[] numrt = new double[x.length];
324
325        for(int j=0; j<nI; j++){
326          double exp=0.0;
327          for(int k=0; k<m_Data[i].length; k++)
328            exp += (m_Data[i][k][j]-x[k*2])*(m_Data[i][k][j]-x[k*2])/
329              (x[k*2+1]*x[k*2+1]);                     
330          exp = Math.exp(-exp);
331          if(m_Classes[i]==1)
332            denom += exp;
333          else
334            denom += (1.0-exp);           
335
336          // Instance-wise update
337          for(int p=0; p<m_Data[i].length; p++){  // pth variable
338            numrt[2*p] += exp*2.0*(x[2*p]-m_Data[i][p][j])/
339              (x[2*p+1]*x[2*p+1]);
340            numrt[2*p+1] += 
341              exp*(x[2*p]-m_Data[i][p][j])*(x[2*p]-m_Data[i][p][j])/
342              (x[2*p+1]*x[2*p+1]*x[2*p+1]);
343          }                     
344        }
345
346        if(denom <= m_Zero){
347          denom = m_Zero;
348        }
349
350        // Bag-wise update
351        for(int q=0; q<m_Data[i].length; q++){
352          if(m_Classes[i]==1){
353            grad[2*q] += numrt[2*q]/denom;
354            grad[2*q+1] -= numrt[2*q+1]/denom;
355          }else{
356            grad[2*q] -= numrt[2*q]/denom;
357            grad[2*q+1] += numrt[2*q+1]/denom;
358          }
359        }
360      }
361
362      return grad;
363    }
364   
365    /**
366     * Returns the revision string.
367     *
368     * @return          the revision
369     */
370    public String getRevision() {
371      return RevisionUtils.extract("$Revision: 5928 $");
372    }
373  }
374
375  /**
376   * Returns default capabilities of the classifier.
377   *
378   * @return      the capabilities of this classifier
379   */
380  public Capabilities getCapabilities() {
381    Capabilities result = super.getCapabilities();
382    result.disableAll();
383
384    // attributes
385    result.enable(Capability.NOMINAL_ATTRIBUTES);
386    result.enable(Capability.RELATIONAL_ATTRIBUTES);
387    result.enable(Capability.MISSING_VALUES);
388
389    // class
390    result.enable(Capability.BINARY_CLASS);
391    result.enable(Capability.MISSING_CLASS_VALUES);
392   
393    // other
394    result.enable(Capability.ONLY_MULTIINSTANCE);
395   
396    return result;
397  }
398
399  /**
400   * Returns the capabilities of this multi-instance classifier for the
401   * relational data.
402   *
403   * @return            the capabilities of this object
404   * @see               Capabilities
405   */
406  public Capabilities getMultiInstanceCapabilities() {
407    Capabilities result = super.getCapabilities();
408    result.disableAll();
409   
410    // attributes
411    result.enable(Capability.NOMINAL_ATTRIBUTES);
412    result.enable(Capability.NUMERIC_ATTRIBUTES);
413    result.enable(Capability.DATE_ATTRIBUTES);
414    result.enable(Capability.MISSING_VALUES);
415
416    // class
417    result.disableAllClasses();
418    result.enable(Capability.NO_CLASS);
419   
420    return result;
421  }
422
423  /**
424   * Builds the classifier
425   *
426   * @param train the training data to be used for generating the
427   * boosted classifier.
428   * @throws Exception if the classifier could not be built successfully
429   */
430  public void buildClassifier(Instances train) throws Exception {
431    // can classifier handle the data?
432    getCapabilities().testWithFail(train);
433
434    // remove instances with missing class
435    train = new Instances(train);
436    train.deleteWithMissingClass();
437   
438    m_ClassIndex = train.classIndex();
439    m_NumClasses = train.numClasses();
440
441    int nR = train.attribute(1).relation().numAttributes();
442    int nC = train.numInstances();
443    int [] bagSize=new int [nC];
444    Instances datasets= new Instances(train.attribute(1).relation(),0);
445
446    m_Data  = new double [nC][nR][];              // Data values
447    m_Classes  = new int [nC];                    // Class values
448    m_Attributes = datasets.stringFreeStructure();             
449    double sY1=0, sY0=0;                          // Number of classes
450
451    if (m_Debug) {
452      System.out.println("Extracting data...");
453    }
454    FastVector maxSzIdx=new FastVector();
455    int maxSz=0;
456
457    for(int h=0; h<nC; h++){
458      Instance current = train.instance(h);
459      m_Classes[h] = (int)current.classValue();  // Class value starts from 0
460      Instances currInsts = current.relationalValue(1);
461      int nI = currInsts.numInstances();
462      bagSize[h]=nI;
463
464      for (int i=0; i<nI;i++){
465        Instance inst=currInsts.instance(i);
466        datasets.add(inst);
467      }
468
469      if(m_Classes[h]==1){
470        if(nI>maxSz){
471          maxSz=nI;
472          maxSzIdx=new FastVector(1);
473          maxSzIdx.addElement(new Integer(h));
474        }
475        else if(nI == maxSz)
476          maxSzIdx.addElement(new Integer(h));
477      }
478    }
479
480    /* filter the training data */
481    if (m_filterType == FILTER_STANDARDIZE) 
482      m_Filter = new Standardize();
483    else if (m_filterType == FILTER_NORMALIZE)
484      m_Filter = new Normalize();
485    else 
486      m_Filter = null; 
487
488    if (m_Filter!=null) {
489      m_Filter.setInputFormat(datasets);
490      datasets = Filter.useFilter(datasets, m_Filter); 
491    }
492
493    m_Missing.setInputFormat(datasets);
494    datasets = Filter.useFilter(datasets, m_Missing);
495
496    int instIndex=0;
497    int start=0;       
498    for(int h=0; h<nC; h++)  { 
499      for (int i = 0; i < datasets.numAttributes(); i++) {
500        // initialize m_data[][][]
501        m_Data[h][i] = new double[bagSize[h]];
502        instIndex=start;
503        for (int k=0; k<bagSize[h]; k++){
504          m_Data[h][i][k]=datasets.instance(instIndex).value(i);
505          instIndex ++;
506        }
507      }
508      start=instIndex;
509
510      // Class count   
511      if (m_Classes[h] == 1)
512        sY1++;         
513      else
514        sY0++;
515    }
516
517    if (m_Debug) {
518      System.out.println("\nIteration History..." );
519    }
520
521    double[] x = new double[nR*2], tmp = new double[x.length];
522    double[][] b = new double[2][x.length]; 
523
524    OptEng opt;
525    double nll, bestnll = Double.MAX_VALUE;
526    for (int t=0; t<x.length; t++){
527      b[0][t] = Double.NaN;
528      b[1][t] = Double.NaN; 
529    }
530
531    // Largest positive exemplar
532    for(int s=0; s<maxSzIdx.size(); s++){
533      int exIdx = ((Integer)maxSzIdx.elementAt(s)).intValue();
534      for(int p=0; p<m_Data[exIdx][0].length; p++){
535        for (int q=0; q < nR;q++){
536          x[2*q] = m_Data[exIdx][q][p];  // pick one instance
537          x[2*q+1] = 1.0;
538        }               
539
540        opt = new OptEng();     
541        tmp = opt.findArgmin(x, b);
542        while(tmp==null){
543          tmp = opt.getVarbValues();
544          if (m_Debug)
545            System.out.println("200 iterations finished, not enough!");
546          tmp = opt.findArgmin(tmp, b);
547        }
548        nll = opt.getMinFunction();
549
550        if(nll < bestnll){
551          bestnll = nll;
552          m_Par = tmp;
553          if (m_Debug)
554            System.out.println("!!!!!!!!!!!!!!!!Smaller NLL found: "+nll);
555        }
556        if (m_Debug)
557          System.out.println(exIdx+":  -------------<Converged>--------------");
558      }
559    }
560    }           
561
562  /**
563   * Computes the distribution for a given exemplar
564   *
565   * @param exmp the exemplar for which distribution is computed
566   * @return the distribution
567   * @throws Exception if the distribution can't be computed successfully
568   */
569  public double[] distributionForInstance(Instance exmp) 
570    throws Exception {
571
572    // Extract the data
573    Instances ins = exmp.relationalValue(1);
574    if(m_Filter!=null)
575      ins = Filter.useFilter(ins, m_Filter);
576
577    ins = Filter.useFilter(ins, m_Missing);
578
579    int nI = ins.numInstances(), nA = ins.numAttributes();
580    double[][] dat = new double [nI][nA];
581    for(int j=0; j<nI; j++){
582      for(int k=0; k<nA; k++){ 
583        dat[j][k] = ins.instance(j).value(k);
584      }
585    }
586
587    // Compute the probability of the bag
588    double [] distribution = new double[2];
589    distribution[1]=0.0;  // Prob. for class 1
590
591    for(int i=0; i<nI; i++){
592      double exp = 0.0;
593      for(int r=0; r<nA; r++)
594        exp += (m_Par[r*2]-dat[i][r])*(m_Par[r*2]-dat[i][r])/
595          ((m_Par[r*2+1])*(m_Par[r*2+1]));
596      exp = Math.exp(-exp);
597
598      // Prob. updated for one instance
599      distribution[1] += exp/(double)nI;
600      distribution[0] += (1.0-exp)/(double)nI;
601    }
602
603    return distribution;
604  }
605
606  /**
607   * Gets a string describing the classifier.
608   *
609   * @return a string describing the classifer built.
610   */
611  public String toString() {
612
613    String result = "Modified Logistic Regression";
614    if (m_Par == null) {
615      return result + ": No model built yet.";
616    }
617
618    result += "\nCoefficients...\n"
619      + "Variable      Coeff.\n";
620    for (int j = 0, idx=0; j < m_Par.length/2; j++, idx++) {
621
622      result += m_Attributes.attribute(idx).name();
623      result += " "+Utils.doubleToString(m_Par[j*2], 12, 4); 
624      result += " "+Utils.doubleToString(m_Par[j*2+1], 12, 4)+"\n";
625    }
626
627    return result;
628  }
629 
630  /**
631   * Returns the revision string.
632   *
633   * @return            the revision
634   */
635  public String getRevision() {
636    return RevisionUtils.extract("$Revision: 5928 $");
637  }
638
639  /**
640   * Main method for testing this class.
641   *
642   * @param argv should contain the command line arguments to the
643   * scheme (see Evaluation)
644   */
645  public static void main(String[] argv) {
646    runClassifier(new MDD(), argv);
647  }
648}
Note: See TracBrowser for help on using the repository browser.