source: src/main/java/weka/classifiers/trees/j48/NBTreeSplit.java @ 9

Last change on this file since 9 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 11.6 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    NBTreeSplit.java
19 *    Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.trees.j48;
24
25import weka.classifiers.bayes.NaiveBayesUpdateable;
26import weka.core.Instance;
27import weka.core.Instances;
28import weka.core.RevisionUtils;
29import weka.filters.Filter;
30import weka.filters.supervised.attribute.Discretize;
31
32import java.util.Random;
33
34/**
35 * Class implementing a NBTree split on an attribute.
36 *
37 * @author Mark Hall (mhall@cs.waikato.ac.nz)
38 * @version $Revision: 6088 $
39 */
40public class NBTreeSplit
41  extends ClassifierSplitModel{
42
43  /** for serialization */
44  private static final long serialVersionUID = 8922627123884975070L;
45
46  /** Desired number of branches. */
47  private int m_complexityIndex; 
48
49  /** Attribute to split on. */
50  private int m_attIndex;         
51
52  /** Minimum number of objects in a split.   */
53  private int m_minNoObj;         
54
55  /** Value of split point. */
56  private double m_splitPoint;   
57
58  /** The sum of the weights of the instances. */
59  private double m_sumOfWeights; 
60
61  /** The weight of the instances incorrectly classified by the
62      naive bayes models arising from this split*/
63  private double m_errors;
64
65  private C45Split m_c45S;
66
67  /** The global naive bayes model for this node */
68  NBTreeNoSplit m_globalNB;
69
70  /**
71   * Initializes the split model.
72   */
73  public NBTreeSplit(int attIndex, int minNoObj, double sumOfWeights) {
74   
75    // Get index of attribute to split on.
76    m_attIndex = attIndex;
77       
78    // Set minimum number of objects.
79    m_minNoObj = minNoObj;
80
81    // Set the sum of the weights
82    m_sumOfWeights = sumOfWeights;
83   
84  }
85
86  /**
87   * Creates a NBTree-type split on the given data. Assumes that none of
88   * the class values is missing.
89   *
90   * @exception Exception if something goes wrong
91   */
92  public void buildClassifier(Instances trainInstances) 
93       throws Exception {
94
95    // Initialize the remaining instance variables.
96    m_numSubsets = 0;
97    m_splitPoint = Double.MAX_VALUE;
98    m_errors = 0;
99    if (m_globalNB != null) {
100      m_errors = m_globalNB.getErrors();
101    }
102
103    // Different treatment for enumerated and numeric
104    // attributes.
105    if (trainInstances.attribute(m_attIndex).isNominal()) {
106      m_complexityIndex = trainInstances.attribute(m_attIndex).numValues();
107      handleEnumeratedAttribute(trainInstances);
108    }else{
109      m_complexityIndex = 2;
110      trainInstances.sort(trainInstances.attribute(m_attIndex));
111      handleNumericAttribute(trainInstances);
112    }
113  }
114
115  /**
116   * Returns index of attribute for which split was generated.
117   */
118  public final int attIndex() {
119   
120    return m_attIndex;
121  }
122
123  /**
124   * Creates split on enumerated attribute.
125   *
126   * @exception Exception if something goes wrong
127   */
128  private void handleEnumeratedAttribute(Instances trainInstances)
129       throws Exception {
130
131    m_c45S = new C45Split(m_attIndex, 2, m_sumOfWeights, true);
132    m_c45S.buildClassifier(trainInstances);
133    if (m_c45S.numSubsets() == 0) {
134      return;
135    }
136    m_errors = 0;
137    Instance instance;
138
139    Instances [] trainingSets = new Instances [m_complexityIndex];
140    for (int i = 0; i < m_complexityIndex; i++) {
141      trainingSets[i] = new Instances(trainInstances, 0);
142    }
143    /*    m_distribution = new Distribution(m_complexityIndex,
144          trainInstances.numClasses()); */
145    int subset;
146    for (int i = 0; i < trainInstances.numInstances(); i++) {
147      instance = trainInstances.instance(i);
148      subset = m_c45S.whichSubset(instance);
149      if (subset > -1) {
150        trainingSets[subset].add((Instance)instance.copy());
151      } else {
152        double [] weights = m_c45S.weights(instance);
153        for (int j = 0; j < m_complexityIndex; j++) {
154          try {
155            Instance temp = (Instance) instance.copy();
156            if (weights.length == m_complexityIndex) {
157              temp.setWeight(temp.weight() * weights[j]);
158            } else {
159              temp.setWeight(temp.weight() / m_complexityIndex);
160            }
161            trainingSets[j].add(temp);
162          } catch (Exception ex) {
163            ex.printStackTrace();
164            System.err.println("*** "+m_complexityIndex);
165            System.err.println(weights.length);
166            System.exit(1);
167          }
168        }
169      }
170    }
171
172    /*    // compute weights (weights of instances per subset
173    m_weights = new double [m_complexityIndex];
174    for (int i = 0; i < m_complexityIndex; i++) {
175      m_weights[i] = trainingSets[i].sumOfWeights();
176    }
177    Utils.normalize(m_weights); */
178
179    /*
180    // Only Instances with known values are relevant.
181    Enumeration enu = trainInstances.enumerateInstances();
182    while (enu.hasMoreElements()) {
183      instance = (Instance) enu.nextElement();
184      if (!instance.isMissing(m_attIndex)) {
185        //      m_distribution.add((int)instance.value(m_attIndex),instance);
186        trainingSets[(int)instances.value(m_attIndex)].add(instance);
187      } else {
188        // add these to the error count
189        m_errors += instance.weight();
190      }
191      } */
192
193    Random r = new Random(1);
194    int minNumCount = 0;
195    for (int i = 0; i < m_complexityIndex; i++) {
196      if (trainingSets[i].numInstances() >= 5) {
197        minNumCount++;
198        // Discretize the sets
199        Discretize disc = new Discretize();
200        disc.setInputFormat(trainingSets[i]);
201        trainingSets[i] = Filter.useFilter(trainingSets[i], disc);
202
203        trainingSets[i].randomize(r);
204        trainingSets[i].stratify(5);
205        NaiveBayesUpdateable fullModel = new NaiveBayesUpdateable();
206        fullModel.buildClassifier(trainingSets[i]);
207
208        // add the errors for this branch of the split
209        m_errors += NBTreeNoSplit.crossValidate(fullModel, trainingSets[i], r);
210      } else {
211        // if fewer than min obj then just count them as errors
212        for (int j = 0; j < trainingSets[i].numInstances(); j++) {
213          m_errors += trainingSets[i].instance(j).weight();
214        }
215      }
216    }
217   
218    // Check if there are at least five instances in at least two of the subsets
219    // subsets.
220    if (minNumCount > 1) {
221      m_numSubsets = m_complexityIndex;
222    }
223  }
224
225  /**
226   * Creates split on numeric attribute.
227   *
228   * @exception Exception if something goes wrong
229   */
230  private void handleNumericAttribute(Instances trainInstances)
231       throws Exception {
232
233    m_c45S = new C45Split(m_attIndex, 2, m_sumOfWeights, true);
234    m_c45S.buildClassifier(trainInstances);
235    if (m_c45S.numSubsets() == 0) {
236      return;
237    }
238    m_errors = 0;
239
240    Instances [] trainingSets = new Instances [m_complexityIndex];
241    trainingSets[0] = new Instances(trainInstances, 0);
242    trainingSets[1] = new Instances(trainInstances, 0);
243    int subset = -1;
244   
245    // populate the subsets
246    for (int i = 0; i < trainInstances.numInstances(); i++) {
247      Instance instance = trainInstances.instance(i);
248      subset = m_c45S.whichSubset(instance);
249      if (subset != -1) {
250        trainingSets[subset].add((Instance)instance.copy());
251      } else {
252        double [] weights = m_c45S.weights(instance);
253        for (int j = 0; j < m_complexityIndex; j++) {
254          Instance temp = (Instance)instance.copy();
255          if (weights.length == m_complexityIndex) {
256            temp.setWeight(temp.weight() * weights[j]);
257          } else {
258            temp.setWeight(temp.weight() / m_complexityIndex);
259          }
260          trainingSets[j].add(temp); 
261        }
262      }
263    }
264   
265    /*    // compute weights (weights of instances per subset
266    m_weights = new double [m_complexityIndex];
267    for (int i = 0; i < m_complexityIndex; i++) {
268      m_weights[i] = trainingSets[i].sumOfWeights();
269    }
270    Utils.normalize(m_weights); */
271
272    Random r = new Random(1);
273    int minNumCount = 0;
274    for (int i = 0; i < m_complexityIndex; i++) {
275      if (trainingSets[i].numInstances() > 5) {
276        minNumCount++;
277        // Discretize the sets
278                Discretize disc = new Discretize();
279        disc.setInputFormat(trainingSets[i]);
280        trainingSets[i] = Filter.useFilter(trainingSets[i], disc);
281
282        trainingSets[i].randomize(r);
283        trainingSets[i].stratify(5);
284        NaiveBayesUpdateable fullModel = new NaiveBayesUpdateable();
285        fullModel.buildClassifier(trainingSets[i]);
286
287        // add the errors for this branch of the split
288        m_errors += NBTreeNoSplit.crossValidate(fullModel, trainingSets[i], r);
289      } else {
290        for (int j = 0; j < trainingSets[i].numInstances(); j++) {
291          m_errors += trainingSets[i].instance(j).weight();
292        }
293      }
294    }
295   
296    // Check if minimum number of Instances in at least two
297    // subsets.
298    if (minNumCount > 1) {
299      m_numSubsets = m_complexityIndex;
300    }
301  }
302
303  /**
304   * Returns index of subset instance is assigned to.
305   * Returns -1 if instance is assigned to more than one subset.
306   *
307   * @exception Exception if something goes wrong
308   */
309  public final int whichSubset(Instance instance) 
310    throws Exception {
311   
312    return m_c45S.whichSubset(instance);
313  }
314
315  /**
316   * Returns weights if instance is assigned to more than one subset.
317   * Returns null if instance is only assigned to one subset.
318   */
319  public final double [] weights(Instance instance) {
320    return m_c45S.weights(instance);
321    //     return m_weights;
322  }
323
324  /**
325   * Returns a string containing java source code equivalent to the test
326   * made at this node. The instance being tested is called "i".
327   *
328   * @param index index of the nominal value tested
329   * @param data the data containing instance structure info
330   * @return a value of type 'String'
331   */
332  public final String sourceExpression(int index, Instances data) {
333    return m_c45S.sourceExpression(index, data);
334  }
335
336  /**
337   * Prints the condition satisfied by instances in a subset.
338   *
339   * @param index of subset
340   * @param data training set.
341   */
342  public final String rightSide(int index,Instances data) {
343    return m_c45S.rightSide(index, data);
344  }
345
346  /**
347   * Prints left side of condition..
348   *
349   * @param data training set.
350   */
351  public final String leftSide(Instances data) {
352
353    return m_c45S.leftSide(data);
354  }
355
356  /**
357   * Return the probability for a class value
358   *
359   * @param classIndex the index of the class value
360   * @param instance the instance to generate a probability for
361   * @param theSubset the subset to consider
362   * @return a probability
363   * @exception Exception if an error occurs
364   */
365  public double classProb(int classIndex, Instance instance, int theSubset) 
366    throws Exception {
367
368    // use the global naive bayes model
369    if (theSubset > -1) {
370      return m_globalNB.classProb(classIndex, instance, theSubset);
371    } else {
372      throw new Exception("This shouldn't happen!!!");
373    }
374  }
375
376  /**
377   * Return the global naive bayes model for this node
378   *
379   * @return a <code>NBTreeNoSplit</code> value
380   */
381  public NBTreeNoSplit getGlobalModel() {
382    return m_globalNB;
383  }
384
385  /**
386   * Set the global naive bayes model for this node
387   *
388   * @param global a <code>NBTreeNoSplit</code> value
389   */
390  public void setGlobalModel(NBTreeNoSplit global) {
391    m_globalNB = global;
392  }
393
394  /**
395   * Return the errors made by the naive bayes models arising
396   * from this split.
397   *
398   * @return a <code>double</code> value
399   */
400  public double getErrors() {
401    return m_errors;
402  }
403 
404  /**
405   * Returns the revision string.
406   *
407   * @return            the revision
408   */
409  public String getRevision() {
410    return RevisionUtils.extract("$Revision: 6088 $");
411  }
412}
Note: See TracBrowser for help on using the repository browser.