source: src/test/java/weka/classifiers/pmml/consumer/AbstractPMMLClassifierTest.java @ 23

Last change on this file since 23 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 3.0 KB
Line 
1package weka.classifiers.pmml.consumer;
2
3import weka.core.Instances;
4import weka.core.FastVector;
5import weka.core.Attribute;
6import weka.core.pmml.PMMLFactory;
7import weka.core.pmml.PMMLModel;
8import weka.test.Regression;
9import weka.classifiers.evaluation.EvaluationUtils;
10
11import java.io.*;
12
13import junit.framework.TestCase;
14import junit.framework.Test;
15import junit.framework.TestSuite;
16
17public abstract class AbstractPMMLClassifierTest extends TestCase {
18
19  protected FastVector m_modelNames = new FastVector();
20  protected FastVector m_dataSetNames = new FastVector();
21
22  public AbstractPMMLClassifierTest(String name) { 
23    super(name); 
24  }
25
26  public Instances getData(String name) {
27    Instances elnino = null;
28    try {
29      elnino = 
30        new Instances(new BufferedReader(new InputStreamReader(
31          ClassLoader.getSystemResourceAsStream("weka/classifiers/pmml/data/" + name))));
32    } catch (Exception ex) {
33      ex.printStackTrace();
34    }
35    return elnino;
36  }
37
38  public PMMLClassifier getClassifier(String name) {
39    PMMLClassifier regression = null;
40    try {
41      PMMLModel model = 
42        PMMLFactory.getPMMLModel(new BufferedInputStream(ClassLoader.getSystemResourceAsStream(
43                  "weka/classifiers/pmml/data/" + name)));
44
45      regression = (PMMLClassifier)model;
46
47    } catch (Exception ex) {
48      ex.printStackTrace();
49    }
50    return regression;
51  }
52
53  public void testRegression() throws Exception {
54
55    PMMLClassifier classifier = null;
56    Instances testData = null;
57    EvaluationUtils evalUtils = null; 
58    weka.test.Regression reg = new weka.test.Regression(this.getClass());
59
60    FastVector predictions = null;
61    boolean success = false;
62    for (int i = 0; i < m_modelNames.size(); i++) {
63      classifier = getClassifier((String)m_modelNames.elementAt(i));
64      testData = getData((String)m_dataSetNames.elementAt(i));
65      evalUtils = new EvaluationUtils();
66
67      try {
68        String  className = classifier.getMiningSchema().getFieldsAsInstances().classAttribute().name();
69        Attribute classAtt = testData.attribute(className);
70        testData.setClass(classAtt);
71        predictions = evalUtils.getTestPredictions(classifier, testData);
72        success = true;
73        String predsString = weka.classifiers.AbstractClassifierTest.predictionsToString(predictions);
74        reg.println(predsString);
75      } catch (Exception ex) {
76        ex.printStackTrace();
77        String msg = ex.getMessage().toLowerCase();
78        if (msg.indexOf("not in classpath") > -1) {
79          return;
80        }
81      }
82    }
83
84    if (!success) {
85      fail("Problem during regression testing: no successful predictions generated");
86    }
87
88    try {
89      String diff = reg.diff();
90      if (diff == null) {
91        System.err.println("Warning: No reference available, creating."); 
92      } else if (!diff.equals("")) {
93        fail("Regression test failed. Difference:\n" + diff);
94      }
95    }  catch (java.io.IOException ex) {
96      fail("Problem during regression testing.\n" + ex);
97    }   
98  }
99}
Note: See TracBrowser for help on using the repository browser.