source: src/main/java/weka/classifiers/bayes/DMNBtext.java @ 11

Last change on this file since 11 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 16.4 KB
RevLine 
[4]1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    Discriminative Multinomial Naive Bayes for Text Classification
19 *    Copyright (C) 2008 Jiang Su
20 */
21
22package weka.classifiers.bayes;
23
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.core.Instance;
28import weka.core.Instances;
29import weka.core.TechnicalInformation;
30import weka.core.TechnicalInformationHandler;
31import weka.core.Utils;
32import weka.core.WeightedInstancesHandler;
33import weka.core.Capabilities.Capability;
34import weka.core.TechnicalInformation.Field;
35import weka.core.TechnicalInformation.Type;
36import weka.classifiers.UpdateableClassifier;
37import java.util.*;
38import java.io.Serializable;
39import weka.core.Capabilities;
40import weka.core.OptionHandler;
41
42
43/**
44 <!-- globalinfo-start -->
45 * Class for building and using a Discriminative Multinomial Naive Bayes classifier. For more information see,<br/>
46 * <br/>
47 * Jiang Su,Harry Zhang,Charles X. Ling,Stan Matwin: Discriminative Parameter Learning for Bayesian Networks. In: ICML 2008', 2008.<br/>
48 * <br/>
49 * The core equation for this classifier:<br/>
50 * <br/>
51 * P[Ci|D] = (P[D|Ci] x P[Ci]) / P[D] (Bayes rule)<br/>
52 * <br/>
53 * where Ci is class i and D is a document.
54 * <p/>
55 <!-- globalinfo-end -->
56 *
57 <!-- technical-bibtex-start -->
58 * BibTeX:
59 * <pre>
60 * &#64;inproceedings{JiangSu2008,
61 *    author = {Jiang Su,Harry Zhang,Charles X. Ling,Stan Matwin},
62 *    booktitle = {ICML 2008'},
63 *    title = {Discriminative Parameter Learning for Bayesian Networks},
64 *    year = {2008}
65 * }
66 * </pre>
67 * <p/>
68 <!-- technical-bibtex-end -->
69 *
70 <!-- options-start -->
71 * Valid options are: <p/>
72 *
73 * <pre> -D
74 *  If set, classifier is run in debug mode and
75 *  may output additional info to the console</pre>
76 *
77 <!-- options-end -->
78 *
79 * @author Jiang Su (Jiang.Su@unb.ca) 2008
80 * @version $Revision: 5928 $
81 */
82public class DMNBtext extends AbstractClassifier
83    implements OptionHandler, WeightedInstancesHandler, 
84               TechnicalInformationHandler, UpdateableClassifier {
85
86  /** for serialization */
87  static final long serialVersionUID = 5932177450183457085L;
88  /** The number of iterations. */
89  protected int m_NumIterations = 1;
90  protected boolean m_BinaryWord = true;
91  int m_numClasses=-1;
92  protected Instances m_headerInfo;
93  DNBBinary[] m_binaryClassifiers = null;
94
95  /**
96   * Returns a string describing this classifier
97   * @return a description of the classifier suitable for
98   * displaying in the explorer/experimenter gui
99   */
100  public String globalInfo() {
101    return
102      "Class for building and using a Discriminative Multinomial Naive Bayes classifier. "
103      + "For more information see,\n\n"
104      + getTechnicalInformation().toString() + "\n\n"
105      + "The core equation for this classifier:\n\n"
106      + "P[Ci|D] = (P[D|Ci] x P[Ci]) / P[D] (Bayes rule)\n\n"
107      + "where Ci is class i and D is a document.";
108  }
109
110  /**
111   * Returns an instance of a TechnicalInformation object, containing
112   * detailed information about the technical background of this class,
113   * e.g., paper reference or book this class is based on.
114   *
115   * @return the technical information about this class
116   */
117  public TechnicalInformation getTechnicalInformation() {
118    TechnicalInformation        result;
119
120    result = new TechnicalInformation(Type.INPROCEEDINGS);
121    result.setValue(Field.AUTHOR, "Jiang Su,Harry Zhang,Charles X. Ling,Stan Matwin");
122    result.setValue(Field.YEAR, "2008");
123    result.setValue(Field.TITLE, "Discriminative Parameter Learning for Bayesian Networks");
124    result.setValue(Field.BOOKTITLE, "ICML 2008'");
125
126    return result;
127  }
128
129  /**
130   * Returns default capabilities of the classifier.
131   *
132   * @return      the capabilities of this classifier
133   */
134  public Capabilities getCapabilities() {
135    Capabilities result = super.getCapabilities();
136    result.disableAll();
137
138    // attributes
139    result.enable(Capability.NUMERIC_ATTRIBUTES);
140
141    // class
142    result.enable(Capability.NOMINAL_CLASS);
143    result.enable(Capability.MISSING_CLASS_VALUES);
144
145    return result;
146  }
147
148  /**
149   * Generates the classifier.
150   *
151   * @param data set of instances serving as training data
152   * @exception Exception if the classifier has not been generated successfully
153   */
154  public void buildClassifier(Instances data) throws Exception {
155    // can classifier handle the data?
156    getCapabilities().testWithFail(data);
157    // remove instances with missing class
158    Instances instances =  new Instances(data);
159    instances.deleteWithMissingClass();
160
161    m_binaryClassifiers = new DNBBinary[instances.numClasses()];
162    m_numClasses=instances.numClasses();
163    m_headerInfo = new Instances(instances, 0);
164    for (int i = 0; i < instances.numClasses(); i++) {
165      m_binaryClassifiers[i] = new DNBBinary();
166      m_binaryClassifiers[i].setTargetClass(i);
167      m_binaryClassifiers[i].initClassifier(instances);
168    }
169
170    if (instances.numInstances() == 0)
171      return;
172    //Iterative update
173    Random random = new Random();
174    for (int it = 0; it < m_NumIterations; it++) {
175      for (int i = 0; i < instances.numInstances(); i++) {
176        updateClassifier(instances.instance(i));
177      }
178    }
179
180    //  Utils.normalize(m_oldClassDis);
181    // Utils.normalize(m_ClassDis);
182    // m_originalPositive = m_oldClassDis[0];
183    //   m_positive = m_ClassDis[0];
184
185  }
186
187  /**
188   * Updates the classifier with the given instance.
189   *
190   * @param instance the new training instance to include in the model
191   * @exception Exception if the instance could not be incorporated in
192   * the model.
193   */
194
195  public void updateClassifier(Instance instance) throws Exception {
196
197    if (m_numClasses == 2) {
198      m_binaryClassifiers[0].updateClassifier(instance);
199    } else {
200      for (int i = 0; i < instance.numClasses(); i++)
201        m_binaryClassifiers[i].updateClassifier(instance);
202    }
203  }
204
205  /**
206   * Calculates the class membership probabilities for the given test
207   * instance.
208   *
209   * @param instance the instance to be classified
210   * @return predicted class probability distribution
211   * @exception Exception if there is a problem generating the prediction
212   */
213  public double[] distributionForInstance(Instance instance) throws Exception {
214    if (m_numClasses == 2) {
215      // System.out.println(m_binaryClassifiers[0].getProbForTargetClass(instance));
216      return m_binaryClassifiers[0].distributionForInstance(instance);
217    }
218    double[] logDocGivenClass = new double[instance.numClasses()];
219    for (int i = 0; i < m_numClasses; i++)
220      logDocGivenClass[i] = m_binaryClassifiers[i].getLogProbForTargetClass(instance);
221
222
223    double max = logDocGivenClass[Utils.maxIndex(logDocGivenClass)];
224    for(int i = 0; i<m_numClasses; i++)
225      logDocGivenClass[i] = Math.exp(logDocGivenClass[i] - max);
226
227
228    try {
229      Utils.normalize(logDocGivenClass);
230    } catch (Exception e) {
231      e.printStackTrace();
232
233
234    }
235
236    return logDocGivenClass;
237  }
238  /**
239   * Returns a string representation of the classifier.
240   *
241   * @return a string representation of the classifier
242   */
243  public String toString() {
244    StringBuffer result = new StringBuffer("");
245    result.append("The log ratio of two conditional probabilities of a word w_i: log(p(w_i)|+)/p(w_i)|-)) in decent order based on their absolute values\n");
246    result.append("Can be used to measure the discriminative power of each word.\n");
247    if (m_numClasses == 2) {
248      // System.out.println(m_binaryClassifiers[0].getProbForTargetClass(instance));
249      return result.append(m_binaryClassifiers[0].toString()).toString();
250    }
251    for (int i = 0; i < m_numClasses; i++)
252      { result.append(i+" against the rest classes\n");
253        result.append(m_binaryClassifiers[i].toString()+"\n");
254      }
255    return result.toString();
256  }
257
258  /*
259   * Options after -- are passed to the designated classifier.<p>
260   *
261   * @param options the list of options as an array of strings
262   * @exception Exception if an option is not supported
263   */
264  public void setOptions(String[] options) throws Exception {
265
266    String iterations = Utils.getOption('I', options);
267    if (iterations.length() != 0) {
268      setNumIterations(Integer.parseInt(iterations));
269    } else {
270      setNumIterations(m_NumIterations);
271    }
272    iterations = Utils.getOption('B', options);
273    if (iterations.length() != 0) {
274      setBinaryWord(Boolean.parseBoolean(iterations));
275    } else {
276      setBinaryWord(m_BinaryWord);
277    }
278
279  }
280
281  /**
282   * Gets the current settings of the classifier.
283   *
284   * @return an array of strings suitable for passing to setOptions
285   */
286  public String[] getOptions() {
287
288    String[] options = new String[4];
289
290    int current = 0;
291    options[current++] = "-I";
292    options[current++] = "" + getNumIterations();
293
294    options[current++] = "-B";
295    options[current++] = "" + getBinaryWord();
296
297    return options;
298  }
299
300  /**
301   * Returns the tip text for this property
302   * @return tip text for this property suitable for
303   * displaying in the explorer/experimenter gui
304   */
305  public String numIterationsTipText() {
306    return "The number of iterations that the classifier will scan the training data";
307  }
308
309  /**
310   * Sets the number of iterations to be performed
311   */
312  public void setNumIterations(int numIterations) {
313
314    m_NumIterations = numIterations;
315  }
316
317  /**
318   * Gets the number of iterations to be performed
319   *
320   * @return the iterations to be performed
321   */
322  public int getNumIterations() {
323
324    return m_NumIterations;
325  }
326  /**
327   * Returns the tip text for this property
328   * @return tip text for this property suitable for
329   * displaying in the explorer/experimenter gui
330   */
331  public String binaryWordTipText() {
332    return " whether ingore the frequency information in data";
333  }
334  /**
335   * Sets whether use binary text representation
336   */
337  public void setBinaryWord(boolean val) {
338
339    m_BinaryWord = val;
340  }
341
342  /**
343   * Gets whether use binary text representation
344   *
345   * @return whether use binary text representation
346   */
347  public boolean getBinaryWord() {
348
349    return m_BinaryWord;
350  }
351
352  /**
353   * Returns the revision string.
354   *
355   * @return            the revision
356   */
357  public String getRevision() {
358    return "$Revision: 1.0";
359  }
360
361  public class DNBBinary implements Serializable {
362
363    /** The number of iterations. */
364    private double[][] m_perWordPerClass;
365    private double[] m_wordsPerClass;
366    int m_classIndex = -1;
367    private double[] m_classDistribution;
368    /** number of unique words */
369    private int m_numAttributes;
370    //set the target class
371    private int m_targetClass = -1;
372
373    private double m_WordLaplace=1;
374
375    private double[] m_coefficient;
376    private double m_classRatio;
377    private double m_wordRatio;
378
379    public void initClassifier(Instances instances) throws Exception {
380      m_numAttributes = instances.numAttributes();
381      m_perWordPerClass = new double[2][m_numAttributes];
382      m_coefficient = new double[m_numAttributes];
383      m_wordsPerClass = new double[2];
384      m_classDistribution = new double[2];
385      m_WordLaplace = Math.log(m_numAttributes);
386      m_classIndex = instances.classIndex();
387
388      //Laplace
389      for (int c = 0; c < 2; c++) {
390        m_classDistribution[c] = 1;
391        m_wordsPerClass[c] = m_WordLaplace * m_numAttributes;
392        java.util.Arrays.fill(m_perWordPerClass[c], m_WordLaplace);
393      }
394
395    }
396
397    public void updateClassifier(Instance ins) throws
398      Exception {
399      //c=0 is 1, which is the target class, and c=1 is the rest
400      int classIndex = 0;
401      if (ins.value(ins.classIndex()) != m_targetClass)
402        classIndex = 1;
403      double prob = 1 -
404        distributionForInstance(ins)[classIndex];
405
406
407      double weight = prob * ins.weight();
408
409      for (int a = 0; a < ins.numValues(); a++) {
410        if (ins.index(a) != m_classIndex )
411          {
412
413            if (m_BinaryWord) {
414              if (ins.valueSparse(a) > 0) {
415                m_wordsPerClass[classIndex] +=
416                  weight;
417                m_perWordPerClass[classIndex][ins.
418                                              index(a)] +=
419                  weight;
420              }
421            } else {
422              double t = ins.valueSparse(a) * weight;
423              m_wordsPerClass[classIndex] += t;
424              m_perWordPerClass[classIndex][ins.index(a)] += t;
425            }
426            //update coefficient
427            m_coefficient[ins.index(a)] = Math.log(m_perWordPerClass[0][
428                                                                        ins.index(a)] /
429                                                   m_perWordPerClass[1][ins.index(a)]);
430          }
431      }
432      m_wordRatio = Math.log(m_wordsPerClass[0] / m_wordsPerClass[1]);
433      m_classDistribution[classIndex] += weight;
434      m_classRatio = Math.log(m_classDistribution[0] /
435                              m_classDistribution[1]);
436    }
437
438
439    /**
440     * Calculates the class membership probabilities for the given test
441     * instance.
442     *
443     * @param ins the instance to be classified
444     * @return predicted class probability distribution
445     * @exception Exception if there is a problem generating the prediction
446     */
447    public double getLogProbForTargetClass(Instance ins) throws Exception {
448
449      double probLog = m_classRatio;
450      for (int a = 0; a < ins.numValues(); a++) {
451        if (ins.index(a) != m_classIndex )
452          {
453
454            if (m_BinaryWord) {
455              if (ins.valueSparse(a) > 0) {
456                probLog += m_coefficient[ins.index(a)] -
457                  m_wordRatio;
458              }
459            } else {
460              probLog += ins.valueSparse(a) *
461                (m_coefficient[ins.index(a)] - m_wordRatio);
462            }
463          }
464      }
465      return probLog;
466    }
467
468    /**
469     * Calculates the class membership probabilities for the given test
470     * instance.
471     *
472     * @param instance the instance to be classified
473     * @return predicted class probability distribution
474     * @exception Exception if there is a problem generating the prediction
475     */
476    public double[] distributionForInstance(Instance instance) throws
477      Exception {
478      double[] probOfClassGivenDoc = new double[2];
479      double ratio=getLogProbForTargetClass(instance);
480      if (ratio > 709)
481        probOfClassGivenDoc[0]=1;
482      else
483        {
484          ratio = Math.exp(ratio);
485          probOfClassGivenDoc[0]=ratio / (1 + ratio);
486        }
487
488      probOfClassGivenDoc[1] = 1 - probOfClassGivenDoc[0];
489      return probOfClassGivenDoc;
490    }
491
492    /**
493     * Returns a string representation of the classifier.
494     *
495     * @return a string representation of the classifier
496     */
497    public String toString() {
498      //            StringBuffer result = new StringBuffer("The cofficiency of a naive Bayes classifier, can be considered as the discriminative power of a word\n--------------------------------------\n");
499      StringBuffer result = new StringBuffer();
500
501      result.append("\n");
502      TreeMap sort=new TreeMap();
503      double[] absCoeff=new double[m_numAttributes];
504      for(int w = 0; w<m_numAttributes; w++)
505        {
506          if(w==m_headerInfo.classIndex())continue;
507          String val= m_headerInfo.attribute(w).name()+": "+m_coefficient[w];
508          sort.put((-1)*Math.abs(m_coefficient[w]),val);
509        }
510      Iterator it=sort.values().iterator();
511      while(it.hasNext())
512        {
513          result.append((String)it.next());
514          result.append("\n");
515        }
516
517      return result.toString();
518    }
519
520    /**
521     * Sets the Target Class
522     */
523    public void setTargetClass(int targetClass) {
524
525      m_targetClass = targetClass;
526    }
527
528    /**
529     * Gets the Target Class
530     *
531     * @return the Target Class Index
532     */
533    public int getTargetClass() {
534
535      return m_targetClass;
536    }
537
538  }
539
540
541  /**
542   * Main method for testing this class.
543   *
544   * @param argv the options
545   */
546  public static void main(String[] argv) {
547
548    DMNBtext c = new DMNBtext();
549
550    runClassifier(c, argv);
551  }
552}
553
Note: See TracBrowser for help on using the repository browser.