source: src/main/java/weka/classifiers/rules/OneR.java @ 7

Last change on this file since 7 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 20.8 KB
RevLine 
[4]1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    OneR.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.rules;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.classifiers.Sourcable;
28import weka.core.Attribute;
29import weka.core.Capabilities;
30import weka.core.Instance;
31import weka.core.Instances;
32import weka.core.Option;
33import weka.core.RevisionHandler;
34import weka.core.RevisionUtils;
35import weka.core.TechnicalInformation;
36import weka.core.TechnicalInformationHandler;
37import weka.core.Utils;
38import weka.core.WekaException;
39import weka.core.Capabilities.Capability;
40import weka.core.TechnicalInformation.Field;
41import weka.core.TechnicalInformation.Type;
42
43import java.io.Serializable;
44import java.util.Enumeration;
45import java.util.Vector;
46
47/**
48 <!-- globalinfo-start -->
49 * Class for building and using a 1R classifier; in other words, uses the minimum-error attribute for prediction, discretizing numeric attributes. For more information, see:<br/>
50 * <br/>
51 * R.C. Holte (1993). Very simple classification rules perform well on most commonly used datasets. Machine Learning. 11:63-91.
52 * <p/>
53 <!-- globalinfo-end -->
54 *
55 <!-- technical-bibtex-start -->
56 * BibTeX:
57 * <pre>
58 * &#64;article{Holte1993,
59 *    author = {R.C. Holte},
60 *    journal = {Machine Learning},
61 *    pages = {63-91},
62 *    title = {Very simple classification rules perform well on most commonly used datasets},
63 *    volume = {11},
64 *    year = {1993}
65 * }
66 * </pre>
67 * <p/>
68 <!-- technical-bibtex-end -->
69 *
70 <!-- options-start -->
71 * Valid options are: <p/>
72 *
73 * <pre> -B &lt;minimum bucket size&gt;
74 *  The minimum number of objects in a bucket (default: 6).</pre>
75 *
76 <!-- options-end -->
77 *
78 * @author Ian H. Witten (ihw@cs.waikato.ac.nz)
79 * @version $Revision: 5928 $
80*/
81public class OneR 
82  extends AbstractClassifier
83  implements TechnicalInformationHandler, Sourcable {
84   
85  /** for serialization */
86  static final long serialVersionUID = -2459427002147861445L;
87 
88  /**
89   * Returns a string describing classifier
90   * @return a description suitable for
91   * displaying in the explorer/experimenter gui
92   */
93  public String globalInfo() {
94
95    return "Class for building and using a 1R classifier; in other words, uses "
96      + "the minimum-error attribute for prediction, discretizing numeric "
97      + "attributes. For more information, see:\n\n"
98      + getTechnicalInformation().toString();
99  }
100
101  /**
102   * Returns an instance of a TechnicalInformation object, containing
103   * detailed information about the technical background of this class,
104   * e.g., paper reference or book this class is based on.
105   *
106   * @return the technical information about this class
107   */
108  public TechnicalInformation getTechnicalInformation() {
109    TechnicalInformation        result;
110   
111    result = new TechnicalInformation(Type.ARTICLE);
112    result.setValue(Field.AUTHOR, "R.C. Holte");
113    result.setValue(Field.YEAR, "1993");
114    result.setValue(Field.TITLE, "Very simple classification rules perform well on most commonly used datasets");
115    result.setValue(Field.JOURNAL, "Machine Learning");
116    result.setValue(Field.VOLUME, "11");
117    result.setValue(Field.PAGES, "63-91");
118   
119    return result;
120  }
121
122  /**
123   * Class for storing store a 1R rule.
124   */
125  private class OneRRule 
126    implements Serializable, RevisionHandler {
127   
128    /** for serialization */
129    static final long serialVersionUID = 1152814630957092281L;
130
131    /** The class attribute. */
132    private Attribute m_class;
133
134    /** The number of instances used for building the rule. */
135    private int m_numInst;
136
137    /** Attribute to test */
138    private Attribute m_attr; 
139
140    /** Training set examples this rule gets right */
141    private int m_correct; 
142
143    /** Predicted class for each value of attr */
144    private int[] m_classifications; 
145
146    /** Predicted class for missing values */
147    private int m_missingValueClass = -1; 
148
149    /** Breakpoints (numeric attributes only) */
150    private double[] m_breakpoints; 
151 
152    /**
153     * Constructor for nominal attribute.
154     *
155     * @param data the data to work with
156     * @param attribute the attribute to use
157     * @throws Exception if something goes wrong
158     */
159    public OneRRule(Instances data, Attribute attribute) throws Exception {
160
161      m_class = data.classAttribute();
162      m_numInst = data.numInstances();
163      m_attr = attribute;
164      m_correct = 0;
165      m_classifications = new int[m_attr.numValues()];
166    }
167
168    /**
169     * Constructor for numeric attribute.
170     *
171     * @param data the data to work with
172     * @param attribute the attribute to use
173     * @param nBreaks the break point
174     * @throws Exception if something goes wrong
175     */
176    public OneRRule(Instances data, Attribute attribute, int nBreaks) throws Exception {
177
178      m_class = data.classAttribute();
179      m_numInst = data.numInstances();
180      m_attr = attribute;
181      m_correct = 0;
182      m_classifications = new int[nBreaks];
183      m_breakpoints = new double[nBreaks - 1]; // last breakpoint is infinity
184    }
185   
186    /**
187     * Returns a description of the rule.
188     *
189     * @return a string representation of the rule
190     */
191    public String toString() {
192
193      try {
194        StringBuffer text = new StringBuffer();
195        text.append(m_attr.name() + ":\n");
196        for (int v = 0; v < m_classifications.length; v++) {
197          text.append("\t");
198          if (m_attr.isNominal()) {
199            text.append(m_attr.value(v));
200          } else if (v < m_breakpoints.length) {
201            text.append("< " + m_breakpoints[v]);
202          } else if (v > 0) {
203            text.append(">= " + m_breakpoints[v - 1]);
204          } else {
205            text.append("not ?");
206          }
207          text.append("\t-> " + m_class.value(m_classifications[v]) + "\n");
208        }
209        if (m_missingValueClass != -1) {
210          text.append("\t?\t-> " + m_class.value(m_missingValueClass) + "\n");
211        }
212        text.append("(" + m_correct + "/" + m_numInst + " instances correct)\n");
213        return text.toString();
214      } catch (Exception e) {
215        return "Can't print OneR classifier!";
216      }
217    }
218   
219    /**
220     * Returns the revision string.
221     *
222     * @return          the revision
223     */
224    public String getRevision() {
225      return RevisionUtils.extract("$Revision: 5928 $");
226    }
227  }
228 
229  /** A 1-R rule */
230  private OneRRule m_rule;
231
232  /** The minimum bucket size */
233  private int m_minBucketSize = 6;
234
235  /** a ZeroR model in case no model can be built from the data */
236  private Classifier m_ZeroR;
237   
238  /**
239   * Classifies a given instance.
240   *
241   * @param inst the instance to be classified
242   * @return the classification of the instance
243   */
244  public double classifyInstance(Instance inst) throws Exception {
245
246    // default model?
247    if (m_ZeroR != null) {
248      return m_ZeroR.classifyInstance(inst);
249    }
250   
251    int v = 0;
252    if (inst.isMissing(m_rule.m_attr)) {
253      if (m_rule.m_missingValueClass != -1) {
254        return m_rule.m_missingValueClass;
255      } else {
256        return 0;  // missing values occur in test but not training set   
257      }
258    }
259    if (m_rule.m_attr.isNominal()) {
260      v = (int) inst.value(m_rule.m_attr);
261    } else {
262      while (v < m_rule.m_breakpoints.length &&
263             inst.value(m_rule.m_attr) >= m_rule.m_breakpoints[v]) {
264        v++;
265      }
266    }
267    return m_rule.m_classifications[v];
268  }
269
270  /**
271   * Returns default capabilities of the classifier.
272   *
273   * @return      the capabilities of this classifier
274   */
275  public Capabilities getCapabilities() {
276    Capabilities result = super.getCapabilities();
277    result.disableAll();
278
279    // attributes
280    result.enable(Capability.NOMINAL_ATTRIBUTES);
281    result.enable(Capability.NUMERIC_ATTRIBUTES);
282    result.enable(Capability.DATE_ATTRIBUTES);
283    result.enable(Capability.MISSING_VALUES);
284
285    // class
286    result.enable(Capability.NOMINAL_CLASS);
287    result.enable(Capability.MISSING_CLASS_VALUES);
288
289    return result;
290  }
291
292  /**
293   * Generates the classifier.
294   *
295   * @param instances the instances to be used for building the classifier
296   * @throws Exception if the classifier can't be built successfully
297   */
298  public void buildClassifier(Instances instances) 
299    throws Exception {
300   
301    boolean noRule = true;
302
303    // can classifier handle the data?
304    getCapabilities().testWithFail(instances);
305
306    // remove instances with missing class
307    Instances data = new Instances(instances);
308    data.deleteWithMissingClass();
309
310    // only class? -> build ZeroR model
311    if (data.numAttributes() == 1) {
312      System.err.println(
313          "Cannot build model (only class attribute present in data!), "
314          + "using ZeroR model instead!");
315      m_ZeroR = new weka.classifiers.rules.ZeroR();
316      m_ZeroR.buildClassifier(data);
317      return;
318    }
319    else {
320      m_ZeroR = null;
321    }
322   
323    // for each attribute ...
324    Enumeration enu = instances.enumerateAttributes();
325    while (enu.hasMoreElements()) {
326      try {
327        OneRRule r = newRule((Attribute) enu.nextElement(), data);
328
329        // if this attribute is the best so far, replace the rule
330        if (noRule || r.m_correct > m_rule.m_correct) {
331          m_rule = r;
332        }
333        noRule = false;
334      } catch (Exception ex) {
335      }
336    }
337   
338    if (noRule)
339      throw new WekaException("No attributes found to work with!");
340  }
341
342  /**
343   * Create a rule branching on this attribute.
344   *
345   * @param attr the attribute to branch on
346   * @param data the data to be used for creating the rule
347   * @return the generated rule
348   * @throws Exception if the rule can't be built successfully
349   */
350  public OneRRule newRule(Attribute attr, Instances data) throws Exception {
351
352    OneRRule r;
353
354    // ... create array to hold the missing value counts
355    int[] missingValueCounts =
356      new int [data.classAttribute().numValues()];
357   
358    if (attr.isNominal()) {
359      r = newNominalRule(attr, data, missingValueCounts);
360    } else {
361      r = newNumericRule(attr, data, missingValueCounts);
362    }
363    r.m_missingValueClass = Utils.maxIndex(missingValueCounts);
364    if (missingValueCounts[r.m_missingValueClass] == 0) {
365      r.m_missingValueClass = -1; // signal for no missing value class
366    } else {
367      r.m_correct += missingValueCounts[r.m_missingValueClass];
368    }
369    return r;
370  }
371
372  /**
373   * Create a rule branching on this nominal attribute.
374   *
375   * @param attr the attribute to branch on
376   * @param data the data to be used for creating the rule
377   * @param missingValueCounts to be filled in
378   * @return the generated rule
379   * @throws Exception if the rule can't be built successfully
380   */
381  public OneRRule newNominalRule(Attribute attr, Instances data,
382                                 int[] missingValueCounts) throws Exception {
383
384    // ... create arrays to hold the counts
385    int[][] counts = new int [attr.numValues()]
386                             [data.classAttribute().numValues()];
387     
388    // ... calculate the counts
389    Enumeration enu = data.enumerateInstances();
390    while (enu.hasMoreElements()) {
391      Instance i = (Instance) enu.nextElement();
392      if (i.isMissing(attr)) {
393        missingValueCounts[(int) i.classValue()]++; 
394      } else {
395        counts[(int) i.value(attr)][(int) i.classValue()]++;
396      }
397    }
398
399    OneRRule r = new OneRRule(data, attr); // create a new rule
400    for (int value = 0; value < attr.numValues(); value++) {
401      int best = Utils.maxIndex(counts[value]);
402      r.m_classifications[value] = best;
403      r.m_correct += counts[value][best];
404    }
405    return r;
406  }
407
408  /**
409   * Create a rule branching on this numeric attribute
410   *
411   * @param attr the attribute to branch on
412   * @param data the data to be used for creating the rule
413   * @param missingValueCounts to be filled in
414   * @return the generated rule
415   * @throws Exception if the rule can't be built successfully
416   */
417  public OneRRule newNumericRule(Attribute attr, Instances data,
418                             int[] missingValueCounts) throws Exception {
419
420
421    // ... can't be more than numInstances buckets
422    int [] classifications = new int[data.numInstances()];
423    double [] breakpoints = new double[data.numInstances()];
424
425    // create array to hold the counts
426    int [] counts = new int[data.classAttribute().numValues()];
427    int correct = 0;
428    int lastInstance = data.numInstances();
429
430    // missing values get sorted to the end of the instances
431    data.sort(attr);
432    while (lastInstance > 0 && 
433           data.instance(lastInstance-1).isMissing(attr)) {
434      lastInstance--;
435      missingValueCounts[(int) data.instance(lastInstance).
436                         classValue()]++; 
437    }
438    int i = 0; 
439    int cl = 0; // index of next bucket to create
440    int it;
441    while (i < lastInstance) { // start a new bucket
442      for (int j = 0; j < counts.length; j++) counts[j] = 0;
443      do { // fill it until it has enough of the majority class
444        it = (int) data.instance(i++).classValue();
445        counts[it]++;
446      } while (counts[it] < m_minBucketSize && i < lastInstance);
447
448      // while class remains the same, keep on filling
449      while (i < lastInstance && 
450             (int) data.instance(i).classValue() == it) { 
451        counts[it]++; 
452        i++;
453      }
454      while (i < lastInstance && // keep on while attr value is the same
455             (data.instance(i - 1).value(attr) 
456              == data.instance(i).value(attr))) {
457        counts[(int) data.instance(i++).classValue()]++;
458      }
459      for (int j = 0; j < counts.length; j++) {
460        if (counts[j] > counts[it]) { 
461          it = j;
462        }
463      }
464      if (cl > 0) { // can we coalesce with previous class?
465        if (counts[classifications[cl - 1]] == counts[it]) {
466          it = classifications[cl - 1];
467        }
468        if (it == classifications[cl - 1]) {
469          cl--; // yes!
470        }
471      }
472      correct += counts[it];
473      classifications[cl] = it;
474      if (i < lastInstance) {
475        breakpoints[cl] = (data.instance(i - 1).value(attr)
476                           + data.instance(i).value(attr)) / 2;
477      }
478      cl++;
479    }
480    if (cl == 0) {
481      throw new Exception("Only missing values in the training data!");
482    }
483    OneRRule r = new OneRRule(data, attr, cl); // new rule with cl branches
484    r.m_correct = correct;
485    for (int v = 0; v < cl; v++) {
486      r.m_classifications[v] = classifications[v];
487      if (v < cl-1) {
488        r.m_breakpoints[v] = breakpoints[v];
489      }
490    }
491
492    return r;
493  }
494
495  /**
496   * Returns an enumeration describing the available options..
497   *
498   * @return an enumeration of all the available options.
499   */
500  public Enumeration listOptions() {
501
502    String string = "\tThe minimum number of objects in a bucket (default: 6).";
503
504    Vector newVector = new Vector(1);
505
506    newVector.addElement(new Option(string, "B", 1, 
507                                    "-B <minimum bucket size>"));
508
509    return newVector.elements();
510  }
511
512  /**
513   * Parses a given list of options. <p/>
514   *
515   <!-- options-start -->
516   * Valid options are: <p/>
517   *
518   * <pre> -B &lt;minimum bucket size&gt;
519   *  The minimum number of objects in a bucket (default: 6).</pre>
520   *
521   <!-- options-end -->
522   *
523   * @param options the list of options as an array of strings
524   * @throws Exception if an option is not supported
525   */
526  public void setOptions(String[] options) throws Exception {
527   
528    String bucketSizeString = Utils.getOption('B', options);
529    if (bucketSizeString.length() != 0) {
530      m_minBucketSize = Integer.parseInt(bucketSizeString);
531    } else {
532      m_minBucketSize = 6;
533    }
534  }
535
536  /**
537   * Gets the current settings of the OneR classifier.
538   *
539   * @return an array of strings suitable for passing to setOptions
540   */
541  public String [] getOptions() {
542
543    String [] options = new String [2];
544    int current = 0;
545
546    options[current++] = "-B"; options[current++] = "" + m_minBucketSize;
547
548    while (current < options.length) {
549      options[current++] = "";
550    }
551    return options;
552  }
553
554  /**
555   * Returns a string that describes the classifier as source. The
556   * classifier will be contained in a class with the given name (there may
557   * be auxiliary classes),
558   * and will contain a method with the signature:
559   * <pre><code>
560   * public static double classify(Object[] i);
561   * </code></pre>
562   * where the array <code>i</code> contains elements that are either
563   * Double, String, with missing values represented as null. The generated
564   * code is public domain and comes with no warranty.
565   *
566   * @param className the name that should be given to the source class.
567   * @return the object source described by a string
568   * @throws Exception if the souce can't be computed
569   */
570  public String toSource(String className) throws Exception {
571    StringBuffer        result;
572    int                 i;
573   
574    result = new StringBuffer();
575   
576    if (m_ZeroR != null) {
577      result.append(((ZeroR) m_ZeroR).toSource(className));
578    }
579    else {
580      result.append("class " + className + " {\n");
581      result.append("  public static double classify(Object[] i) {\n");
582      result.append("    // chosen attribute: " + m_rule.m_attr.name() + " (" + m_rule.m_attr.index() + ")\n");
583      result.append("\n");
584      // missing values
585      result.append("    // missing value?\n");
586      result.append("    if (i[" + m_rule.m_attr.index() + "] == null)\n");
587      if (m_rule.m_missingValueClass != -1)
588        result.append("      return Double.NaN;\n");
589      else
590        result.append("      return 0;\n");
591      result.append("\n");
592     
593      // actual prediction
594      result.append("    // prediction\n");
595      result.append("    double v = 0;\n");
596      result.append("    double[] classifications = new double[]{" + Utils.arrayToString(m_rule.m_classifications) + "};");
597      result.append(" // ");
598      for (i = 0; i < m_rule.m_classifications.length; i++) {
599        if (i > 0)
600          result.append(", ");
601        result.append(m_rule.m_class.value(m_rule.m_classifications[i]));
602      }
603      result.append("\n");
604      if (m_rule.m_attr.isNominal()) {
605        for (i = 0; i < m_rule.m_attr.numValues(); i++) {
606          result.append("    ");
607          if (i > 0)
608            result.append("else ");
609          result.append("if (((String) i[" + m_rule.m_attr.index() + "]).equals(\"" + m_rule.m_attr.value(i) + "\"))\n");
610          result.append("      v = " + i + "; // " + m_rule.m_class.value(m_rule.m_classifications[i]) + "\n");
611        }
612      }
613      else {
614        result.append("    double[] breakpoints = new double[]{" + Utils.arrayToString(m_rule.m_breakpoints) + "};\n");
615        result.append("    while (v < breakpoints.length && \n");
616        result.append("           ((Double) i[" + m_rule.m_attr.index() + "]) >= breakpoints[(int) v]) {\n");
617        result.append("      v++;\n");
618        result.append("    }\n");
619      }
620      result.append("    return classifications[(int) v];\n");
621     
622      result.append("  }\n");
623      result.append("}\n");
624    }
625   
626    return result.toString();
627  }
628
629  /**
630   * Returns a description of the classifier
631   *
632   * @return a string representation of the classifier
633   */
634  public String toString() {
635
636    // only ZeroR model?
637    if (m_ZeroR != null) {
638      StringBuffer buf = new StringBuffer();
639      buf.append(this.getClass().getName().replaceAll(".*\\.", "") + "\n");
640      buf.append(this.getClass().getName().replaceAll(".*\\.", "").replaceAll(".", "=") + "\n\n");
641      buf.append("Warning: No model could be built, hence ZeroR model is used:\n\n");
642      buf.append(m_ZeroR.toString());
643      return buf.toString();
644    }
645   
646    if (m_rule == null) {
647      return "OneR: No model built yet.";
648    }
649    return m_rule.toString();
650  }
651
652  /**
653   * Returns the tip text for this property
654   * @return tip text for this property suitable for
655   * displaying in the explorer/experimenter gui
656   */
657  public String minBucketSizeTipText() {
658    return "The minimum bucket size used for discretizing numeric "
659      + "attributes.";
660  }
661 
662  /**
663   * Get the value of minBucketSize.
664   * @return Value of minBucketSize.
665   */
666  public int getMinBucketSize() {
667   
668    return m_minBucketSize;
669  }
670 
671  /**
672   * Set the value of minBucketSize.
673   * @param v  Value to assign to minBucketSize.
674   */
675  public void setMinBucketSize(int v) {
676   
677    m_minBucketSize = v;
678  }
679 
680  /**
681   * Returns the revision string.
682   *
683   * @return            the revision
684   */
685  public String getRevision() {
686    return RevisionUtils.extract("$Revision: 5928 $");
687  }
688 
689  /**
690   * Main method for testing this class
691   *
692   * @param argv the commandline options
693   */
694  public static void main(String [] argv) {
695    runClassifier(new OneR(), argv);
696  }
697}
Note: See TracBrowser for help on using the repository browser.