[29] | 1 | /* |
---|
| 2 | * This program is free software; you can redistribute it and/or modify |
---|
| 3 | * it under the terms of the GNU General Public License as published by |
---|
| 4 | * the Free Software Foundation; either version 2 of the License, or |
---|
| 5 | * (at your option) any later version. |
---|
| 6 | * |
---|
| 7 | * This program is distributed in the hope that it will be useful, |
---|
| 8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
| 9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
---|
| 10 | * GNU General Public License for more details. |
---|
| 11 | * |
---|
| 12 | * You should have received a copy of the GNU General Public License |
---|
| 13 | * along with this program; if not, write to the Free Software |
---|
| 14 | * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. |
---|
| 15 | */ |
---|
| 16 | |
---|
| 17 | /* |
---|
| 18 | * MIDD.java |
---|
| 19 | * Copyright (C) 2005 University of Waikato, Hamilton, New Zealand |
---|
| 20 | * |
---|
| 21 | */ |
---|
| 22 | |
---|
| 23 | package weka.classifiers.mi; |
---|
| 24 | |
---|
| 25 | import weka.classifiers.Classifier; |
---|
| 26 | import weka.classifiers.AbstractClassifier; |
---|
| 27 | import weka.core.Capabilities; |
---|
| 28 | import weka.core.FastVector; |
---|
| 29 | import weka.core.Instance; |
---|
| 30 | import weka.core.Instances; |
---|
| 31 | import weka.core.MultiInstanceCapabilitiesHandler; |
---|
| 32 | import weka.core.Optimization; |
---|
| 33 | import weka.core.Option; |
---|
| 34 | import weka.core.OptionHandler; |
---|
| 35 | import weka.core.RevisionUtils; |
---|
| 36 | import weka.core.SelectedTag; |
---|
| 37 | import weka.core.Tag; |
---|
| 38 | import weka.core.TechnicalInformation; |
---|
| 39 | import weka.core.TechnicalInformationHandler; |
---|
| 40 | import weka.core.Utils; |
---|
| 41 | import weka.core.Capabilities.Capability; |
---|
| 42 | import weka.core.TechnicalInformation.Field; |
---|
| 43 | import weka.core.TechnicalInformation.Type; |
---|
| 44 | import weka.filters.Filter; |
---|
| 45 | import weka.filters.unsupervised.attribute.Normalize; |
---|
| 46 | import weka.filters.unsupervised.attribute.ReplaceMissingValues; |
---|
| 47 | import weka.filters.unsupervised.attribute.Standardize; |
---|
| 48 | |
---|
| 49 | import java.util.Enumeration; |
---|
| 50 | import java.util.Vector; |
---|
| 51 | |
---|
| 52 | /** |
---|
| 53 | <!-- globalinfo-start --> |
---|
| 54 | * Re-implement the Diverse Density algorithm, changes the testing procedure.<br/> |
---|
| 55 | * <br/> |
---|
| 56 | * Oded Maron (1998). Learning from ambiguity.<br/> |
---|
| 57 | * <br/> |
---|
| 58 | * O. Maron, T. Lozano-Perez (1998). A Framework for Multiple Instance Learning. Neural Information Processing Systems. 10. |
---|
| 59 | * <p/> |
---|
| 60 | <!-- globalinfo-end --> |
---|
| 61 | * |
---|
| 62 | <!-- technical-bibtex-start --> |
---|
| 63 | * BibTeX: |
---|
| 64 | * <pre> |
---|
| 65 | * @phdthesis{Maron1998, |
---|
| 66 | * author = {Oded Maron}, |
---|
| 67 | * school = {Massachusetts Institute of Technology}, |
---|
| 68 | * title = {Learning from ambiguity}, |
---|
| 69 | * year = {1998} |
---|
| 70 | * } |
---|
| 71 | * |
---|
| 72 | * @article{Maron1998, |
---|
| 73 | * author = {O. Maron and T. Lozano-Perez}, |
---|
| 74 | * journal = {Neural Information Processing Systems}, |
---|
| 75 | * title = {A Framework for Multiple Instance Learning}, |
---|
| 76 | * volume = {10}, |
---|
| 77 | * year = {1998} |
---|
| 78 | * } |
---|
| 79 | * </pre> |
---|
| 80 | * <p/> |
---|
| 81 | <!-- technical-bibtex-end --> |
---|
| 82 | * |
---|
| 83 | <!-- options-start --> |
---|
| 84 | * Valid options are: <p/> |
---|
| 85 | * |
---|
| 86 | * <pre> -D |
---|
| 87 | * Turn on debugging output.</pre> |
---|
| 88 | * |
---|
| 89 | * <pre> -N <num> |
---|
| 90 | * Whether to 0=normalize/1=standardize/2=neither. |
---|
| 91 | * (default 1=standardize)</pre> |
---|
| 92 | * |
---|
| 93 | <!-- options-end --> |
---|
| 94 | * |
---|
| 95 | * @author Eibe Frank (eibe@cs.waikato.ac.nz) |
---|
| 96 | * @author Xin Xu (xx5@cs.waikato.ac.nz) |
---|
| 97 | * @version $Revision: 5928 $ |
---|
| 98 | */ |
---|
| 99 | public class MIDD |
---|
| 100 | extends AbstractClassifier |
---|
| 101 | implements OptionHandler, MultiInstanceCapabilitiesHandler, |
---|
| 102 | TechnicalInformationHandler { |
---|
| 103 | |
---|
| 104 | /** for serialization */ |
---|
| 105 | static final long serialVersionUID = 4263507733600536168L; |
---|
| 106 | |
---|
| 107 | /** The index of the class attribute */ |
---|
| 108 | protected int m_ClassIndex; |
---|
| 109 | |
---|
| 110 | protected double[] m_Par; |
---|
| 111 | |
---|
| 112 | /** The number of the class labels */ |
---|
| 113 | protected int m_NumClasses; |
---|
| 114 | |
---|
| 115 | /** Class labels for each bag */ |
---|
| 116 | protected int[] m_Classes; |
---|
| 117 | |
---|
| 118 | /** MI data */ |
---|
| 119 | protected double[][][] m_Data; |
---|
| 120 | |
---|
| 121 | /** All attribute names */ |
---|
| 122 | protected Instances m_Attributes; |
---|
| 123 | |
---|
| 124 | /** The filter used to standardize/normalize all values. */ |
---|
| 125 | protected Filter m_Filter = null; |
---|
| 126 | |
---|
| 127 | /** Whether to normalize/standardize/neither, default:standardize */ |
---|
| 128 | protected int m_filterType = FILTER_STANDARDIZE; |
---|
| 129 | |
---|
| 130 | /** Normalize training data */ |
---|
| 131 | public static final int FILTER_NORMALIZE = 0; |
---|
| 132 | /** Standardize training data */ |
---|
| 133 | public static final int FILTER_STANDARDIZE = 1; |
---|
| 134 | /** No normalization/standardization */ |
---|
| 135 | public static final int FILTER_NONE = 2; |
---|
| 136 | /** The filter to apply to the training data */ |
---|
| 137 | public static final Tag [] TAGS_FILTER = { |
---|
| 138 | new Tag(FILTER_NORMALIZE, "Normalize training data"), |
---|
| 139 | new Tag(FILTER_STANDARDIZE, "Standardize training data"), |
---|
| 140 | new Tag(FILTER_NONE, "No normalization/standardization"), |
---|
| 141 | }; |
---|
| 142 | |
---|
| 143 | /** The filter used to get rid of missing values. */ |
---|
| 144 | protected ReplaceMissingValues m_Missing = new ReplaceMissingValues(); |
---|
| 145 | |
---|
| 146 | /** |
---|
| 147 | * Returns a string describing this filter |
---|
| 148 | * |
---|
| 149 | * @return a description of the filter suitable for |
---|
| 150 | * displaying in the explorer/experimenter gui |
---|
| 151 | */ |
---|
| 152 | public String globalInfo() { |
---|
| 153 | return |
---|
| 154 | "Re-implement the Diverse Density algorithm, changes the testing " |
---|
| 155 | + "procedure.\n\n" |
---|
| 156 | + getTechnicalInformation().toString(); |
---|
| 157 | } |
---|
| 158 | |
---|
| 159 | /** |
---|
| 160 | * Returns an instance of a TechnicalInformation object, containing |
---|
| 161 | * detailed information about the technical background of this class, |
---|
| 162 | * e.g., paper reference or book this class is based on. |
---|
| 163 | * |
---|
| 164 | * @return the technical information about this class |
---|
| 165 | */ |
---|
| 166 | public TechnicalInformation getTechnicalInformation() { |
---|
| 167 | TechnicalInformation result; |
---|
| 168 | TechnicalInformation additional; |
---|
| 169 | |
---|
| 170 | result = new TechnicalInformation(Type.PHDTHESIS); |
---|
| 171 | result.setValue(Field.AUTHOR, "Oded Maron"); |
---|
| 172 | result.setValue(Field.YEAR, "1998"); |
---|
| 173 | result.setValue(Field.TITLE, "Learning from ambiguity"); |
---|
| 174 | result.setValue(Field.SCHOOL, "Massachusetts Institute of Technology"); |
---|
| 175 | |
---|
| 176 | additional = result.add(Type.ARTICLE); |
---|
| 177 | additional.setValue(Field.AUTHOR, "O. Maron and T. Lozano-Perez"); |
---|
| 178 | additional.setValue(Field.YEAR, "1998"); |
---|
| 179 | additional.setValue(Field.TITLE, "A Framework for Multiple Instance Learning"); |
---|
| 180 | additional.setValue(Field.JOURNAL, "Neural Information Processing Systems"); |
---|
| 181 | additional.setValue(Field.VOLUME, "10"); |
---|
| 182 | |
---|
| 183 | return result; |
---|
| 184 | } |
---|
| 185 | |
---|
| 186 | /** |
---|
| 187 | * Returns an enumeration describing the available options |
---|
| 188 | * |
---|
| 189 | * @return an enumeration of all the available options |
---|
| 190 | */ |
---|
| 191 | public Enumeration listOptions() { |
---|
| 192 | Vector result = new Vector(); |
---|
| 193 | |
---|
| 194 | result.addElement(new Option( |
---|
| 195 | "\tTurn on debugging output.", |
---|
| 196 | "D", 0, "-D")); |
---|
| 197 | |
---|
| 198 | result.addElement(new Option( |
---|
| 199 | "\tWhether to 0=normalize/1=standardize/2=neither.\n" |
---|
| 200 | + "\t(default 1=standardize)", |
---|
| 201 | "N", 1, "-N <num>")); |
---|
| 202 | |
---|
| 203 | return result.elements(); |
---|
| 204 | } |
---|
| 205 | |
---|
| 206 | /** |
---|
| 207 | * Parses a given list of options. <p/> |
---|
| 208 | * |
---|
| 209 | <!-- options-start --> |
---|
| 210 | * Valid options are: <p/> |
---|
| 211 | * |
---|
| 212 | * <pre> -D |
---|
| 213 | * Turn on debugging output.</pre> |
---|
| 214 | * |
---|
| 215 | * <pre> -N <num> |
---|
| 216 | * Whether to 0=normalize/1=standardize/2=neither. |
---|
| 217 | * (default 1=standardize)</pre> |
---|
| 218 | * |
---|
| 219 | <!-- options-end --> |
---|
| 220 | * |
---|
| 221 | * @param options the list of options as an array of strings |
---|
| 222 | * @throws Exception if an option is not supported |
---|
| 223 | */ |
---|
| 224 | public void setOptions(String[] options) throws Exception { |
---|
| 225 | setDebug(Utils.getFlag('D', options)); |
---|
| 226 | |
---|
| 227 | String nString = Utils.getOption('N', options); |
---|
| 228 | if (nString.length() != 0) { |
---|
| 229 | setFilterType(new SelectedTag(Integer.parseInt(nString), TAGS_FILTER)); |
---|
| 230 | } else { |
---|
| 231 | setFilterType(new SelectedTag(FILTER_STANDARDIZE, TAGS_FILTER)); |
---|
| 232 | } |
---|
| 233 | } |
---|
| 234 | |
---|
| 235 | /** |
---|
| 236 | * Gets the current settings of the classifier. |
---|
| 237 | * |
---|
| 238 | * @return an array of strings suitable for passing to setOptions |
---|
| 239 | */ |
---|
| 240 | public String[] getOptions() { |
---|
| 241 | Vector result; |
---|
| 242 | |
---|
| 243 | result = new Vector(); |
---|
| 244 | |
---|
| 245 | if (getDebug()) |
---|
| 246 | result.add("-D"); |
---|
| 247 | |
---|
| 248 | result.add("-N"); |
---|
| 249 | result.add("" + m_filterType); |
---|
| 250 | |
---|
| 251 | return (String[]) result.toArray(new String[result.size()]); |
---|
| 252 | } |
---|
| 253 | |
---|
| 254 | /** |
---|
| 255 | * Returns the tip text for this property |
---|
| 256 | * |
---|
| 257 | * @return tip text for this property suitable for |
---|
| 258 | * displaying in the explorer/experimenter gui |
---|
| 259 | */ |
---|
| 260 | public String filterTypeTipText() { |
---|
| 261 | return "The filter type for transforming the training data."; |
---|
| 262 | } |
---|
| 263 | |
---|
| 264 | /** |
---|
| 265 | * Gets how the training data will be transformed. Will be one of |
---|
| 266 | * FILTER_NORMALIZE, FILTER_STANDARDIZE, FILTER_NONE. |
---|
| 267 | * |
---|
| 268 | * @return the filtering mode |
---|
| 269 | */ |
---|
| 270 | public SelectedTag getFilterType() { |
---|
| 271 | return new SelectedTag(m_filterType, TAGS_FILTER); |
---|
| 272 | } |
---|
| 273 | |
---|
| 274 | /** |
---|
| 275 | * Sets how the training data will be transformed. Should be one of |
---|
| 276 | * FILTER_NORMALIZE, FILTER_STANDARDIZE, FILTER_NONE. |
---|
| 277 | * |
---|
| 278 | * @param newType the new filtering mode |
---|
| 279 | */ |
---|
| 280 | public void setFilterType(SelectedTag newType) { |
---|
| 281 | |
---|
| 282 | if (newType.getTags() == TAGS_FILTER) { |
---|
| 283 | m_filterType = newType.getSelectedTag().getID(); |
---|
| 284 | } |
---|
| 285 | } |
---|
| 286 | |
---|
| 287 | private class OptEng |
---|
| 288 | extends Optimization { |
---|
| 289 | |
---|
| 290 | /** |
---|
| 291 | * Evaluate objective function |
---|
| 292 | * @param x the current values of variables |
---|
| 293 | * @return the value of the objective function |
---|
| 294 | */ |
---|
| 295 | protected double objectiveFunction(double[] x){ |
---|
| 296 | double nll = 0; // -LogLikelihood |
---|
| 297 | for(int i=0; i<m_Classes.length; i++){ // ith bag |
---|
| 298 | int nI = m_Data[i][0].length; // numInstances in ith bag |
---|
| 299 | double bag = 0.0; // NLL of pos bag |
---|
| 300 | |
---|
| 301 | for(int j=0; j<nI; j++){ |
---|
| 302 | double ins=0.0; |
---|
| 303 | for(int k=0; k<m_Data[i].length; k++) |
---|
| 304 | ins += (m_Data[i][k][j]-x[k*2])*(m_Data[i][k][j]-x[k*2])* |
---|
| 305 | x[k*2+1]*x[k*2+1]; |
---|
| 306 | ins = Math.exp(-ins); |
---|
| 307 | ins = 1.0-ins; |
---|
| 308 | |
---|
| 309 | if(m_Classes[i] == 1) |
---|
| 310 | bag += Math.log(ins); |
---|
| 311 | else{ |
---|
| 312 | if(ins<=m_Zero) ins=m_Zero; |
---|
| 313 | nll -= Math.log(ins); |
---|
| 314 | } |
---|
| 315 | } |
---|
| 316 | |
---|
| 317 | if(m_Classes[i] == 1){ |
---|
| 318 | bag = 1.0 - Math.exp(bag); |
---|
| 319 | if(bag<=m_Zero) bag=m_Zero; |
---|
| 320 | nll -= Math.log(bag); |
---|
| 321 | } |
---|
| 322 | } |
---|
| 323 | return nll; |
---|
| 324 | } |
---|
| 325 | |
---|
| 326 | /** |
---|
| 327 | * Evaluate Jacobian vector |
---|
| 328 | * @param x the current values of variables |
---|
| 329 | * @return the gradient vector |
---|
| 330 | */ |
---|
| 331 | protected double[] evaluateGradient(double[] x){ |
---|
| 332 | double[] grad = new double[x.length]; |
---|
| 333 | for(int i=0; i<m_Classes.length; i++){ // ith bag |
---|
| 334 | int nI = m_Data[i][0].length; // numInstances in ith bag |
---|
| 335 | |
---|
| 336 | double denom=0.0; |
---|
| 337 | double[] numrt = new double[x.length]; |
---|
| 338 | |
---|
| 339 | for(int j=0; j<nI; j++){ |
---|
| 340 | double exp=0.0; |
---|
| 341 | for(int k=0; k<m_Data[i].length; k++) |
---|
| 342 | exp += (m_Data[i][k][j]-x[k*2])*(m_Data[i][k][j]-x[k*2]) |
---|
| 343 | *x[k*2+1]*x[k*2+1]; |
---|
| 344 | exp = Math.exp(-exp); |
---|
| 345 | exp = 1.0-exp; |
---|
| 346 | if(m_Classes[i]==1) |
---|
| 347 | denom += Math.log(exp); |
---|
| 348 | |
---|
| 349 | if(exp<=m_Zero) exp=m_Zero; |
---|
| 350 | // Instance-wise update |
---|
| 351 | for(int p=0; p<m_Data[i].length; p++){ // pth variable |
---|
| 352 | numrt[2*p] += (1.0-exp)*2.0*(x[2*p]-m_Data[i][p][j])*x[p*2+1]*x[p*2+1] |
---|
| 353 | /exp; |
---|
| 354 | numrt[2*p+1] += 2.0*(1.0-exp)*(x[2*p]-m_Data[i][p][j])*(x[2*p]-m_Data[i][p][j]) |
---|
| 355 | *x[p*2+1]/exp; |
---|
| 356 | } |
---|
| 357 | } |
---|
| 358 | |
---|
| 359 | // Bag-wise update |
---|
| 360 | denom = 1.0-Math.exp(denom); |
---|
| 361 | if(denom <= m_Zero) denom = m_Zero; |
---|
| 362 | for(int q=0; q<m_Data[i].length; q++){ |
---|
| 363 | if(m_Classes[i]==1){ |
---|
| 364 | grad[2*q] += numrt[2*q]*(1.0-denom)/denom; |
---|
| 365 | grad[2*q+1] += numrt[2*q+1]*(1.0-denom)/denom; |
---|
| 366 | }else{ |
---|
| 367 | grad[2*q] -= numrt[2*q]; |
---|
| 368 | grad[2*q+1] -= numrt[2*q+1]; |
---|
| 369 | } |
---|
| 370 | } |
---|
| 371 | } // one bag |
---|
| 372 | |
---|
| 373 | return grad; |
---|
| 374 | } |
---|
| 375 | |
---|
| 376 | /** |
---|
| 377 | * Returns the revision string. |
---|
| 378 | * |
---|
| 379 | * @return the revision |
---|
| 380 | */ |
---|
| 381 | public String getRevision() { |
---|
| 382 | return RevisionUtils.extract("$Revision: 5928 $"); |
---|
| 383 | } |
---|
| 384 | } |
---|
| 385 | |
---|
| 386 | /** |
---|
| 387 | * Returns default capabilities of the classifier. |
---|
| 388 | * |
---|
| 389 | * @return the capabilities of this classifier |
---|
| 390 | */ |
---|
| 391 | public Capabilities getCapabilities() { |
---|
| 392 | Capabilities result = super.getCapabilities(); |
---|
| 393 | result.disableAll(); |
---|
| 394 | |
---|
| 395 | // attributes |
---|
| 396 | result.enable(Capability.NOMINAL_ATTRIBUTES); |
---|
| 397 | result.enable(Capability.RELATIONAL_ATTRIBUTES); |
---|
| 398 | result.enable(Capability.MISSING_VALUES); |
---|
| 399 | |
---|
| 400 | // class |
---|
| 401 | result.enable(Capability.BINARY_CLASS); |
---|
| 402 | result.enable(Capability.MISSING_CLASS_VALUES); |
---|
| 403 | |
---|
| 404 | // other |
---|
| 405 | result.enable(Capability.ONLY_MULTIINSTANCE); |
---|
| 406 | |
---|
| 407 | return result; |
---|
| 408 | } |
---|
| 409 | |
---|
| 410 | /** |
---|
| 411 | * Returns the capabilities of this multi-instance classifier for the |
---|
| 412 | * relational data. |
---|
| 413 | * |
---|
| 414 | * @return the capabilities of this object |
---|
| 415 | * @see Capabilities |
---|
| 416 | */ |
---|
| 417 | public Capabilities getMultiInstanceCapabilities() { |
---|
| 418 | Capabilities result = super.getCapabilities(); |
---|
| 419 | result.disableAll(); |
---|
| 420 | |
---|
| 421 | // attributes |
---|
| 422 | result.enable(Capability.NOMINAL_ATTRIBUTES); |
---|
| 423 | result.enable(Capability.NUMERIC_ATTRIBUTES); |
---|
| 424 | result.enable(Capability.DATE_ATTRIBUTES); |
---|
| 425 | result.enable(Capability.MISSING_VALUES); |
---|
| 426 | |
---|
| 427 | // class |
---|
| 428 | result.disableAllClasses(); |
---|
| 429 | result.enable(Capability.NO_CLASS); |
---|
| 430 | |
---|
| 431 | return result; |
---|
| 432 | } |
---|
| 433 | |
---|
| 434 | /** |
---|
| 435 | * Builds the classifier |
---|
| 436 | * |
---|
| 437 | * @param train the training data to be used for generating the |
---|
| 438 | * boosted classifier. |
---|
| 439 | * @throws Exception if the classifier could not be built successfully |
---|
| 440 | */ |
---|
| 441 | public void buildClassifier(Instances train) throws Exception { |
---|
| 442 | // can classifier handle the data? |
---|
| 443 | getCapabilities().testWithFail(train); |
---|
| 444 | |
---|
| 445 | // remove instances with missing class |
---|
| 446 | train = new Instances(train); |
---|
| 447 | train.deleteWithMissingClass(); |
---|
| 448 | |
---|
| 449 | m_ClassIndex = train.classIndex(); |
---|
| 450 | m_NumClasses = train.numClasses(); |
---|
| 451 | |
---|
| 452 | int nR = train.attribute(1).relation().numAttributes(); |
---|
| 453 | int nC = train.numInstances(); |
---|
| 454 | FastVector maxSzIdx=new FastVector(); |
---|
| 455 | int maxSz=0; |
---|
| 456 | int [] bagSize=new int [nC]; |
---|
| 457 | Instances datasets= new Instances(train.attribute(1).relation(),0); |
---|
| 458 | |
---|
| 459 | m_Data = new double [nC][nR][]; // Data values |
---|
| 460 | m_Classes = new int [nC]; // Class values |
---|
| 461 | m_Attributes = datasets.stringFreeStructure(); |
---|
| 462 | if (m_Debug) { |
---|
| 463 | System.out.println("Extracting data..."); |
---|
| 464 | } |
---|
| 465 | |
---|
| 466 | for(int h=0; h<nC; h++) {//h_th bag |
---|
| 467 | Instance current = train.instance(h); |
---|
| 468 | m_Classes[h] = (int)current.classValue(); // Class value starts from 0 |
---|
| 469 | Instances currInsts = current.relationalValue(1); |
---|
| 470 | for (int i=0; i<currInsts.numInstances();i++){ |
---|
| 471 | Instance inst=currInsts.instance(i); |
---|
| 472 | datasets.add(inst); |
---|
| 473 | } |
---|
| 474 | |
---|
| 475 | int nI = currInsts.numInstances(); |
---|
| 476 | bagSize[h]=nI; |
---|
| 477 | if(m_Classes[h]==1){ |
---|
| 478 | if(nI>maxSz){ |
---|
| 479 | maxSz=nI; |
---|
| 480 | maxSzIdx=new FastVector(1); |
---|
| 481 | maxSzIdx.addElement(new Integer(h)); |
---|
| 482 | } |
---|
| 483 | else if(nI == maxSz) |
---|
| 484 | maxSzIdx.addElement(new Integer(h)); |
---|
| 485 | } |
---|
| 486 | |
---|
| 487 | } |
---|
| 488 | |
---|
| 489 | /* filter the training data */ |
---|
| 490 | if (m_filterType == FILTER_STANDARDIZE) |
---|
| 491 | m_Filter = new Standardize(); |
---|
| 492 | else if (m_filterType == FILTER_NORMALIZE) |
---|
| 493 | m_Filter = new Normalize(); |
---|
| 494 | else |
---|
| 495 | m_Filter = null; |
---|
| 496 | |
---|
| 497 | if (m_Filter!=null) { |
---|
| 498 | m_Filter.setInputFormat(datasets); |
---|
| 499 | datasets = Filter.useFilter(datasets, m_Filter); |
---|
| 500 | } |
---|
| 501 | |
---|
| 502 | m_Missing.setInputFormat(datasets); |
---|
| 503 | datasets = Filter.useFilter(datasets, m_Missing); |
---|
| 504 | |
---|
| 505 | |
---|
| 506 | int instIndex=0; |
---|
| 507 | int start=0; |
---|
| 508 | for(int h=0; h<nC; h++) { |
---|
| 509 | for (int i = 0; i < datasets.numAttributes(); i++) { |
---|
| 510 | // initialize m_data[][][] |
---|
| 511 | m_Data[h][i] = new double[bagSize[h]]; |
---|
| 512 | instIndex=start; |
---|
| 513 | for (int k=0; k<bagSize[h]; k++){ |
---|
| 514 | m_Data[h][i][k]=datasets.instance(instIndex).value(i); |
---|
| 515 | instIndex ++; |
---|
| 516 | } |
---|
| 517 | } |
---|
| 518 | start=instIndex; |
---|
| 519 | } |
---|
| 520 | |
---|
| 521 | |
---|
| 522 | if (m_Debug) { |
---|
| 523 | System.out.println("\nIteration History..." ); |
---|
| 524 | } |
---|
| 525 | |
---|
| 526 | double[] x = new double[nR*2], tmp = new double[x.length]; |
---|
| 527 | double[][] b = new double[2][x.length]; |
---|
| 528 | |
---|
| 529 | OptEng opt; |
---|
| 530 | double nll, bestnll = Double.MAX_VALUE; |
---|
| 531 | for (int t=0; t<x.length; t++){ |
---|
| 532 | b[0][t] = Double.NaN; |
---|
| 533 | b[1][t] = Double.NaN; |
---|
| 534 | } |
---|
| 535 | |
---|
| 536 | // Largest Positive exemplar |
---|
| 537 | for(int s=0; s<maxSzIdx.size(); s++){ |
---|
| 538 | int exIdx = ((Integer)maxSzIdx.elementAt(s)).intValue(); |
---|
| 539 | for(int p=0; p<m_Data[exIdx][0].length; p++){ |
---|
| 540 | for (int q=0; q < nR;q++){ |
---|
| 541 | x[2*q] = m_Data[exIdx][q][p]; // pick one instance |
---|
| 542 | x[2*q+1] = 1.0; |
---|
| 543 | } |
---|
| 544 | |
---|
| 545 | opt = new OptEng(); |
---|
| 546 | //opt.setDebug(m_Debug); |
---|
| 547 | tmp = opt.findArgmin(x, b); |
---|
| 548 | while(tmp==null){ |
---|
| 549 | tmp = opt.getVarbValues(); |
---|
| 550 | if (m_Debug) |
---|
| 551 | System.out.println("200 iterations finished, not enough!"); |
---|
| 552 | tmp = opt.findArgmin(tmp, b); |
---|
| 553 | } |
---|
| 554 | nll = opt.getMinFunction(); |
---|
| 555 | |
---|
| 556 | if(nll < bestnll){ |
---|
| 557 | bestnll = nll; |
---|
| 558 | m_Par = tmp; |
---|
| 559 | tmp = new double[x.length]; // Save memory |
---|
| 560 | if (m_Debug) |
---|
| 561 | System.out.println("!!!!!!!!!!!!!!!!Smaller NLL found: "+nll); |
---|
| 562 | } |
---|
| 563 | if (m_Debug) |
---|
| 564 | System.out.println(exIdx+": -------------<Converged>--------------"); |
---|
| 565 | } |
---|
| 566 | } |
---|
| 567 | } |
---|
| 568 | |
---|
| 569 | /** |
---|
| 570 | * Computes the distribution for a given exemplar |
---|
| 571 | * |
---|
| 572 | * @param exmp the exemplar for which distribution is computed |
---|
| 573 | * @return the distribution |
---|
| 574 | * @throws Exception if the distribution can't be computed successfully |
---|
| 575 | */ |
---|
| 576 | public double[] distributionForInstance(Instance exmp) |
---|
| 577 | throws Exception { |
---|
| 578 | |
---|
| 579 | // Extract the data |
---|
| 580 | Instances ins = exmp.relationalValue(1); |
---|
| 581 | if(m_Filter!=null) |
---|
| 582 | ins = Filter.useFilter(ins, m_Filter); |
---|
| 583 | |
---|
| 584 | ins = Filter.useFilter(ins, m_Missing); |
---|
| 585 | |
---|
| 586 | int nI = ins.numInstances(), nA = ins.numAttributes(); |
---|
| 587 | double[][] dat = new double [nI][nA]; |
---|
| 588 | for(int j=0; j<nI; j++){ |
---|
| 589 | for(int k=0; k<nA; k++){ |
---|
| 590 | dat[j][k] = ins.instance(j).value(k); |
---|
| 591 | } |
---|
| 592 | } |
---|
| 593 | |
---|
| 594 | // Compute the probability of the bag |
---|
| 595 | double [] distribution = new double[2]; |
---|
| 596 | distribution[0]=0.0; // log-Prob. for class 0 |
---|
| 597 | |
---|
| 598 | for(int i=0; i<nI; i++){ |
---|
| 599 | double exp = 0.0; |
---|
| 600 | for(int r=0; r<nA; r++) |
---|
| 601 | exp += (m_Par[r*2]-dat[i][r])*(m_Par[r*2]-dat[i][r])* |
---|
| 602 | m_Par[r*2+1]*m_Par[r*2+1]; |
---|
| 603 | exp = Math.exp(-exp); |
---|
| 604 | |
---|
| 605 | // Prob. updated for one instance |
---|
| 606 | distribution[0] += Math.log(1.0-exp); |
---|
| 607 | } |
---|
| 608 | |
---|
| 609 | distribution[0] = Math.exp(distribution[0]); |
---|
| 610 | distribution[1] = 1.0-distribution[0]; |
---|
| 611 | |
---|
| 612 | return distribution; |
---|
| 613 | } |
---|
| 614 | |
---|
| 615 | /** |
---|
| 616 | * Gets a string describing the classifier. |
---|
| 617 | * |
---|
| 618 | * @return a string describing the classifer built. |
---|
| 619 | */ |
---|
| 620 | public String toString() { |
---|
| 621 | |
---|
| 622 | //double CSq = m_LLn - m_LL; |
---|
| 623 | //int df = m_NumPredictors; |
---|
| 624 | String result = "Diverse Density"; |
---|
| 625 | if (m_Par == null) { |
---|
| 626 | return result + ": No model built yet."; |
---|
| 627 | } |
---|
| 628 | |
---|
| 629 | result += "\nCoefficients...\n" |
---|
| 630 | + "Variable Point Scale\n"; |
---|
| 631 | for (int j = 0, idx=0; j < m_Par.length/2; j++, idx++) { |
---|
| 632 | result += m_Attributes.attribute(idx).name(); |
---|
| 633 | result += " "+Utils.doubleToString(m_Par[j*2], 12, 4); |
---|
| 634 | result += " "+Utils.doubleToString(m_Par[j*2+1], 12, 4)+"\n"; |
---|
| 635 | } |
---|
| 636 | |
---|
| 637 | return result; |
---|
| 638 | } |
---|
| 639 | |
---|
| 640 | /** |
---|
| 641 | * Returns the revision string. |
---|
| 642 | * |
---|
| 643 | * @return the revision |
---|
| 644 | */ |
---|
| 645 | public String getRevision() { |
---|
| 646 | return RevisionUtils.extract("$Revision: 5928 $"); |
---|
| 647 | } |
---|
| 648 | |
---|
| 649 | /** |
---|
| 650 | * Main method for testing this class. |
---|
| 651 | * |
---|
| 652 | * @param argv should contain the command line arguments to the |
---|
| 653 | * scheme (see Evaluation) |
---|
| 654 | */ |
---|
| 655 | public static void main(String[] argv) { |
---|
| 656 | runClassifier(new MIDD(), argv); |
---|
| 657 | } |
---|
| 658 | } |
---|