source: src/main/java/weka/classifiers/functions/neural/LinearUnit.java @ 19

Last change on this file since 19 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 3.4 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    LinearUnit.java
19 *    Copyright (C) 2001 University of Waikato, Hamilton, New Zealand
20 */
21
22package weka.classifiers.functions.neural;
23
24import weka.core.RevisionHandler;
25import weka.core.RevisionUtils;
26
27/**
28 * This can be used by the
29 * neuralnode to perform all it's computations (as a Linear unit).
30 *
31 * @author Malcolm Ware (mfw4@cs.waikato.ac.nz)
32 * @version $Revision: 5928 $
33 */
34public class LinearUnit
35  implements NeuralMethod, RevisionHandler {
36
37  /** for serialization */
38  private static final long serialVersionUID = 8572152807755673630L;
39 
40  /**
41   * This function calculates what the output value should be.
42   * @param node The node to calculate the value for.
43   * @return The value.
44   */
45  public double outputValue(NeuralNode node) {
46    double[] weights = node.getWeights();
47    NeuralConnection[] inputs = node.getInputs();
48    double value = weights[0];
49    for (int noa = 0; noa < node.getNumInputs(); noa++) {
50     
51      value += inputs[noa].outputValue(true) 
52        * weights[noa+1];
53    }
54     
55    return value;
56  }
57 
58  /**
59   * This function calculates what the error value should be.
60   * @param node The node to calculate the error for.
61   * @return The error.
62   */
63  public double errorValue(NeuralNode node) {
64    //then calculate the error.
65   
66    NeuralConnection[] outputs = node.getOutputs();
67    int[] oNums = node.getOutputNums();
68    double error = 0;
69 
70    for (int noa = 0; noa < node.getNumOutputs(); noa++) {
71      error += outputs[noa].errorValue(true) 
72        * outputs[noa].weightValue(oNums[noa]);
73    }
74    return error;
75  }
76
77  /**
78   * This function will calculate what the change in weights should be
79   * and also update them.
80   * @param node The node to update the weights for.
81   * @param learn The learning rate to use.
82   * @param momentum The momentum to use.
83   */
84  public void updateWeights(NeuralNode node, double learn, double momentum) {
85
86    NeuralConnection[] inputs = node.getInputs();
87    double[] cWeights = node.getChangeInWeights();
88    double[] weights = node.getWeights();
89   
90    double learnTimesError = 0;
91    learnTimesError = learn * node.errorValue(false);
92   
93    double c = learnTimesError + momentum * cWeights[0];
94    weights[0] += c;
95    cWeights[0] = c;
96     
97    int stopValue = node.getNumInputs() + 1;
98    for (int noa = 1; noa < stopValue; noa++) {
99     
100      c = learnTimesError * inputs[noa-1].outputValue(false);
101      c += momentum * cWeights[noa];
102     
103      weights[noa] += c;
104      cWeights[noa] = c; 
105    }
106  }
107 
108  /**
109   * Returns the revision string.
110   *
111   * @return            the revision
112   */
113  public String getRevision() {
114    return RevisionUtils.extract("$Revision: 5928 $");
115  }
116}
Note: See TracBrowser for help on using the repository browser.