source: src/main/java/weka/classifiers/trees/j48/NBTreeNoSplit.java @ 9

Last change on this file since 9 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 5.9 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    NBTreeNoSplit.java
19 *    Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.trees.j48;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.classifiers.Evaluation;
28import weka.classifiers.bayes.NaiveBayesUpdateable;
29import weka.core.Instance;
30import weka.core.Instances;
31import weka.core.RevisionUtils;
32import weka.filters.Filter;
33import weka.filters.supervised.attribute.Discretize;
34
35import java.util.Random;
36
37/**
38 * Class implementing a "no-split"-split (leaf node) for naive bayes
39 * trees.
40 *
41 * @author Mark Hall (mhall@cs.waikato.ac.nz)
42 * @version $Revision: 5928 $
43 */
44public final class NBTreeNoSplit
45  extends ClassifierSplitModel {
46
47  /** for serialization */
48  private static final long serialVersionUID = 7824804381545259618L;
49
50  /** the naive bayes classifier */
51  private NaiveBayesUpdateable m_nb;
52
53  /** the discretizer used */
54  private Discretize m_disc;
55
56  /** errors on the training data at this node */
57  private double m_errors;
58
59  public NBTreeNoSplit() {
60    m_numSubsets = 1;
61  }
62
63  /**
64   * Build the no-split node
65   *
66   * @param instances an <code>Instances</code> value
67   * @exception Exception if an error occurs
68   */
69  public final void buildClassifier(Instances instances) throws Exception {
70    m_nb = new NaiveBayesUpdateable();
71    m_disc = new Discretize();
72    m_disc.setInputFormat(instances);
73    Instances temp = Filter.useFilter(instances, m_disc);
74    m_nb.buildClassifier(temp);
75    if (temp.numInstances() >= 5) {
76      m_errors = crossValidate(m_nb, temp, new Random(1));
77    }
78    m_numSubsets = 1;
79  }
80
81  /**
82   * Return the errors made by the naive bayes model at this node
83   *
84   * @return the number of errors made
85   */
86  public double getErrors() {
87    return m_errors;
88  }
89
90  /**
91   * Return the discretizer used at this node
92   *
93   * @return a <code>Discretize</code> value
94   */
95  public Discretize getDiscretizer() {
96    return m_disc;
97  }
98
99  /**
100   * Get the naive bayes model at this node
101   *
102   * @return a <code>NaiveBayesUpdateable</code> value
103   */
104  public NaiveBayesUpdateable getNaiveBayesModel() {
105    return m_nb;
106  }
107
108  /**
109   * Always returns 0 because only there is only one subset.
110   */
111  public final int whichSubset(Instance instance){
112   
113    return 0;
114  }
115
116  /**
117   * Always returns null because there is only one subset.
118   */
119  public final double [] weights(Instance instance){
120
121    return null;
122  }
123 
124  /**
125   * Does nothing because no condition has to be satisfied.
126   */
127  public final String leftSide(Instances instances){
128
129    return "";
130  }
131 
132  /**
133   * Does nothing because no condition has to be satisfied.
134   */
135  public final String rightSide(int index, Instances instances){
136
137    return "";
138  }
139
140  /**
141   * Returns a string containing java source code equivalent to the test
142   * made at this node. The instance being tested is called "i".
143   *
144   * @param index index of the nominal value tested
145   * @param data the data containing instance structure info
146   * @return a value of type 'String'
147   */
148  public final String sourceExpression(int index, Instances data) {
149
150    return "true";  // or should this be false??
151  }
152
153  /**
154   * Return the probability for a class value
155   *
156   * @param classIndex the index of the class value
157   * @param instance the instance to generate a probability for
158   * @param theSubset the subset to consider
159   * @return a probability
160   * @exception Exception if an error occurs
161   */
162  public double classProb(int classIndex, Instance instance, int theSubset) 
163    throws Exception {
164    m_disc.input(instance);
165    Instance temp = m_disc.output();
166    return m_nb.distributionForInstance(temp)[classIndex];
167  }
168
169  /**
170   * Return a textual description of the node
171   *
172   * @return a <code>String</code> value
173   */
174  public String toString() {
175    return m_nb.toString();
176  }
177
178  /**
179   * Utility method for fast 5-fold cross validation of a naive bayes
180   * model
181   *
182   * @param fullModel a <code>NaiveBayesUpdateable</code> value
183   * @param trainingSet an <code>Instances</code> value
184   * @param r a <code>Random</code> value
185   * @return a <code>double</code> value
186   * @exception Exception if an error occurs
187   */
188  public static double crossValidate(NaiveBayesUpdateable fullModel,
189                               Instances trainingSet,
190                               Random r) throws Exception {
191    // make some copies for fast evaluation of 5-fold xval
192    Classifier [] copies = AbstractClassifier.makeCopies(fullModel, 5);
193    Evaluation eval = new Evaluation(trainingSet);
194    // make some splits
195    for (int j = 0; j < 5; j++) {
196      Instances test = trainingSet.testCV(5, j);
197      // unlearn these test instances
198      for (int k = 0; k < test.numInstances(); k++) {
199        test.instance(k).setWeight(-test.instance(k).weight());
200        ((NaiveBayesUpdateable)copies[j]).updateClassifier(test.instance(k));
201        // reset the weight back to its original value
202        test.instance(k).setWeight(-test.instance(k).weight());
203      }
204      eval.evaluateModel(copies[j], test);
205    }
206    return eval.incorrect();
207  }
208 
209  /**
210   * Returns the revision string.
211   *
212   * @return            the revision
213   */
214  public String getRevision() {
215    return RevisionUtils.extract("$Revision: 5928 $");
216  }
217}
Note: See TracBrowser for help on using the repository browser.