source: src/main/java/weka/classifiers/bayes/NaiveBayesSimple.java @ 24

Last change on this file since 24 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 12.7 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    NaiveBayesSimple.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.bayes;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.core.Attribute;
28import weka.core.Capabilities;
29import weka.core.Instance;
30import weka.core.Instances;
31import weka.core.RevisionUtils;
32import weka.core.TechnicalInformation;
33import weka.core.TechnicalInformationHandler;
34import weka.core.Utils;
35import weka.core.Capabilities.Capability;
36import weka.core.TechnicalInformation.Field;
37import weka.core.TechnicalInformation.Type;
38
39import java.util.Enumeration;
40
41/**
42 <!-- globalinfo-start -->
43 * Class for building and using a simple Naive Bayes classifier.Numeric attributes are modelled by a normal distribution.<br/>
44 * <br/>
45 * For more information, see<br/>
46 * <br/>
47 * Richard Duda, Peter Hart (1973). Pattern Classification and Scene Analysis. Wiley, New York.
48 * <p/>
49 <!-- globalinfo-end -->
50 *
51 <!-- technical-bibtex-start -->
52 * BibTeX:
53 * <pre>
54 * &#64;book{Duda1973,
55 *    address = {New York},
56 *    author = {Richard Duda and Peter Hart},
57 *    publisher = {Wiley},
58 *    title = {Pattern Classification and Scene Analysis},
59 *    year = {1973}
60 * }
61 * </pre>
62 * <p/>
63 <!-- technical-bibtex-end -->
64 *
65 <!-- options-start -->
66 * Valid options are: <p/>
67 *
68 * <pre> -D
69 *  If set, classifier is run in debug mode and
70 *  may output additional info to the console</pre>
71 *
72 <!-- options-end -->
73 *
74 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
75 * @version $Revision: 5928 $
76*/
77public class NaiveBayesSimple 
78  extends AbstractClassifier
79  implements TechnicalInformationHandler {
80 
81  /** for serialization */
82  static final long serialVersionUID = -1478242251770381214L;
83
84  /** All the counts for nominal attributes. */
85  protected double [][][] m_Counts;
86 
87  /** The means for numeric attributes. */
88  protected double [][] m_Means;
89
90  /** The standard deviations for numeric attributes. */
91  protected double [][] m_Devs;
92
93  /** The prior probabilities of the classes. */
94  protected double [] m_Priors;
95
96  /** The instances used for training. */
97  protected Instances m_Instances;
98
99  /** Constant for normal distribution. */
100  protected static double NORM_CONST = Math.sqrt(2 * Math.PI);
101
102  /**
103   * Returns a string describing this classifier
104   * @return a description of the classifier suitable for
105   * displaying in the explorer/experimenter gui
106   */
107  public String globalInfo() {
108    return 
109        "Class for building and using a simple Naive Bayes classifier."
110      + "Numeric attributes are modelled by a normal distribution.\n\n"
111      + "For more information, see\n\n"
112      + getTechnicalInformation().toString();
113  }
114
115  /**
116   * Returns an instance of a TechnicalInformation object, containing
117   * detailed information about the technical background of this class,
118   * e.g., paper reference or book this class is based on.
119   *
120   * @return the technical information about this class
121   */
122  public TechnicalInformation getTechnicalInformation() {
123    TechnicalInformation        result;
124   
125    result = new TechnicalInformation(Type.BOOK);
126    result.setValue(Field.AUTHOR, "Richard Duda and Peter Hart");
127    result.setValue(Field.YEAR, "1973");
128    result.setValue(Field.TITLE, "Pattern Classification and Scene Analysis");
129    result.setValue(Field.PUBLISHER, "Wiley");
130    result.setValue(Field.ADDRESS, "New York");
131   
132    return result;
133  }
134
135  /**
136   * Returns default capabilities of the classifier.
137   *
138   * @return      the capabilities of this classifier
139   */
140  public Capabilities getCapabilities() {
141    Capabilities result = super.getCapabilities();
142    result.disableAll();
143
144    // attributes
145    result.enable(Capability.NOMINAL_ATTRIBUTES);
146    result.enable(Capability.NUMERIC_ATTRIBUTES);
147    result.enable(Capability.DATE_ATTRIBUTES);
148    result.enable(Capability.MISSING_VALUES);
149
150    // class
151    result.enable(Capability.NOMINAL_CLASS);
152    result.enable(Capability.MISSING_CLASS_VALUES);
153   
154    return result;
155  }
156
157  /**
158   * Generates the classifier.
159   *
160   * @param instances set of instances serving as training data
161   * @exception Exception if the classifier has not been generated successfully
162   */
163  public void buildClassifier(Instances instances) throws Exception {
164
165    int attIndex = 0;
166    double sum;
167   
168    // can classifier handle the data?
169    getCapabilities().testWithFail(instances);
170
171    // remove instances with missing class
172    instances = new Instances(instances);
173    instances.deleteWithMissingClass();
174   
175    m_Instances = new Instances(instances, 0);
176   
177    // Reserve space
178    m_Counts = new double[instances.numClasses()]
179      [instances.numAttributes() - 1][0];
180    m_Means = new double[instances.numClasses()]
181      [instances.numAttributes() - 1];
182    m_Devs = new double[instances.numClasses()]
183      [instances.numAttributes() - 1];
184    m_Priors = new double[instances.numClasses()];
185    Enumeration enu = instances.enumerateAttributes();
186    while (enu.hasMoreElements()) {
187      Attribute attribute = (Attribute) enu.nextElement();
188      if (attribute.isNominal()) {
189        for (int j = 0; j < instances.numClasses(); j++) {
190          m_Counts[j][attIndex] = new double[attribute.numValues()];
191        }
192      } else {
193        for (int j = 0; j < instances.numClasses(); j++) {
194          m_Counts[j][attIndex] = new double[1];
195        }
196      }
197      attIndex++;
198    }
199   
200    // Compute counts and sums
201    Enumeration enumInsts = instances.enumerateInstances();
202    while (enumInsts.hasMoreElements()) {
203      Instance instance = (Instance) enumInsts.nextElement();
204      if (!instance.classIsMissing()) {
205        Enumeration enumAtts = instances.enumerateAttributes();
206        attIndex = 0;
207        while (enumAtts.hasMoreElements()) {
208          Attribute attribute = (Attribute) enumAtts.nextElement();
209          if (!instance.isMissing(attribute)) {
210            if (attribute.isNominal()) {
211              m_Counts[(int)instance.classValue()][attIndex]
212                [(int)instance.value(attribute)]++;
213            } else {
214              m_Means[(int)instance.classValue()][attIndex] +=
215                instance.value(attribute);
216              m_Counts[(int)instance.classValue()][attIndex][0]++;
217            }
218          }
219          attIndex++;
220        }
221        m_Priors[(int)instance.classValue()]++;
222      }
223    }
224   
225    // Compute means
226    Enumeration enumAtts = instances.enumerateAttributes();
227    attIndex = 0;
228    while (enumAtts.hasMoreElements()) {
229      Attribute attribute = (Attribute) enumAtts.nextElement();
230      if (attribute.isNumeric()) {
231        for (int j = 0; j < instances.numClasses(); j++) {
232          if (m_Counts[j][attIndex][0] < 2) {
233            throw new Exception("attribute " + attribute.name() +
234                                ": less than two values for class " +
235                                instances.classAttribute().value(j));
236          }
237          m_Means[j][attIndex] /= m_Counts[j][attIndex][0];
238        }
239      }
240      attIndex++;
241    }   
242   
243    // Compute standard deviations
244    enumInsts = instances.enumerateInstances();
245    while (enumInsts.hasMoreElements()) {
246      Instance instance = 
247        (Instance) enumInsts.nextElement();
248      if (!instance.classIsMissing()) {
249        enumAtts = instances.enumerateAttributes();
250        attIndex = 0;
251        while (enumAtts.hasMoreElements()) {
252          Attribute attribute = (Attribute) enumAtts.nextElement();
253          if (!instance.isMissing(attribute)) {
254            if (attribute.isNumeric()) {
255              m_Devs[(int)instance.classValue()][attIndex] +=
256                (m_Means[(int)instance.classValue()][attIndex]-
257                 instance.value(attribute))*
258                (m_Means[(int)instance.classValue()][attIndex]-
259                 instance.value(attribute));
260            }
261          }
262          attIndex++;
263        }
264      }
265    }
266    enumAtts = instances.enumerateAttributes();
267    attIndex = 0;
268    while (enumAtts.hasMoreElements()) {
269      Attribute attribute = (Attribute) enumAtts.nextElement();
270      if (attribute.isNumeric()) {
271        for (int j = 0; j < instances.numClasses(); j++) {
272          if (m_Devs[j][attIndex] <= 0) {
273            throw new Exception("attribute " + attribute.name() +
274                                ": standard deviation is 0 for class " +
275                                instances.classAttribute().value(j));
276          }
277          else {
278            m_Devs[j][attIndex] /= m_Counts[j][attIndex][0] - 1;
279            m_Devs[j][attIndex] = Math.sqrt(m_Devs[j][attIndex]);
280          }
281        }
282      }
283      attIndex++;
284    } 
285   
286    // Normalize counts
287    enumAtts = instances.enumerateAttributes();
288    attIndex = 0;
289    while (enumAtts.hasMoreElements()) {
290      Attribute attribute = (Attribute) enumAtts.nextElement();
291      if (attribute.isNominal()) {
292        for (int j = 0; j < instances.numClasses(); j++) {
293          sum = Utils.sum(m_Counts[j][attIndex]);
294          for (int i = 0; i < attribute.numValues(); i++) {
295            m_Counts[j][attIndex][i] =
296              (m_Counts[j][attIndex][i] + 1) 
297              / (sum + (double)attribute.numValues());
298          }
299        }
300      }
301      attIndex++;
302    }
303   
304    // Normalize priors
305    sum = Utils.sum(m_Priors);
306    for (int j = 0; j < instances.numClasses(); j++)
307      m_Priors[j] = (m_Priors[j] + 1) 
308        / (sum + (double)instances.numClasses());
309  }
310
311  /**
312   * Calculates the class membership probabilities for the given test instance.
313   *
314   * @param instance the instance to be classified
315   * @return predicted class probability distribution
316   * @exception Exception if distribution can't be computed
317   */
318  public double[] distributionForInstance(Instance instance) throws Exception {
319   
320    double [] probs = new double[instance.numClasses()];
321    int attIndex;
322   
323    for (int j = 0; j < instance.numClasses(); j++) {
324      probs[j] = 1;
325      Enumeration enumAtts = instance.enumerateAttributes();
326      attIndex = 0;
327      while (enumAtts.hasMoreElements()) {
328        Attribute attribute = (Attribute) enumAtts.nextElement();
329        if (!instance.isMissing(attribute)) {
330          if (attribute.isNominal()) {
331            probs[j] *= m_Counts[j][attIndex][(int)instance.value(attribute)];
332          } else {
333            probs[j] *= normalDens(instance.value(attribute),
334                                   m_Means[j][attIndex],
335                                   m_Devs[j][attIndex]);}
336        }
337        attIndex++;
338      }
339      probs[j] *= m_Priors[j];
340    }
341
342    // Normalize probabilities
343    Utils.normalize(probs);
344
345    return probs;
346  }
347
348  /**
349   * Returns a description of the classifier.
350   *
351   * @return a description of the classifier as a string.
352   */
353  public String toString() {
354
355    if (m_Instances == null) {
356      return "Naive Bayes (simple): No model built yet.";
357    }
358    try {
359      StringBuffer text = new StringBuffer("Naive Bayes (simple)");
360      int attIndex;
361     
362      for (int i = 0; i < m_Instances.numClasses(); i++) {
363        text.append("\n\nClass " + m_Instances.classAttribute().value(i) 
364                    + ": P(C) = " 
365                    + Utils.doubleToString(m_Priors[i], 10, 8)
366                    + "\n\n");
367        Enumeration enumAtts = m_Instances.enumerateAttributes();
368        attIndex = 0;
369        while (enumAtts.hasMoreElements()) {
370          Attribute attribute = (Attribute) enumAtts.nextElement();
371          text.append("Attribute " + attribute.name() + "\n");
372          if (attribute.isNominal()) {
373            for (int j = 0; j < attribute.numValues(); j++) {
374              text.append(attribute.value(j) + "\t");
375            }
376            text.append("\n");
377            for (int j = 0; j < attribute.numValues(); j++)
378              text.append(Utils.
379                          doubleToString(m_Counts[i][attIndex][j], 10, 8)
380                          + "\t");
381          } else {
382            text.append("Mean: " + Utils.
383                        doubleToString(m_Means[i][attIndex], 10, 8) + "\t");
384            text.append("Standard Deviation: " 
385                        + Utils.doubleToString(m_Devs[i][attIndex], 10, 8));
386          }
387          text.append("\n\n");
388          attIndex++;
389        }
390      }
391     
392      return text.toString();
393    } catch (Exception e) {
394      return "Can't print Naive Bayes classifier!";
395    }
396  }
397
398  /**
399   * Density function of normal distribution.
400   *
401   * @param x the value to get the density for
402   * @param mean the mean
403   * @param stdDev the standard deviation
404   * @return the density
405   */
406  protected double normalDens(double x, double mean, double stdDev) {
407   
408    double diff = x - mean;
409   
410    return (1 / (NORM_CONST * stdDev)) 
411      * Math.exp(-(diff * diff / (2 * stdDev * stdDev)));
412  }
413 
414  /**
415   * Returns the revision string.
416   *
417   * @return            the revision
418   */
419  public String getRevision() {
420    return RevisionUtils.extract("$Revision: 5928 $");
421  }
422
423  /**
424   * Main method for testing this class.
425   *
426   * @param argv the options
427   */
428  public static void main(String [] argv) {
429    runClassifier(new NaiveBayesSimple(), argv);
430  }
431}
Note: See TracBrowser for help on using the repository browser.