source: src/main/java/weka/classifiers/trees/j48/C45Split.java @ 12

Last change on this file since 12 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 14.0 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    C45Split.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.trees.j48;
24
25import weka.core.Instance;
26import weka.core.Instances;
27import weka.core.RevisionUtils;
28import weka.core.Utils;
29
30import java.util.Enumeration;
31
32/**
33 * Class implementing a C4.5-type split on an attribute.
34 *
35 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
36 * @version $Revision: 6073 $
37 */
38public class C45Split
39  extends ClassifierSplitModel{
40
41  /** for serialization */
42  private static final long serialVersionUID = 3064079330067903161L;
43
44  /** Desired number of branches. */
45  private int m_complexityIndex; 
46
47  /** Attribute to split on. */
48  private int m_attIndex;         
49
50  /** Minimum number of objects in a split.   */
51  private int m_minNoObj;         
52
53  /** Use MDL correction? */
54  private boolean m_useMDLcorrection;         
55
56  /** Value of split point. */
57  private double m_splitPoint;   
58
59  /** InfoGain of split. */ 
60  private double m_infoGain; 
61
62  /** GainRatio of split.  */
63  private double m_gainRatio; 
64
65  /** The sum of the weights of the instances. */
66  private double m_sumOfWeights; 
67
68  /** Number of split points. */
69  private int m_index;           
70
71  /** Static reference to splitting criterion. */
72  private static InfoGainSplitCrit infoGainCrit = new InfoGainSplitCrit();
73
74  /** Static reference to splitting criterion. */
75  private static GainRatioSplitCrit gainRatioCrit = new GainRatioSplitCrit();
76
77  /**
78   * Initializes the split model.
79   */
80  public C45Split(int attIndex,int minNoObj, double sumOfWeights,
81                  boolean useMDLcorrection) {
82
83    // Get index of attribute to split on.
84    m_attIndex = attIndex;
85       
86    // Set minimum number of objects.
87    m_minNoObj = minNoObj;
88
89    // Set the sum of the weights
90    m_sumOfWeights = sumOfWeights;
91
92    // Whether to use the MDL correction for numeric attributes
93    m_useMDLcorrection = useMDLcorrection;
94  }
95
96  /**
97   * Creates a C4.5-type split on the given data. Assumes that none of
98   * the class values is missing.
99   *
100   * @exception Exception if something goes wrong
101   */
102  public void buildClassifier(Instances trainInstances) 
103       throws Exception {
104
105    // Initialize the remaining instance variables.
106    m_numSubsets = 0;
107    m_splitPoint = Double.MAX_VALUE;
108    m_infoGain = 0;
109    m_gainRatio = 0;
110
111    // Different treatment for enumerated and numeric
112    // attributes.
113    if (trainInstances.attribute(m_attIndex).isNominal()) {
114      m_complexityIndex = trainInstances.attribute(m_attIndex).numValues();
115      m_index = m_complexityIndex;
116      handleEnumeratedAttribute(trainInstances);
117    }else{
118      m_complexityIndex = 2;
119      m_index = 0;
120      trainInstances.sort(trainInstances.attribute(m_attIndex));
121      handleNumericAttribute(trainInstances);
122    }
123  }   
124
125  /**
126   * Returns index of attribute for which split was generated.
127   */
128  public final int attIndex() {
129
130    return m_attIndex;
131  }
132 
133  /**
134   * Returns the split point (numeric attribute only).
135   *
136   * @return the split point used for a test on a numeric attribute
137   */
138  public double splitPoint() {
139    return m_splitPoint;
140  }
141
142  /**
143   * Gets class probability for instance.
144   *
145   * @exception Exception if something goes wrong
146   */
147  public final double classProb(int classIndex,Instance instance,
148                                int theSubset) throws Exception {
149
150    if (theSubset <= -1) {
151      double [] weights = weights(instance);
152      if (weights == null) {
153        return m_distribution.prob(classIndex);
154      } else {
155        double prob = 0;
156        for (int i = 0; i < weights.length; i++) {
157          prob += weights[i] * m_distribution.prob(classIndex, i);
158        }
159        return prob;
160      }
161    } else {
162      if (Utils.gr(m_distribution.perBag(theSubset), 0)) {
163        return m_distribution.prob(classIndex, theSubset);
164      } else {
165        return m_distribution.prob(classIndex);
166      }
167    }
168  }
169 
170  /**
171   * Returns coding cost for split (used in rule learner).
172   */
173  public final double codingCost() {
174
175    return Utils.log2(m_index);
176  }
177 
178  /**
179   * Returns (C4.5-type) gain ratio for the generated split.
180   */
181  public final double gainRatio() {
182    return m_gainRatio;
183  }
184
185  /**
186   * Creates split on enumerated attribute.
187   *
188   * @exception Exception if something goes wrong
189   */
190  private void handleEnumeratedAttribute(Instances trainInstances)
191       throws Exception {
192   
193    Instance instance;
194
195    m_distribution = new Distribution(m_complexityIndex,
196                              trainInstances.numClasses());
197   
198    // Only Instances with known values are relevant.
199    Enumeration enu = trainInstances.enumerateInstances();
200    while (enu.hasMoreElements()) {
201      instance = (Instance) enu.nextElement();
202      if (!instance.isMissing(m_attIndex))
203        m_distribution.add((int)instance.value(m_attIndex),instance);
204    }
205   
206    // Check if minimum number of Instances in at least two
207    // subsets.
208    if (m_distribution.check(m_minNoObj)) {
209      m_numSubsets = m_complexityIndex;
210      m_infoGain = infoGainCrit.
211        splitCritValue(m_distribution,m_sumOfWeights);
212      m_gainRatio = 
213        gainRatioCrit.splitCritValue(m_distribution,m_sumOfWeights,
214                                     m_infoGain);
215    }
216  }
217 
218  /**
219   * Creates split on numeric attribute.
220   *
221   * @exception Exception if something goes wrong
222   */
223  private void handleNumericAttribute(Instances trainInstances)
224       throws Exception {
225 
226    int firstMiss;
227    int next = 1;
228    int last = 0;
229    int splitIndex = -1;
230    double currentInfoGain;
231    double defaultEnt;
232    double minSplit;
233    Instance instance;
234    int i;
235
236    // Current attribute is a numeric attribute.
237    m_distribution = new Distribution(2,trainInstances.numClasses());
238   
239    // Only Instances with known values are relevant.
240    Enumeration enu = trainInstances.enumerateInstances();
241    i = 0;
242    while (enu.hasMoreElements()) {
243      instance = (Instance) enu.nextElement();
244      if (instance.isMissing(m_attIndex))
245        break;
246      m_distribution.add(1,instance);
247      i++;
248    }
249    firstMiss = i;
250       
251    // Compute minimum number of Instances required in each
252    // subset.
253    minSplit =  0.1*(m_distribution.total())/
254      ((double)trainInstances.numClasses());
255    if (Utils.smOrEq(minSplit,m_minNoObj)) 
256      minSplit = m_minNoObj;
257    else
258      if (Utils.gr(minSplit,25)) 
259        minSplit = 25;
260       
261    // Enough Instances with known values?
262    if (Utils.sm((double)firstMiss,2*minSplit))
263      return;
264   
265    // Compute values of criteria for all possible split
266    // indices.
267    defaultEnt = infoGainCrit.oldEnt(m_distribution);
268    while (next < firstMiss) {
269         
270      if (trainInstances.instance(next-1).value(m_attIndex)+1e-5 < 
271          trainInstances.instance(next).value(m_attIndex)) { 
272       
273        // Move class values for all Instances up to next
274        // possible split point.
275        m_distribution.shiftRange(1,0,trainInstances,last,next);
276       
277        // Check if enough Instances in each subset and compute
278        // values for criteria.
279        if (Utils.grOrEq(m_distribution.perBag(0),minSplit) &&
280            Utils.grOrEq(m_distribution.perBag(1),minSplit)) {
281          currentInfoGain = infoGainCrit.
282            splitCritValue(m_distribution,m_sumOfWeights,
283                           defaultEnt);
284          if (Utils.gr(currentInfoGain,m_infoGain)) {
285            m_infoGain = currentInfoGain;
286            splitIndex = next-1;
287          }
288          m_index++;
289        }
290        last = next;
291      }
292      next++;
293    }
294   
295    // Was there any useful split?
296    if (m_index == 0)
297      return;
298   
299    // Compute modified information gain for best split.
300    if (m_useMDLcorrection) {
301      m_infoGain = m_infoGain-(Utils.log2(m_index)/m_sumOfWeights);
302    }
303    if (Utils.smOrEq(m_infoGain,0))
304      return;
305   
306    // Set instance variables' values to values for
307    // best split.
308    m_numSubsets = 2;
309    m_splitPoint = 
310      (trainInstances.instance(splitIndex+1).value(m_attIndex)+
311       trainInstances.instance(splitIndex).value(m_attIndex))/2;
312
313    // In case we have a numerical precision problem we need to choose the
314    // smaller value
315    if (m_splitPoint == trainInstances.instance(splitIndex + 1).value(m_attIndex)) {
316      m_splitPoint = trainInstances.instance(splitIndex).value(m_attIndex);
317    }
318
319    // Restore distributioN for best split.
320    m_distribution = new Distribution(2,trainInstances.numClasses());
321    m_distribution.addRange(0,trainInstances,0,splitIndex+1);
322    m_distribution.addRange(1,trainInstances,splitIndex+1,firstMiss);
323
324    // Compute modified gain ratio for best split.
325    m_gainRatio = gainRatioCrit.
326      splitCritValue(m_distribution,m_sumOfWeights,
327                     m_infoGain);
328  }
329
330  /**
331   * Returns (C4.5-type) information gain for the generated split.
332   */
333  public final double infoGain() {
334
335    return m_infoGain;
336  }
337
338  /**
339   * Prints left side of condition..
340   *
341   * @param data training set.
342   */
343  public final String leftSide(Instances data) {
344
345    return data.attribute(m_attIndex).name();
346  }
347
348  /**
349   * Prints the condition satisfied by instances in a subset.
350   *
351   * @param index of subset
352   * @param data training set.
353   */
354  public final String rightSide(int index,Instances data) {
355
356    StringBuffer text;
357
358    text = new StringBuffer();
359    if (data.attribute(m_attIndex).isNominal())
360      text.append(" = "+
361                  data.attribute(m_attIndex).value(index));
362    else
363      if (index == 0)
364        text.append(" <= "+
365                    Utils.doubleToString(m_splitPoint,6));
366      else
367        text.append(" > "+
368                    Utils.doubleToString(m_splitPoint,6));
369    return text.toString();
370  }
371 
372  /**
373   * Returns a string containing java source code equivalent to the test
374   * made at this node. The instance being tested is called "i".
375   *
376   * @param index index of the nominal value tested
377   * @param data the data containing instance structure info
378   * @return a value of type 'String'
379   */
380  public final String sourceExpression(int index, Instances data) {
381
382    StringBuffer expr = null;
383    if (index < 0) {
384      return "i[" + m_attIndex + "] == null";
385    }
386    if (data.attribute(m_attIndex).isNominal()) {
387      expr = new StringBuffer("i[");
388      expr.append(m_attIndex).append("]");
389      expr.append(".equals(\"").append(data.attribute(m_attIndex)
390                                     .value(index)).append("\")");
391    } else {
392      expr = new StringBuffer("((Double) i[");
393      expr.append(m_attIndex).append("])");
394      if (index == 0) {
395        expr.append(".doubleValue() <= ").append(m_splitPoint);
396      } else {
397        expr.append(".doubleValue() > ").append(m_splitPoint);
398      }
399    }
400    return expr.toString();
401  } 
402
403  /**
404   * Sets split point to greatest value in given data smaller or equal to
405   * old split point.
406   * (C4.5 does this for some strange reason).
407   */
408  public final void setSplitPoint(Instances allInstances) {
409   
410    double newSplitPoint = -Double.MAX_VALUE;
411    double tempValue;
412    Instance instance;
413   
414    if ((allInstances.attribute(m_attIndex).isNumeric()) &&
415        (m_numSubsets > 1)) {
416      Enumeration enu = allInstances.enumerateInstances();
417      while (enu.hasMoreElements()) {
418        instance = (Instance) enu.nextElement();
419        if (!instance.isMissing(m_attIndex)) {
420          tempValue = instance.value(m_attIndex);
421          if (Utils.gr(tempValue,newSplitPoint) && 
422              Utils.smOrEq(tempValue,m_splitPoint))
423            newSplitPoint = tempValue;
424        }
425      }
426      m_splitPoint = newSplitPoint;
427    }
428  }
429 
430  /**
431   * Returns the minsAndMaxs of the index.th subset.
432   */
433  public final double [][] minsAndMaxs(Instances data, double [][] minsAndMaxs,
434                                       int index) {
435
436    double [][] newMinsAndMaxs = new double[data.numAttributes()][2];
437   
438    for (int i = 0; i < data.numAttributes(); i++) {
439      newMinsAndMaxs[i][0] = minsAndMaxs[i][0];
440      newMinsAndMaxs[i][1] = minsAndMaxs[i][1];
441      if (i == m_attIndex)
442        if (data.attribute(m_attIndex).isNominal())
443          newMinsAndMaxs[m_attIndex][1] = 1;
444        else
445          newMinsAndMaxs[m_attIndex][1-index] = m_splitPoint;
446    }
447
448    return newMinsAndMaxs;
449  }
450 
451  /**
452   * Sets distribution associated with model.
453   */
454  public void resetDistribution(Instances data) throws Exception {
455   
456    Instances insts = new Instances(data, data.numInstances());
457    for (int i = 0; i < data.numInstances(); i++) {
458      if (whichSubset(data.instance(i)) > -1) {
459        insts.add(data.instance(i));
460      }
461    }
462    Distribution newD = new Distribution(insts, this);
463    newD.addInstWithUnknown(data, m_attIndex);
464    m_distribution = newD;
465  }
466
467  /**
468   * Returns weights if instance is assigned to more than one subset.
469   * Returns null if instance is only assigned to one subset.
470   */
471  public final double [] weights(Instance instance) {
472   
473    double [] weights;
474    int i;
475   
476    if (instance.isMissing(m_attIndex)) {
477      weights = new double [m_numSubsets];
478      for (i=0;i<m_numSubsets;i++)
479        weights [i] = m_distribution.perBag(i)/m_distribution.total();
480      return weights;
481    }else{
482      return null;
483    }
484  }
485 
486  /**
487   * Returns index of subset instance is assigned to.
488   * Returns -1 if instance is assigned to more than one subset.
489   *
490   * @exception Exception if something goes wrong
491   */
492  public final int whichSubset(Instance instance) 
493       throws Exception {
494   
495    if (instance.isMissing(m_attIndex))
496      return -1;
497    else{
498      if (instance.attribute(m_attIndex).isNominal())
499        return (int)instance.value(m_attIndex);
500      else
501        if (Utils.smOrEq(instance.value(m_attIndex),m_splitPoint))
502          return 0;
503        else
504          return 1;
505    }
506  }
507 
508  /**
509   * Returns the revision string.
510   *
511   * @return            the revision
512   */
513  public String getRevision() {
514    return RevisionUtils.extract("$Revision: 6073 $");
515  }
516}
Note: See TracBrowser for help on using the repository browser.