source: src/main/java/weka/classifiers/trees/DecisionStump.java @ 15

Last change on this file since 15 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 23.7 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    DecisionStump.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.trees;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.classifiers.Sourcable;
28import weka.core.Attribute;
29import weka.core.Capabilities;
30import weka.core.ContingencyTables;
31import weka.core.Instance;
32import weka.core.Instances;
33import weka.core.RevisionUtils;
34import weka.core.Utils;
35import weka.core.WeightedInstancesHandler;
36import weka.core.Capabilities.Capability;
37
38/**
39 <!-- globalinfo-start -->
40 * Class for building and using a decision stump. Usually used in conjunction with a boosting algorithm. Does regression (based on mean-squared error) or classification (based on entropy). Missing is treated as a separate value.
41 * <p/>
42 <!-- globalinfo-end -->
43 *
44 * Typical usage: <p>
45 * <code>java weka.classifiers.meta.LogitBoost -I 100 -W weka.classifiers.trees.DecisionStump
46 * -t training_data </code><p>
47 *
48 <!-- options-start -->
49 * Valid options are: <p/>
50 *
51 * <pre> -D
52 *  If set, classifier is run in debug mode and
53 *  may output additional info to the console</pre>
54 *
55 <!-- options-end -->
56 *
57 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
58 * @version $Revision: 5928 $
59 */
60public class DecisionStump 
61  extends AbstractClassifier
62  implements WeightedInstancesHandler, Sourcable {
63
64  /** for serialization */
65  static final long serialVersionUID = 1618384535950391L;
66 
67  /** The attribute used for classification. */
68  private int m_AttIndex;
69
70  /** The split point (index respectively). */
71  private double m_SplitPoint;
72
73  /** The distribution of class values or the means in each subset. */
74  private double[][] m_Distribution;
75
76  /** The instances used for training. */
77  private Instances m_Instances;
78
79  /** a ZeroR model in case no model can be built from the data */
80  private Classifier m_ZeroR;
81   
82  /**
83   * Returns a string describing classifier
84   * @return a description suitable for
85   * displaying in the explorer/experimenter gui
86   */
87  public String globalInfo() {
88
89    return  "Class for building and using a decision stump. Usually used in "
90      + "conjunction with a boosting algorithm. Does regression (based on "
91      + "mean-squared error) or classification (based on entropy). Missing "
92      + "is treated as a separate value.";
93  }
94
95  /**
96   * Returns default capabilities of the classifier.
97   *
98   * @return      the capabilities of this classifier
99   */
100  public Capabilities getCapabilities() {
101    Capabilities result = super.getCapabilities();
102    result.disableAll();
103
104    // attributes
105    result.enable(Capability.NOMINAL_ATTRIBUTES);
106    result.enable(Capability.NUMERIC_ATTRIBUTES);
107    result.enable(Capability.DATE_ATTRIBUTES);
108    result.enable(Capability.MISSING_VALUES);
109
110    // class
111    result.enable(Capability.NOMINAL_CLASS);
112    result.enable(Capability.NUMERIC_CLASS);
113    result.enable(Capability.DATE_CLASS);
114    result.enable(Capability.MISSING_CLASS_VALUES);
115   
116    return result;
117  }
118
119  /**
120   * Generates the classifier.
121   *
122   * @param instances set of instances serving as training data
123   * @throws Exception if the classifier has not been generated successfully
124   */
125  public void buildClassifier(Instances instances) throws Exception {
126   
127    double bestVal = Double.MAX_VALUE, currVal;
128    double bestPoint = -Double.MAX_VALUE;
129    int bestAtt = -1, numClasses;
130
131    // can classifier handle the data?
132    getCapabilities().testWithFail(instances);
133
134    // remove instances with missing class
135    instances = new Instances(instances);
136    instances.deleteWithMissingClass();
137   
138    // only class? -> build ZeroR model
139    if (instances.numAttributes() == 1) {
140      System.err.println(
141          "Cannot build model (only class attribute present in data!), "
142          + "using ZeroR model instead!");
143      m_ZeroR = new weka.classifiers.rules.ZeroR();
144      m_ZeroR.buildClassifier(instances);
145      return;
146    }
147    else {
148      m_ZeroR = null;
149    }
150   
151    double[][] bestDist = new double[3][instances.numClasses()];
152
153    m_Instances = new Instances(instances);
154
155    if (m_Instances.classAttribute().isNominal()) {
156      numClasses = m_Instances.numClasses();
157    } else {
158      numClasses = 1;
159    }
160
161    // For each attribute
162    boolean first = true;
163    for (int i = 0; i < m_Instances.numAttributes(); i++) {
164      if (i != m_Instances.classIndex()) {
165
166        // Reserve space for distribution.
167        m_Distribution = new double[3][numClasses];
168
169        // Compute value of criterion for best split on attribute
170        if (m_Instances.attribute(i).isNominal()) {
171          currVal = findSplitNominal(i);
172        } else {
173          currVal = findSplitNumeric(i);
174        }
175        if ((first) || (currVal < bestVal)) {
176          bestVal = currVal;
177          bestAtt = i;
178          bestPoint = m_SplitPoint;
179          for (int j = 0; j < 3; j++) {
180            System.arraycopy(m_Distribution[j], 0, bestDist[j], 0, 
181                             numClasses);
182          }
183        }
184       
185        // First attribute has been investigated
186        first = false;
187      }
188    }
189   
190    // Set attribute, split point and distribution.
191    m_AttIndex = bestAtt;
192    m_SplitPoint = bestPoint;
193    m_Distribution = bestDist;
194    if (m_Instances.classAttribute().isNominal()) {
195      for (int i = 0; i < m_Distribution.length; i++) {
196        double sumCounts = Utils.sum(m_Distribution[i]);
197        if (sumCounts == 0) { // This means there were only missing attribute values
198          System.arraycopy(m_Distribution[2], 0, m_Distribution[i], 0, 
199                           m_Distribution[2].length);
200          Utils.normalize(m_Distribution[i]);
201        } else {
202          Utils.normalize(m_Distribution[i], sumCounts); 
203        }
204      }
205    }
206   
207    // Save memory
208    m_Instances = new Instances(m_Instances, 0);
209  }
210
211  /**
212   * Calculates the class membership probabilities for the given test instance.
213   *
214   * @param instance the instance to be classified
215   * @return predicted class probability distribution
216   * @throws Exception if distribution can't be computed
217   */
218  public double[] distributionForInstance(Instance instance) throws Exception {
219
220    // default model?
221    if (m_ZeroR != null) {
222      return m_ZeroR.distributionForInstance(instance);
223    }
224   
225    return m_Distribution[whichSubset(instance)];
226  }
227
228  /**
229   * Returns the decision tree as Java source code.
230   *
231   * @param className the classname of the generated code
232   * @return the tree as Java source code
233   * @throws Exception if something goes wrong
234   */
235  public String toSource(String className) throws Exception {
236
237    StringBuffer text = new StringBuffer("class ");
238    Attribute c = m_Instances.classAttribute();
239    text.append(className)
240      .append(" {\n"
241              +"  public static double classify(Object[] i) {\n");
242    text.append("    /* " + m_Instances.attribute(m_AttIndex).name() + " */\n");
243    text.append("    if (i[").append(m_AttIndex);
244    text.append("] == null) { return ");
245    text.append(sourceClass(c, m_Distribution[2])).append(";");
246    if (m_Instances.attribute(m_AttIndex).isNominal()) {
247      text.append(" } else if (((String)i[").append(m_AttIndex);
248      text.append("]).equals(\"");
249      text.append(m_Instances.attribute(m_AttIndex).value((int)m_SplitPoint));
250      text.append("\")");
251    } else {
252      text.append(" } else if (((Double)i[").append(m_AttIndex);
253      text.append("]).doubleValue() <= ").append(m_SplitPoint);
254    }
255    text.append(") { return ");
256    text.append(sourceClass(c, m_Distribution[0])).append(";");
257    text.append(" } else { return ");
258    text.append(sourceClass(c, m_Distribution[1])).append(";");
259    text.append(" }\n  }\n}\n");
260    return text.toString();
261  }
262
263  /**
264   * Returns the value as string out of the given distribution
265   *
266   * @param c the attribute to get the value for
267   * @param dist the distribution to extract the value
268   * @return the value
269   */
270  private String sourceClass(Attribute c, double []dist) {
271
272    if (c.isNominal()) {
273      return Integer.toString(Utils.maxIndex(dist));
274    } else {
275      return Double.toString(dist[0]);
276    }
277  }
278
279  /**
280   * Returns a description of the classifier.
281   *
282   * @return a description of the classifier as a string.
283   */
284  public String toString(){
285
286    // only ZeroR model?
287    if (m_ZeroR != null) {
288      StringBuffer buf = new StringBuffer();
289      buf.append(this.getClass().getName().replaceAll(".*\\.", "") + "\n");
290      buf.append(this.getClass().getName().replaceAll(".*\\.", "").replaceAll(".", "=") + "\n\n");
291      buf.append("Warning: No model could be built, hence ZeroR model is used:\n\n");
292      buf.append(m_ZeroR.toString());
293      return buf.toString();
294    }
295   
296    if (m_Instances == null) {
297      return "Decision Stump: No model built yet.";
298    }
299    try {
300      StringBuffer text = new StringBuffer();
301     
302      text.append("Decision Stump\n\n");
303      text.append("Classifications\n\n");
304      Attribute att = m_Instances.attribute(m_AttIndex);
305      if (att.isNominal()) {
306        text.append(att.name() + " = " + att.value((int)m_SplitPoint) + 
307                    " : ");
308        text.append(printClass(m_Distribution[0]));
309        text.append(att.name() + " != " + att.value((int)m_SplitPoint) + 
310                    " : ");
311        text.append(printClass(m_Distribution[1]));
312      } else {
313        text.append(att.name() + " <= " + m_SplitPoint + " : ");
314        text.append(printClass(m_Distribution[0]));
315        text.append(att.name() + " > " + m_SplitPoint + " : ");
316        text.append(printClass(m_Distribution[1]));
317      }
318      text.append(att.name() + " is missing : ");
319      text.append(printClass(m_Distribution[2]));
320
321      if (m_Instances.classAttribute().isNominal()) {
322        text.append("\nClass distributions\n\n");
323        if (att.isNominal()) {
324          text.append(att.name() + " = " + att.value((int)m_SplitPoint) + 
325                      "\n");
326          text.append(printDist(m_Distribution[0]));
327          text.append(att.name() + " != " + att.value((int)m_SplitPoint) + 
328                      "\n");
329          text.append(printDist(m_Distribution[1]));
330        } else {
331          text.append(att.name() + " <= " + m_SplitPoint + "\n");
332          text.append(printDist(m_Distribution[0]));
333          text.append(att.name() + " > " + m_SplitPoint + "\n");
334          text.append(printDist(m_Distribution[1]));
335        }
336        text.append(att.name() + " is missing\n");
337        text.append(printDist(m_Distribution[2]));
338      }
339
340      return text.toString();
341    } catch (Exception e) {
342      return "Can't print decision stump classifier!";
343    }
344  }
345
346  /**
347   * Prints a class distribution.
348   *
349   * @param dist the class distribution to print
350   * @return the distribution as a string
351   * @throws Exception if distribution can't be printed
352   */
353  private String printDist(double[] dist) throws Exception {
354
355    StringBuffer text = new StringBuffer();
356   
357    if (m_Instances.classAttribute().isNominal()) {
358      for (int i = 0; i < m_Instances.numClasses(); i++) {
359        text.append(m_Instances.classAttribute().value(i) + "\t");
360      }
361      text.append("\n");
362      for (int i = 0; i < m_Instances.numClasses(); i++) {
363        text.append(dist[i] + "\t");
364      }
365      text.append("\n");
366    }
367   
368    return text.toString();
369  }
370
371  /**
372   * Prints a classification.
373   *
374   * @param dist the class distribution
375   * @return the classificationn as a string
376   * @throws Exception if the classification can't be printed
377   */
378  private String printClass(double[] dist) throws Exception {
379
380    StringBuffer text = new StringBuffer();
381   
382    if (m_Instances.classAttribute().isNominal()) {
383      text.append(m_Instances.classAttribute().value(Utils.maxIndex(dist)));
384    } else {
385      text.append(dist[0]);
386    }
387   
388    return text.toString() + "\n";
389  }
390
391  /**
392   * Finds best split for nominal attribute and returns value.
393   *
394   * @param index attribute index
395   * @return value of criterion for the best split
396   * @throws Exception if something goes wrong
397   */
398  private double findSplitNominal(int index) throws Exception {
399
400    if (m_Instances.classAttribute().isNominal()) {
401      return findSplitNominalNominal(index);
402    } else {
403      return findSplitNominalNumeric(index);
404    }
405  }
406
407  /**
408   * Finds best split for nominal attribute and nominal class
409   * and returns value.
410   *
411   * @param index attribute index
412   * @return value of criterion for the best split
413   * @throws Exception if something goes wrong
414   */
415  private double findSplitNominalNominal(int index) throws Exception {
416
417    double bestVal = Double.MAX_VALUE, currVal;
418    double[][] counts = new double[m_Instances.attribute(index).numValues() 
419                                  + 1][m_Instances.numClasses()];
420    double[] sumCounts = new double[m_Instances.numClasses()];
421    double[][] bestDist = new double[3][m_Instances.numClasses()];
422    int numMissing = 0;
423
424    // Compute counts for all the values
425    for (int i = 0; i < m_Instances.numInstances(); i++) {
426      Instance inst = m_Instances.instance(i);
427      if (inst.isMissing(index)) {
428        numMissing++;
429        counts[m_Instances.attribute(index).numValues()]
430          [(int)inst.classValue()] += inst.weight();
431      } else {
432        counts[(int)inst.value(index)][(int)inst.classValue()] += inst
433          .weight();
434      }
435    }
436
437    // Compute sum of counts
438    for (int i = 0; i < m_Instances.attribute(index).numValues(); i++) {
439      for (int j = 0; j < m_Instances.numClasses(); j++) {
440        sumCounts[j] += counts[i][j];
441      }
442    }
443   
444    // Make split counts for each possible split and evaluate
445    System.arraycopy(counts[m_Instances.attribute(index).numValues()], 0,
446                     m_Distribution[2], 0, m_Instances.numClasses());
447    for (int i = 0; i < m_Instances.attribute(index).numValues(); i++) {
448      for (int j = 0; j < m_Instances.numClasses(); j++) {
449        m_Distribution[0][j] = counts[i][j];
450        m_Distribution[1][j] = sumCounts[j] - counts[i][j];
451      }
452      currVal = ContingencyTables.entropyConditionedOnRows(m_Distribution);
453      if (currVal < bestVal) {
454        bestVal = currVal;
455        m_SplitPoint = (double)i;
456        for (int j = 0; j < 3; j++) {
457          System.arraycopy(m_Distribution[j], 0, bestDist[j], 0, 
458                           m_Instances.numClasses());
459        }
460      }
461    }
462
463    // No missing values in training data.
464    if (numMissing == 0) {
465      System.arraycopy(sumCounts, 0, bestDist[2], 0, 
466                       m_Instances.numClasses());
467    }
468   
469    m_Distribution = bestDist;
470    return bestVal;
471  }
472
473  /**
474   * Finds best split for nominal attribute and numeric class
475   * and returns value.
476   *
477   * @param index attribute index
478   * @return value of criterion for the best split
479   * @throws Exception if something goes wrong
480   */
481  private double findSplitNominalNumeric(int index) throws Exception {
482
483    double bestVal = Double.MAX_VALUE, currVal;
484    double[] sumsSquaresPerValue = 
485      new double[m_Instances.attribute(index).numValues()], 
486      sumsPerValue = new double[m_Instances.attribute(index).numValues()], 
487      weightsPerValue = new double[m_Instances.attribute(index).numValues()];
488    double totalSumSquaresW = 0, totalSumW = 0, totalSumOfWeightsW = 0,
489      totalSumOfWeights = 0, totalSum = 0;
490    double[] sumsSquares = new double[3], sumOfWeights = new double[3];
491    double[][] bestDist = new double[3][1];
492
493    // Compute counts for all the values
494    for (int i = 0; i < m_Instances.numInstances(); i++) {
495      Instance inst = m_Instances.instance(i);
496      if (inst.isMissing(index)) {
497        m_Distribution[2][0] += inst.classValue() * inst.weight();
498        sumsSquares[2] += inst.classValue() * inst.classValue() 
499          * inst.weight();
500        sumOfWeights[2] += inst.weight();
501      } else {
502        weightsPerValue[(int)inst.value(index)] += inst.weight();
503        sumsPerValue[(int)inst.value(index)] += inst.classValue() 
504          * inst.weight();
505        sumsSquaresPerValue[(int)inst.value(index)] += 
506          inst.classValue() * inst.classValue() * inst.weight();
507      }
508      totalSumOfWeights += inst.weight();
509      totalSum += inst.classValue() * inst.weight();
510    }
511
512    // Check if the total weight is zero
513    if (totalSumOfWeights <= 0) {
514      return bestVal;
515    }
516
517    // Compute sum of counts without missing ones
518    for (int i = 0; i < m_Instances.attribute(index).numValues(); i++) {
519      totalSumOfWeightsW += weightsPerValue[i];
520      totalSumSquaresW += sumsSquaresPerValue[i];
521      totalSumW += sumsPerValue[i];
522    }
523   
524    // Make split counts for each possible split and evaluate
525    for (int i = 0; i < m_Instances.attribute(index).numValues(); i++) {
526     
527      m_Distribution[0][0] = sumsPerValue[i];
528      sumsSquares[0] = sumsSquaresPerValue[i];
529      sumOfWeights[0] = weightsPerValue[i];
530      m_Distribution[1][0] = totalSumW - sumsPerValue[i];
531      sumsSquares[1] = totalSumSquaresW - sumsSquaresPerValue[i];
532      sumOfWeights[1] = totalSumOfWeightsW - weightsPerValue[i];
533
534      currVal = variance(m_Distribution, sumsSquares, sumOfWeights);
535     
536      if (currVal < bestVal) {
537        bestVal = currVal;
538        m_SplitPoint = (double)i;
539        for (int j = 0; j < 3; j++) {
540          if (sumOfWeights[j] > 0) {
541            bestDist[j][0] = m_Distribution[j][0] / sumOfWeights[j];
542          } else {
543            bestDist[j][0] = totalSum / totalSumOfWeights;
544          }
545        }
546      }
547    }
548
549    m_Distribution = bestDist;
550    return bestVal;
551  }
552
553  /**
554   * Finds best split for numeric attribute and returns value.
555   *
556   * @param index attribute index
557   * @return value of criterion for the best split
558   * @throws Exception if something goes wrong
559   */
560  private double findSplitNumeric(int index) throws Exception {
561
562    if (m_Instances.classAttribute().isNominal()) {
563      return findSplitNumericNominal(index);
564    } else {
565      return findSplitNumericNumeric(index);
566    }
567  }
568
569  /**
570   * Finds best split for numeric attribute and nominal class
571   * and returns value.
572   *
573   * @param index attribute index
574   * @return value of criterion for the best split
575   * @throws Exception if something goes wrong
576   */
577  private double findSplitNumericNominal(int index) throws Exception {
578
579    double bestVal = Double.MAX_VALUE, currVal, currCutPoint;
580    int numMissing = 0;
581    double[] sum = new double[m_Instances.numClasses()];
582    double[][] bestDist = new double[3][m_Instances.numClasses()];
583
584    // Compute counts for all the values
585    for (int i = 0; i < m_Instances.numInstances(); i++) {
586      Instance inst = m_Instances.instance(i);
587      if (!inst.isMissing(index)) {
588        m_Distribution[1][(int)inst.classValue()] += inst.weight();
589      } else {
590        m_Distribution[2][(int)inst.classValue()] += inst.weight();
591        numMissing++;
592      }
593    }
594    System.arraycopy(m_Distribution[1], 0, sum, 0, m_Instances.numClasses());
595
596    // Save current distribution as best distribution
597    for (int j = 0; j < 3; j++) {
598      System.arraycopy(m_Distribution[j], 0, bestDist[j], 0, 
599                       m_Instances.numClasses());
600    }
601
602    // Sort instances
603    m_Instances.sort(index);
604   
605    // Make split counts for each possible split and evaluate
606    for (int i = 0; i < m_Instances.numInstances() - (numMissing + 1); i++) {
607      Instance inst = m_Instances.instance(i);
608      Instance instPlusOne = m_Instances.instance(i + 1);
609      m_Distribution[0][(int)inst.classValue()] += inst.weight();
610      m_Distribution[1][(int)inst.classValue()] -= inst.weight();
611      if (inst.value(index) < instPlusOne.value(index)) {
612        currCutPoint = (inst.value(index) + instPlusOne.value(index)) / 2.0;
613        currVal = ContingencyTables.entropyConditionedOnRows(m_Distribution);
614        if (currVal < bestVal) {
615          m_SplitPoint = currCutPoint;
616          bestVal = currVal;
617          for (int j = 0; j < 3; j++) {
618            System.arraycopy(m_Distribution[j], 0, bestDist[j], 0, 
619                             m_Instances.numClasses());
620          }
621        }
622      }
623    }
624
625    // No missing values in training data.
626    if (numMissing == 0) {
627      System.arraycopy(sum, 0, bestDist[2], 0, m_Instances.numClasses());
628    }
629 
630    m_Distribution = bestDist;
631    return bestVal;
632  }
633
634  /**
635   * Finds best split for numeric attribute and numeric class
636   * and returns value.
637   *
638   * @param index attribute index
639   * @return value of criterion for the best split
640   * @throws Exception if something goes wrong
641   */
642  private double findSplitNumericNumeric(int index) throws Exception {
643
644    double bestVal = Double.MAX_VALUE, currVal, currCutPoint;
645    int numMissing = 0;
646    double[] sumsSquares = new double[3], sumOfWeights = new double[3];
647    double[][] bestDist = new double[3][1];
648    double totalSum = 0, totalSumOfWeights = 0;
649
650    // Compute counts for all the values
651    for (int i = 0; i < m_Instances.numInstances(); i++) {
652      Instance inst = m_Instances.instance(i);
653      if (!inst.isMissing(index)) {
654        m_Distribution[1][0] += inst.classValue() * inst.weight();
655        sumsSquares[1] += inst.classValue() * inst.classValue() 
656          * inst.weight();
657        sumOfWeights[1] += inst.weight();
658      } else {
659        m_Distribution[2][0] += inst.classValue() * inst.weight();
660        sumsSquares[2] += inst.classValue() * inst.classValue() 
661          * inst.weight();
662        sumOfWeights[2] += inst.weight();
663        numMissing++;
664      }
665      totalSumOfWeights += inst.weight();
666      totalSum += inst.classValue() * inst.weight();
667    }
668
669    // Check if the total weight is zero
670    if (totalSumOfWeights <= 0) {
671      return bestVal;
672    }
673
674    // Sort instances
675    m_Instances.sort(index);
676   
677    // Make split counts for each possible split and evaluate
678    for (int i = 0; i < m_Instances.numInstances() - (numMissing + 1); i++) {
679      Instance inst = m_Instances.instance(i);
680      Instance instPlusOne = m_Instances.instance(i + 1);
681      m_Distribution[0][0] += inst.classValue() * inst.weight();
682      sumsSquares[0] += inst.classValue() * inst.classValue() * inst.weight();
683      sumOfWeights[0] += inst.weight();
684      m_Distribution[1][0] -= inst.classValue() * inst.weight();
685      sumsSquares[1] -= inst.classValue() * inst.classValue() * inst.weight();
686      sumOfWeights[1] -= inst.weight();
687      if (inst.value(index) < instPlusOne.value(index)) {
688        currCutPoint = (inst.value(index) + instPlusOne.value(index)) / 2.0;
689        currVal = variance(m_Distribution, sumsSquares, sumOfWeights);
690        if (currVal < bestVal) {
691          m_SplitPoint = currCutPoint;
692          bestVal = currVal;
693          for (int j = 0; j < 3; j++) {
694            if (sumOfWeights[j] > 0) {
695              bestDist[j][0] = m_Distribution[j][0] / sumOfWeights[j];
696            } else {
697              bestDist[j][0] = totalSum / totalSumOfWeights;
698            }
699          }
700        }
701      }
702    }
703
704    m_Distribution = bestDist;
705    return bestVal;
706  }
707
708  /**
709   * Computes variance for subsets.
710   *
711   * @param s
712   * @param sS
713   * @param sumOfWeights
714   * @return the variance
715   */
716  private double variance(double[][] s,double[] sS,double[] sumOfWeights) {
717
718    double var = 0;
719
720    for (int i = 0; i < s.length; i++) {
721      if (sumOfWeights[i] > 0) {
722        var += sS[i] - ((s[i][0] * s[i][0]) / (double) sumOfWeights[i]);
723      }
724    }
725   
726    return var;
727  }
728
729  /**
730   * Returns the subset an instance falls into.
731   *
732   * @param instance the instance to check
733   * @return the subset the instance falls into
734   * @throws Exception if something goes wrong
735   */
736  private int whichSubset(Instance instance) throws Exception {
737
738    if (instance.isMissing(m_AttIndex)) {
739      return 2;
740    } else if (instance.attribute(m_AttIndex).isNominal()) {
741      if ((int)instance.value(m_AttIndex) == m_SplitPoint) {
742        return 0;
743      } else {
744        return 1;
745      }
746    } else {
747      if (instance.value(m_AttIndex) <= m_SplitPoint) {
748        return 0;
749      } else {
750        return 1;
751      }
752    }
753  }
754 
755  /**
756   * Returns the revision string.
757   *
758   * @return            the revision
759   */
760  public String getRevision() {
761    return RevisionUtils.extract("$Revision: 5928 $");
762  }
763 
764  /**
765   * Main method for testing this class.
766   *
767   * @param argv the options
768   */
769  public static void main(String [] argv) {
770    runClassifier(new DecisionStump(), argv);
771  }
772}
Note: See TracBrowser for help on using the repository browser.