| 1 | /* |
|---|
| 2 | * This program is free software; you can redistribute it and/or modify |
|---|
| 3 | * it under the terms of the GNU General Public License as published by |
|---|
| 4 | * the Free Software Foundation; either version 2 of the License, or |
|---|
| 5 | * (at your option) any later version. |
|---|
| 6 | * |
|---|
| 7 | * This program is distributed in the hope that it will be useful, |
|---|
| 8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
|---|
| 9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
|---|
| 10 | * GNU General Public License for more details. |
|---|
| 11 | * |
|---|
| 12 | * You should have received a copy of the GNU General Public License |
|---|
| 13 | * along with this program; if not, write to the Free Software |
|---|
| 14 | * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. |
|---|
| 15 | */ |
|---|
| 16 | |
|---|
| 17 | /* |
|---|
| 18 | * Copyright (C) 2002 University of Waikato |
|---|
| 19 | */ |
|---|
| 20 | |
|---|
| 21 | package weka.classifiers.meta; |
|---|
| 22 | |
|---|
| 23 | import weka.classifiers.AbstractClassifierTest; |
|---|
| 24 | import weka.classifiers.Classifier; |
|---|
| 25 | import weka.classifiers.evaluation.EvaluationUtils; |
|---|
| 26 | import weka.classifiers.evaluation.NominalPrediction; |
|---|
| 27 | import weka.core.Attribute; |
|---|
| 28 | import weka.core.FastVector; |
|---|
| 29 | import weka.core.Instances; |
|---|
| 30 | import weka.core.NoSupportForMissingValuesException; |
|---|
| 31 | import weka.core.SelectedTag; |
|---|
| 32 | import weka.core.UnsupportedAttributeTypeException; |
|---|
| 33 | import weka.filters.Filter; |
|---|
| 34 | import weka.filters.unsupervised.attribute.RemoveType; |
|---|
| 35 | import weka.filters.unsupervised.attribute.ReplaceMissingValues; |
|---|
| 36 | |
|---|
| 37 | import java.io.BufferedReader; |
|---|
| 38 | import java.io.InputStreamReader; |
|---|
| 39 | |
|---|
| 40 | import junit.framework.Test; |
|---|
| 41 | import junit.framework.TestSuite; |
|---|
| 42 | |
|---|
| 43 | /** |
|---|
| 44 | * Tests ThresholdSelector. Run from the command line with:<p> |
|---|
| 45 | * java weka.classifiers.meta.ThresholdSelectorTest |
|---|
| 46 | * |
|---|
| 47 | * @author <a href="mailto:len@reeltwo.com">Len Trigg</a> |
|---|
| 48 | * @author FracPete (fracpete at waikato dot ac dot nz) |
|---|
| 49 | * @version $Revision: 1.8 $ |
|---|
| 50 | */ |
|---|
| 51 | public class ThresholdSelectorTest |
|---|
| 52 | extends AbstractClassifierTest { |
|---|
| 53 | |
|---|
| 54 | private static double[] DIST1 = new double [] { |
|---|
| 55 | 0.25, |
|---|
| 56 | 0.375, |
|---|
| 57 | 0.5, |
|---|
| 58 | 0.625, |
|---|
| 59 | 0.75, |
|---|
| 60 | 0.875, |
|---|
| 61 | 1.0 |
|---|
| 62 | }; |
|---|
| 63 | |
|---|
| 64 | /** A set of instances to test with */ |
|---|
| 65 | protected transient Instances m_Instances; |
|---|
| 66 | |
|---|
| 67 | /** Used to generate various types of predictions */ |
|---|
| 68 | protected transient EvaluationUtils m_Evaluation; |
|---|
| 69 | |
|---|
| 70 | public ThresholdSelectorTest(String name) { |
|---|
| 71 | super(name); |
|---|
| 72 | } |
|---|
| 73 | |
|---|
| 74 | /** |
|---|
| 75 | * Called by JUnit before each test method. This implementation creates |
|---|
| 76 | * the default classifier to test and loads a test set of Instances. |
|---|
| 77 | * |
|---|
| 78 | * @exception Exception if an error occurs reading the example instances. |
|---|
| 79 | */ |
|---|
| 80 | protected void setUp() throws Exception { |
|---|
| 81 | super.setUp(); |
|---|
| 82 | |
|---|
| 83 | m_Evaluation = new EvaluationUtils(); |
|---|
| 84 | m_Instances = new Instances( |
|---|
| 85 | new BufferedReader( |
|---|
| 86 | new InputStreamReader( |
|---|
| 87 | ClassLoader.getSystemResourceAsStream( |
|---|
| 88 | "weka/classifiers/data/ClassifierTest.arff")))); |
|---|
| 89 | } |
|---|
| 90 | |
|---|
| 91 | /** Creates a default ThresholdSelector */ |
|---|
| 92 | public Classifier getClassifier() { |
|---|
| 93 | return getClassifier(DIST1); |
|---|
| 94 | } |
|---|
| 95 | |
|---|
| 96 | /** Called by JUnit after each test method */ |
|---|
| 97 | protected void tearDown() { |
|---|
| 98 | super.tearDown(); |
|---|
| 99 | |
|---|
| 100 | m_Evaluation = null; |
|---|
| 101 | } |
|---|
| 102 | |
|---|
| 103 | /** |
|---|
| 104 | * Creates a ThresholdSelector that returns predictions from a |
|---|
| 105 | * given distribution |
|---|
| 106 | */ |
|---|
| 107 | public Classifier getClassifier(double[] dist) { |
|---|
| 108 | return getClassifier(new ThresholdSelectorDummyClassifier(dist)); |
|---|
| 109 | } |
|---|
| 110 | |
|---|
| 111 | /** |
|---|
| 112 | * Creates a ThresholdSelector with the given subclassifier. |
|---|
| 113 | * |
|---|
| 114 | * @param classifier a <code>Classifier</code> to use as the |
|---|
| 115 | * subclassifier |
|---|
| 116 | * @return a new <code>ThresholdSelector</code> |
|---|
| 117 | */ |
|---|
| 118 | public Classifier getClassifier(Classifier classifier) { |
|---|
| 119 | ThresholdSelector t = new ThresholdSelector(); |
|---|
| 120 | t.setClassifier(classifier); |
|---|
| 121 | return t; |
|---|
| 122 | } |
|---|
| 123 | |
|---|
| 124 | /** |
|---|
| 125 | * Builds a model using the current classifier using the first |
|---|
| 126 | * half of the current data for training, and generates a bunch of |
|---|
| 127 | * predictions using the remaining half of the data for testing. |
|---|
| 128 | * |
|---|
| 129 | * @return a <code>FastVector</code> containing the predictions. |
|---|
| 130 | */ |
|---|
| 131 | protected FastVector useClassifier() throws Exception { |
|---|
| 132 | |
|---|
| 133 | Classifier dc = null; |
|---|
| 134 | int tot = m_Instances.numInstances(); |
|---|
| 135 | int mid = tot / 2; |
|---|
| 136 | Instances train = null; |
|---|
| 137 | Instances test = null; |
|---|
| 138 | try { |
|---|
| 139 | train = new Instances(m_Instances, 0, mid); |
|---|
| 140 | test = new Instances(m_Instances, mid, tot - mid); |
|---|
| 141 | dc = m_Classifier; |
|---|
| 142 | } catch (Exception ex) { |
|---|
| 143 | ex.printStackTrace(); |
|---|
| 144 | fail("Problem setting up to use classifier: " + ex); |
|---|
| 145 | } |
|---|
| 146 | int counter = 0; |
|---|
| 147 | do { |
|---|
| 148 | try { |
|---|
| 149 | return m_Evaluation.getTrainTestPredictions(dc, train, test); |
|---|
| 150 | } catch (UnsupportedAttributeTypeException ex) { |
|---|
| 151 | SelectedTag tag = null; |
|---|
| 152 | boolean invert = false; |
|---|
| 153 | String msg = ex.getMessage(); |
|---|
| 154 | if ((msg.indexOf("string") != -1) && |
|---|
| 155 | (msg.indexOf("attributes") != -1)) { |
|---|
| 156 | System.err.println("\nDeleting string attributes."); |
|---|
| 157 | tag = new SelectedTag(Attribute.STRING, |
|---|
| 158 | RemoveType.TAGS_ATTRIBUTETYPE); |
|---|
| 159 | } else if ((msg.indexOf("only") != -1) && |
|---|
| 160 | (msg.indexOf("nominal") != -1)) { |
|---|
| 161 | System.err.println("\nDeleting non-nominal attributes."); |
|---|
| 162 | tag = new SelectedTag(Attribute.NOMINAL, |
|---|
| 163 | RemoveType.TAGS_ATTRIBUTETYPE); |
|---|
| 164 | invert = true; |
|---|
| 165 | } else if ((msg.indexOf("only") != -1) && |
|---|
| 166 | (msg.indexOf("numeric") != -1)) { |
|---|
| 167 | System.err.println("\nDeleting non-numeric attributes."); |
|---|
| 168 | tag = new SelectedTag(Attribute.NUMERIC, |
|---|
| 169 | RemoveType.TAGS_ATTRIBUTETYPE); |
|---|
| 170 | invert = true; |
|---|
| 171 | } else { |
|---|
| 172 | throw ex; |
|---|
| 173 | } |
|---|
| 174 | RemoveType attFilter = new RemoveType(); |
|---|
| 175 | attFilter.setAttributeType(tag); |
|---|
| 176 | attFilter.setInvertSelection(invert); |
|---|
| 177 | attFilter.setInputFormat(train); |
|---|
| 178 | train = Filter.useFilter(train, attFilter); |
|---|
| 179 | attFilter.batchFinished(); |
|---|
| 180 | test = Filter.useFilter(test, attFilter); |
|---|
| 181 | counter++; |
|---|
| 182 | if (counter > 2) { |
|---|
| 183 | throw ex; |
|---|
| 184 | } |
|---|
| 185 | } catch (NoSupportForMissingValuesException ex2) { |
|---|
| 186 | System.err.println("\nReplacing missing values."); |
|---|
| 187 | ReplaceMissingValues rmFilter = new ReplaceMissingValues(); |
|---|
| 188 | rmFilter.setInputFormat(train); |
|---|
| 189 | train = Filter.useFilter(train, rmFilter); |
|---|
| 190 | rmFilter.batchFinished(); |
|---|
| 191 | test = Filter.useFilter(test, rmFilter); |
|---|
| 192 | } catch (IllegalArgumentException ex3) { |
|---|
| 193 | String msg = ex3.getMessage(); |
|---|
| 194 | if (msg.indexOf("Not enough instances") != -1) { |
|---|
| 195 | System.err.println("\nInflating training data."); |
|---|
| 196 | Instances trainNew = new Instances(train); |
|---|
| 197 | for (int i = 0; i < train.numInstances(); i++) { |
|---|
| 198 | trainNew.add(train.instance(i)); |
|---|
| 199 | } |
|---|
| 200 | train = trainNew; |
|---|
| 201 | } else { |
|---|
| 202 | throw ex3; |
|---|
| 203 | } |
|---|
| 204 | } |
|---|
| 205 | } while (true); |
|---|
| 206 | } |
|---|
| 207 | |
|---|
| 208 | public void testRangeNone() throws Exception { |
|---|
| 209 | |
|---|
| 210 | int cind = 0; |
|---|
| 211 | ((ThresholdSelector)m_Classifier).setDesignatedClass(new SelectedTag(ThresholdSelector.OPTIMIZE_0, ThresholdSelector.TAGS_OPTIMIZE)); |
|---|
| 212 | ((ThresholdSelector)m_Classifier).setRangeCorrection(new SelectedTag(ThresholdSelector.RANGE_NONE, ThresholdSelector.TAGS_RANGE)); |
|---|
| 213 | FastVector result = null; |
|---|
| 214 | m_Instances.setClassIndex(1); |
|---|
| 215 | result = useClassifier(); |
|---|
| 216 | assertTrue(result.size() != 0); |
|---|
| 217 | double minp = 0; |
|---|
| 218 | double maxp = 0; |
|---|
| 219 | for (int i = 0; i < result.size(); i++) { |
|---|
| 220 | NominalPrediction p = (NominalPrediction)result.elementAt(i); |
|---|
| 221 | double prob = p.distribution()[cind]; |
|---|
| 222 | if ((i == 0) || (prob < minp)) minp = prob; |
|---|
| 223 | if ((i == 0) || (prob > maxp)) maxp = prob; |
|---|
| 224 | } |
|---|
| 225 | assertTrue("Upper limit shouldn't increase", maxp <= 1.0); |
|---|
| 226 | assertTrue("Lower limit shouldn'd decrease", minp >= 0.25); |
|---|
| 227 | } |
|---|
| 228 | |
|---|
| 229 | public void testDesignatedClass() throws Exception { |
|---|
| 230 | |
|---|
| 231 | int cind = 0; |
|---|
| 232 | for (int i = 0; i < ThresholdSelector.TAGS_OPTIMIZE.length; i++) { |
|---|
| 233 | ((ThresholdSelector)m_Classifier).setDesignatedClass(new SelectedTag(ThresholdSelector.TAGS_OPTIMIZE[i].getID(), ThresholdSelector.TAGS_OPTIMIZE)); |
|---|
| 234 | m_Instances.setClassIndex(1); |
|---|
| 235 | FastVector result = useClassifier(); |
|---|
| 236 | assertTrue(result.size() != 0); |
|---|
| 237 | } |
|---|
| 238 | } |
|---|
| 239 | |
|---|
| 240 | public void testEvaluationMode() throws Exception { |
|---|
| 241 | |
|---|
| 242 | int cind = 0; |
|---|
| 243 | for (int i = 0; i < ThresholdSelector.TAGS_EVAL.length; i++) { |
|---|
| 244 | ((ThresholdSelector)m_Classifier).setEvaluationMode(new SelectedTag(ThresholdSelector.TAGS_EVAL[i].getID(), ThresholdSelector.TAGS_EVAL)); |
|---|
| 245 | m_Instances.setClassIndex(1); |
|---|
| 246 | FastVector result = useClassifier(); |
|---|
| 247 | assertTrue(result.size() != 0); |
|---|
| 248 | } |
|---|
| 249 | } |
|---|
| 250 | |
|---|
| 251 | public void testNumXValFolds() throws Exception { |
|---|
| 252 | |
|---|
| 253 | try { |
|---|
| 254 | ((ThresholdSelector)m_Classifier).setNumXValFolds(0); |
|---|
| 255 | fail("Expected IllegalArgumentException"); |
|---|
| 256 | } catch (IllegalArgumentException e) { |
|---|
| 257 | // OK |
|---|
| 258 | } |
|---|
| 259 | |
|---|
| 260 | int cind = 0; |
|---|
| 261 | for (int i = 2; i < 20; i += 2) { |
|---|
| 262 | ((ThresholdSelector)m_Classifier).setNumXValFolds(i); |
|---|
| 263 | m_Instances.setClassIndex(1); |
|---|
| 264 | FastVector result = useClassifier(); |
|---|
| 265 | assertTrue(result.size() != 0); |
|---|
| 266 | } |
|---|
| 267 | } |
|---|
| 268 | |
|---|
| 269 | public static Test suite() { |
|---|
| 270 | return new TestSuite(ThresholdSelectorTest.class); |
|---|
| 271 | } |
|---|
| 272 | |
|---|
| 273 | public static void main(String[] args){ |
|---|
| 274 | junit.textui.TestRunner.run(suite()); |
|---|
| 275 | } |
|---|
| 276 | } |
|---|