source: src/main/java/weka/classifiers/meta/nestedDichotomies/ND.java @ 11

Last change on this file since 11 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 17.5 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    ND.java
19 *    Copyright (C) 2003-2005 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.meta.nestedDichotomies;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.classifiers.RandomizableSingleClassifierEnhancer;
28import weka.classifiers.meta.FilteredClassifier;
29import weka.classifiers.rules.ZeroR;
30import weka.core.Capabilities;
31import weka.core.FastVector;
32import weka.core.Instance;
33import weka.core.Instances;
34import weka.core.RevisionHandler;
35import weka.core.RevisionUtils;
36import weka.core.TechnicalInformation;
37import weka.core.TechnicalInformationHandler;
38import weka.core.Capabilities.Capability;
39import weka.core.TechnicalInformation.Field;
40import weka.core.TechnicalInformation.Type;
41import weka.filters.Filter;
42import weka.filters.unsupervised.attribute.MakeIndicator;
43import weka.filters.unsupervised.instance.RemoveWithValues;
44
45import java.io.Serializable;
46import java.util.Hashtable;
47import java.util.Random;
48
49/**
50 <!-- globalinfo-start -->
51 * A meta classifier for handling multi-class datasets with 2-class classifiers by building a random tree structure.<br/>
52 * <br/>
53 * For more info, check<br/>
54 * <br/>
55 * Lin Dong, Eibe Frank, Stefan Kramer: Ensembles of Balanced Nested Dichotomies for Multi-class Problems. In: PKDD, 84-95, 2005.<br/>
56 * <br/>
57 * Eibe Frank, Stefan Kramer: Ensembles of nested dichotomies for multi-class problems. In: Twenty-first International Conference on Machine Learning, 2004.
58 * <p/>
59 <!-- globalinfo-end -->
60 *
61 <!-- technical-bibtex-start -->
62 * BibTeX:
63 * <pre>
64 * &#64;inproceedings{Dong2005,
65 *    author = {Lin Dong and Eibe Frank and Stefan Kramer},
66 *    booktitle = {PKDD},
67 *    pages = {84-95},
68 *    publisher = {Springer},
69 *    title = {Ensembles of Balanced Nested Dichotomies for Multi-class Problems},
70 *    year = {2005}
71 * }
72 *
73 * &#64;inproceedings{Frank2004,
74 *    author = {Eibe Frank and Stefan Kramer},
75 *    booktitle = {Twenty-first International Conference on Machine Learning},
76 *    publisher = {ACM},
77 *    title = {Ensembles of nested dichotomies for multi-class problems},
78 *    year = {2004}
79 * }
80 * </pre>
81 * <p/>
82 <!-- technical-bibtex-end -->
83 *
84 <!-- options-start -->
85 * Valid options are: <p/>
86 *
87 * <pre> -S &lt;num&gt;
88 *  Random number seed.
89 *  (default 1)</pre>
90 *
91 * <pre> -D
92 *  If set, classifier is run in debug mode and
93 *  may output additional info to the console</pre>
94 *
95 * <pre> -W
96 *  Full name of base classifier.
97 *  (default: weka.classifiers.trees.J48)</pre>
98 *
99 * <pre>
100 * Options specific to classifier weka.classifiers.trees.J48:
101 * </pre>
102 *
103 * <pre> -U
104 *  Use unpruned tree.</pre>
105 *
106 * <pre> -C &lt;pruning confidence&gt;
107 *  Set confidence threshold for pruning.
108 *  (default 0.25)</pre>
109 *
110 * <pre> -M &lt;minimum number of instances&gt;
111 *  Set minimum number of instances per leaf.
112 *  (default 2)</pre>
113 *
114 * <pre> -R
115 *  Use reduced error pruning.</pre>
116 *
117 * <pre> -N &lt;number of folds&gt;
118 *  Set number of folds for reduced error
119 *  pruning. One fold is used as pruning set.
120 *  (default 3)</pre>
121 *
122 * <pre> -B
123 *  Use binary splits only.</pre>
124 *
125 * <pre> -S
126 *  Don't perform subtree raising.</pre>
127 *
128 * <pre> -L
129 *  Do not clean up after the tree has been built.</pre>
130 *
131 * <pre> -A
132 *  Laplace smoothing for predicted probabilities.</pre>
133 *
134 * <pre> -Q &lt;seed&gt;
135 *  Seed for random data shuffling (default 1).</pre>
136 *
137 <!-- options-end -->
138 *
139 * @author Eibe Frank
140 * @author Lin Dong
141 */
142public class ND 
143  extends RandomizableSingleClassifierEnhancer
144  implements TechnicalInformationHandler {
145 
146  /** for serialization */
147  static final long serialVersionUID = -6355893369855683820L;
148
149  /**
150   * a node class
151   */
152  protected class NDTree
153    implements Serializable, RevisionHandler {
154
155    /** for serialization */
156    private static final long serialVersionUID = 4284655952754474880L;
157   
158    /** The indices associated with this node */
159    protected FastVector m_indices = null;
160   
161    /** The parent */
162    protected NDTree m_parent = null;
163   
164    /** The left successor */
165    protected NDTree m_left = null;
166   
167    /** The right successor */
168    protected NDTree m_right = null;
169   
170    /**
171     * Constructor.
172     */
173    protected NDTree() {
174     
175      m_indices = new FastVector(1);
176      m_indices.addElement(new Integer(Integer.MAX_VALUE));
177    }
178   
179    /**
180     * Locates the node with the given index (depth-first traversal).
181     */
182    protected NDTree locateNode(int nodeIndex, int[] currentIndex) {
183     
184      if (nodeIndex == currentIndex[0]) {
185        return this;
186      } else if (m_left == null) {
187        return null;
188      } else {
189        currentIndex[0]++;
190        NDTree leftresult = m_left.locateNode(nodeIndex, currentIndex);
191        if (leftresult != null) {
192          return leftresult;
193        } else {
194          currentIndex[0]++;
195          return m_right.locateNode(nodeIndex, currentIndex);
196        }
197      }
198    }
199     
200    /**
201     * Inserts a class index into the tree.
202     *
203     * @param classIndex the class index to insert
204     */
205    protected void insertClassIndex(int classIndex) {
206
207      // Create new nodes
208      NDTree right = new NDTree();
209      if (m_left != null) {
210        m_right.m_parent = right;
211        m_left.m_parent = right;
212        right.m_right = m_right;
213        right.m_left = m_left;
214      }
215      m_right = right;
216      m_right.m_indices = (FastVector)m_indices.copy();
217      m_right.m_parent = this;
218      m_left = new NDTree();
219      m_left.insertClassIndexAtNode(classIndex);
220      m_left.m_parent = this; 
221
222      // Propagate class Index
223      propagateClassIndex(classIndex);
224    }
225
226    /**
227     * Propagates class index to the root.
228     *
229     * @param classIndex the index to propagate to the root
230     */
231    protected void propagateClassIndex(int classIndex) {
232
233      insertClassIndexAtNode(classIndex);
234      if (m_parent != null) {
235        m_parent.propagateClassIndex(classIndex);
236      }
237    }
238   
239    /**
240     * Inserts the class index at a given node.
241     *
242     * @param classIndex the classIndex to insert
243     */
244    protected void insertClassIndexAtNode(int classIndex) {
245
246      int i = 0;
247      while (classIndex > ((Integer)m_indices.elementAt(i)).intValue()) {
248        i++;
249      }
250      m_indices.insertElementAt(new Integer(classIndex), i);
251    }
252
253    /**
254     * Gets the indices in an array of ints.
255     *
256     * @return the indices
257     */
258    protected int[] getIndices() {
259
260      int[] ints = new int[m_indices.size() - 1];
261      for (int i = 0; i < m_indices.size() - 1; i++) {
262        ints[i] = ((Integer)m_indices.elementAt(i)).intValue();
263      }
264      return ints;
265    }
266
267    /**
268     * Checks whether an index is in the array.
269     *
270     * @param index the index to check
271     * @return true of the index is in the array
272     */
273    protected boolean contains(int index) {
274
275      for (int i = 0; i < m_indices.size() - 1; i++) {
276        if (index == ((Integer)m_indices.elementAt(i)).intValue()) {
277          return true;
278        }
279      }
280      return false;
281    }
282
283    /**
284     * Returns the list of indices as a string.
285     *
286     * @return the indices as string
287     */
288    protected String getString() {
289
290      StringBuffer string = new StringBuffer();
291      for (int i = 0; i < m_indices.size() - 1; i++) {
292        if (i > 0) {
293          string.append(',');
294        }
295        string.append(((Integer)m_indices.elementAt(i)).intValue() + 1);
296      }
297      return string.toString();
298    }
299
300    /**
301     * Unifies tree for improve hashing.
302     */
303    protected void unifyTree() {
304
305      if (m_left != null) {
306        if (((Integer)m_left.m_indices.elementAt(0)).intValue() >
307            ((Integer)m_right.m_indices.elementAt(0)).intValue()) {
308          NDTree temp = m_left;
309          m_left = m_right;
310          m_right = temp;
311        }
312        m_left.unifyTree();
313        m_right.unifyTree();
314      }
315    }
316
317    /**
318     * Returns a description of the tree rooted at this node.
319     *
320     * @param text the buffer to add the node to
321     * @param id the node id
322     * @param level the level of the tree
323     */
324    protected void toString(StringBuffer text, int[] id, int level) {
325
326      for (int i = 0; i < level; i++) {
327        text.append("   | ");
328      }
329      text.append(id[0] + ": " + getString() + "\n");
330      if (m_left != null) {
331        id[0]++;
332        m_left.toString(text, id, level + 1);
333        id[0]++;
334        m_right.toString(text, id, level + 1);
335      }
336    }
337   
338    /**
339     * Returns the revision string.
340     *
341     * @return          the revision
342     */
343    public String getRevision() {
344      return RevisionUtils.extract("$Revision: 5928 $");
345    }
346  }
347
348  /** The tree of classes */
349  protected NDTree m_ndtree = null;
350 
351  /** The hashtable containing all the classifiers */
352  protected Hashtable m_classifiers = null;
353
354  /** Is Hashtable given from END? */
355  protected boolean m_hashtablegiven = false;
356   
357  /**
358   * Constructor.
359   */
360  public ND() {
361   
362    m_Classifier = new weka.classifiers.trees.J48();
363  }
364 
365  /**
366   * String describing default classifier.
367   *
368   * @return the default classifier classname
369   */
370  protected String defaultClassifierString() {
371   
372    return "weka.classifiers.trees.J48";
373  }
374
375  /**
376   * Returns an instance of a TechnicalInformation object, containing
377   * detailed information about the technical background of this class,
378   * e.g., paper reference or book this class is based on.
379   *
380   * @return the technical information about this class
381   */
382  public TechnicalInformation getTechnicalInformation() {
383    TechnicalInformation        result;
384    TechnicalInformation        additional;
385   
386    result = new TechnicalInformation(Type.INPROCEEDINGS);
387    result.setValue(Field.AUTHOR, "Lin Dong and Eibe Frank and Stefan Kramer");
388    result.setValue(Field.TITLE, "Ensembles of Balanced Nested Dichotomies for Multi-class Problems");
389    result.setValue(Field.BOOKTITLE, "PKDD");
390    result.setValue(Field.YEAR, "2005");
391    result.setValue(Field.PAGES, "84-95");
392    result.setValue(Field.PUBLISHER, "Springer");
393
394    additional = result.add(Type.INPROCEEDINGS);
395    additional.setValue(Field.AUTHOR, "Eibe Frank and Stefan Kramer");
396    additional.setValue(Field.TITLE, "Ensembles of nested dichotomies for multi-class problems");
397    additional.setValue(Field.BOOKTITLE, "Twenty-first International Conference on Machine Learning");
398    additional.setValue(Field.YEAR, "2004");
399    additional.setValue(Field.PUBLISHER, "ACM");
400   
401    return result;
402  }
403
404  /**
405   * Set hashtable from END.
406   *
407   * @param table the hashtable to use
408   */
409  public void setHashtable(Hashtable table) {
410
411    m_hashtablegiven = true;
412    m_classifiers = table;
413  }
414
415  /**
416   * Returns default capabilities of the classifier.
417   *
418   * @return      the capabilities of this classifier
419   */
420  public Capabilities getCapabilities() {
421    Capabilities result = super.getCapabilities();
422
423    // class
424    result.disableAllClasses();
425    result.enable(Capability.NOMINAL_CLASS);
426    result.enable(Capability.MISSING_CLASS_VALUES);
427
428    // instances
429    result.setMinimumNumberInstances(1);
430   
431    return result;
432  }
433
434  /**
435   * Builds the classifier.
436   *
437   * @param data the data to train the classifier with
438   * @throws Exception if anything goes wrong
439   */
440  public void buildClassifier(Instances data) throws Exception {
441
442    // can classifier handle the data?
443    getCapabilities().testWithFail(data);
444
445    // remove instances with missing class
446    data = new Instances(data);
447    data.deleteWithMissingClass();
448   
449    Random random = data.getRandomNumberGenerator(m_Seed);
450
451    if (!m_hashtablegiven) {
452      m_classifiers = new Hashtable();
453    }
454
455    // Generate random class hierarchy
456    int[] indices = new int[data.numClasses()];
457    for (int i = 0; i < indices.length; i++) {
458      indices[i] = i;
459    }
460
461    // Randomize list of class indices
462    for (int i = indices.length - 1; i > 0; i--) {
463      int help = indices[i];
464      int index = random.nextInt(i + 1);
465      indices[i] = indices[index];
466      indices[index] = help;
467    }
468
469    // Insert random class index at randomly chosen node
470    m_ndtree = new NDTree();
471    m_ndtree.insertClassIndexAtNode(indices[0]);
472    for (int i = 1; i < indices.length; i++) {
473      int nodeIndex = random.nextInt(2 * i - 1);
474     
475      NDTree node = m_ndtree.locateNode(nodeIndex, new int[1]);
476      node.insertClassIndex(indices[i]);
477    }
478    m_ndtree.unifyTree();
479   
480
481    // Build classifiers
482    buildClassifierForNode(m_ndtree, data);
483  }
484
485  /**
486   * Builds the classifier for one node.
487   *
488   * @param node the node to build the classifier for
489   * @param data the data to work with
490   * @throws Exception if anything goes wrong
491   */
492  public void buildClassifierForNode(NDTree node, Instances data) throws Exception {
493
494    // Are we at a leaf node ?
495    if (node.m_left != null) {
496     
497      // Create classifier
498      MakeIndicator filter = new MakeIndicator();
499      filter.setAttributeIndex("" + (data.classIndex() + 1));
500      filter.setValueIndices(node.m_right.getString());
501      filter.setNumeric(false);
502      filter.setInputFormat(data);
503      FilteredClassifier classifier = new FilteredClassifier();
504      if (data.numInstances() > 0) {
505        classifier.setClassifier(AbstractClassifier.makeCopies(m_Classifier, 1)[0]);
506      } else {
507        classifier.setClassifier(new ZeroR());
508      }
509      classifier.setFilter(filter);
510     
511      if (!m_classifiers.containsKey(node.m_left.getString() + "|" + node.m_right.getString())) {
512        classifier.buildClassifier(data);
513        m_classifiers.put(node.m_left.getString() + "|" + node.m_right.getString(), classifier);
514      } else {
515        classifier=(FilteredClassifier)m_classifiers.get(node.m_left.getString() + "|" + 
516                                                         node.m_right.getString());
517      }
518     
519      // Generate successors
520      if (node.m_left.m_left != null) {
521        RemoveWithValues rwv = new RemoveWithValues();
522        rwv.setInvertSelection(true);
523        rwv.setNominalIndices(node.m_left.getString());
524        rwv.setAttributeIndex("" + (data.classIndex() + 1));
525        rwv.setInputFormat(data);
526        Instances firstSubset = Filter.useFilter(data, rwv);
527        buildClassifierForNode(node.m_left, firstSubset);
528      }
529      if (node.m_right.m_left != null) {
530        RemoveWithValues rwv = new RemoveWithValues();
531        rwv.setInvertSelection(true);
532        rwv.setNominalIndices(node.m_right.getString());
533        rwv.setAttributeIndex("" + (data.classIndex() + 1));
534        rwv.setInputFormat(data);
535        Instances secondSubset = Filter.useFilter(data, rwv);
536        buildClassifierForNode(node.m_right, secondSubset);
537      }
538    }
539  }
540   
541  /**
542   * Predicts the class distribution for a given instance
543   *
544   * @param inst the (multi-class) instance to be classified
545   * @return the class distribution
546   * @throws Exception if computing fails
547   */
548  public double[] distributionForInstance(Instance inst) throws Exception {
549       
550    return distributionForInstance(inst, m_ndtree);
551  }
552
553  /**
554   * Predicts the class distribution for a given instance
555   *
556   * @param inst the (multi-class) instance to be classified
557   * @param node the node to do get the distribution for
558   * @return the class distribution
559   * @throws Exception if computing fails
560   */
561  protected double[] distributionForInstance(Instance inst, NDTree node) throws Exception {
562
563    double[] newDist = new double[inst.numClasses()];
564    if (node.m_left == null) {
565      newDist[node.getIndices()[0]] = 1.0;
566      return newDist;
567    } else {
568      Classifier classifier = (Classifier)m_classifiers.get(node.m_left.getString() + "|" +
569                                                            node.m_right.getString());
570      double[] leftDist = distributionForInstance(inst, node.m_left);
571      double[] rightDist = distributionForInstance(inst, node.m_right);
572      double[] dist = classifier.distributionForInstance(inst);
573
574      for (int i = 0; i < inst.numClasses(); i++) {
575        if (node.m_right.contains(i)) {
576          newDist[i] = dist[1] * rightDist[i];
577        } else {
578          newDist[i] = dist[0] * leftDist[i];
579        }
580      }
581      return newDist;
582    }
583  }
584
585  /**
586   * Outputs the classifier as a string.
587   *
588   * @return a string representation of the classifier
589   */
590  public String toString() {
591       
592    if (m_classifiers == null) {
593      return "ND: No model built yet.";
594    }
595    StringBuffer text = new StringBuffer();
596    text.append("ND\n\n");
597    m_ndtree.toString(text, new int[1], 0);
598       
599    return text.toString();
600  }
601       
602  /**
603   * @return a description of the classifier suitable for
604   * displaying in the explorer/experimenter gui
605   */
606  public String globalInfo() {
607           
608    return 
609        "A meta classifier for handling multi-class datasets with 2-class "
610      + "classifiers by building a random tree structure.\n\n"
611      + "For more info, check\n\n"
612      + getTechnicalInformation().toString();
613  }
614 
615  /**
616   * Returns the revision string.
617   *
618   * @return            the revision
619   */
620  public String getRevision() {
621    return RevisionUtils.extract("$Revision: 5928 $");
622  }
623   
624  /**
625   * Main method for testing this class.
626   *
627   * @param argv the options
628   */
629  public static void main(String [] argv) {
630    runClassifier(new ND(), argv);
631  }
632}
Note: See TracBrowser for help on using the repository browser.