source: src/main/java/weka/classifiers/functions/neural/NeuralNode.java @ 17

Last change on this file since 17 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 9.9 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    NeuralNode.java
19 *    Copyright (C) 2000 University of Waikato, Hamilton, New Zealand
20 */
21
22package weka.classifiers.functions.neural;
23
24import weka.core.RevisionUtils;
25
26import java.util.Random;
27
28/**
29 * This class is used to represent a node in the neuralnet.
30 *
31 * @author Malcolm Ware (mfw4@cs.waikato.ac.nz)
32 * @version $Revision: 5402 $
33 */
34public class NeuralNode
35  extends NeuralConnection {
36
37  /** for serialization */
38  private static final long serialVersionUID = -1085750607680839163L;
39   
40  /** The weights for each of the input connections, and the threshold. */
41  private double[] m_weights;
42 
43  /** The best (lowest error) weights. Only used when validation set is used */
44  private double[] m_bestWeights;
45 
46  /** The change in the weights. */
47  private double[] m_changeInWeights;
48 
49  private Random m_random;
50
51  /** Performs the operations for this node. Currently this
52   * defines that the node is either a sigmoid or a linear unit. */
53  private NeuralMethod m_methods;
54
55  /**
56   * @param id The string name for this node (used to id this node).
57   * @param r A random number generator used to generate initial weights.
58   * @param m The methods this node should use to update.
59   */
60  public NeuralNode(String id, Random r, NeuralMethod m) {
61    super(id);
62    m_weights = new double[1];
63    m_bestWeights = new double[1];
64    m_changeInWeights = new double[1];
65   
66    m_random = r;
67   
68    m_weights[0] = m_random.nextDouble() * .1 - .05;
69    m_changeInWeights[0] = 0;
70
71    m_methods = m;
72  }
73 
74  /**
75   * Set how this node should operate (note that the neural method has no
76   * internal state, so the same object can be used by any number of nodes.
77   * @param m The new method.
78   */
79  public void setMethod(NeuralMethod m) {
80    m_methods = m;
81  } 
82
83  public NeuralMethod getMethod() {
84    return m_methods;
85  }
86
87  /**
88   * Call this to get the output value of this unit.
89   * @param calculate True if the value should be calculated if it hasn't been
90   * already.
91   * @return The output value, or NaN, if the value has not been calculated.
92   */
93  public double outputValue(boolean calculate) {
94   
95    if (Double.isNaN(m_unitValue) && calculate) {
96      //then calculate the output value;
97      m_unitValue = m_methods.outputValue(this);
98    }
99   
100    return m_unitValue;
101  }
102
103 
104  /**
105   * Call this to get the error value of this unit.
106   * @param calculate True if the value should be calculated if it hasn't been
107   * already.
108   * @return The error value, or NaN, if the value has not been calculated.
109   */
110  public double errorValue(boolean calculate) {
111
112    if (!Double.isNaN(m_unitValue) && Double.isNaN(m_unitError) && calculate) {
113      //then calculate the error.
114      m_unitError = m_methods.errorValue(this);
115    }
116    return m_unitError;
117  }
118
119  /**
120   * Call this to reset the value and error for this unit, ready for the next
121   * run. This will also call the reset function of all units that are
122   * connected as inputs to this one.
123   * This is also the time that the update for the listeners will be performed.
124   */
125  public void reset() {
126   
127    if (!Double.isNaN(m_unitValue) || !Double.isNaN(m_unitError)) {
128      m_unitValue = Double.NaN;
129      m_unitError = Double.NaN;
130      m_weightsUpdated = false;
131      for (int noa = 0; noa < m_numInputs; noa++) {
132        m_inputList[noa].reset();
133      }
134    }
135  }
136 
137  /**
138   * Call this to have the connection save the current
139   * weights.
140   */
141  public void saveWeights() {
142    // copy the current weights
143    System.arraycopy(m_weights, 0, m_bestWeights, 0, m_weights.length);
144   
145    // tell inputs to save weights
146    for (int i = 0; i < m_numInputs; i++) {
147      m_inputList[i].saveWeights();
148    }
149  }
150 
151  /**
152   * Call this to have the connection restore from the saved
153   * weights.
154   */
155  public void restoreWeights() {
156    // copy the saved best weights back into the weights
157    System.arraycopy(m_bestWeights, 0, m_weights, 0, m_weights.length);
158   
159    // tell inputs to restore weights
160    for (int i = 0; i < m_numInputs; i++) {
161      m_inputList[i].restoreWeights();
162    }
163  }
164
165  /**
166   * Call this to get the weight value on a particular connection.
167   * @param n The connection number to get the weight for, -1 if The threshold
168   * weight should be returned.
169   * @return The value for the specified connection or if -1 then it should
170   * return the threshold value. If no value exists for the specified
171   * connection, NaN will be returned.
172   */
173  public double weightValue(int n) {
174    if (n >= m_numInputs || n < -1) {
175      return Double.NaN;
176    }
177    return m_weights[n + 1];
178  }
179
180  /**
181   * call this function to get the weights array.
182   * This will also allow the weights to be updated.
183   * @return The weights array.
184   */
185  public double[] getWeights() {
186    return m_weights;
187  }
188
189  /**
190   * call this function to get the chnage in weights array.
191   * This will also allow the change in weights to be updated.
192   * @return The change in weights array.
193   */
194  public double[] getChangeInWeights() {
195    return m_changeInWeights;
196  }
197
198  /**
199   * Call this function to update the weight values at this unit.
200   * After the weights have been updated at this unit, All the
201   * input connections will then be called from this to have their
202   * weights updated.
203   * @param l The learning rate to use.
204   * @param m The momentum to use.
205   */
206  public void updateWeights(double l, double m) {
207   
208    if (!m_weightsUpdated && !Double.isNaN(m_unitError)) {
209      m_methods.updateWeights(this, l, m);
210     
211      //note that the super call to update the inputs is done here and
212      //not in the m_method updateWeights, because it is not deemed to be
213      //required to update the weights at this node (while the error and output
214      //value ao need to be recursively calculated)
215      super.updateWeights(l, m); //to call all of the inputs.
216    }
217   
218  }
219
220  /**
221   * This will connect the specified unit to be an input to this unit.
222   * @param i The unit.
223   * @param n It's connection number for this connection.
224   * @return True if the connection was made, false otherwise.
225   */
226  protected boolean connectInput(NeuralConnection i, int n) {
227   
228    //the function that this overrides can do most of the work.
229    if (!super.connectInput(i, n)) {
230      return false;
231    }
232   
233    //note that the weights are shifted 1 forward in the array so
234    //it leaves the numinputs aligned on the space the weight needs to go.
235    m_weights[m_numInputs] = m_random.nextDouble() * .1 - .05;
236    m_changeInWeights[m_numInputs] = 0;
237   
238    return true;
239  }
240
241  /**
242   * This will allocate more space for input connection information
243   * if the arrays for this have been filled up.
244   */
245  protected void allocateInputs() {
246   
247    NeuralConnection[] temp1 = new NeuralConnection[m_inputList.length + 15];
248    int[] temp2 = new int[m_inputNums.length + 15];
249    double[] temp4 = new double[m_weights.length + 15];
250    double[] temp5 = new double[m_changeInWeights.length + 15];
251    double[] temp6 = new double[m_bestWeights.length + 15];
252
253    temp4[0] = m_weights[0];
254    temp5[0] = m_changeInWeights[0];
255    temp6[0] = m_bestWeights[0];
256    for (int noa = 0; noa < m_numInputs; noa++) {
257      temp1[noa] = m_inputList[noa];
258      temp2[noa] = m_inputNums[noa];
259      temp4[noa+1] = m_weights[noa+1];
260      temp5[noa+1] = m_changeInWeights[noa+1];
261      temp6[noa+1] = m_bestWeights[noa+1];
262    }
263   
264    m_inputList = temp1;
265    m_inputNums = temp2;
266    m_weights = temp4;
267    m_changeInWeights = temp5;
268    m_bestWeights = temp6;
269  }
270
271 
272 
273
274  /**
275   * This will disconnect the input with the specific connection number
276   * From this node (only on this end however).
277   * @param i The unit to disconnect.
278   * @param n The connection number at the other end, -1 if all the connections
279   * to this unit should be severed (not the same as removeAllInputs).
280   * @return True if the connection was removed, false if the connection was
281   * not found.
282   */
283  protected boolean disconnectInput(NeuralConnection i, int n) {
284   
285    int loc = -1;
286    boolean removed = false;
287    do {
288      loc = -1;
289      for (int noa = 0; noa < m_numInputs; noa++) {
290        if (i == m_inputList[noa] && (n == -1 || n == m_inputNums[noa])) {
291          loc = noa;
292          break;
293        }
294      }
295     
296      if (loc >= 0) {
297        for (int noa = loc+1; noa < m_numInputs; noa++) {
298          m_inputList[noa-1] = m_inputList[noa];
299          m_inputNums[noa-1] = m_inputNums[noa];
300         
301          m_weights[noa] = m_weights[noa+1];
302          m_changeInWeights[noa] = m_changeInWeights[noa+1];
303         
304          m_inputList[noa-1].changeOutputNum(m_inputNums[noa-1], noa-1);
305        }
306        m_numInputs--;
307        removed = true;
308      }     
309    } while (n == -1 && loc != -1);
310    return removed;
311  }
312 
313  /**
314   * This function will remove all the inputs to this unit.
315   * In doing so it will also terminate the connections at the other end.
316   */
317  public void removeAllInputs() {
318    super.removeAllInputs();
319   
320    double temp1 = m_weights[0];
321    double temp2 = m_changeInWeights[0];
322
323    m_weights = new double[1];
324    m_changeInWeights = new double[1];
325
326    m_weights[0] = temp1;
327    m_changeInWeights[0] = temp2;
328   
329  } 
330 
331  /**
332   * Returns the revision string.
333   *
334   * @return            the revision
335   */
336  public String getRevision() {
337    return RevisionUtils.extract("$Revision: 5402 $");
338  }
339}
Note: See TracBrowser for help on using the repository browser.