source: src/main/java/weka/attributeSelection/ClassifierAttributeEval.java @ 9

Last change on this file since 9 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 12.1 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    ClassifierAttributeEval.java
19 *    Copyright (C) 2009 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.attributeSelection;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.classifiers.Evaluation;
28import weka.classifiers.rules.OneR;
29import weka.core.Capabilities;
30import weka.core.Instances;
31import weka.core.Option;
32import weka.core.OptionHandler;
33import weka.core.RevisionUtils;
34import weka.core.Utils;
35import weka.filters.Filter;
36import weka.filters.unsupervised.attribute.Remove;
37
38import java.util.Enumeration;
39import java.util.Random;
40import java.util.Vector;
41
42/**
43 <!-- globalinfo-start -->
44 * ClassifierAttributeEval :<br/>
45 * <br/>
46 * Evaluates the worth of an attribute by using a user-specified classifier.<br/>
47 * <p/>
48 <!-- globalinfo-end -->
49 *
50 <!-- options-start -->
51 * Valid options are: <p/>
52 *
53 * <pre> -S &lt;seed&gt;
54 *  Random number seed for cross validation.
55 *  (default = 1)</pre>
56 *
57 * <pre> -F &lt;folds&gt;
58 *  Number of folds for cross validation.
59 *  (default = 10)</pre>
60 *
61 * <pre> -D
62 *  Use training data for evaluation rather than cross validaton.</pre>
63 *
64 * <pre> -B &lt;classname + options&gt;
65 *  Classifier to use.
66 *  (default = OneR)</pre>
67 *
68 <!-- options-end -->
69 *
70 * @author Mark Hall (mhall@cs.waikato.ac.nz)
71 * @author FracPete (fracpete at waikato dot ac dot nz)
72 * @version $Revision: 5928 $
73 */
74public class ClassifierAttributeEval
75  extends ASEvaluation
76  implements AttributeEvaluator, OptionHandler {
77 
78  /** for serialization. */
79  private static final long serialVersionUID = 2442390690522602284L;
80
81  /** The training instances. */
82  protected Instances m_trainInstances;
83
84  /** Random number seed. */
85  protected int m_randomSeed;
86
87  /** Number of folds for cross validation. */
88  protected int m_folds;
89
90  /** Use training data to evaluate merit rather than x-val. */
91  protected boolean m_evalUsingTrainingData;
92
93  /** The classifier to use for evaluating the attribute. */
94  protected Classifier m_Classifier;
95
96  /**
97   * Constructor.
98   */
99  public ClassifierAttributeEval () {
100    resetOptions();
101  }
102 
103  /**
104   * Returns a string describing this attribute evaluator.
105   *
106   * @return            a description of the evaluator suitable for
107   *                    displaying in the explorer/experimenter gui
108   */
109  public String globalInfo() {
110    return "ClassifierAttributeEval :\n\nEvaluates the worth of an attribute by "
111      +"using a user-specified classifier.\n";
112  }
113
114  /**
115   * Returns an enumeration describing the available options.
116   *
117   * @return            an enumeration of all the available options.
118   */
119  public Enumeration listOptions() {
120    Vector result = new Vector();
121
122    result.addElement(new Option(
123        "\tRandom number seed for cross validation.\n"
124        + "\t(default = 1)",
125        "S", 1, "-S <seed>"));
126
127    result.addElement(new Option(
128        "\tNumber of folds for cross validation.\n"
129        + "\t(default = 10)",
130        "F", 1, "-F <folds>"));
131
132    result.addElement(new Option(
133        "\tUse training data for evaluation rather than cross validaton.",
134        "D", 0, "-D"));
135
136    result.addElement(new Option(
137        "\tClassifier to use.\n"
138        + "\t(default = OneR)",
139        "B", 1, "-B <classname + options>"));
140
141    return result.elements();
142  }
143
144  /**
145   * Parses a given list of options. <p/>
146   *
147   <!-- options-start -->
148   * Valid options are: <p/>
149   *
150   * <pre> -S &lt;seed&gt;
151   *  Random number seed for cross validation.
152   *  (default = 1)</pre>
153   *
154   * <pre> -F &lt;folds&gt;
155   *  Number of folds for cross validation.
156   *  (default = 10)</pre>
157   *
158   * <pre> -D
159   *  Use training data for evaluation rather than cross validaton.</pre>
160   *
161   * <pre> -B &lt;classname + options&gt;
162   *  Classifier to use.
163   *  (default = OneR)</pre>
164   *
165   <!-- options-end -->
166   *
167   * @param options the list of options as an array of strings
168   * @throws Exception if an option is not supported
169   */
170  public void setOptions(String [] options) throws Exception {
171    String      tmpStr;
172    String[]    tmpOptions;
173   
174    tmpStr = Utils.getOption('S', options);
175    if (tmpStr.length() != 0)
176      setSeed(Integer.parseInt(tmpStr));
177   
178    tmpStr = Utils.getOption('F', options);
179    if (tmpStr.length() != 0)
180      setFolds(Integer.parseInt(tmpStr));
181
182    tmpStr = Utils.getOption('B', options);
183    if (tmpStr.length() != 0) {
184      tmpOptions    = Utils.splitOptions(tmpStr);
185      tmpStr        = tmpOptions[0];
186      tmpOptions[0] = "";
187      setClassifier((Classifier) Utils.forName(Classifier.class, tmpStr, tmpOptions));
188    }
189   
190    setEvalUsingTrainingData(Utils.getFlag('D', options));
191    Utils.checkForRemainingOptions(options);
192  }
193
194  /**
195   * returns the current setup.
196   *
197   * @return the options of the current setup
198   */
199  public String[] getOptions() {
200    Vector<String>      result;
201   
202    result = new Vector<String>();
203   
204    if (getEvalUsingTrainingData())
205      result.add("-D");
206   
207    result.add("-S");
208    result.add("" + getSeed());
209   
210    result.add("-F");
211    result.add("" + getFolds());
212   
213    result.add("-B");
214    result.add(
215        new String(
216            m_Classifier.getClass().getName() + " " 
217            + Utils.joinOptions(((OptionHandler)m_Classifier).getOptions())).trim());
218
219    return result.toArray(new String[result.size()]);
220  }
221
222  /**
223   * Set the random number seed for cross validation.
224   *
225   * @param value       the seed to use
226   */
227  public void setSeed(int value) {
228    m_randomSeed = value;
229  }
230
231  /**
232   * Get the random number seed.
233   *
234   * @return            an <code>int</code> value
235   */
236  public int getSeed() {
237    return m_randomSeed;
238  }
239
240  /**
241   * Returns a string for this option suitable for display in the gui
242   * as a tip text.
243   *
244   * @return            a string describing this option
245   */
246  public String seedTipText() {
247    return "Set the seed for use in cross validation.";
248  }
249
250  /**
251   * Set the number of folds to use for cross validation.
252   *
253   * @param value       the number of folds
254   */
255  public void setFolds(int value) {
256    m_folds = value;
257    if (m_folds < 2)
258      m_folds = 2;
259  }
260   
261  /**
262   * Get the number of folds used for cross validation.
263   *
264   * @return            the number of folds
265   */
266  public int getFolds() {
267    return m_folds;
268  }
269
270  /**
271   * Returns a string for this option suitable for display in the gui
272   * as a tip text.
273   *
274   * @return            a string describing this option
275   */
276  public String foldsTipText() {
277    return "Set the number of folds for cross validation.";
278  }
279
280  /**
281   * Use the training data to evaluate attributes rather than cross validation.
282   *
283   * @param value       true if training data is to be used for evaluation
284   */
285  public void setEvalUsingTrainingData(boolean value) {
286    m_evalUsingTrainingData = value;
287  }
288
289  /**
290   * Returns true if the training data is to be used for evaluation.
291   *
292   * @return            true if training data is to be used for evaluation
293   */
294  public boolean getEvalUsingTrainingData() {
295    return m_evalUsingTrainingData;
296  }
297
298  /**
299   * Returns a string for this option suitable for display in the gui
300   * as a tip text.
301   *
302   * @return            a string describing this option
303   */
304  public String evalUsingTrainingDataTipText() {
305    return "Use the training data to evaluate attributes rather than "
306      + "cross validation.";
307  }
308
309  /**
310   * Set the classifier to use for evaluating the attribute.
311   *
312   * @param value       the classifier to use
313   */
314  public void setClassifier(Classifier value) {
315    m_Classifier = value;
316  }
317
318  /**
319   * Returns the classifier to use for evaluating the attribute.
320   *
321   * @return            the classifier in use
322   */
323  public Classifier getClassifier() {
324    return m_Classifier;
325  }
326
327  /**
328   * Returns a string for this option suitable for display in the gui
329   * as a tip text.
330   *
331   * @return            a string describing this option
332   */
333  public String classifierTipText() {
334    return "The classifier to use for evaluating the attribute.";
335  }
336
337  /**
338   * Returns the capabilities of this evaluator.
339   *
340   * @return            the capabilities of this evaluator
341   * @see               Capabilities
342   */
343  public Capabilities getCapabilities() {
344    Capabilities        result;
345   
346    if (m_Classifier != null) {
347      result = m_Classifier.getCapabilities();
348      result.setOwner(this);
349    }
350    else {
351      result = super.getCapabilities();
352      result.disableAll();
353    }
354   
355    return result;
356  }
357
358  /**
359   * Initializes a ClassifierAttribute attribute evaluator.
360   *
361   * @param data        set of instances serving as training data
362   * @throws Exception  if the evaluator has not been generated successfully
363   */
364  public void buildEvaluator (Instances data) throws Exception {
365    // can evaluator handle data?
366    getCapabilities().testWithFail(data);
367
368    m_trainInstances = data;
369  }
370
371
372  /**
373   * Resets to defaults.
374   */
375  protected void resetOptions () {
376    m_trainInstances        = null;
377    m_randomSeed            = 1;
378    m_folds                 = 10;
379    m_evalUsingTrainingData = false;
380    m_Classifier            = new OneR();
381  }
382
383
384  /**
385   * Evaluates an individual attribute by measuring the amount
386   * of information gained about the class given the attribute.
387   *
388   * @param attribute   the index of the attribute to be evaluated
389   * @return            the evaluation
390   * @throws Exception  if the attribute could not be evaluated
391   */
392  public double evaluateAttribute(int attribute) throws Exception {
393    int[]       featArray; 
394    double      errorRate;
395    Evaluation  eval;
396    Remove      delTransform;
397    Instances   train;
398    Classifier  cls;
399
400    // create tmp dataset
401    featArray    = new int[2]; // feat + class
402    delTransform = new Remove();
403    delTransform.setInvertSelection(true);
404    train        = new Instances(m_trainInstances);
405    featArray[0] = attribute;
406    featArray[1] = train.classIndex();
407    delTransform.setAttributeIndicesArray(featArray);
408    delTransform.setInputFormat(train);
409    train = Filter.useFilter(train, delTransform);
410   
411    // evaluate classifier
412    eval = new Evaluation(train);
413    cls  = AbstractClassifier.makeCopy(m_Classifier);
414    if (m_evalUsingTrainingData) {
415      cls.buildClassifier(train);
416      eval.evaluateModel(cls, train);
417    }
418    else {
419      eval.crossValidateModel(cls, train, m_folds, new Random(m_randomSeed));
420    }
421    errorRate = eval.errorRate();
422   
423    return (1 - errorRate)*100.0;
424  }
425
426  /**
427   * Return a description of the evaluator.
428   *
429   * @return            description as a string
430   */
431  public String toString () {
432    StringBuffer text = new StringBuffer();
433
434    if (m_trainInstances == null) {
435      text.append("\tClassifier feature evaluator has not been built yet");
436    }
437    else {
438      text.append("\tClassifier feature evaluator.\n\n");
439      text.append("\tUsing ");
440      if (m_evalUsingTrainingData)
441        text.append("training data for evaluation of attributes.\n");
442      else
443        text.append(getFolds()+ " fold cross validation for evaluating attributes.\n");
444      text.append("\tClassifier in use: " + m_Classifier.getClass().getName() + " " + Utils.joinOptions(((OptionHandler)m_Classifier).getOptions()));
445    }
446    text.append("\n");
447   
448    return text.toString();
449  }
450 
451  /**
452   * Returns the revision string.
453   *
454   * @return            the revision
455   */
456  public String getRevision() {
457    return RevisionUtils.extract("$Revision: 5928 $");
458  }
459
460  /**
461   * Main method for executing this class.
462   *
463   * @param args        the options
464   */
465  public static void main (String[] args) {
466    runEvaluator(new ClassifierAttributeEval(), args);
467  }
468}
Note: See TracBrowser for help on using the repository browser.