| [4] | 1 | /* | 
|---|
 | 2 |  *    This program is free software; you can redistribute it and/or modify | 
|---|
 | 3 |  *    it under the terms of the GNU General Public License as published by | 
|---|
 | 4 |  *    the Free Software Foundation; either version 2 of the License, or | 
|---|
 | 5 |  *    (at your option) any later version. | 
|---|
 | 6 |  * | 
|---|
 | 7 |  *    This program is distributed in the hope that it will be useful, | 
|---|
 | 8 |  *    but WITHOUT ANY WARRANTY; without even the implied warranty of | 
|---|
 | 9 |  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | 
|---|
 | 10 |  *    GNU General Public License for more details. | 
|---|
 | 11 |  * | 
|---|
 | 12 |  *    You should have received a copy of the GNU General Public License | 
|---|
 | 13 |  *    along with this program; if not, write to the Free Software | 
|---|
 | 14 |  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. | 
|---|
 | 15 |  */ | 
|---|
 | 16 |  | 
|---|
 | 17 | /* | 
|---|
 | 18 |  * MDD.java | 
|---|
 | 19 |  * Copyright (C) 2005 University of Waikato, Hamilton, New Zealand | 
|---|
 | 20 |  * | 
|---|
 | 21 |  */ | 
|---|
 | 22 |  | 
|---|
 | 23 | package weka.classifiers.mi; | 
|---|
 | 24 |  | 
|---|
 | 25 | import weka.classifiers.Classifier; | 
|---|
 | 26 | import weka.classifiers.AbstractClassifier; | 
|---|
 | 27 | import weka.core.Capabilities; | 
|---|
 | 28 | import weka.core.FastVector; | 
|---|
 | 29 | import weka.core.Instance; | 
|---|
 | 30 | import weka.core.Instances; | 
|---|
 | 31 | import weka.core.MultiInstanceCapabilitiesHandler; | 
|---|
 | 32 | import weka.core.Optimization; | 
|---|
 | 33 | import weka.core.Option; | 
|---|
 | 34 | import weka.core.OptionHandler; | 
|---|
 | 35 | import weka.core.RevisionUtils; | 
|---|
 | 36 | import weka.core.SelectedTag; | 
|---|
 | 37 | import weka.core.Tag; | 
|---|
 | 38 | import weka.core.TechnicalInformation; | 
|---|
 | 39 | import weka.core.TechnicalInformationHandler; | 
|---|
 | 40 | import weka.core.Utils; | 
|---|
 | 41 | import weka.core.Capabilities.Capability; | 
|---|
 | 42 | import weka.core.TechnicalInformation.Field; | 
|---|
 | 43 | import weka.core.TechnicalInformation.Type; | 
|---|
 | 44 | import weka.filters.Filter; | 
|---|
 | 45 | import weka.filters.unsupervised.attribute.Normalize; | 
|---|
 | 46 | import weka.filters.unsupervised.attribute.ReplaceMissingValues; | 
|---|
 | 47 | import weka.filters.unsupervised.attribute.Standardize; | 
|---|
 | 48 |  | 
|---|
 | 49 | import java.util.Enumeration; | 
|---|
 | 50 | import java.util.Vector; | 
|---|
 | 51 |  | 
|---|
 | 52 | /** | 
|---|
 | 53 |  <!-- globalinfo-start --> | 
|---|
 | 54 |  * Modified Diverse Density algorithm, with collective assumption.<br/> | 
|---|
 | 55 |  * <br/> | 
|---|
 | 56 |  * More information about DD:<br/> | 
|---|
 | 57 |  * <br/> | 
|---|
 | 58 |  * Oded Maron (1998). Learning from ambiguity.<br/> | 
|---|
 | 59 |  * <br/> | 
|---|
 | 60 |  * O. Maron, T. Lozano-Perez (1998). A Framework for Multiple Instance Learning. Neural Information Processing Systems. 10. | 
|---|
 | 61 |  * <p/> | 
|---|
 | 62 |  <!-- globalinfo-end --> | 
|---|
 | 63 |  *  | 
|---|
 | 64 |  <!-- technical-bibtex-start --> | 
|---|
 | 65 |  * BibTeX: | 
|---|
 | 66 |  * <pre> | 
|---|
 | 67 |  * @phdthesis{Maron1998, | 
|---|
 | 68 |  *    author = {Oded Maron}, | 
|---|
 | 69 |  *    school = {Massachusetts Institute of Technology}, | 
|---|
 | 70 |  *    title = {Learning from ambiguity}, | 
|---|
 | 71 |  *    year = {1998} | 
|---|
 | 72 |  * } | 
|---|
 | 73 |  *  | 
|---|
 | 74 |  * @article{Maron1998, | 
|---|
 | 75 |  *    author = {O. Maron and T. Lozano-Perez}, | 
|---|
 | 76 |  *    journal = {Neural Information Processing Systems}, | 
|---|
 | 77 |  *    title = {A Framework for Multiple Instance Learning}, | 
|---|
 | 78 |  *    volume = {10}, | 
|---|
 | 79 |  *    year = {1998} | 
|---|
 | 80 |  * } | 
|---|
 | 81 |  * </pre> | 
|---|
 | 82 |  * <p/> | 
|---|
 | 83 |  <!-- technical-bibtex-end --> | 
|---|
 | 84 |  *  | 
|---|
 | 85 |  <!-- options-start --> | 
|---|
 | 86 |  * Valid options are: <p/> | 
|---|
 | 87 |  *  | 
|---|
 | 88 |  * <pre> -D | 
|---|
 | 89 |  *  Turn on debugging output.</pre> | 
|---|
 | 90 |  *  | 
|---|
 | 91 |  * <pre> -N <num> | 
|---|
 | 92 |  *  Whether to 0=normalize/1=standardize/2=neither. | 
|---|
 | 93 |  *  (default 1=standardize)</pre> | 
|---|
 | 94 |  *  | 
|---|
 | 95 |  <!-- options-end --> | 
|---|
 | 96 |  *     | 
|---|
 | 97 |  * @author Eibe Frank (eibe@cs.waikato.ac.nz) | 
|---|
 | 98 |  * @author Xin Xu (xx5@cs.waikato.ac.nz) | 
|---|
 | 99 |  * @version $Revision: 5928 $  | 
|---|
 | 100 |  */ | 
|---|
 | 101 | public class MDD  | 
|---|
 | 102 |   extends AbstractClassifier  | 
|---|
 | 103 |   implements OptionHandler, MultiInstanceCapabilitiesHandler, | 
|---|
 | 104 |              TechnicalInformationHandler { | 
|---|
 | 105 |    | 
|---|
 | 106 |   /** for serialization */ | 
|---|
 | 107 |   static final long serialVersionUID = -7273119490545290581L; | 
|---|
 | 108 |  | 
|---|
 | 109 |   /** The index of the class attribute */ | 
|---|
 | 110 |   protected int m_ClassIndex; | 
|---|
 | 111 |  | 
|---|
 | 112 |   protected double[] m_Par; | 
|---|
 | 113 |  | 
|---|
 | 114 |   /** The number of the class labels */ | 
|---|
 | 115 |   protected int m_NumClasses; | 
|---|
 | 116 |  | 
|---|
 | 117 |   /** Class labels for each bag */ | 
|---|
 | 118 |   protected int[] m_Classes; | 
|---|
 | 119 |  | 
|---|
 | 120 |   /** MI data */  | 
|---|
 | 121 |   protected double[][][] m_Data; | 
|---|
 | 122 |  | 
|---|
 | 123 |   /** All attribute names */ | 
|---|
 | 124 |   protected Instances m_Attributes; | 
|---|
 | 125 |  | 
|---|
 | 126 |   /** The filter used to standardize/normalize all values. */ | 
|---|
 | 127 |   protected Filter m_Filter =null; | 
|---|
 | 128 |  | 
|---|
 | 129 |   /** Whether to normalize/standardize/neither, default:standardize */ | 
|---|
 | 130 |   protected int m_filterType = FILTER_STANDARDIZE; | 
|---|
 | 131 |  | 
|---|
 | 132 |   /** Normalize training data */ | 
|---|
 | 133 |   public static final int FILTER_NORMALIZE = 0; | 
|---|
 | 134 |   /** Standardize training data */ | 
|---|
 | 135 |   public static final int FILTER_STANDARDIZE = 1; | 
|---|
 | 136 |   /** No normalization/standardization */ | 
|---|
 | 137 |   public static final int FILTER_NONE = 2; | 
|---|
 | 138 |   /** The filter to apply to the training data */ | 
|---|
 | 139 |   public static final Tag [] TAGS_FILTER = { | 
|---|
 | 140 |     new Tag(FILTER_NORMALIZE, "Normalize training data"), | 
|---|
 | 141 |     new Tag(FILTER_STANDARDIZE, "Standardize training data"), | 
|---|
 | 142 |     new Tag(FILTER_NONE, "No normalization/standardization"), | 
|---|
 | 143 |   }; | 
|---|
 | 144 |  | 
|---|
 | 145 |   /** The filter used to get rid of missing values. */ | 
|---|
 | 146 |   protected ReplaceMissingValues m_Missing = new ReplaceMissingValues(); | 
|---|
 | 147 |  | 
|---|
 | 148 |   /** | 
|---|
 | 149 |    * Returns a string describing this filter | 
|---|
 | 150 |    * | 
|---|
 | 151 |    * @return a description of the filter suitable for | 
|---|
 | 152 |    * displaying in the explorer/experimenter gui | 
|---|
 | 153 |    */ | 
|---|
 | 154 |   public String globalInfo() { | 
|---|
 | 155 |     return  | 
|---|
 | 156 |         "Modified Diverse Density algorithm, with collective assumption.\n\n" | 
|---|
 | 157 |       + "More information about DD:\n\n" | 
|---|
 | 158 |       + getTechnicalInformation().toString(); | 
|---|
 | 159 |   } | 
|---|
 | 160 |  | 
|---|
 | 161 |   /** | 
|---|
 | 162 |    * Returns an instance of a TechnicalInformation object, containing  | 
|---|
 | 163 |    * detailed information about the technical background of this class, | 
|---|
 | 164 |    * e.g., paper reference or book this class is based on. | 
|---|
 | 165 |    *  | 
|---|
 | 166 |    * @return the technical information about this class | 
|---|
 | 167 |    */ | 
|---|
 | 168 |   public TechnicalInformation getTechnicalInformation() { | 
|---|
 | 169 |     TechnicalInformation        result; | 
|---|
 | 170 |     TechnicalInformation        additional; | 
|---|
 | 171 |      | 
|---|
 | 172 |     result = new TechnicalInformation(Type.PHDTHESIS); | 
|---|
 | 173 |     result.setValue(Field.AUTHOR, "Oded Maron"); | 
|---|
 | 174 |     result.setValue(Field.YEAR, "1998"); | 
|---|
 | 175 |     result.setValue(Field.TITLE, "Learning from ambiguity"); | 
|---|
 | 176 |     result.setValue(Field.SCHOOL, "Massachusetts Institute of Technology"); | 
|---|
 | 177 |      | 
|---|
 | 178 |     additional = result.add(Type.ARTICLE); | 
|---|
 | 179 |     additional.setValue(Field.AUTHOR, "O. Maron and T. Lozano-Perez"); | 
|---|
 | 180 |     additional.setValue(Field.YEAR, "1998"); | 
|---|
 | 181 |     additional.setValue(Field.TITLE, "A Framework for Multiple Instance Learning"); | 
|---|
 | 182 |     additional.setValue(Field.JOURNAL, "Neural Information Processing Systems"); | 
|---|
 | 183 |     additional.setValue(Field.VOLUME, "10"); | 
|---|
 | 184 |      | 
|---|
 | 185 |     return result; | 
|---|
 | 186 |   } | 
|---|
 | 187 |  | 
|---|
 | 188 |   /** | 
|---|
 | 189 |    * Returns an enumeration describing the available options | 
|---|
 | 190 |    * | 
|---|
 | 191 |    * @return an enumeration of all the available options | 
|---|
 | 192 |    */ | 
|---|
 | 193 |   public Enumeration listOptions() { | 
|---|
 | 194 |     Vector result = new Vector(); | 
|---|
 | 195 |      | 
|---|
 | 196 |     result.addElement(new Option( | 
|---|
 | 197 |           "\tTurn on debugging output.", | 
|---|
 | 198 |           "D", 0, "-D")); | 
|---|
 | 199 |      | 
|---|
 | 200 |     result.addElement(new Option( | 
|---|
 | 201 |           "\tWhether to 0=normalize/1=standardize/2=neither.\n" | 
|---|
 | 202 |           + "\t(default 1=standardize)", | 
|---|
 | 203 |           "N", 1, "-N <num>")); | 
|---|
 | 204 |  | 
|---|
 | 205 |     return result.elements(); | 
|---|
 | 206 |   } | 
|---|
 | 207 |  | 
|---|
 | 208 |   /** | 
|---|
 | 209 |    * Parses a given list of options.  | 
|---|
 | 210 |    * | 
|---|
 | 211 |    * @param options the list of options as an array of strings | 
|---|
 | 212 |    * @throws Exception if an option is not supported | 
|---|
 | 213 |    */ | 
|---|
 | 214 |   public void setOptions(String[] options) throws Exception { | 
|---|
 | 215 |     setDebug(Utils.getFlag('D', options)); | 
|---|
 | 216 |  | 
|---|
 | 217 |     String nString = Utils.getOption('N', options); | 
|---|
 | 218 |     if (nString.length() != 0) { | 
|---|
 | 219 |       setFilterType(new SelectedTag(Integer.parseInt(nString), TAGS_FILTER)); | 
|---|
 | 220 |     } else { | 
|---|
 | 221 |       setFilterType(new SelectedTag(FILTER_STANDARDIZE, TAGS_FILTER)); | 
|---|
 | 222 |     }      | 
|---|
 | 223 |   } | 
|---|
 | 224 |  | 
|---|
 | 225 |   /** | 
|---|
 | 226 |    * Gets the current settings of the classifier. | 
|---|
 | 227 |    * | 
|---|
 | 228 |    * @return an array of strings suitable for passing to setOptions | 
|---|
 | 229 |    */ | 
|---|
 | 230 |   public String[] getOptions() { | 
|---|
 | 231 |     Vector        result; | 
|---|
 | 232 |      | 
|---|
 | 233 |     result = new Vector(); | 
|---|
 | 234 |  | 
|---|
 | 235 |     if (getDebug()) | 
|---|
 | 236 |       result.add("-D"); | 
|---|
 | 237 |      | 
|---|
 | 238 |     result.add("-N"); | 
|---|
 | 239 |     result.add("" + m_filterType); | 
|---|
 | 240 |  | 
|---|
 | 241 |     return (String[]) result.toArray(new String[result.size()]); | 
|---|
 | 242 |   } | 
|---|
 | 243 |  | 
|---|
 | 244 |   /** | 
|---|
 | 245 |    * Returns the tip text for this property | 
|---|
 | 246 |    * | 
|---|
 | 247 |    * @return tip text for this property suitable for | 
|---|
 | 248 |    * displaying in the explorer/experimenter gui | 
|---|
 | 249 |    */ | 
|---|
 | 250 |   public String filterTypeTipText() { | 
|---|
 | 251 |     return "The filter type for transforming the training data."; | 
|---|
 | 252 |   } | 
|---|
 | 253 |  | 
|---|
 | 254 |   /** | 
|---|
 | 255 |    * Gets how the training data will be transformed. Will be one of | 
|---|
 | 256 |    * FILTER_NORMALIZE, FILTER_STANDARDIZE, FILTER_NONE. | 
|---|
 | 257 |    * | 
|---|
 | 258 |    * @return the filtering mode | 
|---|
 | 259 |    */ | 
|---|
 | 260 |   public SelectedTag getFilterType() { | 
|---|
 | 261 |     return new SelectedTag(m_filterType, TAGS_FILTER); | 
|---|
 | 262 |   } | 
|---|
 | 263 |  | 
|---|
 | 264 |   /** | 
|---|
 | 265 |    * Sets how the training data will be transformed. Should be one of | 
|---|
 | 266 |    * FILTER_NORMALIZE, FILTER_STANDARDIZE, FILTER_NONE. | 
|---|
 | 267 |    * | 
|---|
 | 268 |    * @param newType the new filtering mode | 
|---|
 | 269 |    */ | 
|---|
 | 270 |   public void setFilterType(SelectedTag newType) { | 
|---|
 | 271 |  | 
|---|
 | 272 |     if (newType.getTags() == TAGS_FILTER) { | 
|---|
 | 273 |       m_filterType = newType.getSelectedTag().getID(); | 
|---|
 | 274 |     } | 
|---|
 | 275 |   } | 
|---|
 | 276 |  | 
|---|
 | 277 |  | 
|---|
 | 278 |   private class OptEng  | 
|---|
 | 279 |     extends Optimization { | 
|---|
 | 280 |      | 
|---|
 | 281 |     /**  | 
|---|
 | 282 |      * Evaluate objective function | 
|---|
 | 283 |      * @param x the current values of variables | 
|---|
 | 284 |      * @return the value of the objective function  | 
|---|
 | 285 |      */ | 
|---|
 | 286 |     protected double objectiveFunction(double[] x){ | 
|---|
 | 287 |       double nll = 0; // -LogLikelihood | 
|---|
 | 288 |       for(int i=0; i<m_Classes.length; i++){ // ith bag | 
|---|
 | 289 |         int nI = m_Data[i][0].length; // numInstances in ith bag | 
|---|
 | 290 |         double bag = 0;  // NLL of each bag | 
|---|
 | 291 |  | 
|---|
 | 292 |         for(int j=0; j<nI; j++){ | 
|---|
 | 293 |           double ins=0.0;  | 
|---|
 | 294 |           for(int k=0; k<m_Data[i].length; k++) { | 
|---|
 | 295 |             ins += (m_Data[i][k][j]-x[k*2])*(m_Data[i][k][j]-x[k*2])/ | 
|---|
 | 296 |               (x[k*2+1]*x[k*2+1]); | 
|---|
 | 297 |           } | 
|---|
 | 298 |           ins = Math.exp(-ins);  | 
|---|
 | 299 |  | 
|---|
 | 300 |           if(m_Classes[i] == 1) | 
|---|
 | 301 |             bag += ins/(double)nI; | 
|---|
 | 302 |           else | 
|---|
 | 303 |             bag += (1.0-ins)/(double)nI;    | 
|---|
 | 304 |         }                | 
|---|
 | 305 |         if(bag<=m_Zero) bag=m_Zero;  | 
|---|
 | 306 |         nll -= Math.log(bag); | 
|---|
 | 307 |       }          | 
|---|
 | 308 |  | 
|---|
 | 309 |       return nll; | 
|---|
 | 310 |     } | 
|---|
 | 311 |  | 
|---|
 | 312 |     /**  | 
|---|
 | 313 |      * Evaluate Jacobian vector | 
|---|
 | 314 |      * @param x the current values of variables | 
|---|
 | 315 |      * @return the gradient vector  | 
|---|
 | 316 |      */ | 
|---|
 | 317 |     protected double[] evaluateGradient(double[] x){ | 
|---|
 | 318 |       double[] grad = new double[x.length]; | 
|---|
 | 319 |       for(int i=0; i<m_Classes.length; i++){ // ith bag | 
|---|
 | 320 |         int nI = m_Data[i][0].length; // numInstances in ith bag  | 
|---|
 | 321 |  | 
|---|
 | 322 |         double denom=0.0; | 
|---|
 | 323 |         double[] numrt = new double[x.length]; | 
|---|
 | 324 |  | 
|---|
 | 325 |         for(int j=0; j<nI; j++){ | 
|---|
 | 326 |           double exp=0.0; | 
|---|
 | 327 |           for(int k=0; k<m_Data[i].length; k++) | 
|---|
 | 328 |             exp += (m_Data[i][k][j]-x[k*2])*(m_Data[i][k][j]-x[k*2])/ | 
|---|
 | 329 |               (x[k*2+1]*x[k*2+1]);                       | 
|---|
 | 330 |           exp = Math.exp(-exp); | 
|---|
 | 331 |           if(m_Classes[i]==1) | 
|---|
 | 332 |             denom += exp; | 
|---|
 | 333 |           else | 
|---|
 | 334 |             denom += (1.0-exp);             | 
|---|
 | 335 |  | 
|---|
 | 336 |           // Instance-wise update | 
|---|
 | 337 |           for(int p=0; p<m_Data[i].length; p++){  // pth variable | 
|---|
 | 338 |             numrt[2*p] += exp*2.0*(x[2*p]-m_Data[i][p][j])/ | 
|---|
 | 339 |               (x[2*p+1]*x[2*p+1]); | 
|---|
 | 340 |             numrt[2*p+1] +=  | 
|---|
 | 341 |               exp*(x[2*p]-m_Data[i][p][j])*(x[2*p]-m_Data[i][p][j])/ | 
|---|
 | 342 |               (x[2*p+1]*x[2*p+1]*x[2*p+1]); | 
|---|
 | 343 |           }                      | 
|---|
 | 344 |         } | 
|---|
 | 345 |  | 
|---|
 | 346 |         if(denom <= m_Zero){ | 
|---|
 | 347 |           denom = m_Zero; | 
|---|
 | 348 |         } | 
|---|
 | 349 |  | 
|---|
 | 350 |         // Bag-wise update  | 
|---|
 | 351 |         for(int q=0; q<m_Data[i].length; q++){ | 
|---|
 | 352 |           if(m_Classes[i]==1){ | 
|---|
 | 353 |             grad[2*q] += numrt[2*q]/denom; | 
|---|
 | 354 |             grad[2*q+1] -= numrt[2*q+1]/denom; | 
|---|
 | 355 |           }else{ | 
|---|
 | 356 |             grad[2*q] -= numrt[2*q]/denom; | 
|---|
 | 357 |             grad[2*q+1] += numrt[2*q+1]/denom; | 
|---|
 | 358 |           } | 
|---|
 | 359 |         } | 
|---|
 | 360 |       } | 
|---|
 | 361 |  | 
|---|
 | 362 |       return grad; | 
|---|
 | 363 |     } | 
|---|
 | 364 |      | 
|---|
 | 365 |     /** | 
|---|
 | 366 |      * Returns the revision string. | 
|---|
 | 367 |      *  | 
|---|
 | 368 |      * @return          the revision | 
|---|
 | 369 |      */ | 
|---|
 | 370 |     public String getRevision() { | 
|---|
 | 371 |       return RevisionUtils.extract("$Revision: 5928 $"); | 
|---|
 | 372 |     } | 
|---|
 | 373 |   } | 
|---|
 | 374 |  | 
|---|
 | 375 |   /** | 
|---|
 | 376 |    * Returns default capabilities of the classifier. | 
|---|
 | 377 |    * | 
|---|
 | 378 |    * @return      the capabilities of this classifier | 
|---|
 | 379 |    */ | 
|---|
 | 380 |   public Capabilities getCapabilities() { | 
|---|
 | 381 |     Capabilities result = super.getCapabilities(); | 
|---|
 | 382 |     result.disableAll(); | 
|---|
 | 383 |  | 
|---|
 | 384 |     // attributes | 
|---|
 | 385 |     result.enable(Capability.NOMINAL_ATTRIBUTES); | 
|---|
 | 386 |     result.enable(Capability.RELATIONAL_ATTRIBUTES); | 
|---|
 | 387 |     result.enable(Capability.MISSING_VALUES); | 
|---|
 | 388 |  | 
|---|
 | 389 |     // class | 
|---|
 | 390 |     result.enable(Capability.BINARY_CLASS); | 
|---|
 | 391 |     result.enable(Capability.MISSING_CLASS_VALUES); | 
|---|
 | 392 |      | 
|---|
 | 393 |     // other | 
|---|
 | 394 |     result.enable(Capability.ONLY_MULTIINSTANCE); | 
|---|
 | 395 |      | 
|---|
 | 396 |     return result; | 
|---|
 | 397 |   } | 
|---|
 | 398 |  | 
|---|
 | 399 |   /** | 
|---|
 | 400 |    * Returns the capabilities of this multi-instance classifier for the | 
|---|
 | 401 |    * relational data. | 
|---|
 | 402 |    * | 
|---|
 | 403 |    * @return            the capabilities of this object | 
|---|
 | 404 |    * @see               Capabilities | 
|---|
 | 405 |    */ | 
|---|
 | 406 |   public Capabilities getMultiInstanceCapabilities() { | 
|---|
 | 407 |     Capabilities result = super.getCapabilities(); | 
|---|
 | 408 |     result.disableAll(); | 
|---|
 | 409 |      | 
|---|
 | 410 |     // attributes | 
|---|
 | 411 |     result.enable(Capability.NOMINAL_ATTRIBUTES); | 
|---|
 | 412 |     result.enable(Capability.NUMERIC_ATTRIBUTES); | 
|---|
 | 413 |     result.enable(Capability.DATE_ATTRIBUTES); | 
|---|
 | 414 |     result.enable(Capability.MISSING_VALUES); | 
|---|
 | 415 |  | 
|---|
 | 416 |     // class | 
|---|
 | 417 |     result.disableAllClasses(); | 
|---|
 | 418 |     result.enable(Capability.NO_CLASS); | 
|---|
 | 419 |      | 
|---|
 | 420 |     return result; | 
|---|
 | 421 |   } | 
|---|
 | 422 |  | 
|---|
 | 423 |   /** | 
|---|
 | 424 |    * Builds the classifier | 
|---|
 | 425 |    * | 
|---|
 | 426 |    * @param train the training data to be used for generating the | 
|---|
 | 427 |    * boosted classifier. | 
|---|
 | 428 |    * @throws Exception if the classifier could not be built successfully | 
|---|
 | 429 |    */ | 
|---|
 | 430 |   public void buildClassifier(Instances train) throws Exception { | 
|---|
 | 431 |     // can classifier handle the data? | 
|---|
 | 432 |     getCapabilities().testWithFail(train); | 
|---|
 | 433 |  | 
|---|
 | 434 |     // remove instances with missing class | 
|---|
 | 435 |     train = new Instances(train); | 
|---|
 | 436 |     train.deleteWithMissingClass(); | 
|---|
 | 437 |      | 
|---|
 | 438 |     m_ClassIndex = train.classIndex(); | 
|---|
 | 439 |     m_NumClasses = train.numClasses(); | 
|---|
 | 440 |  | 
|---|
 | 441 |     int nR = train.attribute(1).relation().numAttributes(); | 
|---|
 | 442 |     int nC = train.numInstances(); | 
|---|
 | 443 |     int [] bagSize=new int [nC]; | 
|---|
 | 444 |     Instances datasets= new Instances(train.attribute(1).relation(),0); | 
|---|
 | 445 |  | 
|---|
 | 446 |     m_Data  = new double [nC][nR][];              // Data values | 
|---|
 | 447 |     m_Classes  = new int [nC];                    // Class values | 
|---|
 | 448 |     m_Attributes = datasets.stringFreeStructure();               | 
|---|
 | 449 |     double sY1=0, sY0=0;                          // Number of classes | 
|---|
 | 450 |  | 
|---|
 | 451 |     if (m_Debug) { | 
|---|
 | 452 |       System.out.println("Extracting data..."); | 
|---|
 | 453 |     } | 
|---|
 | 454 |     FastVector maxSzIdx=new FastVector(); | 
|---|
 | 455 |     int maxSz=0; | 
|---|
 | 456 |  | 
|---|
 | 457 |     for(int h=0; h<nC; h++){ | 
|---|
 | 458 |       Instance current = train.instance(h); | 
|---|
 | 459 |       m_Classes[h] = (int)current.classValue();  // Class value starts from 0 | 
|---|
 | 460 |       Instances currInsts = current.relationalValue(1); | 
|---|
 | 461 |       int nI = currInsts.numInstances(); | 
|---|
 | 462 |       bagSize[h]=nI; | 
|---|
 | 463 |  | 
|---|
 | 464 |       for (int i=0; i<nI;i++){ | 
|---|
 | 465 |         Instance inst=currInsts.instance(i); | 
|---|
 | 466 |         datasets.add(inst); | 
|---|
 | 467 |       } | 
|---|
 | 468 |  | 
|---|
 | 469 |       if(m_Classes[h]==1){ | 
|---|
 | 470 |         if(nI>maxSz){ | 
|---|
 | 471 |           maxSz=nI; | 
|---|
 | 472 |           maxSzIdx=new FastVector(1); | 
|---|
 | 473 |           maxSzIdx.addElement(new Integer(h)); | 
|---|
 | 474 |         } | 
|---|
 | 475 |         else if(nI == maxSz) | 
|---|
 | 476 |           maxSzIdx.addElement(new Integer(h)); | 
|---|
 | 477 |       } | 
|---|
 | 478 |     } | 
|---|
 | 479 |  | 
|---|
 | 480 |     /* filter the training data */ | 
|---|
 | 481 |     if (m_filterType == FILTER_STANDARDIZE)   | 
|---|
 | 482 |       m_Filter = new Standardize(); | 
|---|
 | 483 |     else if (m_filterType == FILTER_NORMALIZE) | 
|---|
 | 484 |       m_Filter = new Normalize(); | 
|---|
 | 485 |     else  | 
|---|
 | 486 |       m_Filter = null;  | 
|---|
 | 487 |  | 
|---|
 | 488 |     if (m_Filter!=null) { | 
|---|
 | 489 |       m_Filter.setInputFormat(datasets); | 
|---|
 | 490 |       datasets = Filter.useFilter(datasets, m_Filter);   | 
|---|
 | 491 |     } | 
|---|
 | 492 |  | 
|---|
 | 493 |     m_Missing.setInputFormat(datasets); | 
|---|
 | 494 |     datasets = Filter.useFilter(datasets, m_Missing); | 
|---|
 | 495 |  | 
|---|
 | 496 |     int instIndex=0; | 
|---|
 | 497 |     int start=0;         | 
|---|
 | 498 |     for(int h=0; h<nC; h++)  {   | 
|---|
 | 499 |       for (int i = 0; i < datasets.numAttributes(); i++) { | 
|---|
 | 500 |         // initialize m_data[][][] | 
|---|
 | 501 |         m_Data[h][i] = new double[bagSize[h]]; | 
|---|
 | 502 |         instIndex=start; | 
|---|
 | 503 |         for (int k=0; k<bagSize[h]; k++){ | 
|---|
 | 504 |           m_Data[h][i][k]=datasets.instance(instIndex).value(i); | 
|---|
 | 505 |           instIndex ++; | 
|---|
 | 506 |         } | 
|---|
 | 507 |       } | 
|---|
 | 508 |       start=instIndex; | 
|---|
 | 509 |  | 
|---|
 | 510 |       // Class count     | 
|---|
 | 511 |       if (m_Classes[h] == 1) | 
|---|
 | 512 |         sY1++;           | 
|---|
 | 513 |       else | 
|---|
 | 514 |         sY0++; | 
|---|
 | 515 |     } | 
|---|
 | 516 |  | 
|---|
 | 517 |     if (m_Debug) { | 
|---|
 | 518 |       System.out.println("\nIteration History..." ); | 
|---|
 | 519 |     } | 
|---|
 | 520 |  | 
|---|
 | 521 |     double[] x = new double[nR*2], tmp = new double[x.length]; | 
|---|
 | 522 |     double[][] b = new double[2][x.length];  | 
|---|
 | 523 |  | 
|---|
 | 524 |     OptEng opt; | 
|---|
 | 525 |     double nll, bestnll = Double.MAX_VALUE; | 
|---|
 | 526 |     for (int t=0; t<x.length; t++){ | 
|---|
 | 527 |       b[0][t] = Double.NaN; | 
|---|
 | 528 |       b[1][t] = Double.NaN;  | 
|---|
 | 529 |     } | 
|---|
 | 530 |  | 
|---|
 | 531 |     // Largest positive exemplar | 
|---|
 | 532 |     for(int s=0; s<maxSzIdx.size(); s++){ | 
|---|
 | 533 |       int exIdx = ((Integer)maxSzIdx.elementAt(s)).intValue(); | 
|---|
 | 534 |       for(int p=0; p<m_Data[exIdx][0].length; p++){ | 
|---|
 | 535 |         for (int q=0; q < nR;q++){ | 
|---|
 | 536 |           x[2*q] = m_Data[exIdx][q][p];  // pick one instance | 
|---|
 | 537 |           x[2*q+1] = 1.0; | 
|---|
 | 538 |         }                | 
|---|
 | 539 |  | 
|---|
 | 540 |         opt = new OptEng();      | 
|---|
 | 541 |         tmp = opt.findArgmin(x, b); | 
|---|
 | 542 |         while(tmp==null){ | 
|---|
 | 543 |           tmp = opt.getVarbValues(); | 
|---|
 | 544 |           if (m_Debug) | 
|---|
 | 545 |             System.out.println("200 iterations finished, not enough!"); | 
|---|
 | 546 |           tmp = opt.findArgmin(tmp, b); | 
|---|
 | 547 |         } | 
|---|
 | 548 |         nll = opt.getMinFunction(); | 
|---|
 | 549 |  | 
|---|
 | 550 |         if(nll < bestnll){ | 
|---|
 | 551 |           bestnll = nll; | 
|---|
 | 552 |           m_Par = tmp; | 
|---|
 | 553 |           if (m_Debug) | 
|---|
 | 554 |             System.out.println("!!!!!!!!!!!!!!!!Smaller NLL found: "+nll); | 
|---|
 | 555 |         } | 
|---|
 | 556 |         if (m_Debug) | 
|---|
 | 557 |           System.out.println(exIdx+":  -------------<Converged>--------------"); | 
|---|
 | 558 |       } | 
|---|
 | 559 |     } | 
|---|
 | 560 |     }            | 
|---|
 | 561 |  | 
|---|
 | 562 |   /** | 
|---|
 | 563 |    * Computes the distribution for a given exemplar | 
|---|
 | 564 |    * | 
|---|
 | 565 |    * @param exmp the exemplar for which distribution is computed | 
|---|
 | 566 |    * @return the distribution | 
|---|
 | 567 |    * @throws Exception if the distribution can't be computed successfully | 
|---|
 | 568 |    */ | 
|---|
 | 569 |   public double[] distributionForInstance(Instance exmp)  | 
|---|
 | 570 |     throws Exception { | 
|---|
 | 571 |  | 
|---|
 | 572 |     // Extract the data | 
|---|
 | 573 |     Instances ins = exmp.relationalValue(1); | 
|---|
 | 574 |     if(m_Filter!=null) | 
|---|
 | 575 |       ins = Filter.useFilter(ins, m_Filter); | 
|---|
 | 576 |  | 
|---|
 | 577 |     ins = Filter.useFilter(ins, m_Missing); | 
|---|
 | 578 |  | 
|---|
 | 579 |     int nI = ins.numInstances(), nA = ins.numAttributes(); | 
|---|
 | 580 |     double[][] dat = new double [nI][nA]; | 
|---|
 | 581 |     for(int j=0; j<nI; j++){ | 
|---|
 | 582 |       for(int k=0; k<nA; k++){  | 
|---|
 | 583 |         dat[j][k] = ins.instance(j).value(k); | 
|---|
 | 584 |       } | 
|---|
 | 585 |     } | 
|---|
 | 586 |  | 
|---|
 | 587 |     // Compute the probability of the bag | 
|---|
 | 588 |     double [] distribution = new double[2]; | 
|---|
 | 589 |     distribution[1]=0.0;  // Prob. for class 1 | 
|---|
 | 590 |  | 
|---|
 | 591 |     for(int i=0; i<nI; i++){ | 
|---|
 | 592 |       double exp = 0.0; | 
|---|
 | 593 |       for(int r=0; r<nA; r++) | 
|---|
 | 594 |         exp += (m_Par[r*2]-dat[i][r])*(m_Par[r*2]-dat[i][r])/ | 
|---|
 | 595 |           ((m_Par[r*2+1])*(m_Par[r*2+1])); | 
|---|
 | 596 |       exp = Math.exp(-exp); | 
|---|
 | 597 |  | 
|---|
 | 598 |       // Prob. updated for one instance | 
|---|
 | 599 |       distribution[1] += exp/(double)nI; | 
|---|
 | 600 |       distribution[0] += (1.0-exp)/(double)nI; | 
|---|
 | 601 |     } | 
|---|
 | 602 |  | 
|---|
 | 603 |     return distribution; | 
|---|
 | 604 |   } | 
|---|
 | 605 |  | 
|---|
 | 606 |   /** | 
|---|
 | 607 |    * Gets a string describing the classifier. | 
|---|
 | 608 |    * | 
|---|
 | 609 |    * @return a string describing the classifer built. | 
|---|
 | 610 |    */ | 
|---|
 | 611 |   public String toString() { | 
|---|
 | 612 |  | 
|---|
 | 613 |     String result = "Modified Logistic Regression"; | 
|---|
 | 614 |     if (m_Par == null) { | 
|---|
 | 615 |       return result + ": No model built yet."; | 
|---|
 | 616 |     } | 
|---|
 | 617 |  | 
|---|
 | 618 |     result += "\nCoefficients...\n" | 
|---|
 | 619 |       + "Variable      Coeff.\n"; | 
|---|
 | 620 |     for (int j = 0, idx=0; j < m_Par.length/2; j++, idx++) { | 
|---|
 | 621 |  | 
|---|
 | 622 |       result += m_Attributes.attribute(idx).name(); | 
|---|
 | 623 |       result += " "+Utils.doubleToString(m_Par[j*2], 12, 4);  | 
|---|
 | 624 |       result += " "+Utils.doubleToString(m_Par[j*2+1], 12, 4)+"\n"; | 
|---|
 | 625 |     } | 
|---|
 | 626 |  | 
|---|
 | 627 |     return result; | 
|---|
 | 628 |   } | 
|---|
 | 629 |    | 
|---|
 | 630 |   /** | 
|---|
 | 631 |    * Returns the revision string. | 
|---|
 | 632 |    *  | 
|---|
 | 633 |    * @return            the revision | 
|---|
 | 634 |    */ | 
|---|
 | 635 |   public String getRevision() { | 
|---|
 | 636 |     return RevisionUtils.extract("$Revision: 5928 $"); | 
|---|
 | 637 |   } | 
|---|
 | 638 |  | 
|---|
 | 639 |   /** | 
|---|
 | 640 |    * Main method for testing this class. | 
|---|
 | 641 |    * | 
|---|
 | 642 |    * @param argv should contain the command line arguments to the | 
|---|
 | 643 |    * scheme (see Evaluation) | 
|---|
 | 644 |    */ | 
|---|
 | 645 |   public static void main(String[] argv) { | 
|---|
 | 646 |     runClassifier(new MDD(), argv); | 
|---|
 | 647 |   } | 
|---|
 | 648 | } | 
|---|