source: src/main/java/weka/classifiers/trees/j48/Distribution.java @ 16

Last change on this file since 16 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 18.6 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    Distribution.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.trees.j48;
24
25import weka.core.Instance;
26import weka.core.Instances;
27import weka.core.RevisionHandler;
28import weka.core.RevisionUtils;
29import weka.core.Utils;
30
31import java.io.Serializable;
32import java.util.Enumeration;
33
34/**
35 * Class for handling a distribution of class values.
36 *
37 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
38 * @version $Revision: 1.12 $
39 */
40public class Distribution
41  implements Cloneable, Serializable, RevisionHandler {
42
43  /** for serialization */
44  private static final long serialVersionUID = 8526859638230806576L;
45
46  /** Weight of instances per class per bag. */
47  private double m_perClassPerBag[][]; 
48
49  /** Weight of instances per bag. */
50  private double m_perBag[];           
51
52  /** Weight of instances per class. */
53  private double m_perClass[];         
54
55  /** Total weight of instances. */
56  private double totaL;           
57
58  /**
59   * Creates and initializes a new distribution.
60   */
61  public Distribution(int numBags,int numClasses) {
62
63    int i;
64
65    m_perClassPerBag = new double [numBags][0];
66    m_perBag = new double [numBags];
67    m_perClass = new double [numClasses];
68    for (i=0;i<numBags;i++)
69      m_perClassPerBag[i] = new double [numClasses];
70    totaL = 0;
71  }
72
73  /**
74   * Creates and initializes a new distribution using the given
75   * array. WARNING: it just copies a reference to this array.
76   */
77  public Distribution(double [][] table) {
78
79    int i, j;
80
81    m_perClassPerBag = table;
82    m_perBag = new double [table.length];
83    m_perClass = new double [table[0].length];
84    for (i = 0; i < table.length; i++) 
85      for (= 0; j < table[i].length; j++) {
86        m_perBag[i] += table[i][j];
87        m_perClass[j] += table[i][j];
88        totaL += table[i][j];
89      }
90  }
91
92  /**
93   * Creates a distribution with only one bag according
94   * to instances in source.
95   *
96   * @exception Exception if something goes wrong
97   */
98  public Distribution(Instances source) throws Exception {
99   
100    m_perClassPerBag = new double [1][0];
101    m_perBag = new double [1];
102    totaL = 0;
103    m_perClass = new double [source.numClasses()];
104    m_perClassPerBag[0] = new double [source.numClasses()];
105    Enumeration enu = source.enumerateInstances();
106    while (enu.hasMoreElements())
107      add(0,(Instance) enu.nextElement());
108  }
109
110  /**
111   * Creates a distribution according to given instances and
112   * split model.
113   *
114   * @exception Exception if something goes wrong
115   */
116
117  public Distribution(Instances source, 
118                      ClassifierSplitModel modelToUse)
119       throws Exception {
120
121    int index;
122    Instance instance;
123    double[] weights;
124
125    m_perClassPerBag = new double [modelToUse.numSubsets()][0];
126    m_perBag = new double [modelToUse.numSubsets()];
127    totaL = 0;
128    m_perClass = new double [source.numClasses()];
129    for (int i = 0; i < modelToUse.numSubsets(); i++)
130      m_perClassPerBag[i] = new double [source.numClasses()];
131    Enumeration enu = source.enumerateInstances();
132    while (enu.hasMoreElements()) {
133      instance = (Instance) enu.nextElement();
134      index = modelToUse.whichSubset(instance);
135      if (index != -1)
136        add(index, instance);
137      else {
138        weights = modelToUse.weights(instance);
139        addWeights(instance, weights);
140      }
141    }
142  }
143
144  /**
145   * Creates distribution with only one bag by merging all
146   * bags of given distribution.
147   */
148  public Distribution(Distribution toMerge) {
149
150    totaL = toMerge.totaL;
151    m_perClass = new double [toMerge.numClasses()];
152    System.arraycopy(toMerge.m_perClass,0,m_perClass,0,toMerge.numClasses());
153    m_perClassPerBag = new double [1] [0];
154    m_perClassPerBag[0] = new double [toMerge.numClasses()];
155    System.arraycopy(toMerge.m_perClass,0,m_perClassPerBag[0],0,
156                     toMerge.numClasses());
157    m_perBag = new double [1];
158    m_perBag[0] = totaL;
159  }
160
161  /**
162   * Creates distribution with two bags by merging all bags apart of
163   * the indicated one.
164   */
165  public Distribution(Distribution toMerge, int index) {
166
167    int i;
168
169    totaL = toMerge.totaL;
170    m_perClass = new double [toMerge.numClasses()];
171    System.arraycopy(toMerge.m_perClass,0,m_perClass,0,toMerge.numClasses());
172    m_perClassPerBag = new double [2] [0];
173    m_perClassPerBag[0] = new double [toMerge.numClasses()];
174    System.arraycopy(toMerge.m_perClassPerBag[index],0,m_perClassPerBag[0],0,
175                     toMerge.numClasses());
176    m_perClassPerBag[1] = new double [toMerge.numClasses()];
177    for (i=0;i<toMerge.numClasses();i++)
178      m_perClassPerBag[1][i] = toMerge.m_perClass[i]-m_perClassPerBag[0][i];
179    m_perBag = new double [2];
180    m_perBag[0] = toMerge.m_perBag[index];
181    m_perBag[1] = totaL-m_perBag[0];
182  }
183 
184  /**
185   * Returns number of non-empty bags of distribution.
186   */
187  public final int actualNumBags() {
188   
189    int returnValue = 0;
190    int i;
191
192    for (i=0;i<m_perBag.length;i++)
193      if (Utils.gr(m_perBag[i],0))
194        returnValue++;
195   
196    return returnValue;
197  }
198
199  /**
200   * Returns number of classes actually occuring in distribution.
201   */
202  public final int actualNumClasses() {
203
204    int returnValue = 0;
205    int i;
206
207    for (i=0;i<m_perClass.length;i++)
208      if (Utils.gr(m_perClass[i],0))
209        returnValue++;
210   
211    return returnValue;
212  }
213
214  /**
215   * Returns number of classes actually occuring in given bag.
216   */
217  public final int actualNumClasses(int bagIndex) {
218
219    int returnValue = 0;
220    int i;
221
222    for (i=0;i<m_perClass.length;i++)
223      if (Utils.gr(m_perClassPerBag[bagIndex][i],0))
224        returnValue++;
225   
226    return returnValue;
227  }
228
229  /**
230   * Adds given instance to given bag.
231   *
232   * @exception Exception if something goes wrong
233   */
234  public final void add(int bagIndex,Instance instance) 
235       throws Exception {
236   
237    int classIndex;
238    double weight;
239
240    classIndex = (int)instance.classValue();
241    weight = instance.weight();
242    m_perClassPerBag[bagIndex][classIndex] = 
243      m_perClassPerBag[bagIndex][classIndex]+weight;
244    m_perBag[bagIndex] = m_perBag[bagIndex]+weight;
245    m_perClass[classIndex] = m_perClass[classIndex]+weight;
246    totaL = totaL+weight;
247  }
248
249  /**
250   * Subtracts given instance from given bag.
251   *
252   * @exception Exception if something goes wrong
253   */
254  public final void sub(int bagIndex,Instance instance) 
255       throws Exception {
256   
257    int classIndex;
258    double weight;
259
260    classIndex = (int)instance.classValue();
261    weight = instance.weight();
262    m_perClassPerBag[bagIndex][classIndex] = 
263      m_perClassPerBag[bagIndex][classIndex]-weight;
264    m_perBag[bagIndex] = m_perBag[bagIndex]-weight;
265    m_perClass[classIndex] = m_perClass[classIndex]-weight;
266    totaL = totaL-weight;
267  }
268
269  /**
270   * Adds counts to given bag.
271   */
272  public final void add(int bagIndex, double[] counts) {
273   
274    double sum = Utils.sum(counts);
275
276    for (int i = 0; i < counts.length; i++)
277      m_perClassPerBag[bagIndex][i] += counts[i];
278    m_perBag[bagIndex] = m_perBag[bagIndex]+sum;
279    for (int i = 0; i < counts.length; i++)
280      m_perClass[i] = m_perClass[i]+counts[i];
281    totaL = totaL+sum;
282  }
283
284  /**
285   * Adds all instances with unknown values for given attribute, weighted
286   * according to frequency of instances in each bag.
287   *
288   * @exception Exception if something goes wrong
289   */
290  public final void addInstWithUnknown(Instances source,
291                                       int attIndex)
292       throws Exception {
293
294    double [] probs;
295    double weight,newWeight;
296    int classIndex;
297    Instance instance;
298    int j;
299
300    probs = new double [m_perBag.length];
301    for (j=0;j<m_perBag.length;j++) {
302      if (Utils.eq(totaL, 0)) {
303        probs[j] = 1.0 / probs.length;
304      } else {
305        probs[j] = m_perBag[j]/totaL;
306      }
307    }
308    Enumeration enu = source.enumerateInstances();
309    while (enu.hasMoreElements()) {
310      instance = (Instance) enu.nextElement();
311      if (instance.isMissing(attIndex)) {
312        classIndex = (int)instance.classValue();
313        weight = instance.weight();
314        m_perClass[classIndex] = m_perClass[classIndex]+weight;
315        totaL = totaL+weight;
316        for (j = 0; j < m_perBag.length; j++) {
317          newWeight = probs[j]*weight;
318          m_perClassPerBag[j][classIndex] = m_perClassPerBag[j][classIndex]+
319            newWeight;
320          m_perBag[j] = m_perBag[j]+newWeight;
321        }
322      }
323    }
324  }
325
326  /**
327   * Adds all instances in given range to given bag.
328   *
329   * @exception Exception if something goes wrong
330   */
331  public final void addRange(int bagIndex,Instances source,
332                             int startIndex, int lastPlusOne)
333       throws Exception {
334
335    double sumOfWeights = 0;
336    int classIndex;
337    Instance instance;
338    int i;
339
340    for (i = startIndex; i < lastPlusOne; i++) {
341      instance = (Instance) source.instance(i);
342      classIndex = (int)instance.classValue();
343      sumOfWeights = sumOfWeights+instance.weight();
344      m_perClassPerBag[bagIndex][classIndex] += instance.weight();
345      m_perClass[classIndex] += instance.weight();
346    }
347    m_perBag[bagIndex] += sumOfWeights;
348    totaL += sumOfWeights;
349  }
350
351  /**
352   * Adds given instance to all bags weighting it according to given weights.
353   *
354   * @exception Exception if something goes wrong
355   */
356  public final void addWeights(Instance instance, 
357                               double [] weights)
358       throws Exception {
359
360    int classIndex;
361    int i;
362
363    classIndex = (int)instance.classValue();
364    for (i=0;i<m_perBag.length;i++) {
365      double weight = instance.weight() * weights[i];
366      m_perClassPerBag[i][classIndex] = m_perClassPerBag[i][classIndex] + weight;
367      m_perBag[i] = m_perBag[i] + weight;
368      m_perClass[classIndex] = m_perClass[classIndex] + weight;
369      totaL = totaL + weight;
370    }
371  }
372
373  /**
374   * Checks if at least two bags contain a minimum number of instances.
375   */
376  public final boolean check(double minNoObj) {
377
378    int counter = 0;
379    int i;
380
381    for (i=0;i<m_perBag.length;i++)
382      if (Utils.grOrEq(m_perBag[i],minNoObj))
383        counter++;
384    if (counter > 1)
385      return true;
386    else
387      return false;
388  }
389
390  /**
391   * Clones distribution (Deep copy of distribution).
392   */
393  public final Object clone() {
394
395    int i,j;
396
397    Distribution newDistribution = new Distribution (m_perBag.length,
398                                                     m_perClass.length);
399    for (i=0;i<m_perBag.length;i++) {
400      newDistribution.m_perBag[i] = m_perBag[i];
401      for (j=0;j<m_perClass.length;j++)
402        newDistribution.m_perClassPerBag[i][j] = m_perClassPerBag[i][j];
403    }
404    for (j=0;j<m_perClass.length;j++)
405      newDistribution.m_perClass[j] = m_perClass[j];
406    newDistribution.totaL = totaL;
407 
408    return newDistribution;
409  }
410
411  /**
412   * Deletes given instance from given bag.
413   *
414   * @exception Exception if something goes wrong
415   */
416  public final void del(int bagIndex,Instance instance) 
417       throws Exception {
418
419    int classIndex;
420    double weight;
421
422    classIndex = (int)instance.classValue();
423    weight = instance.weight();
424    m_perClassPerBag[bagIndex][classIndex] = 
425      m_perClassPerBag[bagIndex][classIndex]-weight;
426    m_perBag[bagIndex] = m_perBag[bagIndex]-weight;
427    m_perClass[classIndex] = m_perClass[classIndex]-weight;
428    totaL = totaL-weight;
429  }
430
431  /**
432   * Deletes all instances in given range from given bag.
433   *
434   * @exception Exception if something goes wrong
435   */
436  public final void delRange(int bagIndex,Instances source,
437                             int startIndex, int lastPlusOne)
438       throws Exception {
439
440    double sumOfWeights = 0;
441    int classIndex;
442    Instance instance;
443    int i;
444
445    for (i = startIndex; i < lastPlusOne; i++) {
446      instance = (Instance) source.instance(i);
447      classIndex = (int)instance.classValue();
448      sumOfWeights = sumOfWeights+instance.weight();
449      m_perClassPerBag[bagIndex][classIndex] -= instance.weight();
450      m_perClass[classIndex] -= instance.weight();
451    }
452    m_perBag[bagIndex] -= sumOfWeights;
453    totaL -= sumOfWeights;
454  }
455
456  /**
457   * Prints distribution.
458   */
459 
460  public final String dumpDistribution() {
461
462    StringBuffer text;
463    int i,j;
464
465    text = new StringBuffer();
466    for (i=0;i<m_perBag.length;i++) {
467      text.append("Bag num "+i+"\n");
468      for (j=0;j<m_perClass.length;j++)
469        text.append("Class num "+j+" "+m_perClassPerBag[i][j]+"\n");
470    }
471    return text.toString();
472  }
473
474  /**
475   * Sets all counts to zero.
476   */
477  public final void initialize() {
478
479    for (int i = 0; i < m_perClass.length; i++) 
480      m_perClass[i] = 0;
481    for (int i = 0; i < m_perBag.length; i++)
482      m_perBag[i] = 0;
483    for (int i = 0; i < m_perBag.length; i++)
484      for (int j = 0; j < m_perClass.length; j++)
485        m_perClassPerBag[i][j] = 0;
486    totaL = 0;
487  }
488
489  /**
490   * Returns matrix with distribution of class values.
491   */
492  public final double[][] matrix() {
493
494    return m_perClassPerBag;
495  }
496 
497  /**
498   * Returns index of bag containing maximum number of instances.
499   */
500  public final int maxBag() {
501
502    double max;
503    int maxIndex;
504    int i;
505   
506    max = 0;
507    maxIndex = -1;
508    for (i=0;i<m_perBag.length;i++)
509      if (Utils.grOrEq(m_perBag[i],max)) {
510        max = m_perBag[i];
511        maxIndex = i;
512      }
513    return maxIndex;
514  }
515
516  /**
517   * Returns class with highest frequency over all bags.
518   */
519  public final int maxClass() {
520
521    double maxCount = 0;
522    int maxIndex = 0;
523    int i;
524
525    for (i=0;i<m_perClass.length;i++)
526      if (Utils.gr(m_perClass[i],maxCount)) {
527        maxCount = m_perClass[i];
528        maxIndex = i;
529      }
530
531    return maxIndex;
532  }
533
534  /**
535   * Returns class with highest frequency for given bag.
536   */
537  public final int maxClass(int index) {
538
539    double maxCount = 0;
540    int maxIndex = 0;
541    int i;
542
543    if (Utils.gr(m_perBag[index],0)) {
544      for (i=0;i<m_perClass.length;i++)
545        if (Utils.gr(m_perClassPerBag[index][i],maxCount)) {
546          maxCount = m_perClassPerBag[index][i];
547          maxIndex = i;
548        }
549      return maxIndex;
550    }else
551      return maxClass();
552  }
553
554  /**
555   * Returns number of bags.
556   */
557  public final int numBags() {
558   
559    return m_perBag.length;
560  }
561
562  /**
563   * Returns number of classes.
564   */
565  public final int numClasses() {
566
567    return m_perClass.length;
568  }
569
570  /**
571   * Returns perClass(maxClass()).
572   */
573  public final double numCorrect() {
574
575    return m_perClass[maxClass()];
576  }
577
578  /**
579   * Returns perClassPerBag(index,maxClass(index)).
580   */
581  public final double numCorrect(int index) {
582
583    return m_perClassPerBag[index][maxClass(index)];
584  }
585
586  /**
587   * Returns total-numCorrect().
588   */
589  public final double numIncorrect() {
590
591    return totaL-numCorrect();
592  }
593
594  /**
595   * Returns perBag(index)-numCorrect(index).
596   */
597  public final double numIncorrect(int index) {
598
599    return m_perBag[index]-numCorrect(index);
600  }
601
602  /**
603   * Returns number of (possibly fractional) instances of given class in
604   * given bag.
605   */
606  public final double perClassPerBag(int bagIndex, int classIndex) {
607
608    return m_perClassPerBag[bagIndex][classIndex];
609  }
610
611  /**
612   * Returns number of (possibly fractional) instances in given bag.
613   */
614  public final double perBag(int bagIndex) {
615
616    return m_perBag[bagIndex];
617  }
618
619  /**
620   * Returns number of (possibly fractional) instances of given class.
621   */
622  public final double perClass(int classIndex) {
623
624    return m_perClass[classIndex];
625  }
626
627  /**
628   * Returns relative frequency of class over all bags with
629   * Laplace correction.
630   */
631  public final double laplaceProb(int classIndex) {
632
633    return (m_perClass[classIndex] + 1) / 
634      (totaL + (double) m_perClass.length);
635  }
636
637  /**
638   * Returns relative frequency of class for given bag.
639   */
640  public final double laplaceProb(int classIndex, int intIndex) {
641
642          if (Utils.gr(m_perBag[intIndex],0))
643                return (m_perClassPerBag[intIndex][classIndex] + 1.0) /
644                   (m_perBag[intIndex] + (double) m_perClass.length);
645          else
646            return laplaceProb(classIndex);
647         
648  }
649
650  /**
651   * Returns relative frequency of class over all bags.
652   */
653  public final double prob(int classIndex) {
654
655    if (!Utils.eq(totaL, 0)) {
656      return m_perClass[classIndex]/totaL;
657    } else {
658      return 0;
659    }
660  }
661
662  /**
663   * Returns relative frequency of class for given bag.
664   */
665  public final double prob(int classIndex,int intIndex) {
666
667    if (Utils.gr(m_perBag[intIndex],0))
668      return m_perClassPerBag[intIndex][classIndex]/m_perBag[intIndex];
669    else
670      return prob(classIndex);
671  }
672
673  /**
674   * Subtracts the given distribution from this one. The results
675   * has only one bag.
676   */
677  public final Distribution subtract(Distribution toSubstract) {
678
679    Distribution newDist = new Distribution(1,m_perClass.length);
680
681    newDist.m_perBag[0] = totaL-toSubstract.totaL;
682    newDist.totaL = newDist.m_perBag[0];
683    for (int i = 0; i < m_perClass.length; i++) {
684      newDist.m_perClassPerBag[0][i] = m_perClass[i] - toSubstract.m_perClass[i];
685      newDist.m_perClass[i] = newDist.m_perClassPerBag[0][i];
686    }
687    return newDist;
688  }
689
690  /**
691   * Returns total number of (possibly fractional) instances.
692   */
693  public final double total() {
694
695    return totaL;
696  }
697
698  /**
699   * Shifts given instance from one bag to another one.
700   *
701   * @exception Exception if something goes wrong
702   */
703  public final void shift(int from,int to,Instance instance) 
704       throws Exception {
705   
706    int classIndex;
707    double weight;
708
709    classIndex = (int)instance.classValue();
710    weight = instance.weight();
711    m_perClassPerBag[from][classIndex] -= weight;
712    m_perClassPerBag[to][classIndex] += weight;
713    m_perBag[from] -= weight;
714    m_perBag[to] += weight;
715  }
716
717  /**
718   * Shifts all instances in given range from one bag to another one.
719   *
720   * @exception Exception if something goes wrong
721   */
722  public final void shiftRange(int from,int to,Instances source,
723                               int startIndex,int lastPlusOne) 
724       throws Exception {
725   
726    int classIndex;
727    double weight;
728    Instance instance;
729    int i;
730
731    for (i = startIndex; i < lastPlusOne; i++) {
732      instance = (Instance) source.instance(i);
733      classIndex = (int)instance.classValue();
734      weight = instance.weight();
735      m_perClassPerBag[from][classIndex] -= weight;
736      m_perClassPerBag[to][classIndex] += weight;
737      m_perBag[from] -= weight;
738      m_perBag[to] += weight;
739    }
740  }
741 
742  /**
743   * Returns the revision string.
744   *
745   * @return            the revision
746   */
747  public String getRevision() {
748    return RevisionUtils.extract("$Revision: 1.12 $");
749  }
750}
Note: See TracBrowser for help on using the repository browser.