source: branches/MetisMQI/src/main/java/weka/classifiers/trees/m5/CorrelationSplitInfo.java

Last change on this file was 29, checked in by gnappo, 15 years ago

Taggata versione per la demo e aggiunto branch.

File size: 6.5 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 * CorrelationSplitInfo.java
19 * Copyright (C) 2000 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.trees.m5;
24
25import weka.core.Instances;
26import weka.core.RevisionHandler;
27import weka.core.RevisionUtils;
28import weka.core.Utils;
29import weka.experiment.PairedStats;
30
31import java.io.Serializable;
32
33/**
34 * Finds split points using correlation.
35 *
36 * @author Mark Hall (mhall@cs.waikato.ac.nz)
37 * @version $Revision: 1.4 $
38 */
39public final class CorrelationSplitInfo
40  implements Cloneable, Serializable, SplitEvaluate, RevisionHandler {
41
42  /** for serialization */
43  private static final long serialVersionUID = 4212734895125452770L;
44
45  /**
46   * the first instance
47   */
48  private int    m_first;
49
50  /**
51   * the last instance
52   */
53  private int    m_last;
54  private int    m_position;
55
56  /**
57   * the maximum impurity reduction
58   */
59  private double m_maxImpurity;
60
61  /**
62   * the attribute being tested
63   */
64  private int    m_splitAttr;
65
66  /**
67   * the best value on which to split
68   */
69  private double m_splitValue;
70
71  /**
72   * the number of instances
73   */
74  private int    m_number;
75
76  /**
77   * Constructs an object which contains the split information
78   *
79   * @param low the index of the first instance
80   * @param high the index of the last instance
81   * @param attr an attribute
82   */
83  public CorrelationSplitInfo(int low, int high, int attr) {
84    initialize(low, high, attr);
85  }
86
87  /**
88   * Makes a copy of this CorrelationSplitInfo object
89   */
90  public final SplitEvaluate copy() throws Exception {
91    CorrelationSplitInfo s = (CorrelationSplitInfo) this.clone();
92
93    return s;
94  } 
95
96  /**
97   * Resets the object of split information
98   *
99   * @param low the index of the first instance
100   * @param high the index of the last instance
101   * @param attr the attribute
102   */
103  public final void initialize(int low, int high, int attr) {
104    m_number = high - low + 1;
105    m_first = low;
106    m_last = high;
107    m_position = -1;
108    m_maxImpurity = -Double.MAX_VALUE;
109    m_splitAttr = attr;
110    m_splitValue = 0.0;
111  } 
112
113  /**
114   * Finds the best splitting point for an attribute in the instances
115   *
116   * @param attr the splitting attribute
117   * @param inst the instances
118   * @exception Exception if something goes wrong
119   */
120  public final void attrSplit(int attr, Instances inst) throws Exception {
121    int         i;
122    int         len;
123    int         part;
124    int         low = 0;
125    int         high = inst.numInstances() - 1;
126    PairedStats full = new PairedStats(0.01);
127    PairedStats leftSubset = new PairedStats(0.01);
128    PairedStats rightSubset = new PairedStats(0.01);
129    int         classIndex = inst.classIndex();
130    double      leftCorr, rightCorr;
131    double      leftVar, rightVar, allVar;
132    double      order = 2.0;
133
134    initialize(low, high, attr);
135
136    if (m_number < 4) {
137      return;
138    } 
139
140    len = ((high - low + 1) < 5) ? 1 : (high - low + 1) / 5;
141    m_position = low;
142    part = low + len - 1;
143
144    // prime the subsets
145    for (i = low; i < len; i++) {
146      full.add(inst.instance(i).value(attr), 
147               inst.instance(i).value(classIndex));
148      leftSubset.add(inst.instance(i).value(attr), 
149                     inst.instance(i).value(classIndex));
150    } 
151
152    for (i = len; i < inst.numInstances(); i++) {
153      full.add(inst.instance(i).value(attr), 
154               inst.instance(i).value(classIndex));
155      rightSubset.add(inst.instance(i).value(attr), 
156                      inst.instance(i).value(classIndex));
157    } 
158
159    full.calculateDerived();
160
161    allVar = (full.yStats.stdDev * full.yStats.stdDev);
162    allVar = Math.abs(allVar);
163    allVar = Math.pow(allVar, (1.0 / order));
164
165    for (i = low + len; i < high - len - 1; i++) {
166      rightSubset.subtract(inst.instance(i).value(attr), 
167                           inst.instance(i).value(classIndex));
168      leftSubset.add(inst.instance(i).value(attr), 
169                     inst.instance(i).value(classIndex));
170
171      if (!Utils.eq(inst.instance(i + 1).value(attr), 
172                    inst.instance(i).value(attr))) {
173        leftSubset.calculateDerived();
174        rightSubset.calculateDerived();
175
176        leftCorr = Math.abs(leftSubset.correlation);
177        rightCorr = Math.abs(rightSubset.correlation);
178        leftVar = (leftSubset.yStats.stdDev * leftSubset.yStats.stdDev);
179        leftVar = Math.abs(leftVar);
180        leftVar = Math.pow(leftVar, (1.0 / order));
181        rightVar = (rightSubset.yStats.stdDev * rightSubset.yStats.stdDev);
182        rightVar = Math.abs(rightVar);
183        rightVar = Math.pow(rightVar, (1.0 / order));
184
185        double score = allVar - ((leftSubset.count / full.count) * leftVar) 
186                       - ((rightSubset.count / full.count) * rightVar);
187
188        // score /= allVar;
189        leftCorr = (leftSubset.count / full.count) * leftCorr;
190        rightCorr = (rightSubset.count / full.count) * rightCorr;
191
192        double c_score = (leftCorr + rightCorr) - Math.abs(full.correlation);
193
194        // c_score += score;
195        if (!Utils.eq(score, 0.0)) {
196          if (score > m_maxImpurity) {
197            m_maxImpurity = score;
198            m_splitValue = 
199              (inst.instance(i).value(attr) + inst.instance(i + 1)
200              .value(attr)) * 0.5;
201            m_position = i;
202          } 
203        } 
204      } 
205    } 
206  } 
207
208  /**
209   * Returns the impurity of this split
210   *
211   * @return the impurity of this split
212   */
213  public double maxImpurity() {
214    return m_maxImpurity;
215  } 
216
217  /**
218   * Returns the attribute used in this split
219   *
220   * @return the attribute used in this split
221   */
222  public int splitAttr() {
223    return m_splitAttr;
224  } 
225
226  /**
227   * Returns the position of the split in the sorted values. -1 indicates that
228   * a split could not be found.
229   *
230   * @return an <code>int</code> value
231   */
232  public int position() {
233    return m_position;
234  } 
235
236  /**
237   * Returns the split value
238   *
239   * @return the split value
240   */
241  public double splitValue() {
242    return m_splitValue;
243  } 
244 
245  /**
246   * Returns the revision string.
247   *
248   * @return            the revision
249   */
250  public String getRevision() {
251    return RevisionUtils.extract("$Revision: 1.4 $");
252  }
253}
Note: See TracBrowser for help on using the repository browser.