source: src/main/java/weka/attributeSelection/OneRAttributeEval.java @ 23

Last change on this file since 23 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 12.5 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    OneRAttributeEval.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.attributeSelection;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.classifiers.Evaluation;
28import weka.core.Capabilities;
29import weka.core.Instances;
30import weka.core.Option;
31import weka.core.OptionHandler;
32import weka.core.RevisionUtils;
33import weka.core.Utils;
34import weka.core.Capabilities.Capability;
35import weka.filters.Filter;
36import weka.filters.unsupervised.attribute.Remove;
37
38import java.util.Enumeration;
39import java.util.Random;
40import java.util.Vector;
41
42/**
43 <!-- globalinfo-start -->
44 * OneRAttributeEval :<br/>
45 * <br/>
46 * Evaluates the worth of an attribute by using the OneR classifier.<br/>
47 * <p/>
48 <!-- globalinfo-end -->
49 *
50 <!-- options-start -->
51 * Valid options are: <p/>
52 *
53 * <pre> -S &lt;seed&gt;
54 *  Random number seed for cross validation
55 *  (default = 1)</pre>
56 *
57 * <pre> -F &lt;folds&gt;
58 *  Number of folds for cross validation
59 *  (default = 10)</pre>
60 *
61 * <pre> -D
62 *  Use training data for evaluation rather than cross validaton</pre>
63 *
64 * <pre> -B &lt;minimum bucket size&gt;
65 *  Minimum number of objects in a bucket
66 *  (passed on to OneR, default = 6)</pre>
67 *
68 <!-- options-end -->
69 *
70 * @author Mark Hall (mhall@cs.waikato.ac.nz)
71 * @version $Revision: 5928 $
72 */
73public class OneRAttributeEval
74  extends ASEvaluation
75  implements AttributeEvaluator, OptionHandler {
76 
77  /** for serialization */
78  static final long serialVersionUID = 4386514823886856980L;
79
80  /** The training instances */
81  private Instances m_trainInstances;
82
83  /** The class index */
84  private int m_classIndex;
85
86  /** The number of attributes */
87  private int m_numAttribs;
88
89  /** The number of instances */
90  private int m_numInstances;
91
92  /** Random number seed */
93  private int m_randomSeed;
94
95  /** Number of folds for cross validation */
96  private int m_folds;
97
98  /** Use training data to evaluate merit rather than x-val */
99  private boolean m_evalUsingTrainingData;
100
101  /** Passed on to OneR */
102  private int m_minBucketSize;
103
104
105  /**
106   * Returns a string describing this attribute evaluator
107   * @return a description of the evaluator suitable for
108   * displaying in the explorer/experimenter gui
109   */
110  public String globalInfo() {
111    return "OneRAttributeEval :\n\nEvaluates the worth of an attribute by "
112      +"using the OneR classifier.\n";
113  }
114
115  /**
116   * Returns a string for this option suitable for display in the gui
117   * as a tip text
118   *
119   * @return a string describing this option
120   */
121  public String seedTipText() {
122    return "Set the seed for use in cross validation.";
123  }
124
125  /**
126   * Set the random number seed for cross validation
127   *
128   * @param seed the seed to use
129   */
130  public void setSeed(int seed) {
131    m_randomSeed = seed;
132  }
133
134  /**
135   * Get the random number seed
136   *
137   * @return an <code>int</code> value
138   */
139  public int getSeed() {
140    return m_randomSeed;
141  }
142
143  /**
144   * Returns a string for this option suitable for display in the gui
145   * as a tip text
146   *
147   * @return a string describing this option
148   */
149  public String foldsTipText() {
150    return "Set the number of folds for cross validation.";
151  }
152
153  /**
154   * Set the number of folds to use for cross validation
155   *
156   * @param folds the number of folds
157   */
158  public void setFolds(int folds) {
159    m_folds = folds;
160    if (m_folds < 2) {
161      m_folds = 2;
162    }
163  }
164   
165  /**
166   * Get the number of folds used for cross validation
167   *
168   * @return the number of folds
169   */
170  public int getFolds() {
171    return m_folds;
172  }
173
174  /**
175   * Returns a string for this option suitable for display in the gui
176   * as a tip text
177   *
178   * @return a string describing this option
179   */
180  public String evalUsingTrainingDataTipText() {
181    return "Use the training data to evaluate attributes rather than "
182      + "cross validation.";
183  }
184
185  /**
186   * Use the training data to evaluate attributes rather than cross validation
187   *
188   * @param e true if training data is to be used for evaluation
189   */
190  public void setEvalUsingTrainingData(boolean e) {
191    m_evalUsingTrainingData = e;
192  }
193
194  /**
195   * Returns a string for this option suitable for display in the gui
196   * as a tip text
197   *
198   * @return a string describing this option
199   */
200  public String minimumBucketSizeTipText() {
201    return "The minimum number of objects in a bucket "
202      + "(passed to OneR).";
203  }
204
205  /**
206   * Set the minumum bucket size used by OneR
207   *
208   * @param minB the minimum bucket size to use
209   */
210  public void setMinimumBucketSize(int minB) {
211    m_minBucketSize = minB;
212  }
213
214  /**
215   * Get the minimum bucket size used by oneR
216   *
217   * @return the minimum bucket size used
218   */
219  public int getMinimumBucketSize() {
220    return m_minBucketSize;
221  }
222
223  /**
224   * Returns true if the training data is to be used for evaluation
225   *
226   * @return true if training data is to be used for evaluation
227   */
228  public boolean getEvalUsingTrainingData() {
229    return m_evalUsingTrainingData;
230  }
231
232  /**
233   * Returns an enumeration describing the available options.
234   *
235   * @return an enumeration of all the available options.
236   */
237  public Enumeration listOptions() {
238
239    Vector newVector = new Vector(4);
240
241    newVector.addElement(new Option(
242        "\tRandom number seed for cross validation\n"
243        + "\t(default = 1)",
244        "S", 1, "-S <seed>"));
245
246    newVector.addElement(new Option(
247        "\tNumber of folds for cross validation\n"
248        + "\t(default = 10)",
249        "F", 1, "-F <folds>"));
250
251    newVector.addElement(new Option(
252        "\tUse training data for evaluation rather than cross validaton",
253        "D", 0, "-D"));
254
255    newVector.addElement(new Option(
256        "\tMinimum number of objects in a bucket\n"
257        + "\t(passed on to "
258        +"OneR, default = 6)",
259        "B", 1, "-B <minimum bucket size>"));
260
261    return newVector.elements();
262  }
263
264  /**
265   * Parses a given list of options. <p/>
266   *
267   <!-- options-start -->
268   * Valid options are: <p/>
269   *
270   * <pre> -S &lt;seed&gt;
271   *  Random number seed for cross validation
272   *  (default = 1)</pre>
273   *
274   * <pre> -F &lt;folds&gt;
275   *  Number of folds for cross validation
276   *  (default = 10)</pre>
277   *
278   * <pre> -D
279   *  Use training data for evaluation rather than cross validaton</pre>
280   *
281   * <pre> -B &lt;minimum bucket size&gt;
282   *  Minimum number of objects in a bucket
283   *  (passed on to OneR, default = 6)</pre>
284   *
285   <!-- options-end -->
286   *
287   * @param options the list of options as an array of strings
288   * @throws Exception if an option is not supported
289   */
290  public void setOptions(String [] options) throws Exception {
291    String temp = Utils.getOption('S', options);
292
293    if (temp.length() != 0) {
294      setSeed(Integer.parseInt(temp));
295    }
296   
297    temp = Utils.getOption('F', options);
298    if (temp.length() != 0) {
299      setFolds(Integer.parseInt(temp));
300    }
301
302    temp = Utils.getOption('B', options);
303    if (temp.length() != 0) {
304      setMinimumBucketSize(Integer.parseInt(temp));
305    }
306   
307    setEvalUsingTrainingData(Utils.getFlag('D', options));
308    Utils.checkForRemainingOptions(options);
309  }
310
311  /**
312   * returns the current setup.
313   *
314   * @return the options of the current setup
315   */
316  public String[] getOptions() {
317    String [] options = new String [7];
318    int current = 0;
319   
320    if (getEvalUsingTrainingData()) {
321      options[current++] = "-D";
322    }
323   
324    options[current++] = "-S";
325    options[current++] = "" + getSeed();
326    options[current++] = "-F";
327    options[current++] = "" + getFolds();
328    options[current++] = "-B";
329    options[current++] = "" + getMinimumBucketSize();
330
331    while (current < options.length) {
332      options[current++] = "";
333    }
334    return options;
335  }
336
337  /**
338   * Constructor
339   */
340  public OneRAttributeEval () {
341    resetOptions();
342  }
343
344  /**
345   * Returns the capabilities of this evaluator.
346   *
347   * @return            the capabilities of this evaluator
348   * @see               Capabilities
349   */
350  public Capabilities getCapabilities() {
351    Capabilities result = super.getCapabilities();
352    result.disableAll();
353   
354    // attributes
355    result.enable(Capability.NOMINAL_ATTRIBUTES);
356    result.enable(Capability.NUMERIC_ATTRIBUTES);
357    result.enable(Capability.DATE_ATTRIBUTES);
358    result.enable(Capability.MISSING_VALUES);
359   
360    // class
361    result.enable(Capability.NOMINAL_CLASS);
362    result.enable(Capability.MISSING_CLASS_VALUES);
363   
364    return result;
365  }
366
367  /**
368   * Initializes a OneRAttribute attribute evaluator.
369   * Discretizes all attributes that are numeric.
370   *
371   * @param data set of instances serving as training data
372   * @throws Exception if the evaluator has not been
373   * generated successfully
374   */
375  public void buildEvaluator (Instances data)
376    throws Exception {
377   
378    // can evaluator handle data?
379    getCapabilities().testWithFail(data);
380
381    m_trainInstances = data;
382    m_classIndex = m_trainInstances.classIndex();
383    m_numAttribs = m_trainInstances.numAttributes();
384    m_numInstances = m_trainInstances.numInstances();
385  }
386
387
388  /**
389   * rests to defaults.
390   */
391  protected void resetOptions () {
392    m_trainInstances = null;
393    m_randomSeed = 1;
394    m_folds = 10;
395    m_evalUsingTrainingData = false;
396    m_minBucketSize = 6; // default used by OneR
397  }
398
399
400  /**
401   * evaluates an individual attribute by measuring the amount
402   * of information gained about the class given the attribute.
403   *
404   * @param attribute the index of the attribute to be evaluated
405   * @throws Exception if the attribute could not be evaluated
406   */
407  public double evaluateAttribute (int attribute)
408    throws Exception {
409    int[] featArray = new int[2]; // feat + class
410    double errorRate;
411    Evaluation o_Evaluation;
412    Remove delTransform = new Remove();
413    delTransform.setInvertSelection(true);
414    // copy the instances
415    Instances trainCopy = new Instances(m_trainInstances);
416    featArray[0] = attribute;
417    featArray[1] = trainCopy.classIndex();
418    delTransform.setAttributeIndicesArray(featArray);
419    delTransform.setInputFormat(trainCopy);
420    trainCopy = Filter.useFilter(trainCopy, delTransform);
421    o_Evaluation = new Evaluation(trainCopy);
422    String [] oneROpts = { "-B", ""+getMinimumBucketSize()};
423    Classifier oneR = AbstractClassifier.forName("weka.classifiers.rules.OneR", oneROpts);
424    if (m_evalUsingTrainingData) {
425      oneR.buildClassifier(trainCopy);
426      o_Evaluation.evaluateModel(oneR, trainCopy);
427    } else {
428      /*      o_Evaluation.crossValidateModel("weka.classifiers.rules.OneR",
429              trainCopy, 10,
430              null, new Random(m_randomSeed)); */
431      o_Evaluation.crossValidateModel(oneR, trainCopy, m_folds, new Random(m_randomSeed));
432    }
433    errorRate = o_Evaluation.errorRate();
434    return  (1 - errorRate)*100.0;
435  }
436
437
438  /**
439   * Return a description of the evaluator
440   * @return description as a string
441   */
442  public String toString () {
443    StringBuffer text = new StringBuffer();
444
445    if (m_trainInstances == null) {
446      text.append("\tOneR feature evaluator has not been built yet");
447    }
448    else {
449      text.append("\tOneR feature evaluator.\n\n");
450      text.append("\tUsing ");
451      if (m_evalUsingTrainingData) {
452        text.append("training data for evaluation of attributes.");
453      } else {
454        text.append(""+getFolds()+" fold cross validation for evaluating "
455                    +"attributes.");
456      }
457      text.append("\n\tMinimum bucket size for OneR: "
458                  +getMinimumBucketSize());
459    }
460
461    text.append("\n");
462    return  text.toString();
463  }
464 
465  /**
466   * Returns the revision string.
467   *
468   * @return            the revision
469   */
470  public String getRevision() {
471    return RevisionUtils.extract("$Revision: 5928 $");
472  }
473
474  // ============
475  // Test method.
476  // ============
477  /**
478   * Main method for testing this class.
479   *
480   * @param args the options
481   */
482  public static void main (String[] args) {
483    runEvaluator(new OneRAttributeEval(), args);
484  }
485}
Note: See TracBrowser for help on using the repository browser.