source: src/test/java/weka/classifiers/meta/ThresholdSelectorTest.java @ 10

Last change on this file since 10 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 8.7 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 * Copyright (C) 2002 University of Waikato
19 */
20
21package weka.classifiers.meta;
22
23import weka.classifiers.AbstractClassifierTest;
24import weka.classifiers.Classifier;
25import weka.classifiers.evaluation.EvaluationUtils;
26import weka.classifiers.evaluation.NominalPrediction;
27import weka.core.Attribute;
28import weka.core.FastVector;
29import weka.core.Instances;
30import weka.core.NoSupportForMissingValuesException;
31import weka.core.SelectedTag;
32import weka.core.UnsupportedAttributeTypeException;
33import weka.filters.Filter;
34import weka.filters.unsupervised.attribute.RemoveType;
35import weka.filters.unsupervised.attribute.ReplaceMissingValues;
36
37import java.io.BufferedReader;
38import java.io.InputStreamReader;
39
40import junit.framework.Test;
41import junit.framework.TestSuite;
42
43/**
44 * Tests ThresholdSelector. Run from the command line with:<p>
45 * java weka.classifiers.meta.ThresholdSelectorTest
46 *
47 * @author <a href="mailto:len@reeltwo.com">Len Trigg</a>
48 * @author FracPete (fracpete at waikato dot ac dot nz)
49 * @version $Revision: 1.8 $
50 */
51public class ThresholdSelectorTest 
52  extends AbstractClassifierTest {
53
54  private static double[] DIST1 = new double [] {
55    0.25,
56    0.375,
57    0.5,
58    0.625,
59    0.75,
60    0.875,
61    1.0
62  };
63
64  /** A set of instances to test with */
65  protected transient Instances m_Instances;
66
67  /** Used to generate various types of predictions */
68  protected transient EvaluationUtils m_Evaluation;
69
70  public ThresholdSelectorTest(String name) { 
71    super(name); 
72  }
73
74  /**
75   * Called by JUnit before each test method. This implementation creates
76   * the default classifier to test and loads a test set of Instances.
77   *
78   * @exception Exception if an error occurs reading the example instances.
79   */
80  protected void setUp() throws Exception {
81    super.setUp();
82
83    m_Evaluation = new EvaluationUtils();
84    m_Instances = new Instances(
85                    new BufferedReader(
86                      new InputStreamReader(
87                        ClassLoader.getSystemResourceAsStream(
88                          "weka/classifiers/data/ClassifierTest.arff"))));
89  }
90
91  /** Creates a default ThresholdSelector */
92  public Classifier getClassifier() {
93    return getClassifier(DIST1);
94  }
95
96  /** Called by JUnit after each test method */
97  protected void tearDown() {
98    super.tearDown();
99
100    m_Evaluation = null;
101  }
102
103  /**
104   * Creates a ThresholdSelector that returns predictions from a
105   * given distribution
106   */
107  public Classifier getClassifier(double[] dist) {
108    return getClassifier(new ThresholdSelectorDummyClassifier(dist));
109  }
110
111  /**
112   * Creates a ThresholdSelector with the given subclassifier.
113   *
114   * @param classifier a <code>Classifier</code> to use as the
115   * subclassifier
116   * @return a new <code>ThresholdSelector</code>
117   */
118  public Classifier getClassifier(Classifier classifier) {
119    ThresholdSelector t = new ThresholdSelector();
120    t.setClassifier(classifier);
121    return t;
122  }
123
124  /**
125   * Builds a model using the current classifier using the first
126   * half of the current data for training, and generates a bunch of
127   * predictions using the remaining half of the data for testing.
128   *
129   * @return a <code>FastVector</code> containing the predictions.
130   */
131  protected FastVector useClassifier() throws Exception {
132
133    Classifier dc = null;
134    int tot = m_Instances.numInstances();
135    int mid = tot / 2;
136    Instances train = null;
137    Instances test = null;
138    try {
139      train = new Instances(m_Instances, 0, mid);
140      test = new Instances(m_Instances, mid, tot - mid);
141      dc = m_Classifier;
142    } catch (Exception ex) {
143      ex.printStackTrace();
144      fail("Problem setting up to use classifier: " + ex);
145    }
146    int counter = 0;
147    do {
148      try {
149        return m_Evaluation.getTrainTestPredictions(dc, train, test);
150      } catch (UnsupportedAttributeTypeException ex) {
151        SelectedTag tag = null;
152        boolean invert = false;
153        String msg = ex.getMessage();
154        if ((msg.indexOf("string") != -1) && 
155            (msg.indexOf("attributes") != -1)) {
156          System.err.println("\nDeleting string attributes.");
157          tag = new SelectedTag(Attribute.STRING,
158                                RemoveType.TAGS_ATTRIBUTETYPE);
159        } else if ((msg.indexOf("only") != -1) && 
160                   (msg.indexOf("nominal") != -1)) {
161          System.err.println("\nDeleting non-nominal attributes.");
162          tag = new SelectedTag(Attribute.NOMINAL,
163                                RemoveType.TAGS_ATTRIBUTETYPE);
164          invert = true;
165        } else if ((msg.indexOf("only") != -1) && 
166                   (msg.indexOf("numeric") != -1)) {
167          System.err.println("\nDeleting non-numeric attributes.");
168          tag = new SelectedTag(Attribute.NUMERIC,
169                                RemoveType.TAGS_ATTRIBUTETYPE);
170          invert = true;
171        }  else {
172          throw ex;
173        }
174        RemoveType attFilter = new RemoveType();
175        attFilter.setAttributeType(tag);
176        attFilter.setInvertSelection(invert);
177        attFilter.setInputFormat(train);
178        train = Filter.useFilter(train, attFilter);
179        attFilter.batchFinished();
180        test = Filter.useFilter(test, attFilter);
181        counter++;
182        if (counter > 2) {
183          throw ex;
184        }
185      } catch (NoSupportForMissingValuesException ex2) {
186        System.err.println("\nReplacing missing values.");
187        ReplaceMissingValues rmFilter = new ReplaceMissingValues();
188        rmFilter.setInputFormat(train);
189        train = Filter.useFilter(train, rmFilter);
190        rmFilter.batchFinished();
191        test = Filter.useFilter(test, rmFilter);
192      } catch (IllegalArgumentException ex3) {
193        String msg = ex3.getMessage();
194        if (msg.indexOf("Not enough instances") != -1) {
195          System.err.println("\nInflating training data.");
196          Instances trainNew = new Instances(train);
197          for (int i = 0; i < train.numInstances(); i++) {
198            trainNew.add(train.instance(i));
199          }
200          train = trainNew;
201        } else {
202          throw ex3;
203        }
204      }
205    } while (true);
206  }
207
208  public void testRangeNone() throws Exception {
209   
210    int cind = 0;
211    ((ThresholdSelector)m_Classifier).setDesignatedClass(new SelectedTag(ThresholdSelector.OPTIMIZE_0, ThresholdSelector.TAGS_OPTIMIZE));
212    ((ThresholdSelector)m_Classifier).setRangeCorrection(new SelectedTag(ThresholdSelector.RANGE_NONE, ThresholdSelector.TAGS_RANGE));
213    FastVector result = null;
214    m_Instances.setClassIndex(1);
215    result = useClassifier();
216    assertTrue(result.size() != 0);
217    double minp = 0;
218    double maxp = 0;
219    for (int i = 0; i < result.size(); i++) {
220      NominalPrediction p = (NominalPrediction)result.elementAt(i);
221      double prob = p.distribution()[cind];
222      if ((i == 0) || (prob < minp)) minp = prob;
223      if ((i == 0) || (prob > maxp)) maxp = prob;
224    }
225    assertTrue("Upper limit shouldn't increase", maxp <= 1.0);
226    assertTrue("Lower limit shouldn'd decrease", minp >= 0.25);
227  }
228 
229  public void testDesignatedClass() throws Exception {
230   
231    int cind = 0;
232    for (int i = 0; i < ThresholdSelector.TAGS_OPTIMIZE.length; i++) {
233      ((ThresholdSelector)m_Classifier).setDesignatedClass(new SelectedTag(ThresholdSelector.TAGS_OPTIMIZE[i].getID(), ThresholdSelector.TAGS_OPTIMIZE));
234      m_Instances.setClassIndex(1);
235      FastVector result = useClassifier();
236      assertTrue(result.size() != 0);
237    }
238  }
239
240  public void testEvaluationMode() throws Exception {
241   
242    int cind = 0;
243    for (int i = 0; i < ThresholdSelector.TAGS_EVAL.length; i++) {
244      ((ThresholdSelector)m_Classifier).setEvaluationMode(new SelectedTag(ThresholdSelector.TAGS_EVAL[i].getID(), ThresholdSelector.TAGS_EVAL));
245      m_Instances.setClassIndex(1);
246      FastVector result = useClassifier();
247      assertTrue(result.size() != 0);
248    }
249  }
250
251  public void testNumXValFolds() throws Exception {
252   
253    try {
254      ((ThresholdSelector)m_Classifier).setNumXValFolds(0);
255      fail("Expected IllegalArgumentException");
256    } catch (IllegalArgumentException e) {
257      // OK
258    }
259
260    int cind = 0;
261    for (int i = 2; i < 20; i += 2) {
262      ((ThresholdSelector)m_Classifier).setNumXValFolds(i);
263      m_Instances.setClassIndex(1);
264      FastVector result = useClassifier();
265      assertTrue(result.size() != 0);
266    }
267  }
268
269  public static Test suite() {
270    return new TestSuite(ThresholdSelectorTest.class);
271  }
272
273  public static void main(String[] args){
274    junit.textui.TestRunner.run(suite());
275  }
276}
Note: See TracBrowser for help on using the repository browser.