1 | /* |
---|
2 | * This program is free software; you can redistribute it and/or modify |
---|
3 | * it under the terms of the GNU General Public License as published by |
---|
4 | * the Free Software Foundation; either version 2 of the License, or |
---|
5 | * (at your option) any later version. |
---|
6 | * |
---|
7 | * This program is distributed in the hope that it will be useful, |
---|
8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
---|
10 | * GNU General Public License for more details. |
---|
11 | * |
---|
12 | * You should have received a copy of the GNU General Public License |
---|
13 | * along with this program; if not, write to the Free Software |
---|
14 | * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. |
---|
15 | */ |
---|
16 | |
---|
17 | /* |
---|
18 | * RuleNode.java |
---|
19 | * Copyright (C) 2000 University of Waikato, Hamilton, New Zealand |
---|
20 | * |
---|
21 | */ |
---|
22 | |
---|
23 | package weka.classifiers.trees.m5; |
---|
24 | |
---|
25 | import weka.classifiers.Classifier; |
---|
26 | import weka.classifiers.AbstractClassifier; |
---|
27 | import weka.core.Instance; |
---|
28 | import weka.core.Instances; |
---|
29 | import weka.core.RevisionUtils; |
---|
30 | import weka.core.Utils; |
---|
31 | |
---|
32 | import java.io.Serializable; |
---|
33 | |
---|
34 | /** |
---|
35 | * This class encapsulates a linear regression function. It is a classifier |
---|
36 | * but does not learn the function itself, instead it is constructed with |
---|
37 | * coefficients and intercept obtained elsewhere. The buildClassifier method |
---|
38 | * must still be called however as this stores a copy of the training data's |
---|
39 | * header for use in printing the model to the console. |
---|
40 | * |
---|
41 | * @author Mark Hall (mhall@cs.waikato.ac.nz) |
---|
42 | * @version $Revision: 5928 $ |
---|
43 | */ |
---|
44 | public class PreConstructedLinearModel |
---|
45 | extends AbstractClassifier |
---|
46 | implements Serializable { |
---|
47 | |
---|
48 | /** for serialization */ |
---|
49 | static final long serialVersionUID = 2030974097051713247L; |
---|
50 | |
---|
51 | /** The coefficients */ |
---|
52 | private double [] m_coefficients; |
---|
53 | |
---|
54 | /** The intercept */ |
---|
55 | private double m_intercept; |
---|
56 | |
---|
57 | /** Holds the instances header for printing the model */ |
---|
58 | private Instances m_instancesHeader; |
---|
59 | |
---|
60 | /** number of coefficients in the model */ |
---|
61 | private int m_numParameters; |
---|
62 | |
---|
63 | /** |
---|
64 | * Constructor |
---|
65 | * |
---|
66 | * @param coeffs an array of coefficients |
---|
67 | * @param intercept the intercept |
---|
68 | */ |
---|
69 | public PreConstructedLinearModel(double [] coeffs, double intercept) { |
---|
70 | m_coefficients = coeffs; |
---|
71 | m_intercept = intercept; |
---|
72 | int count = 0; |
---|
73 | for (int i = 0; i < coeffs.length; i++) { |
---|
74 | if (coeffs[i] != 0) { |
---|
75 | count++; |
---|
76 | } |
---|
77 | } |
---|
78 | m_numParameters = count; |
---|
79 | } |
---|
80 | |
---|
81 | /** |
---|
82 | * Builds the classifier. In this case all that is done is that a |
---|
83 | * copy of the training instances header is saved. |
---|
84 | * |
---|
85 | * @param instances an <code>Instances</code> value |
---|
86 | * @exception Exception if an error occurs |
---|
87 | */ |
---|
88 | public void buildClassifier(Instances instances) throws Exception { |
---|
89 | m_instancesHeader = new Instances(instances, 0); |
---|
90 | } |
---|
91 | |
---|
92 | /** |
---|
93 | * Predicts the class of the supplied instance using the linear model. |
---|
94 | * |
---|
95 | * @param inst the instance to make a prediction for |
---|
96 | * @return the prediction |
---|
97 | * @exception Exception if an error occurs |
---|
98 | */ |
---|
99 | public double classifyInstance(Instance inst) throws Exception { |
---|
100 | double result = 0; |
---|
101 | |
---|
102 | // System.out.println(inst); |
---|
103 | for (int i = 0; i < m_coefficients.length; i++) { |
---|
104 | if (i != inst.classIndex() && !inst.isMissing(i)) { |
---|
105 | // System.out.println(inst.value(i)+" "+m_coefficients[i]); |
---|
106 | result += m_coefficients[i] * inst.value(i); |
---|
107 | } |
---|
108 | } |
---|
109 | |
---|
110 | result += m_intercept; |
---|
111 | return result; |
---|
112 | } |
---|
113 | |
---|
114 | /** |
---|
115 | * Return the number of parameters (coefficients) in the linear model |
---|
116 | * |
---|
117 | * @return the number of parameters |
---|
118 | */ |
---|
119 | public int numParameters() { |
---|
120 | return m_numParameters; |
---|
121 | } |
---|
122 | |
---|
123 | /** |
---|
124 | * Return the array of coefficients |
---|
125 | * |
---|
126 | * @return the coefficients |
---|
127 | */ |
---|
128 | public double [] coefficients() { |
---|
129 | return m_coefficients; |
---|
130 | } |
---|
131 | |
---|
132 | /** |
---|
133 | * Return the intercept |
---|
134 | * |
---|
135 | * @return the intercept |
---|
136 | */ |
---|
137 | public double intercept() { |
---|
138 | return m_intercept; |
---|
139 | } |
---|
140 | |
---|
141 | /** |
---|
142 | * Returns a textual description of this linear model |
---|
143 | * |
---|
144 | * @return String containing a description of this linear model |
---|
145 | */ |
---|
146 | public String toString() { |
---|
147 | StringBuffer b = new StringBuffer(); |
---|
148 | b.append("\n"+m_instancesHeader.classAttribute().name() + " = "); |
---|
149 | boolean first = true; |
---|
150 | for (int i = 0; i < m_coefficients.length; i++) { |
---|
151 | if (m_coefficients[i] != 0.0) { |
---|
152 | double c = m_coefficients[i]; |
---|
153 | if (first) { |
---|
154 | b.append("\n\t" + Utils.doubleToString(c, 12, 4).trim() + " * " |
---|
155 | + m_instancesHeader.attribute(i).name() + " "); |
---|
156 | first = false; |
---|
157 | } else { |
---|
158 | b.append("\n\t" + ((m_coefficients[i] < 0) ? |
---|
159 | "- " + Utils.doubleToString(Math.abs(c), 12, 4).trim() : "+ " |
---|
160 | + Utils.doubleToString(Math.abs(c), 12, 4).trim()) + " * " |
---|
161 | + m_instancesHeader.attribute(i).name() + " "); |
---|
162 | } |
---|
163 | } |
---|
164 | } |
---|
165 | |
---|
166 | b.append("\n\t" + ((m_intercept < 0) ? "- " : "+ ") |
---|
167 | + Utils.doubleToString(Math.abs(m_intercept), 12, 4).trim()); |
---|
168 | return b.toString(); |
---|
169 | } |
---|
170 | |
---|
171 | /** |
---|
172 | * Returns the revision string. |
---|
173 | * |
---|
174 | * @return the revision |
---|
175 | */ |
---|
176 | public String getRevision() { |
---|
177 | return RevisionUtils.extract("$Revision: 5928 $"); |
---|
178 | } |
---|
179 | } |
---|