[29] | 1 | /* |
---|
| 2 | * This program is free software; you can redistribute it and/or modify |
---|
| 3 | * it under the terms of the GNU General Public License as published by |
---|
| 4 | * the Free Software Foundation; either version 2 of the License, or |
---|
| 5 | * (at your option) any later version. |
---|
| 6 | * |
---|
| 7 | * This program is distributed in the hope that it will be useful, |
---|
| 8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
| 9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
---|
| 10 | * GNU General Public License for more details. |
---|
| 11 | * |
---|
| 12 | * You should have received a copy of the GNU General Public License |
---|
| 13 | * along with this program; if not, write to the Free Software |
---|
| 14 | * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. |
---|
| 15 | */ |
---|
| 16 | |
---|
| 17 | /* |
---|
| 18 | * RaceSearch.java |
---|
| 19 | * Copyright (C) 2000 University of Waikato, Hamilton, New Zealand |
---|
| 20 | * |
---|
| 21 | */ |
---|
| 22 | |
---|
| 23 | package weka.attributeSelection; |
---|
| 24 | |
---|
| 25 | import weka.core.Instance; |
---|
| 26 | import weka.core.Instances; |
---|
| 27 | import weka.core.Option; |
---|
| 28 | import weka.core.OptionHandler; |
---|
| 29 | import weka.core.RevisionUtils; |
---|
| 30 | import weka.core.SelectedTag; |
---|
| 31 | import weka.core.Statistics; |
---|
| 32 | import weka.core.Tag; |
---|
| 33 | import weka.core.TechnicalInformation; |
---|
| 34 | import weka.core.TechnicalInformation.Type; |
---|
| 35 | import weka.core.TechnicalInformation.Field; |
---|
| 36 | import weka.core.TechnicalInformationHandler; |
---|
| 37 | import weka.core.Utils; |
---|
| 38 | import weka.experiment.PairedStats; |
---|
| 39 | import weka.experiment.Stats; |
---|
| 40 | |
---|
| 41 | import java.util.BitSet; |
---|
| 42 | import java.util.Enumeration; |
---|
| 43 | import java.util.Random; |
---|
| 44 | import java.util.Vector; |
---|
| 45 | |
---|
| 46 | /** |
---|
| 47 | <!-- globalinfo-start --> |
---|
| 48 | * Races the cross validation error of competing attribute subsets. Use in conjuction with a ClassifierSubsetEval. RaceSearch has four modes:<br/> |
---|
| 49 | * <br/> |
---|
| 50 | * forward selection races all single attribute additions to a base set (initially no attributes), selects the winner to become the new base set and then iterates until there is no improvement over the base set. <br/> |
---|
| 51 | * <br/> |
---|
| 52 | * Backward elimination is similar but the initial base set has all attributes included and races all single attribute deletions. <br/> |
---|
| 53 | * <br/> |
---|
| 54 | * Schemata search is a bit different. Each iteration a series of races are run in parallel. Each race in a set determines whether a particular attribute should be included or not---ie the race is between the attribute being "in" or "out". The other attributes for this race are included or excluded randomly at each point in the evaluation. As soon as one race has a clear winner (ie it has been decided whether a particular attribute should be inor not) then the next set of races begins, using the result of the winning race from the previous iteration as new base set.<br/> |
---|
| 55 | * <br/> |
---|
| 56 | * Rank race first ranks the attributes using an attribute evaluator and then races the ranking. The race includes no attributes, the top ranked attribute, the top two attributes, the top three attributes, etc.<br/> |
---|
| 57 | * <br/> |
---|
| 58 | * It is also possible to generate a raked list of attributes through the forward racing process. If generateRanking is set to true then a complete forward race will be run---that is, racing continues until all attributes have been selected. The order that they are added in determines a complete ranking of all the attributes.<br/> |
---|
| 59 | * <br/> |
---|
| 60 | * Racing uses paired and unpaired t-tests on cross-validation errors of competing subsets. When there is a significant difference between the means of the errors of two competing subsets then the poorer of the two can be eliminated from the race. Similarly, if there is no significant difference between the mean errors of two competing subsets and they are within some threshold of each other, then one can be eliminated from the race.<br/> |
---|
| 61 | * <br/> |
---|
| 62 | * For more information see:<br/> |
---|
| 63 | * <br/> |
---|
| 64 | * Andrew W. Moore, Mary S. Lee: Efficient Algorithms for Minimizing Cross Validation Error. In: Eleventh International Conference on Machine Learning, 190-198, 1994. |
---|
| 65 | * <p/> |
---|
| 66 | <!-- globalinfo-end --> |
---|
| 67 | * |
---|
| 68 | <!-- technical-bibtex-start --> |
---|
| 69 | * BibTeX: |
---|
| 70 | * <pre> |
---|
| 71 | * @inproceedings{Moore1994, |
---|
| 72 | * author = {Andrew W. Moore and Mary S. Lee}, |
---|
| 73 | * booktitle = {Eleventh International Conference on Machine Learning}, |
---|
| 74 | * pages = {190-198}, |
---|
| 75 | * publisher = {Morgan Kaufmann}, |
---|
| 76 | * title = {Efficient Algorithms for Minimizing Cross Validation Error}, |
---|
| 77 | * year = {1994} |
---|
| 78 | * } |
---|
| 79 | * </pre> |
---|
| 80 | * <p/> |
---|
| 81 | <!-- technical-bibtex-end --> |
---|
| 82 | * |
---|
| 83 | <!-- options-start --> |
---|
| 84 | * Valid options are: <p/> |
---|
| 85 | * |
---|
| 86 | * <pre> -R <0 = forward | 1 = backward race | 2 = schemata | 3 = rank> |
---|
| 87 | * Type of race to perform. |
---|
| 88 | * (default = 0).</pre> |
---|
| 89 | * |
---|
| 90 | * <pre> -L <significance> |
---|
| 91 | * Significance level for comaparisons |
---|
| 92 | * (default = 0.001(forward/backward/rank)/0.01(schemata)).</pre> |
---|
| 93 | * |
---|
| 94 | * <pre> -T <threshold> |
---|
| 95 | * Threshold for error comparison. |
---|
| 96 | * (default = 0.001).</pre> |
---|
| 97 | * |
---|
| 98 | * <pre> -A <attribute evaluator> |
---|
| 99 | * Attribute ranker to use if doing a |
---|
| 100 | * rank search. Place any |
---|
| 101 | * evaluator options LAST on |
---|
| 102 | * the command line following a "--". |
---|
| 103 | * eg. -A weka.attributeSelection.GainRatioAttributeEval ... -- -M. |
---|
| 104 | * (default = GainRatioAttributeEval)</pre> |
---|
| 105 | * |
---|
| 106 | * <pre> -F <0 = 10 fold | 1 = leave-one-out> |
---|
| 107 | * Folds for cross validation |
---|
| 108 | * (default = 0 (1 if schemata race)</pre> |
---|
| 109 | * |
---|
| 110 | * <pre> -Q |
---|
| 111 | * Generate a ranked list of attributes. |
---|
| 112 | * Forces the search to be forward |
---|
| 113 | * and races until all attributes have |
---|
| 114 | * selected, thus producing a ranking.</pre> |
---|
| 115 | * |
---|
| 116 | * <pre> -N <num to select> |
---|
| 117 | * Specify number of attributes to retain from |
---|
| 118 | * the ranking. Overides -T. Use in conjunction with -Q</pre> |
---|
| 119 | * |
---|
| 120 | * <pre> -J <threshold> |
---|
| 121 | * Specify a theshold by which attributes |
---|
| 122 | * may be discarded from the ranking. |
---|
| 123 | * Use in conjuction with -Q</pre> |
---|
| 124 | * |
---|
| 125 | * <pre> -Z |
---|
| 126 | * Verbose output for monitoring the search.</pre> |
---|
| 127 | * |
---|
| 128 | * <pre> |
---|
| 129 | * Options specific to evaluator weka.attributeSelection.GainRatioAttributeEval: |
---|
| 130 | * </pre> |
---|
| 131 | * |
---|
| 132 | * <pre> -M |
---|
| 133 | * treat missing values as a seperate value.</pre> |
---|
| 134 | * |
---|
| 135 | <!-- options-end --> |
---|
| 136 | * |
---|
| 137 | * @author Mark Hall (mhall@cs.waikato.ac.nz) |
---|
| 138 | * @version $Revision: 1.26 $ |
---|
| 139 | */ |
---|
| 140 | public class RaceSearch |
---|
| 141 | extends ASSearch |
---|
| 142 | implements RankedOutputSearch, OptionHandler, TechnicalInformationHandler { |
---|
| 143 | |
---|
| 144 | /** for serialization */ |
---|
| 145 | static final long serialVersionUID = 4015453851212985720L; |
---|
| 146 | |
---|
| 147 | /** the training instances */ |
---|
| 148 | private Instances m_Instances = null; |
---|
| 149 | |
---|
| 150 | /** search types */ |
---|
| 151 | private static final int FORWARD_RACE = 0; |
---|
| 152 | private static final int BACKWARD_RACE = 1; |
---|
| 153 | private static final int SCHEMATA_RACE = 2; |
---|
| 154 | private static final int RANK_RACE = 3; |
---|
| 155 | public static final Tag [] TAGS_SELECTION = { |
---|
| 156 | new Tag(FORWARD_RACE, "Forward selection race"), |
---|
| 157 | new Tag(BACKWARD_RACE, "Backward elimination race"), |
---|
| 158 | new Tag(SCHEMATA_RACE, "Schemata race"), |
---|
| 159 | new Tag(RANK_RACE, "Rank race") |
---|
| 160 | }; |
---|
| 161 | |
---|
| 162 | /** the selected search type */ |
---|
| 163 | private int m_raceType = FORWARD_RACE; |
---|
| 164 | |
---|
| 165 | /** xval types */ |
---|
| 166 | private static final int TEN_FOLD = 0; |
---|
| 167 | private static final int LEAVE_ONE_OUT = 1; |
---|
| 168 | public static final Tag [] XVALTAGS_SELECTION = { |
---|
| 169 | new Tag(TEN_FOLD, "10 Fold"), |
---|
| 170 | new Tag(LEAVE_ONE_OUT, "Leave-one-out"), |
---|
| 171 | }; |
---|
| 172 | |
---|
| 173 | /** the selected xval type */ |
---|
| 174 | private int m_xvalType = TEN_FOLD; |
---|
| 175 | |
---|
| 176 | /** the class index */ |
---|
| 177 | private int m_classIndex; |
---|
| 178 | |
---|
| 179 | /** the number of attributes in the data */ |
---|
| 180 | private int m_numAttribs; |
---|
| 181 | |
---|
| 182 | /** the total number of partially/fully evaluated subsets */ |
---|
| 183 | private int m_totalEvals; |
---|
| 184 | |
---|
| 185 | /** holds the merit of the best subset found */ |
---|
| 186 | private double m_bestMerit = -Double.MAX_VALUE; |
---|
| 187 | |
---|
| 188 | /** the subset evaluator to use */ |
---|
| 189 | private HoldOutSubsetEvaluator m_theEvaluator = null; |
---|
| 190 | |
---|
| 191 | /** the significance level for comparisons */ |
---|
| 192 | private double m_sigLevel = 0.001; |
---|
| 193 | |
---|
| 194 | /** threshold for comparisons */ |
---|
| 195 | private double m_delta = 0.001; |
---|
| 196 | |
---|
| 197 | /** the number of samples above which to begin testing for similarity |
---|
| 198 | between competing subsets */ |
---|
| 199 | private int m_samples = 20; |
---|
| 200 | |
---|
| 201 | /** number of cross validation folds---equal to the number of instances |
---|
| 202 | for leave-one-out cv */ |
---|
| 203 | private int m_numFolds = 10; |
---|
| 204 | |
---|
| 205 | /** the attribute evaluator to generate the initial ranking when |
---|
| 206 | doing a rank race */ |
---|
| 207 | private ASEvaluation m_ASEval = new GainRatioAttributeEval(); |
---|
| 208 | |
---|
| 209 | /** will hold the attribute ranking produced by the above attribute |
---|
| 210 | evaluator if doing a rank search */ |
---|
| 211 | private int [] m_Ranking; |
---|
| 212 | |
---|
| 213 | /** verbose output for monitoring the search and debugging */ |
---|
| 214 | private boolean m_debug = false; |
---|
| 215 | |
---|
| 216 | /** If true then produce a ranked list of attributes by fully traversing |
---|
| 217 | a forward hillclimb race */ |
---|
| 218 | private boolean m_rankingRequested = false; |
---|
| 219 | |
---|
| 220 | /** The ranked list of attributes produced if m_rankingRequested is true */ |
---|
| 221 | private double [][] m_rankedAtts; |
---|
| 222 | |
---|
| 223 | /** The number of attributes ranked so far (if ranking is requested) */ |
---|
| 224 | private int m_rankedSoFar; |
---|
| 225 | |
---|
| 226 | /** The number of attributes to retain if a ranking is requested. -1 |
---|
| 227 | indicates that all attributes are to be retained. Has precedence over |
---|
| 228 | m_threshold */ |
---|
| 229 | private int m_numToSelect = -1; |
---|
| 230 | |
---|
| 231 | private int m_calculatedNumToSelect = -1; |
---|
| 232 | |
---|
| 233 | /** the threshold for removing attributes if ranking is requested */ |
---|
| 234 | private double m_threshold = -Double.MAX_VALUE; |
---|
| 235 | |
---|
| 236 | /** |
---|
| 237 | * Returns a string describing this search method |
---|
| 238 | * @return a description of the search method suitable for |
---|
| 239 | * displaying in the explorer/experimenter gui |
---|
| 240 | */ |
---|
| 241 | public String globalInfo() { |
---|
| 242 | return "Races the cross validation error of competing " |
---|
| 243 | +"attribute subsets. Use in conjuction with a ClassifierSubsetEval. " |
---|
| 244 | +"RaceSearch has four modes:\n\nforward selection " |
---|
| 245 | +"races all single attribute additions to a base set (initially " |
---|
| 246 | +" no attributes), selects the winner to become the new base set " |
---|
| 247 | +"and then iterates until there is no improvement over the base set. " |
---|
| 248 | +"\n\nBackward elimination is similar but the initial base set has all " |
---|
| 249 | +"attributes included and races all single attribute deletions. " |
---|
| 250 | +"\n\nSchemata search is a bit different. Each iteration a series of " |
---|
| 251 | +"races are run in parallel. Each race in a set determines whether " |
---|
| 252 | +"a particular attribute should be included or not---ie the race is " |
---|
| 253 | +"between the attribute being \"in\" or \"out\". The other attributes " |
---|
| 254 | +"for this race are included or excluded randomly at each point in the " |
---|
| 255 | +"evaluation. As soon as one race " |
---|
| 256 | +"has a clear winner (ie it has been decided whether a particular " |
---|
| 257 | +"attribute should be inor not) then the next set of races begins, " |
---|
| 258 | +"using the result of the winning race from the previous iteration as " |
---|
| 259 | +"new base set.\n\nRank race first ranks the attributes using an " |
---|
| 260 | +"attribute evaluator and then races the ranking. The race includes " |
---|
| 261 | +"no attributes, the top ranked attribute, the top two attributes, the " |
---|
| 262 | +"top three attributes, etc.\n\nIt is also possible to generate a " |
---|
| 263 | +"raked list of attributes through the forward racing process. " |
---|
| 264 | +"If generateRanking is set to true then a complete forward race will " |
---|
| 265 | +"be run---that is, racing continues until all attributes have been " |
---|
| 266 | +"selected. The order that they are added in determines a complete " |
---|
| 267 | +"ranking of all the attributes.\n\nRacing uses paired and unpaired " |
---|
| 268 | +"t-tests on cross-validation errors of competing subsets. When there " |
---|
| 269 | +"is a significant difference between the means of the errors of two " |
---|
| 270 | +"competing subsets then the poorer of the two can be eliminated from " |
---|
| 271 | +"the race. Similarly, if there is no significant difference between " |
---|
| 272 | +"the mean errors of two competing subsets and they are within some " |
---|
| 273 | +"threshold of each other, then one can be eliminated from the race.\n\n" |
---|
| 274 | + "For more information see:\n\n" |
---|
| 275 | + getTechnicalInformation().toString(); |
---|
| 276 | } |
---|
| 277 | |
---|
| 278 | /** |
---|
| 279 | * Returns an instance of a TechnicalInformation object, containing |
---|
| 280 | * detailed information about the technical background of this class, |
---|
| 281 | * e.g., paper reference or book this class is based on. |
---|
| 282 | * |
---|
| 283 | * @return the technical information about this class |
---|
| 284 | */ |
---|
| 285 | public TechnicalInformation getTechnicalInformation() { |
---|
| 286 | TechnicalInformation result; |
---|
| 287 | |
---|
| 288 | result = new TechnicalInformation(Type.INPROCEEDINGS); |
---|
| 289 | result.setValue(Field.AUTHOR, "Andrew W. Moore and Mary S. Lee"); |
---|
| 290 | result.setValue(Field.TITLE, "Efficient Algorithms for Minimizing Cross Validation Error"); |
---|
| 291 | result.setValue(Field.BOOKTITLE, "Eleventh International Conference on Machine Learning"); |
---|
| 292 | result.setValue(Field.YEAR, "1994"); |
---|
| 293 | result.setValue(Field.PAGES, "190-198"); |
---|
| 294 | result.setValue(Field.PUBLISHER, "Morgan Kaufmann"); |
---|
| 295 | |
---|
| 296 | return result; |
---|
| 297 | } |
---|
| 298 | |
---|
| 299 | /** |
---|
| 300 | * Returns the tip text for this property |
---|
| 301 | * @return tip text for this property suitable for |
---|
| 302 | * displaying in the explorer/experimenter gui |
---|
| 303 | */ |
---|
| 304 | public String raceTypeTipText() { |
---|
| 305 | return "Set the type of search."; |
---|
| 306 | } |
---|
| 307 | |
---|
| 308 | /** |
---|
| 309 | * Set the race type |
---|
| 310 | * |
---|
| 311 | * @param d the type of race |
---|
| 312 | */ |
---|
| 313 | public void setRaceType (SelectedTag d) { |
---|
| 314 | |
---|
| 315 | if (d.getTags() == TAGS_SELECTION) { |
---|
| 316 | m_raceType = d.getSelectedTag().getID(); |
---|
| 317 | } |
---|
| 318 | if (m_raceType == SCHEMATA_RACE && !m_rankingRequested) { |
---|
| 319 | try { |
---|
| 320 | setFoldsType(new SelectedTag(LEAVE_ONE_OUT, |
---|
| 321 | XVALTAGS_SELECTION)); |
---|
| 322 | setSignificanceLevel(0.01); |
---|
| 323 | } catch (Exception ex) { |
---|
| 324 | } |
---|
| 325 | } else { |
---|
| 326 | try { |
---|
| 327 | setFoldsType(new SelectedTag(TEN_FOLD, |
---|
| 328 | XVALTAGS_SELECTION)); |
---|
| 329 | setSignificanceLevel(0.001); |
---|
| 330 | } catch (Exception ex) { |
---|
| 331 | } |
---|
| 332 | } |
---|
| 333 | } |
---|
| 334 | |
---|
| 335 | /** |
---|
| 336 | * Get the race type |
---|
| 337 | * |
---|
| 338 | * @return the type of race |
---|
| 339 | */ |
---|
| 340 | public SelectedTag getRaceType() { |
---|
| 341 | return new SelectedTag(m_raceType, TAGS_SELECTION); |
---|
| 342 | } |
---|
| 343 | |
---|
| 344 | /** |
---|
| 345 | * Returns the tip text for this property |
---|
| 346 | * @return tip text for this property suitable for |
---|
| 347 | * displaying in the explorer/experimenter gui |
---|
| 348 | */ |
---|
| 349 | public String significanceLevelTipText() { |
---|
| 350 | return "Set the significance level to use for t-test comparisons."; |
---|
| 351 | } |
---|
| 352 | |
---|
| 353 | /** |
---|
| 354 | * Sets the significance level to use |
---|
| 355 | * @param sig the significance level |
---|
| 356 | */ |
---|
| 357 | public void setSignificanceLevel(double sig) { |
---|
| 358 | m_sigLevel = sig; |
---|
| 359 | } |
---|
| 360 | |
---|
| 361 | /** |
---|
| 362 | * Get the significance level |
---|
| 363 | * @return the current significance level |
---|
| 364 | */ |
---|
| 365 | public double getSignificanceLevel() { |
---|
| 366 | return m_sigLevel; |
---|
| 367 | } |
---|
| 368 | |
---|
| 369 | /** |
---|
| 370 | * Returns the tip text for this property |
---|
| 371 | * @return tip text for this property suitable for |
---|
| 372 | * displaying in the explorer/experimenter gui |
---|
| 373 | */ |
---|
| 374 | public String thresholdTipText() { |
---|
| 375 | return "Set the error threshold by which to consider two subsets " |
---|
| 376 | +"equivalent."; |
---|
| 377 | } |
---|
| 378 | |
---|
| 379 | /** |
---|
| 380 | * Sets the threshold for comparisons |
---|
| 381 | * @param t the threshold to use |
---|
| 382 | */ |
---|
| 383 | public void setThreshold(double t) { |
---|
| 384 | m_delta = t; |
---|
| 385 | } |
---|
| 386 | |
---|
| 387 | /** |
---|
| 388 | * Get the threshold |
---|
| 389 | * @return the current threshold |
---|
| 390 | */ |
---|
| 391 | public double getThreshold() { |
---|
| 392 | return m_delta; |
---|
| 393 | } |
---|
| 394 | |
---|
| 395 | /** |
---|
| 396 | * Returns the tip text for this property |
---|
| 397 | * @return tip text for this property suitable for |
---|
| 398 | * displaying in the explorer/experimenter gui |
---|
| 399 | */ |
---|
| 400 | public String foldsTypeTipText() { |
---|
| 401 | return "Set the number of folds to use for x-val error estimation; " |
---|
| 402 | +"leave-one-out is selected automatically for schemata search."; |
---|
| 403 | } |
---|
| 404 | |
---|
| 405 | /** |
---|
| 406 | * Set the xfold type |
---|
| 407 | * |
---|
| 408 | * @param d the type of xval |
---|
| 409 | */ |
---|
| 410 | public void setFoldsType (SelectedTag d) { |
---|
| 411 | |
---|
| 412 | if (d.getTags() == XVALTAGS_SELECTION) { |
---|
| 413 | m_xvalType = d.getSelectedTag().getID(); |
---|
| 414 | } |
---|
| 415 | } |
---|
| 416 | |
---|
| 417 | /** |
---|
| 418 | * Get the xfold type |
---|
| 419 | * |
---|
| 420 | * @return the type of xval |
---|
| 421 | */ |
---|
| 422 | public SelectedTag getFoldsType () { |
---|
| 423 | return new SelectedTag(m_xvalType, XVALTAGS_SELECTION); |
---|
| 424 | } |
---|
| 425 | |
---|
| 426 | /** |
---|
| 427 | * Returns the tip text for this property |
---|
| 428 | * @return tip text for this property suitable for |
---|
| 429 | * displaying in the explorer/experimenter gui |
---|
| 430 | */ |
---|
| 431 | public String debugTipText() { |
---|
| 432 | return "Turn on verbose output for monitoring the search's progress."; |
---|
| 433 | } |
---|
| 434 | |
---|
| 435 | /** |
---|
| 436 | * Set whether verbose output should be generated. |
---|
| 437 | * @param d true if output is to be verbose. |
---|
| 438 | */ |
---|
| 439 | public void setDebug(boolean d) { |
---|
| 440 | m_debug = d; |
---|
| 441 | } |
---|
| 442 | |
---|
| 443 | /** |
---|
| 444 | * Get whether output is to be verbose |
---|
| 445 | * @return true if output will be verbose |
---|
| 446 | */ |
---|
| 447 | public boolean getDebug() { |
---|
| 448 | return m_debug; |
---|
| 449 | } |
---|
| 450 | |
---|
| 451 | /** |
---|
| 452 | * Returns the tip text for this property |
---|
| 453 | * @return tip text for this property suitable for |
---|
| 454 | * displaying in the explorer/experimenter gui |
---|
| 455 | */ |
---|
| 456 | public String attributeEvaluatorTipText() { |
---|
| 457 | return "Attribute evaluator to use for generating an initial ranking. " |
---|
| 458 | +"Use in conjunction with a rank race"; |
---|
| 459 | } |
---|
| 460 | |
---|
| 461 | /** |
---|
| 462 | * Set the attribute evaluator to use for generating the ranking. |
---|
| 463 | * @param newEvaluator the attribute evaluator to use. |
---|
| 464 | */ |
---|
| 465 | public void setAttributeEvaluator(ASEvaluation newEvaluator) { |
---|
| 466 | m_ASEval = newEvaluator; |
---|
| 467 | } |
---|
| 468 | |
---|
| 469 | /** |
---|
| 470 | * Get the attribute evaluator used to generate the ranking. |
---|
| 471 | * @return the evaluator used to generate the ranking. |
---|
| 472 | */ |
---|
| 473 | public ASEvaluation getAttributeEvaluator() { |
---|
| 474 | return m_ASEval; |
---|
| 475 | } |
---|
| 476 | |
---|
| 477 | /** |
---|
| 478 | * Returns the tip text for this property |
---|
| 479 | * @return tip text for this property suitable for |
---|
| 480 | * displaying in the explorer/experimenter gui |
---|
| 481 | */ |
---|
| 482 | public String generateRankingTipText() { |
---|
| 483 | return "Use the racing process to generate a ranked list of attributes. " |
---|
| 484 | +"Using this mode forces the race to be a forward type and then races " |
---|
| 485 | +"until all attributes have been added, thus giving a ranked list"; |
---|
| 486 | } |
---|
| 487 | |
---|
| 488 | /** |
---|
| 489 | * Records whether the user has requested a ranked list of attributes. |
---|
| 490 | * @param doRank true if ranking is requested |
---|
| 491 | */ |
---|
| 492 | public void setGenerateRanking(boolean doRank) { |
---|
| 493 | m_rankingRequested = doRank; |
---|
| 494 | if (m_rankingRequested) { |
---|
| 495 | try { |
---|
| 496 | setRaceType(new SelectedTag(FORWARD_RACE, |
---|
| 497 | TAGS_SELECTION)); |
---|
| 498 | } catch (Exception ex) { |
---|
| 499 | } |
---|
| 500 | } |
---|
| 501 | } |
---|
| 502 | |
---|
| 503 | /** |
---|
| 504 | * Gets whether ranking has been requested. This is used by the |
---|
| 505 | * AttributeSelection module to determine if rankedAttributes() |
---|
| 506 | * should be called. |
---|
| 507 | * @return true if ranking has been requested. |
---|
| 508 | */ |
---|
| 509 | public boolean getGenerateRanking() { |
---|
| 510 | return m_rankingRequested; |
---|
| 511 | } |
---|
| 512 | |
---|
| 513 | /** |
---|
| 514 | * Returns the tip text for this property |
---|
| 515 | * @return tip text for this property suitable for |
---|
| 516 | * displaying in the explorer/experimenter gui |
---|
| 517 | */ |
---|
| 518 | public String numToSelectTipText() { |
---|
| 519 | return "Specify the number of attributes to retain. Use in conjunction " |
---|
| 520 | +"with generateRanking. The default value " |
---|
| 521 | +"(-1) indicates that all attributes are to be retained. Use either " |
---|
| 522 | +"this option or a threshold to reduce the attribute set."; |
---|
| 523 | } |
---|
| 524 | |
---|
| 525 | /** |
---|
| 526 | * Specify the number of attributes to select from the ranked list |
---|
| 527 | * (if generating a ranking). -1 |
---|
| 528 | * indicates that all attributes are to be retained. |
---|
| 529 | * @param n the number of attributes to retain |
---|
| 530 | */ |
---|
| 531 | public void setNumToSelect(int n) { |
---|
| 532 | m_numToSelect = n; |
---|
| 533 | } |
---|
| 534 | |
---|
| 535 | /** |
---|
| 536 | * Gets the number of attributes to be retained. |
---|
| 537 | * @return the number of attributes to retain |
---|
| 538 | */ |
---|
| 539 | public int getNumToSelect() { |
---|
| 540 | return m_numToSelect; |
---|
| 541 | } |
---|
| 542 | |
---|
| 543 | /** |
---|
| 544 | * Gets the calculated number of attributes to retain. This is the |
---|
| 545 | * actual number of attributes to retain. This is the same as |
---|
| 546 | * getNumToSelect if the user specifies a number which is not less |
---|
| 547 | * than zero. Otherwise it should be the number of attributes in the |
---|
| 548 | * (potentially transformed) data. |
---|
| 549 | */ |
---|
| 550 | public int getCalculatedNumToSelect() { |
---|
| 551 | if (m_numToSelect >= 0) { |
---|
| 552 | m_calculatedNumToSelect = m_numToSelect; |
---|
| 553 | } |
---|
| 554 | return m_calculatedNumToSelect; |
---|
| 555 | } |
---|
| 556 | |
---|
| 557 | /** |
---|
| 558 | * Returns the tip text for this property |
---|
| 559 | * @return tip text for this property suitable for |
---|
| 560 | * displaying in the explorer/experimenter gui |
---|
| 561 | */ |
---|
| 562 | public String selectionThresholdTipText() { |
---|
| 563 | return "Set threshold by which attributes can be discarded. Default value " |
---|
| 564 | + "results in no attributes being discarded. Use in conjunction with " |
---|
| 565 | + "generateRanking"; |
---|
| 566 | } |
---|
| 567 | |
---|
| 568 | /** |
---|
| 569 | * Set the threshold by which the AttributeSelection module can discard |
---|
| 570 | * attributes. |
---|
| 571 | * @param threshold the threshold. |
---|
| 572 | */ |
---|
| 573 | public void setSelectionThreshold(double threshold) { |
---|
| 574 | m_threshold = threshold; |
---|
| 575 | } |
---|
| 576 | |
---|
| 577 | /** |
---|
| 578 | * Returns the threshold so that the AttributeSelection module can |
---|
| 579 | * discard attributes from the ranking. |
---|
| 580 | */ |
---|
| 581 | public double getSelectionThreshold() { |
---|
| 582 | return m_threshold; |
---|
| 583 | } |
---|
| 584 | |
---|
| 585 | |
---|
| 586 | /** |
---|
| 587 | * Returns an enumeration describing the available options. |
---|
| 588 | * @return an enumeration of all the available options. |
---|
| 589 | **/ |
---|
| 590 | public Enumeration listOptions () { |
---|
| 591 | Vector newVector = new Vector(); |
---|
| 592 | |
---|
| 593 | newVector.addElement(new Option( |
---|
| 594 | "\tType of race to perform.\n" |
---|
| 595 | + "\t(default = 0).", |
---|
| 596 | "R", 1 ,"-R <0 = forward | 1 = backward race | 2 = schemata | 3 = rank>")); |
---|
| 597 | |
---|
| 598 | newVector.addElement(new Option( |
---|
| 599 | "\tSignificance level for comaparisons\n" |
---|
| 600 | + "\t(default = 0.001(forward/backward/rank)/0.01(schemata)).", |
---|
| 601 | "L",1,"-L <significance>")); |
---|
| 602 | |
---|
| 603 | newVector.addElement(new Option( |
---|
| 604 | "\tThreshold for error comparison.\n" |
---|
| 605 | + "\t(default = 0.001).", |
---|
| 606 | "T",1,"-T <threshold>")); |
---|
| 607 | |
---|
| 608 | newVector.addElement(new Option( |
---|
| 609 | "\tAttribute ranker to use if doing a \n" |
---|
| 610 | + "\trank search. Place any\n" |
---|
| 611 | + "\tevaluator options LAST on \n" |
---|
| 612 | + "\tthe command line following a \"--\".\n" |
---|
| 613 | + "\teg. -A weka.attributeSelection.GainRatioAttributeEval ... -- -M.\n" |
---|
| 614 | + "\t(default = GainRatioAttributeEval)", |
---|
| 615 | "A", 1, "-A <attribute evaluator>")); |
---|
| 616 | |
---|
| 617 | newVector.addElement(new Option( |
---|
| 618 | "\tFolds for cross validation\n" |
---|
| 619 | + "\t(default = 0 (1 if schemata race)", |
---|
| 620 | "F",1,"-F <0 = 10 fold | 1 = leave-one-out>")); |
---|
| 621 | |
---|
| 622 | newVector.addElement(new Option( |
---|
| 623 | "\tGenerate a ranked list of attributes.\n" |
---|
| 624 | +"\tForces the search to be forward\n" |
---|
| 625 | +"\tand races until all attributes have\n" |
---|
| 626 | +"\tselected, thus producing a ranking.", |
---|
| 627 | "Q",0,"-Q")); |
---|
| 628 | |
---|
| 629 | newVector.addElement(new Option( |
---|
| 630 | "\tSpecify number of attributes to retain from \n" |
---|
| 631 | + "\tthe ranking. Overides -T. Use in conjunction with -Q", |
---|
| 632 | "N", 1, "-N <num to select>")); |
---|
| 633 | |
---|
| 634 | newVector.addElement(new Option( |
---|
| 635 | "\tSpecify a theshold by which attributes\n" |
---|
| 636 | + "\tmay be discarded from the ranking.\n" |
---|
| 637 | +"\tUse in conjuction with -Q", |
---|
| 638 | "J",1, "-J <threshold>")); |
---|
| 639 | |
---|
| 640 | newVector.addElement(new Option( |
---|
| 641 | "\tVerbose output for monitoring the search.", |
---|
| 642 | "Z",0,"-Z")); |
---|
| 643 | |
---|
| 644 | if ((m_ASEval != null) && |
---|
| 645 | (m_ASEval instanceof OptionHandler)) { |
---|
| 646 | newVector.addElement(new Option( |
---|
| 647 | "", |
---|
| 648 | "", 0, "\nOptions specific to evaluator " |
---|
| 649 | + m_ASEval.getClass().getName() + ":")); |
---|
| 650 | |
---|
| 651 | Enumeration enu = ((OptionHandler)m_ASEval).listOptions(); |
---|
| 652 | while (enu.hasMoreElements()) { |
---|
| 653 | newVector.addElement(enu.nextElement()); |
---|
| 654 | } |
---|
| 655 | } |
---|
| 656 | |
---|
| 657 | return newVector.elements(); |
---|
| 658 | } |
---|
| 659 | |
---|
| 660 | /** |
---|
| 661 | * Parses a given list of options. <p/> |
---|
| 662 | * |
---|
| 663 | <!-- options-start --> |
---|
| 664 | * Valid options are: <p/> |
---|
| 665 | * |
---|
| 666 | * <pre> -R <0 = forward | 1 = backward race | 2 = schemata | 3 = rank> |
---|
| 667 | * Type of race to perform. |
---|
| 668 | * (default = 0).</pre> |
---|
| 669 | * |
---|
| 670 | * <pre> -L <significance> |
---|
| 671 | * Significance level for comaparisons |
---|
| 672 | * (default = 0.001(forward/backward/rank)/0.01(schemata)).</pre> |
---|
| 673 | * |
---|
| 674 | * <pre> -T <threshold> |
---|
| 675 | * Threshold for error comparison. |
---|
| 676 | * (default = 0.001).</pre> |
---|
| 677 | * |
---|
| 678 | * <pre> -A <attribute evaluator> |
---|
| 679 | * Attribute ranker to use if doing a |
---|
| 680 | * rank search. Place any |
---|
| 681 | * evaluator options LAST on |
---|
| 682 | * the command line following a "--". |
---|
| 683 | * eg. -A weka.attributeSelection.GainRatioAttributeEval ... -- -M. |
---|
| 684 | * (default = GainRatioAttributeEval)</pre> |
---|
| 685 | * |
---|
| 686 | * <pre> -F <0 = 10 fold | 1 = leave-one-out> |
---|
| 687 | * Folds for cross validation |
---|
| 688 | * (default = 0 (1 if schemata race)</pre> |
---|
| 689 | * |
---|
| 690 | * <pre> -Q |
---|
| 691 | * Generate a ranked list of attributes. |
---|
| 692 | * Forces the search to be forward |
---|
| 693 | * and races until all attributes have |
---|
| 694 | * selected, thus producing a ranking.</pre> |
---|
| 695 | * |
---|
| 696 | * <pre> -N <num to select> |
---|
| 697 | * Specify number of attributes to retain from |
---|
| 698 | * the ranking. Overides -T. Use in conjunction with -Q</pre> |
---|
| 699 | * |
---|
| 700 | * <pre> -J <threshold> |
---|
| 701 | * Specify a theshold by which attributes |
---|
| 702 | * may be discarded from the ranking. |
---|
| 703 | * Use in conjuction with -Q</pre> |
---|
| 704 | * |
---|
| 705 | * <pre> -Z |
---|
| 706 | * Verbose output for monitoring the search.</pre> |
---|
| 707 | * |
---|
| 708 | * <pre> |
---|
| 709 | * Options specific to evaluator weka.attributeSelection.GainRatioAttributeEval: |
---|
| 710 | * </pre> |
---|
| 711 | * |
---|
| 712 | * <pre> -M |
---|
| 713 | * treat missing values as a seperate value.</pre> |
---|
| 714 | * |
---|
| 715 | <!-- options-end --> |
---|
| 716 | * |
---|
| 717 | * @param options the list of options as an array of strings |
---|
| 718 | * @throws Exception if an option is not supported |
---|
| 719 | */ |
---|
| 720 | public void setOptions (String[] options) |
---|
| 721 | throws Exception { |
---|
| 722 | String optionString; |
---|
| 723 | resetOptions(); |
---|
| 724 | |
---|
| 725 | optionString = Utils.getOption('R', options); |
---|
| 726 | if (optionString.length() != 0) { |
---|
| 727 | setRaceType(new SelectedTag(Integer.parseInt(optionString), |
---|
| 728 | TAGS_SELECTION)); |
---|
| 729 | } |
---|
| 730 | |
---|
| 731 | optionString = Utils.getOption('F', options); |
---|
| 732 | if (optionString.length() != 0) { |
---|
| 733 | setFoldsType(new SelectedTag(Integer.parseInt(optionString), |
---|
| 734 | XVALTAGS_SELECTION)); |
---|
| 735 | } |
---|
| 736 | |
---|
| 737 | optionString = Utils.getOption('L', options); |
---|
| 738 | if (optionString.length() !=0) { |
---|
| 739 | setSignificanceLevel(Double.parseDouble(optionString)); |
---|
| 740 | } |
---|
| 741 | |
---|
| 742 | optionString = Utils.getOption('T', options); |
---|
| 743 | if (optionString.length() !=0) { |
---|
| 744 | setThreshold(Double.parseDouble(optionString)); |
---|
| 745 | } |
---|
| 746 | |
---|
| 747 | optionString = Utils.getOption('A', options); |
---|
| 748 | if (optionString.length() != 0) { |
---|
| 749 | setAttributeEvaluator(ASEvaluation.forName(optionString, |
---|
| 750 | Utils.partitionOptions(options))); |
---|
| 751 | } |
---|
| 752 | |
---|
| 753 | setGenerateRanking(Utils.getFlag('Q', options)); |
---|
| 754 | |
---|
| 755 | optionString = Utils.getOption('J', options); |
---|
| 756 | if (optionString.length() != 0) { |
---|
| 757 | setSelectionThreshold(Double.parseDouble(optionString)); |
---|
| 758 | } |
---|
| 759 | |
---|
| 760 | optionString = Utils.getOption('N', options); |
---|
| 761 | if (optionString.length() != 0) { |
---|
| 762 | setNumToSelect(Integer.parseInt(optionString)); |
---|
| 763 | } |
---|
| 764 | |
---|
| 765 | setDebug(Utils.getFlag('Z', options)); |
---|
| 766 | } |
---|
| 767 | |
---|
| 768 | /** |
---|
| 769 | * Gets the current settings of BestFirst. |
---|
| 770 | * @return an array of strings suitable for passing to setOptions() |
---|
| 771 | */ |
---|
| 772 | public String[] getOptions () { |
---|
| 773 | int current = 0; |
---|
| 774 | String[] evaluatorOptions = new String[0]; |
---|
| 775 | |
---|
| 776 | if ((m_ASEval != null) && |
---|
| 777 | (m_ASEval instanceof OptionHandler)) { |
---|
| 778 | evaluatorOptions = ((OptionHandler)m_ASEval).getOptions(); |
---|
| 779 | } |
---|
| 780 | String[] options = new String[17+evaluatorOptions.length]; |
---|
| 781 | |
---|
| 782 | options[current++] = "-R"; options[current++] = ""+m_raceType; |
---|
| 783 | options[current++] = "-L"; options[current++] = ""+getSignificanceLevel(); |
---|
| 784 | options[current++] = "-T"; options[current++] = ""+getThreshold(); |
---|
| 785 | options[current++] = "-F"; options[current++] = ""+m_xvalType; |
---|
| 786 | if (getGenerateRanking()) { |
---|
| 787 | options[current++] = "-Q"; |
---|
| 788 | } |
---|
| 789 | options[current++] = "-N"; options[current++] = ""+getNumToSelect(); |
---|
| 790 | options[current++] = "-J"; options[current++] = ""+getSelectionThreshold(); |
---|
| 791 | if (getDebug()) { |
---|
| 792 | options[current++] = "-Z"; |
---|
| 793 | } |
---|
| 794 | |
---|
| 795 | if (getAttributeEvaluator() != null) { |
---|
| 796 | options[current++] = "-A"; |
---|
| 797 | options[current++] = getAttributeEvaluator().getClass().getName(); |
---|
| 798 | options[current++] = "--"; |
---|
| 799 | System.arraycopy(evaluatorOptions, 0, options, current, |
---|
| 800 | evaluatorOptions.length); |
---|
| 801 | current += evaluatorOptions.length; |
---|
| 802 | } |
---|
| 803 | |
---|
| 804 | |
---|
| 805 | while (current < options.length) { |
---|
| 806 | options[current++] = ""; |
---|
| 807 | } |
---|
| 808 | |
---|
| 809 | return options; |
---|
| 810 | } |
---|
| 811 | |
---|
| 812 | |
---|
| 813 | |
---|
| 814 | |
---|
| 815 | /** |
---|
| 816 | * Searches the attribute subset space by racing cross validation |
---|
| 817 | * errors of competing subsets |
---|
| 818 | * |
---|
| 819 | * @param ASEval the attribute evaluator to guide the search |
---|
| 820 | * @param data the training instances. |
---|
| 821 | * @return an array (not necessarily ordered) of selected attribute indexes |
---|
| 822 | * @throws Exception if the search can't be completed |
---|
| 823 | */ |
---|
| 824 | public int[] search (ASEvaluation ASEval, Instances data) |
---|
| 825 | throws Exception { |
---|
| 826 | if (!(ASEval instanceof SubsetEvaluator)) { |
---|
| 827 | throw new Exception(ASEval.getClass().getName() |
---|
| 828 | + " is not a " |
---|
| 829 | + "Subset evaluator! (RaceSearch)"); |
---|
| 830 | } |
---|
| 831 | |
---|
| 832 | if (ASEval instanceof UnsupervisedSubsetEvaluator) { |
---|
| 833 | throw new Exception("Can't use an unsupervised subset evaluator " |
---|
| 834 | +"(RaceSearch)."); |
---|
| 835 | } |
---|
| 836 | |
---|
| 837 | if (!(ASEval instanceof HoldOutSubsetEvaluator)) { |
---|
| 838 | throw new Exception("Must use a HoldOutSubsetEvaluator, eg. " |
---|
| 839 | +"weka.attributeSelection.ClassifierSubsetEval " |
---|
| 840 | +"(RaceSearch)"); |
---|
| 841 | } |
---|
| 842 | |
---|
| 843 | if (!(ASEval instanceof ErrorBasedMeritEvaluator)) { |
---|
| 844 | throw new Exception("Only error based subset evaluators can be used, " |
---|
| 845 | +"eg. weka.attributeSelection.ClassifierSubsetEval " |
---|
| 846 | +"(RaceSearch)"); |
---|
| 847 | } |
---|
| 848 | |
---|
| 849 | m_Instances = new Instances(data); |
---|
| 850 | m_Instances.deleteWithMissingClass(); |
---|
| 851 | if (m_Instances.numInstances() == 0) { |
---|
| 852 | throw new Exception("All train instances have missing class! (RaceSearch)"); |
---|
| 853 | } |
---|
| 854 | if (m_rankingRequested && m_numToSelect > m_Instances.numAttributes()-1) { |
---|
| 855 | throw new Exception("More attributes requested than exist in the data " |
---|
| 856 | +"(RaceSearch)."); |
---|
| 857 | } |
---|
| 858 | m_theEvaluator = (HoldOutSubsetEvaluator)ASEval; |
---|
| 859 | m_numAttribs = m_Instances.numAttributes(); |
---|
| 860 | m_classIndex = m_Instances.classIndex(); |
---|
| 861 | |
---|
| 862 | if (m_rankingRequested) { |
---|
| 863 | m_rankedAtts = new double[m_numAttribs-1][2]; |
---|
| 864 | m_rankedSoFar = 0; |
---|
| 865 | } |
---|
| 866 | |
---|
| 867 | if (m_xvalType == LEAVE_ONE_OUT) { |
---|
| 868 | m_numFolds = m_Instances.numInstances(); |
---|
| 869 | } else { |
---|
| 870 | m_numFolds = 10; |
---|
| 871 | } |
---|
| 872 | |
---|
| 873 | Random random = new Random(1); // I guess this should really be a parameter? |
---|
| 874 | m_Instances.randomize(random); |
---|
| 875 | int [] bestSubset=null; |
---|
| 876 | |
---|
| 877 | switch (m_raceType) { |
---|
| 878 | case FORWARD_RACE: |
---|
| 879 | case BACKWARD_RACE: |
---|
| 880 | bestSubset = hillclimbRace(m_Instances, random); |
---|
| 881 | break; |
---|
| 882 | case SCHEMATA_RACE: |
---|
| 883 | bestSubset = schemataRace(m_Instances, random); |
---|
| 884 | break; |
---|
| 885 | case RANK_RACE: |
---|
| 886 | bestSubset = rankRace(m_Instances, random); |
---|
| 887 | break; |
---|
| 888 | } |
---|
| 889 | |
---|
| 890 | return bestSubset; |
---|
| 891 | } |
---|
| 892 | |
---|
| 893 | public double [][] rankedAttributes() throws Exception { |
---|
| 894 | if (!m_rankingRequested) { |
---|
| 895 | throw new Exception("Need to request a ranked list of attributes " |
---|
| 896 | +"before attributes can be ranked (RaceSearch)."); |
---|
| 897 | } |
---|
| 898 | if (m_rankedAtts == null) { |
---|
| 899 | throw new Exception("Search must be performed before attributes " |
---|
| 900 | +"can be ranked (RaceSearch)."); |
---|
| 901 | } |
---|
| 902 | |
---|
| 903 | double [][] final_rank = new double [m_rankedSoFar][2]; |
---|
| 904 | for (int i=0;i<m_rankedSoFar;i++) { |
---|
| 905 | final_rank[i][0] = m_rankedAtts[i][0]; |
---|
| 906 | final_rank[i][1] = m_rankedAtts[i][1]; |
---|
| 907 | } |
---|
| 908 | |
---|
| 909 | if (m_numToSelect <= 0) { |
---|
| 910 | if (m_threshold == -Double.MAX_VALUE) { |
---|
| 911 | m_calculatedNumToSelect = final_rank.length; |
---|
| 912 | } else { |
---|
| 913 | determineNumToSelectFromThreshold(final_rank); |
---|
| 914 | } |
---|
| 915 | } |
---|
| 916 | |
---|
| 917 | return final_rank; |
---|
| 918 | } |
---|
| 919 | |
---|
| 920 | private void determineNumToSelectFromThreshold(double [][] ranking) { |
---|
| 921 | int count = 0; |
---|
| 922 | for (int i = 0; i < ranking.length; i++) { |
---|
| 923 | if (ranking[i][1] > m_threshold) { |
---|
| 924 | count++; |
---|
| 925 | } |
---|
| 926 | } |
---|
| 927 | m_calculatedNumToSelect = count; |
---|
| 928 | } |
---|
| 929 | |
---|
| 930 | /** |
---|
| 931 | * Print an attribute set. |
---|
| 932 | */ |
---|
| 933 | private String printSets(char [][]raceSets) { |
---|
| 934 | StringBuffer temp = new StringBuffer(); |
---|
| 935 | for (int i=0;i<raceSets.length;i++) { |
---|
| 936 | for (int j=0;j<m_numAttribs;j++) { |
---|
| 937 | temp.append(raceSets[i][j]); |
---|
| 938 | } |
---|
| 939 | temp.append('\n'); |
---|
| 940 | } |
---|
| 941 | return temp.toString(); |
---|
| 942 | } |
---|
| 943 | |
---|
| 944 | /** |
---|
| 945 | * Performs a schemata race---a series of races in parallel. |
---|
| 946 | * @param data the instances to estimate accuracy over. |
---|
| 947 | * @param random a random number generator |
---|
| 948 | * @return an array of selected attribute indices. |
---|
| 949 | */ |
---|
| 950 | private int [] schemataRace(Instances data, Random random) throws Exception { |
---|
| 951 | // # races, 2 (competitors in each race), # attributes |
---|
| 952 | char [][][] parallelRaces; |
---|
| 953 | int numRaces = m_numAttribs-1; |
---|
| 954 | Random r = new Random(42); |
---|
| 955 | int numInstances = data.numInstances(); |
---|
| 956 | Instances trainCV; Instances testCV; |
---|
| 957 | Instance testInstance; |
---|
| 958 | |
---|
| 959 | // statistics on the racers |
---|
| 960 | Stats [][] raceStats = new Stats[numRaces][2]; |
---|
| 961 | |
---|
| 962 | parallelRaces = new char [numRaces][2][m_numAttribs-1]; |
---|
| 963 | char [] base = new char [m_numAttribs]; |
---|
| 964 | for (int i=0;i<m_numAttribs;i++) { |
---|
| 965 | base[i] = '*'; |
---|
| 966 | } |
---|
| 967 | |
---|
| 968 | int count=0; |
---|
| 969 | // set up initial races |
---|
| 970 | for (int i=0;i<m_numAttribs;i++) { |
---|
| 971 | if (i != m_classIndex) { |
---|
| 972 | parallelRaces[count][0] = (char [])base.clone(); |
---|
| 973 | parallelRaces[count][1] = (char [])base.clone(); |
---|
| 974 | parallelRaces[count][0][i] = '1'; |
---|
| 975 | parallelRaces[count++][1][i] = '0'; |
---|
| 976 | } |
---|
| 977 | } |
---|
| 978 | |
---|
| 979 | if (m_debug) { |
---|
| 980 | System.err.println("Initial sets:\n"); |
---|
| 981 | for (int i=0;i<numRaces;i++) { |
---|
| 982 | System.err.print(printSets(parallelRaces[i])+"--------------\n"); |
---|
| 983 | } |
---|
| 984 | } |
---|
| 985 | |
---|
| 986 | BitSet randomB = new BitSet(m_numAttribs); |
---|
| 987 | char [] randomBC = new char [m_numAttribs]; |
---|
| 988 | |
---|
| 989 | // notes which bit positions have been decided |
---|
| 990 | boolean [] attributeConstraints = new boolean[m_numAttribs]; |
---|
| 991 | double error; |
---|
| 992 | int evaluationCount = 0; |
---|
| 993 | raceSet: while (numRaces > 0) { |
---|
| 994 | boolean won = false; |
---|
| 995 | for (int i=0;i<numRaces;i++) { |
---|
| 996 | raceStats[i][0] = new Stats(); |
---|
| 997 | raceStats[i][1] = new Stats(); |
---|
| 998 | } |
---|
| 999 | |
---|
| 1000 | // keep an eye on how many test instances have been randomly sampled |
---|
| 1001 | int sampleCount = 0; |
---|
| 1002 | // run the current set of races |
---|
| 1003 | while (!won) { |
---|
| 1004 | // generate a random binary string |
---|
| 1005 | for (int i=0;i<m_numAttribs;i++) { |
---|
| 1006 | if (i != m_classIndex) { |
---|
| 1007 | if (!attributeConstraints[i]) { |
---|
| 1008 | if (r.nextDouble() < 0.5) { |
---|
| 1009 | randomB.set(i); |
---|
| 1010 | } else { |
---|
| 1011 | randomB.clear(i); |
---|
| 1012 | } |
---|
| 1013 | } else { // this position has been decided from previous races |
---|
| 1014 | if (base[i] == '1') { |
---|
| 1015 | randomB.set(i); |
---|
| 1016 | } else { |
---|
| 1017 | randomB.clear(i); |
---|
| 1018 | } |
---|
| 1019 | } |
---|
| 1020 | } |
---|
| 1021 | } |
---|
| 1022 | |
---|
| 1023 | // randomly select an instance to test on |
---|
| 1024 | int testIndex = Math.abs(r.nextInt() % numInstances); |
---|
| 1025 | |
---|
| 1026 | |
---|
| 1027 | // We want to randomize the data the same way for every |
---|
| 1028 | // learning scheme. |
---|
| 1029 | trainCV = data.trainCV(numInstances, testIndex, new Random (1)); |
---|
| 1030 | testCV = data.testCV(numInstances, testIndex); |
---|
| 1031 | testInstance = testCV.instance(0); |
---|
| 1032 | sampleCount++; |
---|
| 1033 | /* if (sampleCount > numInstances) { |
---|
| 1034 | throw new Exception("raceSchemata: No clear winner after sampling " |
---|
| 1035 | +sampleCount+" instances."); |
---|
| 1036 | } */ |
---|
| 1037 | |
---|
| 1038 | m_theEvaluator.buildEvaluator(trainCV); |
---|
| 1039 | |
---|
| 1040 | // the evaluator must retrain for every test point |
---|
| 1041 | error = -((HoldOutSubsetEvaluator)m_theEvaluator). |
---|
| 1042 | evaluateSubset(randomB, |
---|
| 1043 | testInstance, |
---|
| 1044 | true); |
---|
| 1045 | evaluationCount++; |
---|
| 1046 | |
---|
| 1047 | // see which racers match this random subset |
---|
| 1048 | for (int i=0;i<m_numAttribs;i++) { |
---|
| 1049 | if (randomB.get(i)) { |
---|
| 1050 | randomBC[i] = '1'; |
---|
| 1051 | } else { |
---|
| 1052 | randomBC[i] = '0'; |
---|
| 1053 | } |
---|
| 1054 | } |
---|
| 1055 | // System.err.println("Random subset: "+(new String(randomBC))); |
---|
| 1056 | |
---|
| 1057 | checkRaces: for (int i=0;i<numRaces;i++) { |
---|
| 1058 | // if a pair of racers has evaluated more than num instances |
---|
| 1059 | // then bail out---unlikely that having any more atts is any |
---|
| 1060 | // better than the current base set. |
---|
| 1061 | if (((raceStats[i][0].count + raceStats[i][1].count) / 2) > |
---|
| 1062 | (numInstances)) { |
---|
| 1063 | break raceSet; |
---|
| 1064 | } |
---|
| 1065 | for (int j=0;j<2;j++) { |
---|
| 1066 | boolean matched = true; |
---|
| 1067 | for (int k =0;k<m_numAttribs;k++) { |
---|
| 1068 | if (parallelRaces[i][j][k] != '*') { |
---|
| 1069 | if (parallelRaces[i][j][k] != randomBC[k]) { |
---|
| 1070 | matched = false; |
---|
| 1071 | break; |
---|
| 1072 | } |
---|
| 1073 | } |
---|
| 1074 | } |
---|
| 1075 | if (matched) { // update the stats for this racer |
---|
| 1076 | // System.err.println("Matched "+i+" "+j); |
---|
| 1077 | raceStats[i][j].add(error); |
---|
| 1078 | |
---|
| 1079 | // does this race have a clear winner, meaning we can |
---|
| 1080 | // terminate the whole set of parallel races? |
---|
| 1081 | if (raceStats[i][0].count > m_samples && |
---|
| 1082 | raceStats[i][1].count > m_samples) { |
---|
| 1083 | raceStats[i][0].calculateDerived(); |
---|
| 1084 | raceStats[i][1].calculateDerived(); |
---|
| 1085 | // System.err.println(j+" : "+(new String(parallelRaces[i][j]))); |
---|
| 1086 | // System.err.println(raceStats[i][0]); |
---|
| 1087 | // System.err.println(raceStats[i][1]); |
---|
| 1088 | // check the ttest |
---|
| 1089 | double prob = ttest(raceStats[i][0], raceStats[i][1]); |
---|
| 1090 | // System.err.println("Prob :"+prob); |
---|
| 1091 | if (prob < m_sigLevel) { // stop the races we have a winner! |
---|
| 1092 | if (raceStats[i][0].mean < raceStats[i][1].mean) { |
---|
| 1093 | base = (char [])parallelRaces[i][0].clone(); |
---|
| 1094 | m_bestMerit = raceStats[i][0].mean; |
---|
| 1095 | if (m_debug) { |
---|
| 1096 | System.err.println("contender 0 won "); |
---|
| 1097 | } |
---|
| 1098 | } else { |
---|
| 1099 | base = (char [])parallelRaces[i][1].clone(); |
---|
| 1100 | m_bestMerit = raceStats[i][1].mean; |
---|
| 1101 | if (m_debug) { |
---|
| 1102 | System.err.println("contender 1 won"); |
---|
| 1103 | } |
---|
| 1104 | } |
---|
| 1105 | if (m_debug) { |
---|
| 1106 | System.err.println((new String(parallelRaces[i][0])) |
---|
| 1107 | +" "+(new String(parallelRaces[i][1]))); |
---|
| 1108 | System.err.println("Means : "+raceStats[i][0].mean |
---|
| 1109 | +" vs"+raceStats[i][1].mean); |
---|
| 1110 | System.err.println("Evaluations so far : " |
---|
| 1111 | +evaluationCount); |
---|
| 1112 | } |
---|
| 1113 | won = true; |
---|
| 1114 | break checkRaces; |
---|
| 1115 | } |
---|
| 1116 | } |
---|
| 1117 | |
---|
| 1118 | } |
---|
| 1119 | } |
---|
| 1120 | } |
---|
| 1121 | } |
---|
| 1122 | |
---|
| 1123 | numRaces--; |
---|
| 1124 | // set up the next set of races if necessary |
---|
| 1125 | if (numRaces > 0 && won) { |
---|
| 1126 | parallelRaces = new char [numRaces][2][m_numAttribs-1]; |
---|
| 1127 | raceStats = new Stats[numRaces][2]; |
---|
| 1128 | // update the attribute constraints |
---|
| 1129 | for (int i=0;i<m_numAttribs;i++) { |
---|
| 1130 | if (i != m_classIndex && !attributeConstraints[i] && |
---|
| 1131 | base[i] != '*') { |
---|
| 1132 | attributeConstraints[i] = true; |
---|
| 1133 | break; |
---|
| 1134 | } |
---|
| 1135 | } |
---|
| 1136 | count=0; |
---|
| 1137 | for (int i=0;i<numRaces;i++) { |
---|
| 1138 | parallelRaces[i][0] = (char [])base.clone(); |
---|
| 1139 | parallelRaces[i][1] = (char [])base.clone(); |
---|
| 1140 | for (int j=count;j<m_numAttribs;j++) { |
---|
| 1141 | if (j != m_classIndex && parallelRaces[i][0][j] == '*') { |
---|
| 1142 | parallelRaces[i][0][j] = '1'; |
---|
| 1143 | parallelRaces[i][1][j] = '0'; |
---|
| 1144 | count = j+1; |
---|
| 1145 | break; |
---|
| 1146 | } |
---|
| 1147 | } |
---|
| 1148 | } |
---|
| 1149 | |
---|
| 1150 | if (m_debug) { |
---|
| 1151 | System.err.println("Next sets:\n"); |
---|
| 1152 | for (int i=0;i<numRaces;i++) { |
---|
| 1153 | System.err.print(printSets(parallelRaces[i])+"--------------\n"); |
---|
| 1154 | } |
---|
| 1155 | } |
---|
| 1156 | } |
---|
| 1157 | } |
---|
| 1158 | |
---|
| 1159 | if (m_debug) { |
---|
| 1160 | System.err.println("Total evaluations : " |
---|
| 1161 | +evaluationCount); |
---|
| 1162 | } |
---|
| 1163 | return attributeList(base); |
---|
| 1164 | } |
---|
| 1165 | |
---|
| 1166 | /** |
---|
| 1167 | * t-test for unequal sample sizes and same variance. Returns probability |
---|
| 1168 | * that observed difference in means is due to chance. |
---|
| 1169 | */ |
---|
| 1170 | private double ttest(Stats c1, Stats c2) throws Exception { |
---|
| 1171 | double n1 = c1.count; double n2 = c2.count; |
---|
| 1172 | double v1 = c1.stdDev * c1.stdDev; |
---|
| 1173 | double v2 = c2.stdDev * c2.stdDev; |
---|
| 1174 | double av1 = c1.mean; |
---|
| 1175 | double av2 = c2.mean; |
---|
| 1176 | |
---|
| 1177 | double df = n1 + n2 - 2; |
---|
| 1178 | double cv = (((n1 - 1) * v1) + ((n2 - 1) * v2)) /df; |
---|
| 1179 | double t = (av1 - av2) / Math.sqrt(cv * ((1.0 / n1) + (1.0 / n2))); |
---|
| 1180 | |
---|
| 1181 | return Statistics.incompleteBeta(df / 2.0, 0.5, |
---|
| 1182 | df / (df + (t * t))); |
---|
| 1183 | } |
---|
| 1184 | |
---|
| 1185 | /** |
---|
| 1186 | * Performs a rank race---race consisting of no attributes, the top |
---|
| 1187 | * ranked attribute, the top two attributes etc. The initial ranking |
---|
| 1188 | * is determined by an attribute evaluator. |
---|
| 1189 | * @param data the instances to estimate accuracy over |
---|
| 1190 | * @param random a random number generator |
---|
| 1191 | * @return an array of selected attribute indices. |
---|
| 1192 | */ |
---|
| 1193 | private int [] rankRace(Instances data, Random random) throws Exception { |
---|
| 1194 | char [] baseSet = new char [m_numAttribs]; |
---|
| 1195 | char [] bestSet; |
---|
| 1196 | double bestSetError; |
---|
| 1197 | for (int i=0;i<m_numAttribs;i++) { |
---|
| 1198 | if (i == m_classIndex) { |
---|
| 1199 | baseSet[i] = '-'; |
---|
| 1200 | } else { |
---|
| 1201 | baseSet[i] = '0'; |
---|
| 1202 | } |
---|
| 1203 | } |
---|
| 1204 | |
---|
| 1205 | int numCompetitors = m_numAttribs-1; |
---|
| 1206 | char [][] raceSets = new char [numCompetitors+1][m_numAttribs]; |
---|
| 1207 | |
---|
| 1208 | if (m_ASEval instanceof AttributeEvaluator) { |
---|
| 1209 | // generate the attribute ranking first |
---|
| 1210 | Ranker ranker = new Ranker(); |
---|
| 1211 | m_ASEval.buildEvaluator(data); |
---|
| 1212 | m_Ranking = ranker.search(m_ASEval,data); |
---|
| 1213 | } else { |
---|
| 1214 | GreedyStepwise fs = new GreedyStepwise(); |
---|
| 1215 | double [][]rankres; |
---|
| 1216 | fs.setGenerateRanking(true); |
---|
| 1217 | ((ASEvaluation)m_ASEval).buildEvaluator(data); |
---|
| 1218 | fs.search(m_ASEval, data); |
---|
| 1219 | rankres = fs.rankedAttributes(); |
---|
| 1220 | m_Ranking = new int[rankres.length]; |
---|
| 1221 | for (int i=0;i<rankres.length;i++) { |
---|
| 1222 | m_Ranking[i] = (int)rankres[i][0]; |
---|
| 1223 | } |
---|
| 1224 | } |
---|
| 1225 | |
---|
| 1226 | // set up the race |
---|
| 1227 | raceSets[0] = (char [])baseSet.clone(); |
---|
| 1228 | for (int i=0;i<m_Ranking.length;i++) { |
---|
| 1229 | raceSets[i+1] = (char [])raceSets[i].clone(); |
---|
| 1230 | raceSets[i+1][m_Ranking[i]] = '1'; |
---|
| 1231 | } |
---|
| 1232 | |
---|
| 1233 | if (m_debug) { |
---|
| 1234 | System.err.println("Initial sets:\n"+printSets(raceSets)); |
---|
| 1235 | } |
---|
| 1236 | |
---|
| 1237 | // run the race |
---|
| 1238 | double [] winnerInfo = raceSubsets(raceSets, data, true, random); |
---|
| 1239 | bestSetError = winnerInfo[1]; |
---|
| 1240 | bestSet = (char [])raceSets[(int)winnerInfo[0]].clone(); |
---|
| 1241 | m_bestMerit = bestSetError; |
---|
| 1242 | return attributeList(bestSet); |
---|
| 1243 | } |
---|
| 1244 | |
---|
| 1245 | /** |
---|
| 1246 | * Performs a hill climbing race---all single attribute changes to a |
---|
| 1247 | * base subset are raced in parallel. The winner is chosen and becomes |
---|
| 1248 | * the new base subset and the process is repeated until there is no |
---|
| 1249 | * improvement in error over the base subset. |
---|
| 1250 | * @param data the instances to estimate accuracy over |
---|
| 1251 | * @param random a random number generator |
---|
| 1252 | * @return an array of selected attribute indices. |
---|
| 1253 | * @throws Exception if something goes wrong |
---|
| 1254 | */ |
---|
| 1255 | private int [] hillclimbRace(Instances data, Random random) throws Exception { |
---|
| 1256 | double baseSetError; |
---|
| 1257 | char [] baseSet = new char [m_numAttribs]; |
---|
| 1258 | |
---|
| 1259 | for (int i=0;i<m_numAttribs;i++) { |
---|
| 1260 | if (i != m_classIndex) { |
---|
| 1261 | if (m_raceType == FORWARD_RACE) { |
---|
| 1262 | baseSet[i] = '0'; |
---|
| 1263 | } else { |
---|
| 1264 | baseSet[i] = '1'; |
---|
| 1265 | } |
---|
| 1266 | } else { |
---|
| 1267 | baseSet[i] = '-'; |
---|
| 1268 | } |
---|
| 1269 | } |
---|
| 1270 | |
---|
| 1271 | int numCompetitors = m_numAttribs-1; |
---|
| 1272 | char [][] raceSets = new char [numCompetitors+1][m_numAttribs]; |
---|
| 1273 | |
---|
| 1274 | raceSets[0] = (char [])baseSet.clone(); |
---|
| 1275 | int count = 1; |
---|
| 1276 | // initialize each race set to 1 attribute |
---|
| 1277 | for (int i=0;i<m_numAttribs;i++) { |
---|
| 1278 | if (i != m_classIndex) { |
---|
| 1279 | raceSets[count] = (char [])baseSet.clone(); |
---|
| 1280 | if (m_raceType == BACKWARD_RACE) { |
---|
| 1281 | raceSets[count++][i] = '0'; |
---|
| 1282 | } else { |
---|
| 1283 | raceSets[count++][i] = '1'; |
---|
| 1284 | } |
---|
| 1285 | } |
---|
| 1286 | } |
---|
| 1287 | |
---|
| 1288 | if (m_debug) { |
---|
| 1289 | System.err.println("Initial sets:\n"+printSets(raceSets)); |
---|
| 1290 | } |
---|
| 1291 | |
---|
| 1292 | // race the initial sets (base set either no or all features) |
---|
| 1293 | double [] winnerInfo = raceSubsets(raceSets, data, true, random); |
---|
| 1294 | baseSetError = winnerInfo[1]; |
---|
| 1295 | m_bestMerit = baseSetError; |
---|
| 1296 | baseSet = (char [])raceSets[(int)winnerInfo[0]].clone(); |
---|
| 1297 | if (m_rankingRequested) { |
---|
| 1298 | m_rankedAtts[m_rankedSoFar][0] = (int)(winnerInfo[0]-1); |
---|
| 1299 | m_rankedAtts[m_rankedSoFar][1] = winnerInfo[1]; |
---|
| 1300 | m_rankedSoFar++; |
---|
| 1301 | } |
---|
| 1302 | |
---|
| 1303 | boolean improved = true; |
---|
| 1304 | int j; |
---|
| 1305 | // now race until there is no improvement over the base set or only |
---|
| 1306 | // one competitor remains |
---|
| 1307 | while (improved) { |
---|
| 1308 | // generate the next set of competitors |
---|
| 1309 | numCompetitors--; |
---|
| 1310 | if (numCompetitors == 0) { //race finished! |
---|
| 1311 | break; |
---|
| 1312 | } |
---|
| 1313 | j=0; |
---|
| 1314 | // +1. we'll race against the base set---might be able to bail out |
---|
| 1315 | // of the race if none from the new set are statistically better |
---|
| 1316 | // than the base set. Base set is stored in loc 0. |
---|
| 1317 | raceSets = new char [numCompetitors+1][m_numAttribs]; |
---|
| 1318 | for (int i=0;i<numCompetitors+1;i++) { |
---|
| 1319 | raceSets[i] = (char [])baseSet.clone(); |
---|
| 1320 | if (i > 0) { |
---|
| 1321 | for (int k=j;k<m_numAttribs;k++) { |
---|
| 1322 | if (m_raceType == 1) { |
---|
| 1323 | if (k != m_classIndex && raceSets[i][k] != '0') { |
---|
| 1324 | raceSets[i][k] = '0'; |
---|
| 1325 | j = k+1; |
---|
| 1326 | break; |
---|
| 1327 | } |
---|
| 1328 | } else { |
---|
| 1329 | if (k != m_classIndex && raceSets[i][k] != '1') { |
---|
| 1330 | raceSets[i][k] = '1'; |
---|
| 1331 | j = k+1; |
---|
| 1332 | break; |
---|
| 1333 | } |
---|
| 1334 | } |
---|
| 1335 | } |
---|
| 1336 | } |
---|
| 1337 | } |
---|
| 1338 | |
---|
| 1339 | if (m_debug) { |
---|
| 1340 | System.err.println("Next set : \n"+printSets(raceSets)); |
---|
| 1341 | } |
---|
| 1342 | improved = false; |
---|
| 1343 | winnerInfo = raceSubsets(raceSets, data, true, random); |
---|
| 1344 | String bs = new String(baseSet); |
---|
| 1345 | String win = new String(raceSets[(int)winnerInfo[0]]); |
---|
| 1346 | if (bs.compareTo(win) == 0) { |
---|
| 1347 | // race finished |
---|
| 1348 | } else { |
---|
| 1349 | if (winnerInfo[1] < baseSetError || m_rankingRequested) { |
---|
| 1350 | improved = true; |
---|
| 1351 | baseSetError = winnerInfo[1]; |
---|
| 1352 | m_bestMerit = baseSetError; |
---|
| 1353 | // find which att is different |
---|
| 1354 | if (m_rankingRequested) { |
---|
| 1355 | for (int i = 0; i < baseSet.length; i++) { |
---|
| 1356 | if (win.charAt(i) != bs.charAt(i)) { |
---|
| 1357 | m_rankedAtts[m_rankedSoFar][0] = i; |
---|
| 1358 | m_rankedAtts[m_rankedSoFar][1] = winnerInfo[1]; |
---|
| 1359 | m_rankedSoFar++; |
---|
| 1360 | } |
---|
| 1361 | } |
---|
| 1362 | } |
---|
| 1363 | baseSet = (char [])raceSets[(int)winnerInfo[0]].clone(); |
---|
| 1364 | } else { |
---|
| 1365 | // Will get here for a subset whose error is outside the delta |
---|
| 1366 | // threshold but is not *significantly* worse than the base |
---|
| 1367 | // subset |
---|
| 1368 | //throw new Exception("RaceSearch: problem in hillClimbRace"); |
---|
| 1369 | } |
---|
| 1370 | } |
---|
| 1371 | } |
---|
| 1372 | return attributeList(baseSet); |
---|
| 1373 | } |
---|
| 1374 | |
---|
| 1375 | /** |
---|
| 1376 | * Convert an attribute set to an array of indices |
---|
| 1377 | */ |
---|
| 1378 | private int [] attributeList(char [] list) { |
---|
| 1379 | int count = 0; |
---|
| 1380 | |
---|
| 1381 | for (int i=0;i<m_numAttribs;i++) { |
---|
| 1382 | if (list[i] == '1') { |
---|
| 1383 | count++; |
---|
| 1384 | } |
---|
| 1385 | } |
---|
| 1386 | |
---|
| 1387 | int [] rlist = new int[count]; |
---|
| 1388 | count = 0; |
---|
| 1389 | for (int i=0;i<m_numAttribs;i++) { |
---|
| 1390 | if (list[i] == '1') { |
---|
| 1391 | rlist[count++] = i; |
---|
| 1392 | } |
---|
| 1393 | } |
---|
| 1394 | |
---|
| 1395 | return rlist; |
---|
| 1396 | } |
---|
| 1397 | |
---|
| 1398 | /** |
---|
| 1399 | * Races the leave-one-out cross validation errors of a set of |
---|
| 1400 | * attribute subsets on a set of instances. |
---|
| 1401 | * @param raceSets a set of attribute subset specifications |
---|
| 1402 | * @param data the instances to use when cross validating |
---|
| 1403 | * @param baseSetIncluded true if the first attribute set is a |
---|
| 1404 | * base set generated from the previous race |
---|
| 1405 | * @param random a random number generator |
---|
| 1406 | * @return the index of the winning subset |
---|
| 1407 | * @throws Exception if an error occurs during cross validation |
---|
| 1408 | */ |
---|
| 1409 | private double [] raceSubsets(char [][]raceSets, Instances data, |
---|
| 1410 | boolean baseSetIncluded, Random random) |
---|
| 1411 | throws Exception { |
---|
| 1412 | // the evaluators --- one for each subset |
---|
| 1413 | ASEvaluation [] evaluators = |
---|
| 1414 | ASEvaluation.makeCopies(m_theEvaluator, raceSets.length); |
---|
| 1415 | |
---|
| 1416 | // array of subsets eliminated from the race |
---|
| 1417 | boolean [] eliminated = new boolean [raceSets.length]; |
---|
| 1418 | |
---|
| 1419 | // individual statistics |
---|
| 1420 | Stats [] individualStats = new Stats [raceSets.length]; |
---|
| 1421 | |
---|
| 1422 | // pairwise statistics |
---|
| 1423 | PairedStats [][] testers = |
---|
| 1424 | new PairedStats[raceSets.length][raceSets.length]; |
---|
| 1425 | |
---|
| 1426 | /** do we ignore the base set or not? */ |
---|
| 1427 | int startPt = m_rankingRequested ? 1 : 0; |
---|
| 1428 | |
---|
| 1429 | for (int i=0;i<raceSets.length;i++) { |
---|
| 1430 | individualStats[i] = new Stats(); |
---|
| 1431 | for (int j=i+1;j<raceSets.length;j++) { |
---|
| 1432 | testers[i][j] = new PairedStats(m_sigLevel); |
---|
| 1433 | } |
---|
| 1434 | } |
---|
| 1435 | |
---|
| 1436 | BitSet [] raceBitSets = new BitSet[raceSets.length]; |
---|
| 1437 | for (int i=0;i<raceSets.length;i++) { |
---|
| 1438 | raceBitSets[i] = new BitSet(m_numAttribs); |
---|
| 1439 | for (int j=0;j<m_numAttribs;j++) { |
---|
| 1440 | if (raceSets[i][j] == '1') { |
---|
| 1441 | raceBitSets[i].set(j); |
---|
| 1442 | } |
---|
| 1443 | } |
---|
| 1444 | } |
---|
| 1445 | |
---|
| 1446 | // now loop over the data points collecting leave-one-out errors for |
---|
| 1447 | // each attribute set |
---|
| 1448 | Instances trainCV; |
---|
| 1449 | Instances testCV; |
---|
| 1450 | Instance testInst; |
---|
| 1451 | double [] errors = new double [raceSets.length]; |
---|
| 1452 | int eliminatedCount = 0; |
---|
| 1453 | int processedCount = 0; |
---|
| 1454 | // if there is one set left in the race then we need to continue to |
---|
| 1455 | // evaluate it for the remaining instances in order to get an |
---|
| 1456 | // accurate error estimate |
---|
| 1457 | processedCount = 0; |
---|
| 1458 | race: for (int i=0;i<m_numFolds;i++) { |
---|
| 1459 | |
---|
| 1460 | // We want to randomize the data the same way for every |
---|
| 1461 | // learning scheme. |
---|
| 1462 | trainCV = data.trainCV(m_numFolds, i, new Random (1)); |
---|
| 1463 | testCV = data.testCV(m_numFolds, i); |
---|
| 1464 | |
---|
| 1465 | // loop over the surviving attribute sets building classifiers for this |
---|
| 1466 | // training set |
---|
| 1467 | for (int j=startPt;j<raceSets.length;j++) { |
---|
| 1468 | if (!eliminated[j]) { |
---|
| 1469 | evaluators[j].buildEvaluator(trainCV); |
---|
| 1470 | } |
---|
| 1471 | } |
---|
| 1472 | |
---|
| 1473 | for (int z=0;z<testCV.numInstances();z++) { |
---|
| 1474 | testInst = testCV.instance(z); |
---|
| 1475 | processedCount++; |
---|
| 1476 | |
---|
| 1477 | // loop over surviving attribute sets computing errors for this |
---|
| 1478 | // test point |
---|
| 1479 | for (int zz=startPt;zz<raceSets.length;zz++) { |
---|
| 1480 | if (!eliminated[zz]) { |
---|
| 1481 | if (z == 0) {// first test instance---make sure classifier is built |
---|
| 1482 | errors[zz] = -((HoldOutSubsetEvaluator)evaluators[zz]). |
---|
| 1483 | evaluateSubset(raceBitSets[zz], |
---|
| 1484 | testInst, |
---|
| 1485 | true); |
---|
| 1486 | } else { // must be k fold rather than leave one out |
---|
| 1487 | errors[zz] = -((HoldOutSubsetEvaluator)evaluators[zz]). |
---|
| 1488 | evaluateSubset(raceBitSets[zz], |
---|
| 1489 | testInst, |
---|
| 1490 | false); |
---|
| 1491 | } |
---|
| 1492 | } |
---|
| 1493 | } |
---|
| 1494 | |
---|
| 1495 | // now update the stats |
---|
| 1496 | for (int j=startPt;j<raceSets.length;j++) { |
---|
| 1497 | if (!eliminated[j]) { |
---|
| 1498 | individualStats[j].add(errors[j]); |
---|
| 1499 | for (int k=j+1;k<raceSets.length;k++) { |
---|
| 1500 | if (!eliminated[k]) { |
---|
| 1501 | testers[j][k].add(errors[j], errors[k]); |
---|
| 1502 | } |
---|
| 1503 | } |
---|
| 1504 | } |
---|
| 1505 | } |
---|
| 1506 | |
---|
| 1507 | // test for near identical models and models that are significantly |
---|
| 1508 | // worse than some other model |
---|
| 1509 | if (processedCount > m_samples-1 && |
---|
| 1510 | (eliminatedCount < raceSets.length-1)) { |
---|
| 1511 | for (int j=0;j<raceSets.length;j++) { |
---|
| 1512 | if (!eliminated[j]) { |
---|
| 1513 | for (int k=j+1;k<raceSets.length;k++) { |
---|
| 1514 | if (!eliminated[k]) { |
---|
| 1515 | testers[j][k].calculateDerived(); |
---|
| 1516 | // near identical ? |
---|
| 1517 | if ((testers[j][k].differencesSignificance == 0) && |
---|
| 1518 | (Utils.eq(testers[j][k].differencesStats.mean, 0.0) || |
---|
| 1519 | (Utils.gr(m_delta, Math.abs(testers[j][k]. |
---|
| 1520 | differencesStats.mean))))) { |
---|
| 1521 | // if they're exactly the same and there is a base set |
---|
| 1522 | // in this race, make sure that the base set is NOT the |
---|
| 1523 | // one eliminated. |
---|
| 1524 | if (Utils.eq(testers[j][k].differencesStats.mean, 0.0)) { |
---|
| 1525 | |
---|
| 1526 | if (baseSetIncluded) { |
---|
| 1527 | if (j != 0) { |
---|
| 1528 | eliminated[j] = true; |
---|
| 1529 | } else { |
---|
| 1530 | eliminated[k] = true; |
---|
| 1531 | } |
---|
| 1532 | eliminatedCount++; |
---|
| 1533 | } else { |
---|
| 1534 | eliminated[j] = true; |
---|
| 1535 | } |
---|
| 1536 | if (m_debug) { |
---|
| 1537 | System.err.println("Eliminating (identical) " |
---|
| 1538 | +j+" "+raceBitSets[j].toString() |
---|
| 1539 | +" vs "+k+" " |
---|
| 1540 | +raceBitSets[k].toString() |
---|
| 1541 | +" after " |
---|
| 1542 | +processedCount |
---|
| 1543 | +" evaluations\n" |
---|
| 1544 | +"\nerror "+j+" : " |
---|
| 1545 | +testers[j][k].xStats.mean |
---|
| 1546 | +" vs "+k+" : " |
---|
| 1547 | +testers[j][k].yStats.mean |
---|
| 1548 | +" diff : " |
---|
| 1549 | +testers[j][k].differencesStats |
---|
| 1550 | .mean); |
---|
| 1551 | } |
---|
| 1552 | } else { |
---|
| 1553 | // eliminate the one with the higer error |
---|
| 1554 | if (testers[j][k].xStats.mean > |
---|
| 1555 | testers[j][k].yStats.mean) { |
---|
| 1556 | eliminated[j] = true; |
---|
| 1557 | eliminatedCount++; |
---|
| 1558 | if (m_debug) { |
---|
| 1559 | System.err.println("Eliminating (near identical) " |
---|
| 1560 | +j+" "+raceBitSets[j].toString() |
---|
| 1561 | +" vs "+k+" " |
---|
| 1562 | +raceBitSets[k].toString() |
---|
| 1563 | +" after " |
---|
| 1564 | +processedCount |
---|
| 1565 | +" evaluations\n" |
---|
| 1566 | +"\nerror "+j+" : " |
---|
| 1567 | +testers[j][k].xStats.mean |
---|
| 1568 | +" vs "+k+" : " |
---|
| 1569 | +testers[j][k].yStats.mean |
---|
| 1570 | +" diff : " |
---|
| 1571 | +testers[j][k].differencesStats |
---|
| 1572 | .mean); |
---|
| 1573 | } |
---|
| 1574 | break; |
---|
| 1575 | } else { |
---|
| 1576 | eliminated[k] = true; |
---|
| 1577 | eliminatedCount++; |
---|
| 1578 | if (m_debug) { |
---|
| 1579 | System.err.println("Eliminating (near identical) " |
---|
| 1580 | +k+" "+raceBitSets[k].toString() |
---|
| 1581 | +" vs "+j+" " |
---|
| 1582 | +raceBitSets[j].toString() |
---|
| 1583 | +" after " |
---|
| 1584 | +processedCount |
---|
| 1585 | +" evaluations\n" |
---|
| 1586 | +"\nerror "+k+" : " |
---|
| 1587 | +testers[j][k].yStats.mean |
---|
| 1588 | +" vs "+j+" : " |
---|
| 1589 | +testers[j][k].xStats.mean |
---|
| 1590 | +" diff : " |
---|
| 1591 | +testers[j][k].differencesStats |
---|
| 1592 | .mean); |
---|
| 1593 | } |
---|
| 1594 | } |
---|
| 1595 | } |
---|
| 1596 | } else { |
---|
| 1597 | // significantly worse ? |
---|
| 1598 | if (testers[j][k].differencesSignificance != 0) { |
---|
| 1599 | if (testers[j][k].differencesSignificance > 0) { |
---|
| 1600 | eliminated[j] = true; |
---|
| 1601 | eliminatedCount++; |
---|
| 1602 | if (m_debug) { |
---|
| 1603 | System.err.println("Eliminating (-worse) " |
---|
| 1604 | +j+" "+raceBitSets[j].toString() |
---|
| 1605 | +" vs "+k+" " |
---|
| 1606 | +raceBitSets[k].toString() |
---|
| 1607 | +" after " |
---|
| 1608 | +processedCount |
---|
| 1609 | +" evaluations" |
---|
| 1610 | +"\nerror "+j+" : " |
---|
| 1611 | +testers[j][k].xStats.mean |
---|
| 1612 | +" vs "+k+" : " |
---|
| 1613 | +testers[j][k].yStats.mean); |
---|
| 1614 | } |
---|
| 1615 | break; |
---|
| 1616 | } else { |
---|
| 1617 | eliminated[k] = true; |
---|
| 1618 | eliminatedCount++; |
---|
| 1619 | if (m_debug) { |
---|
| 1620 | System.err.println("Eliminating (worse) " |
---|
| 1621 | +k+" "+raceBitSets[k].toString() |
---|
| 1622 | +" vs "+j+" " |
---|
| 1623 | +raceBitSets[j].toString() |
---|
| 1624 | +" after " |
---|
| 1625 | +processedCount |
---|
| 1626 | +" evaluations" |
---|
| 1627 | +"\nerror "+k+" : " |
---|
| 1628 | +testers[j][k].yStats.mean |
---|
| 1629 | +" vs "+j+" : " |
---|
| 1630 | +testers[j][k].xStats.mean); |
---|
| 1631 | } |
---|
| 1632 | } |
---|
| 1633 | } |
---|
| 1634 | } |
---|
| 1635 | } |
---|
| 1636 | } |
---|
| 1637 | } |
---|
| 1638 | } |
---|
| 1639 | } |
---|
| 1640 | // if there is a base set from the previous race and it's the |
---|
| 1641 | // only remaining subset then terminate the race. |
---|
| 1642 | if (eliminatedCount == raceSets.length-1 && baseSetIncluded && |
---|
| 1643 | !eliminated[0] && !m_rankingRequested) { |
---|
| 1644 | break race; |
---|
| 1645 | } |
---|
| 1646 | } |
---|
| 1647 | } |
---|
| 1648 | |
---|
| 1649 | if (m_debug) { |
---|
| 1650 | System.err.println("*****eliminated count: "+eliminatedCount); |
---|
| 1651 | } |
---|
| 1652 | double bestError = Double.MAX_VALUE; |
---|
| 1653 | int bestIndex=0; |
---|
| 1654 | // return the index of the winner |
---|
| 1655 | for (int i=startPt;i<raceSets.length;i++) { |
---|
| 1656 | if (!eliminated[i]) { |
---|
| 1657 | individualStats[i].calculateDerived(); |
---|
| 1658 | if (m_debug) { |
---|
| 1659 | System.err.println("Remaining error: "+raceBitSets[i].toString() |
---|
| 1660 | +" "+individualStats[i].mean); |
---|
| 1661 | } |
---|
| 1662 | if (individualStats[i].mean < bestError) { |
---|
| 1663 | bestError = individualStats[i].mean; |
---|
| 1664 | bestIndex = i; |
---|
| 1665 | } |
---|
| 1666 | } |
---|
| 1667 | } |
---|
| 1668 | |
---|
| 1669 | double [] retInfo = new double[2]; |
---|
| 1670 | retInfo[0] = bestIndex; |
---|
| 1671 | retInfo[1] = bestError; |
---|
| 1672 | |
---|
| 1673 | if (m_debug) { |
---|
| 1674 | System.err.print("Best set from race : "); |
---|
| 1675 | |
---|
| 1676 | for (int i=0;i<m_numAttribs;i++) { |
---|
| 1677 | if (raceSets[bestIndex][i] == '1') { |
---|
| 1678 | System.err.print('1'); |
---|
| 1679 | } else { |
---|
| 1680 | System.err.print('0'); |
---|
| 1681 | } |
---|
| 1682 | } |
---|
| 1683 | System.err.println(" :"+bestError+" Processed : "+(processedCount) |
---|
| 1684 | +"\n"+individualStats[bestIndex].toString()); |
---|
| 1685 | } |
---|
| 1686 | return retInfo; |
---|
| 1687 | } |
---|
| 1688 | |
---|
| 1689 | /** |
---|
| 1690 | * Returns a string represenation |
---|
| 1691 | * |
---|
| 1692 | * @return a string representation |
---|
| 1693 | */ |
---|
| 1694 | public String toString() { |
---|
| 1695 | StringBuffer text = new StringBuffer(); |
---|
| 1696 | |
---|
| 1697 | text.append("\tRaceSearch.\n\tRace type : "); |
---|
| 1698 | switch (m_raceType) { |
---|
| 1699 | case FORWARD_RACE: |
---|
| 1700 | text.append("forward selection race\n\tBase set : no attributes"); |
---|
| 1701 | break; |
---|
| 1702 | case BACKWARD_RACE: |
---|
| 1703 | text.append("backward elimination race\n\tBase set : all attributes"); |
---|
| 1704 | break; |
---|
| 1705 | case SCHEMATA_RACE: |
---|
| 1706 | text.append("schemata race\n\tBase set : no attributes"); |
---|
| 1707 | break; |
---|
| 1708 | case RANK_RACE: |
---|
| 1709 | text.append("rank race\n\tBase set : no attributes\n\t"); |
---|
| 1710 | text.append("Attribute evaluator : " |
---|
| 1711 | + getAttributeEvaluator().getClass().getName() +" "); |
---|
| 1712 | if (m_ASEval instanceof OptionHandler) { |
---|
| 1713 | String[] evaluatorOptions = new String[0]; |
---|
| 1714 | evaluatorOptions = ((OptionHandler)m_ASEval).getOptions(); |
---|
| 1715 | for (int i=0;i<evaluatorOptions.length;i++) { |
---|
| 1716 | text.append(evaluatorOptions[i]+' '); |
---|
| 1717 | } |
---|
| 1718 | } |
---|
| 1719 | text.append("\n"); |
---|
| 1720 | text.append("\tAttribute ranking : \n"); |
---|
| 1721 | int rlength = (int)(Math.log(m_Ranking.length) / Math.log(10) + 1); |
---|
| 1722 | for (int i=0;i<m_Ranking.length;i++) { |
---|
| 1723 | text.append("\t "+Utils.doubleToString((double)(m_Ranking[i]+1), |
---|
| 1724 | rlength,0) |
---|
| 1725 | +" "+m_Instances.attribute(m_Ranking[i]).name()+'\n'); |
---|
| 1726 | } |
---|
| 1727 | break; |
---|
| 1728 | } |
---|
| 1729 | text.append("\n\tCross validation mode : "); |
---|
| 1730 | if (m_xvalType == TEN_FOLD) { |
---|
| 1731 | text.append("10 fold"); |
---|
| 1732 | } else { |
---|
| 1733 | text.append("Leave-one-out"); |
---|
| 1734 | } |
---|
| 1735 | |
---|
| 1736 | text.append("\n\tMerit of best subset found : "); |
---|
| 1737 | int fieldwidth = 3; |
---|
| 1738 | double precision = (m_bestMerit - (int)m_bestMerit); |
---|
| 1739 | if (Math.abs(m_bestMerit) > 0) { |
---|
| 1740 | fieldwidth = (int)Math.abs((Math.log(Math.abs(m_bestMerit)) / |
---|
| 1741 | Math.log(10)))+2; |
---|
| 1742 | } |
---|
| 1743 | if (Math.abs(precision) > 0) { |
---|
| 1744 | precision = Math.abs((Math.log(Math.abs(precision)) / Math.log(10)))+3; |
---|
| 1745 | } else { |
---|
| 1746 | precision = 2; |
---|
| 1747 | } |
---|
| 1748 | |
---|
| 1749 | text.append(Utils.doubleToString(Math.abs(m_bestMerit), |
---|
| 1750 | fieldwidth+(int)precision, |
---|
| 1751 | (int)precision)+"\n"); |
---|
| 1752 | return text.toString(); |
---|
| 1753 | |
---|
| 1754 | } |
---|
| 1755 | |
---|
| 1756 | /** |
---|
| 1757 | * Reset the search method. |
---|
| 1758 | */ |
---|
| 1759 | protected void resetOptions () { |
---|
| 1760 | m_sigLevel = 0.001; |
---|
| 1761 | m_delta = 0.001; |
---|
| 1762 | m_ASEval = new GainRatioAttributeEval(); |
---|
| 1763 | m_Ranking = null; |
---|
| 1764 | m_raceType = FORWARD_RACE; |
---|
| 1765 | m_debug = false; |
---|
| 1766 | m_theEvaluator = null; |
---|
| 1767 | m_bestMerit = -Double.MAX_VALUE; |
---|
| 1768 | m_numFolds = 10; |
---|
| 1769 | } |
---|
| 1770 | |
---|
| 1771 | /** |
---|
| 1772 | * Returns the revision string. |
---|
| 1773 | * |
---|
| 1774 | * @return the revision |
---|
| 1775 | */ |
---|
| 1776 | public String getRevision() { |
---|
| 1777 | return RevisionUtils.extract("$Revision: 1.26 $"); |
---|
| 1778 | } |
---|
| 1779 | } |
---|