source: src/main/java/weka/classifiers/functions/neural/NeuralConnection.java @ 16

Last change on this file since 16 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 19.6 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    NeuralConnection.java
19 *    Copyright (C) 2000 University of Waikato, Hamilton, New Zealand
20 */
21
22package weka.classifiers.functions.neural;
23
24import weka.core.RevisionHandler;
25
26import java.awt.Color;
27import java.awt.Graphics;
28import java.io.Serializable;
29
30/**
31 * Abstract unit in a NeuralNetwork.
32 *
33 * @author Malcolm Ware (mfw4@cs.waikato.ac.nz)
34 * @version $Revision: 5402 $
35 */
36public abstract class NeuralConnection
37  implements Serializable, RevisionHandler {
38
39  /** for serialization */
40  private static final long serialVersionUID = -286208828571059163L;
41
42  //bitwise flags for the types of unit.
43
44  /** This unit is not connected to any others. */
45  public static final int UNCONNECTED = 0;
46 
47  /** This unit is a pure input unit. */
48  public static final int PURE_INPUT = 1;
49 
50  /** This unit is a pure output unit. */
51  public static final int PURE_OUTPUT = 2;
52 
53  /** This unit is an input unit. */
54  public static final int INPUT = 4;
55 
56  /** This unit is an output unit. */
57  public static final int OUTPUT = 8;
58 
59  /** This flag is set once the unit has a connection. */
60  public static final int CONNECTED = 16;
61
62
63
64  /////The difference between pure and not is that pure is used to feed
65  /////the neural network the attribute values and the errors on the outputs
66  /////Beyond that they do no calculations, and have certain restrictions
67  /////on the connections they can make.
68
69
70
71  /** The list of inputs to this unit. */
72  protected NeuralConnection[] m_inputList;
73
74  /** The list of outputs from this unit. */
75  protected NeuralConnection[] m_outputList;
76
77  /** The numbering for the connections at the other end of the input lines. */
78  protected int[] m_inputNums;
79 
80  /** The numbering for the connections at the other end of the out lines. */
81  protected int[] m_outputNums;
82
83  /** The number of inputs. */
84  protected int m_numInputs;
85
86  /** The number of outputs. */
87  protected int m_numOutputs;
88
89  /** The output value for this unit, NaN if not calculated. */
90  protected double m_unitValue;
91
92  /** The error value for this unit, NaN if not calculated. */
93  protected double m_unitError;
94 
95  /** True if the weights have already been updated. */
96  protected boolean m_weightsUpdated;
97 
98  /** The string that uniquely (provided naming is done properly) identifies
99   * this unit. */
100  protected String m_id;
101
102  /** The type of unit this is. */
103  protected int m_type;
104
105  /** The x coord of this unit purely for displaying purposes. */
106  protected double m_x;
107 
108  /** The y coord of this unit purely for displaying purposes. */
109  protected double m_y;
110 
111
112 
113 
114  /**
115   * Constructs The unit with the basic connection information prepared for
116   * use.
117   *
118   * @param id the unique id of the unit
119   */
120  public NeuralConnection(String id) {
121   
122    m_id = id;
123    m_inputList = new NeuralConnection[0];
124    m_outputList = new NeuralConnection[0];
125    m_inputNums = new int[0];
126    m_outputNums = new int[0];
127
128    m_numInputs = 0;
129    m_numOutputs = 0;
130
131    m_unitValue = Double.NaN;
132    m_unitError = Double.NaN;
133
134    m_weightsUpdated = false;
135    m_x = 0;
136    m_y = 0;
137    m_type = UNCONNECTED;
138  }
139 
140 
141  /**
142   * @return The identity string of this unit.
143   */
144  public String getId() {
145    return m_id;
146  }
147
148  /**
149   * @return The type of this unit.
150   */
151  public int getType() {
152    return m_type;
153  }
154
155  /**
156   * @param t The new type of this unit.
157   */
158  public void setType(int t) {
159    m_type = t;
160  }
161
162  /**
163   * Call this to reset the unit for another run.
164   * It is expected by that this unit will call the reset functions of all
165   * input units to it. It is also expected that this will not be done
166   * if the unit has already been reset (or atleast appears to be).
167   */
168  public abstract void reset();
169
170  /**
171   * Call this to get the output value of this unit.
172   * @param calculate True if the value should be calculated if it hasn't been
173   * already.
174   * @return The output value, or NaN, if the value has not been calculated.
175   */
176  public abstract double outputValue(boolean calculate);
177
178  /**
179   * Call this to get the error value of this unit.
180   * @param calculate True if the value should be calculated if it hasn't been
181   * already.
182   * @return The error value, or NaN, if the value has not been calculated.
183   */
184  public abstract double errorValue(boolean calculate);
185 
186  /**
187   * Call this to have the connection save the current
188   * weights.
189   */
190  public abstract void saveWeights();
191 
192  /**
193   * Call this to have the connection restore from the saved
194   * weights.
195   */
196  public abstract void restoreWeights();
197
198  /**
199   * Call this to get the weight value on a particular connection.
200   * @param n The connection number to get the weight for, -1 if The threshold
201   * weight should be returned.
202   * @return This function will default to return 1. If overridden, it should
203   * return the value for the specified connection or if -1 then it should
204   * return the threshold value. If no value exists for the specified
205   * connection, NaN will be returned.
206   */
207  public double weightValue(int n) {
208    return 1;
209  }
210
211  /**
212   * Call this function to update the weight values at this unit.
213   * After the weights have been updated at this unit, All the
214   * input connections will then be called from this to have their
215   * weights updated.
216   * @param l The learning Rate to use.
217   * @param m The momentum to use.
218   */
219  public void updateWeights(double l, double m) {
220   
221    //the action the subclasses should perform is upto them
222    //but if they coverride they should make a call to this to
223    //call the method for all their inputs.
224   
225    if (!m_weightsUpdated) {
226      for (int noa = 0; noa < m_numInputs; noa++) {
227        m_inputList[noa].updateWeights(l, m);
228      }
229      m_weightsUpdated = true;
230    }
231   
232  }
233
234  /**
235   * Use this to get easy access to the inputs.
236   * It is not advised to change the entries in this list
237   * (use the connecting and disconnecting functions to do that)
238   * @return The inputs list.
239   */
240  public NeuralConnection[] getInputs() {
241    return m_inputList;
242  }
243
244  /**
245   * Use this to get easy access to the outputs.
246   * It is not advised to change the entries in this list
247   * (use the connecting and disconnecting functions to do that)
248   * @return The outputs list.
249   */
250  public NeuralConnection[] getOutputs() {
251    return m_outputList;
252  }
253
254  /**
255   * Use this to get easy access to the input numbers.
256   * It is not advised to change the entries in this list
257   * (use the connecting and disconnecting functions to do that)
258   * @return The input nums list.
259   */
260  public int[] getInputNums() {
261    return m_inputNums;
262  }
263
264  /**
265   * Use this to get easy access to the output numbers.
266   * It is not advised to change the entries in this list
267   * (use the connecting and disconnecting functions to do that)
268   * @return The outputs list.
269   */
270  public int[] getOutputNums() {
271    return m_outputNums;
272  }
273
274  /**
275   * @return the x coord.
276   */
277  public double getX() {
278    return m_x;
279  }
280 
281  /**
282   * @return the y coord.
283   */
284  public double getY() {
285    return m_y;
286  }
287 
288  /**
289   * @param x The new value for it's x pos.
290   */
291  public void setX(double x) {
292    m_x = x;
293  }
294 
295  /**
296   * @param y The new value for it's y pos.
297   */
298  public void setY(double y) {
299    m_y = y;
300  }
301 
302 
303  /**
304   * Call this function to determine if the point at x,y is on the unit.
305   * @param g The graphics context for font size info.
306   * @param x The x coord.
307   * @param y The y coord.
308   * @param w The width of the display.
309   * @param h The height of the display.
310   * @return True if the point is on the unit, false otherwise.
311   */
312  public boolean onUnit(Graphics g, int x, int y, int w, int h) {
313
314    int m = (int)(m_x * w);
315    int c = (int)(m_y * h);
316    if (x > m + 10 || x < m - 10 || y > c + 10 || y < c - 10) {
317      return false;
318    }
319    return true;
320
321  }
322 
323  /**
324   * Call this function to draw the node.
325   * @param g The graphics context.
326   * @param w The width of the drawing area.
327   * @param h The height of the drawing area.
328   */
329  public void drawNode(Graphics g, int w, int h) {
330   
331    if ((m_type & OUTPUT) == OUTPUT) {
332      g.setColor(Color.orange);
333    }
334    else {
335      g.setColor(Color.red);
336    }
337    g.fillOval((int)(m_x * w) - 9, (int)(m_y * h) - 9, 19, 19);
338    g.setColor(Color.gray);
339    g.fillOval((int)(m_x * w) - 5, (int)(m_y * h) - 5, 11, 11);
340  }
341
342  /**
343   * Call this function to draw the node highlighted.
344   * @param g The graphics context.
345   * @param w The width of the drawing area.
346   * @param h The height of the drawing area.
347   */
348  public void drawHighlight(Graphics g, int w, int h) {
349   
350    drawNode(g, w, h);
351    g.setColor(Color.yellow);
352    g.fillOval((int)(m_x * w) - 5, (int)(m_y * h) - 5, 11, 11);
353  }
354
355  /**
356   * Call this function to draw the nodes input connections.
357   * @param g The graphics context.
358   * @param w The width of the drawing area.
359   * @param h The height of the drawing area.
360   */
361  public void drawInputLines(Graphics g, int w, int h) {
362
363    g.setColor(Color.black);
364   
365    int px = (int)(m_x * w);
366    int py = (int)(m_y * h);
367    for (int noa = 0; noa < m_numInputs; noa++) {
368      g.drawLine((int)(m_inputList[noa].getX() * w)
369                 , (int)(m_inputList[noa].getY() * h)
370                 , px, py);
371    }
372  }
373
374  /**
375   * Call this function to draw the nodes output connections.
376   * @param g The graphics context.
377   * @param w The width of the drawing area.
378   * @param h The height of the drawing area.
379   */
380  public void drawOutputLines(Graphics g, int w, int h) {
381   
382    g.setColor(Color.black);
383   
384    int px = (int)(m_x * w);
385    int py = (int)(m_y * h);
386    for (int noa = 0; noa < m_numOutputs; noa++) {
387      g.drawLine(px, py
388                 , (int)(m_outputList[noa].getX() * w)
389                 , (int)(m_outputList[noa].getY() * h));
390    }
391  }
392
393
394  /**
395   * This will connect the specified unit to be an input to this unit.
396   * @param i The unit.
397   * @param n It's connection number for this connection.
398   * @return True if the connection was made, false otherwise.
399   */
400  protected boolean connectInput(NeuralConnection i, int n) {
401   
402    for (int noa = 0; noa < m_numInputs; noa++) {
403      if (i == m_inputList[noa]) {
404        return false;
405      }
406    }
407    if (m_numInputs >= m_inputList.length) {
408      //then allocate more space to it.
409      allocateInputs();
410    }
411    m_inputList[m_numInputs] = i;
412    m_inputNums[m_numInputs] = n;
413    m_numInputs++;
414    return true;
415  }
416 
417  /**
418   * This will allocate more space for input connection information
419   * if the arrays for this have been filled up.
420   */
421  protected void allocateInputs() {
422   
423    NeuralConnection[] temp1 = new NeuralConnection[m_inputList.length + 15];
424    int[] temp2 = new int[m_inputNums.length + 15];
425
426    for (int noa = 0; noa < m_numInputs; noa++) {
427      temp1[noa] = m_inputList[noa];
428      temp2[noa] = m_inputNums[noa];
429    }
430    m_inputList = temp1;
431    m_inputNums = temp2;
432  }
433
434  /**
435   * This will connect the specified unit to be an output to this unit.
436   * @param o The unit.
437   * @param n It's connection number for this connection.
438   * @return True if the connection was made, false otherwise.
439   */
440  protected boolean connectOutput(NeuralConnection o, int n) {
441   
442    for (int noa = 0; noa < m_numOutputs; noa++) {
443      if (o == m_outputList[noa]) {
444        return false;
445      }
446    }
447    if (m_numOutputs >= m_outputList.length) {
448      //then allocate more space to it.
449      allocateOutputs();
450    }
451    m_outputList[m_numOutputs] = o;
452    m_outputNums[m_numOutputs] = n;
453    m_numOutputs++;
454    return true;
455  }
456 
457  /**
458   * Allocates more space for output connection information
459   * if the arrays have been filled up.
460   */
461  protected void allocateOutputs() {
462   
463    NeuralConnection[] temp1
464      = new NeuralConnection[m_outputList.length + 15];
465   
466    int[] temp2 = new int[m_outputNums.length + 15];
467   
468    for (int noa = 0; noa < m_numOutputs; noa++) {
469      temp1[noa] = m_outputList[noa];
470      temp2[noa] = m_outputNums[noa];
471    }
472    m_outputList = temp1;
473    m_outputNums = temp2;
474  }
475 
476  /**
477   * This will disconnect the input with the specific connection number
478   * From this node (only on this end however).
479   * @param i The unit to disconnect.
480   * @param n The connection number at the other end, -1 if all the connections
481   * to this unit should be severed.
482   * @return True if the connection was removed, false if the connection was
483   * not found.
484   */
485  protected boolean disconnectInput(NeuralConnection i, int n) {
486   
487    int loc = -1;
488    boolean removed = false;
489    do {
490      loc = -1;
491      for (int noa = 0; noa < m_numInputs; noa++) {
492        if (i == m_inputList[noa] && (n == -1 || n == m_inputNums[noa])) {
493          loc = noa;
494          break;
495        }
496      }
497     
498      if (loc >= 0) {
499        for (int noa = loc+1; noa < m_numInputs; noa++) {
500          m_inputList[noa-1] = m_inputList[noa];
501          m_inputNums[noa-1] = m_inputNums[noa];
502          //set the other end to have the right connection number.
503          m_inputList[noa-1].changeOutputNum(m_inputNums[noa-1], noa-1);
504        }
505        m_numInputs--;
506        removed = true;
507      }
508    } while (n == -1 && loc != -1);
509
510    return removed;
511  }
512
513  /**
514   * This function will remove all the inputs to this unit.
515   * In doing so it will also terminate the connections at the other end.
516   */
517  public void removeAllInputs() {
518   
519    for (int noa = 0; noa < m_numInputs; noa++) {
520      //this command will simply remove any connections this node has
521      //with the other in 1 go, rather than seperately.
522      m_inputList[noa].disconnectOutput(this, -1);
523    }
524   
525    //now reset the inputs.
526    m_inputList = new NeuralConnection[0];
527    setType(getType() & (~INPUT));
528    if (getNumOutputs() == 0) {
529      setType(getType() & (~CONNECTED));
530    }
531    m_inputNums = new int[0];
532    m_numInputs = 0;
533   
534  }
535
536 
537
538  /**
539   * Changes the connection value information for one of the connections.
540   * @param n The connection number to change.
541   * @param v The value to change it to.
542   */
543  protected void changeInputNum(int n, int v) {
544   
545    if (n >= m_numInputs || n < 0) {
546      return;
547    }
548
549    m_inputNums[n] = v;
550  }
551 
552  /**
553   * This will disconnect the output with the specific connection number
554   * From this node (only on this end however).
555   * @param o The unit to disconnect.
556   * @param n The connection number at the other end, -1 if all the connections
557   * to this unit should be severed.
558   * @return True if the connection was removed, false if the connection was
559   * not found.
560   */ 
561  protected boolean disconnectOutput(NeuralConnection o, int n) {
562   
563    int loc = -1;
564    boolean removed = false;
565    do {
566      loc = -1;
567      for (int noa = 0; noa < m_numOutputs; noa++) {
568        if (o == m_outputList[noa] && (n == -1 || n == m_outputNums[noa])) {
569          loc =noa;
570          break;
571        }
572      }
573     
574      if (loc >= 0) {
575        for (int noa = loc+1; noa < m_numOutputs; noa++) {
576          m_outputList[noa-1] = m_outputList[noa];
577          m_outputNums[noa-1] = m_outputNums[noa];
578
579          //set the other end to have the right connection number
580          m_outputList[noa-1].changeInputNum(m_outputNums[noa-1], noa-1);
581        }
582        m_numOutputs--;
583        removed = true;
584      }
585    } while (n == -1 && loc != -1);
586   
587    return removed;
588  }
589
590  /**
591   * This function will remove all outputs to this unit.
592   * In doing so it will also terminate the connections at the other end.
593   */
594  public void removeAllOutputs() {
595   
596    for (int noa = 0; noa < m_numOutputs; noa++) {
597      //this command will simply remove any connections this node has
598      //with the other in 1 go, rather than seperately.
599      m_outputList[noa].disconnectInput(this, -1);
600    }
601   
602    //now reset the inputs.
603    m_outputList = new NeuralConnection[0];
604    m_outputNums = new int[0];
605    setType(getType() & (~OUTPUT));
606    if (getNumInputs() == 0) {
607      setType(getType() & (~CONNECTED));
608    }
609    m_numOutputs = 0;
610   
611  }
612
613  /**
614   * Changes the connection value information for one of the connections.
615   * @param n The connection number to change.
616   * @param v The value to change it to.
617   */
618  protected void changeOutputNum(int n, int v) {
619   
620    if (n >= m_numOutputs || n < 0) {
621      return;
622    }
623
624    m_outputNums[n] = v;
625  }
626 
627  /**
628   * @return The number of input connections.
629   */
630  public int getNumInputs() {
631    return m_numInputs;
632  }
633
634  /**
635   * @return The number of output connections.
636   */
637  public int getNumOutputs() {
638    return m_numOutputs;
639  }
640
641
642  /**
643   * Connects two units together.
644   * @param s The source unit.
645   * @param t The target unit.
646   * @return True if the units were connected, false otherwise.
647   */
648  public static boolean connect(NeuralConnection s, NeuralConnection t) {
649   
650    if (s == null || t == null) {
651      return false;
652    }
653    //this ensures that there is no existing connection between these
654    //two units already. This will also cause the current weight there to be
655    //lost
656 
657    disconnect(s, t);
658    if (s == t) {
659      return false;
660    }
661    if ((t.getType() & PURE_INPUT) == PURE_INPUT) {
662      return false;   //target is an input node.
663    }
664    if ((s.getType() & PURE_OUTPUT) == PURE_OUTPUT) {
665      return false;   //source is an output node
666    }
667    if ((s.getType() & PURE_INPUT) == PURE_INPUT
668        && (t.getType() & PURE_OUTPUT) == PURE_OUTPUT) {     
669      return false;   //there is no actual working node in use
670    }
671    if ((t.getType() & PURE_OUTPUT) == PURE_OUTPUT && t.getNumInputs() > 0) {
672      return false; //more than 1 node is trying to feed a particular output
673    }
674
675    if ((t.getType() & PURE_OUTPUT) == PURE_OUTPUT &&
676        (s.getType() & OUTPUT) == OUTPUT) {
677      return false; //an output node already feeding out a final answer
678    }
679
680    if (!s.connectOutput(t, t.getNumInputs())) {
681      return false;
682    }
683    if (!t.connectInput(s, s.getNumOutputs() - 1)) {
684     
685      s.disconnectOutput(t, t.getNumInputs());
686      return false;
687
688    }
689
690    //now ammend the type.
691    if ((s.getType() & PURE_INPUT) == PURE_INPUT) {
692      t.setType(t.getType() | INPUT);
693    }
694    else if ((t.getType() & PURE_OUTPUT) == PURE_OUTPUT) {
695      s.setType(s.getType() | OUTPUT);
696    }
697    t.setType(t.getType() | CONNECTED);
698    s.setType(s.getType() | CONNECTED);
699    return true;
700  }
701
702  /**
703   * Disconnects two units.
704   * @param s The source unit.
705   * @param t The target unit.
706   * @return True if the units were disconnected, false if they weren't
707   * (probably due to there being no connection).
708   */
709  public static boolean disconnect(NeuralConnection s, NeuralConnection t) {
710   
711    if (s == null || t == null) {
712      return false;
713    }
714
715    boolean stat1 = s.disconnectOutput(t, -1);
716    boolean stat2 = t.disconnectInput(s, -1);
717    if (stat1 && stat2) {
718      if ((s.getType() & PURE_INPUT) == PURE_INPUT) {
719        t.setType(t.getType() & (~INPUT));
720      }
721      else if ((t.getType() & (PURE_OUTPUT)) == PURE_OUTPUT) {
722        s.setType(s.getType() & (~OUTPUT));
723      }
724      if (s.getNumInputs() == 0 && s.getNumOutputs() == 0) {
725        s.setType(s.getType() & (~CONNECTED));
726      }
727      if (t.getNumInputs() == 0 && t.getNumOutputs() == 0) {
728        t.setType(t.getType() & (~CONNECTED));
729      }
730    }
731    return stat1 && stat2;
732  }
733}
734
735
736
737
738
739
740
741
742
743
744
745
Note: See TracBrowser for help on using the repository browser.