[4] | 1 | /* |
---|
| 2 | * This program is free software; you can redistribute it and/or modify |
---|
| 3 | * it under the terms of the GNU General Public License as published by |
---|
| 4 | * the Free Software Foundation; either version 2 of the License, or |
---|
| 5 | * (at your option) any later version. |
---|
| 6 | * |
---|
| 7 | * This program is distributed in the hope that it will be useful, |
---|
| 8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
| 9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
---|
| 10 | * GNU General Public License for more details. |
---|
| 11 | * |
---|
| 12 | * You should have received a copy of the GNU General Public License |
---|
| 13 | * along with this program; if not, write to the Free Software |
---|
| 14 | * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. |
---|
| 15 | */ |
---|
| 16 | |
---|
| 17 | /* |
---|
| 18 | * ReliefFAttributeEval.java |
---|
| 19 | * Copyright (C) 1999 University of Waikato, Hamilton, New Zealand |
---|
| 20 | * |
---|
| 21 | */ |
---|
| 22 | |
---|
| 23 | package weka.attributeSelection; |
---|
| 24 | |
---|
| 25 | import weka.core.Attribute; |
---|
| 26 | import weka.core.Capabilities; |
---|
| 27 | import weka.core.Instance; |
---|
| 28 | import weka.core.Instances; |
---|
| 29 | import weka.core.Option; |
---|
| 30 | import weka.core.OptionHandler; |
---|
| 31 | import weka.core.RevisionUtils; |
---|
| 32 | import weka.core.TechnicalInformation; |
---|
| 33 | import weka.core.TechnicalInformationHandler; |
---|
| 34 | import weka.core.Utils; |
---|
| 35 | import weka.core.Capabilities.Capability; |
---|
| 36 | import weka.core.TechnicalInformation.Field; |
---|
| 37 | import weka.core.TechnicalInformation.Type; |
---|
| 38 | |
---|
| 39 | import java.util.Enumeration; |
---|
| 40 | import java.util.Random; |
---|
| 41 | import java.util.Vector; |
---|
| 42 | |
---|
| 43 | /** |
---|
| 44 | <!-- globalinfo-start --> |
---|
| 45 | * ReliefFAttributeEval :<br/> |
---|
| 46 | * <br/> |
---|
| 47 | * Evaluates the worth of an attribute by repeatedly sampling an instance and considering the value of the given attribute for the nearest instance of the same and different class. Can operate on both discrete and continuous class data.<br/> |
---|
| 48 | * <br/> |
---|
| 49 | * For more information see:<br/> |
---|
| 50 | * <br/> |
---|
| 51 | * Kenji Kira, Larry A. Rendell: A Practical Approach to Feature Selection. In: Ninth International Workshop on Machine Learning, 249-256, 1992.<br/> |
---|
| 52 | * <br/> |
---|
| 53 | * Igor Kononenko: Estimating Attributes: Analysis and Extensions of RELIEF. In: European Conference on Machine Learning, 171-182, 1994.<br/> |
---|
| 54 | * <br/> |
---|
| 55 | * Marko Robnik-Sikonja, Igor Kononenko: An adaptation of Relief for attribute estimation in regression. In: Fourteenth International Conference on Machine Learning, 296-304, 1997. |
---|
| 56 | * <p/> |
---|
| 57 | <!-- globalinfo-end --> |
---|
| 58 | * |
---|
| 59 | <!-- technical-bibtex-start --> |
---|
| 60 | * BibTeX: |
---|
| 61 | * <pre> |
---|
| 62 | * @inproceedings{Kira1992, |
---|
| 63 | * author = {Kenji Kira and Larry A. Rendell}, |
---|
| 64 | * booktitle = {Ninth International Workshop on Machine Learning}, |
---|
| 65 | * editor = {Derek H. Sleeman and Peter Edwards}, |
---|
| 66 | * pages = {249-256}, |
---|
| 67 | * publisher = {Morgan Kaufmann}, |
---|
| 68 | * title = {A Practical Approach to Feature Selection}, |
---|
| 69 | * year = {1992} |
---|
| 70 | * } |
---|
| 71 | * |
---|
| 72 | * @inproceedings{Kononenko1994, |
---|
| 73 | * author = {Igor Kononenko}, |
---|
| 74 | * booktitle = {European Conference on Machine Learning}, |
---|
| 75 | * editor = {Francesco Bergadano and Luc De Raedt}, |
---|
| 76 | * pages = {171-182}, |
---|
| 77 | * publisher = {Springer}, |
---|
| 78 | * title = {Estimating Attributes: Analysis and Extensions of RELIEF}, |
---|
| 79 | * year = {1994} |
---|
| 80 | * } |
---|
| 81 | * |
---|
| 82 | * @inproceedings{Robnik-Sikonja1997, |
---|
| 83 | * author = {Marko Robnik-Sikonja and Igor Kononenko}, |
---|
| 84 | * booktitle = {Fourteenth International Conference on Machine Learning}, |
---|
| 85 | * editor = {Douglas H. Fisher}, |
---|
| 86 | * pages = {296-304}, |
---|
| 87 | * publisher = {Morgan Kaufmann}, |
---|
| 88 | * title = {An adaptation of Relief for attribute estimation in regression}, |
---|
| 89 | * year = {1997} |
---|
| 90 | * } |
---|
| 91 | * </pre> |
---|
| 92 | * <p/> |
---|
| 93 | <!-- technical-bibtex-end --> |
---|
| 94 | * |
---|
| 95 | <!-- options-start --> |
---|
| 96 | * Valid options are: <p/> |
---|
| 97 | * |
---|
| 98 | * <pre> -M <num instances> |
---|
| 99 | * Specify the number of instances to |
---|
| 100 | * sample when estimating attributes. |
---|
| 101 | * If not specified, then all instances |
---|
| 102 | * will be used.</pre> |
---|
| 103 | * |
---|
| 104 | * <pre> -D <seed> |
---|
| 105 | * Seed for randomly sampling instances. |
---|
| 106 | * (Default = 1)</pre> |
---|
| 107 | * |
---|
| 108 | * <pre> -K <number of neighbours> |
---|
| 109 | * Number of nearest neighbours (k) used |
---|
| 110 | * to estimate attribute relevances |
---|
| 111 | * (Default = 10).</pre> |
---|
| 112 | * |
---|
| 113 | * <pre> -W |
---|
| 114 | * Weight nearest neighbours by distance</pre> |
---|
| 115 | * |
---|
| 116 | * <pre> -A <num> |
---|
| 117 | * Specify sigma value (used in an exp |
---|
| 118 | * function to control how quickly |
---|
| 119 | * weights for more distant instances |
---|
| 120 | * decrease. Use in conjunction with -W. |
---|
| 121 | * Sensible value=1/5 to 1/10 of the |
---|
| 122 | * number of nearest neighbours. |
---|
| 123 | * (Default = 2)</pre> |
---|
| 124 | * |
---|
| 125 | <!-- options-end --> |
---|
| 126 | * |
---|
| 127 | * @author Mark Hall (mhall@cs.waikato.ac.nz) |
---|
| 128 | * @version $Revision: 5987 $ |
---|
| 129 | */ |
---|
| 130 | public class ReliefFAttributeEval |
---|
| 131 | extends ASEvaluation |
---|
| 132 | implements AttributeEvaluator, |
---|
| 133 | OptionHandler, |
---|
| 134 | TechnicalInformationHandler { |
---|
| 135 | |
---|
| 136 | /** for serialization */ |
---|
| 137 | static final long serialVersionUID = -8422186665795839379L; |
---|
| 138 | |
---|
| 139 | /** The training instances */ |
---|
| 140 | private Instances m_trainInstances; |
---|
| 141 | |
---|
| 142 | /** The class index */ |
---|
| 143 | private int m_classIndex; |
---|
| 144 | |
---|
| 145 | /** The number of attributes */ |
---|
| 146 | private int m_numAttribs; |
---|
| 147 | |
---|
| 148 | /** The number of instances */ |
---|
| 149 | private int m_numInstances; |
---|
| 150 | |
---|
| 151 | /** Numeric class */ |
---|
| 152 | private boolean m_numericClass; |
---|
| 153 | |
---|
| 154 | /** The number of classes if class is nominal */ |
---|
| 155 | private int m_numClasses; |
---|
| 156 | |
---|
| 157 | /** |
---|
| 158 | * Used to hold the probability of a different class val given nearest |
---|
| 159 | * instances (numeric class) |
---|
| 160 | */ |
---|
| 161 | private double m_ndc; |
---|
| 162 | |
---|
| 163 | /** |
---|
| 164 | * Used to hold the prob of different value of an attribute given |
---|
| 165 | * nearest instances (numeric class case) |
---|
| 166 | */ |
---|
| 167 | private double[] m_nda; |
---|
| 168 | |
---|
| 169 | /** |
---|
| 170 | * Used to hold the prob of a different class val and different att |
---|
| 171 | * val given nearest instances (numeric class case) |
---|
| 172 | */ |
---|
| 173 | private double[] m_ndcda; |
---|
| 174 | |
---|
| 175 | /** Holds the weights that relief assigns to attributes */ |
---|
| 176 | private double[] m_weights; |
---|
| 177 | |
---|
| 178 | /** Prior class probabilities (discrete class case) */ |
---|
| 179 | private double[] m_classProbs; |
---|
| 180 | |
---|
| 181 | /** |
---|
| 182 | * The number of instances to sample when estimating attributes |
---|
| 183 | * default == -1, use all instances |
---|
| 184 | */ |
---|
| 185 | private int m_sampleM; |
---|
| 186 | |
---|
| 187 | /** The number of nearest hits/misses */ |
---|
| 188 | private int m_Knn; |
---|
| 189 | |
---|
| 190 | /** k nearest scores + instance indexes for n classes */ |
---|
| 191 | private double[][][] m_karray; |
---|
| 192 | |
---|
| 193 | /** Upper bound for numeric attributes */ |
---|
| 194 | private double[] m_maxArray; |
---|
| 195 | |
---|
| 196 | /** Lower bound for numeric attributes */ |
---|
| 197 | private double[] m_minArray; |
---|
| 198 | |
---|
| 199 | /** Keep track of the farthest instance for each class */ |
---|
| 200 | private double[] m_worst; |
---|
| 201 | |
---|
| 202 | /** Index in the m_karray of the farthest instance for each class */ |
---|
| 203 | private int[] m_index; |
---|
| 204 | |
---|
| 205 | /** Number of nearest neighbours stored of each class */ |
---|
| 206 | private int[] m_stored; |
---|
| 207 | |
---|
| 208 | /** Random number seed used for sampling instances */ |
---|
| 209 | private int m_seed; |
---|
| 210 | |
---|
| 211 | /** |
---|
| 212 | * used to (optionally) weight nearest neighbours by their distance |
---|
| 213 | * from the instance in question. Each entry holds |
---|
| 214 | * exp(-((rank(r_i, i_j)/sigma)^2)) where rank(r_i,i_j) is the rank of |
---|
| 215 | * instance i_j in a sequence of instances ordered by the distance |
---|
| 216 | * from r_i. sigma is a user defined parameter, default=20 |
---|
| 217 | **/ |
---|
| 218 | private double[] m_weightsByRank; |
---|
| 219 | private int m_sigma; |
---|
| 220 | |
---|
| 221 | /** Weight by distance rather than equal weights */ |
---|
| 222 | private boolean m_weightByDistance; |
---|
| 223 | |
---|
| 224 | /** |
---|
| 225 | * Constructor |
---|
| 226 | */ |
---|
| 227 | public ReliefFAttributeEval () { |
---|
| 228 | resetOptions(); |
---|
| 229 | } |
---|
| 230 | |
---|
| 231 | /** |
---|
| 232 | * Returns a string describing this attribute evaluator |
---|
| 233 | * @return a description of the evaluator suitable for |
---|
| 234 | * displaying in the explorer/experimenter gui |
---|
| 235 | */ |
---|
| 236 | public String globalInfo() { |
---|
| 237 | return "ReliefFAttributeEval :\n\nEvaluates the worth of an attribute by " |
---|
| 238 | +"repeatedly sampling an instance and considering the value of the " |
---|
| 239 | +"given attribute for the nearest instance of the same and different " |
---|
| 240 | +"class. Can operate on both discrete and continuous class data.\n\n" |
---|
| 241 | + "For more information see:\n\n" |
---|
| 242 | + getTechnicalInformation().toString(); |
---|
| 243 | } |
---|
| 244 | |
---|
| 245 | /** |
---|
| 246 | * Returns an instance of a TechnicalInformation object, containing |
---|
| 247 | * detailed information about the technical background of this class, |
---|
| 248 | * e.g., paper reference or book this class is based on. |
---|
| 249 | * |
---|
| 250 | * @return the technical information about this class |
---|
| 251 | */ |
---|
| 252 | public TechnicalInformation getTechnicalInformation() { |
---|
| 253 | TechnicalInformation result; |
---|
| 254 | TechnicalInformation additional; |
---|
| 255 | |
---|
| 256 | result = new TechnicalInformation(Type.INPROCEEDINGS); |
---|
| 257 | result.setValue(Field.AUTHOR, "Kenji Kira and Larry A. Rendell"); |
---|
| 258 | result.setValue(Field.TITLE, "A Practical Approach to Feature Selection"); |
---|
| 259 | result.setValue(Field.BOOKTITLE, "Ninth International Workshop on Machine Learning"); |
---|
| 260 | result.setValue(Field.EDITOR, "Derek H. Sleeman and Peter Edwards"); |
---|
| 261 | result.setValue(Field.YEAR, "1992"); |
---|
| 262 | result.setValue(Field.PAGES, "249-256"); |
---|
| 263 | result.setValue(Field.PUBLISHER, "Morgan Kaufmann"); |
---|
| 264 | |
---|
| 265 | additional = result.add(Type.INPROCEEDINGS); |
---|
| 266 | additional.setValue(Field.AUTHOR, "Igor Kononenko"); |
---|
| 267 | additional.setValue(Field.TITLE, "Estimating Attributes: Analysis and Extensions of RELIEF"); |
---|
| 268 | additional.setValue(Field.BOOKTITLE, "European Conference on Machine Learning"); |
---|
| 269 | additional.setValue(Field.EDITOR, "Francesco Bergadano and Luc De Raedt"); |
---|
| 270 | additional.setValue(Field.YEAR, "1994"); |
---|
| 271 | additional.setValue(Field.PAGES, "171-182"); |
---|
| 272 | additional.setValue(Field.PUBLISHER, "Springer"); |
---|
| 273 | |
---|
| 274 | additional = result.add(Type.INPROCEEDINGS); |
---|
| 275 | additional.setValue(Field.AUTHOR, "Marko Robnik-Sikonja and Igor Kononenko"); |
---|
| 276 | additional.setValue(Field.TITLE, "An adaptation of Relief for attribute estimation in regression"); |
---|
| 277 | additional.setValue(Field.BOOKTITLE, "Fourteenth International Conference on Machine Learning"); |
---|
| 278 | additional.setValue(Field.EDITOR, "Douglas H. Fisher"); |
---|
| 279 | additional.setValue(Field.YEAR, "1997"); |
---|
| 280 | additional.setValue(Field.PAGES, "296-304"); |
---|
| 281 | additional.setValue(Field.PUBLISHER, "Morgan Kaufmann"); |
---|
| 282 | |
---|
| 283 | return result; |
---|
| 284 | } |
---|
| 285 | |
---|
| 286 | /** |
---|
| 287 | * Returns an enumeration describing the available options. |
---|
| 288 | * @return an enumeration of all the available options. |
---|
| 289 | **/ |
---|
| 290 | public Enumeration listOptions () { |
---|
| 291 | Vector newVector = new Vector(4); |
---|
| 292 | newVector |
---|
| 293 | .addElement(new Option("\tSpecify the number of instances to\n" |
---|
| 294 | + "\tsample when estimating attributes.\n" |
---|
| 295 | + "\tIf not specified, then all instances\n" |
---|
| 296 | + "\twill be used.", "M", 1 |
---|
| 297 | , "-M <num instances>")); |
---|
| 298 | newVector. |
---|
| 299 | addElement(new Option("\tSeed for randomly sampling instances.\n" |
---|
| 300 | + "\t(Default = 1)", "D", 1 |
---|
| 301 | , "-D <seed>")); |
---|
| 302 | newVector. |
---|
| 303 | addElement(new Option("\tNumber of nearest neighbours (k) used\n" |
---|
| 304 | + "\tto estimate attribute relevances\n" |
---|
| 305 | + "\t(Default = 10).", "K", 1 |
---|
| 306 | , "-K <number of neighbours>")); |
---|
| 307 | newVector. |
---|
| 308 | addElement(new Option("\tWeight nearest neighbours by distance", "W" |
---|
| 309 | , 0, "-W")); |
---|
| 310 | newVector. |
---|
| 311 | addElement(new Option("\tSpecify sigma value (used in an exp\n" |
---|
| 312 | + "\tfunction to control how quickly\n" |
---|
| 313 | + "\tweights for more distant instances\n" |
---|
| 314 | + "\tdecrease. Use in conjunction with -W.\n" |
---|
| 315 | + "\tSensible value=1/5 to 1/10 of the\n" |
---|
| 316 | + "\tnumber of nearest neighbours.\n" |
---|
| 317 | + "\t(Default = 2)", "A", 1, "-A <num>")); |
---|
| 318 | return newVector.elements(); |
---|
| 319 | } |
---|
| 320 | |
---|
| 321 | |
---|
| 322 | /** |
---|
| 323 | * Parses a given list of options. <p/> |
---|
| 324 | * |
---|
| 325 | <!-- options-start --> |
---|
| 326 | * Valid options are: <p/> |
---|
| 327 | * |
---|
| 328 | * <pre> -M <num instances> |
---|
| 329 | * Specify the number of instances to |
---|
| 330 | * sample when estimating attributes. |
---|
| 331 | * If not specified, then all instances |
---|
| 332 | * will be used.</pre> |
---|
| 333 | * |
---|
| 334 | * <pre> -D <seed> |
---|
| 335 | * Seed for randomly sampling instances. |
---|
| 336 | * (Default = 1)</pre> |
---|
| 337 | * |
---|
| 338 | * <pre> -K <number of neighbours> |
---|
| 339 | * Number of nearest neighbours (k) used |
---|
| 340 | * to estimate attribute relevances |
---|
| 341 | * (Default = 10).</pre> |
---|
| 342 | * |
---|
| 343 | * <pre> -W |
---|
| 344 | * Weight nearest neighbours by distance</pre> |
---|
| 345 | * |
---|
| 346 | * <pre> -A <num> |
---|
| 347 | * Specify sigma value (used in an exp |
---|
| 348 | * function to control how quickly |
---|
| 349 | * weights for more distant instances |
---|
| 350 | * decrease. Use in conjunction with -W. |
---|
| 351 | * Sensible value=1/5 to 1/10 of the |
---|
| 352 | * number of nearest neighbours. |
---|
| 353 | * (Default = 2)</pre> |
---|
| 354 | * |
---|
| 355 | <!-- options-end --> |
---|
| 356 | * |
---|
| 357 | * @param options the list of options as an array of strings |
---|
| 358 | * @throws Exception if an option is not supported |
---|
| 359 | */ |
---|
| 360 | public void setOptions (String[] options) |
---|
| 361 | throws Exception { |
---|
| 362 | String optionString; |
---|
| 363 | resetOptions(); |
---|
| 364 | setWeightByDistance(Utils.getFlag('W', options)); |
---|
| 365 | optionString = Utils.getOption('M', options); |
---|
| 366 | |
---|
| 367 | if (optionString.length() != 0) { |
---|
| 368 | setSampleSize(Integer.parseInt(optionString)); |
---|
| 369 | } |
---|
| 370 | |
---|
| 371 | optionString = Utils.getOption('D', options); |
---|
| 372 | |
---|
| 373 | if (optionString.length() != 0) { |
---|
| 374 | setSeed(Integer.parseInt(optionString)); |
---|
| 375 | } |
---|
| 376 | |
---|
| 377 | optionString = Utils.getOption('K', options); |
---|
| 378 | |
---|
| 379 | if (optionString.length() != 0) { |
---|
| 380 | setNumNeighbours(Integer.parseInt(optionString)); |
---|
| 381 | } |
---|
| 382 | |
---|
| 383 | optionString = Utils.getOption('A', options); |
---|
| 384 | |
---|
| 385 | if (optionString.length() != 0) { |
---|
| 386 | setWeightByDistance(true); // turn on weighting by distance |
---|
| 387 | setSigma(Integer.parseInt(optionString)); |
---|
| 388 | } |
---|
| 389 | } |
---|
| 390 | |
---|
| 391 | /** |
---|
| 392 | * Returns the tip text for this property |
---|
| 393 | * @return tip text for this property suitable for |
---|
| 394 | * displaying in the explorer/experimenter gui |
---|
| 395 | */ |
---|
| 396 | public String sigmaTipText() { |
---|
| 397 | return "Set influence of nearest neighbours. Used in an exp function to " |
---|
| 398 | +"control how quickly weights decrease for more distant instances. " |
---|
| 399 | +"Use in conjunction with weightByDistance. Sensible values = 1/5 to " |
---|
| 400 | +"1/10 the number of nearest neighbours."; |
---|
| 401 | } |
---|
| 402 | |
---|
| 403 | /** |
---|
| 404 | * Sets the sigma value. |
---|
| 405 | * |
---|
| 406 | * @param s the value of sigma (> 0) |
---|
| 407 | * @throws Exception if s is not positive |
---|
| 408 | */ |
---|
| 409 | public void setSigma (int s) |
---|
| 410 | throws Exception { |
---|
| 411 | if (s <= 0) { |
---|
| 412 | throw new Exception("value of sigma must be > 0!"); |
---|
| 413 | } |
---|
| 414 | |
---|
| 415 | m_sigma = s; |
---|
| 416 | } |
---|
| 417 | |
---|
| 418 | |
---|
| 419 | /** |
---|
| 420 | * Get the value of sigma. |
---|
| 421 | * |
---|
| 422 | * @return the sigma value. |
---|
| 423 | */ |
---|
| 424 | public int getSigma () { |
---|
| 425 | return m_sigma; |
---|
| 426 | } |
---|
| 427 | |
---|
| 428 | /** |
---|
| 429 | * Returns the tip text for this property |
---|
| 430 | * @return tip text for this property suitable for |
---|
| 431 | * displaying in the explorer/experimenter gui |
---|
| 432 | */ |
---|
| 433 | public String numNeighboursTipText() { |
---|
| 434 | return "Number of nearest neighbours for attribute estimation."; |
---|
| 435 | } |
---|
| 436 | |
---|
| 437 | /** |
---|
| 438 | * Set the number of nearest neighbours |
---|
| 439 | * |
---|
| 440 | * @param n the number of nearest neighbours. |
---|
| 441 | */ |
---|
| 442 | public void setNumNeighbours (int n) { |
---|
| 443 | m_Knn = n; |
---|
| 444 | } |
---|
| 445 | |
---|
| 446 | |
---|
| 447 | /** |
---|
| 448 | * Get the number of nearest neighbours |
---|
| 449 | * |
---|
| 450 | * @return the number of nearest neighbours |
---|
| 451 | */ |
---|
| 452 | public int getNumNeighbours () { |
---|
| 453 | return m_Knn; |
---|
| 454 | } |
---|
| 455 | |
---|
| 456 | /** |
---|
| 457 | * Returns the tip text for this property |
---|
| 458 | * @return tip text for this property suitable for |
---|
| 459 | * displaying in the explorer/experimenter gui |
---|
| 460 | */ |
---|
| 461 | public String seedTipText() { |
---|
| 462 | return "Random seed for sampling instances."; |
---|
| 463 | } |
---|
| 464 | |
---|
| 465 | /** |
---|
| 466 | * Set the random number seed for randomly sampling instances. |
---|
| 467 | * |
---|
| 468 | * @param s the random number seed. |
---|
| 469 | */ |
---|
| 470 | public void setSeed (int s) { |
---|
| 471 | m_seed = s; |
---|
| 472 | } |
---|
| 473 | |
---|
| 474 | |
---|
| 475 | /** |
---|
| 476 | * Get the seed used for randomly sampling instances. |
---|
| 477 | * |
---|
| 478 | * @return the random number seed. |
---|
| 479 | */ |
---|
| 480 | public int getSeed () { |
---|
| 481 | return m_seed; |
---|
| 482 | } |
---|
| 483 | |
---|
| 484 | /** |
---|
| 485 | * Returns the tip text for this property |
---|
| 486 | * @return tip text for this property suitable for |
---|
| 487 | * displaying in the explorer/experimenter gui |
---|
| 488 | */ |
---|
| 489 | public String sampleSizeTipText() { |
---|
| 490 | return "Number of instances to sample. Default (-1) indicates that all " |
---|
| 491 | +"instances will be used for attribute estimation."; |
---|
| 492 | } |
---|
| 493 | |
---|
| 494 | /** |
---|
| 495 | * Set the number of instances to sample for attribute estimation |
---|
| 496 | * |
---|
| 497 | * @param s the number of instances to sample. |
---|
| 498 | */ |
---|
| 499 | public void setSampleSize (int s) { |
---|
| 500 | m_sampleM = s; |
---|
| 501 | } |
---|
| 502 | |
---|
| 503 | |
---|
| 504 | /** |
---|
| 505 | * Get the number of instances used for estimating attributes |
---|
| 506 | * |
---|
| 507 | * @return the number of instances. |
---|
| 508 | */ |
---|
| 509 | public int getSampleSize () { |
---|
| 510 | return m_sampleM; |
---|
| 511 | } |
---|
| 512 | |
---|
| 513 | /** |
---|
| 514 | * Returns the tip text for this property |
---|
| 515 | * @return tip text for this property suitable for |
---|
| 516 | * displaying in the explorer/experimenter gui |
---|
| 517 | */ |
---|
| 518 | public String weightByDistanceTipText() { |
---|
| 519 | return "Weight nearest neighbours by their distance."; |
---|
| 520 | } |
---|
| 521 | |
---|
| 522 | /** |
---|
| 523 | * Set the nearest neighbour weighting method |
---|
| 524 | * |
---|
| 525 | * @param b true nearest neighbours are to be weighted by distance. |
---|
| 526 | */ |
---|
| 527 | public void setWeightByDistance (boolean b) { |
---|
| 528 | m_weightByDistance = b; |
---|
| 529 | } |
---|
| 530 | |
---|
| 531 | |
---|
| 532 | /** |
---|
| 533 | * Get whether nearest neighbours are being weighted by distance |
---|
| 534 | * |
---|
| 535 | * @return m_weightByDiffernce |
---|
| 536 | */ |
---|
| 537 | public boolean getWeightByDistance () { |
---|
| 538 | return m_weightByDistance; |
---|
| 539 | } |
---|
| 540 | |
---|
| 541 | |
---|
| 542 | /** |
---|
| 543 | * Gets the current settings of ReliefFAttributeEval. |
---|
| 544 | * |
---|
| 545 | * @return an array of strings suitable for passing to setOptions() |
---|
| 546 | */ |
---|
| 547 | public String[] getOptions () { |
---|
| 548 | String[] options = new String[9]; |
---|
| 549 | int current = 0; |
---|
| 550 | |
---|
| 551 | if (getWeightByDistance()) { |
---|
| 552 | options[current++] = "-W"; |
---|
| 553 | } |
---|
| 554 | |
---|
| 555 | options[current++] = "-M"; |
---|
| 556 | options[current++] = "" + getSampleSize(); |
---|
| 557 | options[current++] = "-D"; |
---|
| 558 | options[current++] = "" + getSeed(); |
---|
| 559 | options[current++] = "-K"; |
---|
| 560 | options[current++] = "" + getNumNeighbours(); |
---|
| 561 | |
---|
| 562 | if (getWeightByDistance()) { |
---|
| 563 | options[current++] = "-A"; |
---|
| 564 | options[current++] = "" + getSigma(); |
---|
| 565 | } |
---|
| 566 | |
---|
| 567 | while (current < options.length) { |
---|
| 568 | options[current++] = ""; |
---|
| 569 | } |
---|
| 570 | |
---|
| 571 | return options; |
---|
| 572 | } |
---|
| 573 | |
---|
| 574 | |
---|
| 575 | /** |
---|
| 576 | * Return a description of the ReliefF attribute evaluator. |
---|
| 577 | * |
---|
| 578 | * @return a description of the evaluator as a String. |
---|
| 579 | */ |
---|
| 580 | public String toString () { |
---|
| 581 | StringBuffer text = new StringBuffer(); |
---|
| 582 | |
---|
| 583 | if (m_trainInstances == null) { |
---|
| 584 | text.append("ReliefF feature evaluator has not been built yet\n"); |
---|
| 585 | } |
---|
| 586 | else { |
---|
| 587 | text.append("\tReliefF Ranking Filter"); |
---|
| 588 | text.append("\n\tInstances sampled: "); |
---|
| 589 | |
---|
| 590 | if (m_sampleM == -1) { |
---|
| 591 | text.append("all\n"); |
---|
| 592 | } |
---|
| 593 | else { |
---|
| 594 | text.append(m_sampleM + "\n"); |
---|
| 595 | } |
---|
| 596 | |
---|
| 597 | text.append("\tNumber of nearest neighbours (k): " + m_Knn + "\n"); |
---|
| 598 | |
---|
| 599 | if (m_weightByDistance) { |
---|
| 600 | text.append("\tExponentially decreasing (with distance) " |
---|
| 601 | + "influence for\n" |
---|
| 602 | + "\tnearest neighbours. Sigma: " |
---|
| 603 | + m_sigma + "\n"); |
---|
| 604 | } |
---|
| 605 | else { |
---|
| 606 | text.append("\tEqual influence nearest neighbours\n"); |
---|
| 607 | } |
---|
| 608 | } |
---|
| 609 | |
---|
| 610 | return text.toString(); |
---|
| 611 | } |
---|
| 612 | |
---|
| 613 | /** |
---|
| 614 | * Returns the capabilities of this evaluator. |
---|
| 615 | * |
---|
| 616 | * @return the capabilities of this evaluator |
---|
| 617 | * @see Capabilities |
---|
| 618 | */ |
---|
| 619 | public Capabilities getCapabilities() { |
---|
| 620 | Capabilities result = super.getCapabilities(); |
---|
| 621 | result.disableAll(); |
---|
| 622 | |
---|
| 623 | // attributes |
---|
| 624 | result.enable(Capability.NOMINAL_ATTRIBUTES); |
---|
| 625 | result.enable(Capability.NUMERIC_ATTRIBUTES); |
---|
| 626 | result.enable(Capability.DATE_ATTRIBUTES); |
---|
| 627 | result.enable(Capability.MISSING_VALUES); |
---|
| 628 | |
---|
| 629 | // class |
---|
| 630 | result.enable(Capability.NOMINAL_CLASS); |
---|
| 631 | result.enable(Capability.NUMERIC_CLASS); |
---|
| 632 | result.enable(Capability.DATE_CLASS); |
---|
| 633 | result.enable(Capability.MISSING_CLASS_VALUES); |
---|
| 634 | |
---|
| 635 | return result; |
---|
| 636 | } |
---|
| 637 | |
---|
| 638 | /** |
---|
| 639 | * Initializes a ReliefF attribute evaluator. |
---|
| 640 | * |
---|
| 641 | * @param data set of instances serving as training data |
---|
| 642 | * @throws Exception if the evaluator has not been |
---|
| 643 | * generated successfully |
---|
| 644 | */ |
---|
| 645 | public void buildEvaluator (Instances data) |
---|
| 646 | throws Exception { |
---|
| 647 | |
---|
| 648 | int z, totalInstances; |
---|
| 649 | Random r = new Random(m_seed); |
---|
| 650 | |
---|
| 651 | // can evaluator handle data? |
---|
| 652 | getCapabilities().testWithFail(data); |
---|
| 653 | |
---|
| 654 | m_trainInstances = data; |
---|
| 655 | m_classIndex = m_trainInstances.classIndex(); |
---|
| 656 | m_numAttribs = m_trainInstances.numAttributes(); |
---|
| 657 | m_numInstances = m_trainInstances.numInstances(); |
---|
| 658 | |
---|
| 659 | if (m_trainInstances.attribute(m_classIndex).isNumeric()) { |
---|
| 660 | m_numericClass = true; |
---|
| 661 | } |
---|
| 662 | else { |
---|
| 663 | m_numericClass = false; |
---|
| 664 | } |
---|
| 665 | |
---|
| 666 | if (!m_numericClass) { |
---|
| 667 | m_numClasses = m_trainInstances.attribute(m_classIndex).numValues(); |
---|
| 668 | } |
---|
| 669 | else { |
---|
| 670 | m_ndc = 0; |
---|
| 671 | m_numClasses = 1; |
---|
| 672 | m_nda = new double[m_numAttribs]; |
---|
| 673 | m_ndcda = new double[m_numAttribs]; |
---|
| 674 | } |
---|
| 675 | |
---|
| 676 | if (m_weightByDistance) // set up the rank based weights |
---|
| 677 | { |
---|
| 678 | m_weightsByRank = new double[m_Knn]; |
---|
| 679 | |
---|
| 680 | for (int i = 0; i < m_Knn; i++) { |
---|
| 681 | m_weightsByRank[i] = |
---|
| 682 | Math.exp(-((i/(double)m_sigma)*(i/(double)m_sigma))); |
---|
| 683 | } |
---|
| 684 | } |
---|
| 685 | |
---|
| 686 | // the final attribute weights |
---|
| 687 | m_weights = new double[m_numAttribs]; |
---|
| 688 | // num classes (1 for numeric class) knn neighbours, |
---|
| 689 | // and 0 = distance, 1 = instance index |
---|
| 690 | m_karray = new double[m_numClasses][m_Knn][2]; |
---|
| 691 | |
---|
| 692 | if (!m_numericClass) { |
---|
| 693 | m_classProbs = new double[m_numClasses]; |
---|
| 694 | |
---|
| 695 | for (int i = 0; i < m_numInstances; i++) { |
---|
| 696 | m_classProbs[(int)m_trainInstances.instance(i).value(m_classIndex)]++; |
---|
| 697 | } |
---|
| 698 | |
---|
| 699 | for (int i = 0; i < m_numClasses; i++) { |
---|
| 700 | m_classProbs[i] /= m_numInstances; |
---|
| 701 | } |
---|
| 702 | } |
---|
| 703 | |
---|
| 704 | m_worst = new double[m_numClasses]; |
---|
| 705 | m_index = new int[m_numClasses]; |
---|
| 706 | m_stored = new int[m_numClasses]; |
---|
| 707 | m_minArray = new double[m_numAttribs]; |
---|
| 708 | m_maxArray = new double[m_numAttribs]; |
---|
| 709 | |
---|
| 710 | for (int i = 0; i < m_numAttribs; i++) { |
---|
| 711 | m_minArray[i] = m_maxArray[i] = Double.NaN; |
---|
| 712 | } |
---|
| 713 | |
---|
| 714 | for (int i = 0; i < m_numInstances; i++) { |
---|
| 715 | updateMinMax(m_trainInstances.instance(i)); |
---|
| 716 | } |
---|
| 717 | |
---|
| 718 | if ((m_sampleM > m_numInstances) || (m_sampleM < 0)) { |
---|
| 719 | totalInstances = m_numInstances; |
---|
| 720 | } |
---|
| 721 | else { |
---|
| 722 | totalInstances = m_sampleM; |
---|
| 723 | } |
---|
| 724 | |
---|
| 725 | // process each instance, updating attribute weights |
---|
| 726 | for (int i = 0; i < totalInstances; i++) { |
---|
| 727 | if (totalInstances == m_numInstances) { |
---|
| 728 | z = i; |
---|
| 729 | } |
---|
| 730 | else { |
---|
| 731 | z = r.nextInt()%m_numInstances; |
---|
| 732 | } |
---|
| 733 | |
---|
| 734 | if (z < 0) { |
---|
| 735 | z *= -1; |
---|
| 736 | } |
---|
| 737 | |
---|
| 738 | if (!(m_trainInstances.instance(z).isMissing(m_classIndex))) { |
---|
| 739 | // first clear the knn and worst index stuff for the classes |
---|
| 740 | for (int j = 0; j < m_numClasses; j++) { |
---|
| 741 | m_index[j] = m_stored[j] = 0; |
---|
| 742 | |
---|
| 743 | for (int k = 0; k < m_Knn; k++) { |
---|
| 744 | m_karray[j][k][0] = m_karray[j][k][1] = 0; |
---|
| 745 | } |
---|
| 746 | } |
---|
| 747 | |
---|
| 748 | findKHitMiss(z); |
---|
| 749 | |
---|
| 750 | if (m_numericClass) { |
---|
| 751 | updateWeightsNumericClass(z); |
---|
| 752 | } |
---|
| 753 | else { |
---|
| 754 | updateWeightsDiscreteClass(z); |
---|
| 755 | } |
---|
| 756 | } |
---|
| 757 | } |
---|
| 758 | |
---|
| 759 | // now scale weights by 1/m_numInstances (nominal class) or |
---|
| 760 | // calculate weights numeric class |
---|
| 761 | // System.out.println("num inst:"+m_numInstances+" r_ndc:"+r_ndc); |
---|
| 762 | for (int i = 0; i < m_numAttribs; i++) {if (i != m_classIndex) { |
---|
| 763 | if (m_numericClass) { |
---|
| 764 | m_weights[i] = m_ndcda[i]/m_ndc - |
---|
| 765 | ((m_nda[i] - m_ndcda[i])/((double)totalInstances - m_ndc)); |
---|
| 766 | } |
---|
| 767 | else { |
---|
| 768 | m_weights[i] *= (1.0/(double)totalInstances); |
---|
| 769 | } |
---|
| 770 | |
---|
| 771 | // System.out.println(r_weights[i]); |
---|
| 772 | } |
---|
| 773 | } |
---|
| 774 | } |
---|
| 775 | |
---|
| 776 | |
---|
| 777 | /** |
---|
| 778 | * Evaluates an individual attribute using ReliefF's instance based approach. |
---|
| 779 | * The actual work is done by buildEvaluator which evaluates all features. |
---|
| 780 | * |
---|
| 781 | * @param attribute the index of the attribute to be evaluated |
---|
| 782 | * @throws Exception if the attribute could not be evaluated |
---|
| 783 | */ |
---|
| 784 | public double evaluateAttribute (int attribute) |
---|
| 785 | throws Exception { |
---|
| 786 | return m_weights[attribute]; |
---|
| 787 | } |
---|
| 788 | |
---|
| 789 | |
---|
| 790 | /** |
---|
| 791 | * Reset options to their default values |
---|
| 792 | */ |
---|
| 793 | protected void resetOptions () { |
---|
| 794 | m_trainInstances = null; |
---|
| 795 | m_sampleM = -1; |
---|
| 796 | m_Knn = 10; |
---|
| 797 | m_sigma = 2; |
---|
| 798 | m_weightByDistance = false; |
---|
| 799 | m_seed = 1; |
---|
| 800 | } |
---|
| 801 | |
---|
| 802 | |
---|
| 803 | /** |
---|
| 804 | * Normalizes a given value of a numeric attribute. |
---|
| 805 | * |
---|
| 806 | * @param x the value to be normalized |
---|
| 807 | * @param i the attribute's index |
---|
| 808 | * @return the normalized value |
---|
| 809 | */ |
---|
| 810 | private double norm (double x, int i) { |
---|
| 811 | if (Double.isNaN(m_minArray[i]) || |
---|
| 812 | Utils.eq(m_maxArray[i], m_minArray[i])) { |
---|
| 813 | return 0; |
---|
| 814 | } |
---|
| 815 | else { |
---|
| 816 | return (x - m_minArray[i])/(m_maxArray[i] - m_minArray[i]); |
---|
| 817 | } |
---|
| 818 | } |
---|
| 819 | |
---|
| 820 | |
---|
| 821 | /** |
---|
| 822 | * Updates the minimum and maximum values for all the attributes |
---|
| 823 | * based on a new instance. |
---|
| 824 | * |
---|
| 825 | * @param instance the new instance |
---|
| 826 | */ |
---|
| 827 | private void updateMinMax (Instance instance) { |
---|
| 828 | // for (int j = 0; j < m_numAttribs; j++) { |
---|
| 829 | try { |
---|
| 830 | for (int j = 0; j < instance.numValues(); j++) { |
---|
| 831 | if ((instance.attributeSparse(j).isNumeric()) && |
---|
| 832 | (!instance.isMissingSparse(j))) { |
---|
| 833 | if (Double.isNaN(m_minArray[instance.index(j)])) { |
---|
| 834 | m_minArray[instance.index(j)] = instance.valueSparse(j); |
---|
| 835 | m_maxArray[instance.index(j)] = instance.valueSparse(j); |
---|
| 836 | } |
---|
| 837 | else { |
---|
| 838 | if (instance.valueSparse(j) < m_minArray[instance.index(j)]) { |
---|
| 839 | m_minArray[instance.index(j)] = instance.valueSparse(j); |
---|
| 840 | } |
---|
| 841 | else { |
---|
| 842 | if (instance.valueSparse(j) > m_maxArray[instance.index(j)]) { |
---|
| 843 | m_maxArray[instance.index(j)] = instance.valueSparse(j); |
---|
| 844 | } |
---|
| 845 | } |
---|
| 846 | } |
---|
| 847 | } |
---|
| 848 | } |
---|
| 849 | } catch (Exception ex) { |
---|
| 850 | System.err.println(ex); |
---|
| 851 | ex.printStackTrace(); |
---|
| 852 | } |
---|
| 853 | } |
---|
| 854 | |
---|
| 855 | /** |
---|
| 856 | * Computes the difference between two given attribute |
---|
| 857 | * values. |
---|
| 858 | */ |
---|
| 859 | private double difference(int index, double val1, double val2) { |
---|
| 860 | |
---|
| 861 | switch (m_trainInstances.attribute(index).type()) { |
---|
| 862 | case Attribute.NOMINAL: |
---|
| 863 | |
---|
| 864 | // If attribute is nominal |
---|
| 865 | if (Utils.isMissingValue(val1) || |
---|
| 866 | Utils.isMissingValue(val2)) { |
---|
| 867 | return (1.0 - (1.0/((double)m_trainInstances. |
---|
| 868 | attribute(index).numValues()))); |
---|
| 869 | } else if ((int)val1 != (int)val2) { |
---|
| 870 | return 1; |
---|
| 871 | } else { |
---|
| 872 | return 0; |
---|
| 873 | } |
---|
| 874 | case Attribute.NUMERIC: |
---|
| 875 | |
---|
| 876 | // If attribute is numeric |
---|
| 877 | if (Utils.isMissingValue(val1) || |
---|
| 878 | Utils.isMissingValue(val2)) { |
---|
| 879 | if (Utils.isMissingValue(val1) && |
---|
| 880 | Utils.isMissingValue(val2)) { |
---|
| 881 | return 1; |
---|
| 882 | } else { |
---|
| 883 | double diff; |
---|
| 884 | if (Utils.isMissingValue(val2)) { |
---|
| 885 | diff = norm(val1, index); |
---|
| 886 | } else { |
---|
| 887 | diff = norm(val2, index); |
---|
| 888 | } |
---|
| 889 | if (diff < 0.5) { |
---|
| 890 | diff = 1.0 - diff; |
---|
| 891 | } |
---|
| 892 | return diff; |
---|
| 893 | } |
---|
| 894 | } else { |
---|
| 895 | return Math.abs(norm(val1, index) - norm(val2, index)); |
---|
| 896 | } |
---|
| 897 | default: |
---|
| 898 | return 0; |
---|
| 899 | } |
---|
| 900 | } |
---|
| 901 | |
---|
| 902 | /** |
---|
| 903 | * Calculates the distance between two instances |
---|
| 904 | * |
---|
| 905 | * @param first the first instance |
---|
| 906 | * @param second the second instance |
---|
| 907 | * @return the distance between the two given instances, between 0 and 1 |
---|
| 908 | */ |
---|
| 909 | private double distance(Instance first, Instance second) { |
---|
| 910 | |
---|
| 911 | double distance = 0; |
---|
| 912 | int firstI, secondI; |
---|
| 913 | |
---|
| 914 | for (int p1 = 0, p2 = 0; |
---|
| 915 | p1 < first.numValues() || p2 < second.numValues();) { |
---|
| 916 | if (p1 >= first.numValues()) { |
---|
| 917 | firstI = m_trainInstances.numAttributes(); |
---|
| 918 | } else { |
---|
| 919 | firstI = first.index(p1); |
---|
| 920 | } |
---|
| 921 | if (p2 >= second.numValues()) { |
---|
| 922 | secondI = m_trainInstances.numAttributes(); |
---|
| 923 | } else { |
---|
| 924 | secondI = second.index(p2); |
---|
| 925 | } |
---|
| 926 | if (firstI == m_trainInstances.classIndex()) { |
---|
| 927 | p1++; continue; |
---|
| 928 | } |
---|
| 929 | if (secondI == m_trainInstances.classIndex()) { |
---|
| 930 | p2++; continue; |
---|
| 931 | } |
---|
| 932 | double diff; |
---|
| 933 | if (firstI == secondI) { |
---|
| 934 | diff = difference(firstI, |
---|
| 935 | first.valueSparse(p1), |
---|
| 936 | second.valueSparse(p2)); |
---|
| 937 | p1++; p2++; |
---|
| 938 | } else if (firstI > secondI) { |
---|
| 939 | diff = difference(secondI, |
---|
| 940 | 0, second.valueSparse(p2)); |
---|
| 941 | p2++; |
---|
| 942 | } else { |
---|
| 943 | diff = difference(firstI, |
---|
| 944 | first.valueSparse(p1), 0); |
---|
| 945 | p1++; |
---|
| 946 | } |
---|
| 947 | // distance += diff * diff; |
---|
| 948 | distance += diff; |
---|
| 949 | } |
---|
| 950 | |
---|
| 951 | // return Math.sqrt(distance / m_NumAttributesUsed); |
---|
| 952 | return distance; |
---|
| 953 | } |
---|
| 954 | |
---|
| 955 | |
---|
| 956 | /** |
---|
| 957 | * update attribute weights given an instance when the class is numeric |
---|
| 958 | * |
---|
| 959 | * @param instNum the index of the instance to use when updating weights |
---|
| 960 | */ |
---|
| 961 | private void updateWeightsNumericClass (int instNum) { |
---|
| 962 | int i, j; |
---|
| 963 | double temp,temp2; |
---|
| 964 | int[] tempSorted = null; |
---|
| 965 | double[] tempDist = null; |
---|
| 966 | double distNorm = 1.0; |
---|
| 967 | int firstI, secondI; |
---|
| 968 | |
---|
| 969 | Instance inst = m_trainInstances.instance(instNum); |
---|
| 970 | |
---|
| 971 | // sort nearest neighbours and set up normalization variable |
---|
| 972 | if (m_weightByDistance) { |
---|
| 973 | tempDist = new double[m_stored[0]]; |
---|
| 974 | |
---|
| 975 | for (j = 0, distNorm = 0; j < m_stored[0]; j++) { |
---|
| 976 | // copy the distances |
---|
| 977 | tempDist[j] = m_karray[0][j][0]; |
---|
| 978 | // sum normalizer |
---|
| 979 | distNorm += m_weightsByRank[j]; |
---|
| 980 | } |
---|
| 981 | |
---|
| 982 | tempSorted = Utils.sort(tempDist); |
---|
| 983 | } |
---|
| 984 | |
---|
| 985 | for (i = 0; i < m_stored[0]; i++) { |
---|
| 986 | // P diff prediction (class) given nearest instances |
---|
| 987 | if (m_weightByDistance) { |
---|
| 988 | temp = difference(m_classIndex, |
---|
| 989 | inst.value(m_classIndex), |
---|
| 990 | m_trainInstances. |
---|
| 991 | instance((int)m_karray[0][tempSorted[i]][1]). |
---|
| 992 | value(m_classIndex)); |
---|
| 993 | temp *= (m_weightsByRank[i]/distNorm); |
---|
| 994 | } |
---|
| 995 | else { |
---|
| 996 | temp = difference(m_classIndex, |
---|
| 997 | inst.value(m_classIndex), |
---|
| 998 | m_trainInstances. |
---|
| 999 | instance((int)m_karray[0][i][1]). |
---|
| 1000 | value(m_classIndex)); |
---|
| 1001 | temp *= (1.0/(double)m_stored[0]); // equal influence |
---|
| 1002 | } |
---|
| 1003 | |
---|
| 1004 | m_ndc += temp; |
---|
| 1005 | |
---|
| 1006 | Instance cmp; |
---|
| 1007 | cmp = (m_weightByDistance) |
---|
| 1008 | ? m_trainInstances.instance((int)m_karray[0][tempSorted[i]][1]) |
---|
| 1009 | : m_trainInstances.instance((int)m_karray[0][i][1]); |
---|
| 1010 | |
---|
| 1011 | double temp_diffP_diffA_givNearest = |
---|
| 1012 | difference(m_classIndex, inst.value(m_classIndex), |
---|
| 1013 | cmp.value(m_classIndex)); |
---|
| 1014 | // now the attributes |
---|
| 1015 | for (int p1 = 0, p2 = 0; |
---|
| 1016 | p1 < inst.numValues() || p2 < cmp.numValues();) { |
---|
| 1017 | if (p1 >= inst.numValues()) { |
---|
| 1018 | firstI = m_trainInstances.numAttributes(); |
---|
| 1019 | } else { |
---|
| 1020 | firstI = inst.index(p1); |
---|
| 1021 | } |
---|
| 1022 | if (p2 >= cmp.numValues()) { |
---|
| 1023 | secondI = m_trainInstances.numAttributes(); |
---|
| 1024 | } else { |
---|
| 1025 | secondI = cmp.index(p2); |
---|
| 1026 | } |
---|
| 1027 | if (firstI == m_trainInstances.classIndex()) { |
---|
| 1028 | p1++; continue; |
---|
| 1029 | } |
---|
| 1030 | if (secondI == m_trainInstances.classIndex()) { |
---|
| 1031 | p2++; continue; |
---|
| 1032 | } |
---|
| 1033 | temp = 0.0; |
---|
| 1034 | temp2 = 0.0; |
---|
| 1035 | |
---|
| 1036 | if (firstI == secondI) { |
---|
| 1037 | j = firstI; |
---|
| 1038 | temp = difference(j, inst.valueSparse(p1), cmp.valueSparse(p2)); |
---|
| 1039 | p1++;p2++; |
---|
| 1040 | } else if (firstI > secondI) { |
---|
| 1041 | j = secondI; |
---|
| 1042 | temp = difference(j, 0, cmp.valueSparse(p2)); |
---|
| 1043 | p2++; |
---|
| 1044 | } else { |
---|
| 1045 | j = firstI; |
---|
| 1046 | temp = difference(j, inst.valueSparse(p1), 0); |
---|
| 1047 | p1++; |
---|
| 1048 | } |
---|
| 1049 | |
---|
| 1050 | temp2 = temp_diffP_diffA_givNearest * temp; |
---|
| 1051 | // P of different prediction and different att value given |
---|
| 1052 | // nearest instances |
---|
| 1053 | if (m_weightByDistance) { |
---|
| 1054 | temp2 *= (m_weightsByRank[i]/distNorm); |
---|
| 1055 | } |
---|
| 1056 | else { |
---|
| 1057 | temp2 *= (1.0/(double)m_stored[0]); // equal influence |
---|
| 1058 | } |
---|
| 1059 | |
---|
| 1060 | m_ndcda[j] += temp2; |
---|
| 1061 | |
---|
| 1062 | // P of different attribute val given nearest instances |
---|
| 1063 | if (m_weightByDistance) { |
---|
| 1064 | temp *= (m_weightsByRank[i]/distNorm); |
---|
| 1065 | } |
---|
| 1066 | else { |
---|
| 1067 | temp *= (1.0/(double)m_stored[0]); // equal influence |
---|
| 1068 | } |
---|
| 1069 | |
---|
| 1070 | m_nda[j] += temp; |
---|
| 1071 | } |
---|
| 1072 | } |
---|
| 1073 | } |
---|
| 1074 | |
---|
| 1075 | |
---|
| 1076 | /** |
---|
| 1077 | * update attribute weights given an instance when the class is discrete |
---|
| 1078 | * |
---|
| 1079 | * @param instNum the index of the instance to use when updating weights |
---|
| 1080 | */ |
---|
| 1081 | private void updateWeightsDiscreteClass (int instNum) { |
---|
| 1082 | int i, j, k; |
---|
| 1083 | int cl; |
---|
| 1084 | double temp_diff, w_norm = 1.0; |
---|
| 1085 | double[] tempDistClass; |
---|
| 1086 | int[] tempSortedClass = null; |
---|
| 1087 | double distNormClass = 1.0; |
---|
| 1088 | double[] tempDistAtt; |
---|
| 1089 | int[][] tempSortedAtt = null; |
---|
| 1090 | double[] distNormAtt = null; |
---|
| 1091 | int firstI, secondI; |
---|
| 1092 | |
---|
| 1093 | // store the indexes (sparse instances) of non-zero elements |
---|
| 1094 | Instance inst = m_trainInstances.instance(instNum); |
---|
| 1095 | |
---|
| 1096 | // get the class of this instance |
---|
| 1097 | cl = (int)m_trainInstances.instance(instNum).value(m_classIndex); |
---|
| 1098 | |
---|
| 1099 | // sort nearest neighbours and set up normalization variables |
---|
| 1100 | if (m_weightByDistance) { |
---|
| 1101 | // do class (hits) first |
---|
| 1102 | // sort the distances |
---|
| 1103 | tempDistClass = new double[m_stored[cl]]; |
---|
| 1104 | |
---|
| 1105 | for (j = 0, distNormClass = 0; j < m_stored[cl]; j++) { |
---|
| 1106 | // copy the distances |
---|
| 1107 | tempDistClass[j] = m_karray[cl][j][0]; |
---|
| 1108 | // sum normalizer |
---|
| 1109 | distNormClass += m_weightsByRank[j]; |
---|
| 1110 | } |
---|
| 1111 | |
---|
| 1112 | tempSortedClass = Utils.sort(tempDistClass); |
---|
| 1113 | // do misses (other classes) |
---|
| 1114 | tempSortedAtt = new int[m_numClasses][1]; |
---|
| 1115 | distNormAtt = new double[m_numClasses]; |
---|
| 1116 | |
---|
| 1117 | for (k = 0; k < m_numClasses; k++) { |
---|
| 1118 | if (k != cl) // already done cl |
---|
| 1119 | { |
---|
| 1120 | // sort the distances |
---|
| 1121 | tempDistAtt = new double[m_stored[k]]; |
---|
| 1122 | |
---|
| 1123 | for (j = 0, distNormAtt[k] = 0; j < m_stored[k]; j++) { |
---|
| 1124 | // copy the distances |
---|
| 1125 | tempDistAtt[j] = m_karray[k][j][0]; |
---|
| 1126 | // sum normalizer |
---|
| 1127 | distNormAtt[k] += m_weightsByRank[j]; |
---|
| 1128 | } |
---|
| 1129 | |
---|
| 1130 | tempSortedAtt[k] = Utils.sort(tempDistAtt); |
---|
| 1131 | } |
---|
| 1132 | } |
---|
| 1133 | } |
---|
| 1134 | |
---|
| 1135 | if (m_numClasses > 2) { |
---|
| 1136 | // the amount of probability space left after removing the |
---|
| 1137 | // probability of this instance's class value |
---|
| 1138 | w_norm = (1.0 - m_classProbs[cl]); |
---|
| 1139 | } |
---|
| 1140 | |
---|
| 1141 | // do the k nearest hits of the same class |
---|
| 1142 | for (j = 0, temp_diff = 0.0; j < m_stored[cl]; j++) { |
---|
| 1143 | Instance cmp; |
---|
| 1144 | cmp = (m_weightByDistance) |
---|
| 1145 | ? m_trainInstances. |
---|
| 1146 | instance((int)m_karray[cl][tempSortedClass[j]][1]) |
---|
| 1147 | : m_trainInstances.instance((int)m_karray[cl][j][1]); |
---|
| 1148 | |
---|
| 1149 | for (int p1 = 0, p2 = 0; |
---|
| 1150 | p1 < inst.numValues() || p2 < cmp.numValues();) { |
---|
| 1151 | if (p1 >= inst.numValues()) { |
---|
| 1152 | firstI = m_trainInstances.numAttributes(); |
---|
| 1153 | } else { |
---|
| 1154 | firstI = inst.index(p1); |
---|
| 1155 | } |
---|
| 1156 | if (p2 >= cmp.numValues()) { |
---|
| 1157 | secondI = m_trainInstances.numAttributes(); |
---|
| 1158 | } else { |
---|
| 1159 | secondI = cmp.index(p2); |
---|
| 1160 | } |
---|
| 1161 | if (firstI == m_trainInstances.classIndex()) { |
---|
| 1162 | p1++; continue; |
---|
| 1163 | } |
---|
| 1164 | if (secondI == m_trainInstances.classIndex()) { |
---|
| 1165 | p2++; continue; |
---|
| 1166 | } |
---|
| 1167 | if (firstI == secondI) { |
---|
| 1168 | i = firstI; |
---|
| 1169 | temp_diff = difference(i, inst.valueSparse(p1), |
---|
| 1170 | cmp.valueSparse(p2)); |
---|
| 1171 | p1++;p2++; |
---|
| 1172 | } else if (firstI > secondI) { |
---|
| 1173 | i = secondI; |
---|
| 1174 | temp_diff = difference(i, 0, cmp.valueSparse(p2)); |
---|
| 1175 | p2++; |
---|
| 1176 | } else { |
---|
| 1177 | i = firstI; |
---|
| 1178 | temp_diff = difference(i, inst.valueSparse(p1), 0); |
---|
| 1179 | p1++; |
---|
| 1180 | } |
---|
| 1181 | |
---|
| 1182 | if (m_weightByDistance) { |
---|
| 1183 | temp_diff *= |
---|
| 1184 | (m_weightsByRank[j]/distNormClass); |
---|
| 1185 | } else { |
---|
| 1186 | if (m_stored[cl] > 0) { |
---|
| 1187 | temp_diff /= (double)m_stored[cl]; |
---|
| 1188 | } |
---|
| 1189 | } |
---|
| 1190 | m_weights[i] -= temp_diff; |
---|
| 1191 | |
---|
| 1192 | } |
---|
| 1193 | } |
---|
| 1194 | |
---|
| 1195 | |
---|
| 1196 | // now do k nearest misses from each of the other classes |
---|
| 1197 | temp_diff = 0.0; |
---|
| 1198 | |
---|
| 1199 | for (k = 0; k < m_numClasses; k++) { |
---|
| 1200 | if (k != cl) // already done cl |
---|
| 1201 | { |
---|
| 1202 | for (j = 0; j < m_stored[k]; j++) { |
---|
| 1203 | Instance cmp; |
---|
| 1204 | cmp = (m_weightByDistance) |
---|
| 1205 | ? m_trainInstances. |
---|
| 1206 | instance((int)m_karray[k][tempSortedAtt[k][j]][1]) |
---|
| 1207 | : m_trainInstances.instance((int)m_karray[k][j][1]); |
---|
| 1208 | |
---|
| 1209 | for (int p1 = 0, p2 = 0; |
---|
| 1210 | p1 < inst.numValues() || p2 < cmp.numValues();) { |
---|
| 1211 | if (p1 >= inst.numValues()) { |
---|
| 1212 | firstI = m_trainInstances.numAttributes(); |
---|
| 1213 | } else { |
---|
| 1214 | firstI = inst.index(p1); |
---|
| 1215 | } |
---|
| 1216 | if (p2 >= cmp.numValues()) { |
---|
| 1217 | secondI = m_trainInstances.numAttributes(); |
---|
| 1218 | } else { |
---|
| 1219 | secondI = cmp.index(p2); |
---|
| 1220 | } |
---|
| 1221 | if (firstI == m_trainInstances.classIndex()) { |
---|
| 1222 | p1++; continue; |
---|
| 1223 | } |
---|
| 1224 | if (secondI == m_trainInstances.classIndex()) { |
---|
| 1225 | p2++; continue; |
---|
| 1226 | } |
---|
| 1227 | if (firstI == secondI) { |
---|
| 1228 | i = firstI; |
---|
| 1229 | temp_diff = difference(i, inst.valueSparse(p1), |
---|
| 1230 | cmp.valueSparse(p2)); |
---|
| 1231 | p1++;p2++; |
---|
| 1232 | } else if (firstI > secondI) { |
---|
| 1233 | i = secondI; |
---|
| 1234 | temp_diff = difference(i, 0, cmp.valueSparse(p2)); |
---|
| 1235 | p2++; |
---|
| 1236 | } else { |
---|
| 1237 | i = firstI; |
---|
| 1238 | temp_diff = difference(i, inst.valueSparse(p1), 0); |
---|
| 1239 | p1++; |
---|
| 1240 | } |
---|
| 1241 | |
---|
| 1242 | if (m_weightByDistance) { |
---|
| 1243 | temp_diff *= |
---|
| 1244 | (m_weightsByRank[j]/distNormAtt[k]); |
---|
| 1245 | } |
---|
| 1246 | else { |
---|
| 1247 | if (m_stored[k] > 0) { |
---|
| 1248 | temp_diff /= (double)m_stored[k]; |
---|
| 1249 | } |
---|
| 1250 | } |
---|
| 1251 | if (m_numClasses > 2) { |
---|
| 1252 | m_weights[i] += ((m_classProbs[k]/w_norm)*temp_diff); |
---|
| 1253 | } else { |
---|
| 1254 | m_weights[i] += temp_diff; |
---|
| 1255 | } |
---|
| 1256 | } |
---|
| 1257 | } |
---|
| 1258 | } |
---|
| 1259 | } |
---|
| 1260 | } |
---|
| 1261 | |
---|
| 1262 | |
---|
| 1263 | /** |
---|
| 1264 | * Find the K nearest instances to supplied instance if the class is numeric, |
---|
| 1265 | * or the K nearest Hits (same class) and Misses (K from each of the other |
---|
| 1266 | * classes) if the class is discrete. |
---|
| 1267 | * |
---|
| 1268 | * @param instNum the index of the instance to find nearest neighbours of |
---|
| 1269 | */ |
---|
| 1270 | private void findKHitMiss (int instNum) { |
---|
| 1271 | int i, j; |
---|
| 1272 | int cl; |
---|
| 1273 | double ww; |
---|
| 1274 | double temp_diff = 0.0; |
---|
| 1275 | Instance thisInst = m_trainInstances.instance(instNum); |
---|
| 1276 | |
---|
| 1277 | for (i = 0; i < m_numInstances; i++) { |
---|
| 1278 | if (i != instNum) { |
---|
| 1279 | Instance cmpInst = m_trainInstances.instance(i); |
---|
| 1280 | temp_diff = distance(cmpInst, thisInst); |
---|
| 1281 | |
---|
| 1282 | // class of this training instance or 0 if numeric |
---|
| 1283 | if (m_numericClass) { |
---|
| 1284 | cl = 0; |
---|
| 1285 | } |
---|
| 1286 | else { |
---|
| 1287 | cl = (int)m_trainInstances.instance(i).value(m_classIndex); |
---|
| 1288 | } |
---|
| 1289 | |
---|
| 1290 | // add this diff to the list for the class of this instance |
---|
| 1291 | if (m_stored[cl] < m_Knn) { |
---|
| 1292 | m_karray[cl][m_stored[cl]][0] = temp_diff; |
---|
| 1293 | m_karray[cl][m_stored[cl]][1] = i; |
---|
| 1294 | m_stored[cl]++; |
---|
| 1295 | |
---|
| 1296 | // note the worst diff for this class |
---|
| 1297 | for (j = 0, ww = -1.0; j < m_stored[cl]; j++) { |
---|
| 1298 | if (m_karray[cl][j][0] > ww) { |
---|
| 1299 | ww = m_karray[cl][j][0]; |
---|
| 1300 | m_index[cl] = j; |
---|
| 1301 | } |
---|
| 1302 | } |
---|
| 1303 | |
---|
| 1304 | m_worst[cl] = ww; |
---|
| 1305 | } |
---|
| 1306 | else |
---|
| 1307 | /* if we already have stored knn for this class then check to |
---|
| 1308 | see if this instance is better than the worst */ |
---|
| 1309 | { |
---|
| 1310 | if (temp_diff < m_karray[cl][m_index[cl]][0]) { |
---|
| 1311 | m_karray[cl][m_index[cl]][0] = temp_diff; |
---|
| 1312 | m_karray[cl][m_index[cl]][1] = i; |
---|
| 1313 | |
---|
| 1314 | for (j = 0, ww = -1.0; j < m_stored[cl]; j++) { |
---|
| 1315 | if (m_karray[cl][j][0] > ww) { |
---|
| 1316 | ww = m_karray[cl][j][0]; |
---|
| 1317 | m_index[cl] = j; |
---|
| 1318 | } |
---|
| 1319 | } |
---|
| 1320 | |
---|
| 1321 | m_worst[cl] = ww; |
---|
| 1322 | } |
---|
| 1323 | } |
---|
| 1324 | } |
---|
| 1325 | } |
---|
| 1326 | } |
---|
| 1327 | |
---|
| 1328 | /** |
---|
| 1329 | * Returns the revision string. |
---|
| 1330 | * |
---|
| 1331 | * @return the revision |
---|
| 1332 | */ |
---|
| 1333 | public String getRevision() { |
---|
| 1334 | return RevisionUtils.extract("$Revision: 5987 $"); |
---|
| 1335 | } |
---|
| 1336 | |
---|
| 1337 | // ============ |
---|
| 1338 | // Test method. |
---|
| 1339 | // ============ |
---|
| 1340 | /** |
---|
| 1341 | * Main method for testing this class. |
---|
| 1342 | * |
---|
| 1343 | * @param args the options |
---|
| 1344 | */ |
---|
| 1345 | public static void main (String[] args) { |
---|
| 1346 | runEvaluator(new ReliefFAttributeEval(), args); |
---|
| 1347 | } |
---|
| 1348 | } |
---|