source: src/test/java/weka/classifiers/misc/SerializedClassifierTest.java @ 28

Last change on this file since 28 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 9.6 KB
RevLine 
[4]1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 * Copyright (C) 2007 University of Waikato, Hamilton, New Zealand
19 */
20
21package weka.classifiers.misc;
22
23import weka.classifiers.Classifier;
24import weka.classifiers.evaluation.EvaluationUtils;
25import weka.core.Attribute;
26import weka.core.CheckOptionHandler;
27import weka.core.FastVector;
28import weka.core.Instances;
29import weka.core.SerializationHelper;
30import weka.core.TestInstances;
31import weka.test.Regression;
32
33import java.io.File;
34
35import junit.framework.Test;
36import junit.framework.TestCase;
37import junit.framework.TestSuite;
38
39/**
40 * Tests SerializedClassifier. Run from the command line with:<p>
41 * java weka.classifiers.misc.SerializedClassifierTest
42 *
43 * @author FracPete (fracpete at waikato dot ac dot nz)
44 * @version $Revision: 1.1 $
45 */
46public class SerializedClassifierTest
47  extends TestCase {
48
49  /** the filename for temporary serialized models */
50  public final static String MODEL_FILENAME = System.getProperty("user.dir") + "/" + "temp.model";
51 
52  /** the setup classifier */
53  protected SerializedClassifier m_Classifier;
54 
55  /** the OptionHandler tester */
56  protected CheckOptionHandler m_OptionTester;
57
58  /**
59   * initializes the test
60   *
61   * @param name the name of the test
62   */
63  public SerializedClassifierTest(String name) {
64    super(name);
65  }
66 
67  /**
68   * Called by JUnit before each test method.
69   *
70   * @throws Exception if an error occurs reading the example instances.
71   */
72  protected void setUp() throws Exception {
73    m_Classifier   = null;
74    m_OptionTester = new CheckOptionHandler();
75    m_OptionTester.setSilent(true);
76
77    // delete temp file
78    File file = new File(MODEL_FILENAME);
79    if (file.exists())
80      file.delete();
81  }
82
83  /**
84   * Called by JUnit after each test method
85   */
86  protected void tearDown() {
87    m_Classifier   = null;
88    m_OptionTester = null;
89
90    // delete temp file
91    File file = new File(MODEL_FILENAME);
92    if (file.exists())
93      file.delete();
94  }
95
96  /**
97   * creates a classifier, trains and serializes it
98   *
99   * @param data        the data to use (J48 with nominal class, M5P with
100   *                    numeric class)
101   * @return            the results for the data
102   */
103  protected double[] trainAndSerializeClassifier(Instances data) {
104    Classifier  classifier;
105    double[]    result;
106    int         i;
107   
108    try {
109      // build
110      if (data.classAttribute().isNominal())
111        classifier = new weka.classifiers.trees.J48();
112      else
113        classifier = new weka.classifiers.trees.M5P();
114      classifier.buildClassifier(data);
115     
116      // record predictions
117      result = new double[data.numInstances()];
118      for (i = 0; i < result.length; i++)
119        result[i] = classifier.classifyInstance(data.instance(i));
120     
121      // save
122      SerializationHelper.write(MODEL_FILENAME, classifier);
123    }
124    catch (Exception e) {
125      fail("Training base classifier failed: " + e);
126      return null;
127    }
128   
129    return result;
130  }
131 
132  /**
133   * performs the actual test
134   *
135   * @param nomClass    whether to use a nominal class with J48 (TRUE) or
136   *                    a numeric one with M5P (FALSE)
137   */
138  protected void performTest(boolean nomClass) {
139    TestInstances       test;
140    Instances           data;
141    double[]            originalResults;
142    double[]            testResults;
143    int                 i;
144
145    // generate data
146    try {
147      test = new TestInstances();
148      if (nomClass) {
149        test.setClassType(Attribute.NOMINAL);
150        test.setNumNominal(5);
151        test.setNumNominalValues(4);
152        test.setNumNumeric(0);
153      }
154      else {
155        test.setClassType(Attribute.NUMERIC);
156        test.setNumNominal(0);
157        test.setNumNumeric(5);
158      }
159      test.setNumDate(0);
160      test.setNumString(0);
161      test.setNumRelational(0);
162      test.setNumInstances(100);
163      data = test.generate();
164    }
165    catch (Exception e) {
166      fail("Generating test data failed: " + e);
167      return;
168    }
169   
170    // train and save base classifier
171    try {
172      originalResults = trainAndSerializeClassifier(data);
173    }
174    catch (Exception e) {
175      fail("Training base classifier failed: " + e);
176      return;
177    }
178   
179    // test loading
180    try {
181      m_Classifier = new SerializedClassifier();
182      m_Classifier.setModelFile(new File(MODEL_FILENAME));
183      m_Classifier.buildClassifier(data);
184    }
185    catch (Exception e) {
186      fail("Loading/testing of classifier failed: " + e);
187    }
188   
189    // compare results
190    try {
191      // get results from serialized model
192      testResults = new double[data.numInstances()];
193      for (i = 0; i < testResults.length; i++)
194        testResults[i] = m_Classifier.classifyInstance(data.instance(i));
195     
196      // compare
197      for (i = 0; i < originalResults.length; i++) {
198        if (originalResults[i] != testResults[i])
199          throw new Exception("Result #" + (i+1) + " differs!");
200      }
201    }
202    catch (Exception e) {
203      fail("Comparing results failed: " + e);
204    }
205  }
206 
207  /**
208   * tests a serialized classifier (J48) handling nominal classes
209   */
210  public void testNominalClass() {
211    performTest(true);
212  }
213 
214  /**
215   * tests a serialized classifier (M5P) handling numeric classes
216   */
217  public void testNumericClass() {
218    performTest(true);
219  }
220
221  /**
222   * Returns a string containing all the predictions.
223   *
224   * @param predictions a <code>FastVector</code> containing the predictions
225   * @return a <code>String</code> representing the vector of predictions.
226   */
227  protected String predictionsToString(FastVector predictions) {
228    StringBuffer sb = new StringBuffer();
229    sb.append(predictions.size()).append(" predictions\n");
230    for (int i = 0; i < predictions.size(); i++) {
231      sb.append(predictions.elementAt(i)).append('\n');
232    }
233    return sb.toString();
234  }
235
236  /**
237   * Runs a regression test -- this checks that the output of the tested
238   * object matches that in a reference version. When this test is
239   * run without any pre-existing reference output, the reference version
240   * is created. Uses J48 for this purpose.
241   */
242  public void testRegression() {
243    Regression          reg;
244    Instances           train;
245    Instances           test;
246    Instances           data;
247    TestInstances       testInst;
248    int                 tot;
249    int                 mid;
250    EvaluationUtils     evaluation;
251    FastVector          regressionResults;
252   
253    reg = new Regression(this.getClass());
254
255    // generate test data
256    try {
257      testInst = new TestInstances();
258      testInst.setClassType(Attribute.NOMINAL);
259      testInst.setNumNominal(5);
260      testInst.setNumNominalValues(4);
261      testInst.setNumNumeric(0);
262      testInst.setNumDate(0);
263      testInst.setNumString(0);
264      testInst.setNumRelational(0);
265      testInst.setNumInstances(100);
266      data = testInst.generate();
267    }
268    catch (Exception e) {
269      fail("Failed generating data: " + e);
270      return;
271    }
272   
273    // split data into train/test
274    tot = data.numInstances();
275    mid = tot / 2;
276    train = null;
277    test = null;
278   
279    try {
280      train = new Instances(data, 0, mid);
281      test = new Instances(data, mid, tot - mid);
282      m_Classifier = new SerializedClassifier();
283      m_Classifier.setModelFile(new File(MODEL_FILENAME));
284    } 
285    catch (Exception e) {
286      e.printStackTrace();
287      fail("Problem setting up to use classifier: " + e);
288    }
289
290    evaluation = new EvaluationUtils();
291    try {
292      trainAndSerializeClassifier(train);
293      regressionResults = evaluation.getTrainTestPredictions(m_Classifier, train, test);
294      reg.println(predictionsToString(regressionResults));
295    }
296    catch (Exception e) {
297      fail("Failed obtaining classifier predictions: " + e);
298    }
299   
300    try {
301      String diff = reg.diff();
302      if (diff == null) {
303        System.err.println("Warning: No reference available, creating."); 
304      } else if (!diff.equals("")) {
305        fail("Regression test failed. Difference:\n" + diff);
306      }
307    } 
308    catch (java.io.IOException ex) {
309      fail("Problem during regression testing.\n" + ex);
310    }
311  }
312 
313  /**
314   * tests the listing of the options
315   */
316  public void testListOptions() {
317    if (!m_OptionTester.checkListOptions())
318      fail("Options cannot be listed via listOptions.");
319  }
320 
321  /**
322   * tests the setting of the options
323   */
324  public void testSetOptions() {
325    if (!m_OptionTester.checkSetOptions())
326      fail("setOptions method failed.");
327  }
328 
329  /**
330   * tests whether there are any remaining options
331   */
332  public void testRemainingOptions() {
333    if (!m_OptionTester.checkRemainingOptions())
334      fail("There were 'left-over' options.");
335  }
336 
337  /**
338   * tests the whether the user-supplied options stay the same after setting.
339   * getting, and re-setting again.
340   */
341  public void testCanonicalUserOptions() {
342    if (!m_OptionTester.checkCanonicalUserOptions())
343      fail("setOptions method failed");
344  }
345 
346  /**
347   * tests the resetting of the options to the default ones
348   */
349  public void testResettingOptions() {
350    if (!m_OptionTester.checkSetOptions())
351      fail("Resetting of options failed");
352  }
353 
354  public static Test suite() {
355    return new TestSuite(SerializedClassifierTest.class);
356  }
357
358  public static void main(String[] args){
359    junit.textui.TestRunner.run(suite());
360  }
361}
Note: See TracBrowser for help on using the repository browser.