source: src/main/java/weka/classifiers/trees/j48/BinC45Split.java @ 21

Last change on this file since 21 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 13.7 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    BinC45Split.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.trees.j48;
24
25import weka.core.Instance;
26import weka.core.Instances;
27import weka.core.RevisionUtils;
28import weka.core.Utils;
29
30import java.util.Enumeration;
31
32/**
33 * Class implementing a binary C4.5-like split on an attribute.
34 *
35 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
36 * @version $Revision: 6073 $
37 */
38public class BinC45Split
39  extends ClassifierSplitModel {
40
41  /** for serialization */
42  private static final long serialVersionUID = -1278776919563022474L;
43
44  /** Attribute to split on. */
45  private int m_attIndex;       
46
47  /** Minimum number of objects in a split.   */ 
48  private int m_minNoObj;         
49
50  /** Use MDL correction? */
51  private boolean m_useMDLcorrection;         
52
53  /** Value of split point. */
54  private double m_splitPoint; 
55
56  /** InfoGain of split. */
57  private double m_infoGain; 
58
59  /** GainRatio of split.  */
60  private double m_gainRatio; 
61
62  /** The sum of the weights of the instances. */
63  private double m_sumOfWeights; 
64
65  /** Static reference to splitting criterion. */
66  private static InfoGainSplitCrit m_infoGainCrit = new InfoGainSplitCrit();
67
68  /** Static reference to splitting criterion. */
69  private static GainRatioSplitCrit m_gainRatioCrit = new GainRatioSplitCrit();
70
71  /**
72   * Initializes the split model.
73   */
74  public BinC45Split(int attIndex,int minNoObj,double sumOfWeights,
75                     boolean useMDLcorrection) {
76
77    // Get index of attribute to split on.
78    m_attIndex = attIndex;
79       
80    // Set minimum number of objects.
81    m_minNoObj = minNoObj;
82
83    // Set sum of weights;
84    m_sumOfWeights = sumOfWeights;
85
86    // Whether to use the MDL correction for numeric attributes
87    m_useMDLcorrection = useMDLcorrection;
88  }
89
90  /**
91   * Creates a C4.5-type split on the given data.
92   *
93   * @exception Exception if something goes wrong
94   */
95  public void buildClassifier(Instances trainInstances)
96       throws Exception {
97
98    // Initialize the remaining instance variables.
99    m_numSubsets = 0;
100    m_splitPoint = Double.MAX_VALUE;
101    m_infoGain = 0;
102    m_gainRatio = 0;
103
104    // Different treatment for enumerated and numeric
105    // attributes.
106    if (trainInstances.attribute(m_attIndex).isNominal()){
107      handleEnumeratedAttribute(trainInstances);
108    }else{
109      trainInstances.sort(trainInstances.attribute(m_attIndex));
110      handleNumericAttribute(trainInstances);
111    }
112  }   
113
114  /**
115   * Returns index of attribute for which split was generated.
116   */
117  public final int attIndex(){
118
119    return m_attIndex;
120  }
121 
122  /**
123   * Returns the split point (numeric attribute only).
124   *
125   * @return the split point used for a test on a numeric attribute
126   */
127  public double splitPoint() {
128    return m_splitPoint;
129  }
130 
131  /**
132   * Returns (C4.5-type) gain ratio for the generated split.
133   */
134  public final double gainRatio(){
135    return m_gainRatio;
136  }
137
138  /**
139   * Gets class probability for instance.
140   *
141   * @exception Exception if something goes wrong
142   */
143  public final double classProb(int classIndex,Instance instance,
144                                int theSubset) throws Exception {
145
146    if (theSubset <= -1) {
147      double [] weights = weights(instance);
148      if (weights == null) {
149        return m_distribution.prob(classIndex);
150      } else {
151        double prob = 0;
152        for (int i = 0; i < weights.length; i++) {
153          prob += weights[i] * m_distribution.prob(classIndex, i);
154        }
155        return prob;
156      }
157    } else {
158      if (Utils.gr(m_distribution.perBag(theSubset), 0)) {
159        return m_distribution.prob(classIndex, theSubset);
160      } else {
161        return m_distribution.prob(classIndex);
162      }
163    }
164  }
165 
166  /**
167   * Creates split on enumerated attribute.
168   *
169   * @exception Exception if something goes wrong
170   */
171  private void handleEnumeratedAttribute(Instances trainInstances)
172       throws Exception {
173   
174    Distribution newDistribution,secondDistribution;
175    int numAttValues;
176    double currIG,currGR;
177    Instance instance;
178    int i;
179
180    numAttValues = trainInstances.attribute(m_attIndex).numValues();
181    newDistribution = new Distribution(numAttValues,
182                                       trainInstances.numClasses());
183   
184    // Only Instances with known values are relevant.
185    Enumeration enu = trainInstances.enumerateInstances();
186    while (enu.hasMoreElements()) {
187      instance = (Instance) enu.nextElement();
188      if (!instance.isMissing(m_attIndex))
189        newDistribution.add((int)instance.value(m_attIndex),instance);
190    }
191    m_distribution = newDistribution;
192
193    // For all values
194    for (i = 0; i < numAttValues; i++){
195
196      if (Utils.grOrEq(newDistribution.perBag(i),m_minNoObj)){
197        secondDistribution = new Distribution(newDistribution,i);
198       
199        // Check if minimum number of Instances in the two
200        // subsets.
201        if (secondDistribution.check(m_minNoObj)){
202          m_numSubsets = 2;
203          currIG = m_infoGainCrit.splitCritValue(secondDistribution,
204                                               m_sumOfWeights);
205          currGR = m_gainRatioCrit.splitCritValue(secondDistribution,
206                                                m_sumOfWeights,
207                                                currIG);
208          if ((i == 0) || Utils.gr(currGR,m_gainRatio)){
209            m_gainRatio = currGR;
210            m_infoGain = currIG;
211            m_splitPoint = (double)i;
212            m_distribution = secondDistribution;
213          }
214        }
215      }
216    }
217  }
218 
219  /**
220   * Creates split on numeric attribute.
221   *
222   * @exception Exception if something goes wrong
223   */
224  private void handleNumericAttribute(Instances trainInstances)
225       throws Exception {
226 
227    int firstMiss;
228    int next = 1;
229    int last = 0;
230    int index = 0;
231    int splitIndex = -1;
232    double currentInfoGain;
233    double defaultEnt;
234    double minSplit;
235    Instance instance;
236    int i;
237
238    // Current attribute is a numeric attribute.
239    m_distribution = new Distribution(2,trainInstances.numClasses());
240   
241    // Only Instances with known values are relevant.
242    Enumeration enu = trainInstances.enumerateInstances();
243    i = 0;
244    while (enu.hasMoreElements()) {
245      instance = (Instance) enu.nextElement();
246      if (instance.isMissing(m_attIndex))
247        break;
248      m_distribution.add(1,instance);
249      i++;
250    }
251    firstMiss = i;
252
253    // Compute minimum number of Instances required in each
254    // subset.
255    minSplit =  0.1*(m_distribution.total())/
256      ((double)trainInstances.numClasses());
257    if (Utils.smOrEq(minSplit,m_minNoObj)) 
258      minSplit = m_minNoObj;
259    else
260      if (Utils.gr(minSplit,25)) 
261        minSplit = 25;
262
263    // Enough Instances with known values?
264    if (Utils.sm((double)firstMiss,2*minSplit))
265      return;
266   
267    // Compute values of criteria for all possible split
268    // indices.
269    defaultEnt = m_infoGainCrit.oldEnt(m_distribution);
270    while (next < firstMiss){
271         
272      if (trainInstances.instance(next-1).value(m_attIndex)+1e-5 < 
273          trainInstances.instance(next).value(m_attIndex)){ 
274       
275        // Move class values for all Instances up to next
276        // possible split point.
277        m_distribution.shiftRange(1,0,trainInstances,last,next);
278       
279        // Check if enough Instances in each subset and compute
280        // values for criteria.
281        if (Utils.grOrEq(m_distribution.perBag(0),minSplit) && 
282            Utils.grOrEq(m_distribution.perBag(1),minSplit)){
283          currentInfoGain = m_infoGainCrit.
284            splitCritValue(m_distribution,m_sumOfWeights,
285                           defaultEnt);
286          if (Utils.gr(currentInfoGain,m_infoGain)){
287            m_infoGain = currentInfoGain;
288            splitIndex = next-1;
289          }
290          index++;
291        }
292        last = next;
293      }
294      next++;
295    }
296   
297    // Was there any useful split?
298    if (index == 0)
299      return;
300   
301    // Compute modified information gain for best split.
302    if (m_useMDLcorrection) {
303      m_infoGain = m_infoGain-(Utils.log2(index)/m_sumOfWeights);
304    }
305    if (Utils.smOrEq(m_infoGain,0))
306      return;
307   
308    // Set instance variables' values to values for
309    // best split.
310    m_numSubsets = 2;
311    m_splitPoint = 
312      (trainInstances.instance(splitIndex+1).value(m_attIndex)+
313       trainInstances.instance(splitIndex).value(m_attIndex))/2;
314
315    // In case we have a numerical precision problem we need to choose the
316    // smaller value
317    if (m_splitPoint == trainInstances.instance(splitIndex + 1).value(m_attIndex)) {
318      m_splitPoint = trainInstances.instance(splitIndex).value(m_attIndex);
319    }
320
321    // Restore distributioN for best split.
322    m_distribution = new Distribution(2,trainInstances.numClasses());
323    m_distribution.addRange(0,trainInstances,0,splitIndex+1);
324    m_distribution.addRange(1,trainInstances,splitIndex+1,firstMiss);
325
326    // Compute modified gain ratio for best split.
327    m_gainRatio = m_gainRatioCrit.
328      splitCritValue(m_distribution,m_sumOfWeights,
329                     m_infoGain);
330  }
331
332  /**
333   * Returns (C4.5-type) information gain for the generated split.
334   */
335  public final double infoGain(){
336
337    return m_infoGain;
338  }
339
340  /**
341   * Prints left side of condition.
342   *
343   * @param data the data to get the attribute name from.
344   * @return the attribute name
345   */
346  public final String leftSide(Instances data){
347
348    return data.attribute(m_attIndex).name();
349  }
350
351  /**
352   * Prints the condition satisfied by instances in a subset.
353   *
354   * @param index of subset and training set.
355   */
356  public final String rightSide(int index,Instances data){
357
358    StringBuffer text;
359
360    text = new StringBuffer();
361    if (data.attribute(m_attIndex).isNominal()){
362      if (index == 0)
363        text.append(" = "+
364                    data.attribute(m_attIndex).value((int)m_splitPoint));
365      else
366        text.append(" != "+
367                    data.attribute(m_attIndex).value((int)m_splitPoint));
368    }else
369      if (index == 0)
370        text.append(" <= "+m_splitPoint);
371      else
372        text.append(" > "+m_splitPoint);
373   
374    return text.toString();
375  }
376
377  /**
378   * Returns a string containing java source code equivalent to the test
379   * made at this node. The instance being tested is called "i".
380   *
381   * @param index index of the nominal value tested
382   * @param data the data containing instance structure info
383   * @return a value of type 'String'
384   */
385  public final String sourceExpression(int index, Instances data) {
386
387    StringBuffer expr = null;
388    if (index < 0) {
389      return "i[" + m_attIndex + "] == null";
390    }
391    if (data.attribute(m_attIndex).isNominal()) {
392      if (index == 0) {
393        expr = new StringBuffer("i[");
394      } else {
395        expr = new StringBuffer("!i[");
396      }
397      expr.append(m_attIndex).append("]");
398      expr.append(".equals(\"").append(data.attribute(m_attIndex)
399                                     .value((int)m_splitPoint)).append("\")");
400    } else {
401      expr = new StringBuffer("((Double) i[");
402      expr.append(m_attIndex).append("])");
403      if (index == 0) {
404        expr.append(".doubleValue() <= ").append(m_splitPoint);
405      } else {
406        expr.append(".doubleValue() > ").append(m_splitPoint);
407      }
408    }
409    return expr.toString();
410  } 
411
412  /**
413   * Sets split point to greatest value in given data smaller or equal to
414   * old split point.
415   * (C4.5 does this for some strange reason).
416   */
417  public final void setSplitPoint(Instances allInstances){
418   
419    double newSplitPoint = -Double.MAX_VALUE;
420    double tempValue;
421    Instance instance;
422   
423    if ((!allInstances.attribute(m_attIndex).isNominal()) &&
424        (m_numSubsets > 1)){
425      Enumeration enu = allInstances.enumerateInstances();
426      while (enu.hasMoreElements()) {
427        instance = (Instance) enu.nextElement();
428        if (!instance.isMissing(m_attIndex)){
429          tempValue = instance.value(m_attIndex);
430          if (Utils.gr(tempValue,newSplitPoint) && 
431              Utils.smOrEq(tempValue,m_splitPoint))
432            newSplitPoint = tempValue;
433        }
434      }
435      m_splitPoint = newSplitPoint;
436    }
437  }
438 
439  /**
440   * Sets distribution associated with model.
441   */
442  public void resetDistribution(Instances data) throws Exception {
443
444    Instances insts = new Instances(data, data.numInstances());
445    for (int i = 0; i < data.numInstances(); i++) {
446      if (whichSubset(data.instance(i)) > -1) {
447        insts.add(data.instance(i));
448      }
449    }
450    Distribution newD = new Distribution(insts, this);
451    newD.addInstWithUnknown(data, m_attIndex);
452    m_distribution = newD;
453  }
454
455  /**
456   * Returns weights if instance is assigned to more than one subset.
457   * Returns null if instance is only assigned to one subset.
458   */
459  public final double [] weights(Instance instance){
460   
461    double [] weights;
462    int i;
463   
464    if (instance.isMissing(m_attIndex)){
465      weights = new double [m_numSubsets];
466      for (i=0;i<m_numSubsets;i++)
467        weights [i] = m_distribution.perBag(i)/m_distribution.total();
468      return weights;
469    }else{
470      return null;
471    }
472  }
473 
474  /**
475   * Returns index of subset instance is assigned to.
476   * Returns -1 if instance is assigned to more than one subset.
477   *
478   * @exception Exception if something goes wrong
479   */
480
481  public final int whichSubset(Instance instance) throws Exception {
482   
483    if (instance.isMissing(m_attIndex))
484      return -1;
485    else{
486      if (instance.attribute(m_attIndex).isNominal()){
487        if ((int)m_splitPoint == (int)instance.value(m_attIndex))
488          return 0;
489        else
490          return 1;
491      }else
492        if (Utils.smOrEq(instance.value(m_attIndex),m_splitPoint))
493          return 0;
494        else
495          return 1;
496    }
497  }
498 
499  /**
500   * Returns the revision string.
501   *
502   * @return            the revision
503   */
504  public String getRevision() {
505    return RevisionUtils.extract("$Revision: 6073 $");
506  }
507}
Note: See TracBrowser for help on using the repository browser.