source: src/main/java/weka/classifiers/mi/MIWrapper.java @ 9

Last change on this file since 9 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 16.3 KB
RevLine 
[4]1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 * MIWrapper.java
19 * Copyright (C) 2005 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.mi;
24
25import weka.classifiers.SingleClassifierEnhancer;
26import weka.core.Capabilities;
27import weka.core.Instance;
28import weka.core.Instances;
29import weka.core.MultiInstanceCapabilitiesHandler;
30import weka.core.Option;
31import weka.core.OptionHandler;
32import weka.core.RevisionUtils;
33import weka.core.SelectedTag;
34import weka.core.Tag;
35import weka.core.TechnicalInformation;
36import weka.core.TechnicalInformationHandler;
37import weka.core.Utils;
38import weka.core.Capabilities.Capability;
39import weka.core.TechnicalInformation.Field;
40import weka.core.TechnicalInformation.Type;
41import weka.filters.Filter;
42import weka.filters.unsupervised.attribute.MultiInstanceToPropositional;
43
44import java.util.Enumeration;
45import java.util.Vector;
46
47/**
48 <!-- globalinfo-start -->
49 * A simple Wrapper method for applying standard propositional learners to multi-instance data.<br/>
50 * <br/>
51 * For more information see:<br/>
52 * <br/>
53 * E. T. Frank, X. Xu (2003). Applying propositional learning algorithms to multi-instance data. Department of Computer Science, University of Waikato, Hamilton, NZ.
54 * <p/>
55 <!-- globalinfo-end -->
56 *
57 <!-- technical-bibtex-start -->
58 * BibTeX:
59 * <pre>
60 * &#64;techreport{Frank2003,
61 *    address = {Department of Computer Science, University of Waikato, Hamilton, NZ},
62 *    author = {E. T. Frank and X. Xu},
63 *    institution = {University of Waikato},
64 *    month = {06},
65 *    title = {Applying propositional learning algorithms to multi-instance data},
66 *    year = {2003}
67 * }
68 * </pre>
69 * <p/>
70 <!-- technical-bibtex-end -->
71 *
72 <!-- options-start -->
73 * Valid options are: <p/>
74 *
75 * <pre> -P [1|2|3]
76 *  The method used in testing:
77 *  1.arithmetic average
78 *  2.geometric average
79 *  3.max probability of positive bag.
80 *  (default: 1)</pre>
81 *
82 * <pre> -A [0|1|2|3]
83 *  The type of weight setting for each single-instance:
84 *  0.keep the weight to be the same as the original value;
85 *  1.weight = 1.0
86 *  2.weight = 1.0/Total number of single-instance in the
87 *   corresponding bag
88 *  3. weight = Total number of single-instance / (Total
89 *   number of bags * Total number of single-instance
90 *   in the corresponding bag).
91 *  (default: 3)</pre>
92 *
93 * <pre> -D
94 *  If set, classifier is run in debug mode and
95 *  may output additional info to the console</pre>
96 *
97 * <pre> -W
98 *  Full name of base classifier.
99 *  (default: weka.classifiers.rules.ZeroR)</pre>
100 *
101 * <pre>
102 * Options specific to classifier weka.classifiers.rules.ZeroR:
103 * </pre>
104 *
105 * <pre> -D
106 *  If set, classifier is run in debug mode and
107 *  may output additional info to the console</pre>
108 *
109 <!-- options-end -->
110 *
111 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
112 * @author Xin Xu (xx5@cs.waikato.ac.nz)
113 * @version $Revision: 1.5 $
114 */
115public class MIWrapper 
116  extends SingleClassifierEnhancer
117  implements MultiInstanceCapabilitiesHandler, OptionHandler,
118             TechnicalInformationHandler { 
119
120  /** for serialization */
121  static final long serialVersionUID = -7707766152904315910L;
122 
123  /** The number of the class labels */
124  protected int m_NumClasses;
125
126  /** arithmetic average */
127  public static final int TESTMETHOD_ARITHMETIC = 1;
128  /** geometric average */
129  public static final int TESTMETHOD_GEOMETRIC = 2;
130  /** max probability of positive bag */
131  public static final int TESTMETHOD_MAXPROB = 3;
132  /** the test methods */
133  public static final Tag[] TAGS_TESTMETHOD = {
134    new Tag(TESTMETHOD_ARITHMETIC, "arithmetic average"),
135    new Tag(TESTMETHOD_GEOMETRIC, "geometric average"),
136    new Tag(TESTMETHOD_MAXPROB, "max probability of positive bag")
137  };
138
139  /** the test method  */
140  protected int m_Method = TESTMETHOD_GEOMETRIC;
141
142  /** Filter used to convert MI dataset into single-instance dataset */
143  protected MultiInstanceToPropositional m_ConvertToProp = new MultiInstanceToPropositional();
144
145  /** the single-instance weight setting method */
146  protected int m_WeightMethod = MultiInstanceToPropositional.WEIGHTMETHOD_INVERSE2;
147
148  /**
149   * Returns a string describing this filter
150   *
151   * @return a description of the filter suitable for
152   * displaying in the explorer/experimenter gui
153   */
154  public String globalInfo() {
155    return 
156         "A simple Wrapper method for applying standard propositional learners "
157       + "to multi-instance data.\n\n"
158       + "For more information see:\n\n"
159       + getTechnicalInformation().toString();
160  }
161
162  /**
163   * Returns an instance of a TechnicalInformation object, containing
164   * detailed information about the technical background of this class,
165   * e.g., paper reference or book this class is based on.
166   *
167   * @return the technical information about this class
168   */
169  public TechnicalInformation getTechnicalInformation() {
170    TechnicalInformation        result;
171   
172    result = new TechnicalInformation(Type.TECHREPORT);
173    result.setValue(Field.AUTHOR, "E. T. Frank and X. Xu");
174    result.setValue(Field.TITLE, "Applying propositional learning algorithms to multi-instance data");
175    result.setValue(Field.YEAR, "2003");
176    result.setValue(Field.MONTH, "06");
177    result.setValue(Field.INSTITUTION, "University of Waikato");
178    result.setValue(Field.ADDRESS, "Department of Computer Science, University of Waikato, Hamilton, NZ");
179   
180    return result;
181  }
182
183  /**
184   * Returns an enumeration describing the available options.
185   *
186   * @return an enumeration of all the available options.
187   */
188  public Enumeration listOptions() {
189    Vector result = new Vector();
190
191    result.addElement(new Option(
192          "\tThe method used in testing:\n"
193          + "\t1.arithmetic average\n"
194          + "\t2.geometric average\n"
195          + "\t3.max probability of positive bag.\n"
196          + "\t(default: 1)",
197          "P", 1, "-P [1|2|3]"));
198   
199    result.addElement(new Option(
200          "\tThe type of weight setting for each single-instance:\n"
201          + "\t0.keep the weight to be the same as the original value;\n"
202          + "\t1.weight = 1.0\n"
203          + "\t2.weight = 1.0/Total number of single-instance in the\n"
204          + "\t\tcorresponding bag\n"
205          + "\t3. weight = Total number of single-instance / (Total\n"
206          + "\t\tnumber of bags * Total number of single-instance \n"
207          + "\t\tin the corresponding bag).\n"
208          + "\t(default: 3)",
209          "A", 1, "-A [0|1|2|3]"));     
210
211    Enumeration enu = super.listOptions();
212    while (enu.hasMoreElements()) {
213      result.addElement(enu.nextElement());
214    }
215
216    return result.elements();
217  }
218
219
220  /**
221   * Parses a given list of options. <p/>
222   *
223   <!-- options-start -->
224   * Valid options are: <p/>
225   *
226   * <pre> -P [1|2|3]
227   *  The method used in testing:
228   *  1.arithmetic average
229   *  2.geometric average
230   *  3.max probability of positive bag.
231   *  (default: 1)</pre>
232   *
233   * <pre> -A [0|1|2|3]
234   *  The type of weight setting for each single-instance:
235   *  0.keep the weight to be the same as the original value;
236   *  1.weight = 1.0
237   *  2.weight = 1.0/Total number of single-instance in the
238   *   corresponding bag
239   *  3. weight = Total number of single-instance / (Total
240   *   number of bags * Total number of single-instance
241   *   in the corresponding bag).
242   *  (default: 3)</pre>
243   *
244   * <pre> -D
245   *  If set, classifier is run in debug mode and
246   *  may output additional info to the console</pre>
247   *
248   * <pre> -W
249   *  Full name of base classifier.
250   *  (default: weka.classifiers.rules.ZeroR)</pre>
251   *
252   * <pre>
253   * Options specific to classifier weka.classifiers.rules.ZeroR:
254   * </pre>
255   *
256   * <pre> -D
257   *  If set, classifier is run in debug mode and
258   *  may output additional info to the console</pre>
259   *
260   <!-- options-end -->
261   *
262   * @param options the list of options as an array of strings
263   * @throws Exception if an option is not supported
264   */
265  public void setOptions(String[] options) throws Exception {
266
267    setDebug(Utils.getFlag('D', options));
268
269    String methodString = Utils.getOption('P', options);
270    if (methodString.length() != 0) {
271      setMethod(
272          new SelectedTag(Integer.parseInt(methodString), TAGS_TESTMETHOD));
273    } else {
274      setMethod(
275          new SelectedTag(TESTMETHOD_ARITHMETIC, TAGS_TESTMETHOD));
276    }
277
278    String weightString = Utils.getOption('A', options);
279    if (weightString.length() != 0) {
280      setWeightMethod(
281          new SelectedTag(
282            Integer.parseInt(weightString), 
283            MultiInstanceToPropositional.TAGS_WEIGHTMETHOD));
284    } else {
285      setWeightMethod(
286          new SelectedTag(
287            MultiInstanceToPropositional.WEIGHTMETHOD_INVERSE2, 
288            MultiInstanceToPropositional.TAGS_WEIGHTMETHOD));
289    }   
290
291    super.setOptions(options);
292  }
293
294  /**
295   * Gets the current settings of the Classifier.
296   *
297   * @return an array of strings suitable for passing to setOptions
298   */
299  public String[] getOptions() {
300    Vector        result;
301    String[]      options;
302    int           i;
303   
304    result  = new Vector();
305
306    result.add("-P");
307    result.add("" + m_Method);
308
309    result.add("-A");
310    result.add("" + m_WeightMethod);
311
312    options = super.getOptions();
313    for (i = 0; i < options.length; i++)
314      result.add(options[i]);
315
316    return (String[]) result.toArray(new String[result.size()]);
317  }
318
319  /**
320   * Returns the tip text for this property
321   *
322   * @return tip text for this property suitable for
323   * displaying in the explorer/experimenter gui
324   */
325  public String weightMethodTipText() {
326    return "The method used for weighting the instances.";
327  }
328
329  /**
330   * The new method for weighting the instances.
331   *
332   * @param method      the new method
333   */
334  public void setWeightMethod(SelectedTag method){
335    if (method.getTags() == MultiInstanceToPropositional.TAGS_WEIGHTMETHOD)
336      m_WeightMethod = method.getSelectedTag().getID();
337  }
338
339  /**
340   * Returns the current weighting method for instances.
341   *
342   * @return the current weighting method
343   */
344  public SelectedTag getWeightMethod(){
345    return new SelectedTag(
346                  m_WeightMethod, MultiInstanceToPropositional.TAGS_WEIGHTMETHOD);
347  }
348
349  /**
350   * Returns the tip text for this property
351   *
352   * @return tip text for this property suitable for
353   * displaying in the explorer/experimenter gui
354   */
355  public String methodTipText() {
356    return "The method used for testing.";
357  }
358
359  /**
360   * Set the method used in testing.
361   *
362   * @param method the index of method to use.
363   */
364  public void setMethod(SelectedTag method) {
365    if (method.getTags() == TAGS_TESTMETHOD)
366      m_Method = method.getSelectedTag().getID();
367  }
368
369  /**
370   * Get the method used in testing.
371   *
372   * @return the index of method used in testing.
373   */
374  public SelectedTag getMethod() {
375    return new SelectedTag(m_Method, TAGS_TESTMETHOD);
376  }
377
378  /**
379   * Returns default capabilities of the classifier.
380   *
381   * @return      the capabilities of this classifier
382   */
383  public Capabilities getCapabilities() {
384    Capabilities result = super.getCapabilities();
385
386    // class
387    result.disableAllClasses();
388    result.disableAllClassDependencies();
389    if (super.getCapabilities().handles(Capability.NOMINAL_CLASS))
390      result.enable(Capability.NOMINAL_CLASS);
391    if (super.getCapabilities().handles(Capability.BINARY_CLASS))
392      result.enable(Capability.BINARY_CLASS);
393    result.enable(Capability.RELATIONAL_ATTRIBUTES);
394    result.enable(Capability.MISSING_CLASS_VALUES);
395   
396    // other
397    result.enable(Capability.ONLY_MULTIINSTANCE);
398   
399    return result;
400  }
401
402  /**
403   * Returns the capabilities of this multi-instance classifier for the
404   * relational data.
405   *
406   * @return            the capabilities of this object
407   * @see               Capabilities
408   */
409  public Capabilities getMultiInstanceCapabilities() {
410    Capabilities result = super.getCapabilities();
411   
412    // class
413    result.disableAllClasses();
414    result.enable(Capability.NO_CLASS);
415   
416    return result;
417  }
418
419  /**
420   * Builds the classifier
421   *
422   * @param data the training data to be used for generating the
423   * boosted classifier.
424   * @throws Exception if the classifier could not be built successfully
425   */
426  public void buildClassifier(Instances data) throws Exception {
427
428    // can classifier handle the data?
429    getCapabilities().testWithFail(data);
430
431    // remove instances with missing class
432    Instances train = new Instances(data);
433    train.deleteWithMissingClass();
434   
435    if (m_Classifier == null) {
436      throw new Exception("A base classifier has not been specified!");
437    }
438
439    if (getDebug())
440      System.out.println("Start training ...");
441    m_NumClasses = train.numClasses();
442
443    //convert the training dataset into single-instance dataset
444    m_ConvertToProp.setWeightMethod(getWeightMethod());
445    m_ConvertToProp.setInputFormat(train);
446    train = Filter.useFilter(train, m_ConvertToProp);
447    train.deleteAttributeAt(0); // remove the bag index attribute
448
449    m_Classifier.buildClassifier(train);
450  }             
451
452  /**
453   * Computes the distribution for a given exemplar
454   *
455   * @param exmp the exemplar for which distribution is computed
456   * @return the distribution
457   * @throws Exception if the distribution can't be computed successfully
458   */
459  public double[] distributionForInstance(Instance exmp) 
460    throws Exception { 
461
462    Instances testData = new Instances (exmp.dataset(),0);
463    testData.add(exmp);
464
465    // convert the training dataset into single-instance dataset
466    m_ConvertToProp.setWeightMethod(
467        new SelectedTag(
468          MultiInstanceToPropositional.WEIGHTMETHOD_ORIGINAL, 
469          MultiInstanceToPropositional.TAGS_WEIGHTMETHOD));
470    testData = Filter.useFilter(testData, m_ConvertToProp);
471    testData.deleteAttributeAt(0); //remove the bag index attribute
472
473    // Compute the log-probability of the bag
474    double [] distribution = new double[m_NumClasses];
475    double nI = (double)testData.numInstances();
476    double [] maxPr = new double [m_NumClasses];
477
478    for(int i=0; i<nI; i++){
479      double[] dist = m_Classifier.distributionForInstance(testData.instance(i));
480      for(int j=0; j<m_NumClasses; j++){
481
482        switch(m_Method){
483          case TESTMETHOD_ARITHMETIC:
484            distribution[j] += dist[j]/nI;
485            break;
486          case TESTMETHOD_GEOMETRIC:
487            // Avoid 0/1 probability
488            if(dist[j]<0.001)
489              dist[j] = 0.001;
490            else if(dist[j]>0.999)
491              dist[j] = 0.999;
492
493            distribution[j] += Math.log(dist[j])/nI;
494            break;
495          case TESTMETHOD_MAXPROB:
496            if (dist[j]>maxPr[j]) 
497              maxPr[j] = dist[j];
498            break;
499        }
500      }
501    }
502
503    if(m_Method == TESTMETHOD_GEOMETRIC)
504      for(int j=0; j<m_NumClasses; j++)
505        distribution[j] = Math.exp(distribution[j]);
506
507    if(m_Method == TESTMETHOD_MAXPROB){   // for positive bag
508      distribution[1] = maxPr[1];
509      distribution[0] = 1 - distribution[1];
510    }
511
512    if (Utils.eq(Utils.sum(distribution), 0)) {
513      for (int i = 0; i < distribution.length; i++)
514        distribution[i] = 1.0 / (double) distribution.length;
515    }
516    else {
517      Utils.normalize(distribution);
518    }
519   
520    return distribution;
521  }
522
523  /**
524   * Gets a string describing the classifier.
525   *
526   * @return a string describing the classifer built.
527   */
528  public String toString() {   
529    return "MIWrapper with base classifier: \n"+m_Classifier.toString();
530  }
531 
532  /**
533   * Returns the revision string.
534   *
535   * @return            the revision
536   */
537  public String getRevision() {
538    return RevisionUtils.extract("$Revision: 1.5 $");
539  }
540
541  /**
542   * Main method for testing this class.
543   *
544   * @param argv should contain the command line arguments to the
545   * scheme (see Evaluation)
546   */
547  public static void main(String[] argv) {
548    runClassifier(new MIWrapper(), argv);
549  }
550}
Note: See TracBrowser for help on using the repository browser.