source: branches/MetisMQI/src/main/java/weka/filters/unsupervised/instance/RemoveMisclassified.java

Last change on this file was 29, checked in by gnappo, 15 years ago

Taggata versione per la demo e aggiunto branch.

File size: 21.1 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    RemoveMisclassified.java
19 *    Copyright (C) 2002 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.filters.unsupervised.instance;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.core.Capabilities;
28import weka.core.Instance;
29import weka.core.Instances;
30import weka.core.Option;
31import weka.core.OptionHandler;
32import weka.core.RevisionUtils;
33import weka.core.Utils;
34import weka.filters.Filter;
35import weka.filters.UnsupervisedFilter;
36
37import java.util.Enumeration;
38import java.util.Vector;
39
40/**
41 <!-- globalinfo-start -->
42 * A filter that removes instances which are incorrectly classified. Useful for removing outliers.
43 * <p/>
44 <!-- globalinfo-end -->
45 *
46 <!-- options-start -->
47 * Valid options are: <p/>
48 *
49 * <pre> -W &lt;classifier specification&gt;
50 *  Full class name of classifier to use, followed
51 *  by scheme options. eg:
52 *   "weka.classifiers.bayes.NaiveBayes -D"
53 *  (default: weka.classifiers.rules.ZeroR)</pre>
54 *
55 * <pre> -C &lt;class index&gt;
56 *  Attribute on which misclassifications are based.
57 *  If &lt; 0 will use any current set class or default to the last attribute.</pre>
58 *
59 * <pre> -F &lt;number of folds&gt;
60 *  The number of folds to use for cross-validation cleansing.
61 *  (&lt;2 = no cross-validation - default).</pre>
62 *
63 * <pre> -T &lt;threshold&gt;
64 *  Threshold for the max error when predicting numeric class.
65 *  (Value should be &gt;= 0, default = 0.1).</pre>
66 *
67 * <pre> -I
68 *  The maximum number of cleansing iterations to perform.
69 *  (&lt;1 = until fully cleansed - default)</pre>
70 *
71 * <pre> -V
72 *  Invert the match so that correctly classified instances are discarded.
73 * </pre>
74 *
75 <!-- options-end -->
76 *
77 * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
78 * @author Malcolm Ware (mfw4@cs.waikato.ac.nz)
79 * @version $Revision: 5928 $
80 */
81public class RemoveMisclassified 
82  extends Filter
83  implements UnsupervisedFilter, OptionHandler {
84 
85  /** for serialization */
86  static final long serialVersionUID = 5469157004717663171L;
87
88  /** The classifier used to do the cleansing */
89  protected Classifier m_cleansingClassifier = new weka.classifiers.rules.ZeroR();
90
91  /** The attribute to treat as the class for purposes of cleansing. */
92  protected int m_classIndex = -1;
93
94  /** The number of cross validation folds to perform (&lt;2 = no cross validation)  */
95  protected int m_numOfCrossValidationFolds = 0;
96 
97  /** The maximum number of cleansing iterations to perform (&lt;1 = until fully cleansed)  */
98  protected int m_numOfCleansingIterations = 0;
99
100  /** The threshold for deciding when a numeric value is correctly classified */
101  protected double m_numericClassifyThreshold = 0.1;
102
103  /** Whether to invert the match so the correctly classified instances are discarded */
104  protected boolean m_invertMatching = false;
105
106  /** Have we processed the first batch (i.e. training data)? */
107  protected boolean m_firstBatchFinished = false;
108
109  /**
110   * Returns the Capabilities of this filter.
111   *
112   * @return            the capabilities of this object
113   * @see               Capabilities
114   */
115  public Capabilities getCapabilities() {
116    Capabilities        result;
117   
118    if (getClassifier() == null) {
119      result = super.getCapabilities();
120      result.disableAll();
121    } else {
122      result = getClassifier().getCapabilities();
123    }
124   
125    result.setMinimumNumberInstances(0);
126   
127    return result;
128  }
129
130  /**
131   * Sets the format of the input instances.
132   *
133   * @param instanceInfo an Instances object containing the input instance
134   * structure (any instances contained in the object are ignored - only the
135   * structure is required).
136   * @return true if the outputFormat may be collected immediately
137   * @throws Exception if the inputFormat can't be set successfully
138   */ 
139  public boolean setInputFormat(Instances instanceInfo) throws Exception {
140   
141    super.setInputFormat(instanceInfo);
142    setOutputFormat(instanceInfo);
143    m_firstBatchFinished = false;
144    return true;
145  }
146
147  /**
148   * Cleanses the data based on misclassifications when used training data.
149   *
150   * @param data the data to train with and cleanse
151   * @return the cleansed data
152   * @throws Exception if something goes wrong
153   */
154  private Instances cleanseTrain(Instances data) throws Exception {
155   
156    Instance inst;
157    Instances buildSet = new Instances(data); 
158    Instances temp = new Instances(data, data.numInstances());
159    Instances inverseSet = new Instances(data, data.numInstances()); 
160    int count = 0;
161    double ans;
162    int iterations = 0;
163    int classIndex = m_classIndex;
164    if (classIndex < 0) classIndex = data.classIndex();
165    if (classIndex < 0) classIndex = data.numAttributes()-1;
166
167    // loop until perfect
168    while(count != buildSet.numInstances()) {
169     
170      // check if hit maximum number of iterations
171      iterations++;
172      if (m_numOfCleansingIterations > 0 && iterations > m_numOfCleansingIterations) break;
173
174      // build classifier
175      count = buildSet.numInstances();
176      buildSet.setClassIndex(classIndex);
177      m_cleansingClassifier.buildClassifier(buildSet);
178
179      temp = new Instances(buildSet, buildSet.numInstances());
180
181      // test on training data
182      for (int i = 0; i < buildSet.numInstances(); i++) {
183        inst = buildSet.instance(i);
184        ans = m_cleansingClassifier.classifyInstance(inst);
185        if (buildSet.classAttribute().isNumeric()) {
186          if (ans >= inst.classValue() - m_numericClassifyThreshold &&
187              ans <= inst.classValue() + m_numericClassifyThreshold) {
188            temp.add(inst);
189          } else if (m_invertMatching) {
190            inverseSet.add(inst);
191          }
192        }
193        else { //class is nominal
194          if (ans == inst.classValue()) {
195            temp.add(inst);
196          } else if (m_invertMatching) {
197            inverseSet.add(inst);
198          }
199        }
200      }
201      buildSet = temp;
202    }
203
204    if (m_invertMatching) {
205      inverseSet.setClassIndex(data.classIndex());
206      return inverseSet;
207    }
208    else {
209      buildSet.setClassIndex(data.classIndex());
210      return buildSet;
211    }
212  }
213
214  /**
215   * Cleanses the data based on misclassifications when performing cross-validation.
216   *
217   * @param data the data to train with and cleanse
218   * @return the cleansed data
219   * @throws Exception if something goes wrong
220   */
221  private Instances cleanseCross(Instances data) throws Exception {
222   
223    Instance inst;
224    Instances crossSet = new Instances(data);
225    Instances temp = new Instances(data, data.numInstances());   
226    Instances inverseSet = new Instances(data, data.numInstances()); 
227    int count = 0;
228    double ans;
229    int iterations = 0;
230    int classIndex = m_classIndex;
231    if (classIndex < 0) classIndex = data.classIndex();
232    if (classIndex < 0) classIndex = data.numAttributes()-1;
233
234    // loop until perfect
235    while (count != crossSet.numInstances() && 
236           crossSet.numInstances() >= m_numOfCrossValidationFolds) {
237
238      count = crossSet.numInstances();
239     
240      // check if hit maximum number of iterations
241      iterations++;
242      if (m_numOfCleansingIterations > 0 && iterations > m_numOfCleansingIterations) break;
243
244      crossSet.setClassIndex(classIndex);
245
246      if (crossSet.classAttribute().isNominal()) {
247        crossSet.stratify(m_numOfCrossValidationFolds);
248      }
249      // do the folds
250      temp = new Instances(crossSet, crossSet.numInstances());
251     
252      for (int fold = 0; fold < m_numOfCrossValidationFolds; fold++) {
253        Instances train = crossSet.trainCV(m_numOfCrossValidationFolds, fold);
254        m_cleansingClassifier.buildClassifier(train);
255        Instances test = crossSet.testCV(m_numOfCrossValidationFolds, fold);
256        //now test
257        for (int i = 0; i < test.numInstances(); i++) {
258          inst = test.instance(i);
259          ans = m_cleansingClassifier.classifyInstance(inst);
260          if (crossSet.classAttribute().isNumeric()) {
261            if (ans >= inst.classValue() - m_numericClassifyThreshold &&
262                ans <= inst.classValue() + m_numericClassifyThreshold) {
263              temp.add(inst);
264            } else if (m_invertMatching) {
265              inverseSet.add(inst);
266            }
267          }
268          else { //class is nominal
269            if (ans == inst.classValue()) {
270              temp.add(inst);
271            } else if (m_invertMatching) {
272              inverseSet.add(inst);
273            }
274          }
275        }
276      }
277      crossSet = temp;
278    }
279
280    if (m_invertMatching) {
281      inverseSet.setClassIndex(data.classIndex());
282      return inverseSet;
283    }
284    else {
285      crossSet.setClassIndex(data.classIndex());
286      return crossSet;
287    }
288
289  }
290 
291  /**
292   * Input an instance for filtering.
293   *
294   * @param instance the input instance
295   * @return true if the filtered instance may now be
296   * collected with output().
297   * @throws NullPointerException if the input format has not been
298   * defined.
299   * @throws Exception if the input instance was not of the correct
300   * format or if there was a problem with the filtering. 
301   */
302  public boolean input(Instance instance) throws Exception {
303
304    if (inputFormatPeek() == null) {
305      throw new NullPointerException("No input instance format defined");
306    }
307
308    if (m_NewBatch) {
309      resetQueue();
310      m_NewBatch = false;
311    }
312    if (m_firstBatchFinished) {
313      push(instance);
314      return true;
315    } else {
316      bufferInput(instance);
317      return false;
318    }
319  }
320 
321  /**
322   * Signify that this batch of input to the filter is finished.
323   *
324   * @return true if there are instances pending output
325   * @throws IllegalStateException if no input structure has been defined
326   */ 
327  public boolean batchFinished() throws Exception {
328
329    if (getInputFormat() == null) {
330      throw new IllegalStateException("No input instance format defined");
331    }
332
333    if (!m_firstBatchFinished) {
334
335      Instances filtered;
336      if (m_numOfCrossValidationFolds < 2) {
337        filtered = cleanseTrain(getInputFormat());
338      } else {
339        filtered = cleanseCross(getInputFormat());
340      }
341     
342      for (int i=0; i<filtered.numInstances(); i++) {
343        push(filtered.instance(i));
344      }
345     
346      m_firstBatchFinished = true;
347      flushInput();
348    }
349    m_NewBatch = true;
350    return (numPendingOutput() != 0);
351  }
352
353  /**
354   * Returns an enumeration describing the available options.
355   *
356   * @return an enumeration of all the available options.
357   */
358  public Enumeration listOptions() {
359   
360    Vector newVector = new Vector(6);
361   
362    newVector.addElement(new Option(
363              "\tFull class name of classifier to use, followed\n"
364              + "\tby scheme options. eg:\n"
365              + "\t\t\"weka.classifiers.bayes.NaiveBayes -D\"\n"
366              + "\t(default: weka.classifiers.rules.ZeroR)",
367              "W", 1, "-W <classifier specification>"));
368    newVector.addElement(new Option(
369              "\tAttribute on which misclassifications are based.\n"
370              + "\tIf < 0 will use any current set class or default to the last attribute.",
371              "C", 1, "-C <class index>"));
372    newVector.addElement(new Option(
373              "\tThe number of folds to use for cross-validation cleansing.\n"
374              +"\t(<2 = no cross-validation - default).",
375              "F", 1, "-F <number of folds>"));
376    newVector.addElement(new Option(
377              "\tThreshold for the max error when predicting numeric class.\n"
378              +"\t(Value should be >= 0, default = 0.1).",
379              "T", 1, "-T <threshold>"));
380    newVector.addElement(new Option(
381              "\tThe maximum number of cleansing iterations to perform.\n"
382              +"\t(<1 = until fully cleansed - default)",
383              "I", 1,"-I"));
384    newVector.addElement(new Option(
385              "\tInvert the match so that correctly classified instances are discarded.\n",
386              "V", 0,"-V"));
387   
388    return newVector.elements();
389  }
390
391
392  /**
393   * Parses a given list of options. <p/>
394   *
395   <!-- options-start -->
396   * Valid options are: <p/>
397   *
398   * <pre> -W &lt;classifier specification&gt;
399   *  Full class name of classifier to use, followed
400   *  by scheme options. eg:
401   *   "weka.classifiers.bayes.NaiveBayes -D"
402   *  (default: weka.classifiers.rules.ZeroR)</pre>
403   *
404   * <pre> -C &lt;class index&gt;
405   *  Attribute on which misclassifications are based.
406   *  If &lt; 0 will use any current set class or default to the last attribute.</pre>
407   *
408   * <pre> -F &lt;number of folds&gt;
409   *  The number of folds to use for cross-validation cleansing.
410   *  (&lt;2 = no cross-validation - default).</pre>
411   *
412   * <pre> -T &lt;threshold&gt;
413   *  Threshold for the max error when predicting numeric class.
414   *  (Value should be &gt;= 0, default = 0.1).</pre>
415   *
416   * <pre> -I
417   *  The maximum number of cleansing iterations to perform.
418   *  (&lt;1 = until fully cleansed - default)</pre>
419   *
420   * <pre> -V
421   *  Invert the match so that correctly classified instances are discarded.
422   * </pre>
423   *
424   <!-- options-end -->
425   *
426   * @param options the list of options as an array of strings
427   * @throws Exception if an option is not supported
428   */
429  public void setOptions(String[] options) throws Exception {
430
431    String classifierString = Utils.getOption('W', options);
432    if (classifierString.length() == 0)
433      classifierString = weka.classifiers.rules.ZeroR.class.getName();
434    String[] classifierSpec = Utils.splitOptions(classifierString);
435    if (classifierSpec.length == 0) {
436      throw new Exception("Invalid classifier specification string");
437    }
438    String classifierName = classifierSpec[0];
439    classifierSpec[0] = "";
440    setClassifier(AbstractClassifier.forName(classifierName, classifierSpec));
441
442    String cString = Utils.getOption('C', options);
443    if (cString.length() != 0) {
444      setClassIndex((new Double(cString)).intValue());
445    } else {
446      setClassIndex(-1);
447    }
448
449    String fString = Utils.getOption('F', options);
450    if (fString.length() != 0) {
451      setNumFolds((new Double(fString)).intValue());
452    } else {
453      setNumFolds(0);
454    }
455
456    String tString = Utils.getOption('T', options);
457    if (tString.length() != 0) {
458      setThreshold((new Double(tString)).doubleValue());
459    } else {
460      setThreshold(0.1);
461    }
462
463    String iString = Utils.getOption('I', options);
464    if (iString.length() != 0) {
465      setMaxIterations((new Double(iString)).intValue());
466    } else {
467      setMaxIterations(0);
468    }
469   
470    if (Utils.getFlag('V', options)) {
471      setInvert(true);
472    } else {
473      setInvert(false);
474    }
475       
476    Utils.checkForRemainingOptions(options);
477
478  }
479
480  /**
481   * Gets the current settings of the filter.
482   *
483   * @return an array of strings suitable for passing to setOptions
484   */
485  public String [] getOptions() {
486
487    String [] options = new String [15];
488    int current = 0;
489
490    options[current++] = "-W"; options[current++] = "" + getClassifierSpec();
491    options[current++] = "-C"; options[current++] = "" + getClassIndex();
492    options[current++] = "-F"; options[current++] = "" + getNumFolds();
493    options[current++] = "-T"; options[current++] = "" + getThreshold();
494    options[current++] = "-I"; options[current++] = "" + getMaxIterations();
495    if (getInvert()) {
496      options[current++] = "-V";
497    }
498   
499    while (current < options.length) {
500      options[current++] = "";
501    }
502    return options;
503  }
504
505  /**
506   * Returns a string describing this filter
507   *
508   * @return a description of the filter suitable for
509   * displaying in the explorer/experimenter gui
510   */
511  public String globalInfo() {
512    return 
513        "A filter that removes instances which are incorrectly classified. "
514      + "Useful for removing outliers.";
515  }
516
517  /**
518   * Returns the tip text for this property
519   *
520   * @return tip text for this property suitable for
521   * displaying in the explorer/experimenter gui
522   */
523  public String classifierTipText() {
524
525    return "The classifier upon which to base the misclassifications.";
526  }
527
528  /**
529   * Sets the classifier to classify instances with.
530   *
531   * @param classifier The classifier to be used (with its options set).
532   */
533  public void setClassifier(Classifier classifier) {
534
535    m_cleansingClassifier = classifier;
536  }
537 
538  /**
539   * Gets the classifier used by the filter.
540   *
541   * @return The classifier to be used.
542   */
543  public Classifier getClassifier() {
544
545    return m_cleansingClassifier;
546  }
547
548  /**
549   * Gets the classifier specification string, which contains the class name of
550   * the classifier and any options to the classifier.
551   *
552   * @return the classifier string.
553   */
554  protected String getClassifierSpec() {
555   
556    Classifier c = getClassifier();
557    if (c instanceof OptionHandler) {
558      return c.getClass().getName() + " "
559        + Utils.joinOptions(((OptionHandler)c).getOptions());
560    }
561    return c.getClass().getName();
562  }
563
564  /**
565   * Returns the tip text for this property
566   *
567   * @return tip text for this property suitable for
568   * displaying in the explorer/experimenter gui
569   */
570  public String classIndexTipText() {
571
572    return "Index of the class upon which to base the misclassifications. "
573      + "If < 0 will use any current set class or default to the last attribute.";
574  }
575
576  /**
577   * Sets the attribute on which misclassifications are based.
578   * If &lt; 0 will use any current set class or default to the last attribute.
579   *
580   * @param classIndex the class index.
581   */
582  public void setClassIndex(int classIndex) {
583   
584    m_classIndex = classIndex;
585  }
586
587  /**
588   * Gets the attribute on which misclassifications are based.
589   *
590   * @return the class index.
591   */
592  public int getClassIndex() {
593
594    return m_classIndex;
595  }
596
597  /**
598   * Returns the tip text for this property
599   *
600   * @return tip text for this property suitable for
601   * displaying in the explorer/experimenter gui
602   */
603  public String numFoldsTipText() {
604
605    return "The number of cross-validation folds to use. If < 2 then no cross-validation will be performed.";
606  }
607
608  /**
609   * Sets the number of cross-validation folds to use
610   * - &lt; 2 means no cross-validation.
611   *
612   * @param numOfFolds the number of folds.
613   */
614  public void setNumFolds(int numOfFolds) {
615   
616    m_numOfCrossValidationFolds = numOfFolds;
617  }
618
619  /**
620   * Gets the number of cross-validation folds used by the filter.
621   *
622   * @return the number of folds.
623   */
624  public int getNumFolds() {
625
626    return m_numOfCrossValidationFolds;
627  }
628
629  /**
630   * Returns the tip text for this property
631   *
632   * @return tip text for this property suitable for
633   * displaying in the explorer/experimenter gui
634   */
635  public String thresholdTipText() {
636
637    return "Threshold for the max allowable error when predicting a numeric class. Should be >= 0.";
638  }
639
640  /**
641   * Sets the threshold for the max error when predicting a numeric class.
642   * The value should be &gt;= 0.
643   *
644   * @param threshold the numeric theshold.
645   */
646  public void setThreshold(double threshold) {
647   
648    m_numericClassifyThreshold = threshold;
649  }
650
651  /**
652   * Gets the threshold for the max error when predicting a numeric class.
653   *
654   * @return the numeric threshold.
655   */
656  public double getThreshold() {
657
658    return m_numericClassifyThreshold;
659  }
660
661  /**
662   * Returns the tip text for this property
663   *
664   * @return tip text for this property suitable for
665   * displaying in the explorer/experimenter gui
666   */
667  public String maxIterationsTipText() {
668
669    return "The maximum number of iterations to perform. < 1 means filter will go until fully cleansed.";
670  }
671
672  /**
673   * Sets the maximum number of cleansing iterations to perform
674   * - &lt; 1 means go until fully cleansed
675   *
676   * @param iterations the maximum number of iterations.
677   */
678  public void setMaxIterations(int iterations) {
679   
680    m_numOfCleansingIterations = iterations;
681  }
682
683  /**
684   * Gets the maximum number of cleansing iterations performed
685   *
686   * @return the maximum number of iterations.
687   */
688  public int getMaxIterations() {
689
690    return m_numOfCleansingIterations;
691  }
692
693  /**
694   * Returns the tip text for this property
695   *
696   * @return tip text for this property suitable for
697   * displaying in the explorer/experimenter gui
698   */
699  public String invertTipText() {
700
701    return "Whether or not to invert the selection. If true, correctly classified instances will be discarded.";
702  }
703
704  /**
705   * Set whether selection is inverted.
706   *
707   * @param invert whether or not to invert selection.
708   */
709  public void setInvert(boolean invert) {
710   
711    m_invertMatching = invert;
712  }
713
714  /**
715   * Get whether selection is inverted.
716   *
717   * @return whether or not selection is inverted.
718   */
719  public boolean getInvert() {
720   
721    return m_invertMatching;
722  }
723 
724  /**
725   * Returns the revision string.
726   *
727   * @return            the revision
728   */
729  public String getRevision() {
730    return RevisionUtils.extract("$Revision: 5928 $");
731  }
732
733  /**
734   * Main method for testing this class.
735   *
736   * @param argv should contain arguments to the filter: use -h for help
737   */
738  public static void main(String [] argv) {
739    runFilter(new RemoveMisclassified(), argv);
740  }
741}
Note: See TracBrowser for help on using the repository browser.