source: src/main/java/weka/classifiers/mi/MIOptimalBall.java @ 17

Last change on this file since 17 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 17.9 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 * MIOptimalBall.java
19 * Copyright (C) 2005 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.mi;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.core.Capabilities;
28import weka.core.Instance;
29import weka.core.Instances;
30import weka.core.MultiInstanceCapabilitiesHandler;
31import weka.core.Option;
32import weka.core.OptionHandler;
33import weka.core.RevisionUtils;
34import weka.core.SelectedTag;
35import weka.core.Tag;
36import weka.core.TechnicalInformation;
37import weka.core.TechnicalInformationHandler;
38import weka.core.Utils;
39import weka.core.WeightedInstancesHandler;
40import weka.core.Capabilities.Capability;
41import weka.core.TechnicalInformation.Field;
42import weka.core.TechnicalInformation.Type;
43import weka.core.matrix.DoubleVector;
44import weka.filters.Filter;
45import weka.filters.unsupervised.attribute.MultiInstanceToPropositional;
46import weka.filters.unsupervised.attribute.Normalize;
47import weka.filters.unsupervised.attribute.PropositionalToMultiInstance;
48import weka.filters.unsupervised.attribute.Standardize;
49
50import java.util.Enumeration;
51import java.util.Vector;
52
53/**
54 <!-- globalinfo-start -->
55 * This classifier tries to find a suitable ball in the multiple-instance space, with a certain data point in the instance space as a ball center. The possible ball center is a certain instance in a positive bag. The possible radiuses are those which can achieve the highest classification accuracy. The model selects the maximum radius as the radius of the optimal ball.<br/>
56 * <br/>
57 * For more information about this algorithm, see:<br/>
58 * <br/>
59 * Peter Auer, Ronald Ortner: A Boosting Approach to Multiple Instance Learning. In: 15th European Conference on Machine Learning, 63-74, 2004.
60 * <p/>
61 <!-- globalinfo-end -->
62 *
63 <!-- technical-bibtex-start -->
64 * BibTeX:
65 * <pre>
66 * &#64;inproceedings{Auer2004,
67 *    author = {Peter Auer and Ronald Ortner},
68 *    booktitle = {15th European Conference on Machine Learning},
69 *    note = {LNAI 3201},
70 *    pages = {63-74},
71 *    publisher = {Springer},
72 *    title = {A Boosting Approach to Multiple Instance Learning},
73 *    year = {2004}
74 * }
75 * </pre>
76 * <p/>
77 <!-- technical-bibtex-end -->
78 *
79 <!-- options-start -->
80 * Valid options are: <p/>
81 *
82 * <pre> -N &lt;num&gt;
83 *  Whether to 0=normalize/1=standardize/2=neither.
84 *  (default 0=normalize)</pre>
85 *
86 <!-- options-end -->
87 *
88 * @author Lin Dong (ld21@cs.waikato.ac.nz)
89 * @version $Revision: 5928 $
90 */
91public class MIOptimalBall 
92  extends AbstractClassifier
93  implements OptionHandler, WeightedInstancesHandler, 
94             MultiInstanceCapabilitiesHandler, TechnicalInformationHandler { 
95
96  /** for serialization */
97  static final long serialVersionUID = -6465750129576777254L;
98 
99  /** center of the optimal ball */
100  protected double[] m_Center;
101
102  /** radius of the optimal ball */
103  protected double m_Radius;
104
105  /** the distances from each instance in a positive bag to each bag*/
106  protected double [][][]m_Distance;
107
108  /** The filter used to standardize/normalize all values. */
109  protected Filter m_Filter = null;
110
111  /** Whether to normalize/standardize/neither */
112  protected int m_filterType = FILTER_NORMALIZE;
113
114  /** Normalize training data */
115  public static final int FILTER_NORMALIZE = 0;
116  /** Standardize training data */
117  public static final int FILTER_STANDARDIZE = 1;
118  /** No normalization/standardization */
119  public static final int FILTER_NONE = 2;
120  /** The filter to apply to the training data */
121  public static final Tag [] TAGS_FILTER = {
122    new Tag(FILTER_NORMALIZE, "Normalize training data"),
123    new Tag(FILTER_STANDARDIZE, "Standardize training data"),
124    new Tag(FILTER_NONE, "No normalization/standardization"),
125  };
126
127  /** filter used to convert the MI dataset into single-instance dataset */
128  protected MultiInstanceToPropositional m_ConvertToSI = new MultiInstanceToPropositional();
129
130  /** filter used to convert the single-instance dataset into MI dataset */
131  protected PropositionalToMultiInstance m_ConvertToMI = new PropositionalToMultiInstance();
132
133  /**
134   * Returns a string describing this filter
135   *
136   * @return a description of the filter suitable for
137   * displaying in the explorer/experimenter gui
138   */
139  public String globalInfo() {
140    return
141         "This classifier tries to find a suitable ball in the "
142       + "multiple-instance space, with a certain data point in the instance "
143       + "space as a ball center. The possible ball center is a certain "
144       + "instance in a positive bag. The possible radiuses are those which can "
145       + "achieve the highest classification accuracy. The model selects the "
146       + "maximum radius as the radius of the optimal ball.\n\n"
147       + "For more information about this algorithm, see:\n\n"
148       + getTechnicalInformation().toString();
149  }
150
151  /**
152   * Returns an instance of a TechnicalInformation object, containing
153   * detailed information about the technical background of this class,
154   * e.g., paper reference or book this class is based on.
155   *
156   * @return the technical information about this class
157   */
158  public TechnicalInformation getTechnicalInformation() {
159    TechnicalInformation        result;
160   
161    result = new TechnicalInformation(Type.INPROCEEDINGS);
162    result.setValue(Field.AUTHOR, "Peter Auer and Ronald Ortner");
163    result.setValue(Field.TITLE, "A Boosting Approach to Multiple Instance Learning");
164    result.setValue(Field.BOOKTITLE, "15th European Conference on Machine Learning");
165    result.setValue(Field.YEAR, "2004");
166    result.setValue(Field.PAGES, "63-74");
167    result.setValue(Field.PUBLISHER, "Springer");
168    result.setValue(Field.NOTE, "LNAI 3201");
169   
170    return result;
171  }
172
173  /**
174   * Returns default capabilities of the classifier.
175   *
176   * @return      the capabilities of this classifier
177   */
178  public Capabilities getCapabilities() {
179    Capabilities result = super.getCapabilities();
180    result.disableAll();
181
182    // attributes
183    result.enable(Capability.NOMINAL_ATTRIBUTES);
184    result.enable(Capability.RELATIONAL_ATTRIBUTES);
185    result.enable(Capability.MISSING_VALUES);
186
187    // class
188    result.enable(Capability.BINARY_CLASS);
189    result.enable(Capability.MISSING_CLASS_VALUES);
190   
191    // other
192    result.enable(Capability.ONLY_MULTIINSTANCE);
193   
194    return result;
195  }
196
197  /**
198   * Returns the capabilities of this multi-instance classifier for the
199   * relational data.
200   *
201   * @return            the capabilities of this object
202   * @see               Capabilities
203   */
204  public Capabilities getMultiInstanceCapabilities() {
205    Capabilities result = super.getCapabilities();
206    result.disableAll();
207   
208    // attributes
209    result.enable(Capability.NOMINAL_ATTRIBUTES);
210    result.enable(Capability.NUMERIC_ATTRIBUTES);
211    result.enable(Capability.DATE_ATTRIBUTES);
212    result.enable(Capability.MISSING_VALUES);
213
214    // class
215    result.disableAllClasses();
216    result.enable(Capability.NO_CLASS);
217   
218    return result;
219  }
220
221  /**
222   * Builds the classifier
223   *
224   * @param data the training data to be used for generating the
225   * boosted classifier.
226   * @throws Exception if the classifier could not be built successfully
227   */
228  public void buildClassifier(Instances data) throws Exception {
229    // can classifier handle the data?
230    getCapabilities().testWithFail(data);
231
232    // remove instances with missing class
233    Instances train = new Instances(data);
234    train.deleteWithMissingClass();
235   
236    int numAttributes = train.attribute(1).relation().numAttributes(); 
237    m_Center = new double[numAttributes];
238
239    if (getDebug())
240      System.out.println("Start training ..."); 
241
242    // convert the training dataset into single-instance dataset
243    m_ConvertToSI.setInputFormat(train);       
244    train = Filter.useFilter( train, m_ConvertToSI);
245
246    if (m_filterType == FILTER_STANDARDIZE) 
247      m_Filter = new Standardize();
248    else if (m_filterType == FILTER_NORMALIZE)
249      m_Filter = new Normalize();
250    else 
251      m_Filter = null;
252
253    if (m_Filter!=null) {
254      // normalize/standardize the converted training dataset
255      m_Filter.setInputFormat(train);
256      train = Filter.useFilter(train, m_Filter);
257    }
258
259    // convert the single-instance dataset into multi-instance dataset
260    m_ConvertToMI.setInputFormat(train);
261    train = Filter.useFilter(train, m_ConvertToMI);
262
263    /*calculate all the distances (and store them in m_Distance[][][]), which
264      are from each instance in all positive bags to all bags */
265    calculateDistance(train);
266
267    /*find the suitable ball center (m_Center) and the corresponding radius (m_Radius)*/
268    findRadius(train); 
269
270    if (getDebug())
271      System.out.println("Finish building optimal ball model");
272  }             
273
274
275
276  /**
277   * calculate the distances from each instance in a positive bag to each bag.
278   * All result distances are stored in m_Distance[i][j][k], where
279   * m_Distance[i][j][k] refers the distances from the jth instance in ith bag
280   * to the kth bag
281   *
282   * @param train the multi-instance dataset (with relational attribute)   
283   */
284  public void calculateDistance (Instances train) {
285    int numBags =train.numInstances();
286    int numInstances;
287    Instance tempCenter;
288
289    m_Distance = new double [numBags][][];
290    for (int i=0; i<numBags; i++) {
291      if (train.instance(i).classValue() == 1.0) { //positive bag
292        numInstances = train.instance(i).relationalValue(1).numInstances();
293        m_Distance[i]= new double[numInstances][];
294        for (int j=0; j<numInstances; j++) {
295          tempCenter = train.instance(i).relationalValue(1).instance(j);
296          m_Distance[i][j]=new double [numBags];  //store the distance from one center to all the bags
297          for (int k=0; k<numBags; k++){
298            if (i==k)
299              m_Distance[i][j][k]= 0;     
300            else 
301              m_Distance[i][j][k]= minBagDistance (tempCenter, train.instance(k));         
302          }
303        }
304      } 
305    }
306  } 
307
308  /**
309   * Calculate the distance from one data point to a bag
310   *
311   * @param center the data point in instance space
312   * @param bag the bag
313   * @return the double value as the distance.
314   */
315  public double minBagDistance (Instance center, Instance bag){
316    double distance;
317    double minDistance = Double.MAX_VALUE;
318    Instances temp = bag.relationalValue(1); 
319    //calculate the distance from the data point to each instance in the bag and return the minimum distance
320    for (int i=0; i<temp.numInstances(); i++){
321      distance =0;
322      for (int j=0; j<center.numAttributes(); j++)
323        distance += (center.value(j)-temp.instance(i).value(j))*(center.value(j)-temp.instance(i).value(j));
324
325      if (minDistance>distance)
326        minDistance = distance;
327    }
328    return Math.sqrt(minDistance); 
329  }
330
331  /**
332   * Find the maximum radius for the optimal ball.
333   *
334   * @param train the multi-instance data
335   */ 
336  public void findRadius(Instances train) {
337    int numBags, numInstances;
338    double radius, bagDistance;
339    int highestCount=0;
340
341    numBags = train.numInstances();
342    //try each instance in all positive bag as a ball center (tempCenter),     
343    for (int i=0; i<numBags; i++) {
344      if (train.instance(i).classValue()== 1.0) {//positive bag   
345        numInstances = train.instance(i).relationalValue(1).numInstances();
346        for (int j=0; j<numInstances; j++) {                   
347          Instance tempCenter = train.instance(i).relationalValue(1).instance(j);
348
349          //set the possible set of ball radius corresponding to each tempCenter,
350          double sortedDistance[] = sortArray(m_Distance[i][j]); //sort the distance value               
351          for (int k=1; k<sortedDistance.length; k++){
352            radius = sortedDistance[k]-(sortedDistance[k]-sortedDistance[k-1])/2.0 ;
353
354            //evaluate the performance on the training data according to
355            //the curren selected tempCenter and the set of radius   
356            int correctCount =0;
357            for (int n=0; n<numBags; n++){
358              bagDistance=m_Distance[i][j][n]; 
359              if ((bagDistance <= radius && train.instance(n).classValue()==1.0) 
360                  ||(bagDistance > radius && train.instance(n).classValue ()==0.0))
361                correctCount += train.instance(n).weight();
362
363            }
364
365            //and keep the track of the ball center and the maximum radius which can achieve the highest accuracy.
366            if (correctCount > highestCount || (correctCount==highestCount && radius > m_Radius)){
367              highestCount = correctCount;
368              m_Radius = radius;
369              for (int p=0; p<tempCenter.numAttributes(); p++)
370                m_Center[p]= tempCenter.value(p);
371            }     
372          }
373        }
374      }
375    } 
376  }
377
378  /**
379   * Sort the array.
380   *
381   * @param distance the array need to be sorted
382   * @return sorted array
383   */ 
384  public double [] sortArray(double [] distance) {
385    double [] sorted = new double [distance.length];
386
387    //make a copy of the array
388    double []disCopy = new double[distance.length];
389    for (int i=0;i<distance.length; i++)
390      disCopy[i]= distance[i];
391
392    DoubleVector sortVector = new DoubleVector(disCopy);
393    sortVector.sort();
394    sorted = sortVector.getArrayCopy(); 
395    return sorted;
396  }
397
398
399  /**
400   * Computes the distribution for a given multiple instance
401   *
402   * @param newBag the instance for which distribution is computed
403   * @return the distribution
404   * @throws Exception if the distribution can't be computed successfully
405   */
406  public double[] distributionForInstance(Instance newBag)
407    throws Exception { 
408
409    double [] distribution = new double[2];     
410    double distance; 
411    distribution[0]=0;   
412    distribution[1]=0;
413
414    Instances insts = new Instances(newBag.dataset(),0);
415    insts.add(newBag); 
416
417    // Filter instances
418    insts= Filter.useFilter( insts, m_ConvertToSI);     
419    if (m_Filter!=null) 
420      insts = Filter.useFilter(insts, m_Filter);     
421
422    //calculate the distance from each single instance to the ball center
423    int numInsts = insts.numInstances();               
424    insts.deleteAttributeAt(0); //remove the bagIndex attribute, no use for the distance calculation
425
426    for (int i=0; i<numInsts; i++){
427      distance =0;         
428      for (int j=0; j<insts.numAttributes()-1; j++)
429        distance += (insts.instance(i).value(j) - m_Center[j])*(insts.instance(i).value(j)-m_Center[j]); 
430
431      if (distance <=m_Radius*m_Radius){  // check whether this single instance is inside the ball
432        distribution[1]=1.0;  //predicted as a positive bag       
433        break;
434      } 
435    }
436
437    distribution[0]= 1-distribution[1]; 
438
439    return distribution; 
440  }
441
442  /**
443   * Returns an enumeration describing the available options.
444   *
445   * @return an enumeration of all the available options.
446   */
447  public Enumeration listOptions() {
448    Vector result = new Vector();
449
450    result.addElement(new Option(
451          "\tWhether to 0=normalize/1=standardize/2=neither. \n"
452          + "\t(default 0=normalize)",
453          "N", 1, "-N <num>"));
454
455    return result.elements();
456  }
457
458  /**
459   * Gets the current settings of the classifier.
460   *
461   * @return an array of strings suitable for passing to setOptions
462   */
463  public String[] getOptions() {
464    Vector        result;
465   
466    result = new Vector();
467
468    if (getDebug())
469      result.add("-D");
470   
471    result.add("-N");
472    result.add("" + m_filterType);
473
474    return (String[]) result.toArray(new String[result.size()]);
475  }
476
477  /**
478   * Parses a given list of options. <p/>
479   *
480   <!-- options-start -->
481   * Valid options are: <p/>
482   *
483   * <pre> -N &lt;num&gt;
484   *  Whether to 0=normalize/1=standardize/2=neither.
485   *  (default 0=normalize)</pre>
486   *
487   <!-- options-end -->
488   *
489   * @param options the list of options as an array of strings
490   * @throws Exception if an option is not supported
491   */
492  public void setOptions(String[] options) throws Exception {
493    setDebug(Utils.getFlag('D', options));
494
495    String nString = Utils.getOption('N', options);
496    if (nString.length() != 0) {
497      setFilterType(new SelectedTag(Integer.parseInt(nString), TAGS_FILTER));
498    } else {
499      setFilterType(new SelectedTag(FILTER_NORMALIZE, TAGS_FILTER));
500    }
501  }
502
503  /**
504   * Returns the tip text for this property
505   *
506   * @return tip text for this property suitable for
507   * displaying in the explorer/experimenter gui
508   */
509  public String filterTypeTipText() {
510    return "The filter type for transforming the training data.";
511  }
512
513  /**
514   * Sets how the training data will be transformed. Should be one of
515   * FILTER_NORMALIZE, FILTER_STANDARDIZE, FILTER_NONE.
516   *
517   * @param newType the new filtering mode
518   */
519  public void setFilterType(SelectedTag newType) {
520
521    if (newType.getTags() == TAGS_FILTER) {
522      m_filterType = newType.getSelectedTag().getID();
523    }
524  }
525
526  /**
527   * Gets how the training data will be transformed. Will be one of
528   * FILTER_NORMALIZE, FILTER_STANDARDIZE, FILTER_NONE.
529   *
530   * @return the filtering mode
531   */
532  public SelectedTag getFilterType() {
533
534    return new SelectedTag(m_filterType, TAGS_FILTER);
535  }
536 
537  /**
538   * Returns the revision string.
539   *
540   * @return            the revision
541   */
542  public String getRevision() {
543    return RevisionUtils.extract("$Revision: 5928 $");
544  }
545
546  /**
547   * Main method for testing this class.
548   *
549   * @param argv should contain the command line arguments to the
550   * scheme (see Evaluation)
551   */
552  public static void main(String[] argv) {
553    runClassifier(new MIOptimalBall(), argv);
554  }
555}
Note: See TracBrowser for help on using the repository browser.