source: src/main/java/weka/classifiers/trees/lmt/ResidualSplit.java @ 17

Last change on this file since 17 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 8.3 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    ResidualSplit.java
19 *    Copyright (C) 2003 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.trees.lmt;
24
25import weka.classifiers.trees.j48.ClassifierSplitModel;
26import weka.classifiers.trees.j48.Distribution;
27import weka.core.Attribute;
28import weka.core.Instance;
29import weka.core.Instances;
30import weka.core.RevisionUtils;
31import weka.core.Utils;
32
33/**
34 * Helper class for logistic model trees (weka.classifiers.trees.lmt.LMT) to implement the
35 * splitting criterion based on residuals of the LogitBoost algorithm.
36 *
37 * @author Niels Landwehr
38 * @version $Revision: 1.4 $
39 */
40public class ResidualSplit
41  extends ClassifierSplitModel{
42
43  /** for serialization */
44  private static final long serialVersionUID = -5055883734183713525L;
45 
46  /**The attribute selected for the split*/
47  protected Attribute m_attribute;
48
49  /**The index of the attribute selected for the split*/
50  protected int m_attIndex;
51
52  /**Number of instances in the set*/
53  protected int m_numInstances;
54
55  /**Number of classed*/
56  protected int m_numClasses;
57
58  /**The set of instances*/
59  protected Instances m_data;
60
61  /**The Z-values (LogitBoost response) for the set of instances*/
62  protected double[][] m_dataZs;
63
64  /**The LogitBoost-weights for the set of instances*/
65  protected double[][] m_dataWs; 
66
67  /**The split point (for numeric attributes)*/
68  protected double m_splitPoint;
69
70  /**
71   *Creates a split object
72   *@param attIndex the index of the attribute to split on
73   */   
74  public ResidualSplit(int attIndex) { 
75    m_attIndex = attIndex;             
76  }
77
78  /**
79   * Builds the split.
80   * Needs the Z/W values of LogitBoost for the set of instances.
81   */
82  public void buildClassifier(Instances data, double[][] dataZs, double[][] dataWs) 
83    throws Exception {
84
85    m_numClasses = data.numClasses();   
86    m_numInstances = data.numInstances();
87    if (m_numInstances == 0) throw new Exception("Can't build split on 0 instances");
88
89    //save data/Zs/Ws
90    m_data = data;
91    m_dataZs = dataZs;
92    m_dataWs = dataWs;
93    m_attribute = data.attribute(m_attIndex);
94
95    //determine number of subsets and split point for numeric attributes
96    if (m_attribute.isNominal()) {
97      m_splitPoint = 0.0;
98      m_numSubsets = m_attribute.numValues();
99    } else {
100      getSplitPoint();
101      m_numSubsets = 2;
102    }
103    //create distribution for data
104    m_distribution = new Distribution(data, this);     
105  }
106
107
108  /**
109   * Selects split point for numeric attribute.
110   */
111  protected boolean getSplitPoint() throws Exception{
112
113    //compute possible split points
114    double[] splitPoints = new double[m_numInstances];
115    int numSplitPoints = 0;
116
117    Instances sortedData = new Instances(m_data);
118    sortedData.sort(sortedData.attribute(m_attIndex));
119
120    double last, current;
121
122    last = sortedData.instance(0).value(m_attIndex);   
123
124    for (int i = 0; i < m_numInstances - 1; i++) {
125      current = sortedData.instance(i+1).value(m_attIndex);     
126      if (!Utils.eq(current, last)){
127        splitPoints[numSplitPoints++] = (last + current) / 2.0;
128      }
129      last = current;
130    }
131
132    //compute entropy for all split points
133    double[] entropyGain = new double[numSplitPoints];
134
135    for (int i = 0; i < numSplitPoints; i++) {
136      m_splitPoint = splitPoints[i];
137      entropyGain[i] = entropyGain();
138    }
139
140    //get best entropy gain
141    int bestSplit = -1;
142    double bestGain = -Double.MAX_VALUE;
143
144    for (int i = 0; i < numSplitPoints; i++) {
145      if (entropyGain[i] > bestGain) {
146        bestGain = entropyGain[i];
147        bestSplit = i;
148      }
149    }
150
151    if (bestSplit < 0) return false;
152
153    m_splitPoint = splitPoints[bestSplit];     
154    return true;
155  }
156
157  /**
158   * Computes entropy gain for current split.
159   */
160  public double entropyGain() throws Exception{
161
162    int numSubsets;
163    if (m_attribute.isNominal()) {
164      numSubsets = m_attribute.numValues();
165    } else {
166      numSubsets = 2;
167    }
168
169    double[][][] splitDataZs = new double[numSubsets][][];
170    double[][][] splitDataWs = new double[numSubsets][][];
171
172    //determine size of the subsets
173    int[] subsetSize = new int[numSubsets];
174    for (int i = 0; i < m_numInstances; i++) {
175      int subset = whichSubset(m_data.instance(i));
176      if (subset < 0) throw new Exception("ResidualSplit: no support for splits on missing values");
177      subsetSize[subset]++;
178    }
179
180    for (int i = 0; i < numSubsets; i++) {
181      splitDataZs[i] = new double[subsetSize[i]][];
182      splitDataWs[i] = new double[subsetSize[i]][];
183    }
184
185
186    int[] subsetCount = new int[numSubsets];
187
188    //sort Zs/Ws into subsets
189    for (int i = 0; i < m_numInstances; i++) {
190      int subset = whichSubset(m_data.instance(i));
191      splitDataZs[subset][subsetCount[subset]] = m_dataZs[i];
192      splitDataWs[subset][subsetCount[subset]] = m_dataWs[i];
193      subsetCount[subset]++;
194    }
195
196    //calculate entropy gain
197    double entropyOrig = entropy(m_dataZs, m_dataWs);
198
199    double entropySplit = 0.0;
200
201    for (int i = 0; i < numSubsets; i++) {
202      entropySplit += entropy(splitDataZs[i], splitDataWs[i]);
203    }
204
205    return entropyOrig - entropySplit;
206  }
207
208  /**
209   * Helper function to compute entropy from Z/W values.
210   */
211  protected double entropy(double[][] dataZs, double[][] dataWs){
212    //method returns entropy * sumOfWeights
213    double entropy = 0.0;
214    int numInstances = dataZs.length;
215
216    for (int j = 0; j < m_numClasses; j++) {
217
218      //compute mean for class
219      double m = 0.0;
220      double sum = 0.0;
221      for (int i = 0; i < numInstances; i++) {
222        m += dataZs[i][j] * dataWs[i][j];
223        sum += dataWs[i][j];
224      }
225      m /= sum;
226
227      //sum up entropy for class
228      for (int i = 0; i < numInstances; i++) {
229        entropy += dataWs[i][j] * Math.pow(dataZs[i][j] - m,2);
230      }
231
232    }
233
234    return entropy;
235  }
236
237  /**
238   * Checks if there are at least 2 subsets that contain >= minNumInstances.
239   */
240  public boolean checkModel(int minNumInstances){
241    //checks if there are at least 2 subsets that contain >= minNumInstances
242    int count = 0;
243    for (int i = 0; i < m_distribution.numBags(); i++) {
244      if (m_distribution.perBag(i) >= minNumInstances) count++; 
245    }
246    return (count >= 2);
247  }
248
249  /**
250   * Returns name of splitting attribute (left side of condition).
251   */
252  public final String leftSide(Instances data) {
253
254    return data.attribute(m_attIndex).name();
255  }
256
257  /**
258   * Prints the condition satisfied by instances in a subset.
259   */
260  public final String rightSide(int index,Instances data) {
261
262    StringBuffer text;
263
264    text = new StringBuffer();
265    if (data.attribute(m_attIndex).isNominal())
266      text.append(" = "+
267          data.attribute(m_attIndex).value(index));
268    else
269      if (index == 0)
270        text.append(" <= "+
271            Utils.doubleToString(m_splitPoint,6));
272      else
273        text.append(" > "+
274            Utils.doubleToString(m_splitPoint,6));
275    return text.toString();
276  }
277
278  public final int whichSubset(Instance instance) 
279  throws Exception {
280
281    if (instance.isMissing(m_attIndex))
282      return -1;
283    else{
284      if (instance.attribute(m_attIndex).isNominal())
285        return (int)instance.value(m_attIndex);
286      else
287        if (Utils.smOrEq(instance.value(m_attIndex),m_splitPoint))
288          return 0;
289        else
290          return 1;
291    }
292  }   
293
294  /** Method not in use*/
295  public void buildClassifier(Instances data) {
296    //method not in use
297  }
298
299  /**Method not in use*/
300  public final double [] weights(Instance instance){
301    //method not in use
302    return null;
303  } 
304
305  /**Method not in use*/
306  public final String sourceExpression(int index, Instances data) {
307    //method not in use
308    return "";
309  }
310 
311  /**
312   * Returns the revision string.
313   *
314   * @return            the revision
315   */
316  public String getRevision() {
317    return RevisionUtils.extract("$Revision: 1.4 $");
318  }
319}
Note: See TracBrowser for help on using the repository browser.