| 1 | /* |
|---|
| 2 | * This software is a cooperative product of The MathWorks and the National |
|---|
| 3 | * Institute of Standards and Technology (NIST) which has been released to the |
|---|
| 4 | * public domain. Neither The MathWorks nor NIST assumes any responsibility |
|---|
| 5 | * whatsoever for its use by other parties, and makes no guarantees, expressed |
|---|
| 6 | * or implied, about its quality, reliability, or any other characteristic. |
|---|
| 7 | */ |
|---|
| 8 | |
|---|
| 9 | /* |
|---|
| 10 | * LinearRegression.java |
|---|
| 11 | * Copyright (C) 2005 University of Waikato, Hamilton, New Zealand |
|---|
| 12 | * |
|---|
| 13 | */ |
|---|
| 14 | |
|---|
| 15 | package weka.core.matrix; |
|---|
| 16 | |
|---|
| 17 | import weka.core.RevisionHandler; |
|---|
| 18 | import weka.core.RevisionUtils; |
|---|
| 19 | import weka.core.Utils; |
|---|
| 20 | |
|---|
| 21 | /** |
|---|
| 22 | * Class for performing (ridged) linear regression. |
|---|
| 23 | * |
|---|
| 24 | * @author Fracpete (fracpete at waikato dot ac dot nz) |
|---|
| 25 | * @version $Revision: 5953 $ |
|---|
| 26 | */ |
|---|
| 27 | |
|---|
| 28 | public class LinearRegression |
|---|
| 29 | implements RevisionHandler { |
|---|
| 30 | |
|---|
| 31 | /** the coefficients */ |
|---|
| 32 | protected double[] m_Coefficients = null; |
|---|
| 33 | |
|---|
| 34 | /** |
|---|
| 35 | * Performs a (ridged) linear regression. |
|---|
| 36 | * |
|---|
| 37 | * @param a the matrix to perform the regression on |
|---|
| 38 | * @param y the dependent variable vector |
|---|
| 39 | * @param ridge the ridge parameter |
|---|
| 40 | * @throws IllegalArgumentException if not successful |
|---|
| 41 | */ |
|---|
| 42 | public LinearRegression(Matrix a, Matrix y, double ridge) { |
|---|
| 43 | calculate(a, y, ridge); |
|---|
| 44 | } |
|---|
| 45 | |
|---|
| 46 | /** |
|---|
| 47 | * Performs a weighted (ridged) linear regression. |
|---|
| 48 | * |
|---|
| 49 | * @param a the matrix to perform the regression on |
|---|
| 50 | * @param y the dependent variable vector |
|---|
| 51 | * @param w the array of data point weights |
|---|
| 52 | * @param ridge the ridge parameter |
|---|
| 53 | * @throws IllegalArgumentException if the wrong number of weights were |
|---|
| 54 | * provided. |
|---|
| 55 | */ |
|---|
| 56 | public LinearRegression(Matrix a, Matrix y, double[] w, double ridge) { |
|---|
| 57 | |
|---|
| 58 | if (w.length != a.getRowDimension()) |
|---|
| 59 | throw new IllegalArgumentException("Incorrect number of weights provided"); |
|---|
| 60 | Matrix weightedThis = new Matrix( |
|---|
| 61 | a.getRowDimension(), a.getColumnDimension()); |
|---|
| 62 | Matrix weightedDep = new Matrix(a.getRowDimension(), 1); |
|---|
| 63 | for (int i = 0; i < w.length; i++) { |
|---|
| 64 | double sqrt_weight = Math.sqrt(w[i]); |
|---|
| 65 | for (int j = 0; j < a.getColumnDimension(); j++) |
|---|
| 66 | weightedThis.set(i, j, a.get(i, j) * sqrt_weight); |
|---|
| 67 | weightedDep.set(i, 0, y.get(i, 0) * sqrt_weight); |
|---|
| 68 | } |
|---|
| 69 | |
|---|
| 70 | calculate(weightedThis, weightedDep, ridge); |
|---|
| 71 | } |
|---|
| 72 | |
|---|
| 73 | /** |
|---|
| 74 | * performs the actual regression. |
|---|
| 75 | * |
|---|
| 76 | * @param a the matrix to perform the regression on |
|---|
| 77 | * @param y the dependent variable vector |
|---|
| 78 | * @param ridge the ridge parameter |
|---|
| 79 | * @throws IllegalArgumentException if not successful |
|---|
| 80 | */ |
|---|
| 81 | protected void calculate(Matrix a, Matrix y, double ridge) { |
|---|
| 82 | |
|---|
| 83 | if (y.getColumnDimension() > 1) |
|---|
| 84 | throw new IllegalArgumentException("Only one dependent variable allowed"); |
|---|
| 85 | |
|---|
| 86 | int nc = a.getColumnDimension(); |
|---|
| 87 | m_Coefficients = new double[nc]; |
|---|
| 88 | Matrix xt = a.transpose(); |
|---|
| 89 | Matrix solution; |
|---|
| 90 | |
|---|
| 91 | boolean success = true; |
|---|
| 92 | |
|---|
| 93 | do { |
|---|
| 94 | Matrix ss = xt.times(a); |
|---|
| 95 | |
|---|
| 96 | // Set ridge regression adjustment |
|---|
| 97 | for (int i = 0; i < nc; i++) |
|---|
| 98 | ss.set(i, i, ss.get(i, i) + ridge); |
|---|
| 99 | |
|---|
| 100 | // Carry out the regression |
|---|
| 101 | Matrix bb = xt.times(y); |
|---|
| 102 | for(int i = 0; i < nc; i++) |
|---|
| 103 | m_Coefficients[i] = bb.get(i, 0); |
|---|
| 104 | |
|---|
| 105 | try { |
|---|
| 106 | solution = ss.solve(new Matrix(m_Coefficients, m_Coefficients.length)); |
|---|
| 107 | for (int i = 0; i < nc; i++) |
|---|
| 108 | m_Coefficients[i] = solution.get(i, 0); |
|---|
| 109 | success = true; |
|---|
| 110 | } |
|---|
| 111 | catch (Exception ex) { |
|---|
| 112 | ridge *= 10; |
|---|
| 113 | success = false; |
|---|
| 114 | } |
|---|
| 115 | } while (!success); |
|---|
| 116 | } |
|---|
| 117 | |
|---|
| 118 | /** |
|---|
| 119 | * returns the calculated coefficients |
|---|
| 120 | * |
|---|
| 121 | * @return the coefficients |
|---|
| 122 | */ |
|---|
| 123 | public final double[] getCoefficients() { |
|---|
| 124 | return m_Coefficients; |
|---|
| 125 | } |
|---|
| 126 | |
|---|
| 127 | /** |
|---|
| 128 | * returns the coefficients in a string representation |
|---|
| 129 | */ |
|---|
| 130 | public String toString() { |
|---|
| 131 | return Utils.arrayToString(getCoefficients()); |
|---|
| 132 | } |
|---|
| 133 | |
|---|
| 134 | /** |
|---|
| 135 | * Returns the revision string. |
|---|
| 136 | * |
|---|
| 137 | * @return the revision |
|---|
| 138 | */ |
|---|
| 139 | public String getRevision() { |
|---|
| 140 | return RevisionUtils.extract("$Revision: 5953 $"); |
|---|
| 141 | } |
|---|
| 142 | } |
|---|