source: src/main/java/weka/classifiers/functions/neural/SigmoidUnit.java @ 14

Last change on this file since 14 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 3.7 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    SigmoidUnit.java
19 *    Copyright (C) 2001 University of Waikato, Hamilton, New Zealand
20 */
21
22package weka.classifiers.functions.neural;
23
24import weka.core.RevisionHandler;
25import weka.core.RevisionUtils;
26
27/**
28 * This can be used by the
29 * neuralnode to perform all it's computations (as a sigmoid unit).
30 *
31 * @author Malcolm Ware (mfw4@cs.waikato.ac.nz)
32 * @version $Revision: 1.7 $
33 */
34public class SigmoidUnit
35  implements NeuralMethod, RevisionHandler {
36
37  /** for serialization */
38  private static final long serialVersionUID = -5162958458177475652L;
39 
40  /**
41   * This function calculates what the output value should be.
42   * @param node The node to calculate the value for.
43   * @return The value.
44   */
45  public double outputValue(NeuralNode node) {
46    double[] weights = node.getWeights();
47    NeuralConnection[] inputs = node.getInputs();
48    double value = weights[0];
49    for (int noa = 0; noa < node.getNumInputs(); noa++) {
50     
51      value += inputs[noa].outputValue(true) 
52        * weights[noa+1];
53    }
54     
55    //this I got from the Neural Network faq to combat overflow
56    //pretty simple solution really :)
57    if (value < -45) {
58      value = 0;
59    }
60    else if (value > 45) {
61      value = 1;
62    }
63    else {
64      value = 1 / (1 + Math.exp(-value));
65    } 
66    return value;
67  }
68 
69  /**
70   * This function calculates what the error value should be.
71   * @param node The node to calculate the error for.
72   * @return The error.
73   */
74  public double errorValue(NeuralNode node) {
75    //then calculate the error.
76   
77    NeuralConnection[] outputs = node.getOutputs();
78    int[] oNums = node.getOutputNums();
79    double error = 0;
80   
81    for (int noa = 0; noa < node.getNumOutputs(); noa++) {
82      error += outputs[noa].errorValue(true) 
83        * outputs[noa].weightValue(oNums[noa]);
84    }
85    double value = node.outputValue(false);
86    error *= value * (1 - value);
87   
88    return error;
89  }
90
91  /**
92   * This function will calculate what the change in weights should be
93   * and also update them.
94   * @param node The node to update the weights for.
95   * @param learn The learning rate to use.
96   * @param momentum The momentum to use.
97   */
98  public void updateWeights(NeuralNode node, double learn, double momentum) {
99
100    NeuralConnection[] inputs = node.getInputs();
101    double[] cWeights = node.getChangeInWeights();
102    double[] weights = node.getWeights();
103    double learnTimesError = 0;
104    learnTimesError = learn * node.errorValue(false);
105    double c = learnTimesError + momentum * cWeights[0];
106    weights[0] += c;
107    cWeights[0] = c;
108 
109    int stopValue = node.getNumInputs() + 1;
110    for (int noa = 1; noa < stopValue; noa++) {
111     
112      c = learnTimesError * inputs[noa-1].outputValue(false);
113      c += momentum * cWeights[noa];
114     
115      weights[noa] += c;
116      cWeights[noa] = c; 
117    }
118  }
119 
120  /**
121   * Returns the revision string.
122   *
123   * @return            the revision
124   */
125  public String getRevision() {
126    return RevisionUtils.extract("$Revision: 1.7 $");
127  }
128}
Note: See TracBrowser for help on using the repository browser.