source: branches/MetisMQI/src/main/java/weka/filters/supervised/attribute/AddClassification.java

Last change on this file was 29, checked in by gnappo, 15 years ago

Taggata versione per la demo e aggiunto branch.

File size: 21.3 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 * AddClassification.java
19 * Copyright (C) 2006 University of Waikato, Hamilton, New Zealand
20 */
21
22package weka.filters.supervised.attribute;
23
24import weka.classifiers.Classifier;
25import weka.classifiers.AbstractClassifier;
26import weka.core.Attribute;
27import weka.core.Capabilities;
28import weka.core.FastVector;
29import weka.core.Instance;
30import weka.core.DenseInstance;
31import weka.core.Instances;
32import weka.core.Option;
33import weka.core.OptionHandler;
34import weka.core.RevisionUtils;
35import weka.core.SparseInstance;
36import weka.core.Utils;
37import weka.core.WekaException;
38import weka.filters.SimpleBatchFilter;
39
40import java.io.File;
41import java.io.FileInputStream;
42import java.io.FileNotFoundException;
43import java.io.ObjectInputStream;
44import java.util.Enumeration;
45import java.util.Vector;
46
47/**
48 <!-- globalinfo-start -->
49 * A filter for adding the classification, the class distribution and an error flag to a dataset with a classifier. The classifier is either trained on the data itself or provided as serialized model.
50 * <p/>
51 <!-- globalinfo-end -->
52 *
53 <!-- options-start -->
54 * Valid options are: <p/>
55 *
56 * <pre> -D
57 *  Turns on output of debugging information.</pre>
58 *
59 * <pre> -W &lt;classifier specification&gt;
60 *  Full class name of classifier to use, followed
61 *  by scheme options. eg:
62 *   "weka.classifiers.bayes.NaiveBayes -D"
63 *  (default: weka.classifiers.rules.ZeroR)</pre>
64 *
65 * <pre> -serialized &lt;file&gt;
66 *  Instead of training a classifier on the data, one can also provide
67 *  a serialized model and use that for tagging the data.</pre>
68 *
69 * <pre> -classification
70 *  Adds an attribute with the actual classification.
71 *  (default: off)</pre>
72 *
73 * <pre> -remove-old-class
74 *  Removes the old class attribute.
75 *  (default: off)</pre>
76 *
77 * <pre> -distribution
78 *  Adds attributes with the distribution for all classes
79 *  (for numeric classes this will be identical to the attribute
80 *  output with '-classification').
81 *  (default: off)</pre>
82 *
83 * <pre> -error
84 *  Adds an attribute indicating whether the classifier output
85 *  a wrong classification (for numeric classes this is the numeric
86 *  difference).
87 *  (default: off)</pre>
88 *
89 <!-- options-end -->
90 *
91 * @author  fracpete (fracpete at waikato dot ac dot nz)
92 * @version $Revision: 5987 $
93 */
94public class AddClassification
95  extends SimpleBatchFilter {
96
97  /** for serialization */
98  private static final long serialVersionUID = -1931467132568441909L;
99
100  /** The classifier template used to do the classification */
101  protected Classifier m_Classifier = new weka.classifiers.rules.ZeroR();
102
103  /** The file from which to load a serialized classifier */
104  protected File m_SerializedClassifierFile = new File(System.getProperty("user.dir"));
105 
106  /** The actual classifier used to do the classification */
107  protected Classifier m_ActualClassifier = null;
108
109  /** whether to output the classification */
110  protected boolean m_OutputClassification = false;
111
112  /** whether to remove the old class attribute */
113  protected boolean m_RemoveOldClass = false;
114 
115  /** whether to output the class distribution */
116  protected boolean m_OutputDistribution = false;
117 
118  /** whether to output the error flag */
119  protected boolean m_OutputErrorFlag = false;
120
121  /**
122   * Returns a string describing this filter
123   *
124   * @return            a description of the filter suitable for
125   *                    displaying in the explorer/experimenter gui
126   */
127  public String globalInfo() {
128    return 
129        "A filter for adding the classification, the class distribution and "
130      + "an error flag to a dataset with a classifier. The classifier is "
131      + "either trained on the data itself or provided as serialized model.";
132  }
133
134  /**
135   * Returns an enumeration describing the available options.
136   *
137   * @return            an enumeration of all the available options.
138   */
139  public Enumeration listOptions() {
140    Vector              result;
141    Enumeration         en;
142
143    result = new Vector();
144
145    en = super.listOptions();
146    while (en.hasMoreElements())
147      result.addElement(en.nextElement());
148
149    result.addElement(new Option(
150        "\tFull class name of classifier to use, followed\n"
151        + "\tby scheme options. eg:\n"
152        + "\t\t\"weka.classifiers.bayes.NaiveBayes -D\"\n"
153        + "\t(default: weka.classifiers.rules.ZeroR)",
154        "W", 1, "-W <classifier specification>"));
155
156    result.addElement(new Option(
157        "\tInstead of training a classifier on the data, one can also provide\n"
158        + "\ta serialized model and use that for tagging the data.",
159        "serialized", 1, "-serialized <file>"));
160
161    result.addElement(new Option(
162        "\tAdds an attribute with the actual classification.\n"
163        + "\t(default: off)",
164        "classification", 0, "-classification"));
165
166    result.addElement(new Option(
167        "\tRemoves the old class attribute.\n"
168        + "\t(default: off)",
169        "remove-old-class", 0, "-remove-old-class"));
170
171    result.addElement(new Option(
172        "\tAdds attributes with the distribution for all classes \n"
173        + "\t(for numeric classes this will be identical to the attribute \n"
174        + "\toutput with '-classification').\n"
175        + "\t(default: off)",
176        "distribution", 0, "-distribution"));
177
178    result.addElement(new Option(
179        "\tAdds an attribute indicating whether the classifier output \n"
180        + "\ta wrong classification (for numeric classes this is the numeric \n"
181        + "\tdifference).\n"
182        + "\t(default: off)",
183        "error", 0, "-error"));
184
185    return result.elements();
186  }
187
188  /**
189   * Parses the options for this object. <p/>
190   *
191   <!-- options-start -->
192   * Valid options are: <p/>
193   *
194   * <pre> -D
195   *  Turns on output of debugging information.</pre>
196   *
197   * <pre> -W &lt;classifier specification&gt;
198   *  Full class name of classifier to use, followed
199   *  by scheme options. eg:
200   *   "weka.classifiers.bayes.NaiveBayes -D"
201   *  (default: weka.classifiers.rules.ZeroR)</pre>
202   *
203   * <pre> -serialized &lt;file&gt;
204   *  Instead of training a classifier on the data, one can also provide
205   *  a serialized model and use that for tagging the data.</pre>
206   *
207   * <pre> -classification
208   *  Adds an attribute with the actual classification.
209   *  (default: off)</pre>
210   *
211   * <pre> -remove-old-class
212   *  Removes the old class attribute.
213   *  (default: off)</pre>
214   *
215   * <pre> -distribution
216   *  Adds attributes with the distribution for all classes
217   *  (for numeric classes this will be identical to the attribute
218   *  output with '-classification').
219   *  (default: off)</pre>
220   *
221   * <pre> -error
222   *  Adds an attribute indicating whether the classifier output
223   *  a wrong classification (for numeric classes this is the numeric
224   *  difference).
225   *  (default: off)</pre>
226   *
227   <!-- options-end -->
228   *
229   * @param options     the options to use
230   * @throws Exception  if setting of options fails
231   */
232  public void setOptions(String[] options) throws Exception {
233    String      tmpStr;
234    String[]    tmpOptions;
235    File        file;
236    boolean     serializedModel;
237
238    setOutputClassification(Utils.getFlag("classification", options));
239   
240    setRemoveOldClass(Utils.getFlag("remove-old-class", options));
241   
242    setOutputDistribution(Utils.getFlag("distribution", options));
243
244    setOutputErrorFlag(Utils.getFlag("error", options));
245   
246    serializedModel = false;
247    tmpStr = Utils.getOption("serialized", options);
248    if (tmpStr.length() != 0) {
249      file = new File(tmpStr);
250      if (!file.exists())
251        throw new FileNotFoundException(
252            "File '" + file.getAbsolutePath() + "' not found!");
253      if (file.isDirectory())
254        throw new FileNotFoundException(
255            "'" + file.getAbsolutePath() + "' points to a directory not a file!");
256      setSerializedClassifierFile(file);
257      serializedModel = true;
258    }
259    else {
260      setSerializedClassifierFile(null);
261    }
262   
263    if (!serializedModel) {
264      tmpStr = Utils.getOption('W', options);
265      if (tmpStr.length() == 0)
266        tmpStr = weka.classifiers.rules.ZeroR.class.getName();
267      tmpOptions = Utils.splitOptions(tmpStr);
268      if (tmpOptions.length == 0)
269        throw new Exception("Invalid classifier specification string");
270      tmpStr = tmpOptions[0];
271      tmpOptions[0] = "";
272      setClassifier(AbstractClassifier.forName(tmpStr, tmpOptions));
273    }
274
275    super.setOptions(options);
276  }
277
278  /**
279   * Gets the current settings of the classifier.
280   *
281   * @return            an array of strings suitable for passing to setOptions
282   */
283  public String[] getOptions() {
284    int         i;
285    Vector      result;
286    String[]    options;
287    File        file;
288
289    result = new Vector();
290
291    options = super.getOptions();
292    for (i = 0; i < options.length; i++)
293      result.add(options[i]);
294
295    if (getOutputClassification())
296      result.add("-classification");
297
298    if (getRemoveOldClass())
299      result.add("-remove-old-class");
300
301    if (getOutputDistribution())
302      result.add("-distribution");
303
304    if (getOutputErrorFlag())
305      result.add("-error");
306
307    file = getSerializedClassifierFile();
308    if ((file != null) && (!file.isDirectory())) {
309      result.add("-serialized");
310      result.add(file.getAbsolutePath());
311    }
312    else {
313      result.add("-W");
314      result.add(getClassifierSpec());
315    }
316   
317    return (String[]) result.toArray(new String[result.size()]);         
318  }
319
320  /**
321   * Returns the Capabilities of this filter.
322   *
323   * @return            the capabilities of this object
324   * @see               Capabilities
325   */
326  public Capabilities getCapabilities() {
327    Capabilities        result;
328   
329    if (getClassifier() == null) {
330      result = super.getCapabilities();
331      result.disableAll();
332    } else {
333      result = getClassifier().getCapabilities();
334    }
335   
336    result.setMinimumNumberInstances(0);
337   
338    return result;
339  }
340
341  /**
342   * Returns the tip text for this property
343   *
344   * @return            tip text for this property suitable for
345   *                    displaying in the explorer/experimenter gui
346   */
347  public String classifierTipText() {
348    return "The classifier to use for classification.";
349  }
350
351  /**
352   * Sets the classifier to classify instances with.
353   *
354   * @param value       The classifier to be used (with its options set).
355   */
356  public void setClassifier(Classifier value) {
357    m_Classifier = value;
358  }
359 
360  /**
361   * Gets the classifier used by the filter.
362   *
363   * @return            The classifier to be used.
364   */
365  public Classifier getClassifier() {
366    return m_Classifier;
367  }
368
369  /**
370   * Gets the classifier specification string, which contains the class name of
371   * the classifier and any options to the classifier.
372   *
373   * @return            the classifier string.
374   */
375  protected String getClassifierSpec() {
376    String      result;
377    Classifier  c;
378   
379    c      = getClassifier();
380    result = c.getClass().getName();
381    if (c instanceof OptionHandler)
382      result += " " + Utils.joinOptions(((OptionHandler) c).getOptions());
383   
384    return result;
385  }
386 
387  /**
388   * Returns the tip text for this property
389   *
390   * @return            tip text for this property suitable for
391   *                    displaying in the explorer/experimenter gui
392   */
393  public String serializedClassifierFileTipText() {
394    return "A file containing the serialized model of a trained classifier.";
395  }
396
397  /**
398   * Gets the file pointing to a serialized, trained classifier. If it is
399   * null or pointing to a directory it will not be used.
400   *
401   * @return            the file the serialized, trained classifier is located
402   *                    in
403   */
404  public File getSerializedClassifierFile() {
405    return m_SerializedClassifierFile;
406  }
407
408  /**
409   * Sets the file pointing to a serialized, trained classifier. If the
410   * argument is null, doesn't exist or pointing to a directory, then the
411   * value is ignored.
412   *
413   * @param value       the file pointing to the serialized, trained classifier
414   */
415  public void setSerializedClassifierFile(File value) {
416    if ((value == null) || (!value.exists()))
417      value = new File(System.getProperty("user.dir"));
418
419    m_SerializedClassifierFile = value;
420  }
421 
422  /**
423   * Returns the tip text for this property
424   *
425   * @return            tip text for this property suitable for
426   *                    displaying in the explorer/experimenter gui
427   */
428  public String outputClassificationTipText() {
429    return "Whether to add an attribute with the actual classification.";
430  }
431
432  /**
433   * Get whether the classifiction of the classifier is output.
434   *
435   * @return            true if the classification of the classifier is output.
436   */
437  public boolean getOutputClassification() {
438    return m_OutputClassification;
439  }
440 
441  /**
442   * Set whether the classification of the classifier is output.
443   *
444   * @param value       whether the classification of the classifier is output.
445   */
446  public void setOutputClassification(boolean value) {
447    m_OutputClassification = value;
448  }
449 
450  /**
451   * Returns the tip text for this property
452   *
453   * @return            tip text for this property suitable for
454   *                    displaying in the explorer/experimenter gui
455   */
456  public String removeOldClassTipText() {
457    return "Whether to remove the old class attribute.";
458  }
459
460  /**
461   * Get whether the old class attribute is removed.
462   *
463   * @return            true if the old class attribute is removed.
464   */
465  public boolean getRemoveOldClass() {
466    return m_RemoveOldClass;
467  }
468 
469  /**
470   * Set whether the old class attribute is removed.
471   *
472   * @param value       whether the old class attribute is removed.
473   */
474  public void setRemoveOldClass(boolean value) {
475    m_RemoveOldClass = value;
476  }
477 
478  /**
479   * Returns the tip text for this property
480   *
481   * @return            tip text for this property suitable for
482   *                    displaying in the explorer/experimenter gui
483   */
484  public String outputDistributionTipText() {
485    return 
486        "Whether to add attributes with the distribution for all classes "
487      + "(for numeric classes this will be identical to the attribute output "
488      + "with 'outputClassification').";
489  }
490
491  /**
492   * Get whether the classifiction of the classifier is output.
493   *
494   * @return            true if the distribution of the classifier is output.
495   */
496  public boolean getOutputDistribution() {
497    return m_OutputDistribution;
498  }
499 
500  /**
501   * Set whether the Distribution of the classifier is output.
502   *
503   * @param value       whether the distribution of the classifier is output.
504   */
505  public void setOutputDistribution(boolean value) {
506    m_OutputDistribution = value;
507  }
508 
509  /**
510   * Returns the tip text for this property
511   *
512   * @return            tip text for this property suitable for
513   *                    displaying in the explorer/experimenter gui
514   */
515  public String outputErrorFlagTipText() {
516    return 
517        "Whether to add an attribute indicating whether the classifier output "
518      + "a wrong classification (for numeric classes this is the numeric "
519      + "difference).";
520  }
521
522  /**
523   * Get whether the classifiction of the classifier is output.
524   *
525   * @return            true if the classification of the classifier is output.
526   */
527  public boolean getOutputErrorFlag() {
528    return m_OutputErrorFlag;
529  }
530 
531  /**
532   * Set whether the classification of the classifier is output.
533   *
534   * @param value       whether the classification of the classifier is output.
535   */
536  public void setOutputErrorFlag(boolean value) {
537    m_OutputErrorFlag = value;
538  }
539
540  /**
541   * Determines the output format based on the input format and returns
542   * this. In case the output format cannot be returned immediately, i.e.,
543   * immediateOutputFormat() returns false, then this method will be called
544   * from batchFinished().
545   *
546   * @param inputFormat     the input format to base the output format on
547   * @return                the output format
548   * @throws Exception      in case the determination goes wrong
549   * @see   #hasImmediateOutputFormat()
550   * @see   #batchFinished()
551   */
552  protected Instances determineOutputFormat(Instances inputFormat)
553      throws Exception {
554   
555    Instances   result;
556    FastVector  atts;
557    int         i;
558    FastVector  values;
559    int         classindex;
560   
561    classindex = -1;
562   
563    // copy old attributes
564    atts = new FastVector();
565    for (i = 0; i < inputFormat.numAttributes(); i++) {
566      // remove class?
567      if ((i == inputFormat.classIndex()) && (getRemoveOldClass()) )
568        continue;
569      // record class index
570      if (i == inputFormat.classIndex())
571        classindex = i;
572      atts.addElement(inputFormat.attribute(i).copy());
573    }
574   
575    // add new attributes
576    // 1. classification?
577    if (getOutputClassification()) {
578      // if old class got removed, use this one
579      if (classindex == -1)
580        classindex = atts.size();
581      atts.addElement(inputFormat.classAttribute().copy("classification"));
582    }
583   
584    // 2. distribution?
585    if (getOutputDistribution()) {
586      if (inputFormat.classAttribute().isNominal()) {
587        for (i = 0; i < inputFormat.classAttribute().numValues(); i++) {
588          atts.addElement(new Attribute("distribution_" + inputFormat.classAttribute().value(i)));
589        }
590      }
591      else {
592        atts.addElement(new Attribute("distribution"));
593      }
594    }
595   
596    // 2. error flag?
597    if (getOutputErrorFlag()) {
598      if (inputFormat.classAttribute().isNominal()) {
599        values = new FastVector();
600        values.addElement("no");
601        values.addElement("yes");
602        atts.addElement(new Attribute("error", values));
603      }
604      else {
605        atts.addElement(new Attribute("error"));
606      }
607    }
608   
609    // generate new header
610    result = new Instances(inputFormat.relationName(), atts, 0);
611    result.setClassIndex(classindex);
612   
613    return result;
614  }
615
616  /**
617   * Processes the given data (may change the provided dataset) and returns
618   * the modified version. This method is called in batchFinished().
619   *
620   * @param instances   the data to process
621   * @return            the modified data
622   * @throws Exception  in case the processing goes wrong
623   * @see               #batchFinished()
624   */
625  protected Instances process(Instances instances) throws Exception {
626    Instances           result;
627    double[]            newValues;
628    double[]            oldValues;
629    int                 i;
630    int                 start;
631    int                 n;
632    Instance            newInstance;
633    Instance            oldInstance;
634    Instances           header;
635    double[]            distribution;
636    File                file;
637    ObjectInputStream   ois;
638   
639    // load or train classifier
640    if (!isFirstBatchDone()) {
641      file = getSerializedClassifierFile();
642      if (!file.isDirectory()) {
643        ois = new ObjectInputStream(new FileInputStream(file));
644        m_ActualClassifier = (Classifier) ois.readObject();
645        header = null;
646        // let's see whether there's an Instances header stored as well
647        try {
648          header = (Instances) ois.readObject();
649        }
650        catch (Exception e) {
651          // ignored
652        }
653        ois.close();
654        // same dataset format?
655        if ((header != null) && (!header.equalHeaders(instances)))
656          throw new WekaException(
657              "Training header of classifier and filter dataset don't match:\n"
658              + header.equalHeadersMsg(instances));
659      }
660      else {
661        m_ActualClassifier = AbstractClassifier.makeCopy(m_Classifier);
662        m_ActualClassifier.buildClassifier(instances);
663      }
664    }
665   
666    result = getOutputFormat();
667   
668    // traverse all instances
669    for (i = 0; i < instances.numInstances(); i++) {
670      oldInstance = instances.instance(i);
671      oldValues   = oldInstance.toDoubleArray();
672      newValues   = new double[result.numAttributes()];
673     
674      start = oldValues.length;
675      if (getRemoveOldClass())
676        start--;
677
678      // copy old values
679      System.arraycopy(oldValues, 0, newValues, 0, start);
680     
681      // add new values:
682      // 1. classification?
683      if (getOutputClassification()) {
684        newValues[start] = m_ActualClassifier.classifyInstance(oldInstance);
685        start++;
686      }
687     
688      // 2. distribution?
689      if (getOutputDistribution()) {
690        distribution = m_ActualClassifier.distributionForInstance(oldInstance);
691        for (n = 0; n < distribution.length; n++) {
692          newValues[start] = distribution[n];
693          start++;
694        }
695      }
696     
697      // 3. error flag?
698      if (getOutputErrorFlag()) {
699        if (result.classAttribute().isNominal()) {
700          if (oldInstance.classValue() == m_ActualClassifier.classifyInstance(oldInstance))
701            newValues[start] = 0;
702          else
703            newValues[start] = 1;
704        }
705        else {
706          newValues[start] = m_ActualClassifier.classifyInstance(oldInstance) - oldInstance.classValue();
707        }
708        start++;
709      }
710     
711      // create new instance
712      if (oldInstance instanceof SparseInstance)
713        newInstance = new SparseInstance(oldInstance.weight(), newValues);
714      else
715        newInstance = new DenseInstance(oldInstance.weight(), newValues);
716
717      // copy string/relational values from input to output
718      copyValues(newInstance, false, oldInstance.dataset(), getOutputFormat());
719
720      result.add(newInstance);
721    }
722   
723    return result;
724  }
725 
726  /**
727   * Returns the revision string.
728   *
729   * @return            the revision
730   */
731  public String getRevision() {
732    return RevisionUtils.extract("$Revision: 5987 $");
733  }
734
735  /**
736   * runs the filter with the given arguments
737   *
738   * @param args      the commandline arguments
739   */
740  public static void main(String[] args) {
741    runFilter(new AddClassification(), args);
742  }
743}
Note: See TracBrowser for help on using the repository browser.