[29] | 1 | /* |
---|
| 2 | * This program is free software; you can redistribute it and/or modify |
---|
| 3 | * it under the terms of the GNU General Public License as published by |
---|
| 4 | * the Free Software Foundation; either version 2 of the License, or |
---|
| 5 | * (at your option) any later version. |
---|
| 6 | * |
---|
| 7 | * This program is distributed in the hope that it will be useful, |
---|
| 8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
| 9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
---|
| 10 | * GNU General Public License for more details. |
---|
| 11 | * |
---|
| 12 | * You should have received a copy of the GNU General Public License |
---|
| 13 | * along with this program; if not, write to the Free Software |
---|
| 14 | * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. |
---|
| 15 | */ |
---|
| 16 | |
---|
| 17 | /* |
---|
| 18 | * RegressionByDiscretization.java |
---|
| 19 | * Copyright (C) 1999 University of Waikato, Hamilton, New Zealand |
---|
| 20 | * |
---|
| 21 | */ |
---|
| 22 | |
---|
| 23 | package weka.classifiers.meta; |
---|
| 24 | |
---|
| 25 | import weka.classifiers.SingleClassifierEnhancer; |
---|
| 26 | import weka.classifiers.IntervalEstimator; |
---|
| 27 | import weka.classifiers.ConditionalDensityEstimator; |
---|
| 28 | |
---|
| 29 | import weka.core.Capabilities; |
---|
| 30 | import weka.core.Instance; |
---|
| 31 | import weka.core.Instances; |
---|
| 32 | import weka.core.Attribute; |
---|
| 33 | import weka.core.FastVector; |
---|
| 34 | import weka.core.Option; |
---|
| 35 | import weka.core.RevisionUtils; |
---|
| 36 | import weka.core.Utils; |
---|
| 37 | import weka.core.Capabilities.Capability; |
---|
| 38 | import weka.core.Tag; |
---|
| 39 | import weka.core.SelectedTag; |
---|
| 40 | import weka.core.TechnicalInformation; |
---|
| 41 | import weka.core.TechnicalInformation.Field; |
---|
| 42 | import weka.core.TechnicalInformation.Type; |
---|
| 43 | |
---|
| 44 | import weka.filters.Filter; |
---|
| 45 | import weka.filters.unsupervised.attribute.Discretize; |
---|
| 46 | |
---|
| 47 | import weka.estimators.UnivariateDensityEstimator; |
---|
| 48 | import weka.estimators.UnivariateIntervalEstimator; |
---|
| 49 | import weka.estimators.UnivariateQuantileEstimator; |
---|
| 50 | import weka.estimators.UnivariateEqualFrequencyHistogramEstimator; |
---|
| 51 | import weka.estimators.UnivariateKernelEstimator; |
---|
| 52 | import weka.estimators.UnivariateNormalEstimator; |
---|
| 53 | |
---|
| 54 | import java.util.Enumeration; |
---|
| 55 | import java.util.Vector; |
---|
| 56 | |
---|
| 57 | /** |
---|
| 58 | <!-- globalinfo-start --> |
---|
| 59 | * A regression scheme that employs any classifier on a copy of the data that has the class attribute (equal-width) discretized. The predicted value is the expected value of the mean class value for each discretized interval (based on the predicted probabilities for each interval). |
---|
| 60 | * <p/> |
---|
| 61 | <!-- globalinfo-end --> |
---|
| 62 | * |
---|
| 63 | <!-- options-start --> |
---|
| 64 | * Valid options are: <p/> |
---|
| 65 | * |
---|
| 66 | * <pre> -B <int> |
---|
| 67 | * Number of bins for equal-width discretization |
---|
| 68 | * (default 10). |
---|
| 69 | * </pre> |
---|
| 70 | * |
---|
| 71 | * <pre> -E |
---|
| 72 | * Whether to delete empty bins after discretization |
---|
| 73 | * (default false). |
---|
| 74 | * </pre> |
---|
| 75 | * |
---|
| 76 | * <pre> -F |
---|
| 77 | * Use equal-frequency instead of equal-width discretization.</pre> |
---|
| 78 | * |
---|
| 79 | * <pre> -D |
---|
| 80 | * If set, classifier is run in debug mode and |
---|
| 81 | * may output additional info to the console</pre> |
---|
| 82 | * |
---|
| 83 | * <pre> -W |
---|
| 84 | * Full name of base classifier. |
---|
| 85 | * (default: weka.classifiers.trees.J48)</pre> |
---|
| 86 | * |
---|
| 87 | * <pre> |
---|
| 88 | * Options specific to classifier weka.classifiers.trees.J48: |
---|
| 89 | * </pre> |
---|
| 90 | * |
---|
| 91 | * <pre> -U |
---|
| 92 | * Use unpruned tree.</pre> |
---|
| 93 | * |
---|
| 94 | * <pre> -C <pruning confidence> |
---|
| 95 | * Set confidence threshold for pruning. |
---|
| 96 | * (default 0.25)</pre> |
---|
| 97 | * |
---|
| 98 | * <pre> -M <minimum number of instances> |
---|
| 99 | * Set minimum number of instances per leaf. |
---|
| 100 | * (default 2)</pre> |
---|
| 101 | * |
---|
| 102 | * <pre> -R |
---|
| 103 | * Use reduced error pruning.</pre> |
---|
| 104 | * |
---|
| 105 | * <pre> -N <number of folds> |
---|
| 106 | * Set number of folds for reduced error |
---|
| 107 | * pruning. One fold is used as pruning set. |
---|
| 108 | * (default 3)</pre> |
---|
| 109 | * |
---|
| 110 | * <pre> -B |
---|
| 111 | * Use binary splits only.</pre> |
---|
| 112 | * |
---|
| 113 | * <pre> -S |
---|
| 114 | * Don't perform subtree raising.</pre> |
---|
| 115 | * |
---|
| 116 | * <pre> -L |
---|
| 117 | * Do not clean up after the tree has been built.</pre> |
---|
| 118 | * |
---|
| 119 | * <pre> -A |
---|
| 120 | * Laplace smoothing for predicted probabilities.</pre> |
---|
| 121 | * |
---|
| 122 | * <pre> -Q <seed> |
---|
| 123 | * Seed for random data shuffling (default 1).</pre> |
---|
| 124 | * |
---|
| 125 | <!-- options-end --> |
---|
| 126 | * |
---|
| 127 | * @author Len Trigg (trigg@cs.waikato.ac.nz) |
---|
| 128 | * @author Eibe Frank (eibe@cs.waikato.ac.nz) |
---|
| 129 | * @version $Revision: 5925 $ |
---|
| 130 | */ |
---|
| 131 | public class RegressionByDiscretization |
---|
| 132 | extends SingleClassifierEnhancer implements IntervalEstimator, ConditionalDensityEstimator { |
---|
| 133 | |
---|
| 134 | /** for serialization */ |
---|
| 135 | static final long serialVersionUID = 5066426153134050378L; |
---|
| 136 | |
---|
| 137 | /** The discretization filter. */ |
---|
| 138 | protected Discretize m_Discretizer = new Discretize(); |
---|
| 139 | |
---|
| 140 | /** The number of discretization intervals. */ |
---|
| 141 | protected int m_NumBins = 10; |
---|
| 142 | |
---|
| 143 | /** The mean values for each Discretized class interval. */ |
---|
| 144 | protected double [] m_ClassMeans; |
---|
| 145 | |
---|
| 146 | /** The class counts for each Discretized class interval. */ |
---|
| 147 | protected int [] m_ClassCounts; |
---|
| 148 | |
---|
| 149 | /** Whether to delete empty intervals. */ |
---|
| 150 | protected boolean m_DeleteEmptyBins; |
---|
| 151 | |
---|
| 152 | /** Header of discretized data. */ |
---|
| 153 | protected Instances m_DiscretizedHeader = null; |
---|
| 154 | |
---|
| 155 | /** Use equal-frequency binning */ |
---|
| 156 | protected boolean m_UseEqualFrequency = false; |
---|
| 157 | |
---|
| 158 | /** Whether to minimize absolute error, rather than squared error. */ |
---|
| 159 | protected boolean m_MinimizeAbsoluteError = false; |
---|
| 160 | |
---|
| 161 | /** Use histogram estimator */ |
---|
| 162 | public static final int ESTIMATOR_HISTOGRAM = 0; |
---|
| 163 | /** filter: Standardize training data */ |
---|
| 164 | public static final int ESTIMATOR_KERNEL = 1; |
---|
| 165 | /** filter: No normalization/standardization */ |
---|
| 166 | public static final int ESTIMATOR_NORMAL = 2; |
---|
| 167 | /** The filter to apply to the training data */ |
---|
| 168 | public static final Tag [] TAGS_ESTIMATOR = { |
---|
| 169 | new Tag(ESTIMATOR_HISTOGRAM, "Histogram density estimator"), |
---|
| 170 | new Tag(ESTIMATOR_KERNEL, "Kernel density estimator"), |
---|
| 171 | new Tag(ESTIMATOR_NORMAL, "Normal density estimator"), |
---|
| 172 | }; |
---|
| 173 | |
---|
| 174 | /** Which estimator to use (default: histogram) */ |
---|
| 175 | protected int m_estimatorType = ESTIMATOR_HISTOGRAM; |
---|
| 176 | |
---|
| 177 | /** The original target values in the training data */ |
---|
| 178 | protected double[] m_OriginalTargetValues = null; |
---|
| 179 | |
---|
| 180 | /** The converted target values in the training data */ |
---|
| 181 | protected int[] m_NewTargetValues = null; |
---|
| 182 | |
---|
| 183 | /** |
---|
| 184 | * Returns a string describing classifier |
---|
| 185 | * @return a description suitable for |
---|
| 186 | * displaying in the explorer/experimenter gui |
---|
| 187 | */ |
---|
| 188 | public String globalInfo() { |
---|
| 189 | |
---|
| 190 | return "A regression scheme that employs any " |
---|
| 191 | + "classifier on a copy of the data that has the class attribute " |
---|
| 192 | + "discretized. The predicted value is the expected value of the " |
---|
| 193 | + "mean class value for each discretized interval (based on the " |
---|
| 194 | + "predicted probabilities for each interval). This class now " |
---|
| 195 | + "also supports conditional density estimation by building " |
---|
| 196 | + "a univariate density estimator from the target values in " |
---|
| 197 | + "the training data, weighted by the class probabilities. \n\n" |
---|
| 198 | + "For more information on this process, see\n\n" |
---|
| 199 | + getTechnicalInformation().toString(); |
---|
| 200 | } |
---|
| 201 | |
---|
| 202 | /** |
---|
| 203 | * Returns an instance of a TechnicalInformation object, containing |
---|
| 204 | * detailed information about the technical background of this class, |
---|
| 205 | * e.g., paper reference or book this class is based on. |
---|
| 206 | * |
---|
| 207 | * @return the technical information about this class |
---|
| 208 | */ |
---|
| 209 | public TechnicalInformation getTechnicalInformation() { |
---|
| 210 | TechnicalInformation result; |
---|
| 211 | |
---|
| 212 | result = new TechnicalInformation(Type.INPROCEEDINGS); |
---|
| 213 | result.setValue(Field.AUTHOR, "Eibe Frank and Remco R. Bouckaert"); |
---|
| 214 | result.setValue(Field.TITLE, "Conditional Density Estimation with Class Probability Estimators"); |
---|
| 215 | result.setValue(Field.BOOKTITLE, "First Asian Conference on Machine Learning"); |
---|
| 216 | result.setValue(Field.YEAR, "2009"); |
---|
| 217 | result.setValue(Field.PAGES, "65-81"); |
---|
| 218 | result.setValue(Field.PUBLISHER, "Springer Verlag"); |
---|
| 219 | result.setValue(Field.ADDRESS, "Berlin"); |
---|
| 220 | |
---|
| 221 | return result; |
---|
| 222 | } |
---|
| 223 | |
---|
| 224 | /** |
---|
| 225 | * String describing default classifier. |
---|
| 226 | * |
---|
| 227 | * @return the default classifier classname |
---|
| 228 | */ |
---|
| 229 | protected String defaultClassifierString() { |
---|
| 230 | |
---|
| 231 | return "weka.classifiers.trees.J48"; |
---|
| 232 | } |
---|
| 233 | |
---|
| 234 | /** |
---|
| 235 | * Default constructor. |
---|
| 236 | */ |
---|
| 237 | public RegressionByDiscretization() { |
---|
| 238 | |
---|
| 239 | m_Classifier = new weka.classifiers.trees.J48(); |
---|
| 240 | } |
---|
| 241 | |
---|
| 242 | /** |
---|
| 243 | * Returns default capabilities of the classifier. |
---|
| 244 | * |
---|
| 245 | * @return the capabilities of this classifier |
---|
| 246 | */ |
---|
| 247 | public Capabilities getCapabilities() { |
---|
| 248 | Capabilities result = super.getCapabilities(); |
---|
| 249 | |
---|
| 250 | // class |
---|
| 251 | result.disableAllClasses(); |
---|
| 252 | result.disableAllClassDependencies(); |
---|
| 253 | result.enable(Capability.NUMERIC_CLASS); |
---|
| 254 | result.enable(Capability.DATE_CLASS); |
---|
| 255 | |
---|
| 256 | result.setMinimumNumberInstances(2); |
---|
| 257 | |
---|
| 258 | return result; |
---|
| 259 | } |
---|
| 260 | |
---|
| 261 | /** |
---|
| 262 | * Generates the classifier. |
---|
| 263 | * |
---|
| 264 | * @param instances set of instances serving as training data |
---|
| 265 | * @throws Exception if the classifier has not been generated successfully |
---|
| 266 | */ |
---|
| 267 | public void buildClassifier(Instances instances) throws Exception { |
---|
| 268 | |
---|
| 269 | // can classifier handle the data? |
---|
| 270 | getCapabilities().testWithFail(instances); |
---|
| 271 | |
---|
| 272 | // remove instances with missing class |
---|
| 273 | instances = new Instances(instances); |
---|
| 274 | instances.deleteWithMissingClass(); |
---|
| 275 | |
---|
| 276 | // Discretize the training data |
---|
| 277 | m_Discretizer.setIgnoreClass(true); |
---|
| 278 | m_Discretizer.setAttributeIndices("" + (instances.classIndex() + 1)); |
---|
| 279 | m_Discretizer.setBins(getNumBins()); |
---|
| 280 | m_Discretizer.setUseEqualFrequency(getUseEqualFrequency()); |
---|
| 281 | m_Discretizer.setInputFormat(instances); |
---|
| 282 | Instances newTrain = Filter.useFilter(instances, m_Discretizer); |
---|
| 283 | |
---|
| 284 | // Should empty bins be deleted? |
---|
| 285 | if (m_DeleteEmptyBins) { |
---|
| 286 | |
---|
| 287 | // Figure out which classes are empty after discretization |
---|
| 288 | int numNonEmptyClasses = 0; |
---|
| 289 | boolean[] notEmptyClass = new boolean[newTrain.numClasses()]; |
---|
| 290 | for (int i = 0; i < newTrain.numInstances(); i++) { |
---|
| 291 | if (!notEmptyClass[(int)newTrain.instance(i).classValue()]) { |
---|
| 292 | numNonEmptyClasses++; |
---|
| 293 | notEmptyClass[(int)newTrain.instance(i).classValue()] = true; |
---|
| 294 | } |
---|
| 295 | } |
---|
| 296 | |
---|
| 297 | // Compute new list of non-empty classes and mapping of indices |
---|
| 298 | FastVector newClassVals = new FastVector(numNonEmptyClasses); |
---|
| 299 | int[] oldIndexToNewIndex = new int[newTrain.numClasses()]; |
---|
| 300 | for (int i = 0; i < newTrain.numClasses(); i++) { |
---|
| 301 | if (notEmptyClass[i]) { |
---|
| 302 | oldIndexToNewIndex[i] = newClassVals.size(); |
---|
| 303 | newClassVals.addElement(newTrain.classAttribute().value(i)); |
---|
| 304 | } |
---|
| 305 | } |
---|
| 306 | |
---|
| 307 | // Compute new header information |
---|
| 308 | Attribute newClass = new Attribute(newTrain.classAttribute().name(), |
---|
| 309 | newClassVals); |
---|
| 310 | FastVector newAttributes = new FastVector(newTrain.numAttributes()); |
---|
| 311 | for (int i = 0; i < newTrain.numAttributes(); i++) { |
---|
| 312 | if (i != newTrain.classIndex()) { |
---|
| 313 | newAttributes.addElement(newTrain.attribute(i).copy()); |
---|
| 314 | } else { |
---|
| 315 | newAttributes.addElement(newClass); |
---|
| 316 | } |
---|
| 317 | } |
---|
| 318 | |
---|
| 319 | // Create new header and modify instances |
---|
| 320 | Instances newTrainTransformed = new Instances(newTrain.relationName(), |
---|
| 321 | newAttributes, |
---|
| 322 | newTrain.numInstances()); |
---|
| 323 | newTrainTransformed.setClassIndex(newTrain.classIndex()); |
---|
| 324 | for (int i = 0; i < newTrain.numInstances(); i++) { |
---|
| 325 | Instance inst = newTrain.instance(i); |
---|
| 326 | newTrainTransformed.add(inst); |
---|
| 327 | newTrainTransformed.lastInstance(). |
---|
| 328 | setClassValue(oldIndexToNewIndex[(int)inst.classValue()]); |
---|
| 329 | } |
---|
| 330 | newTrain = newTrainTransformed; |
---|
| 331 | } |
---|
| 332 | |
---|
| 333 | // Store target values, in case a prediction interval or computation of median is required |
---|
| 334 | m_OriginalTargetValues = new double[instances.numInstances()]; |
---|
| 335 | m_NewTargetValues = new int[instances.numInstances()]; |
---|
| 336 | for (int i = 0; i < m_OriginalTargetValues.length; i++) { |
---|
| 337 | m_OriginalTargetValues[i] = instances.instance(i).classValue(); |
---|
| 338 | m_NewTargetValues[i] = (int)newTrain.instance(i).classValue(); |
---|
| 339 | } |
---|
| 340 | |
---|
| 341 | m_DiscretizedHeader = new Instances(newTrain, 0); |
---|
| 342 | |
---|
| 343 | int numClasses = newTrain.numClasses(); |
---|
| 344 | |
---|
| 345 | // Calculate the mean value for each bin of the new class attribute |
---|
| 346 | m_ClassMeans = new double [numClasses]; |
---|
| 347 | m_ClassCounts = new int [numClasses]; |
---|
| 348 | for (int i = 0; i < instances.numInstances(); i++) { |
---|
| 349 | Instance inst = newTrain.instance(i); |
---|
| 350 | if (!inst.classIsMissing()) { |
---|
| 351 | int classVal = (int) inst.classValue(); |
---|
| 352 | m_ClassCounts[classVal]++; |
---|
| 353 | m_ClassMeans[classVal] += instances.instance(i).classValue(); |
---|
| 354 | } |
---|
| 355 | } |
---|
| 356 | |
---|
| 357 | for (int i = 0; i < numClasses; i++) { |
---|
| 358 | if (m_ClassCounts[i] > 0) { |
---|
| 359 | m_ClassMeans[i] /= m_ClassCounts[i]; |
---|
| 360 | } |
---|
| 361 | } |
---|
| 362 | |
---|
| 363 | if (m_Debug) { |
---|
| 364 | System.out.println("Bin Means"); |
---|
| 365 | System.out.println("=========="); |
---|
| 366 | for (int i = 0; i < m_ClassMeans.length; i++) { |
---|
| 367 | System.out.println(m_ClassMeans[i]); |
---|
| 368 | } |
---|
| 369 | System.out.println(); |
---|
| 370 | } |
---|
| 371 | |
---|
| 372 | // Train the sub-classifier |
---|
| 373 | m_Classifier.buildClassifier(newTrain); |
---|
| 374 | } |
---|
| 375 | |
---|
| 376 | /** |
---|
| 377 | * Get density estimator for given instance. |
---|
| 378 | * |
---|
| 379 | * @param inst the instance |
---|
| 380 | * @return the univariate density estimator |
---|
| 381 | * @exception Exception if the estimator can't be computed |
---|
| 382 | */ |
---|
| 383 | protected UnivariateDensityEstimator getDensityEstimator(Instance instance, boolean print) throws Exception { |
---|
| 384 | |
---|
| 385 | // Initialize estimator |
---|
| 386 | UnivariateDensityEstimator e; |
---|
| 387 | |
---|
| 388 | if (m_estimatorType == ESTIMATOR_KERNEL) { |
---|
| 389 | e = new UnivariateKernelEstimator(); |
---|
| 390 | } else if (m_estimatorType == ESTIMATOR_NORMAL) { |
---|
| 391 | e = new UnivariateNormalEstimator(); |
---|
| 392 | } else { |
---|
| 393 | e = new UnivariateEqualFrequencyHistogramEstimator(); |
---|
| 394 | |
---|
| 395 | // Set the number of bins appropriately |
---|
| 396 | ((UnivariateEqualFrequencyHistogramEstimator)e).setNumBins(getNumBins()); |
---|
| 397 | |
---|
| 398 | // Initialize boundaries of equal frequency estimator |
---|
| 399 | for (int i = 0; i < m_OriginalTargetValues.length; i++) { |
---|
| 400 | e.addValue(m_OriginalTargetValues[i], 1.0); |
---|
| 401 | } |
---|
| 402 | |
---|
| 403 | // Construct estimator, then initialize statistics, so that only boundaries will be kept |
---|
| 404 | ((UnivariateEqualFrequencyHistogramEstimator)e).initializeStatistics(); |
---|
| 405 | |
---|
| 406 | // Now that boundaries have been determined, we only need to update the bin weights |
---|
| 407 | ((UnivariateEqualFrequencyHistogramEstimator)e).setUpdateWeightsOnly(true); |
---|
| 408 | } |
---|
| 409 | |
---|
| 410 | // Make sure structure of class attribute correct |
---|
| 411 | Instance newInstance = (Instance)instance.copy(); |
---|
| 412 | newInstance.setDataset(m_DiscretizedHeader); |
---|
| 413 | double [] probs = m_Classifier.distributionForInstance(newInstance); |
---|
| 414 | |
---|
| 415 | // Add values to estimator |
---|
| 416 | for (int i = 0; i < m_OriginalTargetValues.length; i++) { |
---|
| 417 | e.addValue(m_OriginalTargetValues[i], probs[m_NewTargetValues[i]] * |
---|
| 418 | m_OriginalTargetValues.length / m_ClassCounts[m_NewTargetValues[i]]); |
---|
| 419 | } |
---|
| 420 | |
---|
| 421 | // Return estimator |
---|
| 422 | return e; |
---|
| 423 | } |
---|
| 424 | |
---|
| 425 | /** |
---|
| 426 | * Returns an N * 2 array, where N is the number of prediction |
---|
| 427 | * intervals. In each row, the first element contains the lower |
---|
| 428 | * boundary of the corresponding prediction interval and the second |
---|
| 429 | * element the upper boundary. |
---|
| 430 | * |
---|
| 431 | * @param inst the instance to make the prediction for. |
---|
| 432 | * @param confidenceLevel the percentage of cases that the interval should cover. |
---|
| 433 | * @return an array of prediction intervals |
---|
| 434 | * @exception Exception if the intervals can't be computed |
---|
| 435 | */ |
---|
| 436 | public double[][] predictIntervals(Instance instance, double confidenceLevel) throws Exception { |
---|
| 437 | |
---|
| 438 | // Get density estimator |
---|
| 439 | UnivariateIntervalEstimator e = (UnivariateIntervalEstimator)getDensityEstimator(instance, false); |
---|
| 440 | |
---|
| 441 | // Return intervals |
---|
| 442 | return e.predictIntervals(confidenceLevel); |
---|
| 443 | } |
---|
| 444 | |
---|
| 445 | /** |
---|
| 446 | * Returns natural logarithm of density estimate for given value based on given instance. |
---|
| 447 | * |
---|
| 448 | * @param inst the instance to make the prediction for. |
---|
| 449 | * @param the value to make the prediction for. |
---|
| 450 | * @return the natural logarithm of the density estimate |
---|
| 451 | * @exception Exception if the intervals can't be computed |
---|
| 452 | */ |
---|
| 453 | public double logDensity(Instance instance, double value) throws Exception { |
---|
| 454 | |
---|
| 455 | // Get density estimator |
---|
| 456 | UnivariateDensityEstimator e = getDensityEstimator(instance, true); |
---|
| 457 | |
---|
| 458 | // Return estimate |
---|
| 459 | return e.logDensity(value); |
---|
| 460 | } |
---|
| 461 | |
---|
| 462 | /** |
---|
| 463 | * Returns a predicted class for the test instance. |
---|
| 464 | * |
---|
| 465 | * @param instance the instance to be classified |
---|
| 466 | * @return predicted class value |
---|
| 467 | * @throws Exception if the prediction couldn't be made |
---|
| 468 | */ |
---|
| 469 | public double classifyInstance(Instance instance) throws Exception { |
---|
| 470 | |
---|
| 471 | // Make sure structure of class attribute correct |
---|
| 472 | Instance newInstance = (Instance)instance.copy(); |
---|
| 473 | newInstance.setDataset(m_DiscretizedHeader); |
---|
| 474 | double [] probs = m_Classifier.distributionForInstance(newInstance); |
---|
| 475 | |
---|
| 476 | if (!m_MinimizeAbsoluteError) { |
---|
| 477 | |
---|
| 478 | // Compute actual prediction |
---|
| 479 | double prediction = 0, probSum = 0; |
---|
| 480 | for (int j = 0; j < probs.length; j++) { |
---|
| 481 | prediction += probs[j] * m_ClassMeans[j]; |
---|
| 482 | probSum += probs[j]; |
---|
| 483 | } |
---|
| 484 | |
---|
| 485 | return prediction / probSum; |
---|
| 486 | } else { |
---|
| 487 | |
---|
| 488 | // Get density estimator |
---|
| 489 | UnivariateQuantileEstimator e = (UnivariateQuantileEstimator)getDensityEstimator(instance, true); |
---|
| 490 | |
---|
| 491 | // Return estimate |
---|
| 492 | return e.predictQuantile(0.5); |
---|
| 493 | } |
---|
| 494 | } |
---|
| 495 | |
---|
| 496 | /** |
---|
| 497 | * Returns an enumeration describing the available options. |
---|
| 498 | * |
---|
| 499 | * @return an enumeration of all the available options. |
---|
| 500 | */ |
---|
| 501 | public Enumeration listOptions() { |
---|
| 502 | |
---|
| 503 | Vector newVector = new Vector(5); |
---|
| 504 | |
---|
| 505 | newVector.addElement(new Option( |
---|
| 506 | "\tNumber of bins for equal-width discretization\n" |
---|
| 507 | + "\t(default 10).\n", |
---|
| 508 | "B", 1, "-B <int>")); |
---|
| 509 | |
---|
| 510 | newVector.addElement(new Option( |
---|
| 511 | "\tWhether to delete empty bins after discretization\n" |
---|
| 512 | + "\t(default false).\n", |
---|
| 513 | "E", 0, "-E")); |
---|
| 514 | |
---|
| 515 | newVector.addElement(new Option( |
---|
| 516 | "\tWhether to minimize absolute error, rather than squared error.\n" |
---|
| 517 | + "\t(default false).\n", |
---|
| 518 | "A", 0, "-A")); |
---|
| 519 | |
---|
| 520 | newVector.addElement(new Option( |
---|
| 521 | "\tUse equal-frequency instead of equal-width discretization.", |
---|
| 522 | "F", 0, "-F")); |
---|
| 523 | |
---|
| 524 | newVector.addElement(new Option( |
---|
| 525 | "\tWhat type of density estimator to use: 0=histogram/1=kernel/2=normal (default: 0).", |
---|
| 526 | "K", 1, "-K")); |
---|
| 527 | |
---|
| 528 | Enumeration enu = super.listOptions(); |
---|
| 529 | while (enu.hasMoreElements()) { |
---|
| 530 | newVector.addElement(enu.nextElement()); |
---|
| 531 | } |
---|
| 532 | |
---|
| 533 | return newVector.elements(); |
---|
| 534 | } |
---|
| 535 | |
---|
| 536 | /** |
---|
| 537 | * Parses a given list of options. <p/> |
---|
| 538 | * |
---|
| 539 | <!-- options-start --> |
---|
| 540 | <!-- options-end --> |
---|
| 541 | * |
---|
| 542 | * @param options the list of options as an array of strings |
---|
| 543 | * @throws Exception if an option is not supported |
---|
| 544 | */ |
---|
| 545 | public void setOptions(String[] options) throws Exception { |
---|
| 546 | |
---|
| 547 | String binsString = Utils.getOption('B', options); |
---|
| 548 | if (binsString.length() != 0) { |
---|
| 549 | setNumBins(Integer.parseInt(binsString)); |
---|
| 550 | } else { |
---|
| 551 | setNumBins(10); |
---|
| 552 | } |
---|
| 553 | |
---|
| 554 | setDeleteEmptyBins(Utils.getFlag('E', options)); |
---|
| 555 | setUseEqualFrequency(Utils.getFlag('F', options)); |
---|
| 556 | setMinimizeAbsoluteError(Utils.getFlag('A', options)); |
---|
| 557 | |
---|
| 558 | String tmpStr = Utils.getOption('K', options); |
---|
| 559 | if (tmpStr.length() != 0) |
---|
| 560 | setEstimatorType(new SelectedTag(Integer.parseInt(tmpStr), TAGS_ESTIMATOR)); |
---|
| 561 | else |
---|
| 562 | setEstimatorType(new SelectedTag(ESTIMATOR_HISTOGRAM, TAGS_ESTIMATOR)); |
---|
| 563 | |
---|
| 564 | super.setOptions(options); |
---|
| 565 | } |
---|
| 566 | |
---|
| 567 | /** |
---|
| 568 | * Gets the current settings of the Classifier. |
---|
| 569 | * |
---|
| 570 | * @return an array of strings suitable for passing to setOptions |
---|
| 571 | */ |
---|
| 572 | public String [] getOptions() { |
---|
| 573 | |
---|
| 574 | String [] superOptions = super.getOptions(); |
---|
| 575 | String [] options = new String [superOptions.length + 7]; |
---|
| 576 | int current = 0; |
---|
| 577 | |
---|
| 578 | options[current++] = "-B"; |
---|
| 579 | options[current++] = "" + getNumBins(); |
---|
| 580 | |
---|
| 581 | if (getDeleteEmptyBins()) { |
---|
| 582 | options[current++] = "-E"; |
---|
| 583 | } |
---|
| 584 | |
---|
| 585 | if (getUseEqualFrequency()) { |
---|
| 586 | options[current++] = "-F"; |
---|
| 587 | } |
---|
| 588 | |
---|
| 589 | if (getMinimizeAbsoluteError()) { |
---|
| 590 | options[current++] = "-A"; |
---|
| 591 | } |
---|
| 592 | |
---|
| 593 | options[current++] = "-K"; |
---|
| 594 | options[current++] = "" + m_estimatorType; |
---|
| 595 | |
---|
| 596 | System.arraycopy(superOptions, 0, options, current, |
---|
| 597 | superOptions.length); |
---|
| 598 | |
---|
| 599 | current += superOptions.length; |
---|
| 600 | while (current < options.length) { |
---|
| 601 | options[current++] = ""; |
---|
| 602 | } |
---|
| 603 | |
---|
| 604 | return options; |
---|
| 605 | } |
---|
| 606 | |
---|
| 607 | /** |
---|
| 608 | * Returns the tip text for this property |
---|
| 609 | * |
---|
| 610 | * @return tip text for this property suitable for |
---|
| 611 | * displaying in the explorer/experimenter gui |
---|
| 612 | */ |
---|
| 613 | public String numBinsTipText() { |
---|
| 614 | |
---|
| 615 | return "Number of bins for discretization."; |
---|
| 616 | } |
---|
| 617 | |
---|
| 618 | /** |
---|
| 619 | * Gets the number of bins numeric attributes will be divided into |
---|
| 620 | * |
---|
| 621 | * @return the number of bins. |
---|
| 622 | */ |
---|
| 623 | public int getNumBins() { |
---|
| 624 | |
---|
| 625 | return m_NumBins; |
---|
| 626 | } |
---|
| 627 | |
---|
| 628 | /** |
---|
| 629 | * Sets the number of bins to divide each selected numeric attribute into |
---|
| 630 | * |
---|
| 631 | * @param numBins the number of bins |
---|
| 632 | */ |
---|
| 633 | public void setNumBins(int numBins) { |
---|
| 634 | |
---|
| 635 | m_NumBins = numBins; |
---|
| 636 | } |
---|
| 637 | |
---|
| 638 | |
---|
| 639 | /** |
---|
| 640 | * Returns the tip text for this property |
---|
| 641 | * |
---|
| 642 | * @return tip text for this property suitable for |
---|
| 643 | * displaying in the explorer/experimenter gui |
---|
| 644 | */ |
---|
| 645 | public String deleteEmptyBinsTipText() { |
---|
| 646 | |
---|
| 647 | return "Whether to delete empty bins after discretization."; |
---|
| 648 | } |
---|
| 649 | |
---|
| 650 | |
---|
| 651 | /** |
---|
| 652 | * Gets whether empty bins are deleted. |
---|
| 653 | * |
---|
| 654 | * @return true if empty bins get deleted. |
---|
| 655 | */ |
---|
| 656 | public boolean getDeleteEmptyBins() { |
---|
| 657 | |
---|
| 658 | return m_DeleteEmptyBins; |
---|
| 659 | } |
---|
| 660 | |
---|
| 661 | /** |
---|
| 662 | * Sets whether to delete empty bins. |
---|
| 663 | * |
---|
| 664 | * @param b if true, empty bins will be deleted |
---|
| 665 | */ |
---|
| 666 | public void setDeleteEmptyBins(boolean b) { |
---|
| 667 | |
---|
| 668 | m_DeleteEmptyBins = b; |
---|
| 669 | } |
---|
| 670 | |
---|
| 671 | /** |
---|
| 672 | * Returns the tip text for this property |
---|
| 673 | * |
---|
| 674 | * @return tip text for this property suitable for |
---|
| 675 | * displaying in the explorer/experimenter gui |
---|
| 676 | */ |
---|
| 677 | public String minimizeAbsoluteErrorTipText() { |
---|
| 678 | |
---|
| 679 | return "Whether to minimize absolute error."; |
---|
| 680 | } |
---|
| 681 | |
---|
| 682 | |
---|
| 683 | /** |
---|
| 684 | * Gets whether to min. abs. error |
---|
| 685 | * |
---|
| 686 | * @return true if abs. err. is to be minimized |
---|
| 687 | */ |
---|
| 688 | public boolean getMinimizeAbsoluteError() { |
---|
| 689 | |
---|
| 690 | return m_MinimizeAbsoluteError; |
---|
| 691 | } |
---|
| 692 | |
---|
| 693 | /** |
---|
| 694 | * Sets whether to min. abs. error. |
---|
| 695 | * |
---|
| 696 | * @param b if true, abs. err. is minimized |
---|
| 697 | */ |
---|
| 698 | public void setMinimizeAbsoluteError(boolean b) { |
---|
| 699 | |
---|
| 700 | m_MinimizeAbsoluteError = b; |
---|
| 701 | } |
---|
| 702 | |
---|
| 703 | /** |
---|
| 704 | * Returns the tip text for this property |
---|
| 705 | * |
---|
| 706 | * @return tip text for this property suitable for |
---|
| 707 | * displaying in the explorer/experimenter gui |
---|
| 708 | */ |
---|
| 709 | public String useEqualFrequencyTipText() { |
---|
| 710 | |
---|
| 711 | return "If set to true, equal-frequency binning will be used instead of" + |
---|
| 712 | " equal-width binning."; |
---|
| 713 | } |
---|
| 714 | |
---|
| 715 | /** |
---|
| 716 | * Get the value of UseEqualFrequency. |
---|
| 717 | * |
---|
| 718 | * @return Value of UseEqualFrequency. |
---|
| 719 | */ |
---|
| 720 | public boolean getUseEqualFrequency() { |
---|
| 721 | |
---|
| 722 | return m_UseEqualFrequency; |
---|
| 723 | } |
---|
| 724 | |
---|
| 725 | /** |
---|
| 726 | * Set the value of UseEqualFrequency. |
---|
| 727 | * |
---|
| 728 | * @param newUseEqualFrequency Value to assign to UseEqualFrequency. |
---|
| 729 | */ |
---|
| 730 | public void setUseEqualFrequency(boolean newUseEqualFrequency) { |
---|
| 731 | |
---|
| 732 | m_UseEqualFrequency = newUseEqualFrequency; |
---|
| 733 | } |
---|
| 734 | |
---|
| 735 | /** |
---|
| 736 | * Returns the tip text for this property |
---|
| 737 | * |
---|
| 738 | * @return tip text for this property suitable for |
---|
| 739 | * displaying in the explorer/experimenter gui |
---|
| 740 | */ |
---|
| 741 | public String estimatorTypeTipText() { |
---|
| 742 | |
---|
| 743 | return "The density estimator to use."; |
---|
| 744 | } |
---|
| 745 | |
---|
| 746 | /** |
---|
| 747 | * Get the estimator type |
---|
| 748 | * |
---|
| 749 | * @return the estimator type |
---|
| 750 | */ |
---|
| 751 | public SelectedTag getEstimatorType() { |
---|
| 752 | |
---|
| 753 | return new SelectedTag(m_estimatorType, TAGS_ESTIMATOR); |
---|
| 754 | } |
---|
| 755 | |
---|
| 756 | /** |
---|
| 757 | * Set the estimator |
---|
| 758 | * |
---|
| 759 | * @param newEstimator the estimator to use |
---|
| 760 | */ |
---|
| 761 | public void setEstimatorType(SelectedTag newEstimator) { |
---|
| 762 | |
---|
| 763 | if (newEstimator.getTags() == TAGS_ESTIMATOR) { |
---|
| 764 | m_estimatorType = newEstimator.getSelectedTag().getID(); |
---|
| 765 | } |
---|
| 766 | } |
---|
| 767 | |
---|
| 768 | /** |
---|
| 769 | * Returns a description of the classifier. |
---|
| 770 | * |
---|
| 771 | * @return a description of the classifier as a string. |
---|
| 772 | */ |
---|
| 773 | public String toString() { |
---|
| 774 | |
---|
| 775 | StringBuffer text = new StringBuffer(); |
---|
| 776 | |
---|
| 777 | text.append("Regression by discretization"); |
---|
| 778 | if (m_ClassMeans == null) { |
---|
| 779 | text.append(": No model built yet."); |
---|
| 780 | } else { |
---|
| 781 | text.append("\n\nClass attribute discretized into " |
---|
| 782 | + m_ClassMeans.length + " values\n"); |
---|
| 783 | |
---|
| 784 | text.append("\nClassifier spec: " + getClassifierSpec() |
---|
| 785 | + "\n"); |
---|
| 786 | text.append(m_Classifier.toString()); |
---|
| 787 | } |
---|
| 788 | return text.toString(); |
---|
| 789 | } |
---|
| 790 | |
---|
| 791 | /** |
---|
| 792 | * Returns the revision string. |
---|
| 793 | * |
---|
| 794 | * @return the revision |
---|
| 795 | */ |
---|
| 796 | public String getRevision() { |
---|
| 797 | return RevisionUtils.extract("$Revision: 5925 $"); |
---|
| 798 | } |
---|
| 799 | |
---|
| 800 | /** |
---|
| 801 | * Main method for testing this class. |
---|
| 802 | * |
---|
| 803 | * @param argv the options |
---|
| 804 | */ |
---|
| 805 | public static void main(String [] argv) { |
---|
| 806 | runClassifier(new RegressionByDiscretization(), argv); |
---|
| 807 | } |
---|
| 808 | } |
---|
| 809 | |
---|