source: branches/MetisMQI/src/main/java/weka/filters/supervised/attribute/PLSFilter.java

Last change on this file was 29, checked in by gnappo, 15 years ago

Taggata versione per la demo e aggiunto branch.

File size: 31.9 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 * PLSFilter.java
19 * Copyright (C) 2006 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.filters.supervised.attribute;
24
25import weka.core.Attribute;
26import weka.core.Capabilities;
27import weka.core.FastVector;
28import weka.core.Instance;
29import weka.core.DenseInstance;
30import weka.core.Instances;
31import weka.core.Option;
32import weka.core.RevisionUtils;
33import weka.core.SelectedTag;
34import weka.core.Tag;
35import weka.core.TechnicalInformation;
36import weka.core.TechnicalInformationHandler;
37import weka.core.Utils;
38import weka.core.Capabilities.Capability;
39import weka.core.TechnicalInformation.Field;
40import weka.core.TechnicalInformation.Type;
41import weka.core.matrix.EigenvalueDecomposition;
42import weka.core.matrix.Matrix;
43import weka.filters.Filter;
44import weka.filters.SimpleBatchFilter;
45import weka.filters.SupervisedFilter;
46import weka.filters.unsupervised.attribute.Center;
47import weka.filters.unsupervised.attribute.ReplaceMissingValues;
48import weka.filters.unsupervised.attribute.Standardize;
49
50import java.util.Enumeration;
51import java.util.Vector;
52
53/**
54 <!-- globalinfo-start -->
55 * Runs Partial Least Square Regression over the given instances and computes the resulting beta matrix for prediction.<br/>
56 * By default it replaces missing values and centers the data.<br/>
57 * <br/>
58 * For more information see:<br/>
59 * <br/>
60 * Tormod Naes, Tomas Isaksson, Tom Fearn, Tony Davies (2002). A User Friendly Guide to Multivariate Calibration and Classification. NIR Publications.<br/>
61 * <br/>
62 * StatSoft, Inc.. Partial Least Squares (PLS).<br/>
63 * <br/>
64 * Bent Jorgensen, Yuri Goegebeur. Module 7: Partial least squares regression I.<br/>
65 * <br/>
66 * S. de Jong (1993). SIMPLS: an alternative approach to partial least squares regression. Chemometrics and Intelligent Laboratory Systems. 18:251-263.
67 * <p/>
68 <!-- globalinfo-end -->
69 *
70 <!-- technical-bibtex-start -->
71 * BibTeX:
72 * <pre>
73 * &#64;book{Naes2002,
74 *    author = {Tormod Naes and Tomas Isaksson and Tom Fearn and Tony Davies},
75 *    publisher = {NIR Publications},
76 *    title = {A User Friendly Guide to Multivariate Calibration and Classification},
77 *    year = {2002},
78 *    ISBN = {0-9528666-2-5}
79 * }
80 *
81 * &#64;misc{missing_id,
82 *    author = {StatSoft, Inc.},
83 *    booktitle = {Electronic Textbook StatSoft},
84 *    title = {Partial Least Squares (PLS)},
85 *    HTTP = {http://www.statsoft.com/textbook/stpls.html}
86 * }
87 *
88 * &#64;misc{missing_id,
89 *    author = {Bent Jorgensen and Yuri Goegebeur},
90 *    booktitle = {ST02: Multivariate Data Analysis and Chemometrics},
91 *    title = {Module 7: Partial least squares regression I},
92 *    HTTP = {http://statmaster.sdu.dk/courses/ST02/module07/}
93 * }
94 *
95 * &#64;article{Jong1993,
96 *    author = {S. de Jong},
97 *    journal = {Chemometrics and Intelligent Laboratory Systems},
98 *    pages = {251-263},
99 *    title = {SIMPLS: an alternative approach to partial least squares regression},
100 *    volume = {18},
101 *    year = {1993}
102 * }
103 * </pre>
104 * <p/>
105 <!-- technical-bibtex-end -->
106 *
107 <!-- options-start -->
108 * Valid options are: <p/>
109 *
110 * <pre> -D
111 *  Turns on output of debugging information.</pre>
112 *
113 * <pre> -C &lt;num&gt;
114 *  The number of components to compute.
115 *  (default: 20)</pre>
116 *
117 * <pre> -U
118 *  Updates the class attribute as well.
119 *  (default: off)</pre>
120 *
121 * <pre> -M
122 *  Turns replacing of missing values on.
123 *  (default: off)</pre>
124 *
125 * <pre> -A &lt;SIMPLS|PLS1&gt;
126 *  The algorithm to use.
127 *  (default: PLS1)</pre>
128 *
129 * <pre> -P &lt;none|center|standardize&gt;
130 *  The type of preprocessing that is applied to the data.
131 *  (default: center)</pre>
132 *
133 <!-- options-end -->
134 *
135 * @author FracPete (fracpete at waikato dot ac dot nz)
136 * @version $Revision: 5987 $
137 */
138public class PLSFilter
139  extends SimpleBatchFilter
140  implements SupervisedFilter, TechnicalInformationHandler {
141
142  /** for serialization */
143  static final long serialVersionUID = -3335106965521265631L;
144
145  /** the type of algorithm: SIMPLS */
146  public static final int ALGORITHM_SIMPLS = 1;
147  /** the type of algorithm: PLS1 */
148  public static final int ALGORITHM_PLS1 = 2;
149  /** the types of algorithm */
150  public static final Tag[] TAGS_ALGORITHM = {
151    new Tag(ALGORITHM_SIMPLS, "SIMPLS"),
152    new Tag(ALGORITHM_PLS1, "PLS1")
153  };
154
155  /** the type of preprocessing: None */
156  public static final int PREPROCESSING_NONE = 0;
157  /** the type of preprocessing: Center */
158  public static final int PREPROCESSING_CENTER = 1;
159  /** the type of preprocessing: Standardize */
160  public static final int PREPROCESSING_STANDARDIZE = 2;
161  /** the types of preprocessing */
162  public static final Tag[] TAGS_PREPROCESSING = {
163    new Tag(PREPROCESSING_NONE, "none"),
164    new Tag(PREPROCESSING_CENTER, "center"),
165    new Tag(PREPROCESSING_STANDARDIZE, "standardize")
166  };
167
168  /** the maximum number of components to generate */
169  protected int m_NumComponents = 20;
170 
171  /** the type of algorithm */
172  protected int m_Algorithm = ALGORITHM_PLS1;
173
174  /** the regression vector "r-hat" for PLS1 */
175  protected Matrix m_PLS1_RegVector = null;
176
177  /** the P matrix for PLS1 */
178  protected Matrix m_PLS1_P = null;
179
180  /** the W matrix for PLS1 */
181  protected Matrix m_PLS1_W = null;
182
183  /** the b-hat vector for PLS1 */
184  protected Matrix m_PLS1_b_hat = null;
185 
186  /** the W matrix for SIMPLS */
187  protected Matrix m_SIMPLS_W = null;
188 
189  /** the B matrix for SIMPLS (used for prediction) */
190  protected Matrix m_SIMPLS_B = null;
191 
192  /** whether to include the prediction, i.e., modifying the class attribute */
193  protected boolean m_PerformPrediction = false;
194
195  /** for replacing missing values */
196  protected Filter m_Missing = null;
197 
198  /** whether to replace missing values */
199  protected boolean m_ReplaceMissing = true;
200 
201  /** for centering the data */
202  protected Filter m_Filter = null;
203 
204  /** the type of preprocessing */
205  protected int m_Preprocessing = PREPROCESSING_CENTER;
206
207  /** the mean of the class */
208  protected double m_ClassMean = 0;
209
210  /** the standard deviation of the class */
211  protected double m_ClassStdDev = 0;
212 
213  /**
214   * default constructor
215   */
216  public PLSFilter() {
217    super();
218   
219    // setup pre-processing
220    m_Missing = new ReplaceMissingValues();
221    m_Filter  = new Center();
222  }
223 
224  /**
225   * Returns a string describing this classifier.
226   *
227   * @return      a description of the classifier suitable for
228   *              displaying in the explorer/experimenter gui
229   */
230  public String globalInfo() {
231    return 
232        "Runs Partial Least Square Regression over the given instances "
233      + "and computes the resulting beta matrix for prediction.\n"
234      + "By default it replaces missing values and centers the data.\n\n"
235      + "For more information see:\n\n"
236      + getTechnicalInformation().toString();
237  }
238
239  /**
240   * Returns an instance of a TechnicalInformation object, containing
241   * detailed information about the technical background of this class,
242   * e.g., paper reference or book this class is based on.
243   *
244   * @return the technical information about this class
245   */
246  public TechnicalInformation getTechnicalInformation() {
247    TechnicalInformation        result;
248    TechnicalInformation        additional;
249   
250    result = new TechnicalInformation(Type.BOOK);
251    result.setValue(Field.AUTHOR, "Tormod Naes and Tomas Isaksson and Tom Fearn and Tony Davies");
252    result.setValue(Field.YEAR, "2002");
253    result.setValue(Field.TITLE, "A User Friendly Guide to Multivariate Calibration and Classification");
254    result.setValue(Field.PUBLISHER, "NIR Publications");
255    result.setValue(Field.ISBN, "0-9528666-2-5");
256   
257    additional = result.add(Type.MISC);
258    additional.setValue(Field.AUTHOR, "StatSoft, Inc.");
259    additional.setValue(Field.TITLE, "Partial Least Squares (PLS)");
260    additional.setValue(Field.BOOKTITLE, "Electronic Textbook StatSoft");
261    additional.setValue(Field.HTTP, "http://www.statsoft.com/textbook/stpls.html");
262   
263    additional = result.add(Type.MISC);
264    additional.setValue(Field.AUTHOR, "Bent Jorgensen and Yuri Goegebeur");
265    additional.setValue(Field.TITLE, "Module 7: Partial least squares regression I");
266    additional.setValue(Field.BOOKTITLE, "ST02: Multivariate Data Analysis and Chemometrics");
267    additional.setValue(Field.HTTP, "http://statmaster.sdu.dk/courses/ST02/module07/");
268   
269    additional = result.add(Type.ARTICLE);
270    additional.setValue(Field.AUTHOR, "S. de Jong");
271    additional.setValue(Field.YEAR, "1993");
272    additional.setValue(Field.TITLE, "SIMPLS: an alternative approach to partial least squares regression");
273    additional.setValue(Field.JOURNAL, "Chemometrics and Intelligent Laboratory Systems");
274    additional.setValue(Field.VOLUME, "18");
275    additional.setValue(Field.PAGES, "251-263");
276   
277    return result;
278  }
279
280  /**
281   * Gets an enumeration describing the available options.
282   *
283   * @return an enumeration of all the available options.
284   */
285  public Enumeration listOptions() {
286    Vector              result;
287    Enumeration         enm;
288    String              param;
289    SelectedTag         tag;
290    int                 i;
291
292    result = new Vector();
293
294    enm = super.listOptions();
295    while (enm.hasMoreElements())
296      result.addElement(enm.nextElement());
297
298    result.addElement(new Option(
299        "\tThe number of components to compute.\n"
300        + "\t(default: 20)",
301        "C", 1, "-C <num>"));
302
303    result.addElement(new Option(
304        "\tUpdates the class attribute as well.\n"
305        + "\t(default: off)",
306        "U", 0, "-U"));
307
308    result.addElement(new Option(
309        "\tTurns replacing of missing values on.\n"
310        + "\t(default: off)",
311        "M", 0, "-M"));
312
313    param = "";
314    for (i = 0; i < TAGS_ALGORITHM.length; i++) {
315      if (i > 0)
316        param += "|";
317      tag = new SelectedTag(TAGS_ALGORITHM[i].getID(), TAGS_ALGORITHM);
318      param += tag.getSelectedTag().getReadable();
319    }
320    result.addElement(new Option(
321        "\tThe algorithm to use.\n"
322        + "\t(default: PLS1)",
323        "A", 1, "-A <" + param + ">"));
324
325    param = "";
326    for (i = 0; i < TAGS_PREPROCESSING.length; i++) {
327      if (i > 0)
328        param += "|";
329      tag = new SelectedTag(TAGS_PREPROCESSING[i].getID(), TAGS_PREPROCESSING);
330      param += tag.getSelectedTag().getReadable();
331    }
332    result.addElement(new Option(
333        "\tThe type of preprocessing that is applied to the data.\n"
334        + "\t(default: center)",
335        "P", 1, "-P <" + param + ">"));
336
337    return result.elements();
338  }
339
340  /**
341   * returns the options of the current setup
342   *
343   * @return      the current options
344   */
345  public String[] getOptions() {
346    int       i;
347    Vector    result;
348    String[]  options;
349
350    result = new Vector();
351    options = super.getOptions();
352    for (i = 0; i < options.length; i++)
353      result.add(options[i]);
354
355    result.add("-C");
356    result.add("" + getNumComponents());
357
358    if (getPerformPrediction())
359      result.add("-U");
360   
361    if (getReplaceMissing())
362      result.add("-M");
363   
364    result.add("-A");
365    result.add("" + getAlgorithm().getSelectedTag().getReadable());
366
367    result.add("-P");
368    result.add("" + getPreprocessing().getSelectedTag().getReadable());
369
370    return (String[]) result.toArray(new String[result.size()]);         
371  }
372
373  /**
374   * Parses the options for this object. <p/>
375   *
376   <!-- options-start -->
377   * Valid options are: <p/>
378   *
379   * <pre> -D
380   *  Turns on output of debugging information.</pre>
381   *
382   * <pre> -C &lt;num&gt;
383   *  The number of components to compute.
384   *  (default: 20)</pre>
385   *
386   * <pre> -U
387   *  Updates the class attribute as well.
388   *  (default: off)</pre>
389   *
390   * <pre> -M
391   *  Turns replacing of missing values on.
392   *  (default: off)</pre>
393   *
394   * <pre> -A &lt;SIMPLS|PLS1&gt;
395   *  The algorithm to use.
396   *  (default: PLS1)</pre>
397   *
398   * <pre> -P &lt;none|center|standardize&gt;
399   *  The type of preprocessing that is applied to the data.
400   *  (default: center)</pre>
401   *
402   <!-- options-end -->
403   *
404   * @param options     the options to use
405   * @throws Exception  if the option setting fails
406   */
407  public void setOptions(String[] options) throws Exception {
408    String      tmpStr;
409
410    super.setOptions(options);
411
412    tmpStr = Utils.getOption("C", options);
413    if (tmpStr.length() != 0)
414      setNumComponents(Integer.parseInt(tmpStr));
415    else
416      setNumComponents(20);
417
418    setPerformPrediction(Utils.getFlag("U", options));
419   
420    setReplaceMissing(Utils.getFlag("M", options));
421   
422    tmpStr = Utils.getOption("A", options);
423    if (tmpStr.length() != 0)
424      setAlgorithm(new SelectedTag(tmpStr, TAGS_ALGORITHM));
425    else
426      setAlgorithm(new SelectedTag(ALGORITHM_PLS1, TAGS_ALGORITHM));
427   
428    tmpStr = Utils.getOption("P", options);
429    if (tmpStr.length() != 0)
430      setPreprocessing(new SelectedTag(tmpStr, TAGS_PREPROCESSING));
431    else
432      setPreprocessing(new SelectedTag(PREPROCESSING_CENTER, TAGS_PREPROCESSING));
433  }
434
435  /**
436   * Returns the tip text for this property
437   *
438   * @return            tip text for this property suitable for
439   *                    displaying in the explorer/experimenter gui
440   */
441  public String numComponentsTipText() {
442    return "The number of components to compute.";
443  }
444
445  /**
446   * sets the maximum number of attributes to use.
447   *
448   * @param value       the maximum number of attributes
449   */
450  public void setNumComponents(int value) {
451    m_NumComponents = value;
452  }
453
454  /**
455   * returns the maximum number of attributes to use.
456   *
457   * @return            the current maximum number of attributes
458   */
459  public int getNumComponents() {
460    return m_NumComponents;
461  }
462
463  /**
464   * Returns the tip text for this property
465   *
466   * @return            tip text for this property suitable for
467   *                    displaying in the explorer/experimenter gui
468   */
469  public String performPredictionTipText() {
470    return "Whether to update the class attribute with the predicted value.";
471  }
472
473  /**
474   * Sets whether to update the class attribute with the predicted value.
475   *
476   * @param value       if true the class value will be replaced by the
477   *                    predicted value.
478   */
479  public void setPerformPrediction(boolean value) {
480    m_PerformPrediction = value;
481  }
482
483  /**
484   * Gets whether the class attribute is updated with the predicted value.
485   *
486   * @return            true if the class attribute is updated
487   */
488  public boolean getPerformPrediction() {
489    return m_PerformPrediction;
490  }
491
492  /**
493   * Returns the tip text for this property
494   *
495   * @return            tip text for this property suitable for
496   *                    displaying in the explorer/experimenter gui
497   */
498  public String algorithmTipText() {
499    return "Sets the type of algorithm to use.";
500  }
501
502  /**
503   * Sets the type of algorithm to use
504   *
505   * @param value       the algorithm type
506   */
507  public void setAlgorithm(SelectedTag value) {
508    if (value.getTags() == TAGS_ALGORITHM) {
509      m_Algorithm = value.getSelectedTag().getID();
510    }
511  }
512
513  /**
514   * Gets the type of algorithm to use
515   *
516   * @return            the current algorithm type.
517   */
518  public SelectedTag getAlgorithm() {
519    return new SelectedTag(m_Algorithm, TAGS_ALGORITHM);
520  }
521
522  /**
523   * Returns the tip text for this property
524   *
525   * @return            tip text for this property suitable for
526   *                    displaying in the explorer/experimenter gui
527   */
528  public String replaceMissingTipText() {
529    return "Whether to replace missing values.";
530  }
531
532  /**
533   * Sets whether to replace missing values.
534   *
535   * @param value       if true missing values are replaced with the
536   *                    ReplaceMissingValues filter.
537   */
538  public void setReplaceMissing(boolean value) {
539    m_ReplaceMissing = value;
540  }
541
542  /**
543   * Gets whether missing values are replace.
544   *
545   * @return            true if missing values are replaced with the
546   *                    ReplaceMissingValues filter
547   */
548  public boolean getReplaceMissing() {
549    return m_ReplaceMissing;
550  }
551
552  /**
553   * Returns the tip text for this property
554   *
555   * @return            tip text for this property suitable for
556   *                    displaying in the explorer/experimenter gui
557   */
558  public String preprocessingTipText() {
559    return "Sets the type of preprocessing to use.";
560  }
561
562  /**
563   * Sets the type of preprocessing to use
564   *
565   * @param value       the preprocessing type
566   */
567  public void setPreprocessing(SelectedTag value) {
568    if (value.getTags() == TAGS_PREPROCESSING) {
569      m_Preprocessing = value.getSelectedTag().getID();
570    }
571  }
572
573  /**
574   * Gets the type of preprocessing to use
575   *
576   * @return            the current preprocessing type.
577   */
578  public SelectedTag getPreprocessing() {
579    return new SelectedTag(m_Preprocessing, TAGS_PREPROCESSING);
580  }
581
582  /**
583   * Determines the output format based on the input format and returns
584   * this. In case the output format cannot be returned immediately, i.e.,
585   * immediateOutputFormat() returns false, then this method will be called
586   * from batchFinished().
587   *
588   * @param inputFormat     the input format to base the output format on
589   * @return                the output format
590   * @throws Exception      in case the determination goes wrong
591   * @see   #hasImmediateOutputFormat()
592   * @see   #batchFinished()
593   */
594  protected Instances determineOutputFormat(Instances inputFormat) 
595    throws Exception {
596
597    // generate header
598    FastVector atts = new FastVector();
599    String prefix = getAlgorithm().getSelectedTag().getReadable();
600    for (int i = 0; i < getNumComponents(); i++)
601      atts.addElement(new Attribute(prefix + "_" + (i+1)));
602    atts.addElement(new Attribute("Class"));
603    Instances result = new Instances(prefix, atts, 0);
604    result.setClassIndex(result.numAttributes() - 1);
605   
606    return result;
607  }
608 
609  /**
610   * returns the data minus the class column as matrix
611   *
612   * @param instances   the data to work on
613   * @return            the data without class attribute
614   */
615  protected Matrix getX(Instances instances) {
616    double[][]  x;
617    double[]    values;
618    Matrix      result;
619    int         i;
620    int         n;
621    int         j;
622    int         clsIndex;
623   
624    clsIndex = instances.classIndex();
625    x        = new double[instances.numInstances()][];
626   
627    for (i = 0; i < instances.numInstances(); i++) {
628      values = instances.instance(i).toDoubleArray();
629      x[i]   = new double[values.length - 1];
630     
631      j = 0;
632      for (n = 0; n < values.length; n++) {
633        if (n != clsIndex) {
634          x[i][j] = values[n];
635          j++;
636        }
637      }
638    }
639   
640    result = new Matrix(x);
641   
642    return result;
643  }
644 
645  /**
646   * returns the data minus the class column as matrix
647   *
648   * @param instance    the instance to work on
649   * @return            the data without the class attribute
650   */
651  protected Matrix getX(Instance instance) {
652    double[][]  x;
653    double[]    values;
654    Matrix      result;
655   
656    x = new double[1][];
657    values = instance.toDoubleArray();
658    x[0] = new double[values.length - 1];
659    System.arraycopy(values, 0, x[0], 0, values.length - 1);
660   
661    result = new Matrix(x);
662   
663    return result;
664  }
665 
666  /**
667   * returns the data class column as matrix
668   *
669   * @param instances   the data to work on
670   * @return            the class attribute
671   */
672  protected Matrix getY(Instances instances) {
673    double[][]  y;
674    Matrix      result;
675    int         i;
676   
677    y = new double[instances.numInstances()][1];
678    for (i = 0; i < instances.numInstances(); i++)
679      y[i][0] = instances.instance(i).classValue();
680   
681    result = new Matrix(y);
682   
683    return result;
684  }
685 
686  /**
687   * returns the data class column as matrix
688   *
689   * @param instance    the instance to work on
690   * @return            the class attribute
691   */
692  protected Matrix getY(Instance instance) {
693    double[][]  y;
694    Matrix      result;
695   
696    y = new double[1][1];
697    y[0][0] = instance.classValue();
698   
699    result = new Matrix(y);
700   
701    return result;
702  }
703 
704  /**
705   * returns the X and Y matrix again as Instances object, based on the given
706   * header (must have a class attribute set).
707   *
708   * @param header      the format of the instance object
709   * @param x           the X matrix (data)
710   * @param y           the Y matrix (class)
711   * @return            the assembled data
712   */
713  protected Instances toInstances(Instances header, Matrix x, Matrix y) {
714    double[]    values;
715    int         i;
716    int         n;
717    Instances   result;
718    int         rows;
719    int         cols;
720    int         offset;
721    int         clsIdx;
722   
723    result = new Instances(header, 0);
724   
725    rows   = x.getRowDimension();
726    cols   = x.getColumnDimension();
727    clsIdx = header.classIndex();
728   
729    for (i = 0; i < rows; i++) {
730      values = new double[cols + 1];
731      offset = 0;
732
733      for (n = 0; n < values.length; n++) {
734        if (n == clsIdx) {
735          offset--;
736          values[n] = y.get(i, 0);
737        }
738        else {
739          values[n] = x.get(i, n + offset);
740        }
741      }
742     
743      result.add(new DenseInstance(1.0, values));
744    }
745   
746    return result;
747  }
748 
749  /**
750   * returns the given column as a vector (actually a n x 1 matrix)
751   *
752   * @param m           the matrix to work on
753   * @param columnIndex the column to return
754   * @return            the column as n x 1 matrix
755   */
756  protected Matrix columnAsVector(Matrix m, int columnIndex) {
757    Matrix      result;
758    int         i;
759   
760    result = new Matrix(m.getRowDimension(), 1);
761   
762    for (i = 0; i < m.getRowDimension(); i++)
763      result.set(i, 0, m.get(i, columnIndex));
764   
765    return result;
766  }
767 
768  /**
769   * stores the data from the (column) vector in the matrix at the specified
770   * index
771   *
772   * @param v           the vector to store in the matrix
773   * @param m           the receiving matrix
774   * @param columnIndex the column to store the values in
775   */
776  protected void setVector(Matrix v, Matrix m, int columnIndex) {
777    m.setMatrix(0, m.getRowDimension() - 1, columnIndex, columnIndex, v);
778  }
779 
780  /**
781   * returns the (column) vector of the matrix at the specified index
782   *
783   * @param m           the matrix to work on
784   * @param columnIndex the column to get the values from
785   * @return            the column vector
786   */
787  protected Matrix getVector(Matrix m, int columnIndex) {
788    return m.getMatrix(0, m.getRowDimension() - 1, columnIndex, columnIndex);
789  }
790
791  /**
792   * determines the dominant eigenvector for the given matrix and returns it
793   *
794   * @param m           the matrix to determine the dominant eigenvector for
795   * @return            the dominant eigenvector
796   */
797  protected Matrix getDominantEigenVector(Matrix m) {
798    EigenvalueDecomposition     eigendecomp;
799    double[]                    eigenvalues;
800    int                         index;
801    Matrix                      result;
802   
803    eigendecomp = m.eig();
804    eigenvalues = eigendecomp.getRealEigenvalues();
805    index       = Utils.maxIndex(eigenvalues);
806    result      = columnAsVector(eigendecomp.getV(), index);
807   
808    return result;
809  }
810 
811  /**
812   * normalizes the given vector (inplace)
813   *
814   * @param v           the vector to normalize
815   */
816  protected void normalizeVector(Matrix v) {
817    double      sum;
818    int         i;
819   
820    // determine length
821    sum = 0;
822    for (i = 0; i < v.getRowDimension(); i++)
823      sum += v.get(i, 0) * v.get(i, 0);
824    sum = StrictMath.sqrt(sum);
825   
826    // normalize content
827    for (i = 0; i < v.getRowDimension(); i++)
828      v.set(i, 0, v.get(i, 0) / sum);
829  }
830
831  /**
832   * processes the instances using the PLS1 algorithm
833   *
834   * @param instances   the data to process
835   * @return            the modified data
836   * @throws Exception  in case the processing goes wrong
837   */
838  protected Instances processPLS1(Instances instances) throws Exception {
839    Matrix      X, X_trans, x;
840    Matrix      y;
841    Matrix      W, w;
842    Matrix      T, t, t_trans;
843    Matrix      P, p, p_trans;
844    double      b;
845    Matrix      b_hat;
846    int         i;
847    int         j;
848    Matrix      X_new;
849    Matrix      tmp;
850    Instances   result;
851    Instances   tmpInst;
852
853    // initialization
854    if (!isFirstBatchDone()) {
855      // split up data
856      X       = getX(instances);
857      y       = getY(instances);
858      X_trans = X.transpose();
859     
860      // init
861      W     = new Matrix(instances.numAttributes() - 1, getNumComponents());
862      P     = new Matrix(instances.numAttributes() - 1, getNumComponents());
863      T     = new Matrix(instances.numInstances(), getNumComponents());
864      b_hat = new Matrix(getNumComponents(), 1);
865     
866      for (j = 0; j < getNumComponents(); j++) {
867        // 1. step: wj
868        w = X_trans.times(y);
869        normalizeVector(w);
870        setVector(w, W, j);
871       
872        // 2. step: tj
873        t       = X.times(w);
874        t_trans = t.transpose();
875        setVector(t, T, j);
876       
877        // 3. step: ^bj
878        b = t_trans.times(y).get(0, 0) / t_trans.times(t).get(0, 0);
879        b_hat.set(j, 0, b);
880       
881        // 4. step: pj
882        p       = X_trans.times(t).times((double) 1 / t_trans.times(t).get(0, 0));
883        p_trans = p.transpose();
884        setVector(p, P, j);
885       
886        // 5. step: Xj+1
887        X = X.minus(t.times(p_trans));
888        y = y.minus(t.times(b));
889      }
890     
891      // W*(P^T*W)^-1
892      tmp = W.times(((P.transpose()).times(W)).inverse());
893     
894      // X_new = X*W*(P^T*W)^-1
895      X_new = getX(instances).times(tmp);
896     
897      // factor = W*(P^T*W)^-1 * b_hat
898      m_PLS1_RegVector = tmp.times(b_hat);
899   
900      // save matrices
901      m_PLS1_P     = P;
902      m_PLS1_W     = W;
903      m_PLS1_b_hat = b_hat;
904     
905      if (getPerformPrediction())
906        result = toInstances(getOutputFormat(), X_new, y);
907      else
908        result = toInstances(getOutputFormat(), X_new, getY(instances));
909    }
910    // prediction
911    else {
912      result = new Instances(getOutputFormat());
913     
914      for (i = 0; i < instances.numInstances(); i++) {
915        // work on each instance
916        tmpInst = new Instances(instances, 0);
917        tmpInst.add((Instance) instances.instance(i).copy());
918        x = getX(tmpInst);
919        X = new Matrix(1, getNumComponents());
920        T = new Matrix(1, getNumComponents());
921       
922        for (j = 0; j < getNumComponents(); j++) {
923          setVector(x, X, j);
924          // 1. step: tj = xj * wj
925          t = x.times(getVector(m_PLS1_W, j));
926          setVector(t, T, j);
927          // 2. step: xj+1 = xj - tj*pj^T (tj is 1x1 matrix!)
928          x = x.minus(getVector(m_PLS1_P, j).transpose().times(t.get(0, 0)));
929        }
930       
931        if (getPerformPrediction())
932          tmpInst = toInstances(getOutputFormat(), T, T.times(m_PLS1_b_hat));
933        else
934          tmpInst = toInstances(getOutputFormat(), T, getY(tmpInst));
935       
936        result.add(tmpInst.instance(0));
937      }
938    }
939   
940    return result;
941  }
942
943  /**
944   * processes the instances using the SIMPLS algorithm
945   *
946   * @param instances   the data to process
947   * @return            the modified data
948   * @throws Exception  in case the processing goes wrong
949   */
950  protected Instances processSIMPLS(Instances instances) throws Exception {
951    Matrix      A, A_trans;
952    Matrix      M;
953    Matrix      X, X_trans;
954    Matrix      X_new;
955    Matrix      Y, y;
956    Matrix      C, c;
957    Matrix      Q, q;
958    Matrix      W, w;
959    Matrix      P, p, p_trans;
960    Matrix      v, v_trans;
961    Matrix      T;
962    Instances   result;
963    int         h;
964   
965    if (!isFirstBatchDone()) {
966      // init
967      X       = getX(instances);
968      X_trans = X.transpose();
969      Y       = getY(instances);
970      A       = X_trans.times(Y);
971      M       = X_trans.times(X);
972      C       = Matrix.identity(instances.numAttributes() - 1, instances.numAttributes() - 1);
973      W       = new Matrix(instances.numAttributes() - 1, getNumComponents());
974      P       = new Matrix(instances.numAttributes() - 1, getNumComponents());
975      Q       = new Matrix(1, getNumComponents());
976     
977      for (h = 0; h < getNumComponents(); h++) {
978        // 1. qh as dominant EigenVector of Ah'*Ah
979        A_trans = A.transpose();
980        q       = getDominantEigenVector(A_trans.times(A));
981       
982        // 2. wh=Ah*qh, ch=wh'*Mh*wh, wh=wh/sqrt(ch), store wh in W as column
983        w       = A.times(q);
984        c       = w.transpose().times(M).times(w);
985        w       = w.times(1.0 / StrictMath.sqrt(c.get(0, 0)));
986        setVector(w, W, h);
987       
988        // 3. ph=Mh*wh, store ph in P as column
989        p       = M.times(w);
990        p_trans = p.transpose();
991        setVector(p, P, h);
992       
993        // 4. qh=Ah'*wh, store qh in Q as column
994        q = A_trans.times(w);
995        setVector(q, Q, h);
996       
997        // 5. vh=Ch*ph, vh=vh/||vh||
998        v       = C.times(p);
999        normalizeVector(v);
1000        v_trans = v.transpose();
1001       
1002        // 6. Ch+1=Ch-vh*vh', Mh+1=Mh-ph*ph'
1003        C = C.minus(v.times(v_trans));
1004        M = M.minus(p.times(p_trans));
1005       
1006        // 7. Ah+1=ChAh (actually Ch+1)
1007        A = C.times(A);
1008      }
1009     
1010      // finish
1011      m_SIMPLS_W = W;
1012      T          = X.times(m_SIMPLS_W);
1013      X_new      = T;
1014      m_SIMPLS_B = W.times(Q.transpose());
1015     
1016      if (getPerformPrediction())
1017        y = T.times(P.transpose()).times(m_SIMPLS_B);
1018      else
1019        y = getY(instances);
1020
1021      result = toInstances(getOutputFormat(), X_new, y);
1022    }
1023    else {
1024      result = new Instances(getOutputFormat());
1025     
1026      X     = getX(instances);
1027      X_new = X.times(m_SIMPLS_W);
1028     
1029      if (getPerformPrediction())
1030        y = X.times(m_SIMPLS_B);
1031      else
1032        y = getY(instances);
1033     
1034      result = toInstances(getOutputFormat(), X_new, y);
1035    }
1036   
1037    return result;
1038  }
1039
1040  /**
1041   * Returns the Capabilities of this filter.
1042   *
1043   * @return            the capabilities of this object
1044   * @see               Capabilities
1045   */
1046  public Capabilities getCapabilities() {
1047    Capabilities result = super.getCapabilities();
1048    result.disableAll();
1049
1050    // attributes
1051    result.enable(Capability.NUMERIC_ATTRIBUTES);
1052    result.enable(Capability.DATE_ATTRIBUTES);
1053    result.enable(Capability.MISSING_VALUES);
1054   
1055    // class
1056    result.enable(Capability.NUMERIC_CLASS);
1057    result.enable(Capability.DATE_CLASS);
1058   
1059    return result;
1060  }
1061 
1062  /**
1063   * Processes the given data (may change the provided dataset) and returns
1064   * the modified version. This method is called in batchFinished().
1065   *
1066   * @param instances   the data to process
1067   * @return            the modified data
1068   * @throws Exception  in case the processing goes wrong
1069   * @see               #batchFinished()
1070   */
1071  protected Instances process(Instances instances) throws Exception {
1072    Instances   result;
1073    int         i;
1074    double      clsValue;
1075    double[]    clsValues;
1076   
1077    result = null;
1078
1079    // save original class values if no prediction is performed
1080    if (!getPerformPrediction())
1081      clsValues = instances.attributeToDoubleArray(instances.classIndex());
1082    else
1083      clsValues = null;
1084   
1085    if (!isFirstBatchDone()) {
1086      // init filters
1087      if (m_ReplaceMissing)
1088        m_Missing.setInputFormat(instances);
1089     
1090      switch (m_Preprocessing) {
1091        case PREPROCESSING_CENTER:
1092          m_ClassMean   = instances.meanOrMode(instances.classIndex());
1093          m_ClassStdDev = 1;
1094          m_Filter      = new Center();
1095          ((Center) m_Filter).setIgnoreClass(true);
1096          break;
1097        case PREPROCESSING_STANDARDIZE:
1098          m_ClassMean   = instances.meanOrMode(instances.classIndex());
1099          m_ClassStdDev = StrictMath.sqrt(instances.variance(instances.classIndex()));
1100          m_Filter      = new Standardize();
1101          ((Standardize) m_Filter).setIgnoreClass(true);
1102          break;
1103        default:
1104          m_ClassMean   = 0;
1105          m_ClassStdDev = 1;
1106          m_Filter      = null;
1107      }
1108      if (m_Filter != null)
1109        m_Filter.setInputFormat(instances);
1110    }
1111   
1112    // filter data
1113    if (m_ReplaceMissing)
1114      instances = Filter.useFilter(instances, m_Missing);
1115    if (m_Filter != null)
1116      instances = Filter.useFilter(instances, m_Filter);
1117   
1118    switch (m_Algorithm) {
1119      case ALGORITHM_SIMPLS:
1120        result = processSIMPLS(instances);
1121        break;
1122      case ALGORITHM_PLS1:
1123        result = processPLS1(instances);
1124        break;
1125      default:
1126        throw new IllegalStateException(
1127            "Algorithm type '" + m_Algorithm + "' is not recognized!");
1128    }
1129
1130    // add the mean to the class again if predictions are to be performed,
1131    // otherwise restore original class values
1132    for (i = 0; i < result.numInstances(); i++) {
1133      if (!getPerformPrediction()) {
1134        result.instance(i).setClassValue(clsValues[i]);
1135      }
1136      else {
1137        clsValue = result.instance(i).classValue();
1138        result.instance(i).setClassValue(clsValue*m_ClassStdDev + m_ClassMean);
1139      }
1140    }
1141   
1142    return result;
1143  }
1144 
1145  /**
1146   * Returns the revision string.
1147   *
1148   * @return            the revision
1149   */
1150  public String getRevision() {
1151    return RevisionUtils.extract("$Revision: 5987 $");
1152  }
1153
1154  /**
1155   * runs the filter with the given arguments.
1156   *
1157   * @param args      the commandline arguments
1158   */
1159  public static void main(String[] args) {
1160    runFilter(new PLSFilter(), args);
1161  }
1162}
Note: See TracBrowser for help on using the repository browser.