source: src/main/java/weka/classifiers/trees/j48/NBTreeClassifierTree.java @ 21

Last change on this file since 21 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 7.1 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms/*
4 *    This program is free software; you can redistribute it and/or modify
5 *    it under the terms of the GNU General Public License as published by
6 *    the Free Software Foundation; either version 2 of the License, or
7 *    (at your option) any later version.
8 *
9 *    This program is distributed in the hope that it will be useful,
10 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
11 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 *    GNU General Public License for more details.
13 *
14 *    You should have received a copy of the GNU General Public License
15 *    along with this program; if not, write to the Free Software
16 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
17 */
18
19/*
20 *    NBTreeClassifierTree.java
21 *    Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
22 *
23 */
24
25package weka.classifiers.trees.j48;
26
27import weka.core.Capabilities;
28import weka.core.Instances;
29import weka.core.RevisionUtils;
30import weka.core.Capabilities.Capability;
31
32/**
33 * Class for handling a naive bayes tree structure used for
34 * classification.
35 *
36 * @author Mark Hall (mhall@cs.waikato.ac.nz)
37 * @version $Revision: 5534 $
38 */
39public class NBTreeClassifierTree
40  extends ClassifierTree {
41
42  /** for serialization */
43  private static final long serialVersionUID = -4472639447877404786L;
44
45  public NBTreeClassifierTree(ModelSelection toSelectLocModel) {
46    super(toSelectLocModel);
47  }
48
49  /**
50   * Returns default capabilities of the classifier tree.
51   *
52   * @return      the capabilities of this classifier tree
53   */
54  public Capabilities getCapabilities() {
55    Capabilities result = super.getCapabilities();
56    result.disableAll();
57
58    // attributes
59    result.enable(Capability.NOMINAL_ATTRIBUTES);
60    result.enable(Capability.NUMERIC_ATTRIBUTES);
61    result.enable(Capability.DATE_ATTRIBUTES);
62    result.enable(Capability.MISSING_VALUES);
63
64    // class
65    result.enable(Capability.NOMINAL_CLASS);
66    result.enable(Capability.MISSING_CLASS_VALUES);
67
68    // instances
69    result.setMinimumNumberInstances(0);
70   
71    return result;
72  }
73
74  /**
75   * Method for building a naive bayes classifier tree
76   *
77   * @exception Exception if something goes wrong
78   */
79  public void buildClassifier(Instances data) throws Exception {
80   super.buildClassifier(data);
81   cleanup(new Instances(data, 0));
82   assignIDs(-1);
83  }
84
85  /**
86   * Assigns a uniqe id to every node in the tree.
87   *
88  public int assignIDs(int lastID) {
89
90    int currLastID = lastID + 1;
91
92    m_id = currLastID;
93    if (m_sons != null) {
94      for (int i = 0; i < m_sons.length; i++) {
95        currLastID = m_sons[i].assignIDs(currLastID);
96      }
97    }
98    return currLastID;
99    } */
100
101  /**
102   * Returns a newly created tree.
103   *
104   * @param data the training data
105   * @exception Exception if something goes wrong
106   */
107  protected ClassifierTree getNewTree(Instances data) throws Exception {
108         
109    ClassifierTree newTree = new NBTreeClassifierTree(m_toSelectModel);
110    newTree.buildTree(data, false);
111   
112    return newTree;
113  }
114
115  /**
116   * Returns a newly created tree.
117   *
118   * @param train the training data
119   * @param test the pruning data.
120   * @exception Exception if something goes wrong
121   */
122  protected ClassifierTree getNewTree(Instances train, Instances test) 
123       throws Exception {
124         
125    ClassifierTree newTree = new NBTreeClassifierTree(m_toSelectModel);
126    newTree.buildTree(train, test, false);
127   
128    return newTree;
129  }
130
131  /**
132   * Print the models at the leaves
133   *
134   * @return textual description of the leaf models
135   */
136  public String printLeafModels() {
137    StringBuffer text = new StringBuffer();
138
139    if (m_isLeaf) {
140      text.append("\nLeaf number: " + m_id+" ");
141      text.append(m_localModel.toString());
142      text.append("\n");
143    } else {
144       for (int i=0;i<m_sons.length;i++) {
145         text.append(((NBTreeClassifierTree)m_sons[i]).printLeafModels());
146       }
147    } 
148    return text.toString();
149  }
150
151  /**
152   * Prints tree structure.
153   */
154  public String toString() {
155
156    try {
157      StringBuffer text = new StringBuffer();
158     
159      if (m_isLeaf) {
160        text.append(": NB");
161        text.append(m_id);
162      }else
163        dumpTreeNB(0,text);
164
165      text.append("\n"+printLeafModels());
166      text.append("\n\nNumber of Leaves  : \t"+numLeaves()+"\n");
167      text.append("\nSize of the tree : \t"+numNodes()+"\n");
168 
169      return text.toString();
170    } catch (Exception e) {
171      e.printStackTrace();
172      return "Can't print nb tree.";
173    }
174  }
175
176  /**
177   * Help method for printing tree structure.
178   *
179   * @exception Exception if something goes wrong
180   */
181  private void dumpTreeNB(int depth,StringBuffer text) 
182       throws Exception {
183   
184    int i,j;
185   
186    for (i=0;i<m_sons.length;i++) {
187      text.append("\n");;
188      for (j=0;j<depth;j++)
189        text.append("|   ");
190      text.append(m_localModel.leftSide(m_train));
191      text.append(m_localModel.rightSide(i, m_train));
192      if (m_sons[i].m_isLeaf) {
193        text.append(": NB ");
194        text.append(m_sons[i].m_id);
195      }else
196        ((NBTreeClassifierTree)m_sons[i]).dumpTreeNB(depth+1,text);
197    }
198  }
199
200  /**
201   * Returns graph describing the tree.
202   *
203   * @exception Exception if something goes wrong
204   */
205  public String graph() throws Exception {
206
207    StringBuffer text = new StringBuffer();
208
209    text.append("digraph J48Tree {\n");
210    if (m_isLeaf) {
211      text.append("N" + m_id
212                  + " [label=\"" + 
213                  "NB model" + "\" " + 
214                  "shape=box style=filled ");
215      if (m_train != null && m_train.numInstances() > 0) {
216        text.append("data =\n" + m_train + "\n");
217        text.append(",\n");
218
219      }
220      text.append("]\n");
221    }else {
222      text.append("N" + m_id
223                  + " [label=\"" + 
224                  m_localModel.leftSide(m_train) + "\" ");
225      if (m_train != null && m_train.numInstances() > 0) {
226        text.append("data =\n" + m_train + "\n");
227        text.append(",\n");
228     }
229      text.append("]\n");
230      graphTree(text);
231    }
232   
233    return text.toString() +"}\n";
234  }
235
236  /**
237   * Help method for printing tree structure as a graph.
238   *
239   * @exception Exception if something goes wrong
240   */
241  private void graphTree(StringBuffer text) throws Exception {
242   
243    for (int i = 0; i < m_sons.length; i++) {
244      text.append("N" + m_id 
245                  + "->" + 
246                  "N" + m_sons[i].m_id +
247                  " [label=\"" + m_localModel.rightSide(i,m_train).trim() + 
248                  "\"]\n");
249      if (m_sons[i].m_isLeaf) {
250        text.append("N" + m_sons[i].m_id +
251                    " [label=\""+"NB Model"+"\" "+ 
252                    "shape=box style=filled ");
253        if (m_train != null && m_train.numInstances() > 0) {
254          text.append("data =\n" + m_sons[i].m_train + "\n");
255          text.append(",\n");
256        }
257        text.append("]\n");
258      } else {
259        text.append("N" + m_sons[i].m_id +
260                    " [label=\""+m_sons[i].m_localModel.leftSide(m_train) + 
261                    "\" ");
262        if (m_train != null && m_train.numInstances() > 0) {
263          text.append("data =\n" + m_sons[i].m_train + "\n");
264          text.append(",\n");
265        }
266        text.append("]\n");
267        ((NBTreeClassifierTree)m_sons[i]).graphTree(text);
268      }
269    }
270  }
271 
272  /**
273   * Returns the revision string.
274   *
275   * @return            the revision
276   */
277  public String getRevision() {
278    return RevisionUtils.extract("$Revision: 5534 $");
279  }
280}
Note: See TracBrowser for help on using the repository browser.