source: src/main/java/weka/classifiers/trees/j48/GraftSplit.java @ 18

Last change on this file since 18 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 15.6 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *  GraftSplit.java
19 *  Copyright (C) 2007 Geoff Webb & Janice Boughton
20 *  a split object for nodes added to a tree during grafting.
21 *  (used in classifier J48g).
22 */
23
24package weka.classifiers.trees.j48;
25
26import weka.core.*;
27
28/**
29 * Class implementing a split for nodes added to a tree during grafting.
30 *
31 * @author Janice Boughton (jrbought@infotech.monash.edu.au)
32 * @version $Revision 1.0 $
33 */
34public class GraftSplit
35  extends ClassifierSplitModel
36  implements Comparable {
37
38  /** for serialzation. */
39  private static final long serialVersionUID = 722773260393182051L;
40
41  /** the distribution for graft values, from cases in atbop */
42  private Distribution m_graftdistro;
43       
44  /** the attribute we are splitting on */
45  private int m_attIndex;
46
47  /** value of split point (if numeric attribute) */
48  private double m_splitPoint;
49
50  /** dominant class of the subset specified by m_testType */
51  private int m_maxClass;
52
53  /** dominant class of the subset not specified by m_testType */
54  private int m_otherLeafMaxClass;
55
56  /** laplace value of the subset specified by m_testType for m_maxClass */
57  private double m_laplace;
58
59  /** leaf for the subset specified by m_testType */
60  private Distribution m_leafdistro;
61
62  /**
63   * type of test:
64   * 0: <= test
65   * 1: > test
66   * 2: = test
67   * 3: != test
68   */
69  private int m_testType;
70
71
72  /**
73   * constructor
74   *
75   * @param a the attribute to split on
76   * @param v the value of a where split occurs
77   * @param t the test type (0 is <=, 1 is >, 2 is =, 3 is !)
78   * @param c the class to label the leaf node pointed to by test as.
79   * @param l the laplace value (needed when sorting GraftSplits)
80   */
81  public GraftSplit(int a, double v, int t, double c, double l) {
82
83    m_attIndex = a;
84    m_splitPoint = v;
85    m_testType = t;
86    m_maxClass = (int)c;
87    m_laplace = l;
88  }
89
90
91  /**
92   * constructor
93   *
94   * @param a the attribute to split on
95   * @param v the value of a where split occurs
96   * @param t the test type (0 is <=, 1 is >, 2 is =, 3 is !=)
97   * @param oC the class to label the leaf node not pointed to by test as.
98   * @param counts the distribution for this split
99   */
100  public GraftSplit(int a, double v, int t, double oC, double [][] counts)
101                                                           throws Exception {
102    m_attIndex = a;
103    m_splitPoint = v;
104    m_testType = t;
105    m_otherLeafMaxClass = (int)oC;
106
107    // only deal with binary cuts (<= and >; = and !=)
108    m_numSubsets = 2;
109
110    // which subset are we looking at for the graft?
111    int subset = subsetOfInterest();  // this is the subset for m_leaf
112
113    // create graft distribution, based on counts
114    m_distribution = new Distribution(counts);
115
116    // create a distribution object for m_leaf
117    double [][] lcounts = new double[1][m_distribution.numClasses()];
118    for(int c = 0; c < lcounts[0].length; c++) {
119       lcounts[0][c] = counts[subset][c];
120    }
121    m_leafdistro = new Distribution(lcounts);
122
123    // set the max class
124    m_maxClass = m_distribution.maxClass(subset);
125 
126    // set the laplace value (assumes binary class) for subset of interest
127    m_laplace = (m_distribution.perClassPerBag(subset, m_maxClass) + 1.0) 
128               / (m_distribution.perBag(subset) + 2.0);
129  }
130
131
132  /**
133   * deletes the cases in data that belong to leaf pointed to by
134   * the test (i.e. the subset of interest).  this is useful so
135   * the instances belonging to that leaf aren't passed down the
136   * other branch.
137   *
138   * @param data the instances to delete from
139   */
140  public void deleteGraftedCases(Instances data) {
141
142    int subOfInterest = subsetOfInterest();
143    for(int x = 0; x < data.numInstances(); x++) {
144       if(whichSubset(data.instance(x)) == subOfInterest) {
145          data.delete(x--);
146       }
147    }
148  }
149
150
151  /**
152   * builds m_graftdistro using the passed data
153   *
154   * @param data the instances to use when creating the distribution
155   */
156  public void buildClassifier(Instances data) throws Exception {
157
158    // distribution for the graft, not counting cases in atbop, only orig leaf
159    m_graftdistro = new Distribution(2, data.numClasses());
160 
161    // which subset are we looking at for the graft?
162    int subset = subsetOfInterest();  // this is the subset for m_leaf
163
164    double thisNodeCount = 0;
165    double knownCases = 0;
166    boolean allKnown = true;
167    // populate distribution
168    for(int x = 0; x < data.numInstances(); x++) {
169       Instance instance = data.instance(x);
170       if(instance.isMissing(m_attIndex)) {
171          allKnown = false;
172          continue;
173       }
174       knownCases += instance.weight();
175       int subst = whichSubset(instance);
176       if(subst == -1)
177          continue;
178       m_graftdistro.add(subst, instance);
179       if(subst == subset) {  // instance belongs at m_leaf
180          thisNodeCount += instance.weight();
181       }
182    }
183    double factor = (knownCases == 0) ? (1.0 / (double)2.0)
184                                      : (thisNodeCount / knownCases);
185    if(!allKnown) {
186       for(int x = 0; x < data.numInstances(); x++) {
187          if(data.instance(x).isMissing(m_attIndex)) {
188             Instance instance = data.instance(x);
189             int subst = whichSubset(instance);
190             if(subst == -1)
191                continue;
192             instance.setWeight(instance.weight() * factor);
193             m_graftdistro.add(subst, instance);
194          }
195       }
196    }
197
198    // if there are no cases at the leaf, make sure the desired
199    // class is chosen, by setting counts to 0.01
200    if(m_graftdistro.perBag(subset) == 0) {
201       double [] counts = new double[data.numClasses()];
202       counts[m_maxClass] = 0.01;
203       m_graftdistro.add(subset, counts);
204    }
205    if(m_graftdistro.perBag((subset == 0) ? 1 : 0) == 0) {
206       double [] counts = new double[data.numClasses()];
207       counts[(int)m_otherLeafMaxClass] = 0.01;
208       m_graftdistro.add((subset == 0) ? 1 : 0, counts);
209    }
210  }
211
212
213  /**
214   * @return the NoSplit object for the leaf pointed to by m_testType branch
215   */
216  public NoSplit getLeaf() {
217    return new NoSplit(m_leafdistro);
218  }
219
220
221  /**
222   * @return the NoSplit object for the leaf not pointed to by m_testType branch
223   */
224  public NoSplit getOtherLeaf() {
225
226    // the bag (subset) that isn't pointed to by m_testType branch
227    int bag = (subsetOfInterest() == 0) ? 1 : 0;
228
229    double [][] counts = new double[1][m_graftdistro.numClasses()];
230    double totals = 0;
231    for(int c = 0; c < counts[0].length; c++) {
232       counts[0][c] = m_graftdistro.perClassPerBag(bag, c);
233       totals += counts[0][c];
234    }
235    // if empty, make sure proper class gets chosen
236    if(totals == 0) {
237       counts[0][m_otherLeafMaxClass] += 0.01;
238    }
239    return new NoSplit(new Distribution(counts));
240  }
241
242
243  /**
244   * Prints label for subset index of instances (eg class).
245   *
246   * @param index the bag to dump label for
247   * @param data to get attribute names and such
248   * @return the label as a string
249   * @exception Exception if something goes wrong
250   */
251  public final String dumpLabelG(int index, Instances data) throws Exception {
252
253    StringBuffer text;
254
255    text = new StringBuffer();
256    text.append(((Instances)data).classAttribute().
257       value((index==subsetOfInterest()) ? m_maxClass : m_otherLeafMaxClass));
258    text.append(" ("+Utils.roundDouble(m_graftdistro.perBag(index),1));
259    if(Utils.gr(m_graftdistro.numIncorrect(index),0))
260       text.append("/"
261        +Utils.roundDouble(m_graftdistro.numIncorrect(index),2));
262
263    // show the graft values, only if this is subsetOfInterest()
264    if(index == subsetOfInterest()) {
265       text.append("|"+Utils.roundDouble(m_distribution.perBag(index),2));
266       if(Utils.gr(m_distribution.numIncorrect(index),0))
267          text.append("/"
268             +Utils.roundDouble(m_distribution.numIncorrect(index),2));
269    }
270    text.append(")");
271    return text.toString();
272  }
273
274
275  /**
276   * @return the subset that is specified by the test type
277   */
278  public int subsetOfInterest() {
279    if(m_testType == 2)
280       return 0;
281    if(m_testType == 3)
282       return 1;
283    return m_testType;
284  }
285
286
287  /**
288   * @return the number of positive cases in the subset of interest
289   */
290  public double positivesForSubsetOfInterest() {
291    return (m_distribution.perClassPerBag(subsetOfInterest(), m_maxClass));
292  }
293
294
295  /**
296   * @param subset the subset to get the positives for
297   * @return the number of positive cases in the specified subset
298   */
299  public double positives(int subset) {
300    return (m_distribution.perClassPerBag(subset, 
301                                    m_distribution.maxClass(subset)));
302  }
303
304
305  /**
306   * @return the number of instances in the subset of interest
307   */
308  public double totalForSubsetOfInterest() {
309    return (m_distribution.perBag(subsetOfInterest()));
310  }
311
312 
313  /**
314   * @param subset the index of the bag to get the total for
315   * @return the number of instances in the subset
316   */
317  public double totalForSubset(int subset) {
318    return (m_distribution.perBag(subset));
319  }
320
321
322  /**
323   * Prints left side of condition satisfied by instances.
324   *
325   * @param data the data.
326   */
327  public String leftSide(Instances data) {
328    return data.attribute(m_attIndex).name();
329  }
330
331
332  /**
333   * @return the index of the attribute to split on
334   */ 
335  public int attribute() {
336    return m_attIndex;
337  }
338
339
340  /**
341   * Prints condition satisfied by instances in subset index.
342   */
343  public final String rightSide(int index, Instances data) {
344
345    StringBuffer text;
346
347    text = new StringBuffer();
348    if(data.attribute(m_attIndex).isNominal())
349       if(index == 0)
350          text.append(" = "+
351                      data.attribute(m_attIndex).value((int)m_splitPoint));
352       else
353          text.append(" != "+
354                      data.attribute(m_attIndex).value((int)m_splitPoint));
355    else
356       if(index == 0)
357          text.append(" <= "+
358                      Utils.doubleToString(m_splitPoint,6));
359       else
360          text.append(" > "+
361                      Utils.doubleToString(m_splitPoint,6));
362    return text.toString();
363  }
364
365
366  /**
367   * Returns a string containing java source code equivalent to the test
368   * made at this node. The instance being tested is called "i".
369   *
370   * @param index index of the nominal value tested
371   * @param data the data containing instance structure info
372   * @return a value of type 'String'
373   */
374  public final String sourceExpression(int index, Instances data) {
375
376    StringBuffer expr = null;
377    if(index < 0) {
378       return "i[" + m_attIndex + "] == null";
379    }
380    if(data.attribute(m_attIndex).isNominal()) {
381       if(index == 0)
382          expr = new StringBuffer("i[");
383       else
384          expr = new StringBuffer("!i[");
385       expr.append(m_attIndex).append("]");
386       expr.append(".equals(\"").append(data.attribute(m_attIndex)
387                                      .value((int)m_splitPoint)).append("\")");
388    } else {
389       expr = new StringBuffer("((Double) i[");
390       expr.append(m_attIndex).append("])");
391       if(index == 0) {
392          expr.append(".doubleValue() <= ").append(m_splitPoint);
393       } else {
394          expr.append(".doubleValue() > ").append(m_splitPoint);
395       }
396    }
397    return expr.toString();
398  }
399
400
401  /**
402   * @param instance the instance to produce the weights for
403   * @return a double array of weights, null if only belongs to one subset
404   */
405  public double [] weights(Instance instance) {
406
407    double [] weights;
408    int i;
409
410    if(instance.isMissing(m_attIndex)) {
411       weights = new double [m_numSubsets];
412       for(i=0;i<m_numSubsets;i++) {
413          weights [i] = m_graftdistro.perBag(i)/m_graftdistro.total();
414       }
415       return weights;
416    } else {
417       return null;
418    }
419  }
420
421
422  /**
423   * @param instance the instance for which to determine the subset
424   * @return an int indicating the subset this instance belongs to
425   */
426  public int whichSubset(Instance instance) {
427
428    if(instance.isMissing(m_attIndex))
429       return -1;
430
431    if(instance.attribute(m_attIndex).isNominal()) {
432       // in the case of nominal, m_splitPoint is the = value, all else is !=
433       if(instance.value(m_attIndex) == m_splitPoint)
434          return 0;
435       else
436          return 1;
437    } else {
438       if(Utils.smOrEq(instance.value(m_attIndex), m_splitPoint))
439          return 0;
440       else
441          return 1;
442    }
443  }
444
445
446  /**
447   * @return the value of the split point
448   */
449  public double splitPoint() {
450    return m_splitPoint;
451  }
452
453  /**
454   * @return the dominate class for the subset of interest
455   */
456  public int maxClassForSubsetOfInterest() {
457    return m_maxClass;
458  }
459
460  /**
461   * @return the laplace value for maxClass of subset of interest
462   */
463  public double laplaceForSubsetOfInterest() {
464    return m_laplace;
465  }
466
467  /**
468   * returns the test type
469   * @return value of testtype
470   */
471  public int testType() {
472    return m_testType;
473  }
474
475  /**
476   * method needed for sorting a collection of GraftSplits by laplace value
477   * @param g the graft split to compare to this one
478   * @return -1, 0, or 1 if this GraftSplit laplace is <, = or > than that of g
479   */
480  public int compareTo(Object g) {
481
482    if(m_laplace > ((GraftSplit)g).laplaceForSubsetOfInterest())
483       return 1;
484    if(m_laplace < ((GraftSplit)g).laplaceForSubsetOfInterest())
485       return -1;
486    return 0;
487  }
488
489  /**
490   * returns the probability for instance for the specified class
491   * @param classIndex the index of the class
492   * @param instance the instance to get the probability for
493   * @param theSubset the subset
494   */
495  public final double classProb(int classIndex, Instance instance, 
496                            int theSubset) throws Exception {
497
498    if (theSubset <= -1) {
499       double [] weights = weights(instance);
500       if (weights == null) {
501          return m_distribution.prob(classIndex);
502       } else {
503          double prob = 0;
504          for (int i = 0; i < weights.length; i++) {
505             prob += weights[i] * m_distribution.prob(classIndex, i);
506          }
507          return prob;
508       }
509    } else {
510       if (Utils.gr(m_distribution.perBag(theSubset), 0)) {
511          return m_distribution.prob(classIndex, theSubset);
512       } else {
513          return m_distribution.prob(classIndex);
514       }
515    }
516  }
517
518
519  /**
520   * method for returning information about this GraftSplit
521   * @param data instances for determining names of attributes and values
522   * @return a string showing this GraftSplit's information
523   */
524  public String toString(Instances data) {
525
526    String theTest;
527    if(m_testType == 0)
528       theTest = " <= ";
529    else if(m_testType == 1)
530       theTest = " > ";
531    else if(m_testType == 2)
532       theTest = " = ";
533    else
534       theTest = " != ";
535
536    if(data.attribute(m_attIndex).isNominal())
537       theTest += data.attribute(m_attIndex).value((int)m_splitPoint);
538    else
539       theTest += Double.toString(m_splitPoint);
540
541    return data.attribute(m_attIndex).name() + theTest
542           + " (" + Double.toString(m_laplace) + ") --> " 
543           + data.attribute(data.classIndex()).value(m_maxClass);
544  }
545 
546  /**
547   * Returns the revision string.
548   *
549   * @return            the revision
550   */
551  public String getRevision() {
552    return RevisionUtils.extract("$Revision: 1.2 $");
553  }
554}
Note: See TracBrowser for help on using the repository browser.