/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * Copyright (C) 2007 University of Waikato, Hamilton, New Zealand */ package weka.classifiers.misc; import weka.classifiers.Classifier; import weka.classifiers.evaluation.EvaluationUtils; import weka.core.Attribute; import weka.core.CheckOptionHandler; import weka.core.FastVector; import weka.core.Instances; import weka.core.SerializationHelper; import weka.core.TestInstances; import weka.test.Regression; import java.io.File; import junit.framework.Test; import junit.framework.TestCase; import junit.framework.TestSuite; /** * Tests SerializedClassifier. Run from the command line with:
* java weka.classifiers.misc.SerializedClassifierTest
*
* @author FracPete (fracpete at waikato dot ac dot nz)
* @version $Revision: 1.1 $
*/
public class SerializedClassifierTest
extends TestCase {
/** the filename for temporary serialized models */
public final static String MODEL_FILENAME = System.getProperty("user.dir") + "/" + "temp.model";
/** the setup classifier */
protected SerializedClassifier m_Classifier;
/** the OptionHandler tester */
protected CheckOptionHandler m_OptionTester;
/**
* initializes the test
*
* @param name the name of the test
*/
public SerializedClassifierTest(String name) {
super(name);
}
/**
* Called by JUnit before each test method.
*
* @throws Exception if an error occurs reading the example instances.
*/
protected void setUp() throws Exception {
m_Classifier = null;
m_OptionTester = new CheckOptionHandler();
m_OptionTester.setSilent(true);
// delete temp file
File file = new File(MODEL_FILENAME);
if (file.exists())
file.delete();
}
/**
* Called by JUnit after each test method
*/
protected void tearDown() {
m_Classifier = null;
m_OptionTester = null;
// delete temp file
File file = new File(MODEL_FILENAME);
if (file.exists())
file.delete();
}
/**
* creates a classifier, trains and serializes it
*
* @param data the data to use (J48 with nominal class, M5P with
* numeric class)
* @return the results for the data
*/
protected double[] trainAndSerializeClassifier(Instances data) {
Classifier classifier;
double[] result;
int i;
try {
// build
if (data.classAttribute().isNominal())
classifier = new weka.classifiers.trees.J48();
else
classifier = new weka.classifiers.trees.M5P();
classifier.buildClassifier(data);
// record predictions
result = new double[data.numInstances()];
for (i = 0; i < result.length; i++)
result[i] = classifier.classifyInstance(data.instance(i));
// save
SerializationHelper.write(MODEL_FILENAME, classifier);
}
catch (Exception e) {
fail("Training base classifier failed: " + e);
return null;
}
return result;
}
/**
* performs the actual test
*
* @param nomClass whether to use a nominal class with J48 (TRUE) or
* a numeric one with M5P (FALSE)
*/
protected void performTest(boolean nomClass) {
TestInstances test;
Instances data;
double[] originalResults;
double[] testResults;
int i;
// generate data
try {
test = new TestInstances();
if (nomClass) {
test.setClassType(Attribute.NOMINAL);
test.setNumNominal(5);
test.setNumNominalValues(4);
test.setNumNumeric(0);
}
else {
test.setClassType(Attribute.NUMERIC);
test.setNumNominal(0);
test.setNumNumeric(5);
}
test.setNumDate(0);
test.setNumString(0);
test.setNumRelational(0);
test.setNumInstances(100);
data = test.generate();
}
catch (Exception e) {
fail("Generating test data failed: " + e);
return;
}
// train and save base classifier
try {
originalResults = trainAndSerializeClassifier(data);
}
catch (Exception e) {
fail("Training base classifier failed: " + e);
return;
}
// test loading
try {
m_Classifier = new SerializedClassifier();
m_Classifier.setModelFile(new File(MODEL_FILENAME));
m_Classifier.buildClassifier(data);
}
catch (Exception e) {
fail("Loading/testing of classifier failed: " + e);
}
// compare results
try {
// get results from serialized model
testResults = new double[data.numInstances()];
for (i = 0; i < testResults.length; i++)
testResults[i] = m_Classifier.classifyInstance(data.instance(i));
// compare
for (i = 0; i < originalResults.length; i++) {
if (originalResults[i] != testResults[i])
throw new Exception("Result #" + (i+1) + " differs!");
}
}
catch (Exception e) {
fail("Comparing results failed: " + e);
}
}
/**
* tests a serialized classifier (J48) handling nominal classes
*/
public void testNominalClass() {
performTest(true);
}
/**
* tests a serialized classifier (M5P) handling numeric classes
*/
public void testNumericClass() {
performTest(true);
}
/**
* Returns a string containing all the predictions.
*
* @param predictions a FastVector
containing the predictions
* @return a String
representing the vector of predictions.
*/
protected String predictionsToString(FastVector predictions) {
StringBuffer sb = new StringBuffer();
sb.append(predictions.size()).append(" predictions\n");
for (int i = 0; i < predictions.size(); i++) {
sb.append(predictions.elementAt(i)).append('\n');
}
return sb.toString();
}
/**
* Runs a regression test -- this checks that the output of the tested
* object matches that in a reference version. When this test is
* run without any pre-existing reference output, the reference version
* is created. Uses J48 for this purpose.
*/
public void testRegression() {
Regression reg;
Instances train;
Instances test;
Instances data;
TestInstances testInst;
int tot;
int mid;
EvaluationUtils evaluation;
FastVector regressionResults;
reg = new Regression(this.getClass());
// generate test data
try {
testInst = new TestInstances();
testInst.setClassType(Attribute.NOMINAL);
testInst.setNumNominal(5);
testInst.setNumNominalValues(4);
testInst.setNumNumeric(0);
testInst.setNumDate(0);
testInst.setNumString(0);
testInst.setNumRelational(0);
testInst.setNumInstances(100);
data = testInst.generate();
}
catch (Exception e) {
fail("Failed generating data: " + e);
return;
}
// split data into train/test
tot = data.numInstances();
mid = tot / 2;
train = null;
test = null;
try {
train = new Instances(data, 0, mid);
test = new Instances(data, mid, tot - mid);
m_Classifier = new SerializedClassifier();
m_Classifier.setModelFile(new File(MODEL_FILENAME));
}
catch (Exception e) {
e.printStackTrace();
fail("Problem setting up to use classifier: " + e);
}
evaluation = new EvaluationUtils();
try {
trainAndSerializeClassifier(train);
regressionResults = evaluation.getTrainTestPredictions(m_Classifier, train, test);
reg.println(predictionsToString(regressionResults));
}
catch (Exception e) {
fail("Failed obtaining classifier predictions: " + e);
}
try {
String diff = reg.diff();
if (diff == null) {
System.err.println("Warning: No reference available, creating.");
} else if (!diff.equals("")) {
fail("Regression test failed. Difference:\n" + diff);
}
}
catch (java.io.IOException ex) {
fail("Problem during regression testing.\n" + ex);
}
}
/**
* tests the listing of the options
*/
public void testListOptions() {
if (!m_OptionTester.checkListOptions())
fail("Options cannot be listed via listOptions.");
}
/**
* tests the setting of the options
*/
public void testSetOptions() {
if (!m_OptionTester.checkSetOptions())
fail("setOptions method failed.");
}
/**
* tests whether there are any remaining options
*/
public void testRemainingOptions() {
if (!m_OptionTester.checkRemainingOptions())
fail("There were 'left-over' options.");
}
/**
* tests the whether the user-supplied options stay the same after setting.
* getting, and re-setting again.
*/
public void testCanonicalUserOptions() {
if (!m_OptionTester.checkCanonicalUserOptions())
fail("setOptions method failed");
}
/**
* tests the resetting of the options to the default ones
*/
public void testResettingOptions() {
if (!m_OptionTester.checkSetOptions())
fail("Resetting of options failed");
}
public static Test suite() {
return new TestSuite(SerializedClassifierTest.class);
}
public static void main(String[] args){
junit.textui.TestRunner.run(suite());
}
}