source: src/main/java/weka/classifiers/mi/MIDD.java @ 20

Last change on this file since 20 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 18.7 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 * MIDD.java
19 * Copyright (C) 2005 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.mi;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.core.Capabilities;
28import weka.core.FastVector;
29import weka.core.Instance;
30import weka.core.Instances;
31import weka.core.MultiInstanceCapabilitiesHandler;
32import weka.core.Optimization;
33import weka.core.Option;
34import weka.core.OptionHandler;
35import weka.core.RevisionUtils;
36import weka.core.SelectedTag;
37import weka.core.Tag;
38import weka.core.TechnicalInformation;
39import weka.core.TechnicalInformationHandler;
40import weka.core.Utils;
41import weka.core.Capabilities.Capability;
42import weka.core.TechnicalInformation.Field;
43import weka.core.TechnicalInformation.Type;
44import weka.filters.Filter;
45import weka.filters.unsupervised.attribute.Normalize;
46import weka.filters.unsupervised.attribute.ReplaceMissingValues;
47import weka.filters.unsupervised.attribute.Standardize;
48
49import java.util.Enumeration;
50import java.util.Vector;
51
52/**
53 <!-- globalinfo-start -->
54 * Re-implement the Diverse Density algorithm, changes the testing procedure.<br/>
55 * <br/>
56 * Oded Maron (1998). Learning from ambiguity.<br/>
57 * <br/>
58 * O. Maron, T. Lozano-Perez (1998). A Framework for Multiple Instance Learning. Neural Information Processing Systems. 10.
59 * <p/>
60 <!-- globalinfo-end -->
61 *
62 <!-- technical-bibtex-start -->
63 * BibTeX:
64 * <pre>
65 * &#64;phdthesis{Maron1998,
66 *    author = {Oded Maron},
67 *    school = {Massachusetts Institute of Technology},
68 *    title = {Learning from ambiguity},
69 *    year = {1998}
70 * }
71 *
72 * &#64;article{Maron1998,
73 *    author = {O. Maron and T. Lozano-Perez},
74 *    journal = {Neural Information Processing Systems},
75 *    title = {A Framework for Multiple Instance Learning},
76 *    volume = {10},
77 *    year = {1998}
78 * }
79 * </pre>
80 * <p/>
81 <!-- technical-bibtex-end -->
82 *
83 <!-- options-start -->
84 * Valid options are: <p/>
85 *
86 * <pre> -D
87 *  Turn on debugging output.</pre>
88 *
89 * <pre> -N &lt;num&gt;
90 *  Whether to 0=normalize/1=standardize/2=neither.
91 *  (default 1=standardize)</pre>
92 *
93 <!-- options-end -->
94 *
95 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
96 * @author Xin Xu (xx5@cs.waikato.ac.nz)
97 * @version $Revision: 5928 $
98 */
99public class MIDD 
100  extends AbstractClassifier
101  implements OptionHandler, MultiInstanceCapabilitiesHandler,
102             TechnicalInformationHandler {
103
104  /** for serialization */
105  static final long serialVersionUID = 4263507733600536168L;
106 
107  /** The index of the class attribute */
108  protected int m_ClassIndex;
109
110  protected double[] m_Par;
111
112  /** The number of the class labels */
113  protected int m_NumClasses;
114
115  /** Class labels for each bag */
116  protected int[] m_Classes;
117
118  /** MI data */ 
119  protected double[][][] m_Data;
120
121  /** All attribute names */
122  protected Instances m_Attributes;
123
124  /** The filter used to standardize/normalize all values. */
125  protected Filter m_Filter = null;
126
127  /** Whether to normalize/standardize/neither, default:standardize */
128  protected int m_filterType = FILTER_STANDARDIZE;
129
130  /** Normalize training data */
131  public static final int FILTER_NORMALIZE = 0;
132  /** Standardize training data */
133  public static final int FILTER_STANDARDIZE = 1;
134  /** No normalization/standardization */
135  public static final int FILTER_NONE = 2;
136  /** The filter to apply to the training data */
137  public static final Tag [] TAGS_FILTER = {
138    new Tag(FILTER_NORMALIZE, "Normalize training data"),
139    new Tag(FILTER_STANDARDIZE, "Standardize training data"),
140    new Tag(FILTER_NONE, "No normalization/standardization"),
141  };
142
143  /** The filter used to get rid of missing values. */
144  protected ReplaceMissingValues m_Missing = new ReplaceMissingValues();
145
146  /**
147   * Returns a string describing this filter
148   *
149   * @return a description of the filter suitable for
150   * displaying in the explorer/experimenter gui
151   */
152  public String globalInfo() {
153    return 
154        "Re-implement the Diverse Density algorithm, changes the testing "
155      + "procedure.\n\n"
156      + getTechnicalInformation().toString();
157  }
158
159  /**
160   * Returns an instance of a TechnicalInformation object, containing
161   * detailed information about the technical background of this class,
162   * e.g., paper reference or book this class is based on.
163   *
164   * @return the technical information about this class
165   */
166  public TechnicalInformation getTechnicalInformation() {
167    TechnicalInformation        result;
168    TechnicalInformation        additional;
169   
170    result = new TechnicalInformation(Type.PHDTHESIS);
171    result.setValue(Field.AUTHOR, "Oded Maron");
172    result.setValue(Field.YEAR, "1998");
173    result.setValue(Field.TITLE, "Learning from ambiguity");
174    result.setValue(Field.SCHOOL, "Massachusetts Institute of Technology");
175   
176    additional = result.add(Type.ARTICLE);
177    additional.setValue(Field.AUTHOR, "O. Maron and T. Lozano-Perez");
178    additional.setValue(Field.YEAR, "1998");
179    additional.setValue(Field.TITLE, "A Framework for Multiple Instance Learning");
180    additional.setValue(Field.JOURNAL, "Neural Information Processing Systems");
181    additional.setValue(Field.VOLUME, "10");
182   
183    return result;
184  }
185
186  /**
187   * Returns an enumeration describing the available options
188   *
189   * @return an enumeration of all the available options
190   */
191  public Enumeration listOptions() {
192    Vector result = new Vector();
193
194    result.addElement(new Option(
195          "\tTurn on debugging output.",
196          "D", 0, "-D"));
197
198    result.addElement(new Option(
199          "\tWhether to 0=normalize/1=standardize/2=neither.\n"
200          + "\t(default 1=standardize)",
201          "N", 1, "-N <num>"));
202
203    return result.elements();
204  }
205
206  /**
207   * Parses a given list of options. <p/>
208   *     
209   <!-- options-start -->
210   * Valid options are: <p/>
211   *
212   * <pre> -D
213   *  Turn on debugging output.</pre>
214   *
215   * <pre> -N &lt;num&gt;
216   *  Whether to 0=normalize/1=standardize/2=neither.
217   *  (default 1=standardize)</pre>
218   *
219   <!-- options-end -->
220   *
221   * @param options the list of options as an array of strings
222   * @throws Exception if an option is not supported
223   */
224  public void setOptions(String[] options) throws Exception {
225    setDebug(Utils.getFlag('D', options));
226
227    String nString = Utils.getOption('N', options);
228    if (nString.length() != 0) {
229      setFilterType(new SelectedTag(Integer.parseInt(nString), TAGS_FILTER));
230    } else {
231      setFilterType(new SelectedTag(FILTER_STANDARDIZE, TAGS_FILTER));
232    }     
233  }
234
235  /**
236   * Gets the current settings of the classifier.
237   *
238   * @return an array of strings suitable for passing to setOptions
239   */
240  public String[] getOptions() {
241    Vector        result;
242   
243    result = new Vector();
244
245    if (getDebug())
246      result.add("-D");
247   
248    result.add("-N");
249    result.add("" + m_filterType);
250
251    return (String[]) result.toArray(new String[result.size()]);
252  }
253
254  /**
255   * Returns the tip text for this property
256   *
257   * @return tip text for this property suitable for
258   * displaying in the explorer/experimenter gui
259   */
260  public String filterTypeTipText() {
261    return "The filter type for transforming the training data.";
262  }
263
264  /**
265   * Gets how the training data will be transformed. Will be one of
266   * FILTER_NORMALIZE, FILTER_STANDARDIZE, FILTER_NONE.
267   *
268   * @return the filtering mode
269   */
270  public SelectedTag getFilterType() {
271    return new SelectedTag(m_filterType, TAGS_FILTER);
272  }
273
274  /**
275   * Sets how the training data will be transformed. Should be one of
276   * FILTER_NORMALIZE, FILTER_STANDARDIZE, FILTER_NONE.
277   *
278   * @param newType the new filtering mode
279   */
280  public void setFilterType(SelectedTag newType) {
281
282    if (newType.getTags() == TAGS_FILTER) {
283      m_filterType = newType.getSelectedTag().getID();
284    }
285  }
286
287  private class OptEng 
288    extends Optimization {
289
290    /**
291     * Evaluate objective function
292     * @param x the current values of variables
293     * @return the value of the objective function
294     */
295    protected double objectiveFunction(double[] x){
296      double nll = 0; // -LogLikelihood
297      for(int i=0; i<m_Classes.length; i++){ // ith bag
298        int nI = m_Data[i][0].length; // numInstances in ith bag
299        double bag = 0.0;  // NLL of pos bag
300
301        for(int j=0; j<nI; j++){
302          double ins=0.0;
303          for(int k=0; k<m_Data[i].length; k++)
304            ins += (m_Data[i][k][j]-x[k*2])*(m_Data[i][k][j]-x[k*2])*
305              x[k*2+1]*x[k*2+1];
306          ins = Math.exp(-ins);
307          ins = 1.0-ins;
308
309          if(m_Classes[i] == 1)
310            bag += Math.log(ins);
311          else{
312            if(ins<=m_Zero) ins=m_Zero;
313            nll -= Math.log(ins);
314          }   
315        }               
316
317        if(m_Classes[i] == 1){
318          bag = 1.0 - Math.exp(bag);
319          if(bag<=m_Zero) bag=m_Zero;
320          nll -= Math.log(bag);
321        }
322      }         
323      return nll;
324    }
325
326    /**
327     * Evaluate Jacobian vector
328     * @param x the current values of variables
329     * @return the gradient vector
330     */
331    protected double[] evaluateGradient(double[] x){
332      double[] grad = new double[x.length];
333      for(int i=0; i<m_Classes.length; i++){ // ith bag
334        int nI = m_Data[i][0].length; // numInstances in ith bag
335
336        double denom=0.0;       
337        double[] numrt = new double[x.length];
338
339        for(int j=0; j<nI; j++){
340          double exp=0.0;
341          for(int k=0; k<m_Data[i].length; k++)
342            exp += (m_Data[i][k][j]-x[k*2])*(m_Data[i][k][j]-x[k*2])
343              *x[k*2+1]*x[k*2+1];                       
344          exp = Math.exp(-exp);
345          exp = 1.0-exp;
346          if(m_Classes[i]==1)
347            denom += Math.log(exp);                                 
348
349          if(exp<=m_Zero) exp=m_Zero;
350          // Instance-wise update
351          for(int p=0; p<m_Data[i].length; p++){  // pth variable
352            numrt[2*p] += (1.0-exp)*2.0*(x[2*p]-m_Data[i][p][j])*x[p*2+1]*x[p*2+1]
353              /exp;
354            numrt[2*p+1] += 2.0*(1.0-exp)*(x[2*p]-m_Data[i][p][j])*(x[2*p]-m_Data[i][p][j])
355              *x[p*2+1]/exp;
356          }                                         
357        }                   
358
359        // Bag-wise update
360        denom = 1.0-Math.exp(denom);
361        if(denom <= m_Zero) denom = m_Zero;
362        for(int q=0; q<m_Data[i].length; q++){
363          if(m_Classes[i]==1){
364            grad[2*q] += numrt[2*q]*(1.0-denom)/denom;
365            grad[2*q+1] += numrt[2*q+1]*(1.0-denom)/denom;
366          }else{
367            grad[2*q] -= numrt[2*q];
368            grad[2*q+1] -= numrt[2*q+1];
369          }
370        }
371      } // one bag
372
373      return grad;
374    }
375   
376    /**
377     * Returns the revision string.
378     *
379     * @return          the revision
380     */
381    public String getRevision() {
382      return RevisionUtils.extract("$Revision: 5928 $");
383    }
384  }
385
386  /**
387   * Returns default capabilities of the classifier.
388   *
389   * @return      the capabilities of this classifier
390   */
391  public Capabilities getCapabilities() {
392    Capabilities result = super.getCapabilities();
393    result.disableAll();
394
395    // attributes
396    result.enable(Capability.NOMINAL_ATTRIBUTES);
397    result.enable(Capability.RELATIONAL_ATTRIBUTES);
398    result.enable(Capability.MISSING_VALUES);
399
400    // class
401    result.enable(Capability.BINARY_CLASS);
402    result.enable(Capability.MISSING_CLASS_VALUES);
403   
404    // other
405    result.enable(Capability.ONLY_MULTIINSTANCE);
406   
407    return result;
408  }
409
410  /**
411   * Returns the capabilities of this multi-instance classifier for the
412   * relational data.
413   *
414   * @return            the capabilities of this object
415   * @see               Capabilities
416   */
417  public Capabilities getMultiInstanceCapabilities() {
418    Capabilities result = super.getCapabilities();
419    result.disableAll();
420   
421    // attributes
422    result.enable(Capability.NOMINAL_ATTRIBUTES);
423    result.enable(Capability.NUMERIC_ATTRIBUTES);
424    result.enable(Capability.DATE_ATTRIBUTES);
425    result.enable(Capability.MISSING_VALUES);
426
427    // class
428    result.disableAllClasses();
429    result.enable(Capability.NO_CLASS);
430   
431    return result;
432  }
433
434  /**
435   * Builds the classifier
436   *
437   * @param train the training data to be used for generating the
438   * boosted classifier.
439   * @throws Exception if the classifier could not be built successfully
440   */
441  public void buildClassifier(Instances train) throws Exception {
442    // can classifier handle the data?
443    getCapabilities().testWithFail(train);
444
445    // remove instances with missing class
446    train = new Instances(train);
447    train.deleteWithMissingClass();
448   
449    m_ClassIndex = train.classIndex();
450    m_NumClasses = train.numClasses();
451
452    int nR = train.attribute(1).relation().numAttributes();
453    int nC = train.numInstances();
454    FastVector maxSzIdx=new FastVector();
455    int maxSz=0;
456    int [] bagSize=new int [nC];
457    Instances datasets= new Instances(train.attribute(1).relation(),0);
458
459    m_Data  = new double [nC][nR][];              // Data values
460    m_Classes  = new int [nC];                    // Class values
461    m_Attributes = datasets.stringFreeStructure();     
462    if (m_Debug) {
463      System.out.println("Extracting data...");
464    }
465
466    for(int h=0; h<nC; h++)  {//h_th bag
467      Instance current = train.instance(h);
468      m_Classes[h] = (int)current.classValue();  // Class value starts from 0
469      Instances currInsts = current.relationalValue(1);
470      for (int i=0; i<currInsts.numInstances();i++){
471        Instance inst=currInsts.instance(i);
472        datasets.add(inst);
473      }
474
475      int nI = currInsts.numInstances();
476      bagSize[h]=nI;
477      if(m_Classes[h]==1){ 
478        if(nI>maxSz){
479          maxSz=nI;
480          maxSzIdx=new FastVector(1);
481          maxSzIdx.addElement(new Integer(h));
482        }
483        else if(nI == maxSz)
484          maxSzIdx.addElement(new Integer(h));
485      }
486
487    }
488
489    /* filter the training data */
490    if (m_filterType == FILTER_STANDARDIZE) 
491      m_Filter = new Standardize();
492    else if (m_filterType == FILTER_NORMALIZE)
493      m_Filter = new Normalize();
494    else 
495      m_Filter = null; 
496
497    if (m_Filter!=null) {
498      m_Filter.setInputFormat(datasets);
499      datasets = Filter.useFilter(datasets, m_Filter); 
500    }
501
502    m_Missing.setInputFormat(datasets);
503    datasets = Filter.useFilter(datasets, m_Missing);
504
505
506    int instIndex=0;
507    int start=0;       
508    for(int h=0; h<nC; h++)  { 
509      for (int i = 0; i < datasets.numAttributes(); i++) {
510        // initialize m_data[][][]
511        m_Data[h][i] = new double[bagSize[h]];
512        instIndex=start;
513        for (int k=0; k<bagSize[h]; k++){
514          m_Data[h][i][k]=datasets.instance(instIndex).value(i);
515          instIndex ++;
516        }
517      }
518      start=instIndex;
519    }
520
521
522    if (m_Debug) {
523      System.out.println("\nIteration History..." );
524    }
525
526    double[] x = new double[nR*2], tmp = new double[x.length];
527    double[][] b = new double[2][x.length]; 
528
529    OptEng opt;
530    double nll, bestnll = Double.MAX_VALUE;
531    for (int t=0; t<x.length; t++){
532      b[0][t] = Double.NaN; 
533      b[1][t] = Double.NaN;
534    }
535
536    // Largest Positive exemplar
537    for(int s=0; s<maxSzIdx.size(); s++){
538      int exIdx = ((Integer)maxSzIdx.elementAt(s)).intValue();
539      for(int p=0; p<m_Data[exIdx][0].length; p++){
540        for (int q=0; q < nR;q++){
541          x[2*q] = m_Data[exIdx][q][p];  // pick one instance
542          x[2*q+1] = 1.0;
543        }
544
545        opt = new OptEng();     
546        //opt.setDebug(m_Debug);
547        tmp = opt.findArgmin(x, b);
548        while(tmp==null){
549          tmp = opt.getVarbValues();
550          if (m_Debug)
551            System.out.println("200 iterations finished, not enough!");
552          tmp = opt.findArgmin(tmp, b);
553        }
554        nll = opt.getMinFunction();
555
556        if(nll < bestnll){
557          bestnll = nll;
558          m_Par = tmp;
559          tmp = new double[x.length]; // Save memory
560          if (m_Debug)
561            System.out.println("!!!!!!!!!!!!!!!!Smaller NLL found: "+nll);
562        }
563        if (m_Debug)
564          System.out.println(exIdx+":  -------------<Converged>--------------");
565      } 
566    }
567  }             
568
569  /**
570   * Computes the distribution for a given exemplar
571   *
572   * @param exmp the exemplar for which distribution is computed
573   * @return the distribution
574   * @throws Exception if the distribution can't be computed successfully
575   */
576  public double[] distributionForInstance(Instance exmp) 
577    throws Exception {
578
579    // Extract the data
580    Instances ins = exmp.relationalValue(1);
581    if(m_Filter!=null)
582      ins = Filter.useFilter(ins, m_Filter);
583
584    ins = Filter.useFilter(ins, m_Missing);
585
586    int nI = ins.numInstances(), nA = ins.numAttributes();
587    double[][] dat = new double [nI][nA];
588    for(int j=0; j<nI; j++){
589      for(int k=0; k<nA; k++){ 
590        dat[j][k] = ins.instance(j).value(k);
591      }
592    }
593
594    // Compute the probability of the bag
595    double [] distribution = new double[2];
596    distribution[0]=0.0;  // log-Prob. for class 0
597
598    for(int i=0; i<nI; i++){
599      double exp = 0.0;
600      for(int r=0; r<nA; r++)
601        exp += (m_Par[r*2]-dat[i][r])*(m_Par[r*2]-dat[i][r])*
602          m_Par[r*2+1]*m_Par[r*2+1];
603      exp = Math.exp(-exp);
604
605      // Prob. updated for one instance
606      distribution[0] += Math.log(1.0-exp);
607    }
608
609    distribution[0] = Math.exp(distribution[0]);
610    distribution[1] = 1.0-distribution[0];
611
612    return distribution;
613  }
614
615  /**
616   * Gets a string describing the classifier.
617   *
618   * @return a string describing the classifer built.
619   */
620  public String toString() {
621
622    //double CSq = m_LLn - m_LL;
623    //int df = m_NumPredictors;
624    String result = "Diverse Density";
625    if (m_Par == null) {
626      return result + ": No model built yet.";
627    }
628
629    result += "\nCoefficients...\n"
630      + "Variable       Point       Scale\n";
631    for (int j = 0, idx=0; j < m_Par.length/2; j++, idx++) {
632      result += m_Attributes.attribute(idx).name();
633      result += " "+Utils.doubleToString(m_Par[j*2], 12, 4); 
634      result += " "+Utils.doubleToString(m_Par[j*2+1], 12, 4)+"\n";
635    }
636
637    return result;
638  }
639 
640  /**
641   * Returns the revision string.
642   *
643   * @return            the revision
644   */
645  public String getRevision() {
646    return RevisionUtils.extract("$Revision: 5928 $");
647  }
648
649  /**
650   * Main method for testing this class.
651   *
652   * @param argv should contain the command line arguments to the
653   * scheme (see Evaluation)
654   */
655  public static void main(String[] argv) {
656    runClassifier(new MIDD(), argv);
657  }
658}
Note: See TracBrowser for help on using the repository browser.