source: branches/MetisMQI/src/main/java/weka/core/matrix/LinearRegression.java

Last change on this file was 29, checked in by gnappo, 15 years ago

Taggata versione per la demo e aggiunto branch.

File size: 3.9 KB
Line 
1/*
2 * This software is a cooperative product of The MathWorks and the National
3 * Institute of Standards and Technology (NIST) which has been released to the
4 * public domain. Neither The MathWorks nor NIST assumes any responsibility
5 * whatsoever for its use by other parties, and makes no guarantees, expressed
6 * or implied, about its quality, reliability, or any other characteristic.
7 */
8
9/*
10 * LinearRegression.java
11 * Copyright (C) 2005 University of Waikato, Hamilton, New Zealand
12 *
13 */
14
15package weka.core.matrix;
16
17import weka.core.RevisionHandler;
18import weka.core.RevisionUtils;
19import weka.core.Utils;
20
21/**
22 * Class for performing (ridged) linear regression.
23 *
24 * @author Fracpete (fracpete at waikato dot ac dot nz)
25 * @version $Revision: 5953 $
26 */
27 
28public class LinearRegression
29  implements RevisionHandler {
30
31  /** the coefficients */
32  protected double[] m_Coefficients = null;
33
34  /**
35   * Performs a (ridged) linear regression.
36   *
37   * @param a the matrix to perform the regression on
38   * @param y the dependent variable vector
39   * @param ridge the ridge parameter
40   * @throws IllegalArgumentException if not successful
41   */
42  public LinearRegression(Matrix a, Matrix y, double ridge) {
43    calculate(a, y, ridge);
44  }
45
46  /**
47   * Performs a weighted (ridged) linear regression.
48   *
49   * @param a the matrix to perform the regression on
50   * @param y the dependent variable vector
51   * @param w the array of data point weights
52   * @param ridge the ridge parameter
53   * @throws IllegalArgumentException if the wrong number of weights were
54   * provided.
55   */
56  public LinearRegression(Matrix a, Matrix y, double[] w, double ridge) {
57
58    if (w.length != a.getRowDimension())
59      throw new IllegalArgumentException("Incorrect number of weights provided");
60    Matrix weightedThis = new Matrix(
61                              a.getRowDimension(), a.getColumnDimension());
62    Matrix weightedDep = new Matrix(a.getRowDimension(), 1);
63    for (int i = 0; i < w.length; i++) {
64      double sqrt_weight = Math.sqrt(w[i]);
65      for (int j = 0; j < a.getColumnDimension(); j++)
66        weightedThis.set(i, j, a.get(i, j) * sqrt_weight);
67      weightedDep.set(i, 0, y.get(i, 0) * sqrt_weight);
68    }
69
70    calculate(weightedThis, weightedDep, ridge);
71  }
72
73  /**
74   * performs the actual regression.
75   *
76   * @param a the matrix to perform the regression on
77   * @param y the dependent variable vector
78   * @param ridge the ridge parameter
79   * @throws IllegalArgumentException if not successful
80   */
81  protected void calculate(Matrix a, Matrix y, double ridge) {
82
83    if (y.getColumnDimension() > 1)
84      throw new IllegalArgumentException("Only one dependent variable allowed");
85
86    int nc = a.getColumnDimension();
87    m_Coefficients = new double[nc];
88    Matrix xt = a.transpose();
89    Matrix solution;
90
91    boolean success = true;
92
93    do {
94      Matrix ss = xt.times(a);
95
96      // Set ridge regression adjustment
97      for (int i = 0; i < nc; i++)
98        ss.set(i, i, ss.get(i, i) + ridge);
99
100      // Carry out the regression
101      Matrix bb = xt.times(y);
102      for(int i = 0; i < nc; i++)
103        m_Coefficients[i] = bb.get(i, 0);
104
105      try {
106        solution = ss.solve(new Matrix(m_Coefficients, m_Coefficients.length));
107        for (int i = 0; i < nc; i++)
108          m_Coefficients[i] = solution.get(i, 0);
109        success = true;
110      } 
111      catch (Exception ex) {
112        ridge *= 10;
113        success = false;
114      }
115    } while (!success);
116  }
117
118  /**
119   * returns the calculated coefficients
120   *
121   * @return the coefficients
122   */
123  public final double[] getCoefficients() {
124    return m_Coefficients;
125  }
126
127  /**
128   * returns the coefficients in a string representation
129   */
130  public String toString() {
131    return Utils.arrayToString(getCoefficients());
132  }
133 
134  /**
135   * Returns the revision string.
136   *
137   * @return            the revision
138   */
139  public String getRevision() {
140    return RevisionUtils.extract("$Revision: 5953 $");
141  }
142}
Note: See TracBrowser for help on using the repository browser.