| 1 | /* |
|---|
| 2 | * This program is free software; you can redistribute it and/or modify |
|---|
| 3 | * it under the terms of the GNU General Public License as published by |
|---|
| 4 | * the Free Software Foundation; either version 2 of the License, or |
|---|
| 5 | * (at your option) any later version. |
|---|
| 6 | * |
|---|
| 7 | * This program is distributed in the hope that it will be useful, |
|---|
| 8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
|---|
| 9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
|---|
| 10 | * GNU General Public License for more details. |
|---|
| 11 | * |
|---|
| 12 | * You should have received a copy of the GNU General Public License |
|---|
| 13 | * along with this program; if not, write to the Free Software |
|---|
| 14 | * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. |
|---|
| 15 | */ |
|---|
| 16 | |
|---|
| 17 | /* |
|---|
| 18 | * SigmoidUnit.java |
|---|
| 19 | * Copyright (C) 2001 University of Waikato, Hamilton, New Zealand |
|---|
| 20 | */ |
|---|
| 21 | |
|---|
| 22 | package weka.classifiers.functions.neural; |
|---|
| 23 | |
|---|
| 24 | import weka.core.RevisionHandler; |
|---|
| 25 | import weka.core.RevisionUtils; |
|---|
| 26 | |
|---|
| 27 | /** |
|---|
| 28 | * This can be used by the |
|---|
| 29 | * neuralnode to perform all it's computations (as a sigmoid unit). |
|---|
| 30 | * |
|---|
| 31 | * @author Malcolm Ware (mfw4@cs.waikato.ac.nz) |
|---|
| 32 | * @version $Revision: 1.7 $ |
|---|
| 33 | */ |
|---|
| 34 | public class SigmoidUnit |
|---|
| 35 | implements NeuralMethod, RevisionHandler { |
|---|
| 36 | |
|---|
| 37 | /** for serialization */ |
|---|
| 38 | private static final long serialVersionUID = -5162958458177475652L; |
|---|
| 39 | |
|---|
| 40 | /** |
|---|
| 41 | * This function calculates what the output value should be. |
|---|
| 42 | * @param node The node to calculate the value for. |
|---|
| 43 | * @return The value. |
|---|
| 44 | */ |
|---|
| 45 | public double outputValue(NeuralNode node) { |
|---|
| 46 | double[] weights = node.getWeights(); |
|---|
| 47 | NeuralConnection[] inputs = node.getInputs(); |
|---|
| 48 | double value = weights[0]; |
|---|
| 49 | for (int noa = 0; noa < node.getNumInputs(); noa++) { |
|---|
| 50 | |
|---|
| 51 | value += inputs[noa].outputValue(true) |
|---|
| 52 | * weights[noa+1]; |
|---|
| 53 | } |
|---|
| 54 | |
|---|
| 55 | //this I got from the Neural Network faq to combat overflow |
|---|
| 56 | //pretty simple solution really :) |
|---|
| 57 | if (value < -45) { |
|---|
| 58 | value = 0; |
|---|
| 59 | } |
|---|
| 60 | else if (value > 45) { |
|---|
| 61 | value = 1; |
|---|
| 62 | } |
|---|
| 63 | else { |
|---|
| 64 | value = 1 / (1 + Math.exp(-value)); |
|---|
| 65 | } |
|---|
| 66 | return value; |
|---|
| 67 | } |
|---|
| 68 | |
|---|
| 69 | /** |
|---|
| 70 | * This function calculates what the error value should be. |
|---|
| 71 | * @param node The node to calculate the error for. |
|---|
| 72 | * @return The error. |
|---|
| 73 | */ |
|---|
| 74 | public double errorValue(NeuralNode node) { |
|---|
| 75 | //then calculate the error. |
|---|
| 76 | |
|---|
| 77 | NeuralConnection[] outputs = node.getOutputs(); |
|---|
| 78 | int[] oNums = node.getOutputNums(); |
|---|
| 79 | double error = 0; |
|---|
| 80 | |
|---|
| 81 | for (int noa = 0; noa < node.getNumOutputs(); noa++) { |
|---|
| 82 | error += outputs[noa].errorValue(true) |
|---|
| 83 | * outputs[noa].weightValue(oNums[noa]); |
|---|
| 84 | } |
|---|
| 85 | double value = node.outputValue(false); |
|---|
| 86 | error *= value * (1 - value); |
|---|
| 87 | |
|---|
| 88 | return error; |
|---|
| 89 | } |
|---|
| 90 | |
|---|
| 91 | /** |
|---|
| 92 | * This function will calculate what the change in weights should be |
|---|
| 93 | * and also update them. |
|---|
| 94 | * @param node The node to update the weights for. |
|---|
| 95 | * @param learn The learning rate to use. |
|---|
| 96 | * @param momentum The momentum to use. |
|---|
| 97 | */ |
|---|
| 98 | public void updateWeights(NeuralNode node, double learn, double momentum) { |
|---|
| 99 | |
|---|
| 100 | NeuralConnection[] inputs = node.getInputs(); |
|---|
| 101 | double[] cWeights = node.getChangeInWeights(); |
|---|
| 102 | double[] weights = node.getWeights(); |
|---|
| 103 | double learnTimesError = 0; |
|---|
| 104 | learnTimesError = learn * node.errorValue(false); |
|---|
| 105 | double c = learnTimesError + momentum * cWeights[0]; |
|---|
| 106 | weights[0] += c; |
|---|
| 107 | cWeights[0] = c; |
|---|
| 108 | |
|---|
| 109 | int stopValue = node.getNumInputs() + 1; |
|---|
| 110 | for (int noa = 1; noa < stopValue; noa++) { |
|---|
| 111 | |
|---|
| 112 | c = learnTimesError * inputs[noa-1].outputValue(false); |
|---|
| 113 | c += momentum * cWeights[noa]; |
|---|
| 114 | |
|---|
| 115 | weights[noa] += c; |
|---|
| 116 | cWeights[noa] = c; |
|---|
| 117 | } |
|---|
| 118 | } |
|---|
| 119 | |
|---|
| 120 | /** |
|---|
| 121 | * Returns the revision string. |
|---|
| 122 | * |
|---|
| 123 | * @return the revision |
|---|
| 124 | */ |
|---|
| 125 | public String getRevision() { |
|---|
| 126 | return RevisionUtils.extract("$Revision: 1.7 $"); |
|---|
| 127 | } |
|---|
| 128 | } |
|---|