source: src/main/java/weka/classifiers/trees/j48/NBTreeModelSelection.java @ 7

Last change on this file since 7 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 5.6 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    NBTreeModelSelection.java
19 *    Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.trees.j48;
24
25import weka.core.Attribute;
26import weka.core.Instances;
27import weka.core.RevisionUtils;
28import weka.core.Utils;
29
30import java.util.Enumeration;
31
32/**
33 * Class for selecting a NB tree split.
34 *
35 * @author Mark Hall (mhall@cs.waikato.ac.nz)
36 * @version $Revision: 1.5 $
37 */
38public class NBTreeModelSelection
39  extends ModelSelection {
40
41  /** for serialization */
42  private static final long serialVersionUID = 990097748931976704L;
43
44  /** Minimum number of objects in interval. */
45  private int m_minNoObj;               
46
47  /** All the training data */
48  private Instances m_allData; //
49
50  /**
51   * Initializes the split selection method with the given parameters.
52   *
53   * @param minNoObj minimum number of instances that have to occur in at least two
54   * subsets induced by split
55   * @param allData FULL training dataset (necessary for
56   * selection of split points).
57   */
58  public NBTreeModelSelection(int minNoObj, Instances allData) {
59    m_minNoObj = minNoObj;
60    m_allData = allData;
61  }
62
63  /**
64   * Sets reference to training data to null.
65   */
66  public void cleanup() {
67
68    m_allData = null;
69  }
70
71  /**
72   * Selects NBTree-type split for the given dataset.
73   */
74  public final ClassifierSplitModel selectModel(Instances data){
75
76    double globalErrors = 0;
77
78    double minResult;
79    double currentResult;
80    NBTreeSplit [] currentModel;
81    NBTreeSplit bestModel = null;
82    NBTreeNoSplit noSplitModel = null;
83    int validModels = 0;
84    boolean multiVal = true;
85    Distribution checkDistribution;
86    Attribute attribute;
87    double sumOfWeights;
88    int i;
89   
90    try{
91      // build the global model at this node
92      noSplitModel = new NBTreeNoSplit();
93      noSplitModel.buildClassifier(data);
94      if (data.numInstances() < 5) {
95        return noSplitModel;
96      }
97
98      // evaluate it
99      globalErrors = noSplitModel.getErrors();
100      if (globalErrors == 0) {
101        return noSplitModel;
102      }
103
104      // Check if all Instances belong to one class or if not
105      // enough Instances to split.
106      checkDistribution = new Distribution(data);
107      if (Utils.sm(checkDistribution.total(), m_minNoObj) ||
108          Utils.eq(checkDistribution.total(),
109                   checkDistribution.perClass(checkDistribution.maxClass()))) {
110        return noSplitModel;
111      }
112
113      // Check if all attributes are nominal and have a
114      // lot of values.
115      if (m_allData != null) {
116        Enumeration enu = data.enumerateAttributes();
117        while (enu.hasMoreElements()) {
118          attribute = (Attribute) enu.nextElement();
119          if ((attribute.isNumeric()) ||
120              (Utils.sm((double)attribute.numValues(),
121                        (0.3*(double)m_allData.numInstances())))){
122            multiVal = false;
123            break;
124          }
125        }
126      }
127
128      currentModel = new NBTreeSplit[data.numAttributes()];
129      sumOfWeights = data.sumOfWeights();
130
131      // For each attribute.
132      for (i = 0; i < data.numAttributes(); i++){
133       
134        // Apart from class attribute.
135        if (i != (data).classIndex()){
136         
137          // Get models for current attribute.
138          currentModel[i] = new NBTreeSplit(i,m_minNoObj,sumOfWeights);
139          currentModel[i].setGlobalModel(noSplitModel);
140          currentModel[i].buildClassifier(data);
141         
142          // Check if useful split for current attribute
143          // exists and check for enumerated attributes with
144          // a lot of values.
145          if (currentModel[i].checkModel()){
146            validModels++;
147          }
148        } else {
149          currentModel[i] = null;
150        }
151      }
152     
153      // Check if any useful split was found.
154      if (validModels == 0) {
155        return noSplitModel;
156      }
157     
158     // Find "best" attribute to split on.
159      minResult = globalErrors;
160      for (i=0;i<data.numAttributes();i++){
161        if ((i != (data).classIndex()) &&
162            (currentModel[i].checkModel())) {
163          /*  System.err.println("Errors for "+data.attribute(i).name()+" "+
164              currentModel[i].getErrors()); */
165          if (currentModel[i].getErrors() < minResult) {
166            bestModel = currentModel[i];
167            minResult = currentModel[i].getErrors();
168          }
169        }
170      }
171      //      System.exit(1);
172      // Check if useful split was found.
173     
174
175      if (((globalErrors - minResult) / globalErrors) < 0.05) {
176        return noSplitModel;
177      }
178     
179      /*      if (bestModel == null) {
180        System.err.println("This shouldn't happen! glob : "+globalErrors+
181                           " minRes : "+minResult);
182        System.exit(1);
183        } */
184      // Set the global model for the best split
185      //      bestModel.setGlobalModel(noSplitModel);
186
187      return bestModel;
188    }catch(Exception e){
189      e.printStackTrace();
190    }
191    return null;
192  }
193
194  /**
195   * Selects NBTree-type split for the given dataset.
196   */
197  public final ClassifierSplitModel selectModel(Instances train, Instances test) {
198
199    return selectModel(train);
200  }
201 
202  /**
203   * Returns the revision string.
204   *
205   * @return            the revision
206   */
207  public String getRevision() {
208    return RevisionUtils.extract("$Revision: 1.5 $");
209  }
210}
Note: See TracBrowser for help on using the repository browser.