1 | /* |
---|
2 | * This program is free software; you can redistribute it and/or modify |
---|
3 | * it under the terms of the GNU General Public License as published by |
---|
4 | * the Free Software Foundation; either version 2 of the License, or |
---|
5 | * (at your option) any later version. |
---|
6 | * |
---|
7 | * This program is distributed in the hope that it will be useful, |
---|
8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
---|
10 | * GNU General Public License for more details. |
---|
11 | * |
---|
12 | * You should have received a copy of the GNU General Public License |
---|
13 | * along with this program; if not, write to the Free Software |
---|
14 | * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. |
---|
15 | */ |
---|
16 | |
---|
17 | /* |
---|
18 | * SigmoidUnit.java |
---|
19 | * Copyright (C) 2001 University of Waikato, Hamilton, New Zealand |
---|
20 | */ |
---|
21 | |
---|
22 | package weka.classifiers.functions.neural; |
---|
23 | |
---|
24 | import weka.core.RevisionHandler; |
---|
25 | import weka.core.RevisionUtils; |
---|
26 | |
---|
27 | /** |
---|
28 | * This can be used by the |
---|
29 | * neuralnode to perform all it's computations (as a sigmoid unit). |
---|
30 | * |
---|
31 | * @author Malcolm Ware (mfw4@cs.waikato.ac.nz) |
---|
32 | * @version $Revision: 1.7 $ |
---|
33 | */ |
---|
34 | public class SigmoidUnit |
---|
35 | implements NeuralMethod, RevisionHandler { |
---|
36 | |
---|
37 | /** for serialization */ |
---|
38 | private static final long serialVersionUID = -5162958458177475652L; |
---|
39 | |
---|
40 | /** |
---|
41 | * This function calculates what the output value should be. |
---|
42 | * @param node The node to calculate the value for. |
---|
43 | * @return The value. |
---|
44 | */ |
---|
45 | public double outputValue(NeuralNode node) { |
---|
46 | double[] weights = node.getWeights(); |
---|
47 | NeuralConnection[] inputs = node.getInputs(); |
---|
48 | double value = weights[0]; |
---|
49 | for (int noa = 0; noa < node.getNumInputs(); noa++) { |
---|
50 | |
---|
51 | value += inputs[noa].outputValue(true) |
---|
52 | * weights[noa+1]; |
---|
53 | } |
---|
54 | |
---|
55 | //this I got from the Neural Network faq to combat overflow |
---|
56 | //pretty simple solution really :) |
---|
57 | if (value < -45) { |
---|
58 | value = 0; |
---|
59 | } |
---|
60 | else if (value > 45) { |
---|
61 | value = 1; |
---|
62 | } |
---|
63 | else { |
---|
64 | value = 1 / (1 + Math.exp(-value)); |
---|
65 | } |
---|
66 | return value; |
---|
67 | } |
---|
68 | |
---|
69 | /** |
---|
70 | * This function calculates what the error value should be. |
---|
71 | * @param node The node to calculate the error for. |
---|
72 | * @return The error. |
---|
73 | */ |
---|
74 | public double errorValue(NeuralNode node) { |
---|
75 | //then calculate the error. |
---|
76 | |
---|
77 | NeuralConnection[] outputs = node.getOutputs(); |
---|
78 | int[] oNums = node.getOutputNums(); |
---|
79 | double error = 0; |
---|
80 | |
---|
81 | for (int noa = 0; noa < node.getNumOutputs(); noa++) { |
---|
82 | error += outputs[noa].errorValue(true) |
---|
83 | * outputs[noa].weightValue(oNums[noa]); |
---|
84 | } |
---|
85 | double value = node.outputValue(false); |
---|
86 | error *= value * (1 - value); |
---|
87 | |
---|
88 | return error; |
---|
89 | } |
---|
90 | |
---|
91 | /** |
---|
92 | * This function will calculate what the change in weights should be |
---|
93 | * and also update them. |
---|
94 | * @param node The node to update the weights for. |
---|
95 | * @param learn The learning rate to use. |
---|
96 | * @param momentum The momentum to use. |
---|
97 | */ |
---|
98 | public void updateWeights(NeuralNode node, double learn, double momentum) { |
---|
99 | |
---|
100 | NeuralConnection[] inputs = node.getInputs(); |
---|
101 | double[] cWeights = node.getChangeInWeights(); |
---|
102 | double[] weights = node.getWeights(); |
---|
103 | double learnTimesError = 0; |
---|
104 | learnTimesError = learn * node.errorValue(false); |
---|
105 | double c = learnTimesError + momentum * cWeights[0]; |
---|
106 | weights[0] += c; |
---|
107 | cWeights[0] = c; |
---|
108 | |
---|
109 | int stopValue = node.getNumInputs() + 1; |
---|
110 | for (int noa = 1; noa < stopValue; noa++) { |
---|
111 | |
---|
112 | c = learnTimesError * inputs[noa-1].outputValue(false); |
---|
113 | c += momentum * cWeights[noa]; |
---|
114 | |
---|
115 | weights[noa] += c; |
---|
116 | cWeights[noa] = c; |
---|
117 | } |
---|
118 | } |
---|
119 | |
---|
120 | /** |
---|
121 | * Returns the revision string. |
---|
122 | * |
---|
123 | * @return the revision |
---|
124 | */ |
---|
125 | public String getRevision() { |
---|
126 | return RevisionUtils.extract("$Revision: 1.7 $"); |
---|
127 | } |
---|
128 | } |
---|