1 | package weka.classifiers.pmml.consumer; |
---|
2 | |
---|
3 | import weka.core.Instances; |
---|
4 | import weka.core.FastVector; |
---|
5 | import weka.core.Attribute; |
---|
6 | import weka.core.pmml.PMMLFactory; |
---|
7 | import weka.core.pmml.PMMLModel; |
---|
8 | import weka.test.Regression; |
---|
9 | import weka.classifiers.evaluation.EvaluationUtils; |
---|
10 | |
---|
11 | import java.io.*; |
---|
12 | |
---|
13 | import junit.framework.TestCase; |
---|
14 | import junit.framework.Test; |
---|
15 | import junit.framework.TestSuite; |
---|
16 | |
---|
17 | public abstract class AbstractPMMLClassifierTest extends TestCase { |
---|
18 | |
---|
19 | protected FastVector m_modelNames = new FastVector(); |
---|
20 | protected FastVector m_dataSetNames = new FastVector(); |
---|
21 | |
---|
22 | public AbstractPMMLClassifierTest(String name) { |
---|
23 | super(name); |
---|
24 | } |
---|
25 | |
---|
26 | public Instances getData(String name) { |
---|
27 | Instances elnino = null; |
---|
28 | try { |
---|
29 | elnino = |
---|
30 | new Instances(new BufferedReader(new InputStreamReader( |
---|
31 | ClassLoader.getSystemResourceAsStream("weka/classifiers/pmml/data/" + name)))); |
---|
32 | } catch (Exception ex) { |
---|
33 | ex.printStackTrace(); |
---|
34 | } |
---|
35 | return elnino; |
---|
36 | } |
---|
37 | |
---|
38 | public PMMLClassifier getClassifier(String name) { |
---|
39 | PMMLClassifier regression = null; |
---|
40 | try { |
---|
41 | PMMLModel model = |
---|
42 | PMMLFactory.getPMMLModel(new BufferedInputStream(ClassLoader.getSystemResourceAsStream( |
---|
43 | "weka/classifiers/pmml/data/" + name))); |
---|
44 | |
---|
45 | regression = (PMMLClassifier)model; |
---|
46 | |
---|
47 | } catch (Exception ex) { |
---|
48 | ex.printStackTrace(); |
---|
49 | } |
---|
50 | return regression; |
---|
51 | } |
---|
52 | |
---|
53 | public void testRegression() throws Exception { |
---|
54 | |
---|
55 | PMMLClassifier classifier = null; |
---|
56 | Instances testData = null; |
---|
57 | EvaluationUtils evalUtils = null; |
---|
58 | weka.test.Regression reg = new weka.test.Regression(this.getClass()); |
---|
59 | |
---|
60 | FastVector predictions = null; |
---|
61 | boolean success = false; |
---|
62 | for (int i = 0; i < m_modelNames.size(); i++) { |
---|
63 | classifier = getClassifier((String)m_modelNames.elementAt(i)); |
---|
64 | testData = getData((String)m_dataSetNames.elementAt(i)); |
---|
65 | evalUtils = new EvaluationUtils(); |
---|
66 | |
---|
67 | try { |
---|
68 | String className = classifier.getMiningSchema().getFieldsAsInstances().classAttribute().name(); |
---|
69 | Attribute classAtt = testData.attribute(className); |
---|
70 | testData.setClass(classAtt); |
---|
71 | predictions = evalUtils.getTestPredictions(classifier, testData); |
---|
72 | success = true; |
---|
73 | String predsString = weka.classifiers.AbstractClassifierTest.predictionsToString(predictions); |
---|
74 | reg.println(predsString); |
---|
75 | } catch (Exception ex) { |
---|
76 | ex.printStackTrace(); |
---|
77 | String msg = ex.getMessage().toLowerCase(); |
---|
78 | if (msg.indexOf("not in classpath") > -1) { |
---|
79 | return; |
---|
80 | } |
---|
81 | } |
---|
82 | } |
---|
83 | |
---|
84 | if (!success) { |
---|
85 | fail("Problem during regression testing: no successful predictions generated"); |
---|
86 | } |
---|
87 | |
---|
88 | try { |
---|
89 | String diff = reg.diff(); |
---|
90 | if (diff == null) { |
---|
91 | System.err.println("Warning: No reference available, creating."); |
---|
92 | } else if (!diff.equals("")) { |
---|
93 | fail("Regression test failed. Difference:\n" + diff); |
---|
94 | } |
---|
95 | } catch (java.io.IOException ex) { |
---|
96 | fail("Problem during regression testing.\n" + ex); |
---|
97 | } |
---|
98 | } |
---|
99 | } |
---|