source: src/main/java/weka/classifiers/bayes/NaiveBayes.java @ 19

Last change on this file since 19 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 31.5 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    NaiveBayes.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.bayes;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.core.Attribute;
28import weka.core.Capabilities;
29import weka.core.Instance;
30import weka.core.Instances;
31import weka.core.Option;
32import weka.core.OptionHandler;
33import weka.core.RevisionUtils;
34import weka.core.TechnicalInformation;
35import weka.core.TechnicalInformationHandler;
36import weka.core.Utils;
37import weka.core.WeightedInstancesHandler;
38import weka.core.Capabilities.Capability;
39import weka.core.TechnicalInformation.Field;
40import weka.core.TechnicalInformation.Type;
41import weka.estimators.DiscreteEstimator;
42import weka.estimators.Estimator;
43import weka.estimators.KernelEstimator;
44import weka.estimators.NormalEstimator;
45
46import java.util.Enumeration;
47import java.util.Vector;
48
49/**
50 <!-- globalinfo-start -->
51 * Class for a Naive Bayes classifier using estimator classes. Numeric estimator precision values are chosen based on analysis of the  training data. For this reason, the classifier is not an UpdateableClassifier (which in typical usage are initialized with zero training instances) -- if you need the UpdateableClassifier functionality, use the NaiveBayesUpdateable classifier. The NaiveBayesUpdateable classifier will  use a default precision of 0.1 for numeric attributes when buildClassifier is called with zero training instances.<br/>
52 * <br/>
53 * For more information on Naive Bayes classifiers, see<br/>
54 * <br/>
55 * George H. John, Pat Langley: Estimating Continuous Distributions in Bayesian Classifiers. In: Eleventh Conference on Uncertainty in Artificial Intelligence, San Mateo, 338-345, 1995.
56 * <p/>
57 <!-- globalinfo-end -->
58 *
59 <!-- technical-bibtex-start -->
60 * BibTeX:
61 * <pre>
62 * &#64;inproceedings{John1995,
63 *    address = {San Mateo},
64 *    author = {George H. John and Pat Langley},
65 *    booktitle = {Eleventh Conference on Uncertainty in Artificial Intelligence},
66 *    pages = {338-345},
67 *    publisher = {Morgan Kaufmann},
68 *    title = {Estimating Continuous Distributions in Bayesian Classifiers},
69 *    year = {1995}
70 * }
71 * </pre>
72 * <p/>
73 <!-- technical-bibtex-end -->
74 *
75 <!-- options-start -->
76 * Valid options are: <p/>
77 *
78 * <pre> -K
79 *  Use kernel density estimator rather than normal
80 *  distribution for numeric attributes</pre>
81 *
82 * <pre> -D
83 *  Use supervised discretization to process numeric attributes
84 * </pre>
85 *
86 * <pre> -O
87 *  Display model in old format (good when there are many classes)
88 * </pre>
89 *
90 <!-- options-end -->
91 *
92 * @author Len Trigg (trigg@cs.waikato.ac.nz)
93 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
94 * @version $Revision: 5928 $
95 */
96public class NaiveBayes extends AbstractClassifier
97implements OptionHandler, WeightedInstancesHandler, 
98           TechnicalInformationHandler {
99
100  /** for serialization */
101  static final long serialVersionUID = 5995231201785697655L;
102
103  /** The attribute estimators. */
104  protected Estimator [][] m_Distributions;
105
106  /** The class estimator. */
107  protected Estimator m_ClassDistribution;
108
109  /**
110   * Whether to use kernel density estimator rather than normal distribution
111   * for numeric attributes
112   */
113  protected boolean m_UseKernelEstimator = false;
114
115  /**
116   * Whether to use discretization than normal distribution
117   * for numeric attributes
118   */
119  protected boolean m_UseDiscretization = false;
120
121  /** The number of classes (or 1 for numeric class) */
122  protected int m_NumClasses;
123
124  /**
125   * The dataset header for the purposes of printing out a semi-intelligible
126   * model
127   */
128  protected Instances m_Instances;
129
130  /*** The precision parameter used for numeric attributes */
131  protected static final double DEFAULT_NUM_PRECISION = 0.01;
132
133  /**
134   * The discretization filter.
135   */
136  protected weka.filters.supervised.attribute.Discretize m_Disc = null;
137
138  protected boolean m_displayModelInOldFormat = false;
139
140  /**
141   * Returns a string describing this classifier
142   * @return a description of the classifier suitable for
143   * displaying in the explorer/experimenter gui
144   */
145  public String globalInfo() {
146    return "Class for a Naive Bayes classifier using estimator classes. Numeric"
147      +" estimator precision values are chosen based on analysis of the "
148      +" training data. For this reason, the classifier is not an"
149      +" UpdateableClassifier (which in typical usage are initialized with zero"
150      +" training instances) -- if you need the UpdateableClassifier functionality,"
151      +" use the NaiveBayesUpdateable classifier. The NaiveBayesUpdateable"
152      +" classifier will  use a default precision of 0.1 for numeric attributes"
153      +" when buildClassifier is called with zero training instances.\n\n"
154      +"For more information on Naive Bayes classifiers, see\n\n"
155      + getTechnicalInformation().toString();
156  }
157
158  /**
159   * Returns an instance of a TechnicalInformation object, containing
160   * detailed information about the technical background of this class,
161   * e.g., paper reference or book this class is based on.
162   *
163   * @return the technical information about this class
164   */
165  public TechnicalInformation getTechnicalInformation() {
166    TechnicalInformation        result;
167
168    result = new TechnicalInformation(Type.INPROCEEDINGS);
169    result.setValue(Field.AUTHOR, "George H. John and Pat Langley");
170    result.setValue(Field.TITLE, "Estimating Continuous Distributions in Bayesian Classifiers");
171    result.setValue(Field.BOOKTITLE, "Eleventh Conference on Uncertainty in Artificial Intelligence");
172    result.setValue(Field.YEAR, "1995");
173    result.setValue(Field.PAGES, "338-345");
174    result.setValue(Field.PUBLISHER, "Morgan Kaufmann");
175    result.setValue(Field.ADDRESS, "San Mateo");
176
177    return result;
178  }
179
180  /**
181   * Returns default capabilities of the classifier.
182   *
183   * @return      the capabilities of this classifier
184   */
185  public Capabilities getCapabilities() {
186    Capabilities result = super.getCapabilities();
187    result.disableAll();
188
189    // attributes
190    result.enable(Capability.NOMINAL_ATTRIBUTES);
191    result.enable(Capability.NUMERIC_ATTRIBUTES);
192    result.enable(Capability.MISSING_VALUES);
193
194    // class
195    result.enable(Capability.NOMINAL_CLASS);
196    result.enable(Capability.MISSING_CLASS_VALUES);
197
198    // instances
199    result.setMinimumNumberInstances(0);
200
201    return result;
202  }
203
204  /**
205   * Generates the classifier.
206   *
207   * @param instances set of instances serving as training data
208   * @exception Exception if the classifier has not been generated
209   * successfully
210   */
211  public void buildClassifier(Instances instances) throws Exception {
212
213    // can classifier handle the data?
214    getCapabilities().testWithFail(instances);
215
216    // remove instances with missing class
217    instances = new Instances(instances);
218    instances.deleteWithMissingClass();
219
220    m_NumClasses = instances.numClasses();
221
222    // Copy the instances
223    m_Instances = new Instances(instances);
224
225    // Discretize instances if required
226    if (m_UseDiscretization) {
227      m_Disc = new weka.filters.supervised.attribute.Discretize();
228      m_Disc.setInputFormat(m_Instances);
229      m_Instances = weka.filters.Filter.useFilter(m_Instances, m_Disc);
230    } else {
231      m_Disc = null;
232    }
233
234    // Reserve space for the distributions
235    m_Distributions = new Estimator[m_Instances.numAttributes() - 1]
236      [m_Instances.numClasses()];
237    m_ClassDistribution = new DiscreteEstimator(m_Instances.numClasses(), 
238                                                true);
239    int attIndex = 0;
240    Enumeration enu = m_Instances.enumerateAttributes();
241    while (enu.hasMoreElements()) {
242      Attribute attribute = (Attribute) enu.nextElement();
243
244      // If the attribute is numeric, determine the estimator
245      // numeric precision from differences between adjacent values
246      double numPrecision = DEFAULT_NUM_PRECISION;
247      if (attribute.type() == Attribute.NUMERIC) {
248        m_Instances.sort(attribute);
249        if ((m_Instances.numInstances() > 0)
250            && !m_Instances.instance(0).isMissing(attribute)) {
251          double lastVal = m_Instances.instance(0).value(attribute);
252          double currentVal, deltaSum = 0;
253          int distinct = 0;
254          for (int i = 1; i < m_Instances.numInstances(); i++) {
255            Instance currentInst = m_Instances.instance(i);
256            if (currentInst.isMissing(attribute)) {
257              break;
258            }
259            currentVal = currentInst.value(attribute);
260            if (currentVal != lastVal) {
261              deltaSum += currentVal - lastVal;
262              lastVal = currentVal;
263              distinct++;
264            }
265          }
266          if (distinct > 0) {
267            numPrecision = deltaSum / distinct;
268          }
269        }
270      }
271
272
273      for (int j = 0; j < m_Instances.numClasses(); j++) {
274        switch (attribute.type()) {
275        case Attribute.NUMERIC: 
276          if (m_UseKernelEstimator) {
277            m_Distributions[attIndex][j] = 
278              new KernelEstimator(numPrecision);
279          } else {
280            m_Distributions[attIndex][j] = 
281              new NormalEstimator(numPrecision);
282          }
283          break;
284        case Attribute.NOMINAL:
285          m_Distributions[attIndex][j] = 
286            new DiscreteEstimator(attribute.numValues(), true);
287          break;
288        default:
289          throw new Exception("Attribute type unknown to NaiveBayes");
290        }
291      }
292      attIndex++;
293    }
294
295    // Compute counts
296    Enumeration enumInsts = m_Instances.enumerateInstances();
297    while (enumInsts.hasMoreElements()) {
298      Instance instance = 
299        (Instance) enumInsts.nextElement();
300      updateClassifier(instance);
301    }
302
303    // Save space
304    m_Instances = new Instances(m_Instances, 0);
305  }
306
307
308  /**
309   * Updates the classifier with the given instance.
310   *
311   * @param instance the new training instance to include in the model
312   * @exception Exception if the instance could not be incorporated in
313   * the model.
314   */
315  public void updateClassifier(Instance instance) throws Exception {
316
317    if (!instance.classIsMissing()) {
318      Enumeration enumAtts = m_Instances.enumerateAttributes();
319      int attIndex = 0;
320      while (enumAtts.hasMoreElements()) {
321        Attribute attribute = (Attribute) enumAtts.nextElement();
322        if (!instance.isMissing(attribute)) {
323          m_Distributions[attIndex][(int)instance.classValue()].
324            addValue(instance.value(attribute), instance.weight());
325        }
326        attIndex++;
327      }
328      m_ClassDistribution.addValue(instance.classValue(),
329                                   instance.weight());
330    }
331  }
332
333
334  /**
335   * Calculates the class membership probabilities for the given test
336   * instance.
337   *
338   * @param instance the instance to be classified
339   * @return predicted class probability distribution
340   * @exception Exception if there is a problem generating the prediction
341   */
342  public double [] distributionForInstance(Instance instance) 
343    throws Exception { 
344
345    if (m_UseDiscretization) {
346      m_Disc.input(instance);
347      instance = m_Disc.output();
348    }
349    double [] probs = new double[m_NumClasses];
350    for (int j = 0; j < m_NumClasses; j++) {
351      probs[j] = m_ClassDistribution.getProbability(j);
352    }
353    Enumeration enumAtts = instance.enumerateAttributes();
354    int attIndex = 0;
355    while (enumAtts.hasMoreElements()) {
356      Attribute attribute = (Attribute) enumAtts.nextElement();
357      if (!instance.isMissing(attribute)) {
358        double temp, max = 0;
359        for (int j = 0; j < m_NumClasses; j++) {
360          temp = Math.max(1e-75, Math.pow(m_Distributions[attIndex][j].
361                                          getProbability(instance.value(attribute)), 
362                                          m_Instances.attribute(attIndex).weight()));
363          probs[j] *= temp;
364          if (probs[j] > max) {
365            max = probs[j];
366          }
367          if (Double.isNaN(probs[j])) {
368            throw new Exception("NaN returned from estimator for attribute "
369                                + attribute.name() + ":\n"
370                                + m_Distributions[attIndex][j].toString());
371          }
372        }
373        if ((max > 0) && (max < 1e-75)) { // Danger of probability underflow
374          for (int j = 0; j < m_NumClasses; j++) {
375            probs[j] *= 1e75;
376          }
377        }
378      }
379      attIndex++;
380    }
381
382    // Display probabilities
383    Utils.normalize(probs);
384    return probs;
385  }
386
387  /**
388   * Returns an enumeration describing the available options.
389   *
390   * @return an enumeration of all the available options.
391   */
392  public Enumeration listOptions() {
393
394    Vector newVector = new Vector(3);
395
396    newVector.addElement(
397              new Option("\tUse kernel density estimator rather than normal\n"
398                         +"\tdistribution for numeric attributes",
399                         "K", 0,"-K"));
400    newVector.addElement(
401              new Option("\tUse supervised discretization to process numeric attributes\n",
402                         "D", 0,"-D"));
403   
404    newVector.addElement(
405              new Option("\tDisplay model in old format (good when there are "
406                         + "many classes)\n",
407                         "O", 0, "-O"));
408   
409    return newVector.elements();
410  }
411
412  /**
413   * Parses a given list of options. <p/>
414   *
415   <!-- options-start -->
416   * Valid options are: <p/>
417   *
418   * <pre> -K
419   *  Use kernel density estimator rather than normal
420   *  distribution for numeric attributes</pre>
421   *
422   * <pre> -D
423   *  Use supervised discretization to process numeric attributes
424   * </pre>
425   *
426   * <pre> -O
427   *  Display model in old format (good when there are many classes)
428   * </pre>
429   *
430   <!-- options-end -->
431   *
432   * @param options the list of options as an array of strings
433   * @exception Exception if an option is not supported
434   */
435  public void setOptions(String[] options) throws Exception {
436
437    boolean k = Utils.getFlag('K', options);
438    boolean d = Utils.getFlag('D', options);
439    if (k && d) {
440      throw new IllegalArgumentException("Can't use both kernel density " +
441                                         "estimation and discretization!");
442    }
443    setUseSupervisedDiscretization(d);
444    setUseKernelEstimator(k);
445    setDisplayModelInOldFormat(Utils.getFlag('O', options));
446    Utils.checkForRemainingOptions(options);
447  }
448
449  /**
450   * Gets the current settings of the classifier.
451   *
452   * @return an array of strings suitable for passing to setOptions
453   */
454  public String [] getOptions() {
455
456    String [] options = new String [3];
457    int current = 0;
458
459    if (m_UseKernelEstimator) {
460      options[current++] = "-K";
461    }
462
463    if (m_UseDiscretization) {
464      options[current++] = "-D";
465    }
466
467    if (m_displayModelInOldFormat) {
468      options[current++] = "-O";
469    }
470
471    while (current < options.length) {
472      options[current++] = "";
473    }
474    return options;
475  }
476
477  /**
478   * Returns a description of the classifier.
479   *
480   * @return a description of the classifier as a string.
481   */
482  public String toString() {
483    if (m_displayModelInOldFormat) {
484      return toStringOriginal();
485    }
486
487    StringBuffer temp = new StringBuffer();
488    temp.append("Naive Bayes Classifier");
489    if (m_Instances == null) {
490      temp.append(": No model built yet.");
491    } else {
492
493      int maxWidth = 0;
494      int maxAttWidth = 0;
495      boolean containsKernel = false;
496
497      // set up max widths
498      // class values
499      for (int i = 0; i < m_Instances.numClasses(); i++) {
500        if (m_Instances.classAttribute().value(i).length() > maxWidth) {
501          maxWidth = m_Instances.classAttribute().value(i).length();
502        }
503      }
504      // attributes
505      for (int i = 0; i < m_Instances.numAttributes(); i++) {
506        if (i != m_Instances.classIndex()) {
507          Attribute a = m_Instances.attribute(i);
508          if (a.name().length() > maxAttWidth) {
509            maxAttWidth = m_Instances.attribute(i).name().length();
510          }
511          if (a.isNominal()) {
512            // check values
513            for (int j = 0; j < a.numValues(); j++) {
514              String val = a.value(j) + "  ";
515              if (val.length() > maxAttWidth) {
516                maxAttWidth = val.length();
517              }
518            }
519          }
520        }
521      }
522
523      for (int i = 0; i < m_Distributions.length; i++) {
524        for (int j = 0; j < m_Instances.numClasses(); j++) {
525          if (m_Distributions[i][0] instanceof NormalEstimator) {
526            // check mean/precision dev against maxWidth
527            NormalEstimator n = (NormalEstimator)m_Distributions[i][j];
528            double mean = Math.log(Math.abs(n.getMean())) / Math.log(10.0);
529            double precision = Math.log(Math.abs(n.getPrecision())) / Math.log(10.0);
530            double width = (mean > precision)
531              ? mean
532              : precision;
533            if (width < 0) {
534              width = 1;
535            }
536            // decimal + # decimal places + 1
537            width += 6.0;
538            if ((int)width > maxWidth) {
539              maxWidth = (int)width;
540            }
541          } else if (m_Distributions[i][0] instanceof KernelEstimator) {
542            containsKernel = true;
543            KernelEstimator ke = (KernelEstimator)m_Distributions[i][j];
544            int numK = ke.getNumKernels();
545            String temps = "K" + numK + ": mean (weight)";
546            if (maxAttWidth < temps.length()) {
547              maxAttWidth = temps.length();
548            }
549            // check means + weights against maxWidth
550            if (ke.getNumKernels() > 0) {
551              double[] means = ke.getMeans();
552              double[] weights = ke.getWeights();
553              for (int k = 0; k < ke.getNumKernels(); k++) {
554                String m = Utils.doubleToString(means[k], maxWidth, 4).trim();
555                m += " (" + Utils.doubleToString(weights[k], maxWidth, 1).trim() + ")";
556                if (maxWidth < m.length()) {
557                  maxWidth = m.length();
558                }
559              }
560            }
561          } else if (m_Distributions[i][0] instanceof DiscreteEstimator) {
562            DiscreteEstimator d = (DiscreteEstimator)m_Distributions[i][j];
563            for (int k = 0; k < d.getNumSymbols(); k++) {
564              String size = "" + d.getCount(k);
565              if (size.length() > maxWidth) {
566                maxWidth = size.length();
567              }
568            }
569            int sum = ("" + d.getSumOfCounts()).length();
570            if (sum > maxWidth) {
571              maxWidth = sum;
572            }
573          }
574        }
575      }
576
577      // Check width of class labels
578      for (int i = 0; i < m_Instances.numClasses(); i++) {
579        String cSize = m_Instances.classAttribute().value(i);
580        if (cSize.length() > maxWidth) {
581          maxWidth = cSize.length();
582        }
583      }
584
585      // Check width of class priors
586      for (int i = 0; i < m_Instances.numClasses(); i++) {
587        String priorP = 
588          Utils.doubleToString(((DiscreteEstimator)m_ClassDistribution).getProbability(i),
589                               maxWidth, 2).trim();
590        priorP = "(" + priorP + ")";
591        if (priorP.length() > maxWidth) {
592          maxWidth = priorP.length();
593        }
594      }
595   
596      if (maxAttWidth < "Attribute".length()) {
597        maxAttWidth = "Attribute".length();
598      }
599
600      if (maxAttWidth < "  weight sum".length()) {
601        maxAttWidth = "  weight sum".length();
602      }
603
604      if (containsKernel) {
605        if (maxAttWidth < "  [precision]".length()) {
606          maxAttWidth = "  [precision]".length();
607        }
608      }
609
610      maxAttWidth += 2;
611   
612
613
614      temp.append("\n\n");
615      temp.append(pad("Class", " ", 
616                      (maxAttWidth + maxWidth + 1) - "Class".length(), 
617                      true));
618
619      temp.append("\n");
620      temp.append(pad("Attribute", " ", maxAttWidth - "Attribute".length(), false));
621      // class labels
622      for (int i = 0; i < m_Instances.numClasses(); i++) {
623        String classL = m_Instances.classAttribute().value(i);
624        temp.append(pad(classL, " ", maxWidth + 1 - classL.length(), true));
625      }
626      temp.append("\n");
627      // class priors
628      temp.append(pad("", " ", maxAttWidth, true));
629      for (int i = 0; i < m_Instances.numClasses(); i++) {
630        String priorP = 
631          Utils.doubleToString(((DiscreteEstimator)m_ClassDistribution).getProbability(i),
632                               maxWidth, 2).trim();
633        priorP = "(" + priorP + ")";
634        temp.append(pad(priorP, " ", maxWidth + 1 - priorP.length(), true));
635      }
636      temp.append("\n");
637      temp.append(pad("", "=", maxAttWidth + 
638                      (maxWidth * m_Instances.numClasses()) 
639                      + m_Instances.numClasses() + 1, true));
640      temp.append("\n");
641
642      // loop over the attributes
643      int counter = 0;
644      for (int i = 0; i < m_Instances.numAttributes(); i++) {
645        if (i == m_Instances.classIndex()) {
646          continue;
647        }
648        String attName = m_Instances.attribute(i).name();
649        temp.append(attName + "\n");
650         
651        if (m_Distributions[counter][0] instanceof NormalEstimator) {
652          String meanL = "  mean";
653          temp.append(pad(meanL, " ", maxAttWidth + 1 - meanL.length(), false));
654          for (int j = 0; j < m_Instances.numClasses(); j++) {           
655            // means
656            NormalEstimator n = (NormalEstimator)m_Distributions[counter][j];
657            String mean = 
658              Utils.doubleToString(n.getMean(), maxWidth, 4).trim();
659            temp.append(pad(mean, " ", maxWidth + 1 - mean.length(), true));
660          }
661          temp.append("\n");           
662          // now do std deviations
663          String stdDevL = "  std. dev.";
664          temp.append(pad(stdDevL, " ", maxAttWidth + 1 - stdDevL.length(), false));
665          for (int j = 0; j < m_Instances.numClasses(); j++) {
666            NormalEstimator n = (NormalEstimator)m_Distributions[counter][j];
667            String stdDev = 
668              Utils.doubleToString(n.getStdDev(), maxWidth, 4).trim();
669            temp.append(pad(stdDev, " ", maxWidth + 1 - stdDev.length(), true));
670          }
671          temp.append("\n");
672          // now the weight sums
673          String weightL = "  weight sum";
674          temp.append(pad(weightL, " ", maxAttWidth + 1 - weightL.length(), false));
675          for (int j = 0; j < m_Instances.numClasses(); j++) {
676            NormalEstimator n = (NormalEstimator)m_Distributions[counter][j];
677            String weight = 
678              Utils.doubleToString(n.getSumOfWeights(), maxWidth, 4).trim();
679            temp.append(pad(weight, " ", maxWidth + 1 - weight.length(), true));
680          }
681          temp.append("\n");
682          // now the precisions
683          String precisionL = "  precision";
684          temp.append(pad(precisionL, " ", maxAttWidth + 1 - precisionL.length(), false));
685          for (int j = 0; j < m_Instances.numClasses(); j++) {
686            NormalEstimator n = (NormalEstimator)m_Distributions[counter][j];
687            String precision = 
688              Utils.doubleToString(n.getPrecision(), maxWidth, 4).trim();
689            temp.append(pad(precision, " ", maxWidth + 1 - precision.length(), true));
690          }
691          temp.append("\n\n");
692           
693        } else if (m_Distributions[counter][0] instanceof DiscreteEstimator) {
694          Attribute a = m_Instances.attribute(i);
695          for (int j = 0; j < a.numValues(); j++) {
696            String val = "  " + a.value(j);
697            temp.append(pad(val, " ", maxAttWidth + 1 - val.length(), false));
698            for (int k = 0; k < m_Instances.numClasses(); k++) {
699              DiscreteEstimator d = (DiscreteEstimator)m_Distributions[counter][k];
700              String count = "" + d.getCount(j);
701              temp.append(pad(count, " ", maxWidth + 1 - count.length(), true));
702            }
703            temp.append("\n");
704          }
705          // do the totals
706          String total = "  [total]";
707          temp.append(pad(total, " ", maxAttWidth + 1 - total.length(), false));
708          for (int k = 0; k < m_Instances.numClasses(); k++) {
709            DiscreteEstimator d = (DiscreteEstimator)m_Distributions[counter][k];
710            String count = "" + d.getSumOfCounts();
711            temp.append(pad(count, " ", maxWidth + 1 - count.length(), true));
712          }
713          temp.append("\n\n");
714        } else if (m_Distributions[counter][0] instanceof KernelEstimator) {
715          String kL = "  [# kernels]";
716          temp.append(pad(kL, " ", maxAttWidth + 1 - kL.length(), false));
717          for (int k = 0; k < m_Instances.numClasses(); k++) {
718            KernelEstimator ke = (KernelEstimator)m_Distributions[counter][k];
719            String nk = "" + ke.getNumKernels();
720            temp.append(pad(nk, " ", maxWidth + 1 - nk.length(), true));
721          }
722          temp.append("\n");
723          // do num kernels, std. devs and precisions
724          String stdDevL = "  [std. dev]";
725          temp.append(pad(stdDevL, " ", maxAttWidth + 1 - stdDevL.length(), false));
726          for (int k = 0; k < m_Instances.numClasses(); k++) {
727            KernelEstimator ke = (KernelEstimator)m_Distributions[counter][k];
728            String stdD = Utils.doubleToString(ke.getStdDev(), maxWidth, 4).trim(); 
729            temp.append(pad(stdD, " ", maxWidth + 1 - stdD.length(), true));
730          }
731          temp.append("\n");
732          String precL = "  [precision]";
733          temp.append(pad(precL, " ", maxAttWidth + 1 - precL.length(), false));
734          for (int k = 0; k < m_Instances.numClasses(); k++) {
735            KernelEstimator ke = (KernelEstimator)m_Distributions[counter][k];
736            String prec = Utils.doubleToString(ke.getPrecision(), maxWidth, 4).trim(); 
737            temp.append(pad(prec, " ", maxWidth + 1 - prec.length(), true));
738          }
739          temp.append("\n");
740          // first determine max number of kernels accross the classes
741          int maxK = 0;
742          for (int k = 0; k < m_Instances.numClasses(); k++) {
743            KernelEstimator ke = (KernelEstimator)m_Distributions[counter][k];
744            if (ke.getNumKernels() > maxK) {
745              maxK = ke.getNumKernels();
746            }
747          }
748          for (int j = 0; j < maxK; j++) {
749            // means first
750            String meanL = "  K" + (j+1) + ": mean (weight)";
751            temp.append(pad(meanL, " ", maxAttWidth + 1 - meanL.length(), false));
752            for (int k = 0; k < m_Instances.numClasses(); k++) {
753              KernelEstimator ke = (KernelEstimator)m_Distributions[counter][k];
754              double[] means = ke.getMeans();
755              double[] weights = ke.getWeights();
756              String m = "--";
757              if (ke.getNumKernels() == 0) {
758                m = "" + 0;
759              } else if (j < ke.getNumKernels()) {
760                m = Utils.doubleToString(means[j], maxWidth, 4).trim();
761                m += " (" + Utils.doubleToString(weights[j], maxWidth, 1).trim() + ")";
762              }
763              temp.append(pad(m, " ", maxWidth + 1 - m.length(), true));
764            }
765            temp.append("\n");             
766          }
767          temp.append("\n");
768        }
769
770
771        counter++;
772      }
773    }
774     
775    return temp.toString();
776  }
777
778  /**
779   * Returns a description of the classifier in the old format.
780   *
781   * @return a description of the classifier as a string.
782   */
783  protected String toStringOriginal() {
784   
785    StringBuffer text = new StringBuffer();
786
787    text.append("Naive Bayes Classifier");
788    if (m_Instances == null) {
789      text.append(": No model built yet.");
790    } else {
791      try {
792        for (int i = 0; i < m_Distributions[0].length; i++) {
793          text.append("\n\nClass " + m_Instances.classAttribute().value(i) +
794                      ": Prior probability = " + Utils.
795                      doubleToString(m_ClassDistribution.getProbability(i),
796                                     4, 2) + "\n\n");
797          Enumeration enumAtts = m_Instances.enumerateAttributes();
798          int attIndex = 0;
799          while (enumAtts.hasMoreElements()) {
800            Attribute attribute = (Attribute) enumAtts.nextElement();
801            if (attribute.weight() > 0) {
802              text.append(attribute.name() + ":  " 
803                          + m_Distributions[attIndex][i]);
804            }
805            attIndex++;
806          }
807        }
808      } catch (Exception ex) {
809        text.append(ex.getMessage());
810      }
811    }
812
813    return text.toString();
814  } 
815
816  private String pad(String source, String padChar, 
817                     int length, boolean leftPad) {
818    StringBuffer temp = new StringBuffer();
819
820    if (leftPad) {
821      for (int i = 0; i< length; i++) {
822        temp.append(padChar);
823      }
824      temp.append(source);
825    } else {
826      temp.append(source);
827      for (int i = 0; i< length; i++) {
828        temp.append(padChar);
829      }
830    }
831    return temp.toString();
832  }
833
834  /**
835   * Returns the tip text for this property
836   * @return tip text for this property suitable for
837   * displaying in the explorer/experimenter gui
838   */
839  public String useKernelEstimatorTipText() {
840    return "Use a kernel estimator for numeric attributes rather than a "
841      +"normal distribution.";
842  }
843  /**
844   * Gets if kernel estimator is being used.
845   *
846   * @return Value of m_UseKernelEstimatory.
847   */
848  public boolean getUseKernelEstimator() {
849
850    return m_UseKernelEstimator;
851  }
852
853  /**
854   * Sets if kernel estimator is to be used.
855   *
856   * @param v  Value to assign to m_UseKernelEstimatory.
857   */
858  public void setUseKernelEstimator(boolean v) {
859
860    m_UseKernelEstimator = v;
861    if (v) {
862      setUseSupervisedDiscretization(false);
863    }
864  }
865
866  /**
867   * Returns the tip text for this property
868   * @return tip text for this property suitable for
869   * displaying in the explorer/experimenter gui
870   */
871  public String useSupervisedDiscretizationTipText() {
872    return "Use supervised discretization to convert numeric attributes to nominal "
873      +"ones.";
874  }
875
876  /**
877   * Get whether supervised discretization is to be used.
878   *
879   * @return true if supervised discretization is to be used.
880   */
881  public boolean getUseSupervisedDiscretization() {
882
883    return m_UseDiscretization;
884  }
885
886  /**
887   * Set whether supervised discretization is to be used.
888   *
889   * @param newblah true if supervised discretization is to be used.
890   */
891  public void setUseSupervisedDiscretization(boolean newblah) {
892
893    m_UseDiscretization = newblah;
894    if (newblah) {
895      setUseKernelEstimator(false);
896    }
897  }
898
899  /**
900   * Returns the tip text for this property
901   * @return tip text for this property suitable for
902   * displaying in the explorer/experimenter gui
903   */
904  public String displayModelInOldFormatTipText() {
905    return "Use old format for model output. The old format is "
906      + "better when there are many class values. The new format "
907      + "is better when there are fewer classes and many attributes.";
908  }
909
910  /**
911   * Set whether to display model output in the old, original
912   * format.
913   *
914   * @param d true if model ouput is to be shown in the old format
915   */
916  public void setDisplayModelInOldFormat(boolean d) {
917    m_displayModelInOldFormat = d;
918  }
919
920  /**
921   * Get whether to display model output in the old, original
922   * format.
923   *
924   * @return true if model ouput is to be shown in the old format
925   */
926  public boolean getDisplayModelInOldFormat() {
927    return m_displayModelInOldFormat;
928  }
929 
930  /**
931   * Returns the revision string.
932   *
933   * @return            the revision
934   */
935  public String getRevision() {
936    return RevisionUtils.extract("$Revision: 5928 $");
937  }
938
939  /**
940   * Main method for testing this class.
941   *
942   * @param argv the options
943   */
944  public static void main(String [] argv) {
945    runClassifier(new NaiveBayes(), argv);
946  }
947}
948
Note: See TracBrowser for help on using the repository browser.