| 1 | /* | 
|---|
| 2 |  *    This program is free software; you can redistribute it and/or modify | 
|---|
| 3 |  *    it under the terms of the GNU General Public License as published by | 
|---|
| 4 |  *    the Free Software Foundation; either version 2 of the License, or | 
|---|
| 5 |  *    (at your option) any later version. | 
|---|
| 6 |  * | 
|---|
| 7 |  *    This program is distributed in the hope that it will be useful, | 
|---|
| 8 |  *    but WITHOUT ANY WARRANTY; without even the implied warranty of | 
|---|
| 9 |  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | 
|---|
| 10 |  *    GNU General Public License for more details. | 
|---|
| 11 |  * | 
|---|
| 12 |  *    You should have received a copy of the GNU General Public License | 
|---|
| 13 |  *    along with this program; if not, write to the Free Software | 
|---|
| 14 |  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. | 
|---|
| 15 |  */ | 
|---|
| 16 |  | 
|---|
| 17 | /* | 
|---|
| 18 |  * PLSFilter.java | 
|---|
| 19 |  * Copyright (C) 2006 University of Waikato, Hamilton, New Zealand | 
|---|
| 20 |  * | 
|---|
| 21 |  */ | 
|---|
| 22 |  | 
|---|
| 23 | package weka.filters.supervised.attribute; | 
|---|
| 24 |  | 
|---|
| 25 | import weka.core.Attribute; | 
|---|
| 26 | import weka.core.Capabilities; | 
|---|
| 27 | import weka.core.FastVector; | 
|---|
| 28 | import weka.core.Instance; | 
|---|
| 29 | import weka.core.DenseInstance; | 
|---|
| 30 | import weka.core.Instances; | 
|---|
| 31 | import weka.core.Option; | 
|---|
| 32 | import weka.core.RevisionUtils; | 
|---|
| 33 | import weka.core.SelectedTag; | 
|---|
| 34 | import weka.core.Tag; | 
|---|
| 35 | import weka.core.TechnicalInformation; | 
|---|
| 36 | import weka.core.TechnicalInformationHandler; | 
|---|
| 37 | import weka.core.Utils; | 
|---|
| 38 | import weka.core.Capabilities.Capability; | 
|---|
| 39 | import weka.core.TechnicalInformation.Field; | 
|---|
| 40 | import weka.core.TechnicalInformation.Type; | 
|---|
| 41 | import weka.core.matrix.EigenvalueDecomposition; | 
|---|
| 42 | import weka.core.matrix.Matrix; | 
|---|
| 43 | import weka.filters.Filter; | 
|---|
| 44 | import weka.filters.SimpleBatchFilter; | 
|---|
| 45 | import weka.filters.SupervisedFilter; | 
|---|
| 46 | import weka.filters.unsupervised.attribute.Center; | 
|---|
| 47 | import weka.filters.unsupervised.attribute.ReplaceMissingValues; | 
|---|
| 48 | import weka.filters.unsupervised.attribute.Standardize; | 
|---|
| 49 |  | 
|---|
| 50 | import java.util.Enumeration; | 
|---|
| 51 | import java.util.Vector; | 
|---|
| 52 |  | 
|---|
| 53 | /**  | 
|---|
| 54 |  <!-- globalinfo-start --> | 
|---|
| 55 |  * Runs Partial Least Square Regression over the given instances and computes the resulting beta matrix for prediction.<br/> | 
|---|
| 56 |  * By default it replaces missing values and centers the data.<br/> | 
|---|
| 57 |  * <br/> | 
|---|
| 58 |  * For more information see:<br/> | 
|---|
| 59 |  * <br/> | 
|---|
| 60 |  * Tormod Naes, Tomas Isaksson, Tom Fearn, Tony Davies (2002). A User Friendly Guide to Multivariate Calibration and Classification. NIR Publications.<br/> | 
|---|
| 61 |  * <br/> | 
|---|
| 62 |  * StatSoft, Inc.. Partial Least Squares (PLS).<br/> | 
|---|
| 63 |  * <br/> | 
|---|
| 64 |  * Bent Jorgensen, Yuri Goegebeur. Module 7: Partial least squares regression I.<br/> | 
|---|
| 65 |  * <br/> | 
|---|
| 66 |  * S. de Jong (1993). SIMPLS: an alternative approach to partial least squares regression. Chemometrics and Intelligent Laboratory Systems. 18:251-263. | 
|---|
| 67 |  * <p/> | 
|---|
| 68 |  <!-- globalinfo-end --> | 
|---|
| 69 |  * | 
|---|
| 70 |  <!-- technical-bibtex-start --> | 
|---|
| 71 |  * BibTeX: | 
|---|
| 72 |  * <pre> | 
|---|
| 73 |  * @book{Naes2002, | 
|---|
| 74 |  *    author = {Tormod Naes and Tomas Isaksson and Tom Fearn and Tony Davies}, | 
|---|
| 75 |  *    publisher = {NIR Publications}, | 
|---|
| 76 |  *    title = {A User Friendly Guide to Multivariate Calibration and Classification}, | 
|---|
| 77 |  *    year = {2002}, | 
|---|
| 78 |  *    ISBN = {0-9528666-2-5} | 
|---|
| 79 |  * } | 
|---|
| 80 |  *  | 
|---|
| 81 |  * @misc{missing_id, | 
|---|
| 82 |  *    author = {StatSoft, Inc.}, | 
|---|
| 83 |  *    booktitle = {Electronic Textbook StatSoft}, | 
|---|
| 84 |  *    title = {Partial Least Squares (PLS)}, | 
|---|
| 85 |  *    HTTP = {http://www.statsoft.com/textbook/stpls.html} | 
|---|
| 86 |  * } | 
|---|
| 87 |  *  | 
|---|
| 88 |  * @misc{missing_id, | 
|---|
| 89 |  *    author = {Bent Jorgensen and Yuri Goegebeur}, | 
|---|
| 90 |  *    booktitle = {ST02: Multivariate Data Analysis and Chemometrics}, | 
|---|
| 91 |  *    title = {Module 7: Partial least squares regression I}, | 
|---|
| 92 |  *    HTTP = {http://statmaster.sdu.dk/courses/ST02/module07/} | 
|---|
| 93 |  * } | 
|---|
| 94 |  *  | 
|---|
| 95 |  * @article{Jong1993, | 
|---|
| 96 |  *    author = {S. de Jong}, | 
|---|
| 97 |  *    journal = {Chemometrics and Intelligent Laboratory Systems}, | 
|---|
| 98 |  *    pages = {251-263}, | 
|---|
| 99 |  *    title = {SIMPLS: an alternative approach to partial least squares regression}, | 
|---|
| 100 |  *    volume = {18}, | 
|---|
| 101 |  *    year = {1993} | 
|---|
| 102 |  * } | 
|---|
| 103 |  * </pre> | 
|---|
| 104 |  * <p/> | 
|---|
| 105 |  <!-- technical-bibtex-end --> | 
|---|
| 106 |  * | 
|---|
| 107 |  <!-- options-start --> | 
|---|
| 108 |  * Valid options are: <p/> | 
|---|
| 109 |  *  | 
|---|
| 110 |  * <pre> -D | 
|---|
| 111 |  *  Turns on output of debugging information.</pre> | 
|---|
| 112 |  *  | 
|---|
| 113 |  * <pre> -C <num> | 
|---|
| 114 |  *  The number of components to compute. | 
|---|
| 115 |  *  (default: 20)</pre> | 
|---|
| 116 |  *  | 
|---|
| 117 |  * <pre> -U | 
|---|
| 118 |  *  Updates the class attribute as well. | 
|---|
| 119 |  *  (default: off)</pre> | 
|---|
| 120 |  *  | 
|---|
| 121 |  * <pre> -M | 
|---|
| 122 |  *  Turns replacing of missing values on. | 
|---|
| 123 |  *  (default: off)</pre> | 
|---|
| 124 |  *  | 
|---|
| 125 |  * <pre> -A <SIMPLS|PLS1> | 
|---|
| 126 |  *  The algorithm to use. | 
|---|
| 127 |  *  (default: PLS1)</pre> | 
|---|
| 128 |  *  | 
|---|
| 129 |  * <pre> -P <none|center|standardize> | 
|---|
| 130 |  *  The type of preprocessing that is applied to the data. | 
|---|
| 131 |  *  (default: center)</pre> | 
|---|
| 132 |  *  | 
|---|
| 133 |  <!-- options-end --> | 
|---|
| 134 |  * | 
|---|
| 135 |  * @author FracPete (fracpete at waikato dot ac dot nz) | 
|---|
| 136 |  * @version $Revision: 5987 $ | 
|---|
| 137 |  */ | 
|---|
| 138 | public class PLSFilter | 
|---|
| 139 |   extends SimpleBatchFilter  | 
|---|
| 140 |   implements SupervisedFilter, TechnicalInformationHandler { | 
|---|
| 141 |  | 
|---|
| 142 |   /** for serialization */ | 
|---|
| 143 |   static final long serialVersionUID = -3335106965521265631L; | 
|---|
| 144 |  | 
|---|
| 145 |   /** the type of algorithm: SIMPLS */ | 
|---|
| 146 |   public static final int ALGORITHM_SIMPLS = 1; | 
|---|
| 147 |   /** the type of algorithm: PLS1 */ | 
|---|
| 148 |   public static final int ALGORITHM_PLS1 = 2; | 
|---|
| 149 |   /** the types of algorithm */ | 
|---|
| 150 |   public static final Tag[] TAGS_ALGORITHM = { | 
|---|
| 151 |     new Tag(ALGORITHM_SIMPLS, "SIMPLS"), | 
|---|
| 152 |     new Tag(ALGORITHM_PLS1, "PLS1") | 
|---|
| 153 |   }; | 
|---|
| 154 |  | 
|---|
| 155 |   /** the type of preprocessing: None */ | 
|---|
| 156 |   public static final int PREPROCESSING_NONE = 0; | 
|---|
| 157 |   /** the type of preprocessing: Center */ | 
|---|
| 158 |   public static final int PREPROCESSING_CENTER = 1; | 
|---|
| 159 |   /** the type of preprocessing: Standardize */ | 
|---|
| 160 |   public static final int PREPROCESSING_STANDARDIZE = 2; | 
|---|
| 161 |   /** the types of preprocessing */ | 
|---|
| 162 |   public static final Tag[] TAGS_PREPROCESSING = { | 
|---|
| 163 |     new Tag(PREPROCESSING_NONE, "none"), | 
|---|
| 164 |     new Tag(PREPROCESSING_CENTER, "center"), | 
|---|
| 165 |     new Tag(PREPROCESSING_STANDARDIZE, "standardize") | 
|---|
| 166 |   }; | 
|---|
| 167 |  | 
|---|
| 168 |   /** the maximum number of components to generate */ | 
|---|
| 169 |   protected int m_NumComponents = 20; | 
|---|
| 170 |    | 
|---|
| 171 |   /** the type of algorithm */ | 
|---|
| 172 |   protected int m_Algorithm = ALGORITHM_PLS1; | 
|---|
| 173 |  | 
|---|
| 174 |   /** the regression vector "r-hat" for PLS1 */ | 
|---|
| 175 |   protected Matrix m_PLS1_RegVector = null; | 
|---|
| 176 |  | 
|---|
| 177 |   /** the P matrix for PLS1 */ | 
|---|
| 178 |   protected Matrix m_PLS1_P = null; | 
|---|
| 179 |  | 
|---|
| 180 |   /** the W matrix for PLS1 */ | 
|---|
| 181 |   protected Matrix m_PLS1_W = null; | 
|---|
| 182 |  | 
|---|
| 183 |   /** the b-hat vector for PLS1 */ | 
|---|
| 184 |   protected Matrix m_PLS1_b_hat = null; | 
|---|
| 185 |    | 
|---|
| 186 |   /** the W matrix for SIMPLS */ | 
|---|
| 187 |   protected Matrix m_SIMPLS_W = null; | 
|---|
| 188 |    | 
|---|
| 189 |   /** the B matrix for SIMPLS (used for prediction) */ | 
|---|
| 190 |   protected Matrix m_SIMPLS_B = null; | 
|---|
| 191 |    | 
|---|
| 192 |   /** whether to include the prediction, i.e., modifying the class attribute */ | 
|---|
| 193 |   protected boolean m_PerformPrediction = false; | 
|---|
| 194 |  | 
|---|
| 195 |   /** for replacing missing values */ | 
|---|
| 196 |   protected Filter m_Missing = null; | 
|---|
| 197 |    | 
|---|
| 198 |   /** whether to replace missing values */ | 
|---|
| 199 |   protected boolean m_ReplaceMissing = true; | 
|---|
| 200 |    | 
|---|
| 201 |   /** for centering the data */ | 
|---|
| 202 |   protected Filter m_Filter = null; | 
|---|
| 203 |    | 
|---|
| 204 |   /** the type of preprocessing */ | 
|---|
| 205 |   protected int m_Preprocessing = PREPROCESSING_CENTER; | 
|---|
| 206 |  | 
|---|
| 207 |   /** the mean of the class */ | 
|---|
| 208 |   protected double m_ClassMean = 0; | 
|---|
| 209 |  | 
|---|
| 210 |   /** the standard deviation of the class */ | 
|---|
| 211 |   protected double m_ClassStdDev = 0; | 
|---|
| 212 |    | 
|---|
| 213 |   /** | 
|---|
| 214 |    * default constructor | 
|---|
| 215 |    */ | 
|---|
| 216 |   public PLSFilter() { | 
|---|
| 217 |     super(); | 
|---|
| 218 |      | 
|---|
| 219 |     // setup pre-processing | 
|---|
| 220 |     m_Missing = new ReplaceMissingValues(); | 
|---|
| 221 |     m_Filter  = new Center(); | 
|---|
| 222 |   } | 
|---|
| 223 |    | 
|---|
| 224 |   /** | 
|---|
| 225 |    * Returns a string describing this classifier. | 
|---|
| 226 |    * | 
|---|
| 227 |    * @return      a description of the classifier suitable for | 
|---|
| 228 |    *              displaying in the explorer/experimenter gui | 
|---|
| 229 |    */ | 
|---|
| 230 |   public String globalInfo() { | 
|---|
| 231 |     return  | 
|---|
| 232 |         "Runs Partial Least Square Regression over the given instances " | 
|---|
| 233 |       + "and computes the resulting beta matrix for prediction.\n" | 
|---|
| 234 |       + "By default it replaces missing values and centers the data.\n\n" | 
|---|
| 235 |       + "For more information see:\n\n" | 
|---|
| 236 |       + getTechnicalInformation().toString(); | 
|---|
| 237 |   } | 
|---|
| 238 |  | 
|---|
| 239 |   /** | 
|---|
| 240 |    * Returns an instance of a TechnicalInformation object, containing  | 
|---|
| 241 |    * detailed information about the technical background of this class, | 
|---|
| 242 |    * e.g., paper reference or book this class is based on. | 
|---|
| 243 |    *  | 
|---|
| 244 |    * @return the technical information about this class | 
|---|
| 245 |    */ | 
|---|
| 246 |   public TechnicalInformation getTechnicalInformation() { | 
|---|
| 247 |     TechnicalInformation        result; | 
|---|
| 248 |     TechnicalInformation        additional; | 
|---|
| 249 |      | 
|---|
| 250 |     result = new TechnicalInformation(Type.BOOK); | 
|---|
| 251 |     result.setValue(Field.AUTHOR, "Tormod Naes and Tomas Isaksson and Tom Fearn and Tony Davies"); | 
|---|
| 252 |     result.setValue(Field.YEAR, "2002"); | 
|---|
| 253 |     result.setValue(Field.TITLE, "A User Friendly Guide to Multivariate Calibration and Classification"); | 
|---|
| 254 |     result.setValue(Field.PUBLISHER, "NIR Publications"); | 
|---|
| 255 |     result.setValue(Field.ISBN, "0-9528666-2-5"); | 
|---|
| 256 |      | 
|---|
| 257 |     additional = result.add(Type.MISC); | 
|---|
| 258 |     additional.setValue(Field.AUTHOR, "StatSoft, Inc."); | 
|---|
| 259 |     additional.setValue(Field.TITLE, "Partial Least Squares (PLS)"); | 
|---|
| 260 |     additional.setValue(Field.BOOKTITLE, "Electronic Textbook StatSoft"); | 
|---|
| 261 |     additional.setValue(Field.HTTP, "http://www.statsoft.com/textbook/stpls.html"); | 
|---|
| 262 |      | 
|---|
| 263 |     additional = result.add(Type.MISC); | 
|---|
| 264 |     additional.setValue(Field.AUTHOR, "Bent Jorgensen and Yuri Goegebeur"); | 
|---|
| 265 |     additional.setValue(Field.TITLE, "Module 7: Partial least squares regression I"); | 
|---|
| 266 |     additional.setValue(Field.BOOKTITLE, "ST02: Multivariate Data Analysis and Chemometrics"); | 
|---|
| 267 |     additional.setValue(Field.HTTP, "http://statmaster.sdu.dk/courses/ST02/module07/"); | 
|---|
| 268 |      | 
|---|
| 269 |     additional = result.add(Type.ARTICLE); | 
|---|
| 270 |     additional.setValue(Field.AUTHOR, "S. de Jong"); | 
|---|
| 271 |     additional.setValue(Field.YEAR, "1993"); | 
|---|
| 272 |     additional.setValue(Field.TITLE, "SIMPLS: an alternative approach to partial least squares regression"); | 
|---|
| 273 |     additional.setValue(Field.JOURNAL, "Chemometrics and Intelligent Laboratory Systems"); | 
|---|
| 274 |     additional.setValue(Field.VOLUME, "18"); | 
|---|
| 275 |     additional.setValue(Field.PAGES, "251-263"); | 
|---|
| 276 |      | 
|---|
| 277 |     return result; | 
|---|
| 278 |   } | 
|---|
| 279 |  | 
|---|
| 280 |   /** | 
|---|
| 281 |    * Gets an enumeration describing the available options. | 
|---|
| 282 |    * | 
|---|
| 283 |    * @return an enumeration of all the available options. | 
|---|
| 284 |    */ | 
|---|
| 285 |   public Enumeration listOptions() { | 
|---|
| 286 |     Vector              result; | 
|---|
| 287 |     Enumeration         enm; | 
|---|
| 288 |     String              param; | 
|---|
| 289 |     SelectedTag         tag; | 
|---|
| 290 |     int                 i; | 
|---|
| 291 |  | 
|---|
| 292 |     result = new Vector(); | 
|---|
| 293 |  | 
|---|
| 294 |     enm = super.listOptions(); | 
|---|
| 295 |     while (enm.hasMoreElements()) | 
|---|
| 296 |       result.addElement(enm.nextElement()); | 
|---|
| 297 |  | 
|---|
| 298 |     result.addElement(new Option( | 
|---|
| 299 |         "\tThe number of components to compute.\n" | 
|---|
| 300 |         + "\t(default: 20)", | 
|---|
| 301 |         "C", 1, "-C <num>")); | 
|---|
| 302 |  | 
|---|
| 303 |     result.addElement(new Option( | 
|---|
| 304 |         "\tUpdates the class attribute as well.\n" | 
|---|
| 305 |         + "\t(default: off)", | 
|---|
| 306 |         "U", 0, "-U")); | 
|---|
| 307 |  | 
|---|
| 308 |     result.addElement(new Option( | 
|---|
| 309 |         "\tTurns replacing of missing values on.\n" | 
|---|
| 310 |         + "\t(default: off)", | 
|---|
| 311 |         "M", 0, "-M")); | 
|---|
| 312 |  | 
|---|
| 313 |     param = ""; | 
|---|
| 314 |     for (i = 0; i < TAGS_ALGORITHM.length; i++) { | 
|---|
| 315 |       if (i > 0) | 
|---|
| 316 |         param += "|"; | 
|---|
| 317 |       tag = new SelectedTag(TAGS_ALGORITHM[i].getID(), TAGS_ALGORITHM); | 
|---|
| 318 |       param += tag.getSelectedTag().getReadable(); | 
|---|
| 319 |     } | 
|---|
| 320 |     result.addElement(new Option( | 
|---|
| 321 |         "\tThe algorithm to use.\n" | 
|---|
| 322 |         + "\t(default: PLS1)", | 
|---|
| 323 |         "A", 1, "-A <" + param + ">")); | 
|---|
| 324 |  | 
|---|
| 325 |     param = ""; | 
|---|
| 326 |     for (i = 0; i < TAGS_PREPROCESSING.length; i++) { | 
|---|
| 327 |       if (i > 0) | 
|---|
| 328 |         param += "|"; | 
|---|
| 329 |       tag = new SelectedTag(TAGS_PREPROCESSING[i].getID(), TAGS_PREPROCESSING); | 
|---|
| 330 |       param += tag.getSelectedTag().getReadable(); | 
|---|
| 331 |     } | 
|---|
| 332 |     result.addElement(new Option( | 
|---|
| 333 |         "\tThe type of preprocessing that is applied to the data.\n" | 
|---|
| 334 |         + "\t(default: center)", | 
|---|
| 335 |         "P", 1, "-P <" + param + ">")); | 
|---|
| 336 |  | 
|---|
| 337 |     return result.elements(); | 
|---|
| 338 |   } | 
|---|
| 339 |  | 
|---|
| 340 |   /** | 
|---|
| 341 |    * returns the options of the current setup | 
|---|
| 342 |    * | 
|---|
| 343 |    * @return      the current options | 
|---|
| 344 |    */ | 
|---|
| 345 |   public String[] getOptions() { | 
|---|
| 346 |     int       i; | 
|---|
| 347 |     Vector    result; | 
|---|
| 348 |     String[]  options; | 
|---|
| 349 |  | 
|---|
| 350 |     result = new Vector(); | 
|---|
| 351 |     options = super.getOptions(); | 
|---|
| 352 |     for (i = 0; i < options.length; i++) | 
|---|
| 353 |       result.add(options[i]); | 
|---|
| 354 |  | 
|---|
| 355 |     result.add("-C"); | 
|---|
| 356 |     result.add("" + getNumComponents()); | 
|---|
| 357 |  | 
|---|
| 358 |     if (getPerformPrediction()) | 
|---|
| 359 |       result.add("-U"); | 
|---|
| 360 |      | 
|---|
| 361 |     if (getReplaceMissing()) | 
|---|
| 362 |       result.add("-M"); | 
|---|
| 363 |      | 
|---|
| 364 |     result.add("-A"); | 
|---|
| 365 |     result.add("" + getAlgorithm().getSelectedTag().getReadable()); | 
|---|
| 366 |  | 
|---|
| 367 |     result.add("-P"); | 
|---|
| 368 |     result.add("" + getPreprocessing().getSelectedTag().getReadable()); | 
|---|
| 369 |  | 
|---|
| 370 |     return (String[]) result.toArray(new String[result.size()]);           | 
|---|
| 371 |   } | 
|---|
| 372 |  | 
|---|
| 373 |   /** | 
|---|
| 374 |    * Parses the options for this object. <p/> | 
|---|
| 375 |    * | 
|---|
| 376 |    <!-- options-start --> | 
|---|
| 377 |    * Valid options are: <p/> | 
|---|
| 378 |    *  | 
|---|
| 379 |    * <pre> -D | 
|---|
| 380 |    *  Turns on output of debugging information.</pre> | 
|---|
| 381 |    *  | 
|---|
| 382 |    * <pre> -C <num> | 
|---|
| 383 |    *  The number of components to compute. | 
|---|
| 384 |    *  (default: 20)</pre> | 
|---|
| 385 |    *  | 
|---|
| 386 |    * <pre> -U | 
|---|
| 387 |    *  Updates the class attribute as well. | 
|---|
| 388 |    *  (default: off)</pre> | 
|---|
| 389 |    *  | 
|---|
| 390 |    * <pre> -M | 
|---|
| 391 |    *  Turns replacing of missing values on. | 
|---|
| 392 |    *  (default: off)</pre> | 
|---|
| 393 |    *  | 
|---|
| 394 |    * <pre> -A <SIMPLS|PLS1> | 
|---|
| 395 |    *  The algorithm to use. | 
|---|
| 396 |    *  (default: PLS1)</pre> | 
|---|
| 397 |    *  | 
|---|
| 398 |    * <pre> -P <none|center|standardize> | 
|---|
| 399 |    *  The type of preprocessing that is applied to the data. | 
|---|
| 400 |    *  (default: center)</pre> | 
|---|
| 401 |    *  | 
|---|
| 402 |    <!-- options-end --> | 
|---|
| 403 |    * | 
|---|
| 404 |    * @param options     the options to use | 
|---|
| 405 |    * @throws Exception  if the option setting fails | 
|---|
| 406 |    */ | 
|---|
| 407 |   public void setOptions(String[] options) throws Exception { | 
|---|
| 408 |     String      tmpStr; | 
|---|
| 409 |  | 
|---|
| 410 |     super.setOptions(options); | 
|---|
| 411 |  | 
|---|
| 412 |     tmpStr = Utils.getOption("C", options); | 
|---|
| 413 |     if (tmpStr.length() != 0) | 
|---|
| 414 |       setNumComponents(Integer.parseInt(tmpStr)); | 
|---|
| 415 |     else | 
|---|
| 416 |       setNumComponents(20); | 
|---|
| 417 |  | 
|---|
| 418 |     setPerformPrediction(Utils.getFlag("U", options)); | 
|---|
| 419 |      | 
|---|
| 420 |     setReplaceMissing(Utils.getFlag("M", options)); | 
|---|
| 421 |      | 
|---|
| 422 |     tmpStr = Utils.getOption("A", options); | 
|---|
| 423 |     if (tmpStr.length() != 0) | 
|---|
| 424 |       setAlgorithm(new SelectedTag(tmpStr, TAGS_ALGORITHM)); | 
|---|
| 425 |     else | 
|---|
| 426 |       setAlgorithm(new SelectedTag(ALGORITHM_PLS1, TAGS_ALGORITHM)); | 
|---|
| 427 |      | 
|---|
| 428 |     tmpStr = Utils.getOption("P", options); | 
|---|
| 429 |     if (tmpStr.length() != 0) | 
|---|
| 430 |       setPreprocessing(new SelectedTag(tmpStr, TAGS_PREPROCESSING)); | 
|---|
| 431 |     else | 
|---|
| 432 |       setPreprocessing(new SelectedTag(PREPROCESSING_CENTER, TAGS_PREPROCESSING)); | 
|---|
| 433 |   } | 
|---|
| 434 |  | 
|---|
| 435 |   /** | 
|---|
| 436 |    * Returns the tip text for this property | 
|---|
| 437 |    * | 
|---|
| 438 |    * @return            tip text for this property suitable for | 
|---|
| 439 |    *                    displaying in the explorer/experimenter gui | 
|---|
| 440 |    */ | 
|---|
| 441 |   public String numComponentsTipText() { | 
|---|
| 442 |     return "The number of components to compute."; | 
|---|
| 443 |   } | 
|---|
| 444 |  | 
|---|
| 445 |   /** | 
|---|
| 446 |    * sets the maximum number of attributes to use. | 
|---|
| 447 |    *  | 
|---|
| 448 |    * @param value       the maximum number of attributes | 
|---|
| 449 |    */ | 
|---|
| 450 |   public void setNumComponents(int value) { | 
|---|
| 451 |     m_NumComponents = value; | 
|---|
| 452 |   } | 
|---|
| 453 |  | 
|---|
| 454 |   /** | 
|---|
| 455 |    * returns the maximum number of attributes to use. | 
|---|
| 456 |    *  | 
|---|
| 457 |    * @return            the current maximum number of attributes | 
|---|
| 458 |    */ | 
|---|
| 459 |   public int getNumComponents() { | 
|---|
| 460 |     return m_NumComponents; | 
|---|
| 461 |   } | 
|---|
| 462 |  | 
|---|
| 463 |   /** | 
|---|
| 464 |    * Returns the tip text for this property | 
|---|
| 465 |    * | 
|---|
| 466 |    * @return            tip text for this property suitable for | 
|---|
| 467 |    *                    displaying in the explorer/experimenter gui | 
|---|
| 468 |    */ | 
|---|
| 469 |   public String performPredictionTipText() { | 
|---|
| 470 |     return "Whether to update the class attribute with the predicted value."; | 
|---|
| 471 |   } | 
|---|
| 472 |  | 
|---|
| 473 |   /** | 
|---|
| 474 |    * Sets whether to update the class attribute with the predicted value. | 
|---|
| 475 |    *  | 
|---|
| 476 |    * @param value       if true the class value will be replaced by the  | 
|---|
| 477 |    *                    predicted value. | 
|---|
| 478 |    */ | 
|---|
| 479 |   public void setPerformPrediction(boolean value) { | 
|---|
| 480 |     m_PerformPrediction = value; | 
|---|
| 481 |   } | 
|---|
| 482 |  | 
|---|
| 483 |   /** | 
|---|
| 484 |    * Gets whether the class attribute is updated with the predicted value. | 
|---|
| 485 |    *  | 
|---|
| 486 |    * @return            true if the class attribute is updated | 
|---|
| 487 |    */ | 
|---|
| 488 |   public boolean getPerformPrediction() { | 
|---|
| 489 |     return m_PerformPrediction; | 
|---|
| 490 |   } | 
|---|
| 491 |  | 
|---|
| 492 |   /** | 
|---|
| 493 |    * Returns the tip text for this property | 
|---|
| 494 |    *  | 
|---|
| 495 |    * @return            tip text for this property suitable for | 
|---|
| 496 |    *                    displaying in the explorer/experimenter gui | 
|---|
| 497 |    */ | 
|---|
| 498 |   public String algorithmTipText() { | 
|---|
| 499 |     return "Sets the type of algorithm to use."; | 
|---|
| 500 |   } | 
|---|
| 501 |  | 
|---|
| 502 |   /** | 
|---|
| 503 |    * Sets the type of algorithm to use  | 
|---|
| 504 |    * | 
|---|
| 505 |    * @param value       the algorithm type | 
|---|
| 506 |    */ | 
|---|
| 507 |   public void setAlgorithm(SelectedTag value) { | 
|---|
| 508 |     if (value.getTags() == TAGS_ALGORITHM) { | 
|---|
| 509 |       m_Algorithm = value.getSelectedTag().getID(); | 
|---|
| 510 |     } | 
|---|
| 511 |   } | 
|---|
| 512 |  | 
|---|
| 513 |   /** | 
|---|
| 514 |    * Gets the type of algorithm to use  | 
|---|
| 515 |    * | 
|---|
| 516 |    * @return            the current algorithm type. | 
|---|
| 517 |    */ | 
|---|
| 518 |   public SelectedTag getAlgorithm() { | 
|---|
| 519 |     return new SelectedTag(m_Algorithm, TAGS_ALGORITHM); | 
|---|
| 520 |   } | 
|---|
| 521 |  | 
|---|
| 522 |   /** | 
|---|
| 523 |    * Returns the tip text for this property | 
|---|
| 524 |    * | 
|---|
| 525 |    * @return            tip text for this property suitable for | 
|---|
| 526 |    *                    displaying in the explorer/experimenter gui | 
|---|
| 527 |    */ | 
|---|
| 528 |   public String replaceMissingTipText() { | 
|---|
| 529 |     return "Whether to replace missing values."; | 
|---|
| 530 |   } | 
|---|
| 531 |  | 
|---|
| 532 |   /** | 
|---|
| 533 |    * Sets whether to replace missing values. | 
|---|
| 534 |    *  | 
|---|
| 535 |    * @param value       if true missing values are replaced with the | 
|---|
| 536 |    *                    ReplaceMissingValues filter. | 
|---|
| 537 |    */ | 
|---|
| 538 |   public void setReplaceMissing(boolean value) { | 
|---|
| 539 |     m_ReplaceMissing = value; | 
|---|
| 540 |   } | 
|---|
| 541 |  | 
|---|
| 542 |   /** | 
|---|
| 543 |    * Gets whether missing values are replace. | 
|---|
| 544 |    *  | 
|---|
| 545 |    * @return            true if missing values are replaced with the  | 
|---|
| 546 |    *                    ReplaceMissingValues filter | 
|---|
| 547 |    */ | 
|---|
| 548 |   public boolean getReplaceMissing() { | 
|---|
| 549 |     return m_ReplaceMissing; | 
|---|
| 550 |   } | 
|---|
| 551 |  | 
|---|
| 552 |   /** | 
|---|
| 553 |    * Returns the tip text for this property | 
|---|
| 554 |    *  | 
|---|
| 555 |    * @return            tip text for this property suitable for | 
|---|
| 556 |    *                    displaying in the explorer/experimenter gui | 
|---|
| 557 |    */ | 
|---|
| 558 |   public String preprocessingTipText() { | 
|---|
| 559 |     return "Sets the type of preprocessing to use."; | 
|---|
| 560 |   } | 
|---|
| 561 |  | 
|---|
| 562 |   /** | 
|---|
| 563 |    * Sets the type of preprocessing to use  | 
|---|
| 564 |    * | 
|---|
| 565 |    * @param value       the preprocessing type | 
|---|
| 566 |    */ | 
|---|
| 567 |   public void setPreprocessing(SelectedTag value) { | 
|---|
| 568 |     if (value.getTags() == TAGS_PREPROCESSING) { | 
|---|
| 569 |       m_Preprocessing = value.getSelectedTag().getID(); | 
|---|
| 570 |     } | 
|---|
| 571 |   } | 
|---|
| 572 |  | 
|---|
| 573 |   /** | 
|---|
| 574 |    * Gets the type of preprocessing to use  | 
|---|
| 575 |    * | 
|---|
| 576 |    * @return            the current preprocessing type. | 
|---|
| 577 |    */ | 
|---|
| 578 |   public SelectedTag getPreprocessing() { | 
|---|
| 579 |     return new SelectedTag(m_Preprocessing, TAGS_PREPROCESSING); | 
|---|
| 580 |   } | 
|---|
| 581 |  | 
|---|
| 582 |   /** | 
|---|
| 583 |    * Determines the output format based on the input format and returns  | 
|---|
| 584 |    * this. In case the output format cannot be returned immediately, i.e., | 
|---|
| 585 |    * immediateOutputFormat() returns false, then this method will be called | 
|---|
| 586 |    * from batchFinished(). | 
|---|
| 587 |    * | 
|---|
| 588 |    * @param inputFormat     the input format to base the output format on | 
|---|
| 589 |    * @return                the output format | 
|---|
| 590 |    * @throws Exception      in case the determination goes wrong | 
|---|
| 591 |    * @see   #hasImmediateOutputFormat() | 
|---|
| 592 |    * @see   #batchFinished() | 
|---|
| 593 |    */ | 
|---|
| 594 |   protected Instances determineOutputFormat(Instances inputFormat)  | 
|---|
| 595 |     throws Exception { | 
|---|
| 596 |  | 
|---|
| 597 |     // generate header | 
|---|
| 598 |     FastVector atts = new FastVector(); | 
|---|
| 599 |     String prefix = getAlgorithm().getSelectedTag().getReadable(); | 
|---|
| 600 |     for (int i = 0; i < getNumComponents(); i++) | 
|---|
| 601 |       atts.addElement(new Attribute(prefix + "_" + (i+1))); | 
|---|
| 602 |     atts.addElement(new Attribute("Class")); | 
|---|
| 603 |     Instances result = new Instances(prefix, atts, 0); | 
|---|
| 604 |     result.setClassIndex(result.numAttributes() - 1); | 
|---|
| 605 |      | 
|---|
| 606 |     return result; | 
|---|
| 607 |   } | 
|---|
| 608 |    | 
|---|
| 609 |   /** | 
|---|
| 610 |    * returns the data minus the class column as matrix | 
|---|
| 611 |    *  | 
|---|
| 612 |    * @param instances   the data to work on | 
|---|
| 613 |    * @return            the data without class attribute | 
|---|
| 614 |    */ | 
|---|
| 615 |   protected Matrix getX(Instances instances) { | 
|---|
| 616 |     double[][]  x; | 
|---|
| 617 |     double[]    values; | 
|---|
| 618 |     Matrix      result; | 
|---|
| 619 |     int         i; | 
|---|
| 620 |     int         n; | 
|---|
| 621 |     int         j; | 
|---|
| 622 |     int         clsIndex; | 
|---|
| 623 |      | 
|---|
| 624 |     clsIndex = instances.classIndex(); | 
|---|
| 625 |     x        = new double[instances.numInstances()][]; | 
|---|
| 626 |      | 
|---|
| 627 |     for (i = 0; i < instances.numInstances(); i++) { | 
|---|
| 628 |       values = instances.instance(i).toDoubleArray(); | 
|---|
| 629 |       x[i]   = new double[values.length - 1]; | 
|---|
| 630 |        | 
|---|
| 631 |       j = 0; | 
|---|
| 632 |       for (n = 0; n < values.length; n++) { | 
|---|
| 633 |         if (n != clsIndex) { | 
|---|
| 634 |           x[i][j] = values[n]; | 
|---|
| 635 |           j++; | 
|---|
| 636 |         } | 
|---|
| 637 |       } | 
|---|
| 638 |     } | 
|---|
| 639 |      | 
|---|
| 640 |     result = new Matrix(x); | 
|---|
| 641 |      | 
|---|
| 642 |     return result; | 
|---|
| 643 |   } | 
|---|
| 644 |    | 
|---|
| 645 |   /** | 
|---|
| 646 |    * returns the data minus the class column as matrix | 
|---|
| 647 |    *  | 
|---|
| 648 |    * @param instance    the instance to work on | 
|---|
| 649 |    * @return            the data without the class attribute | 
|---|
| 650 |    */ | 
|---|
| 651 |   protected Matrix getX(Instance instance) { | 
|---|
| 652 |     double[][]  x; | 
|---|
| 653 |     double[]    values; | 
|---|
| 654 |     Matrix      result; | 
|---|
| 655 |      | 
|---|
| 656 |     x = new double[1][]; | 
|---|
| 657 |     values = instance.toDoubleArray(); | 
|---|
| 658 |     x[0] = new double[values.length - 1]; | 
|---|
| 659 |     System.arraycopy(values, 0, x[0], 0, values.length - 1); | 
|---|
| 660 |      | 
|---|
| 661 |     result = new Matrix(x); | 
|---|
| 662 |      | 
|---|
| 663 |     return result; | 
|---|
| 664 |   } | 
|---|
| 665 |    | 
|---|
| 666 |   /** | 
|---|
| 667 |    * returns the data class column as matrix | 
|---|
| 668 |    *  | 
|---|
| 669 |    * @param instances   the data to work on | 
|---|
| 670 |    * @return            the class attribute | 
|---|
| 671 |    */ | 
|---|
| 672 |   protected Matrix getY(Instances instances) { | 
|---|
| 673 |     double[][]  y; | 
|---|
| 674 |     Matrix      result; | 
|---|
| 675 |     int         i; | 
|---|
| 676 |      | 
|---|
| 677 |     y = new double[instances.numInstances()][1]; | 
|---|
| 678 |     for (i = 0; i < instances.numInstances(); i++) | 
|---|
| 679 |       y[i][0] = instances.instance(i).classValue(); | 
|---|
| 680 |      | 
|---|
| 681 |     result = new Matrix(y); | 
|---|
| 682 |      | 
|---|
| 683 |     return result; | 
|---|
| 684 |   } | 
|---|
| 685 |    | 
|---|
| 686 |   /** | 
|---|
| 687 |    * returns the data class column as matrix | 
|---|
| 688 |    *  | 
|---|
| 689 |    * @param instance    the instance to work on | 
|---|
| 690 |    * @return            the class attribute | 
|---|
| 691 |    */ | 
|---|
| 692 |   protected Matrix getY(Instance instance) { | 
|---|
| 693 |     double[][]  y; | 
|---|
| 694 |     Matrix      result; | 
|---|
| 695 |      | 
|---|
| 696 |     y = new double[1][1]; | 
|---|
| 697 |     y[0][0] = instance.classValue(); | 
|---|
| 698 |      | 
|---|
| 699 |     result = new Matrix(y); | 
|---|
| 700 |      | 
|---|
| 701 |     return result; | 
|---|
| 702 |   } | 
|---|
| 703 |    | 
|---|
| 704 |   /** | 
|---|
| 705 |    * returns the X and Y matrix again as Instances object, based on the given | 
|---|
| 706 |    * header (must have a class attribute set). | 
|---|
| 707 |    *  | 
|---|
| 708 |    * @param header      the format of the instance object | 
|---|
| 709 |    * @param x           the X matrix (data) | 
|---|
| 710 |    * @param y           the Y matrix (class) | 
|---|
| 711 |    * @return            the assembled data | 
|---|
| 712 |    */ | 
|---|
| 713 |   protected Instances toInstances(Instances header, Matrix x, Matrix y) { | 
|---|
| 714 |     double[]    values; | 
|---|
| 715 |     int         i; | 
|---|
| 716 |     int         n; | 
|---|
| 717 |     Instances   result; | 
|---|
| 718 |     int         rows; | 
|---|
| 719 |     int         cols; | 
|---|
| 720 |     int         offset; | 
|---|
| 721 |     int         clsIdx; | 
|---|
| 722 |      | 
|---|
| 723 |     result = new Instances(header, 0); | 
|---|
| 724 |      | 
|---|
| 725 |     rows   = x.getRowDimension(); | 
|---|
| 726 |     cols   = x.getColumnDimension(); | 
|---|
| 727 |     clsIdx = header.classIndex(); | 
|---|
| 728 |      | 
|---|
| 729 |     for (i = 0; i < rows; i++) { | 
|---|
| 730 |       values = new double[cols + 1]; | 
|---|
| 731 |       offset = 0; | 
|---|
| 732 |  | 
|---|
| 733 |       for (n = 0; n < values.length; n++) { | 
|---|
| 734 |         if (n == clsIdx) { | 
|---|
| 735 |           offset--; | 
|---|
| 736 |           values[n] = y.get(i, 0); | 
|---|
| 737 |         } | 
|---|
| 738 |         else { | 
|---|
| 739 |           values[n] = x.get(i, n + offset); | 
|---|
| 740 |         } | 
|---|
| 741 |       } | 
|---|
| 742 |        | 
|---|
| 743 |       result.add(new DenseInstance(1.0, values)); | 
|---|
| 744 |     } | 
|---|
| 745 |      | 
|---|
| 746 |     return result; | 
|---|
| 747 |   } | 
|---|
| 748 |    | 
|---|
| 749 |   /** | 
|---|
| 750 |    * returns the given column as a vector (actually a n x 1 matrix) | 
|---|
| 751 |    *  | 
|---|
| 752 |    * @param m           the matrix to work on | 
|---|
| 753 |    * @param columnIndex the column to return | 
|---|
| 754 |    * @return            the column as n x 1 matrix | 
|---|
| 755 |    */ | 
|---|
| 756 |   protected Matrix columnAsVector(Matrix m, int columnIndex) { | 
|---|
| 757 |     Matrix      result; | 
|---|
| 758 |     int         i; | 
|---|
| 759 |      | 
|---|
| 760 |     result = new Matrix(m.getRowDimension(), 1); | 
|---|
| 761 |      | 
|---|
| 762 |     for (i = 0; i < m.getRowDimension(); i++) | 
|---|
| 763 |       result.set(i, 0, m.get(i, columnIndex)); | 
|---|
| 764 |      | 
|---|
| 765 |     return result; | 
|---|
| 766 |   } | 
|---|
| 767 |    | 
|---|
| 768 |   /** | 
|---|
| 769 |    * stores the data from the (column) vector in the matrix at the specified  | 
|---|
| 770 |    * index | 
|---|
| 771 |    *  | 
|---|
| 772 |    * @param v           the vector to store in the matrix | 
|---|
| 773 |    * @param m           the receiving matrix | 
|---|
| 774 |    * @param columnIndex the column to store the values in | 
|---|
| 775 |    */ | 
|---|
| 776 |   protected void setVector(Matrix v, Matrix m, int columnIndex) { | 
|---|
| 777 |     m.setMatrix(0, m.getRowDimension() - 1, columnIndex, columnIndex, v); | 
|---|
| 778 |   } | 
|---|
| 779 |    | 
|---|
| 780 |   /** | 
|---|
| 781 |    * returns the (column) vector of the matrix at the specified index | 
|---|
| 782 |    *  | 
|---|
| 783 |    * @param m           the matrix to work on | 
|---|
| 784 |    * @param columnIndex the column to get the values from | 
|---|
| 785 |    * @return            the column vector | 
|---|
| 786 |    */ | 
|---|
| 787 |   protected Matrix getVector(Matrix m, int columnIndex) { | 
|---|
| 788 |     return m.getMatrix(0, m.getRowDimension() - 1, columnIndex, columnIndex); | 
|---|
| 789 |   } | 
|---|
| 790 |  | 
|---|
| 791 |   /** | 
|---|
| 792 |    * determines the dominant eigenvector for the given matrix and returns it | 
|---|
| 793 |    *  | 
|---|
| 794 |    * @param m           the matrix to determine the dominant eigenvector for | 
|---|
| 795 |    * @return            the dominant eigenvector | 
|---|
| 796 |    */ | 
|---|
| 797 |   protected Matrix getDominantEigenVector(Matrix m) { | 
|---|
| 798 |     EigenvalueDecomposition     eigendecomp; | 
|---|
| 799 |     double[]                    eigenvalues; | 
|---|
| 800 |     int                         index; | 
|---|
| 801 |     Matrix                      result; | 
|---|
| 802 |      | 
|---|
| 803 |     eigendecomp = m.eig(); | 
|---|
| 804 |     eigenvalues = eigendecomp.getRealEigenvalues(); | 
|---|
| 805 |     index       = Utils.maxIndex(eigenvalues); | 
|---|
| 806 |     result      = columnAsVector(eigendecomp.getV(), index); | 
|---|
| 807 |      | 
|---|
| 808 |     return result; | 
|---|
| 809 |   } | 
|---|
| 810 |    | 
|---|
| 811 |   /** | 
|---|
| 812 |    * normalizes the given vector (inplace)  | 
|---|
| 813 |    *  | 
|---|
| 814 |    * @param v           the vector to normalize | 
|---|
| 815 |    */ | 
|---|
| 816 |   protected void normalizeVector(Matrix v) { | 
|---|
| 817 |     double      sum; | 
|---|
| 818 |     int         i; | 
|---|
| 819 |      | 
|---|
| 820 |     // determine length | 
|---|
| 821 |     sum = 0; | 
|---|
| 822 |     for (i = 0; i < v.getRowDimension(); i++) | 
|---|
| 823 |       sum += v.get(i, 0) * v.get(i, 0); | 
|---|
| 824 |     sum = StrictMath.sqrt(sum); | 
|---|
| 825 |      | 
|---|
| 826 |     // normalize content | 
|---|
| 827 |     for (i = 0; i < v.getRowDimension(); i++) | 
|---|
| 828 |       v.set(i, 0, v.get(i, 0) / sum); | 
|---|
| 829 |   } | 
|---|
| 830 |  | 
|---|
| 831 |   /** | 
|---|
| 832 |    * processes the instances using the PLS1 algorithm | 
|---|
| 833 |    * | 
|---|
| 834 |    * @param instances   the data to process | 
|---|
| 835 |    * @return            the modified data | 
|---|
| 836 |    * @throws Exception  in case the processing goes wrong | 
|---|
| 837 |    */ | 
|---|
| 838 |   protected Instances processPLS1(Instances instances) throws Exception { | 
|---|
| 839 |     Matrix      X, X_trans, x; | 
|---|
| 840 |     Matrix      y; | 
|---|
| 841 |     Matrix      W, w; | 
|---|
| 842 |     Matrix      T, t, t_trans; | 
|---|
| 843 |     Matrix      P, p, p_trans; | 
|---|
| 844 |     double      b; | 
|---|
| 845 |     Matrix      b_hat; | 
|---|
| 846 |     int         i; | 
|---|
| 847 |     int         j; | 
|---|
| 848 |     Matrix      X_new; | 
|---|
| 849 |     Matrix      tmp; | 
|---|
| 850 |     Instances   result; | 
|---|
| 851 |     Instances   tmpInst; | 
|---|
| 852 |  | 
|---|
| 853 |     // initialization | 
|---|
| 854 |     if (!isFirstBatchDone()) { | 
|---|
| 855 |       // split up data | 
|---|
| 856 |       X       = getX(instances); | 
|---|
| 857 |       y       = getY(instances); | 
|---|
| 858 |       X_trans = X.transpose(); | 
|---|
| 859 |        | 
|---|
| 860 |       // init | 
|---|
| 861 |       W     = new Matrix(instances.numAttributes() - 1, getNumComponents()); | 
|---|
| 862 |       P     = new Matrix(instances.numAttributes() - 1, getNumComponents()); | 
|---|
| 863 |       T     = new Matrix(instances.numInstances(), getNumComponents()); | 
|---|
| 864 |       b_hat = new Matrix(getNumComponents(), 1); | 
|---|
| 865 |        | 
|---|
| 866 |       for (j = 0; j < getNumComponents(); j++) { | 
|---|
| 867 |         // 1. step: wj | 
|---|
| 868 |         w = X_trans.times(y); | 
|---|
| 869 |         normalizeVector(w); | 
|---|
| 870 |         setVector(w, W, j); | 
|---|
| 871 |          | 
|---|
| 872 |         // 2. step: tj | 
|---|
| 873 |         t       = X.times(w); | 
|---|
| 874 |         t_trans = t.transpose(); | 
|---|
| 875 |         setVector(t, T, j); | 
|---|
| 876 |          | 
|---|
| 877 |         // 3. step: ^bj | 
|---|
| 878 |         b = t_trans.times(y).get(0, 0) / t_trans.times(t).get(0, 0); | 
|---|
| 879 |         b_hat.set(j, 0, b); | 
|---|
| 880 |          | 
|---|
| 881 |         // 4. step: pj | 
|---|
| 882 |         p       = X_trans.times(t).times((double) 1 / t_trans.times(t).get(0, 0)); | 
|---|
| 883 |         p_trans = p.transpose(); | 
|---|
| 884 |         setVector(p, P, j); | 
|---|
| 885 |          | 
|---|
| 886 |         // 5. step: Xj+1 | 
|---|
| 887 |         X = X.minus(t.times(p_trans)); | 
|---|
| 888 |         y = y.minus(t.times(b)); | 
|---|
| 889 |       } | 
|---|
| 890 |        | 
|---|
| 891 |       // W*(P^T*W)^-1 | 
|---|
| 892 |       tmp = W.times(((P.transpose()).times(W)).inverse()); | 
|---|
| 893 |        | 
|---|
| 894 |       // X_new = X*W*(P^T*W)^-1 | 
|---|
| 895 |       X_new = getX(instances).times(tmp); | 
|---|
| 896 |        | 
|---|
| 897 |       // factor = W*(P^T*W)^-1 * b_hat | 
|---|
| 898 |       m_PLS1_RegVector = tmp.times(b_hat); | 
|---|
| 899 |     | 
|---|
| 900 |       // save matrices | 
|---|
| 901 |       m_PLS1_P     = P; | 
|---|
| 902 |       m_PLS1_W     = W; | 
|---|
| 903 |       m_PLS1_b_hat = b_hat; | 
|---|
| 904 |        | 
|---|
| 905 |       if (getPerformPrediction()) | 
|---|
| 906 |         result = toInstances(getOutputFormat(), X_new, y); | 
|---|
| 907 |       else | 
|---|
| 908 |         result = toInstances(getOutputFormat(), X_new, getY(instances)); | 
|---|
| 909 |     } | 
|---|
| 910 |     // prediction | 
|---|
| 911 |     else { | 
|---|
| 912 |       result = new Instances(getOutputFormat()); | 
|---|
| 913 |        | 
|---|
| 914 |       for (i = 0; i < instances.numInstances(); i++) { | 
|---|
| 915 |         // work on each instance | 
|---|
| 916 |         tmpInst = new Instances(instances, 0); | 
|---|
| 917 |         tmpInst.add((Instance) instances.instance(i).copy()); | 
|---|
| 918 |         x = getX(tmpInst); | 
|---|
| 919 |         X = new Matrix(1, getNumComponents()); | 
|---|
| 920 |         T = new Matrix(1, getNumComponents()); | 
|---|
| 921 |          | 
|---|
| 922 |         for (j = 0; j < getNumComponents(); j++) { | 
|---|
| 923 |           setVector(x, X, j); | 
|---|
| 924 |           // 1. step: tj = xj * wj | 
|---|
| 925 |           t = x.times(getVector(m_PLS1_W, j)); | 
|---|
| 926 |           setVector(t, T, j); | 
|---|
| 927 |           // 2. step: xj+1 = xj - tj*pj^T (tj is 1x1 matrix!) | 
|---|
| 928 |           x = x.minus(getVector(m_PLS1_P, j).transpose().times(t.get(0, 0))); | 
|---|
| 929 |         } | 
|---|
| 930 |          | 
|---|
| 931 |         if (getPerformPrediction()) | 
|---|
| 932 |           tmpInst = toInstances(getOutputFormat(), T, T.times(m_PLS1_b_hat)); | 
|---|
| 933 |         else | 
|---|
| 934 |           tmpInst = toInstances(getOutputFormat(), T, getY(tmpInst)); | 
|---|
| 935 |          | 
|---|
| 936 |         result.add(tmpInst.instance(0)); | 
|---|
| 937 |       } | 
|---|
| 938 |     } | 
|---|
| 939 |      | 
|---|
| 940 |     return result; | 
|---|
| 941 |   } | 
|---|
| 942 |  | 
|---|
| 943 |   /** | 
|---|
| 944 |    * processes the instances using the SIMPLS algorithm | 
|---|
| 945 |    * | 
|---|
| 946 |    * @param instances   the data to process | 
|---|
| 947 |    * @return            the modified data | 
|---|
| 948 |    * @throws Exception  in case the processing goes wrong | 
|---|
| 949 |    */ | 
|---|
| 950 |   protected Instances processSIMPLS(Instances instances) throws Exception { | 
|---|
| 951 |     Matrix      A, A_trans; | 
|---|
| 952 |     Matrix      M; | 
|---|
| 953 |     Matrix      X, X_trans; | 
|---|
| 954 |     Matrix      X_new; | 
|---|
| 955 |     Matrix      Y, y; | 
|---|
| 956 |     Matrix      C, c; | 
|---|
| 957 |     Matrix      Q, q; | 
|---|
| 958 |     Matrix      W, w; | 
|---|
| 959 |     Matrix      P, p, p_trans; | 
|---|
| 960 |     Matrix      v, v_trans; | 
|---|
| 961 |     Matrix      T; | 
|---|
| 962 |     Instances   result; | 
|---|
| 963 |     int         h; | 
|---|
| 964 |      | 
|---|
| 965 |     if (!isFirstBatchDone()) { | 
|---|
| 966 |       // init | 
|---|
| 967 |       X       = getX(instances); | 
|---|
| 968 |       X_trans = X.transpose(); | 
|---|
| 969 |       Y       = getY(instances); | 
|---|
| 970 |       A       = X_trans.times(Y); | 
|---|
| 971 |       M       = X_trans.times(X); | 
|---|
| 972 |       C       = Matrix.identity(instances.numAttributes() - 1, instances.numAttributes() - 1); | 
|---|
| 973 |       W       = new Matrix(instances.numAttributes() - 1, getNumComponents()); | 
|---|
| 974 |       P       = new Matrix(instances.numAttributes() - 1, getNumComponents()); | 
|---|
| 975 |       Q       = new Matrix(1, getNumComponents()); | 
|---|
| 976 |        | 
|---|
| 977 |       for (h = 0; h < getNumComponents(); h++) { | 
|---|
| 978 |         // 1. qh as dominant EigenVector of Ah'*Ah | 
|---|
| 979 |         A_trans = A.transpose(); | 
|---|
| 980 |         q       = getDominantEigenVector(A_trans.times(A)); | 
|---|
| 981 |          | 
|---|
| 982 |         // 2. wh=Ah*qh, ch=wh'*Mh*wh, wh=wh/sqrt(ch), store wh in W as column | 
|---|
| 983 |         w       = A.times(q); | 
|---|
| 984 |         c       = w.transpose().times(M).times(w); | 
|---|
| 985 |         w       = w.times(1.0 / StrictMath.sqrt(c.get(0, 0))); | 
|---|
| 986 |         setVector(w, W, h); | 
|---|
| 987 |          | 
|---|
| 988 |         // 3. ph=Mh*wh, store ph in P as column | 
|---|
| 989 |         p       = M.times(w); | 
|---|
| 990 |         p_trans = p.transpose(); | 
|---|
| 991 |         setVector(p, P, h); | 
|---|
| 992 |          | 
|---|
| 993 |         // 4. qh=Ah'*wh, store qh in Q as column | 
|---|
| 994 |         q = A_trans.times(w); | 
|---|
| 995 |         setVector(q, Q, h); | 
|---|
| 996 |          | 
|---|
| 997 |         // 5. vh=Ch*ph, vh=vh/||vh|| | 
|---|
| 998 |         v       = C.times(p); | 
|---|
| 999 |         normalizeVector(v); | 
|---|
| 1000 |         v_trans = v.transpose(); | 
|---|
| 1001 |          | 
|---|
| 1002 |         // 6. Ch+1=Ch-vh*vh', Mh+1=Mh-ph*ph' | 
|---|
| 1003 |         C = C.minus(v.times(v_trans)); | 
|---|
| 1004 |         M = M.minus(p.times(p_trans)); | 
|---|
| 1005 |          | 
|---|
| 1006 |         // 7. Ah+1=ChAh (actually Ch+1) | 
|---|
| 1007 |         A = C.times(A); | 
|---|
| 1008 |       } | 
|---|
| 1009 |        | 
|---|
| 1010 |       // finish | 
|---|
| 1011 |       m_SIMPLS_W = W; | 
|---|
| 1012 |       T          = X.times(m_SIMPLS_W); | 
|---|
| 1013 |       X_new      = T; | 
|---|
| 1014 |       m_SIMPLS_B = W.times(Q.transpose()); | 
|---|
| 1015 |        | 
|---|
| 1016 |       if (getPerformPrediction()) | 
|---|
| 1017 |         y = T.times(P.transpose()).times(m_SIMPLS_B); | 
|---|
| 1018 |       else | 
|---|
| 1019 |         y = getY(instances); | 
|---|
| 1020 |  | 
|---|
| 1021 |       result = toInstances(getOutputFormat(), X_new, y); | 
|---|
| 1022 |     } | 
|---|
| 1023 |     else { | 
|---|
| 1024 |       result = new Instances(getOutputFormat()); | 
|---|
| 1025 |        | 
|---|
| 1026 |       X     = getX(instances); | 
|---|
| 1027 |       X_new = X.times(m_SIMPLS_W); | 
|---|
| 1028 |        | 
|---|
| 1029 |       if (getPerformPrediction()) | 
|---|
| 1030 |         y = X.times(m_SIMPLS_B); | 
|---|
| 1031 |       else | 
|---|
| 1032 |         y = getY(instances); | 
|---|
| 1033 |        | 
|---|
| 1034 |       result = toInstances(getOutputFormat(), X_new, y); | 
|---|
| 1035 |     } | 
|---|
| 1036 |      | 
|---|
| 1037 |     return result; | 
|---|
| 1038 |   } | 
|---|
| 1039 |  | 
|---|
| 1040 |   /**  | 
|---|
| 1041 |    * Returns the Capabilities of this filter. | 
|---|
| 1042 |    * | 
|---|
| 1043 |    * @return            the capabilities of this object | 
|---|
| 1044 |    * @see               Capabilities | 
|---|
| 1045 |    */ | 
|---|
| 1046 |   public Capabilities getCapabilities() { | 
|---|
| 1047 |     Capabilities result = super.getCapabilities(); | 
|---|
| 1048 |     result.disableAll(); | 
|---|
| 1049 |  | 
|---|
| 1050 |     // attributes | 
|---|
| 1051 |     result.enable(Capability.NUMERIC_ATTRIBUTES); | 
|---|
| 1052 |     result.enable(Capability.DATE_ATTRIBUTES); | 
|---|
| 1053 |     result.enable(Capability.MISSING_VALUES); | 
|---|
| 1054 |      | 
|---|
| 1055 |     // class | 
|---|
| 1056 |     result.enable(Capability.NUMERIC_CLASS); | 
|---|
| 1057 |     result.enable(Capability.DATE_CLASS); | 
|---|
| 1058 |      | 
|---|
| 1059 |     return result; | 
|---|
| 1060 |   } | 
|---|
| 1061 |    | 
|---|
| 1062 |   /** | 
|---|
| 1063 |    * Processes the given data (may change the provided dataset) and returns | 
|---|
| 1064 |    * the modified version. This method is called in batchFinished(). | 
|---|
| 1065 |    * | 
|---|
| 1066 |    * @param instances   the data to process | 
|---|
| 1067 |    * @return            the modified data | 
|---|
| 1068 |    * @throws Exception  in case the processing goes wrong | 
|---|
| 1069 |    * @see               #batchFinished() | 
|---|
| 1070 |    */ | 
|---|
| 1071 |   protected Instances process(Instances instances) throws Exception { | 
|---|
| 1072 |     Instances   result; | 
|---|
| 1073 |     int         i; | 
|---|
| 1074 |     double      clsValue; | 
|---|
| 1075 |     double[]    clsValues; | 
|---|
| 1076 |      | 
|---|
| 1077 |     result = null; | 
|---|
| 1078 |  | 
|---|
| 1079 |     // save original class values if no prediction is performed | 
|---|
| 1080 |     if (!getPerformPrediction()) | 
|---|
| 1081 |       clsValues = instances.attributeToDoubleArray(instances.classIndex()); | 
|---|
| 1082 |     else | 
|---|
| 1083 |       clsValues = null; | 
|---|
| 1084 |      | 
|---|
| 1085 |     if (!isFirstBatchDone()) { | 
|---|
| 1086 |       // init filters | 
|---|
| 1087 |       if (m_ReplaceMissing) | 
|---|
| 1088 |         m_Missing.setInputFormat(instances); | 
|---|
| 1089 |        | 
|---|
| 1090 |       switch (m_Preprocessing) { | 
|---|
| 1091 |         case PREPROCESSING_CENTER: | 
|---|
| 1092 |           m_ClassMean   = instances.meanOrMode(instances.classIndex()); | 
|---|
| 1093 |           m_ClassStdDev = 1; | 
|---|
| 1094 |           m_Filter      = new Center(); | 
|---|
| 1095 |           ((Center) m_Filter).setIgnoreClass(true); | 
|---|
| 1096 |           break; | 
|---|
| 1097 |         case PREPROCESSING_STANDARDIZE: | 
|---|
| 1098 |           m_ClassMean   = instances.meanOrMode(instances.classIndex()); | 
|---|
| 1099 |           m_ClassStdDev = StrictMath.sqrt(instances.variance(instances.classIndex())); | 
|---|
| 1100 |           m_Filter      = new Standardize(); | 
|---|
| 1101 |           ((Standardize) m_Filter).setIgnoreClass(true); | 
|---|
| 1102 |           break; | 
|---|
| 1103 |         default: | 
|---|
| 1104 |           m_ClassMean   = 0; | 
|---|
| 1105 |           m_ClassStdDev = 1; | 
|---|
| 1106 |           m_Filter      = null; | 
|---|
| 1107 |       } | 
|---|
| 1108 |       if (m_Filter != null) | 
|---|
| 1109 |         m_Filter.setInputFormat(instances); | 
|---|
| 1110 |     } | 
|---|
| 1111 |      | 
|---|
| 1112 |     // filter data | 
|---|
| 1113 |     if (m_ReplaceMissing) | 
|---|
| 1114 |       instances = Filter.useFilter(instances, m_Missing); | 
|---|
| 1115 |     if (m_Filter != null) | 
|---|
| 1116 |       instances = Filter.useFilter(instances, m_Filter); | 
|---|
| 1117 |      | 
|---|
| 1118 |     switch (m_Algorithm) { | 
|---|
| 1119 |       case ALGORITHM_SIMPLS: | 
|---|
| 1120 |         result = processSIMPLS(instances); | 
|---|
| 1121 |         break; | 
|---|
| 1122 |       case ALGORITHM_PLS1: | 
|---|
| 1123 |         result = processPLS1(instances); | 
|---|
| 1124 |         break; | 
|---|
| 1125 |       default: | 
|---|
| 1126 |         throw new IllegalStateException( | 
|---|
| 1127 |             "Algorithm type '" + m_Algorithm + "' is not recognized!"); | 
|---|
| 1128 |     } | 
|---|
| 1129 |  | 
|---|
| 1130 |     // add the mean to the class again if predictions are to be performed, | 
|---|
| 1131 |     // otherwise restore original class values | 
|---|
| 1132 |     for (i = 0; i < result.numInstances(); i++) { | 
|---|
| 1133 |       if (!getPerformPrediction()) { | 
|---|
| 1134 |         result.instance(i).setClassValue(clsValues[i]); | 
|---|
| 1135 |       } | 
|---|
| 1136 |       else { | 
|---|
| 1137 |         clsValue = result.instance(i).classValue(); | 
|---|
| 1138 |         result.instance(i).setClassValue(clsValue*m_ClassStdDev + m_ClassMean); | 
|---|
| 1139 |       } | 
|---|
| 1140 |     } | 
|---|
| 1141 |      | 
|---|
| 1142 |     return result; | 
|---|
| 1143 |   } | 
|---|
| 1144 |    | 
|---|
| 1145 |   /** | 
|---|
| 1146 |    * Returns the revision string. | 
|---|
| 1147 |    *  | 
|---|
| 1148 |    * @return            the revision | 
|---|
| 1149 |    */ | 
|---|
| 1150 |   public String getRevision() { | 
|---|
| 1151 |     return RevisionUtils.extract("$Revision: 5987 $"); | 
|---|
| 1152 |   } | 
|---|
| 1153 |  | 
|---|
| 1154 |   /** | 
|---|
| 1155 |    * runs the filter with the given arguments. | 
|---|
| 1156 |    * | 
|---|
| 1157 |    * @param args      the commandline arguments | 
|---|
| 1158 |    */ | 
|---|
| 1159 |   public static void main(String[] args) { | 
|---|
| 1160 |     runFilter(new PLSFilter(), args); | 
|---|
| 1161 |   } | 
|---|
| 1162 | } | 
|---|