source: src/main/java/weka/classifiers/trees/j48/ClassifierTree.java @ 7

Last change on this file since 7 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 17.9 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    ClassifierTree.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.trees.j48;
24
25import weka.core.Capabilities;
26import weka.core.CapabilitiesHandler;
27import weka.core.Drawable;
28import weka.core.Instance;
29import weka.core.Instances;
30import weka.core.RevisionHandler;
31import weka.core.RevisionUtils;
32import weka.core.Utils;
33
34import java.io.Serializable;
35
36/**
37 * Class for handling a tree structure used for
38 * classification.
39 *
40 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
41 * @version $Revision: 5530 $
42 */
43public class ClassifierTree 
44  implements Drawable, Serializable, CapabilitiesHandler, RevisionHandler {
45
46  /** for serialization */
47  static final long serialVersionUID = -8722249377542734193L;
48 
49  /** The model selection method. */ 
50  protected ModelSelection m_toSelectModel;     
51
52  /** Local model at node. */
53  protected ClassifierSplitModel m_localModel; 
54
55  /** References to sons. */
56  protected ClassifierTree [] m_sons;           
57
58  /** True if node is leaf. */
59  protected boolean m_isLeaf;                   
60
61  /** True if node is empty. */
62  protected boolean m_isEmpty;                 
63
64  /** The training instances. */
65  protected Instances m_train;                 
66
67  /** The pruning instances. */
68  protected Distribution m_test;     
69
70  /** The id for the node. */
71  protected int m_id;
72
73  /**
74   * For getting a unique ID when outputting the tree (hashcode isn't
75   * guaranteed unique)
76   */
77  private static long PRINTED_NODES = 0;
78
79  /**
80   * Gets the next unique node ID.
81   *
82   * @return the next unique node ID.
83   */
84  protected static long nextID() {
85
86    return PRINTED_NODES ++;
87  }
88
89  /**
90   * Resets the unique node ID counter (e.g.
91   * between repeated separate print types)
92   */
93  protected static void resetID() {
94
95    PRINTED_NODES = 0;
96  }
97
98  /**
99   * Constructor.
100   */
101  public ClassifierTree(ModelSelection toSelectLocModel) {
102   
103    m_toSelectModel = toSelectLocModel;
104  }
105
106  /**
107   * Returns default capabilities of the classifier tree.
108   *
109   * @return      the capabilities of this classifier tree
110   */
111  public Capabilities getCapabilities() {
112    Capabilities result = new Capabilities(this);
113    result.enableAll();
114   
115    return result;
116  }
117
118  /**
119   * Method for building a classifier tree.
120   *
121   * @param data the data to build the tree from
122   * @throws Exception if something goes wrong
123   */
124  public void buildClassifier(Instances data) throws Exception {
125
126    // can classifier tree handle the data?
127    getCapabilities().testWithFail(data);
128
129    // remove instances with missing class
130    data = new Instances(data);
131    data.deleteWithMissingClass();
132   
133    buildTree(data, false);
134  }
135
136  /**
137   * Builds the tree structure.
138   *
139   * @param data the data for which the tree structure is to be
140   * generated.
141   * @param keepData is training data to be kept?
142   * @throws Exception if something goes wrong
143   */
144  public void buildTree(Instances data, boolean keepData) throws Exception {
145   
146    Instances [] localInstances;
147
148    if (keepData) {
149      m_train = data;
150    }
151    m_test = null;
152    m_isLeaf = false;
153    m_isEmpty = false;
154    m_sons = null;
155    m_localModel = m_toSelectModel.selectModel(data);
156    if (m_localModel.numSubsets() > 1) {
157      localInstances = m_localModel.split(data);
158      data = null;
159      m_sons = new ClassifierTree [m_localModel.numSubsets()];
160      for (int i = 0; i < m_sons.length; i++) {
161        m_sons[i] = getNewTree(localInstances[i]);
162        localInstances[i] = null;
163      }
164    }else{
165      m_isLeaf = true;
166      if (Utils.eq(data.sumOfWeights(), 0))
167        m_isEmpty = true;
168      data = null;
169    }
170  }
171
172  /**
173   * Builds the tree structure with hold out set
174   *
175   * @param train the data for which the tree structure is to be
176   * generated.
177   * @param test the test data for potential pruning
178   * @param keepData is training Data to be kept?
179   * @throws Exception if something goes wrong
180   */
181  public void buildTree(Instances train, Instances test, boolean keepData)
182       throws Exception {
183   
184    Instances [] localTrain, localTest;
185    int i;
186   
187    if (keepData) {
188      m_train = train;
189    }
190    m_isLeaf = false;
191    m_isEmpty = false;
192    m_sons = null;
193    m_localModel = m_toSelectModel.selectModel(train, test);
194    m_test = new Distribution(test, m_localModel);
195    if (m_localModel.numSubsets() > 1) {
196      localTrain = m_localModel.split(train);
197      localTest = m_localModel.split(test);
198      train = test = null;
199      m_sons = new ClassifierTree [m_localModel.numSubsets()];
200      for (i=0;i<m_sons.length;i++) {
201        m_sons[i] = getNewTree(localTrain[i], localTest[i]);
202        localTrain[i] = null;
203        localTest[i] = null;
204      }
205    }else{
206      m_isLeaf = true;
207      if (Utils.eq(train.sumOfWeights(), 0))
208        m_isEmpty = true;
209      train = test = null;
210    }
211  }
212
213  /**
214   * Classifies an instance.
215   *
216   * @param instance the instance to classify
217   * @return the classification
218   * @throws Exception if something goes wrong
219   */
220  public double classifyInstance(Instance instance) 
221    throws Exception {
222
223    double maxProb = -1;
224    double currentProb;
225    int maxIndex = 0;
226    int j;
227
228    for (j = 0; j < instance.numClasses(); j++) {
229      currentProb = getProbs(j, instance, 1);
230      if (Utils.gr(currentProb,maxProb)) {
231        maxIndex = j;
232        maxProb = currentProb;
233      }
234    }
235
236    return (double)maxIndex;
237  }
238
239  /**
240   * Cleanup in order to save memory.
241   *
242   * @param justHeaderInfo
243   */
244  public final void cleanup(Instances justHeaderInfo) {
245
246    m_train = justHeaderInfo;
247    m_test = null;
248    if (!m_isLeaf)
249      for (int i = 0; i < m_sons.length; i++)
250        m_sons[i].cleanup(justHeaderInfo);
251  }
252
253  /**
254   * Returns class probabilities for a weighted instance.
255   *
256   * @param instance the instance to get the distribution for
257   * @param useLaplace whether to use laplace or not
258   * @return the distribution
259   * @throws Exception if something goes wrong
260   */
261  public final double [] distributionForInstance(Instance instance,
262                                                 boolean useLaplace) 
263       throws Exception {
264
265    double [] doubles = new double[instance.numClasses()];
266
267    for (int i = 0; i < doubles.length; i++) {
268      if (!useLaplace) {
269        doubles[i] = getProbs(i, instance, 1);
270      } else {
271        doubles[i] = getProbsLaplace(i, instance, 1);
272      }
273    }
274
275    return doubles;
276  }
277
278  /**
279   * Assigns a uniqe id to every node in the tree.
280   *
281   * @param lastID the last ID that was assign
282   * @return the new current ID
283   */
284  public int assignIDs(int lastID) {
285
286    int currLastID = lastID + 1;
287
288    m_id = currLastID;
289    if (m_sons != null) {
290      for (int i = 0; i < m_sons.length; i++) {
291        currLastID = m_sons[i].assignIDs(currLastID);
292      }
293    }
294    return currLastID;
295  }
296
297  /**
298   *  Returns the type of graph this classifier
299   *  represents.
300   *  @return Drawable.TREE
301   */   
302  public int graphType() {
303      return Drawable.TREE;
304  }
305
306  /**
307   * Returns graph describing the tree.
308   *
309   * @throws Exception if something goes wrong
310   * @return the tree as graph
311   */
312  public String graph() throws Exception {
313
314    StringBuffer text = new StringBuffer();
315
316    assignIDs(-1);
317    text.append("digraph J48Tree {\n");
318    if (m_isLeaf) {
319      text.append("N" + m_id
320                  + " [label=\"" + 
321                  m_localModel.dumpLabel(0,m_train) + "\" " + 
322                  "shape=box style=filled ");
323      if (m_train != null && m_train.numInstances() > 0) {
324        text.append("data =\n" + m_train + "\n");
325        text.append(",\n");
326
327      }
328      text.append("]\n");
329    }else {
330      text.append("N" + m_id
331                  + " [label=\"" + 
332                  m_localModel.leftSide(m_train) + "\" ");
333      if (m_train != null && m_train.numInstances() > 0) {
334        text.append("data =\n" + m_train + "\n");
335        text.append(",\n");
336     }
337      text.append("]\n");
338      graphTree(text);
339    }
340   
341    return text.toString() +"}\n";
342  }
343
344  /**
345   * Returns tree in prefix order.
346   *
347   * @throws Exception if something goes wrong
348   * @return the prefix order
349   */
350  public String prefix() throws Exception {
351   
352    StringBuffer text;
353
354    text = new StringBuffer();
355    if (m_isLeaf) {
356      text.append("["+m_localModel.dumpLabel(0,m_train)+"]");
357    }else {
358      prefixTree(text);
359    }
360   
361    return text.toString();
362  }
363
364  /**
365   * Returns source code for the tree as an if-then statement. The
366   * class is assigned to variable "p", and assumes the tested
367   * instance is named "i". The results are returned as two stringbuffers:
368   * a section of code for assignment of the class, and a section of
369   * code containing support code (eg: other support methods).
370   *
371   * @param className the classname that this static classifier has
372   * @return an array containing two stringbuffers, the first string containing
373   * assignment code, and the second containing source for support code.
374   * @throws Exception if something goes wrong
375   */
376  public StringBuffer [] toSource(String className) throws Exception {
377   
378    StringBuffer [] result = new StringBuffer [2];
379    if (m_isLeaf) {
380      result[0] = new StringBuffer("    p = " 
381        + m_localModel.distribution().maxClass(0) + ";\n");
382      result[1] = new StringBuffer("");
383    } else {
384      StringBuffer text = new StringBuffer();
385      StringBuffer atEnd = new StringBuffer();
386
387      long printID = ClassifierTree.nextID();
388
389      text.append("  static double N") 
390        .append(Integer.toHexString(m_localModel.hashCode()) + printID)
391        .append("(Object []i) {\n")
392        .append("    double p = Double.NaN;\n");
393
394      text.append("    if (")
395        .append(m_localModel.sourceExpression(-1, m_train))
396        .append(") {\n");
397      text.append("      p = ")
398        .append(m_localModel.distribution().maxClass(0))
399        .append(";\n");
400      text.append("    } ");
401      for (int i = 0; i < m_sons.length; i++) {
402        text.append("else if (" + m_localModel.sourceExpression(i, m_train) 
403                    + ") {\n");
404        if (m_sons[i].m_isLeaf) {
405          text.append("      p = " 
406                      + m_localModel.distribution().maxClass(i) + ";\n");
407        } else {
408          StringBuffer [] sub = m_sons[i].toSource(className);
409          text.append(sub[0]);
410          atEnd.append(sub[1]);
411        }
412        text.append("    } ");
413        if (i == m_sons.length - 1) {
414          text.append('\n');
415        }
416      }
417
418      text.append("    return p;\n  }\n");
419
420      result[0] = new StringBuffer("    p = " + className + ".N");
421      result[0].append(Integer.toHexString(m_localModel.hashCode()) +  printID)
422        .append("(i);\n");
423      result[1] = text.append(atEnd);
424    }
425    return result;
426  }
427
428  /**
429   * Returns number of leaves in tree structure.
430   *
431   * @return the number of leaves
432   */
433  public int numLeaves() {
434   
435    int num = 0;
436    int i;
437   
438    if (m_isLeaf)
439      return 1;
440    else
441      for (i=0;i<m_sons.length;i++)
442        num = num+m_sons[i].numLeaves();
443       
444    return num;
445  }
446
447  /**
448   * Returns number of nodes in tree structure.
449   *
450   * @return the number of nodes
451   */
452  public int numNodes() {
453   
454    int no = 1;
455    int i;
456   
457    if (!m_isLeaf)
458      for (i=0;i<m_sons.length;i++)
459        no = no+m_sons[i].numNodes();
460   
461    return no;
462  }
463
464  /**
465   * Prints tree structure.
466   *
467   * @return the tree structure
468   */
469  public String toString() {
470
471    try {
472      StringBuffer text = new StringBuffer();
473     
474      if (m_isLeaf) {
475        text.append(": ");
476        text.append(m_localModel.dumpLabel(0,m_train));
477      }else
478        dumpTree(0,text);
479      text.append("\n\nNumber of Leaves  : \t"+numLeaves()+"\n");
480      text.append("\nSize of the tree : \t"+numNodes()+"\n");
481 
482      return text.toString();
483    } catch (Exception e) {
484      return "Can't print classification tree.";
485    }
486  }
487
488  /**
489   * Returns a newly created tree.
490   *
491   * @param data the training data
492   * @return the generated tree
493   * @throws Exception if something goes wrong
494   */
495  protected ClassifierTree getNewTree(Instances data) throws Exception {
496         
497    ClassifierTree newTree = new ClassifierTree(m_toSelectModel);
498    newTree.buildTree(data, false);
499   
500    return newTree;
501  }
502
503  /**
504   * Returns a newly created tree.
505   *
506   * @param train the training data
507   * @param test the pruning data.
508   * @return the generated tree
509   * @throws Exception if something goes wrong
510   */
511  protected ClassifierTree getNewTree(Instances train, Instances test) 
512       throws Exception {
513         
514    ClassifierTree newTree = new ClassifierTree(m_toSelectModel);
515    newTree.buildTree(train, test, false);
516   
517    return newTree;
518  }
519
520  /**
521   * Help method for printing tree structure.
522   *
523   * @param depth the current depth
524   * @param text for outputting the structure
525   * @throws Exception if something goes wrong
526   */
527  private void dumpTree(int depth, StringBuffer text) 
528       throws Exception {
529   
530    int i,j;
531   
532    for (i=0;i<m_sons.length;i++) {
533      text.append("\n");;
534      for (j=0;j<depth;j++)
535        text.append("|   ");
536      text.append(m_localModel.leftSide(m_train));
537      text.append(m_localModel.rightSide(i, m_train));
538      if (m_sons[i].m_isLeaf) {
539        text.append(": ");
540        text.append(m_localModel.dumpLabel(i,m_train));
541      }else
542        m_sons[i].dumpTree(depth+1,text);
543    }
544  }
545
546  /**
547   * Help method for printing tree structure as a graph.
548   *
549   * @param text for outputting the tree
550   * @throws Exception if something goes wrong
551   */
552  private void graphTree(StringBuffer text) throws Exception {
553   
554    for (int i = 0; i < m_sons.length; i++) {
555      text.append("N" + m_id 
556                  + "->" + 
557                  "N" + m_sons[i].m_id +
558                  " [label=\"" + m_localModel.rightSide(i,m_train).trim() + 
559                  "\"]\n");
560      if (m_sons[i].m_isLeaf) {
561        text.append("N" + m_sons[i].m_id +
562                    " [label=\""+m_localModel.dumpLabel(i,m_train)+"\" "+ 
563                    "shape=box style=filled ");
564        if (m_train != null && m_train.numInstances() > 0) {
565          text.append("data =\n" + m_sons[i].m_train + "\n");
566          text.append(",\n");
567        }
568        text.append("]\n");
569      } else {
570        text.append("N" + m_sons[i].m_id +
571                    " [label=\""+m_sons[i].m_localModel.leftSide(m_train) + 
572                    "\" ");
573        if (m_train != null && m_train.numInstances() > 0) {
574          text.append("data =\n" + m_sons[i].m_train + "\n");
575          text.append(",\n");
576        }
577        text.append("]\n");
578        m_sons[i].graphTree(text);
579      }
580    }
581  }
582
583  /**
584   * Prints the tree in prefix form
585   *
586   * @param text the buffer to output the prefix form to
587   * @throws Exception if something goes wrong
588   */
589  private void prefixTree(StringBuffer text) throws Exception {
590
591    text.append("[");
592    text.append(m_localModel.leftSide(m_train)+":");
593    for (int i = 0; i < m_sons.length; i++) {
594      if (i > 0) {
595        text.append(",\n");
596      }
597      text.append(m_localModel.rightSide(i, m_train));
598    }
599    for (int i = 0; i < m_sons.length; i++) {
600      if (m_sons[i].m_isLeaf) {
601        text.append("[");
602        text.append(m_localModel.dumpLabel(i,m_train));
603        text.append("]");
604      } else {
605        m_sons[i].prefixTree(text);
606      }
607    }
608    text.append("]");
609  }
610
611  /**
612   * Help method for computing class probabilities of
613   * a given instance.
614   *
615   * @param classIndex the class index
616   * @param instance the instance to compute the probabilities for
617   * @param weight the weight to use
618   * @return the laplace probs
619   * @throws Exception if something goes wrong
620   */
621  private double getProbsLaplace(int classIndex, Instance instance, double weight) 
622    throws Exception {
623   
624    double prob = 0;
625   
626    if (m_isLeaf) {
627      return weight * localModel().classProbLaplace(classIndex, instance, -1);
628    } else {
629      int treeIndex = localModel().whichSubset(instance);
630      if (treeIndex == -1) {
631        double[] weights = localModel().weights(instance);
632        for (int i = 0; i < m_sons.length; i++) {
633          if (!son(i).m_isEmpty) {
634        prob += son(i).getProbsLaplace(classIndex, instance, 
635                                             weights[i] * weight);
636          }
637        }
638        return prob;
639      } else {
640        if (son(treeIndex).m_isEmpty) {
641          return weight * localModel().classProbLaplace(classIndex, instance, 
642                                                        treeIndex);
643        } else {
644          return son(treeIndex).getProbsLaplace(classIndex, instance, weight);
645        }
646      }
647    }
648  }
649
650  /**
651   * Help method for computing class probabilities of
652   * a given instance.
653   *
654   * @param classIndex the class index
655   * @param instance the instance to compute the probabilities for
656   * @param weight the weight to use
657   * @return the probs
658   * @throws Exception if something goes wrong
659   */
660  private double getProbs(int classIndex, Instance instance, double weight) 
661    throws Exception {
662   
663    double prob = 0;
664   
665    if (m_isLeaf) {
666      return weight * localModel().classProb(classIndex, instance, -1);
667    } else {
668      int treeIndex = localModel().whichSubset(instance);
669      if (treeIndex == -1) {
670        double[] weights = localModel().weights(instance);
671        for (int i = 0; i < m_sons.length; i++) {
672          if (!son(i).m_isEmpty) {
673            prob += son(i).getProbs(classIndex, instance, 
674                                    weights[i] * weight);
675          }
676        }
677        return prob;
678      } else {
679        if (son(treeIndex).m_isEmpty) {
680          return weight * localModel().classProb(classIndex, instance, 
681                                                 treeIndex);
682        } else {
683          return son(treeIndex).getProbs(classIndex, instance, weight);
684        }
685      }
686    }
687  }
688
689  /**
690   * Method just exists to make program easier to read.
691   */
692  private ClassifierSplitModel localModel() {
693   
694    return (ClassifierSplitModel)m_localModel;
695  }
696 
697  /**
698   * Method just exists to make program easier to read.
699   */
700  private ClassifierTree son(int index) {
701   
702    return (ClassifierTree)m_sons[index];
703  }
704 
705  /**
706   * Returns the revision string.
707   *
708   * @return            the revision
709   */
710  public String getRevision() {
711    return RevisionUtils.extract("$Revision: 5530 $");
712  }
713}
Note: See TracBrowser for help on using the repository browser.