[4] | 1 | /* |
---|
| 2 | * This program is free software; you can redistribute it and/or modify |
---|
| 3 | * it under the terms of the GNU General Public License as published by |
---|
| 4 | * the Free Software Foundation; either version 2 of the License, or |
---|
| 5 | * (at your option) any later version. |
---|
| 6 | * |
---|
| 7 | * This program is distributed in the hope that it will be useful, |
---|
| 8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
| 9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
---|
| 10 | * GNU General Public License for more details. |
---|
| 11 | * |
---|
| 12 | * You should have received a copy of the GNU General Public License |
---|
| 13 | * along with this program; if not, write to the Free Software |
---|
| 14 | * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. |
---|
| 15 | */ |
---|
| 16 | |
---|
| 17 | /* |
---|
| 18 | * MDD.java |
---|
| 19 | * Copyright (C) 2005 University of Waikato, Hamilton, New Zealand |
---|
| 20 | * |
---|
| 21 | */ |
---|
| 22 | |
---|
| 23 | package weka.classifiers.mi; |
---|
| 24 | |
---|
| 25 | import weka.classifiers.Classifier; |
---|
| 26 | import weka.classifiers.AbstractClassifier; |
---|
| 27 | import weka.core.Capabilities; |
---|
| 28 | import weka.core.FastVector; |
---|
| 29 | import weka.core.Instance; |
---|
| 30 | import weka.core.Instances; |
---|
| 31 | import weka.core.MultiInstanceCapabilitiesHandler; |
---|
| 32 | import weka.core.Optimization; |
---|
| 33 | import weka.core.Option; |
---|
| 34 | import weka.core.OptionHandler; |
---|
| 35 | import weka.core.RevisionUtils; |
---|
| 36 | import weka.core.SelectedTag; |
---|
| 37 | import weka.core.Tag; |
---|
| 38 | import weka.core.TechnicalInformation; |
---|
| 39 | import weka.core.TechnicalInformationHandler; |
---|
| 40 | import weka.core.Utils; |
---|
| 41 | import weka.core.Capabilities.Capability; |
---|
| 42 | import weka.core.TechnicalInformation.Field; |
---|
| 43 | import weka.core.TechnicalInformation.Type; |
---|
| 44 | import weka.filters.Filter; |
---|
| 45 | import weka.filters.unsupervised.attribute.Normalize; |
---|
| 46 | import weka.filters.unsupervised.attribute.ReplaceMissingValues; |
---|
| 47 | import weka.filters.unsupervised.attribute.Standardize; |
---|
| 48 | |
---|
| 49 | import java.util.Enumeration; |
---|
| 50 | import java.util.Vector; |
---|
| 51 | |
---|
| 52 | /** |
---|
| 53 | <!-- globalinfo-start --> |
---|
| 54 | * Modified Diverse Density algorithm, with collective assumption.<br/> |
---|
| 55 | * <br/> |
---|
| 56 | * More information about DD:<br/> |
---|
| 57 | * <br/> |
---|
| 58 | * Oded Maron (1998). Learning from ambiguity.<br/> |
---|
| 59 | * <br/> |
---|
| 60 | * O. Maron, T. Lozano-Perez (1998). A Framework for Multiple Instance Learning. Neural Information Processing Systems. 10. |
---|
| 61 | * <p/> |
---|
| 62 | <!-- globalinfo-end --> |
---|
| 63 | * |
---|
| 64 | <!-- technical-bibtex-start --> |
---|
| 65 | * BibTeX: |
---|
| 66 | * <pre> |
---|
| 67 | * @phdthesis{Maron1998, |
---|
| 68 | * author = {Oded Maron}, |
---|
| 69 | * school = {Massachusetts Institute of Technology}, |
---|
| 70 | * title = {Learning from ambiguity}, |
---|
| 71 | * year = {1998} |
---|
| 72 | * } |
---|
| 73 | * |
---|
| 74 | * @article{Maron1998, |
---|
| 75 | * author = {O. Maron and T. Lozano-Perez}, |
---|
| 76 | * journal = {Neural Information Processing Systems}, |
---|
| 77 | * title = {A Framework for Multiple Instance Learning}, |
---|
| 78 | * volume = {10}, |
---|
| 79 | * year = {1998} |
---|
| 80 | * } |
---|
| 81 | * </pre> |
---|
| 82 | * <p/> |
---|
| 83 | <!-- technical-bibtex-end --> |
---|
| 84 | * |
---|
| 85 | <!-- options-start --> |
---|
| 86 | * Valid options are: <p/> |
---|
| 87 | * |
---|
| 88 | * <pre> -D |
---|
| 89 | * Turn on debugging output.</pre> |
---|
| 90 | * |
---|
| 91 | * <pre> -N <num> |
---|
| 92 | * Whether to 0=normalize/1=standardize/2=neither. |
---|
| 93 | * (default 1=standardize)</pre> |
---|
| 94 | * |
---|
| 95 | <!-- options-end --> |
---|
| 96 | * |
---|
| 97 | * @author Eibe Frank (eibe@cs.waikato.ac.nz) |
---|
| 98 | * @author Xin Xu (xx5@cs.waikato.ac.nz) |
---|
| 99 | * @version $Revision: 5928 $ |
---|
| 100 | */ |
---|
| 101 | public class MDD |
---|
| 102 | extends AbstractClassifier |
---|
| 103 | implements OptionHandler, MultiInstanceCapabilitiesHandler, |
---|
| 104 | TechnicalInformationHandler { |
---|
| 105 | |
---|
| 106 | /** for serialization */ |
---|
| 107 | static final long serialVersionUID = -7273119490545290581L; |
---|
| 108 | |
---|
| 109 | /** The index of the class attribute */ |
---|
| 110 | protected int m_ClassIndex; |
---|
| 111 | |
---|
| 112 | protected double[] m_Par; |
---|
| 113 | |
---|
| 114 | /** The number of the class labels */ |
---|
| 115 | protected int m_NumClasses; |
---|
| 116 | |
---|
| 117 | /** Class labels for each bag */ |
---|
| 118 | protected int[] m_Classes; |
---|
| 119 | |
---|
| 120 | /** MI data */ |
---|
| 121 | protected double[][][] m_Data; |
---|
| 122 | |
---|
| 123 | /** All attribute names */ |
---|
| 124 | protected Instances m_Attributes; |
---|
| 125 | |
---|
| 126 | /** The filter used to standardize/normalize all values. */ |
---|
| 127 | protected Filter m_Filter =null; |
---|
| 128 | |
---|
| 129 | /** Whether to normalize/standardize/neither, default:standardize */ |
---|
| 130 | protected int m_filterType = FILTER_STANDARDIZE; |
---|
| 131 | |
---|
| 132 | /** Normalize training data */ |
---|
| 133 | public static final int FILTER_NORMALIZE = 0; |
---|
| 134 | /** Standardize training data */ |
---|
| 135 | public static final int FILTER_STANDARDIZE = 1; |
---|
| 136 | /** No normalization/standardization */ |
---|
| 137 | public static final int FILTER_NONE = 2; |
---|
| 138 | /** The filter to apply to the training data */ |
---|
| 139 | public static final Tag [] TAGS_FILTER = { |
---|
| 140 | new Tag(FILTER_NORMALIZE, "Normalize training data"), |
---|
| 141 | new Tag(FILTER_STANDARDIZE, "Standardize training data"), |
---|
| 142 | new Tag(FILTER_NONE, "No normalization/standardization"), |
---|
| 143 | }; |
---|
| 144 | |
---|
| 145 | /** The filter used to get rid of missing values. */ |
---|
| 146 | protected ReplaceMissingValues m_Missing = new ReplaceMissingValues(); |
---|
| 147 | |
---|
| 148 | /** |
---|
| 149 | * Returns a string describing this filter |
---|
| 150 | * |
---|
| 151 | * @return a description of the filter suitable for |
---|
| 152 | * displaying in the explorer/experimenter gui |
---|
| 153 | */ |
---|
| 154 | public String globalInfo() { |
---|
| 155 | return |
---|
| 156 | "Modified Diverse Density algorithm, with collective assumption.\n\n" |
---|
| 157 | + "More information about DD:\n\n" |
---|
| 158 | + getTechnicalInformation().toString(); |
---|
| 159 | } |
---|
| 160 | |
---|
| 161 | /** |
---|
| 162 | * Returns an instance of a TechnicalInformation object, containing |
---|
| 163 | * detailed information about the technical background of this class, |
---|
| 164 | * e.g., paper reference or book this class is based on. |
---|
| 165 | * |
---|
| 166 | * @return the technical information about this class |
---|
| 167 | */ |
---|
| 168 | public TechnicalInformation getTechnicalInformation() { |
---|
| 169 | TechnicalInformation result; |
---|
| 170 | TechnicalInformation additional; |
---|
| 171 | |
---|
| 172 | result = new TechnicalInformation(Type.PHDTHESIS); |
---|
| 173 | result.setValue(Field.AUTHOR, "Oded Maron"); |
---|
| 174 | result.setValue(Field.YEAR, "1998"); |
---|
| 175 | result.setValue(Field.TITLE, "Learning from ambiguity"); |
---|
| 176 | result.setValue(Field.SCHOOL, "Massachusetts Institute of Technology"); |
---|
| 177 | |
---|
| 178 | additional = result.add(Type.ARTICLE); |
---|
| 179 | additional.setValue(Field.AUTHOR, "O. Maron and T. Lozano-Perez"); |
---|
| 180 | additional.setValue(Field.YEAR, "1998"); |
---|
| 181 | additional.setValue(Field.TITLE, "A Framework for Multiple Instance Learning"); |
---|
| 182 | additional.setValue(Field.JOURNAL, "Neural Information Processing Systems"); |
---|
| 183 | additional.setValue(Field.VOLUME, "10"); |
---|
| 184 | |
---|
| 185 | return result; |
---|
| 186 | } |
---|
| 187 | |
---|
| 188 | /** |
---|
| 189 | * Returns an enumeration describing the available options |
---|
| 190 | * |
---|
| 191 | * @return an enumeration of all the available options |
---|
| 192 | */ |
---|
| 193 | public Enumeration listOptions() { |
---|
| 194 | Vector result = new Vector(); |
---|
| 195 | |
---|
| 196 | result.addElement(new Option( |
---|
| 197 | "\tTurn on debugging output.", |
---|
| 198 | "D", 0, "-D")); |
---|
| 199 | |
---|
| 200 | result.addElement(new Option( |
---|
| 201 | "\tWhether to 0=normalize/1=standardize/2=neither.\n" |
---|
| 202 | + "\t(default 1=standardize)", |
---|
| 203 | "N", 1, "-N <num>")); |
---|
| 204 | |
---|
| 205 | return result.elements(); |
---|
| 206 | } |
---|
| 207 | |
---|
| 208 | /** |
---|
| 209 | * Parses a given list of options. |
---|
| 210 | * |
---|
| 211 | * @param options the list of options as an array of strings |
---|
| 212 | * @throws Exception if an option is not supported |
---|
| 213 | */ |
---|
| 214 | public void setOptions(String[] options) throws Exception { |
---|
| 215 | setDebug(Utils.getFlag('D', options)); |
---|
| 216 | |
---|
| 217 | String nString = Utils.getOption('N', options); |
---|
| 218 | if (nString.length() != 0) { |
---|
| 219 | setFilterType(new SelectedTag(Integer.parseInt(nString), TAGS_FILTER)); |
---|
| 220 | } else { |
---|
| 221 | setFilterType(new SelectedTag(FILTER_STANDARDIZE, TAGS_FILTER)); |
---|
| 222 | } |
---|
| 223 | } |
---|
| 224 | |
---|
| 225 | /** |
---|
| 226 | * Gets the current settings of the classifier. |
---|
| 227 | * |
---|
| 228 | * @return an array of strings suitable for passing to setOptions |
---|
| 229 | */ |
---|
| 230 | public String[] getOptions() { |
---|
| 231 | Vector result; |
---|
| 232 | |
---|
| 233 | result = new Vector(); |
---|
| 234 | |
---|
| 235 | if (getDebug()) |
---|
| 236 | result.add("-D"); |
---|
| 237 | |
---|
| 238 | result.add("-N"); |
---|
| 239 | result.add("" + m_filterType); |
---|
| 240 | |
---|
| 241 | return (String[]) result.toArray(new String[result.size()]); |
---|
| 242 | } |
---|
| 243 | |
---|
| 244 | /** |
---|
| 245 | * Returns the tip text for this property |
---|
| 246 | * |
---|
| 247 | * @return tip text for this property suitable for |
---|
| 248 | * displaying in the explorer/experimenter gui |
---|
| 249 | */ |
---|
| 250 | public String filterTypeTipText() { |
---|
| 251 | return "The filter type for transforming the training data."; |
---|
| 252 | } |
---|
| 253 | |
---|
| 254 | /** |
---|
| 255 | * Gets how the training data will be transformed. Will be one of |
---|
| 256 | * FILTER_NORMALIZE, FILTER_STANDARDIZE, FILTER_NONE. |
---|
| 257 | * |
---|
| 258 | * @return the filtering mode |
---|
| 259 | */ |
---|
| 260 | public SelectedTag getFilterType() { |
---|
| 261 | return new SelectedTag(m_filterType, TAGS_FILTER); |
---|
| 262 | } |
---|
| 263 | |
---|
| 264 | /** |
---|
| 265 | * Sets how the training data will be transformed. Should be one of |
---|
| 266 | * FILTER_NORMALIZE, FILTER_STANDARDIZE, FILTER_NONE. |
---|
| 267 | * |
---|
| 268 | * @param newType the new filtering mode |
---|
| 269 | */ |
---|
| 270 | public void setFilterType(SelectedTag newType) { |
---|
| 271 | |
---|
| 272 | if (newType.getTags() == TAGS_FILTER) { |
---|
| 273 | m_filterType = newType.getSelectedTag().getID(); |
---|
| 274 | } |
---|
| 275 | } |
---|
| 276 | |
---|
| 277 | |
---|
| 278 | private class OptEng |
---|
| 279 | extends Optimization { |
---|
| 280 | |
---|
| 281 | /** |
---|
| 282 | * Evaluate objective function |
---|
| 283 | * @param x the current values of variables |
---|
| 284 | * @return the value of the objective function |
---|
| 285 | */ |
---|
| 286 | protected double objectiveFunction(double[] x){ |
---|
| 287 | double nll = 0; // -LogLikelihood |
---|
| 288 | for(int i=0; i<m_Classes.length; i++){ // ith bag |
---|
| 289 | int nI = m_Data[i][0].length; // numInstances in ith bag |
---|
| 290 | double bag = 0; // NLL of each bag |
---|
| 291 | |
---|
| 292 | for(int j=0; j<nI; j++){ |
---|
| 293 | double ins=0.0; |
---|
| 294 | for(int k=0; k<m_Data[i].length; k++) { |
---|
| 295 | ins += (m_Data[i][k][j]-x[k*2])*(m_Data[i][k][j]-x[k*2])/ |
---|
| 296 | (x[k*2+1]*x[k*2+1]); |
---|
| 297 | } |
---|
| 298 | ins = Math.exp(-ins); |
---|
| 299 | |
---|
| 300 | if(m_Classes[i] == 1) |
---|
| 301 | bag += ins/(double)nI; |
---|
| 302 | else |
---|
| 303 | bag += (1.0-ins)/(double)nI; |
---|
| 304 | } |
---|
| 305 | if(bag<=m_Zero) bag=m_Zero; |
---|
| 306 | nll -= Math.log(bag); |
---|
| 307 | } |
---|
| 308 | |
---|
| 309 | return nll; |
---|
| 310 | } |
---|
| 311 | |
---|
| 312 | /** |
---|
| 313 | * Evaluate Jacobian vector |
---|
| 314 | * @param x the current values of variables |
---|
| 315 | * @return the gradient vector |
---|
| 316 | */ |
---|
| 317 | protected double[] evaluateGradient(double[] x){ |
---|
| 318 | double[] grad = new double[x.length]; |
---|
| 319 | for(int i=0; i<m_Classes.length; i++){ // ith bag |
---|
| 320 | int nI = m_Data[i][0].length; // numInstances in ith bag |
---|
| 321 | |
---|
| 322 | double denom=0.0; |
---|
| 323 | double[] numrt = new double[x.length]; |
---|
| 324 | |
---|
| 325 | for(int j=0; j<nI; j++){ |
---|
| 326 | double exp=0.0; |
---|
| 327 | for(int k=0; k<m_Data[i].length; k++) |
---|
| 328 | exp += (m_Data[i][k][j]-x[k*2])*(m_Data[i][k][j]-x[k*2])/ |
---|
| 329 | (x[k*2+1]*x[k*2+1]); |
---|
| 330 | exp = Math.exp(-exp); |
---|
| 331 | if(m_Classes[i]==1) |
---|
| 332 | denom += exp; |
---|
| 333 | else |
---|
| 334 | denom += (1.0-exp); |
---|
| 335 | |
---|
| 336 | // Instance-wise update |
---|
| 337 | for(int p=0; p<m_Data[i].length; p++){ // pth variable |
---|
| 338 | numrt[2*p] += exp*2.0*(x[2*p]-m_Data[i][p][j])/ |
---|
| 339 | (x[2*p+1]*x[2*p+1]); |
---|
| 340 | numrt[2*p+1] += |
---|
| 341 | exp*(x[2*p]-m_Data[i][p][j])*(x[2*p]-m_Data[i][p][j])/ |
---|
| 342 | (x[2*p+1]*x[2*p+1]*x[2*p+1]); |
---|
| 343 | } |
---|
| 344 | } |
---|
| 345 | |
---|
| 346 | if(denom <= m_Zero){ |
---|
| 347 | denom = m_Zero; |
---|
| 348 | } |
---|
| 349 | |
---|
| 350 | // Bag-wise update |
---|
| 351 | for(int q=0; q<m_Data[i].length; q++){ |
---|
| 352 | if(m_Classes[i]==1){ |
---|
| 353 | grad[2*q] += numrt[2*q]/denom; |
---|
| 354 | grad[2*q+1] -= numrt[2*q+1]/denom; |
---|
| 355 | }else{ |
---|
| 356 | grad[2*q] -= numrt[2*q]/denom; |
---|
| 357 | grad[2*q+1] += numrt[2*q+1]/denom; |
---|
| 358 | } |
---|
| 359 | } |
---|
| 360 | } |
---|
| 361 | |
---|
| 362 | return grad; |
---|
| 363 | } |
---|
| 364 | |
---|
| 365 | /** |
---|
| 366 | * Returns the revision string. |
---|
| 367 | * |
---|
| 368 | * @return the revision |
---|
| 369 | */ |
---|
| 370 | public String getRevision() { |
---|
| 371 | return RevisionUtils.extract("$Revision: 5928 $"); |
---|
| 372 | } |
---|
| 373 | } |
---|
| 374 | |
---|
| 375 | /** |
---|
| 376 | * Returns default capabilities of the classifier. |
---|
| 377 | * |
---|
| 378 | * @return the capabilities of this classifier |
---|
| 379 | */ |
---|
| 380 | public Capabilities getCapabilities() { |
---|
| 381 | Capabilities result = super.getCapabilities(); |
---|
| 382 | result.disableAll(); |
---|
| 383 | |
---|
| 384 | // attributes |
---|
| 385 | result.enable(Capability.NOMINAL_ATTRIBUTES); |
---|
| 386 | result.enable(Capability.RELATIONAL_ATTRIBUTES); |
---|
| 387 | result.enable(Capability.MISSING_VALUES); |
---|
| 388 | |
---|
| 389 | // class |
---|
| 390 | result.enable(Capability.BINARY_CLASS); |
---|
| 391 | result.enable(Capability.MISSING_CLASS_VALUES); |
---|
| 392 | |
---|
| 393 | // other |
---|
| 394 | result.enable(Capability.ONLY_MULTIINSTANCE); |
---|
| 395 | |
---|
| 396 | return result; |
---|
| 397 | } |
---|
| 398 | |
---|
| 399 | /** |
---|
| 400 | * Returns the capabilities of this multi-instance classifier for the |
---|
| 401 | * relational data. |
---|
| 402 | * |
---|
| 403 | * @return the capabilities of this object |
---|
| 404 | * @see Capabilities |
---|
| 405 | */ |
---|
| 406 | public Capabilities getMultiInstanceCapabilities() { |
---|
| 407 | Capabilities result = super.getCapabilities(); |
---|
| 408 | result.disableAll(); |
---|
| 409 | |
---|
| 410 | // attributes |
---|
| 411 | result.enable(Capability.NOMINAL_ATTRIBUTES); |
---|
| 412 | result.enable(Capability.NUMERIC_ATTRIBUTES); |
---|
| 413 | result.enable(Capability.DATE_ATTRIBUTES); |
---|
| 414 | result.enable(Capability.MISSING_VALUES); |
---|
| 415 | |
---|
| 416 | // class |
---|
| 417 | result.disableAllClasses(); |
---|
| 418 | result.enable(Capability.NO_CLASS); |
---|
| 419 | |
---|
| 420 | return result; |
---|
| 421 | } |
---|
| 422 | |
---|
| 423 | /** |
---|
| 424 | * Builds the classifier |
---|
| 425 | * |
---|
| 426 | * @param train the training data to be used for generating the |
---|
| 427 | * boosted classifier. |
---|
| 428 | * @throws Exception if the classifier could not be built successfully |
---|
| 429 | */ |
---|
| 430 | public void buildClassifier(Instances train) throws Exception { |
---|
| 431 | // can classifier handle the data? |
---|
| 432 | getCapabilities().testWithFail(train); |
---|
| 433 | |
---|
| 434 | // remove instances with missing class |
---|
| 435 | train = new Instances(train); |
---|
| 436 | train.deleteWithMissingClass(); |
---|
| 437 | |
---|
| 438 | m_ClassIndex = train.classIndex(); |
---|
| 439 | m_NumClasses = train.numClasses(); |
---|
| 440 | |
---|
| 441 | int nR = train.attribute(1).relation().numAttributes(); |
---|
| 442 | int nC = train.numInstances(); |
---|
| 443 | int [] bagSize=new int [nC]; |
---|
| 444 | Instances datasets= new Instances(train.attribute(1).relation(),0); |
---|
| 445 | |
---|
| 446 | m_Data = new double [nC][nR][]; // Data values |
---|
| 447 | m_Classes = new int [nC]; // Class values |
---|
| 448 | m_Attributes = datasets.stringFreeStructure(); |
---|
| 449 | double sY1=0, sY0=0; // Number of classes |
---|
| 450 | |
---|
| 451 | if (m_Debug) { |
---|
| 452 | System.out.println("Extracting data..."); |
---|
| 453 | } |
---|
| 454 | FastVector maxSzIdx=new FastVector(); |
---|
| 455 | int maxSz=0; |
---|
| 456 | |
---|
| 457 | for(int h=0; h<nC; h++){ |
---|
| 458 | Instance current = train.instance(h); |
---|
| 459 | m_Classes[h] = (int)current.classValue(); // Class value starts from 0 |
---|
| 460 | Instances currInsts = current.relationalValue(1); |
---|
| 461 | int nI = currInsts.numInstances(); |
---|
| 462 | bagSize[h]=nI; |
---|
| 463 | |
---|
| 464 | for (int i=0; i<nI;i++){ |
---|
| 465 | Instance inst=currInsts.instance(i); |
---|
| 466 | datasets.add(inst); |
---|
| 467 | } |
---|
| 468 | |
---|
| 469 | if(m_Classes[h]==1){ |
---|
| 470 | if(nI>maxSz){ |
---|
| 471 | maxSz=nI; |
---|
| 472 | maxSzIdx=new FastVector(1); |
---|
| 473 | maxSzIdx.addElement(new Integer(h)); |
---|
| 474 | } |
---|
| 475 | else if(nI == maxSz) |
---|
| 476 | maxSzIdx.addElement(new Integer(h)); |
---|
| 477 | } |
---|
| 478 | } |
---|
| 479 | |
---|
| 480 | /* filter the training data */ |
---|
| 481 | if (m_filterType == FILTER_STANDARDIZE) |
---|
| 482 | m_Filter = new Standardize(); |
---|
| 483 | else if (m_filterType == FILTER_NORMALIZE) |
---|
| 484 | m_Filter = new Normalize(); |
---|
| 485 | else |
---|
| 486 | m_Filter = null; |
---|
| 487 | |
---|
| 488 | if (m_Filter!=null) { |
---|
| 489 | m_Filter.setInputFormat(datasets); |
---|
| 490 | datasets = Filter.useFilter(datasets, m_Filter); |
---|
| 491 | } |
---|
| 492 | |
---|
| 493 | m_Missing.setInputFormat(datasets); |
---|
| 494 | datasets = Filter.useFilter(datasets, m_Missing); |
---|
| 495 | |
---|
| 496 | int instIndex=0; |
---|
| 497 | int start=0; |
---|
| 498 | for(int h=0; h<nC; h++) { |
---|
| 499 | for (int i = 0; i < datasets.numAttributes(); i++) { |
---|
| 500 | // initialize m_data[][][] |
---|
| 501 | m_Data[h][i] = new double[bagSize[h]]; |
---|
| 502 | instIndex=start; |
---|
| 503 | for (int k=0; k<bagSize[h]; k++){ |
---|
| 504 | m_Data[h][i][k]=datasets.instance(instIndex).value(i); |
---|
| 505 | instIndex ++; |
---|
| 506 | } |
---|
| 507 | } |
---|
| 508 | start=instIndex; |
---|
| 509 | |
---|
| 510 | // Class count |
---|
| 511 | if (m_Classes[h] == 1) |
---|
| 512 | sY1++; |
---|
| 513 | else |
---|
| 514 | sY0++; |
---|
| 515 | } |
---|
| 516 | |
---|
| 517 | if (m_Debug) { |
---|
| 518 | System.out.println("\nIteration History..." ); |
---|
| 519 | } |
---|
| 520 | |
---|
| 521 | double[] x = new double[nR*2], tmp = new double[x.length]; |
---|
| 522 | double[][] b = new double[2][x.length]; |
---|
| 523 | |
---|
| 524 | OptEng opt; |
---|
| 525 | double nll, bestnll = Double.MAX_VALUE; |
---|
| 526 | for (int t=0; t<x.length; t++){ |
---|
| 527 | b[0][t] = Double.NaN; |
---|
| 528 | b[1][t] = Double.NaN; |
---|
| 529 | } |
---|
| 530 | |
---|
| 531 | // Largest positive exemplar |
---|
| 532 | for(int s=0; s<maxSzIdx.size(); s++){ |
---|
| 533 | int exIdx = ((Integer)maxSzIdx.elementAt(s)).intValue(); |
---|
| 534 | for(int p=0; p<m_Data[exIdx][0].length; p++){ |
---|
| 535 | for (int q=0; q < nR;q++){ |
---|
| 536 | x[2*q] = m_Data[exIdx][q][p]; // pick one instance |
---|
| 537 | x[2*q+1] = 1.0; |
---|
| 538 | } |
---|
| 539 | |
---|
| 540 | opt = new OptEng(); |
---|
| 541 | tmp = opt.findArgmin(x, b); |
---|
| 542 | while(tmp==null){ |
---|
| 543 | tmp = opt.getVarbValues(); |
---|
| 544 | if (m_Debug) |
---|
| 545 | System.out.println("200 iterations finished, not enough!"); |
---|
| 546 | tmp = opt.findArgmin(tmp, b); |
---|
| 547 | } |
---|
| 548 | nll = opt.getMinFunction(); |
---|
| 549 | |
---|
| 550 | if(nll < bestnll){ |
---|
| 551 | bestnll = nll; |
---|
| 552 | m_Par = tmp; |
---|
| 553 | if (m_Debug) |
---|
| 554 | System.out.println("!!!!!!!!!!!!!!!!Smaller NLL found: "+nll); |
---|
| 555 | } |
---|
| 556 | if (m_Debug) |
---|
| 557 | System.out.println(exIdx+": -------------<Converged>--------------"); |
---|
| 558 | } |
---|
| 559 | } |
---|
| 560 | } |
---|
| 561 | |
---|
| 562 | /** |
---|
| 563 | * Computes the distribution for a given exemplar |
---|
| 564 | * |
---|
| 565 | * @param exmp the exemplar for which distribution is computed |
---|
| 566 | * @return the distribution |
---|
| 567 | * @throws Exception if the distribution can't be computed successfully |
---|
| 568 | */ |
---|
| 569 | public double[] distributionForInstance(Instance exmp) |
---|
| 570 | throws Exception { |
---|
| 571 | |
---|
| 572 | // Extract the data |
---|
| 573 | Instances ins = exmp.relationalValue(1); |
---|
| 574 | if(m_Filter!=null) |
---|
| 575 | ins = Filter.useFilter(ins, m_Filter); |
---|
| 576 | |
---|
| 577 | ins = Filter.useFilter(ins, m_Missing); |
---|
| 578 | |
---|
| 579 | int nI = ins.numInstances(), nA = ins.numAttributes(); |
---|
| 580 | double[][] dat = new double [nI][nA]; |
---|
| 581 | for(int j=0; j<nI; j++){ |
---|
| 582 | for(int k=0; k<nA; k++){ |
---|
| 583 | dat[j][k] = ins.instance(j).value(k); |
---|
| 584 | } |
---|
| 585 | } |
---|
| 586 | |
---|
| 587 | // Compute the probability of the bag |
---|
| 588 | double [] distribution = new double[2]; |
---|
| 589 | distribution[1]=0.0; // Prob. for class 1 |
---|
| 590 | |
---|
| 591 | for(int i=0; i<nI; i++){ |
---|
| 592 | double exp = 0.0; |
---|
| 593 | for(int r=0; r<nA; r++) |
---|
| 594 | exp += (m_Par[r*2]-dat[i][r])*(m_Par[r*2]-dat[i][r])/ |
---|
| 595 | ((m_Par[r*2+1])*(m_Par[r*2+1])); |
---|
| 596 | exp = Math.exp(-exp); |
---|
| 597 | |
---|
| 598 | // Prob. updated for one instance |
---|
| 599 | distribution[1] += exp/(double)nI; |
---|
| 600 | distribution[0] += (1.0-exp)/(double)nI; |
---|
| 601 | } |
---|
| 602 | |
---|
| 603 | return distribution; |
---|
| 604 | } |
---|
| 605 | |
---|
| 606 | /** |
---|
| 607 | * Gets a string describing the classifier. |
---|
| 608 | * |
---|
| 609 | * @return a string describing the classifer built. |
---|
| 610 | */ |
---|
| 611 | public String toString() { |
---|
| 612 | |
---|
| 613 | String result = "Modified Logistic Regression"; |
---|
| 614 | if (m_Par == null) { |
---|
| 615 | return result + ": No model built yet."; |
---|
| 616 | } |
---|
| 617 | |
---|
| 618 | result += "\nCoefficients...\n" |
---|
| 619 | + "Variable Coeff.\n"; |
---|
| 620 | for (int j = 0, idx=0; j < m_Par.length/2; j++, idx++) { |
---|
| 621 | |
---|
| 622 | result += m_Attributes.attribute(idx).name(); |
---|
| 623 | result += " "+Utils.doubleToString(m_Par[j*2], 12, 4); |
---|
| 624 | result += " "+Utils.doubleToString(m_Par[j*2+1], 12, 4)+"\n"; |
---|
| 625 | } |
---|
| 626 | |
---|
| 627 | return result; |
---|
| 628 | } |
---|
| 629 | |
---|
| 630 | /** |
---|
| 631 | * Returns the revision string. |
---|
| 632 | * |
---|
| 633 | * @return the revision |
---|
| 634 | */ |
---|
| 635 | public String getRevision() { |
---|
| 636 | return RevisionUtils.extract("$Revision: 5928 $"); |
---|
| 637 | } |
---|
| 638 | |
---|
| 639 | /** |
---|
| 640 | * Main method for testing this class. |
---|
| 641 | * |
---|
| 642 | * @param argv should contain the command line arguments to the |
---|
| 643 | * scheme (see Evaluation) |
---|
| 644 | */ |
---|
| 645 | public static void main(String[] argv) { |
---|
| 646 | runClassifier(new MDD(), argv); |
---|
| 647 | } |
---|
| 648 | } |
---|