source: src/main/java/weka/classifiers/functions/MultilayerPerceptron.java @ 15

Last change on this file since 15 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 80.3 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    MultilayerPerceptron.java
19 *    Copyright (C) 2000 University of Waikato, Hamilton, New Zealand
20 */
21
22package weka.classifiers.functions;
23
24import weka.classifiers.Classifier;
25import weka.classifiers.AbstractClassifier;
26import weka.classifiers.functions.neural.LinearUnit;
27import weka.classifiers.functions.neural.NeuralConnection;
28import weka.classifiers.functions.neural.NeuralNode;
29import weka.classifiers.functions.neural.SigmoidUnit;
30import weka.core.Capabilities;
31import weka.core.FastVector;
32import weka.core.Instance;
33import weka.core.Instances;
34import weka.core.Option;
35import weka.core.OptionHandler;
36import weka.core.Randomizable;
37import weka.core.RevisionHandler;
38import weka.core.RevisionUtils;
39import weka.core.Utils;
40import weka.core.WeightedInstancesHandler;
41import weka.core.Capabilities.Capability;
42import weka.filters.Filter;
43import weka.filters.unsupervised.attribute.NominalToBinary;
44
45import java.awt.BorderLayout;
46import java.awt.Color;
47import java.awt.Component;
48import java.awt.Dimension;
49import java.awt.FontMetrics;
50import java.awt.Graphics;
51import java.awt.event.ActionEvent;
52import java.awt.event.ActionListener;
53import java.awt.event.MouseAdapter;
54import java.awt.event.MouseEvent;
55import java.awt.event.WindowAdapter;
56import java.awt.event.WindowEvent;
57import java.util.Enumeration;
58import java.util.Random;
59import java.util.StringTokenizer;
60import java.util.Vector;
61
62import javax.swing.BorderFactory;
63import javax.swing.Box;
64import javax.swing.BoxLayout;
65import javax.swing.JButton;
66import javax.swing.JFrame;
67import javax.swing.JLabel;
68import javax.swing.JOptionPane;
69import javax.swing.JPanel;
70import javax.swing.JScrollPane;
71import javax.swing.JTextField;
72
73/**
74 <!-- globalinfo-start -->
75 * A Classifier that uses backpropagation to classify instances.<br/>
76 * This network can be built by hand, created by an algorithm or both. The network can also be monitored and modified during training time. The nodes in this network are all sigmoid (except for when the class is numeric in which case the the output nodes become unthresholded linear units).
77 * <p/>
78 <!-- globalinfo-end -->
79 *
80 <!-- options-start -->
81 * Valid options are: <p/>
82 *
83 * <pre> -L &lt;learning rate&gt;
84 *  Learning Rate for the backpropagation algorithm.
85 *  (Value should be between 0 - 1, Default = 0.3).</pre>
86 *
87 * <pre> -M &lt;momentum&gt;
88 *  Momentum Rate for the backpropagation algorithm.
89 *  (Value should be between 0 - 1, Default = 0.2).</pre>
90 *
91 * <pre> -N &lt;number of epochs&gt;
92 *  Number of epochs to train through.
93 *  (Default = 500).</pre>
94 *
95 * <pre> -V &lt;percentage size of validation set&gt;
96 *  Percentage size of validation set to use to terminate
97 *  training (if this is non zero it can pre-empt num of epochs.
98 *  (Value should be between 0 - 100, Default = 0).</pre>
99 *
100 * <pre> -S &lt;seed&gt;
101 *  The value used to seed the random number generator
102 *  (Value should be &gt;= 0 and and a long, Default = 0).</pre>
103 *
104 * <pre> -E &lt;threshold for number of consequetive errors&gt;
105 *  The consequetive number of errors allowed for validation
106 *  testing before the netwrok terminates.
107 *  (Value should be &gt; 0, Default = 20).</pre>
108 *
109 * <pre> -G
110 *  GUI will be opened.
111 *  (Use this to bring up a GUI).</pre>
112 *
113 * <pre> -A
114 *  Autocreation of the network connections will NOT be done.
115 *  (This will be ignored if -G is NOT set)</pre>
116 *
117 * <pre> -B
118 *  A NominalToBinary filter will NOT automatically be used.
119 *  (Set this to not use a NominalToBinary filter).</pre>
120 *
121 * <pre> -H &lt;comma seperated numbers for nodes on each layer&gt;
122 *  The hidden layers to be created for the network.
123 *  (Value should be a list of comma separated Natural
124 *  numbers or the letters 'a' = (attribs + classes) / 2,
125 *  'i' = attribs, 'o' = classes, 't' = attribs .+ classes)
126 *  for wildcard values, Default = a).</pre>
127 *
128 * <pre> -C
129 *  Normalizing a numeric class will NOT be done.
130 *  (Set this to not normalize the class if it's numeric).</pre>
131 *
132 * <pre> -I
133 *  Normalizing the attributes will NOT be done.
134 *  (Set this to not normalize the attributes).</pre>
135 *
136 * <pre> -R
137 *  Reseting the network will NOT be allowed.
138 *  (Set this to not allow the network to reset).</pre>
139 *
140 * <pre> -D
141 *  Learning rate decay will occur.
142 *  (Set this to cause the learning rate to decay).</pre>
143 *
144 <!-- options-end -->
145 *
146 * @author Malcolm Ware (mfw4@cs.waikato.ac.nz)
147 * @version $Revision: 5928 $
148 */
149public class MultilayerPerceptron 
150  extends AbstractClassifier
151  implements OptionHandler, WeightedInstancesHandler, Randomizable {
152 
153  /** for serialization */
154  static final long serialVersionUID = 572250905027665169L;
155 
156  /**
157   * Main method for testing this class.
158   *
159   * @param argv should contain command line options (see setOptions)
160   */
161  public static void main(String [] argv) {
162    runClassifier(new MultilayerPerceptron(), argv);
163  }
164 
165
166  /**
167   * This inner class is used to connect the nodes in the network up to
168   * the data that they are classifying, Note that objects of this class are
169   * only suitable to go on the attribute side or class side of the network
170   * and not both.
171   */
172  protected class NeuralEnd 
173    extends NeuralConnection {
174   
175    /** for serialization */
176    static final long serialVersionUID = 7305185603191183338L;
177 
178    /**
179     * the value that represents the instance value this node represents.
180     * For an input it is the attribute number, for an output, if nominal
181     * it is the class value.
182     */
183    private int m_link;
184   
185    /** True if node is an input, False if it's an output. */
186    private boolean m_input;
187
188    /**
189     * Constructor
190     */
191    public NeuralEnd(String id) {
192      super(id);
193
194      m_link = 0;
195      m_input = true;
196     
197    }
198 
199    /**
200     * Call this function to determine if the point at x,y is on the unit.
201     * @param g The graphics context for font size info.
202     * @param x The x coord.
203     * @param y The y coord.
204     * @param w The width of the display.
205     * @param h The height of the display.
206     * @return True if the point is on the unit, false otherwise.
207     */
208    public boolean onUnit(Graphics g, int x, int y, int w, int h) {
209     
210      FontMetrics fm = g.getFontMetrics();
211      int l = (int)(m_x * w) - fm.stringWidth(m_id) / 2;
212      int t = (int)(m_y * h) - fm.getHeight() / 2;
213      if (x < l || x > l + fm.stringWidth(m_id) + 4 
214          || y < t || y > t + fm.getHeight() + fm.getDescent() + 4) {
215        return false;
216      }
217      return true;
218     
219    }
220   
221
222    /**
223     * This will draw the node id to the graphics context.
224     * @param g The graphics context.
225     * @param w The width of the drawing area.
226     * @param h The height of the drawing area.
227     */
228    public void drawNode(Graphics g, int w, int h) {
229     
230      if ((m_type & PURE_INPUT) == PURE_INPUT) {
231        g.setColor(Color.green);
232      }
233      else {
234        g.setColor(Color.orange);
235      }
236     
237      FontMetrics fm = g.getFontMetrics();
238      int l = (int)(m_x * w) - fm.stringWidth(m_id) / 2;
239      int t = (int)(m_y * h) - fm.getHeight() / 2;
240      g.fill3DRect(l, t, fm.stringWidth(m_id) + 4
241                   , fm.getHeight() + fm.getDescent() + 4
242                   , true);
243      g.setColor(Color.black);
244     
245      g.drawString(m_id, l + 2, t + fm.getHeight() + 2);
246
247    }
248
249
250    /**
251     * Call this function to draw the node highlighted.
252     * @param g The graphics context.
253     * @param w The width of the drawing area.
254     * @param h The height of the drawing area.
255     */
256    public void drawHighlight(Graphics g, int w, int h) {
257     
258      g.setColor(Color.black);
259      FontMetrics fm = g.getFontMetrics();
260      int l = (int)(m_x * w) - fm.stringWidth(m_id) / 2;
261      int t = (int)(m_y * h) - fm.getHeight() / 2;
262      g.fillRect(l - 2, t - 2, fm.stringWidth(m_id) + 8
263                 , fm.getHeight() + fm.getDescent() + 8); 
264      drawNode(g, w, h);
265    }
266   
267    /**
268     * Call this to get the output value of this unit.
269     * @param calculate True if the value should be calculated if it hasn't
270     * been already.
271     * @return The output value, or NaN, if the value has not been calculated.
272     */
273    public double outputValue(boolean calculate) {
274     
275      if (Double.isNaN(m_unitValue) && calculate) {
276        if (m_input) {
277          if (m_currentInstance.isMissing(m_link)) {
278            m_unitValue = 0;
279          }
280          else {
281           
282            m_unitValue = m_currentInstance.value(m_link);
283          }
284        }
285        else {
286          //node is an output.
287          m_unitValue = 0;
288          for (int noa = 0; noa < m_numInputs; noa++) {
289            m_unitValue += m_inputList[noa].outputValue(true);
290           
291          }
292          if (m_numeric && m_normalizeClass) {
293            //then scale the value;
294            //this scales linearly from between -1 and 1
295            m_unitValue = m_unitValue * 
296              m_attributeRanges[m_instances.classIndex()] + 
297              m_attributeBases[m_instances.classIndex()];
298          }
299        }
300      }
301      return m_unitValue;
302     
303     
304    }
305   
306    /**
307     * Call this to get the error value of this unit, which in this case is
308     * the difference between the predicted class, and the actual class.
309     * @param calculate True if the value should be calculated if it hasn't
310     * been already.
311     * @return The error value, or NaN, if the value has not been calculated.
312     */
313    public double errorValue(boolean calculate) {
314     
315      if (!Double.isNaN(m_unitValue) && Double.isNaN(m_unitError) 
316          && calculate) {
317       
318        if (m_input) {
319          m_unitError = 0;
320          for (int noa = 0; noa < m_numOutputs; noa++) {
321            m_unitError += m_outputList[noa].errorValue(true);
322          }
323        }
324        else {
325          if (m_currentInstance.classIsMissing()) {
326            m_unitError = .1; 
327          }
328          else if (m_instances.classAttribute().isNominal()) {
329            if (m_currentInstance.classValue() == m_link) {
330              m_unitError = 1 - m_unitValue;
331            }
332            else {
333              m_unitError = 0 - m_unitValue;
334            }
335          }
336          else if (m_numeric) {
337           
338            if (m_normalizeClass) {
339              if (m_attributeRanges[m_instances.classIndex()] == 0) {
340                m_unitError = 0;
341              }
342              else {
343                m_unitError = (m_currentInstance.classValue() - m_unitValue ) /
344                  m_attributeRanges[m_instances.classIndex()];
345                //m_numericRange;
346               
347              }
348            }
349            else {
350              m_unitError = m_currentInstance.classValue() - m_unitValue;
351            }
352          }
353        }
354      }
355      return m_unitError;
356    }
357   
358   
359    /**
360     * Call this to reset the value and error for this unit, ready for the next
361     * run. This will also call the reset function of all units that are
362     * connected as inputs to this one.
363     * This is also the time that the update for the listeners will be
364     * performed.
365     */
366    public void reset() {
367     
368      if (!Double.isNaN(m_unitValue) || !Double.isNaN(m_unitError)) {
369        m_unitValue = Double.NaN;
370        m_unitError = Double.NaN;
371        m_weightsUpdated = false;
372        for (int noa = 0; noa < m_numInputs; noa++) {
373          m_inputList[noa].reset();
374        }
375      }
376    }
377   
378    /**
379     * Call this to have the connection save the current
380     * weights.
381     */
382    public void saveWeights() {
383      for (int i = 0; i < m_numInputs; i++) {
384        m_inputList[i].saveWeights();
385      }
386    }
387   
388    /**
389     * Call this to have the connection restore from the saved
390     * weights.
391     */
392    public void restoreWeights() {
393      for (int i = 0; i < m_numInputs; i++) {
394        m_inputList[i].restoreWeights();
395      }
396    }
397   
398   
399    /**
400     * Call this function to set What this end unit represents.
401     * @param input True if this unit is used for entering an attribute,
402     * False if it's used for determining a class value.
403     * @param val The attribute number or class type that this unit represents.
404     * (for nominal attributes).
405     */
406    public void setLink(boolean input, int val) throws Exception {
407      m_input = input;
408     
409      if (input) {
410        m_type = PURE_INPUT;
411      }
412      else {
413        m_type = PURE_OUTPUT;
414      }
415      if (val < 0 || (input && val > m_instances.numAttributes()) 
416          || (!input && m_instances.classAttribute().isNominal() 
417              && val > m_instances.classAttribute().numValues())) {
418        m_link = 0;
419      }
420      else {
421        m_link = val;
422      }
423    }
424   
425    /**
426     * @return link for this node.
427     */
428    public int getLink() {
429      return m_link;
430    }
431   
432    /**
433     * Returns the revision string.
434     *
435     * @return          the revision
436     */
437    public String getRevision() {
438      return RevisionUtils.extract("$Revision: 5928 $");
439    }
440  }
441 
442
443 
444  /** Inner class used to draw the nodes onto.(uses the node lists!!)
445   * This will also handle the user input. */
446  private class NodePanel 
447    extends JPanel
448    implements RevisionHandler {
449   
450    /** for serialization */
451    static final long serialVersionUID = -3067621833388149984L;
452
453    /**
454     * The constructor.
455     */
456    public NodePanel() {
457     
458
459      addMouseListener(new MouseAdapter() {
460         
461          public void mousePressed(MouseEvent e) {
462           
463            if (!m_stopped) {
464              return;
465            }
466            if ((e.getModifiers() & MouseEvent.BUTTON1_MASK) == MouseEvent.BUTTON1_MASK && 
467                !e.isAltDown()) {
468              Graphics g = NodePanel.this.getGraphics();
469              int x = e.getX();
470              int y = e.getY();
471              int w = NodePanel.this.getWidth();
472              int h = NodePanel.this.getHeight();
473              FastVector tmp = new FastVector(4);
474              for (int noa = 0; noa < m_numAttributes; noa++) {
475                if (m_inputs[noa].onUnit(g, x, y, w, h)) {
476                  tmp.addElement(m_inputs[noa]);
477                  selection(tmp, 
478                            (e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK
479                            , true);
480                  return;
481                }
482              }
483              for (int noa = 0; noa < m_numClasses; noa++) {
484                if (m_outputs[noa].onUnit(g, x, y, w, h)) {
485                  tmp.addElement(m_outputs[noa]);
486                  selection(tmp,
487                            (e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK
488                            , true);
489                  return;
490                }
491              }
492              for (int noa = 0; noa < m_neuralNodes.length; noa++) {
493                if (m_neuralNodes[noa].onUnit(g, x, y, w, h)) {
494                  tmp.addElement(m_neuralNodes[noa]);
495                  selection(tmp,
496                            (e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK
497                            , true);
498                  return;
499                }
500
501              }
502              NeuralNode temp = new NeuralNode(String.valueOf(m_nextId), 
503                                               m_random, m_sigmoidUnit);
504              m_nextId++;
505              temp.setX((double)e.getX() / w);
506              temp.setY((double)e.getY() / h);
507              tmp.addElement(temp);
508              addNode(temp);
509              selection(tmp, (e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK
510                        , true);
511            }
512            else {
513              //then right click
514              Graphics g = NodePanel.this.getGraphics();
515              int x = e.getX();
516              int y = e.getY();
517              int w = NodePanel.this.getWidth();
518              int h = NodePanel.this.getHeight();
519              FastVector tmp = new FastVector(4);
520              for (int noa = 0; noa < m_numAttributes; noa++) {
521                if (m_inputs[noa].onUnit(g, x, y, w, h)) {
522                  tmp.addElement(m_inputs[noa]);
523                  selection(tmp, 
524                            (e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK
525                            , false);
526                  return;
527                }
528               
529               
530              }
531              for (int noa = 0; noa < m_numClasses; noa++) {
532                if (m_outputs[noa].onUnit(g, x, y, w, h)) {
533                  tmp.addElement(m_outputs[noa]);
534                  selection(tmp,
535                            (e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK
536                            , false);
537                  return;
538                }
539              }
540              for (int noa = 0; noa < m_neuralNodes.length; noa++) {
541                if (m_neuralNodes[noa].onUnit(g, x, y, w, h)) {
542                  tmp.addElement(m_neuralNodes[noa]);
543                  selection(tmp,
544                            (e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK
545                            , false);
546                  return;
547                }
548              }
549              selection(null, (e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK
550                        , false);
551            }
552          }
553        });
554    }
555   
556   
557    /**
558     * This function gets called when the user has clicked something
559     * It will amend the current selection or connect the current selection
560     * to the new selection.
561     * Or if nothing was selected and the right button was used it will
562     * delete the node.
563     * @param v The units that were selected.
564     * @param ctrl True if ctrl was held down.
565     * @param left True if it was the left mouse button.
566     */
567    private void selection(FastVector v, boolean ctrl, boolean left) {
568     
569      if (v == null) {
570        //then unselect all.
571        m_selected.removeAllElements();
572        repaint();
573        return;
574      }
575     
576
577      //then exclusive or the new selection with the current one.
578      if ((ctrl || m_selected.size() == 0) && left) {
579        boolean removed = false;
580        for (int noa = 0; noa < v.size(); noa++) {
581          removed = false;
582          for (int nob = 0; nob < m_selected.size(); nob++) {
583            if (v.elementAt(noa) == m_selected.elementAt(nob)) {
584              //then remove that element
585              m_selected.removeElementAt(nob);
586              removed = true;
587              break;
588            }
589          }
590          if (!removed) {
591            m_selected.addElement(v.elementAt(noa));
592          }
593        }
594        repaint();
595        return;
596      }
597
598     
599      if (left) {
600        //then connect the current selection to the new one.
601        for (int noa = 0; noa < m_selected.size(); noa++) {
602          for (int nob = 0; nob < v.size(); nob++) {
603            NeuralConnection
604              .connect((NeuralConnection)m_selected.elementAt(noa)
605                       , (NeuralConnection)v.elementAt(nob));
606          }
607        }
608      }
609      else if (m_selected.size() > 0) {
610        //then disconnect the current selection from the new one.
611       
612        for (int noa = 0; noa < m_selected.size(); noa++) {
613          for (int nob = 0; nob < v.size(); nob++) {
614            NeuralConnection
615              .disconnect((NeuralConnection)m_selected.elementAt(noa)
616                          , (NeuralConnection)v.elementAt(nob));
617           
618            NeuralConnection
619              .disconnect((NeuralConnection)v.elementAt(nob)
620                          , (NeuralConnection)m_selected.elementAt(noa));
621           
622          }
623        }
624      }
625      else {
626        //then remove the selected node. (it was right clicked while
627        //no other units were selected
628        for (int noa = 0; noa < v.size(); noa++) {
629          ((NeuralConnection)v.elementAt(noa)).removeAllInputs();
630          ((NeuralConnection)v.elementAt(noa)).removeAllOutputs();
631          removeNode((NeuralConnection)v.elementAt(noa));
632        }
633      }
634      repaint();
635    }
636
637    /**
638     * This will paint the nodes ontot the panel.
639     * @param g The graphics context.
640     */
641    public void paintComponent(Graphics g) {
642
643      super.paintComponent(g);
644      int x = getWidth();
645      int y = getHeight();
646      if (25 * m_numAttributes > 25 * m_numClasses && 
647          25 * m_numAttributes > y) {
648        setSize(x, 25 * m_numAttributes);
649      }
650      else if (25 * m_numClasses > y) {
651        setSize(x, 25 * m_numClasses);
652      }
653      else {
654        setSize(x, y);
655      }
656
657      y = getHeight();
658      for (int noa = 0; noa < m_numAttributes; noa++) {
659        m_inputs[noa].drawInputLines(g, x, y);
660      }
661      for (int noa = 0; noa < m_numClasses; noa++) {
662        m_outputs[noa].drawInputLines(g, x, y);
663        m_outputs[noa].drawOutputLines(g, x, y);
664      }
665      for (int noa = 0; noa < m_neuralNodes.length; noa++) {
666        m_neuralNodes[noa].drawInputLines(g, x, y);
667      }
668      for (int noa = 0; noa < m_numAttributes; noa++) {
669        m_inputs[noa].drawNode(g, x, y);
670      }
671      for (int noa = 0; noa < m_numClasses; noa++) {
672        m_outputs[noa].drawNode(g, x, y);
673      }
674      for (int noa = 0; noa < m_neuralNodes.length; noa++) {
675        m_neuralNodes[noa].drawNode(g, x, y);
676      }
677
678      for (int noa = 0; noa < m_selected.size(); noa++) {
679        ((NeuralConnection)m_selected.elementAt(noa)).drawHighlight(g, x, y);
680      }
681    }
682   
683    /**
684     * Returns the revision string.
685     *
686     * @return          the revision
687     */
688    public String getRevision() {
689      return RevisionUtils.extract("$Revision: 5928 $");
690    }
691  }
692
693  /**
694   * This provides the basic controls for working with the neuralnetwork
695   * @author Malcolm Ware (mfw4@cs.waikato.ac.nz)
696   * @version $Revision: 5928 $
697   */
698  class ControlPanel 
699    extends JPanel
700    implements RevisionHandler {
701   
702    /** for serialization */
703    static final long serialVersionUID = 7393543302294142271L;
704   
705    /** The start stop button. */
706    public JButton m_startStop;
707   
708    /** The button to accept the network (even if it hasn't done all epochs. */
709    public JButton m_acceptButton;
710   
711    /** A label to state the number of epochs processed so far. */
712    public JPanel m_epochsLabel;
713   
714    /** A label to state the total number of epochs to be processed. */
715    public JLabel m_totalEpochsLabel;
716   
717    /** A text field to allow the changing of the total number of epochs. */
718    public JTextField m_changeEpochs;
719   
720    /** A label to state the learning rate. */
721    public JLabel m_learningLabel;
722   
723    /** A label to state the momentum. */
724    public JLabel m_momentumLabel;
725   
726    /** A text field to allow the changing of the learning rate. */
727    public JTextField m_changeLearning;
728   
729    /** A text field to allow the changing of the momentum. */
730    public JTextField m_changeMomentum;
731   
732    /** A label to state roughly the accuracy of the network.(because the
733        accuracy is calculated per epoch, but the network is changing
734        throughout each epoch train).
735    */
736    public JPanel m_errorLabel;
737   
738    /** The constructor. */
739    public ControlPanel() { 
740      setBorder(BorderFactory.createTitledBorder("Controls"));
741     
742      m_totalEpochsLabel = new JLabel("Num Of Epochs  ");
743      m_epochsLabel = new JPanel(){ 
744          /** for serialization */
745          private static final long serialVersionUID = 2562773937093221399L;
746
747          public void paintComponent(Graphics g) {
748            super.paintComponent(g);
749            g.setColor(m_controlPanel.m_totalEpochsLabel.getForeground());
750            g.drawString("Epoch  " + m_epoch, 0, 10);
751          }
752        };
753      m_epochsLabel.setFont(m_totalEpochsLabel.getFont());
754     
755      m_changeEpochs = new JTextField();
756      m_changeEpochs.setText("" + m_numEpochs);
757      m_errorLabel = new JPanel(){
758          /** for serialization */
759          private static final long serialVersionUID = 4390239056336679189L;
760
761          public void paintComponent(Graphics g) {
762            super.paintComponent(g);
763            g.setColor(m_controlPanel.m_totalEpochsLabel.getForeground());
764            if (m_valSize == 0) {
765              g.drawString("Error per Epoch = " + 
766                           Utils.doubleToString(m_error, 7), 0, 10);
767            }
768            else {
769              g.drawString("Validation Error per Epoch = "
770                           + Utils.doubleToString(m_error, 7), 0, 10);
771            }
772          }
773        };
774      m_errorLabel.setFont(m_epochsLabel.getFont());
775     
776      m_learningLabel = new JLabel("Learning Rate = ");
777      m_momentumLabel = new JLabel("Momentum = ");
778      m_changeLearning = new JTextField();
779      m_changeMomentum = new JTextField();
780      m_changeLearning.setText("" + m_learningRate);
781      m_changeMomentum.setText("" + m_momentum);
782      setLayout(new BorderLayout(15, 10));
783
784      m_stopIt = true;
785      m_accepted = false;
786      m_startStop = new JButton("Start");
787      m_startStop.setActionCommand("Start");
788     
789      m_acceptButton = new JButton("Accept");
790      m_acceptButton.setActionCommand("Accept");
791     
792      JPanel buttons = new JPanel();
793      buttons.setLayout(new BoxLayout(buttons, BoxLayout.Y_AXIS));
794      buttons.add(m_startStop);
795      buttons.add(m_acceptButton);
796      add(buttons, BorderLayout.WEST);
797      JPanel data = new JPanel();
798      data.setLayout(new BoxLayout(data, BoxLayout.Y_AXIS));
799     
800      Box ab = new Box(BoxLayout.X_AXIS);
801      ab.add(m_epochsLabel);
802      data.add(ab);
803     
804      ab = new Box(BoxLayout.X_AXIS);
805      Component b = Box.createGlue();
806      ab.add(m_totalEpochsLabel);
807      ab.add(m_changeEpochs);
808      m_changeEpochs.setMaximumSize(new Dimension(200, 20));
809      ab.add(b);
810      data.add(ab);
811     
812      ab = new Box(BoxLayout.X_AXIS);
813      ab.add(m_errorLabel);
814      data.add(ab);
815     
816      add(data, BorderLayout.CENTER);
817     
818      data = new JPanel();
819      data.setLayout(new BoxLayout(data, BoxLayout.Y_AXIS));
820      ab = new Box(BoxLayout.X_AXIS);
821      b = Box.createGlue();
822      ab.add(m_learningLabel);
823      ab.add(m_changeLearning);
824      m_changeLearning.setMaximumSize(new Dimension(200, 20));
825      ab.add(b);
826      data.add(ab);
827     
828      ab = new Box(BoxLayout.X_AXIS);
829      b = Box.createGlue();
830      ab.add(m_momentumLabel);
831      ab.add(m_changeMomentum);
832      m_changeMomentum.setMaximumSize(new Dimension(200, 20));
833      ab.add(b);
834      data.add(ab);
835     
836      add(data, BorderLayout.EAST);
837     
838      m_startStop.addActionListener(new ActionListener() {
839          public void actionPerformed(ActionEvent e) {
840            if (e.getActionCommand().equals("Start")) {
841              m_stopIt = false;
842              m_startStop.setText("Stop");
843              m_startStop.setActionCommand("Stop");
844              int n = Integer.valueOf(m_changeEpochs.getText()).intValue();
845             
846              m_numEpochs = n;
847              m_changeEpochs.setText("" + m_numEpochs);
848             
849              double m=Double.valueOf(m_changeLearning.getText()).
850                doubleValue();
851              setLearningRate(m);
852              m_changeLearning.setText("" + m_learningRate);
853             
854              m = Double.valueOf(m_changeMomentum.getText()).doubleValue();
855              setMomentum(m);
856              m_changeMomentum.setText("" + m_momentum);
857             
858              blocker(false);
859            }
860            else if (e.getActionCommand().equals("Stop")) {
861              m_stopIt = true;
862              m_startStop.setText("Start");
863              m_startStop.setActionCommand("Start");
864            }
865          }
866        });
867     
868      m_acceptButton.addActionListener(new ActionListener() {
869          public void actionPerformed(ActionEvent e) {
870            m_accepted = true;
871            blocker(false);
872          }
873        });
874     
875      m_changeEpochs.addActionListener(new ActionListener() {
876          public void actionPerformed(ActionEvent e) {
877            int n = Integer.valueOf(m_changeEpochs.getText()).intValue();
878            if (n > 0) {
879              m_numEpochs = n;
880              blocker(false);
881            }
882          }
883        });
884    }
885   
886    /**
887     * Returns the revision string.
888     *
889     * @return          the revision
890     */
891    public String getRevision() {
892      return RevisionUtils.extract("$Revision: 5928 $");
893    }
894  }
895 
896
897  /** a ZeroR model in case no model can be built from the data */
898  private Classifier m_ZeroR;
899   
900  /** The training instances. */
901  private Instances m_instances;
902 
903  /** The current instance running through the network. */
904  private Instance m_currentInstance;
905 
906  /** A flag to say that it's a numeric class. */
907  private boolean m_numeric;
908
909  /** The ranges for all the attributes. */
910  private double[] m_attributeRanges;
911
912  /** The base values for all the attributes. */
913  private double[] m_attributeBases;
914
915  /** The output units.(only feeds the errors, does no calcs) */
916  private NeuralEnd[] m_outputs;
917
918  /** The input units.(only feeds the inputs does no calcs) */
919  private NeuralEnd[] m_inputs;
920
921  /** All the nodes that actually comprise the logical neural net. */
922  private NeuralConnection[] m_neuralNodes;
923
924  /** The number of classes. */
925  private int m_numClasses = 0;
926 
927  /** The number of attributes. */
928  private int m_numAttributes = 0; //note the number doesn't include the class.
929 
930  /** The panel the nodes are displayed on. */
931  private NodePanel m_nodePanel;
932 
933  /** The control panel. */
934  private ControlPanel m_controlPanel;
935
936  /** The next id number available for default naming. */
937  private int m_nextId;
938   
939  /** A Vector list of the units currently selected. */
940  private FastVector m_selected;
941
942  /** A Vector list of the graphers. */
943  private FastVector m_graphers;
944
945  /** The number of epochs to train through. */
946  private int m_numEpochs;
947
948  /** a flag to state if the network should be running, or stopped. */
949  private boolean m_stopIt;
950
951  /** a flag to state that the network has in fact stopped. */
952  private boolean m_stopped;
953
954  /** a flag to state that the network should be accepted the way it is. */
955  private boolean m_accepted;
956  /** The window for the network. */
957  private JFrame m_win;
958
959  /** A flag to tell the build classifier to automatically build a neural net.
960   */
961  private boolean m_autoBuild;
962
963  /** A flag to state that the gui for the network should be brought up.
964      To allow interaction while training. */
965  private boolean m_gui;
966
967  /** An int to say how big the validation set should be. */
968  private int m_valSize;
969
970  /** The number to to use to quit on validation testing. */
971  private int m_driftThreshold;
972
973  /** The number used to seed the random number generator. */
974  private int m_randomSeed;
975
976  /** The actual random number generator. */
977  private Random m_random;
978
979  /** A flag to state that a nominal to binary filter should be used. */
980  private boolean m_useNomToBin;
981 
982  /** The actual filter. */
983  private NominalToBinary m_nominalToBinaryFilter;
984
985  /** The string that defines the hidden layers */
986  private String m_hiddenLayers;
987
988  /** This flag states that the user wants the input values normalized. */
989  private boolean m_normalizeAttributes;
990
991  /** This flag states that the user wants the learning rate to decay. */
992  private boolean m_decay;
993
994  /** This is the learning rate for the network. */
995  private double m_learningRate;
996
997  /** This is the momentum for the network. */
998  private double m_momentum;
999
1000  /** Shows the number of the epoch that the network just finished. */
1001  private int m_epoch;
1002
1003  /** Shows the error of the epoch that the network just finished. */
1004  private double m_error;
1005
1006  /** This flag states that the user wants the network to restart if it
1007   * is found to be generating infinity or NaN for the error value. This
1008   * would restart the network with the current options except that the
1009   * learning rate would be smaller than before, (perhaps half of its current
1010   * value). This option will not be available if the gui is chosen (if the
1011   * gui is open the user can fix the network themselves, it is an
1012   * architectural minefield for the network to be reset with the gui open). */
1013  private boolean m_reset;
1014
1015  /** This flag states that the user wants the class to be normalized while
1016   * processing in the network is done. (the final answer will be in the
1017   * original range regardless). This option will only be used when the class
1018   * is numeric. */
1019  private boolean m_normalizeClass;
1020
1021  /**
1022   * this is a sigmoid unit.
1023   */
1024  private SigmoidUnit m_sigmoidUnit;
1025 
1026  /**
1027   * This is a linear unit.
1028   */
1029  private LinearUnit m_linearUnit;
1030 
1031  /**
1032   * The constructor.
1033   */
1034  public MultilayerPerceptron() {
1035    m_instances = null;
1036    m_currentInstance = null;
1037    m_controlPanel = null;
1038    m_nodePanel = null;
1039    m_epoch = 0;
1040    m_error = 0;
1041   
1042   
1043    m_outputs = new NeuralEnd[0];
1044    m_inputs = new NeuralEnd[0];
1045    m_numAttributes = 0;
1046    m_numClasses = 0;
1047    m_neuralNodes = new NeuralConnection[0];
1048    m_selected = new FastVector(4);
1049    m_graphers = new FastVector(2);
1050    m_nextId = 0;
1051    m_stopIt = true;
1052    m_stopped = true;
1053    m_accepted = false;
1054    m_numeric = false;
1055    m_random = null;
1056    m_nominalToBinaryFilter = new NominalToBinary();
1057    m_sigmoidUnit = new SigmoidUnit();
1058    m_linearUnit = new LinearUnit();
1059    //setting all the options to their defaults. To completely change these
1060    //defaults they will also need to be changed down the bottom in the
1061    //setoptions function (the text info in the accompanying functions should
1062    //also be changed to reflect the new defaults
1063    m_normalizeClass = true;
1064    m_normalizeAttributes = true;
1065    m_autoBuild = true;
1066    m_gui = false;
1067    m_useNomToBin = true;
1068    m_driftThreshold = 20;
1069    m_numEpochs = 500;
1070    m_valSize = 0;
1071    m_randomSeed = 0;
1072    m_hiddenLayers = "a";
1073    m_learningRate = .3;
1074    m_momentum = .2;
1075    m_reset = true;
1076    m_decay = false;
1077  }
1078
1079  /**
1080   * @param d True if the learning rate should decay.
1081   */
1082  public void setDecay(boolean d) {
1083    m_decay = d;
1084  }
1085 
1086  /**
1087   * @return the flag for having the learning rate decay.
1088   */
1089  public boolean getDecay() {
1090    return m_decay;
1091  }
1092
1093  /**
1094   * This sets the network up to be able to reset itself with the current
1095   * settings and the learning rate at half of what it is currently. This
1096   * will only happen if the network creates NaN or infinite errors. Also this
1097   * will continue to happen until the network is trained properly. The
1098   * learning rate will also get set back to it's original value at the end of
1099   * this. This can only be set to true if the GUI is not brought up.
1100   * @param r True if the network should restart with it's current options
1101   * and set the learning rate to half what it currently is.
1102   */
1103  public void setReset(boolean r) {
1104    if (m_gui) {
1105      r = false;
1106    }
1107    m_reset = r;
1108     
1109  }
1110
1111  /**
1112   * @return The flag for reseting the network.
1113   */
1114  public boolean getReset() {
1115    return m_reset;
1116  }
1117 
1118  /**
1119   * @param c True if the class should be normalized (the class will only ever
1120   * be normalized if it is numeric). (Normalization puts the range between
1121   * -1 - 1).
1122   */
1123  public void setNormalizeNumericClass(boolean c) {
1124    m_normalizeClass = c;
1125  }
1126 
1127  /**
1128   * @return The flag for normalizing a numeric class.
1129   */
1130  public boolean getNormalizeNumericClass() {
1131    return m_normalizeClass;
1132  }
1133
1134  /**
1135   * @param a True if the attributes should be normalized (even nominal
1136   * attributes will get normalized here) (range goes between -1 - 1).
1137   */
1138  public void setNormalizeAttributes(boolean a) {
1139    m_normalizeAttributes = a;
1140  }
1141
1142  /**
1143   * @return The flag for normalizing attributes.
1144   */
1145  public boolean getNormalizeAttributes() {
1146    return m_normalizeAttributes;
1147  }
1148
1149  /**
1150   * @param f True if a nominalToBinary filter should be used on the
1151   * data.
1152   */
1153  public void setNominalToBinaryFilter(boolean f) {
1154    m_useNomToBin = f;
1155  }
1156
1157  /**
1158   * @return The flag for nominal to binary filter use.
1159   */
1160  public boolean getNominalToBinaryFilter() {
1161    return m_useNomToBin;
1162  }
1163
1164  /**
1165   * This seeds the random number generator, that is used when a random
1166   * number is needed for the network.
1167   * @param l The seed.
1168   */
1169  public void setSeed(int l) {
1170    if (l >= 0) {
1171      m_randomSeed = l;
1172    }
1173  }
1174 
1175  /**
1176   * @return The seed for the random number generator.
1177   */
1178  public int getSeed() {
1179    return m_randomSeed;
1180  }
1181
1182  /**
1183   * This sets the threshold to use for when validation testing is being done.
1184   * It works by ending testing once the error on the validation set has
1185   * consecutively increased a certain number of times.
1186   * @param t The threshold to use for this.
1187   */
1188  public void setValidationThreshold(int t) {
1189    if (t > 0) {
1190      m_driftThreshold = t;
1191    }
1192  }
1193
1194  /**
1195   * @return The threshold used for validation testing.
1196   */
1197  public int getValidationThreshold() {
1198    return m_driftThreshold;
1199  }
1200 
1201  /**
1202   * The learning rate can be set using this command.
1203   * NOTE That this is a static variable so it affect all networks that are
1204   * running.
1205   * Must be greater than 0 and no more than 1.
1206   * @param l The New learning rate.
1207   */
1208  public void setLearningRate(double l) {
1209    if (l > 0 && l <= 1) {
1210      m_learningRate = l;
1211   
1212      if (m_controlPanel != null) {
1213        m_controlPanel.m_changeLearning.setText("" + l);
1214      }
1215    }
1216  }
1217
1218  /**
1219   * @return The learning rate for the nodes.
1220   */
1221  public double getLearningRate() {
1222    return m_learningRate;
1223  }
1224
1225  /**
1226   * The momentum can be set using this command.
1227   * THE same conditions apply to this as to the learning rate.
1228   * @param m The new Momentum.
1229   */
1230  public void setMomentum(double m) {
1231    if (m >= 0 && m <= 1) {
1232      m_momentum = m;
1233 
1234      if (m_controlPanel != null) {
1235        m_controlPanel.m_changeMomentum.setText("" + m);
1236      }
1237    }
1238  }
1239 
1240  /**
1241   * @return The momentum for the nodes.
1242   */
1243  public double getMomentum() {
1244    return m_momentum;
1245  }
1246
1247  /**
1248   * This will set whether the network is automatically built
1249   * or if it is left up to the user. (there is nothing to stop a user
1250   * from altering an autobuilt network however).
1251   * @param a True if the network should be auto built.
1252   */
1253  public void setAutoBuild(boolean a) {
1254    if (!m_gui) {
1255      a = true;
1256    }
1257    m_autoBuild = a;
1258  }
1259
1260  /**
1261   * @return The auto build state.
1262   */
1263  public boolean getAutoBuild() {
1264    return m_autoBuild;
1265  }
1266
1267
1268  /**
1269   * This will set what the hidden layers are made up of when auto build is
1270   * enabled. Note to have no hidden units, just put a single 0, Any more
1271   * 0's will indicate that the string is badly formed and make it unaccepted.
1272   * Negative numbers, and floats will do the same. There are also some
1273   * wildcards. These are 'a' = (number of attributes + number of classes) / 2,
1274   * 'i' = number of attributes, 'o' = number of classes, and 't' = number of
1275   * attributes + number of classes.
1276   * @param h A string with a comma seperated list of numbers. Each number is
1277   * the number of nodes to be on a hidden layer.
1278   */
1279  public void setHiddenLayers(String h) {
1280    String tmp = "";
1281    StringTokenizer tok = new StringTokenizer(h, ",");
1282    if (tok.countTokens() == 0) {
1283      return;
1284    }
1285    double dval;
1286    int val;
1287    String c;
1288    boolean first = true;
1289    while (tok.hasMoreTokens()) {
1290      c = tok.nextToken().trim();
1291
1292      if (c.equals("a") || c.equals("i") || c.equals("o") || 
1293               c.equals("t")) {
1294        tmp += c;
1295      }
1296      else {
1297        dval = Double.valueOf(c).doubleValue();
1298        val = (int)dval;
1299       
1300        if ((val == dval && (val != 0 || (tok.countTokens() == 0 && first)) && 
1301             val >= 0)) {
1302          tmp += val;
1303        }
1304        else {
1305          return;
1306        }
1307      }
1308     
1309      first = false;
1310      if (tok.hasMoreTokens()) {
1311        tmp += ", ";
1312      }
1313    }
1314    m_hiddenLayers = tmp;
1315  }
1316
1317  /**
1318   * @return A string representing the hidden layers, each number is the number
1319   * of nodes on a hidden layer.
1320   */
1321  public String getHiddenLayers() {
1322    return m_hiddenLayers;
1323  }
1324
1325  /**
1326   * This will set whether A GUI is brought up to allow interaction by the user
1327   * with the neural network during training.
1328   * @param a True if gui should be created.
1329   */
1330  public void setGUI(boolean a) {
1331    m_gui = a;
1332    if (!a) {
1333      setAutoBuild(true);
1334     
1335    }
1336    else {
1337      setReset(false);
1338    }
1339  }
1340
1341  /**
1342   * @return The true if should show gui.
1343   */
1344  public boolean getGUI() {
1345    return m_gui;
1346  }
1347
1348  /**
1349   * This will set the size of the validation set.
1350   * @param a The size of the validation set, as a percentage of the whole.
1351   */
1352  public void setValidationSetSize(int a) {
1353    if (a < 0 || a > 99) {
1354      return;
1355    }
1356    m_valSize = a;
1357  }
1358
1359  /**
1360   * @return The percentage size of the validation set.
1361   */
1362  public int getValidationSetSize() {
1363    return m_valSize;
1364  }
1365
1366 
1367 
1368 
1369  /**
1370   * Set the number of training epochs to perform.
1371   * Must be greater than 0.
1372   * @param n The number of epochs to train through.
1373   */
1374  public void setTrainingTime(int n) {
1375    if (n > 0) {
1376      m_numEpochs = n;
1377    }
1378  }
1379
1380  /**
1381   * @return The number of epochs to train through.
1382   */
1383  public int getTrainingTime() {
1384    return m_numEpochs;
1385  }
1386 
1387  /**
1388   * Call this function to place a node into the network list.
1389   * @param n The node to place in the list.
1390   */
1391  private void addNode(NeuralConnection n) {
1392   
1393    NeuralConnection[] temp1 = new NeuralConnection[m_neuralNodes.length + 1];
1394    for (int noa = 0; noa < m_neuralNodes.length; noa++) {
1395      temp1[noa] = m_neuralNodes[noa];
1396    }
1397
1398    temp1[temp1.length-1] = n;
1399    m_neuralNodes = temp1;
1400  }
1401
1402  /**
1403   * Call this function to remove the passed node from the list.
1404   * This will only remove the node if it is in the neuralnodes list.
1405   * @param n The neuralConnection to remove.
1406   * @return True if removed false if not (because it wasn't there).
1407   */
1408  private boolean removeNode(NeuralConnection n) {
1409    NeuralConnection[] temp1 = new NeuralConnection[m_neuralNodes.length - 1];
1410    int skip = 0;
1411    for (int noa = 0; noa < m_neuralNodes.length; noa++) {
1412      if (n == m_neuralNodes[noa]) {
1413        skip++;
1414      }
1415      else if (!((noa - skip) >= temp1.length)) {
1416        temp1[noa - skip] = m_neuralNodes[noa];
1417      }
1418      else {
1419        return false;
1420      }
1421    }
1422    m_neuralNodes = temp1;
1423    return true;
1424  }
1425
1426  /**
1427   * This function sets what the m_numeric flag to represent the passed class
1428   * it also performs the normalization of the attributes if applicable
1429   * and sets up the info to normalize the class. (note that regardless of
1430   * the options it will fill an array with the range and base, set to
1431   * normalize all attributes and the class to be between -1 and 1)
1432   * @param inst the instances.
1433   * @return The modified instances. This needs to be done. If the attributes
1434   * are normalized then deep copies will be made of all the instances which
1435   * will need to be passed back out.
1436   */
1437  private Instances setClassType(Instances inst) throws Exception {
1438    if (inst != null) {
1439      // x bounds
1440      double min=Double.POSITIVE_INFINITY;
1441      double max=Double.NEGATIVE_INFINITY;
1442      double value;
1443      m_attributeRanges = new double[inst.numAttributes()];
1444      m_attributeBases = new double[inst.numAttributes()];
1445      for (int noa = 0; noa < inst.numAttributes(); noa++) {
1446        min = Double.POSITIVE_INFINITY;
1447        max = Double.NEGATIVE_INFINITY;
1448        for (int i=0; i < inst.numInstances();i++) {
1449          if (!inst.instance(i).isMissing(noa)) {
1450            value = inst.instance(i).value(noa);
1451            if (value < min) {
1452              min = value;
1453            }
1454            if (value > max) {
1455              max = value;
1456            }
1457          }
1458        }
1459       
1460        m_attributeRanges[noa] = (max - min) / 2;
1461        m_attributeBases[noa] = (max + min) / 2;
1462        if (noa != inst.classIndex() && m_normalizeAttributes) {
1463          for (int i = 0; i < inst.numInstances(); i++) {
1464            if (m_attributeRanges[noa] != 0) {
1465              inst.instance(i).setValue(noa, (inst.instance(i).value(noa) 
1466                                              - m_attributeBases[noa]) /
1467                                        m_attributeRanges[noa]);
1468            }
1469            else {
1470              inst.instance(i).setValue(noa, inst.instance(i).value(noa) - 
1471                                        m_attributeBases[noa]);
1472            }
1473          }
1474        }
1475      }
1476      if (inst.classAttribute().isNumeric()) {
1477        m_numeric = true;
1478      }
1479      else {
1480        m_numeric = false;
1481      }
1482    }
1483    return inst;
1484  }
1485
1486  /**
1487   * A function used to stop the code that called buildclassifier
1488   * from continuing on before the user has finished the decision tree.
1489   * @param tf True to stop the thread, False to release the thread that is
1490   * waiting there (if one).
1491   */
1492  public synchronized void blocker(boolean tf) {
1493    if (tf) {
1494      try {
1495        wait();
1496      } catch(InterruptedException e) {
1497      }
1498    }
1499    else {
1500      notifyAll();
1501    }
1502  }
1503
1504  /**
1505   * Call this function to update the control panel for the gui.
1506   */
1507  private void updateDisplay() {
1508   
1509    if (m_gui) {
1510      m_controlPanel.m_errorLabel.repaint();
1511      m_controlPanel.m_epochsLabel.repaint();
1512    }
1513  }
1514 
1515
1516  /**
1517   * this will reset all the nodes in the network.
1518   */
1519  private void resetNetwork() {
1520    for (int noc = 0; noc < m_numClasses; noc++) {
1521      m_outputs[noc].reset();
1522    }
1523  }
1524 
1525  /**
1526   * This will cause the output values of all the nodes to be calculated.
1527   * Note that the m_currentInstance is used to calculate these values.
1528   */
1529  private void calculateOutputs() {
1530    for (int noc = 0; noc < m_numClasses; noc++) {     
1531      //get the values.
1532      m_outputs[noc].outputValue(true);
1533    }
1534  }
1535
1536  /**
1537   * This will cause the error values to be calculated for all nodes.
1538   * Note that the m_currentInstance is used to calculate these values.
1539   * Also the output values should have been calculated first.
1540   * @return The squared error.
1541   */
1542  private double calculateErrors() throws Exception {
1543    double ret = 0, temp = 0; 
1544    for (int noc = 0; noc < m_numAttributes; noc++) {
1545      //get the errors.
1546      m_inputs[noc].errorValue(true);
1547     
1548    }
1549    for (int noc = 0; noc < m_numClasses; noc++) {
1550      temp = m_outputs[noc].errorValue(false);
1551      ret += temp * temp;
1552    }   
1553    return ret;
1554   
1555  }
1556
1557  /**
1558   * This will cause the weight values to be updated based on the learning
1559   * rate, momentum and the errors that have been calculated for each node.
1560   * @param l The learning rate to update with.
1561   * @param m The momentum to update with.
1562   */
1563  private void updateNetworkWeights(double l, double m) {
1564    for (int noc = 0; noc < m_numClasses; noc++) {
1565      //update weights
1566      m_outputs[noc].updateWeights(l, m);
1567    }
1568
1569  }
1570 
1571  /**
1572   * This creates the required input units.
1573   */
1574  private void setupInputs() throws Exception {
1575    m_inputs = new NeuralEnd[m_numAttributes];
1576    int now = 0;
1577    for (int noa = 0; noa < m_numAttributes+1; noa++) {
1578      if (m_instances.classIndex() != noa) {
1579        m_inputs[noa - now] = new NeuralEnd(m_instances.attribute(noa).name());
1580       
1581        m_inputs[noa - now].setX(.1);
1582        m_inputs[noa - now].setY((noa - now + 1.0) / (m_numAttributes + 1));
1583        m_inputs[noa - now].setLink(true, noa);
1584      }   
1585      else {
1586        now = 1;
1587      }
1588    }
1589
1590  }
1591
1592  /**
1593   * This creates the required output units.
1594   */
1595  private void setupOutputs() throws Exception {
1596 
1597    m_outputs = new NeuralEnd[m_numClasses];
1598    for (int noa = 0; noa < m_numClasses; noa++) {
1599      if (m_numeric) {
1600        m_outputs[noa] = new NeuralEnd(m_instances.classAttribute().name());
1601      }
1602      else {
1603        m_outputs[noa]= new NeuralEnd(m_instances.classAttribute().value(noa));
1604      }
1605     
1606      m_outputs[noa].setX(.9);
1607      m_outputs[noa].setY((noa + 1.0) / (m_numClasses + 1));
1608      m_outputs[noa].setLink(false, noa);
1609      NeuralNode temp = new NeuralNode(String.valueOf(m_nextId), m_random,
1610                                       m_sigmoidUnit);
1611      m_nextId++;
1612      temp.setX(.75);
1613      temp.setY((noa + 1.0) / (m_numClasses + 1));
1614      addNode(temp);
1615      NeuralConnection.connect(temp, m_outputs[noa]);
1616    }
1617 
1618  }
1619 
1620  /**
1621   * Call this function to automatically generate the hidden units
1622   */
1623  private void setupHiddenLayer()
1624  {
1625    StringTokenizer tok = new StringTokenizer(m_hiddenLayers, ",");
1626    int val = 0;  //num of nodes in a layer
1627    int prev = 0; //used to remember the previous layer
1628    int num = tok.countTokens(); //number of layers
1629    String c;
1630    for (int noa = 0; noa < num; noa++) {
1631      //note that I am using the Double to get the value rather than the
1632      //Integer class, because for some reason the Double implementation can
1633      //handle leading white space and the integer version can't!?!
1634      c = tok.nextToken().trim();
1635      if (c.equals("a")) {
1636        val = (m_numAttributes + m_numClasses) / 2;
1637      }
1638      else if (c.equals("i")) {
1639        val = m_numAttributes;
1640      }
1641      else if (c.equals("o")) {
1642        val = m_numClasses;
1643      }
1644      else if (c.equals("t")) {
1645        val = m_numAttributes + m_numClasses;
1646      }
1647      else {
1648        val = Double.valueOf(c).intValue();
1649      }
1650      for (int nob = 0; nob < val; nob++) {
1651        NeuralNode temp = new NeuralNode(String.valueOf(m_nextId), m_random,
1652                                         m_sigmoidUnit);
1653        m_nextId++;
1654        temp.setX(.5 / (num) * noa + .25);
1655        temp.setY((nob + 1.0) / (val + 1));
1656        addNode(temp);
1657        if (noa > 0) {
1658          //then do connections
1659          for (int noc = m_neuralNodes.length - nob - 1 - prev;
1660               noc < m_neuralNodes.length - nob - 1; noc++) {
1661            NeuralConnection.connect(m_neuralNodes[noc], temp);
1662          }
1663        }
1664      }     
1665      prev = val;
1666    }
1667    tok = new StringTokenizer(m_hiddenLayers, ",");
1668    c = tok.nextToken();
1669    if (c.equals("a")) {
1670      val = (m_numAttributes + m_numClasses) / 2;
1671    }
1672    else if (c.equals("i")) {
1673      val = m_numAttributes;
1674    }
1675    else if (c.equals("o")) {
1676      val = m_numClasses;
1677    }
1678    else if (c.equals("t")) {
1679      val = m_numAttributes + m_numClasses;
1680    }
1681    else {
1682      val = Double.valueOf(c).intValue();
1683    }
1684   
1685    if (val == 0) {
1686      for (int noa = 0; noa < m_numAttributes; noa++) {
1687        for (int nob = 0; nob < m_numClasses; nob++) {
1688          NeuralConnection.connect(m_inputs[noa], m_neuralNodes[nob]);
1689        }
1690      }
1691    }
1692    else {
1693      for (int noa = 0; noa < m_numAttributes; noa++) {
1694        for (int nob = m_numClasses; nob < m_numClasses + val; nob++) {
1695          NeuralConnection.connect(m_inputs[noa], m_neuralNodes[nob]);
1696        }
1697      }
1698      for (int noa = m_neuralNodes.length - prev; noa < m_neuralNodes.length;
1699           noa++) {
1700        for (int nob = 0; nob < m_numClasses; nob++) {
1701          NeuralConnection.connect(m_neuralNodes[noa], m_neuralNodes[nob]);
1702        }
1703      }
1704    }
1705   
1706  }
1707 
1708  /**
1709   * This will go through all the nodes and check if they are connected
1710   * to a pure output unit. If so they will be set to be linear units.
1711   * If not they will be set to be sigmoid units.
1712   */
1713  private void setEndsToLinear() {
1714    for (int noa = 0; noa < m_neuralNodes.length; noa++) {
1715      if ((m_neuralNodes[noa].getType() & NeuralConnection.OUTPUT) ==
1716          NeuralConnection.OUTPUT) {
1717        ((NeuralNode)m_neuralNodes[noa]).setMethod(m_linearUnit);
1718      }
1719      else {
1720        ((NeuralNode)m_neuralNodes[noa]).setMethod(m_sigmoidUnit);
1721      }
1722    }
1723  }
1724
1725  /**
1726   * Returns default capabilities of the classifier.
1727   *
1728   * @return      the capabilities of this classifier
1729   */
1730  public Capabilities getCapabilities() {
1731    Capabilities result = super.getCapabilities();
1732    result.disableAll();
1733
1734    // attributes
1735    result.enable(Capability.NOMINAL_ATTRIBUTES);
1736    result.enable(Capability.NUMERIC_ATTRIBUTES);
1737    result.enable(Capability.DATE_ATTRIBUTES);
1738    result.enable(Capability.MISSING_VALUES);
1739
1740    // class
1741    result.enable(Capability.NOMINAL_CLASS);
1742    result.enable(Capability.NUMERIC_CLASS);
1743    result.enable(Capability.DATE_CLASS);
1744    result.enable(Capability.MISSING_CLASS_VALUES);
1745   
1746    return result;
1747  }
1748 
1749  /**
1750   * Call this function to build and train a neural network for the training
1751   * data provided.
1752   * @param i The training data.
1753   * @throws Exception if can't build classification properly.
1754   */
1755  public void buildClassifier(Instances i) throws Exception {
1756
1757    // can classifier handle the data?
1758    getCapabilities().testWithFail(i);
1759
1760    // remove instances with missing class
1761    i = new Instances(i);
1762    i.deleteWithMissingClass();
1763
1764    // only class? -> build ZeroR model
1765    if (i.numAttributes() == 1) {
1766      System.err.println(
1767          "Cannot build model (only class attribute present in data!), "
1768          + "using ZeroR model instead!");
1769      m_ZeroR = new weka.classifiers.rules.ZeroR();
1770      m_ZeroR.buildClassifier(i);
1771      return;
1772    }
1773    else {
1774      m_ZeroR = null;
1775    }
1776   
1777    m_epoch = 0;
1778    m_error = 0;
1779    m_instances = null;
1780    m_currentInstance = null;
1781    m_controlPanel = null;
1782    m_nodePanel = null;
1783   
1784   
1785    m_outputs = new NeuralEnd[0];
1786    m_inputs = new NeuralEnd[0];
1787    m_numAttributes = 0;
1788    m_numClasses = 0;
1789    m_neuralNodes = new NeuralConnection[0];
1790   
1791    m_selected = new FastVector(4);
1792    m_graphers = new FastVector(2);
1793    m_nextId = 0;
1794    m_stopIt = true;
1795    m_stopped = true;
1796    m_accepted = false;   
1797    m_instances = new Instances(i);
1798    m_random = new Random(m_randomSeed);
1799    m_instances.randomize(m_random);
1800
1801    if (m_useNomToBin) {
1802      m_nominalToBinaryFilter = new NominalToBinary();
1803      m_nominalToBinaryFilter.setInputFormat(m_instances);
1804      m_instances = Filter.useFilter(m_instances,
1805                                     m_nominalToBinaryFilter);
1806    }
1807    m_numAttributes = m_instances.numAttributes() - 1;
1808    m_numClasses = m_instances.numClasses();
1809 
1810   
1811    setClassType(m_instances);
1812   
1813
1814   
1815    //this sets up the validation set.
1816    Instances valSet = null;
1817    //numinval is needed later
1818    int numInVal = (int)(m_valSize / 100.0 * m_instances.numInstances());
1819    if (m_valSize > 0) {
1820      if (numInVal == 0) {
1821        numInVal = 1;
1822      }
1823      valSet = new Instances(m_instances, 0, numInVal);
1824    }
1825    ///////////
1826
1827    setupInputs();
1828     
1829    setupOutputs();   
1830    if (m_autoBuild) {
1831      setupHiddenLayer();
1832    }
1833   
1834    /////////////////////////////
1835    //this sets up the gui for usage
1836    if (m_gui) {
1837      m_win = new JFrame();
1838     
1839      m_win.addWindowListener(new WindowAdapter() {
1840          public void windowClosing(WindowEvent e) {
1841            boolean k = m_stopIt;
1842            m_stopIt = true;
1843            int well =JOptionPane.showConfirmDialog(m_win, 
1844                                                    "Are You Sure...\n"
1845                                                    + "Click Yes To Accept"
1846                                                    + " The Neural Network" 
1847                                                    + "\n Click No To Return",
1848                                                    "Accept Neural Network", 
1849                                                    JOptionPane.YES_NO_OPTION);
1850           
1851            if (well == 0) {
1852              m_win.setDefaultCloseOperation(JFrame.DISPOSE_ON_CLOSE);
1853              m_accepted = true;
1854              blocker(false);
1855            }
1856            else {
1857              m_win.setDefaultCloseOperation(JFrame.DO_NOTHING_ON_CLOSE);
1858            }
1859            m_stopIt = k;
1860          }
1861        });
1862     
1863      m_win.getContentPane().setLayout(new BorderLayout());
1864      m_win.setTitle("Neural Network");
1865      m_nodePanel = new NodePanel();
1866      // without the following two lines, the NodePanel.paintComponents(Graphics)
1867      // method will go berserk if the network doesn't fit completely: it will
1868      // get called on a constant basis, using 100% of the CPU
1869      // see the following forum thread:
1870      // http://forum.java.sun.com/thread.jspa?threadID=580929&messageID=2945011
1871      m_nodePanel.setPreferredSize(new Dimension(640, 480));
1872      m_nodePanel.revalidate();
1873
1874      JScrollPane sp = new JScrollPane(m_nodePanel,
1875                                       JScrollPane.VERTICAL_SCROLLBAR_ALWAYS, 
1876                                       JScrollPane.HORIZONTAL_SCROLLBAR_NEVER);
1877      m_controlPanel = new ControlPanel();
1878           
1879      m_win.getContentPane().add(sp, BorderLayout.CENTER);
1880      m_win.getContentPane().add(m_controlPanel, BorderLayout.SOUTH);
1881      m_win.setSize(640, 480);
1882      m_win.setVisible(true);
1883    }
1884   
1885    //This sets up the initial state of the gui
1886    if (m_gui) {
1887      blocker(true);
1888      m_controlPanel.m_changeEpochs.setEnabled(false);
1889      m_controlPanel.m_changeLearning.setEnabled(false);
1890      m_controlPanel.m_changeMomentum.setEnabled(false);
1891    } 
1892   
1893    //For silly situations in which the network gets accepted before training
1894    //commenses
1895    if (m_numeric) {
1896      setEndsToLinear();
1897    }
1898    if (m_accepted) {
1899      m_win.dispose();
1900      m_controlPanel = null;
1901      m_nodePanel = null;
1902      m_instances = new Instances(m_instances, 0);
1903      return;
1904    }
1905
1906    //connections done.
1907    double right = 0;
1908    double driftOff = 0;
1909    double lastRight = Double.POSITIVE_INFINITY;
1910    double bestError = Double.POSITIVE_INFINITY;
1911    double tempRate;
1912    double totalWeight = 0;
1913    double totalValWeight = 0;
1914    double origRate = m_learningRate; //only used for when reset
1915   
1916    //ensure that at least 1 instance is trained through.
1917    if (numInVal == m_instances.numInstances()) {
1918      numInVal--;
1919    }
1920    if (numInVal < 0) {
1921      numInVal = 0;
1922    }
1923    for (int noa = numInVal; noa < m_instances.numInstances(); noa++) {
1924      if (!m_instances.instance(noa).classIsMissing()) {
1925        totalWeight += m_instances.instance(noa).weight();
1926      }
1927    }
1928    if (m_valSize != 0) {
1929      for (int noa = 0; noa < valSet.numInstances(); noa++) {
1930        if (!valSet.instance(noa).classIsMissing()) {
1931          totalValWeight += valSet.instance(noa).weight();
1932        }
1933      }
1934    }
1935    m_stopped = false;
1936     
1937
1938    for (int noa = 1; noa < m_numEpochs + 1; noa++) {
1939      right = 0;
1940      for (int nob = numInVal; nob < m_instances.numInstances(); nob++) {
1941        m_currentInstance = m_instances.instance(nob);
1942       
1943        if (!m_currentInstance.classIsMissing()) {
1944           
1945          //this is where the network updating (and training occurs, for the
1946          //training set
1947          resetNetwork();
1948          calculateOutputs();
1949          tempRate = m_learningRate * m_currentInstance.weight(); 
1950          if (m_decay) {
1951            tempRate /= noa;
1952          }
1953
1954          right += (calculateErrors() / m_instances.numClasses()) *
1955            m_currentInstance.weight();
1956          updateNetworkWeights(tempRate, m_momentum);
1957         
1958        }
1959       
1960      }
1961      right /= totalWeight;
1962      if (Double.isInfinite(right) || Double.isNaN(right)) {
1963        if (!m_reset) {
1964          m_instances = null;
1965          throw new Exception("Network cannot train. Try restarting with a" +
1966                              " smaller learning rate.");
1967        }
1968        else {
1969          //reset the network if possible
1970          if (m_learningRate <= Utils.SMALL)
1971            throw new IllegalStateException(
1972                "Learning rate got too small (" + m_learningRate
1973                + " <= " + Utils.SMALL + ")!");
1974          m_learningRate /= 2;
1975          buildClassifier(i);
1976          m_learningRate = origRate;
1977          m_instances = new Instances(m_instances, 0);   
1978          return;
1979        }
1980      }
1981
1982      ////////////////////////do validation testing if applicable
1983      if (m_valSize != 0) {
1984        right = 0;
1985        for (int nob = 0; nob < valSet.numInstances(); nob++) {
1986          m_currentInstance = valSet.instance(nob);
1987          if (!m_currentInstance.classIsMissing()) {
1988            //this is where the network updating occurs, for the validation set
1989            resetNetwork();
1990            calculateOutputs();
1991            right += (calculateErrors() / valSet.numClasses()) 
1992              * m_currentInstance.weight();
1993            //note 'right' could be calculated here just using
1994            //the calculate output values. This would be faster.
1995            //be less modular
1996          }
1997         
1998        }
1999       
2000        if (right < lastRight) {
2001          if (right < bestError) {
2002            bestError = right;
2003            // save the network weights at this point
2004            for (int noc = 0; noc < m_numClasses; noc++) {
2005              m_outputs[noc].saveWeights();
2006            }
2007            driftOff = 0;
2008          }
2009        }
2010        else {
2011          driftOff++;
2012        }
2013        lastRight = right;
2014        if (driftOff > m_driftThreshold || noa + 1 >= m_numEpochs) {
2015          for (int noc = 0; noc < m_numClasses; noc++) {
2016            m_outputs[noc].restoreWeights();
2017          }
2018          m_accepted = true;
2019        }
2020        right /= totalValWeight;
2021      }
2022      m_epoch = noa;
2023      m_error = right;
2024      //shows what the neuralnet is upto if a gui exists.
2025      updateDisplay();
2026      //This junction controls what state the gui is in at the end of each
2027      //epoch, Such as if it is paused, if it is resumable etc...
2028      if (m_gui) {
2029        while ((m_stopIt || (m_epoch >= m_numEpochs && m_valSize == 0)) && 
2030                !m_accepted) {
2031          m_stopIt = true;
2032          m_stopped = true;
2033          if (m_epoch >= m_numEpochs && m_valSize == 0) {
2034           
2035            m_controlPanel.m_startStop.setEnabled(false);
2036          }
2037          else {
2038            m_controlPanel.m_startStop.setEnabled(true);
2039          }
2040          m_controlPanel.m_startStop.setText("Start");
2041          m_controlPanel.m_startStop.setActionCommand("Start");
2042          m_controlPanel.m_changeEpochs.setEnabled(true);
2043          m_controlPanel.m_changeLearning.setEnabled(true);
2044          m_controlPanel.m_changeMomentum.setEnabled(true);
2045         
2046          blocker(true);
2047          if (m_numeric) {
2048            setEndsToLinear();
2049          }
2050        }
2051        m_controlPanel.m_changeEpochs.setEnabled(false);
2052        m_controlPanel.m_changeLearning.setEnabled(false);
2053        m_controlPanel.m_changeMomentum.setEnabled(false);
2054       
2055        m_stopped = false;
2056        //if the network has been accepted stop the training loop
2057        if (m_accepted) {
2058          m_win.dispose();
2059          m_controlPanel = null;
2060          m_nodePanel = null;
2061          m_instances = new Instances(m_instances, 0);
2062          return;
2063        }
2064      }
2065      if (m_accepted) {
2066        m_instances = new Instances(m_instances, 0);
2067        return;
2068      }
2069    }
2070    if (m_gui) {
2071      m_win.dispose();
2072      m_controlPanel = null;
2073      m_nodePanel = null;
2074    }
2075    m_instances = new Instances(m_instances, 0); 
2076  }
2077
2078  /**
2079   * Call this function to predict the class of an instance once a
2080   * classification model has been built with the buildClassifier call.
2081   * @param i The instance to classify.
2082   * @return A double array filled with the probabilities of each class type.
2083   * @throws Exception if can't classify instance.
2084   */
2085  public double[] distributionForInstance(Instance i) throws Exception {
2086
2087    // default model?
2088    if (m_ZeroR != null) {
2089      return m_ZeroR.distributionForInstance(i);
2090    }
2091   
2092    if (m_useNomToBin) {
2093      m_nominalToBinaryFilter.input(i);
2094      m_currentInstance = m_nominalToBinaryFilter.output();
2095    }
2096    else {
2097      m_currentInstance = i;
2098    }
2099   
2100    if (m_normalizeAttributes) {
2101      for (int noa = 0; noa < m_instances.numAttributes(); noa++) {
2102        if (noa != m_instances.classIndex()) {
2103          if (m_attributeRanges[noa] != 0) {
2104            m_currentInstance.setValue(noa, (m_currentInstance.value(noa) - 
2105                                             m_attributeBases[noa]) / 
2106                                       m_attributeRanges[noa]);
2107          }
2108          else {
2109            m_currentInstance.setValue(noa, m_currentInstance.value(noa) -
2110                                       m_attributeBases[noa]);
2111          }
2112        }
2113      }
2114    }
2115    resetNetwork();
2116   
2117    //since all the output values are needed.
2118    //They are calculated manually here and the values collected.
2119    double[] theArray = new double[m_numClasses];
2120    for (int noa = 0; noa < m_numClasses; noa++) {
2121      theArray[noa] = m_outputs[noa].outputValue(true);
2122    }
2123    if (m_instances.classAttribute().isNumeric()) {
2124      return theArray;
2125    }
2126   
2127    //now normalize the array
2128    double count = 0;
2129    for (int noa = 0; noa < m_numClasses; noa++) {
2130      count += theArray[noa];
2131    }
2132    if (count <= 0) {
2133      return null;
2134    }
2135    for (int noa = 0; noa < m_numClasses; noa++) {
2136      theArray[noa] /= count;
2137    }
2138    return theArray;
2139  }
2140 
2141
2142
2143  /**
2144   * Returns an enumeration describing the available options.
2145   *
2146   * @return an enumeration of all the available options.
2147   */
2148  public Enumeration listOptions() {
2149   
2150    Vector newVector = new Vector(14);
2151
2152    newVector.addElement(new Option(
2153              "\tLearning Rate for the backpropagation algorithm.\n"
2154              +"\t(Value should be between 0 - 1, Default = 0.3).",
2155              "L", 1, "-L <learning rate>"));
2156    newVector.addElement(new Option(
2157              "\tMomentum Rate for the backpropagation algorithm.\n"
2158              +"\t(Value should be between 0 - 1, Default = 0.2).",
2159              "M", 1, "-M <momentum>"));
2160    newVector.addElement(new Option(
2161              "\tNumber of epochs to train through.\n"
2162              +"\t(Default = 500).",
2163              "N", 1,"-N <number of epochs>"));
2164    newVector.addElement(new Option(
2165              "\tPercentage size of validation set to use to terminate\n"
2166              + "\ttraining (if this is non zero it can pre-empt num of epochs.\n"
2167              +"\t(Value should be between 0 - 100, Default = 0).",
2168              "V", 1, "-V <percentage size of validation set>"));
2169    newVector.addElement(new Option(
2170              "\tThe value used to seed the random number generator\n"
2171              + "\t(Value should be >= 0 and and a long, Default = 0).",
2172              "S", 1, "-S <seed>"));
2173    newVector.addElement(new Option(
2174              "\tThe consequetive number of errors allowed for validation\n"
2175              + "\ttesting before the netwrok terminates.\n"
2176              + "\t(Value should be > 0, Default = 20).",
2177              "E", 1, "-E <threshold for number of consequetive errors>"));
2178    newVector.addElement(new Option(
2179              "\tGUI will be opened.\n"
2180              +"\t(Use this to bring up a GUI).",
2181              "G", 0,"-G"));
2182    newVector.addElement(new Option(
2183              "\tAutocreation of the network connections will NOT be done.\n"
2184              +"\t(This will be ignored if -G is NOT set)",
2185              "A", 0,"-A"));
2186    newVector.addElement(new Option(
2187              "\tA NominalToBinary filter will NOT automatically be used.\n"
2188              +"\t(Set this to not use a NominalToBinary filter).",
2189              "B", 0,"-B"));
2190    newVector.addElement(new Option(
2191              "\tThe hidden layers to be created for the network.\n"
2192              + "\t(Value should be a list of comma separated Natural \n"
2193              + "\tnumbers or the letters 'a' = (attribs + classes) / 2, \n"
2194              + "\t'i' = attribs, 'o' = classes, 't' = attribs .+ classes)\n"
2195              + "\tfor wildcard values, Default = a).",
2196              "H", 1, "-H <comma seperated numbers for nodes on each layer>"));
2197    newVector.addElement(new Option(
2198              "\tNormalizing a numeric class will NOT be done.\n"
2199              +"\t(Set this to not normalize the class if it's numeric).",
2200              "C", 0,"-C"));
2201    newVector.addElement(new Option(
2202              "\tNormalizing the attributes will NOT be done.\n"
2203              +"\t(Set this to not normalize the attributes).",
2204              "I", 0,"-I"));
2205    newVector.addElement(new Option(
2206              "\tReseting the network will NOT be allowed.\n"
2207              +"\t(Set this to not allow the network to reset).",
2208              "R", 0,"-R"));
2209    newVector.addElement(new Option(
2210              "\tLearning rate decay will occur.\n"
2211              +"\t(Set this to cause the learning rate to decay).",
2212              "D", 0,"-D"));
2213   
2214   
2215    return newVector.elements();
2216  }
2217
2218  /**
2219   * Parses a given list of options. <p/>
2220   *
2221   <!-- options-start -->
2222   * Valid options are: <p/>
2223   *
2224   * <pre> -L &lt;learning rate&gt;
2225   *  Learning Rate for the backpropagation algorithm.
2226   *  (Value should be between 0 - 1, Default = 0.3).</pre>
2227   *
2228   * <pre> -M &lt;momentum&gt;
2229   *  Momentum Rate for the backpropagation algorithm.
2230   *  (Value should be between 0 - 1, Default = 0.2).</pre>
2231   *
2232   * <pre> -N &lt;number of epochs&gt;
2233   *  Number of epochs to train through.
2234   *  (Default = 500).</pre>
2235   *
2236   * <pre> -V &lt;percentage size of validation set&gt;
2237   *  Percentage size of validation set to use to terminate
2238   *  training (if this is non zero it can pre-empt num of epochs.
2239   *  (Value should be between 0 - 100, Default = 0).</pre>
2240   *
2241   * <pre> -S &lt;seed&gt;
2242   *  The value used to seed the random number generator
2243   *  (Value should be &gt;= 0 and and a long, Default = 0).</pre>
2244   *
2245   * <pre> -E &lt;threshold for number of consequetive errors&gt;
2246   *  The consequetive number of errors allowed for validation
2247   *  testing before the netwrok terminates.
2248   *  (Value should be &gt; 0, Default = 20).</pre>
2249   *
2250   * <pre> -G
2251   *  GUI will be opened.
2252   *  (Use this to bring up a GUI).</pre>
2253   *
2254   * <pre> -A
2255   *  Autocreation of the network connections will NOT be done.
2256   *  (This will be ignored if -G is NOT set)</pre>
2257   *
2258   * <pre> -B
2259   *  A NominalToBinary filter will NOT automatically be used.
2260   *  (Set this to not use a NominalToBinary filter).</pre>
2261   *
2262   * <pre> -H &lt;comma seperated numbers for nodes on each layer&gt;
2263   *  The hidden layers to be created for the network.
2264   *  (Value should be a list of comma separated Natural
2265   *  numbers or the letters 'a' = (attribs + classes) / 2,
2266   *  'i' = attribs, 'o' = classes, 't' = attribs .+ classes)
2267   *  for wildcard values, Default = a).</pre>
2268   *
2269   * <pre> -C
2270   *  Normalizing a numeric class will NOT be done.
2271   *  (Set this to not normalize the class if it's numeric).</pre>
2272   *
2273   * <pre> -I
2274   *  Normalizing the attributes will NOT be done.
2275   *  (Set this to not normalize the attributes).</pre>
2276   *
2277   * <pre> -R
2278   *  Reseting the network will NOT be allowed.
2279   *  (Set this to not allow the network to reset).</pre>
2280   *
2281   * <pre> -D
2282   *  Learning rate decay will occur.
2283   *  (Set this to cause the learning rate to decay).</pre>
2284   *
2285   <!-- options-end -->
2286   *
2287   * @param options the list of options as an array of strings
2288   * @throws Exception if an option is not supported
2289   */
2290  public void setOptions(String[] options) throws Exception {
2291    //the defaults can be found here!!!!
2292    String learningString = Utils.getOption('L', options);
2293    if (learningString.length() != 0) {
2294      setLearningRate((new Double(learningString)).doubleValue());
2295    } else {
2296      setLearningRate(0.3);
2297    }
2298    String momentumString = Utils.getOption('M', options);
2299    if (momentumString.length() != 0) {
2300      setMomentum((new Double(momentumString)).doubleValue());
2301    } else {
2302      setMomentum(0.2);
2303    }
2304    String epochsString = Utils.getOption('N', options);
2305    if (epochsString.length() != 0) {
2306      setTrainingTime(Integer.parseInt(epochsString));
2307    } else {
2308      setTrainingTime(500);
2309    }
2310    String valSizeString = Utils.getOption('V', options);
2311    if (valSizeString.length() != 0) {
2312      setValidationSetSize(Integer.parseInt(valSizeString));
2313    } else {
2314      setValidationSetSize(0);
2315    }
2316    String seedString = Utils.getOption('S', options);
2317    if (seedString.length() != 0) {
2318      setSeed(Integer.parseInt(seedString));
2319    } else {
2320      setSeed(0);
2321    }
2322    String thresholdString = Utils.getOption('E', options);
2323    if (thresholdString.length() != 0) {
2324      setValidationThreshold(Integer.parseInt(thresholdString));
2325    } else {
2326      setValidationThreshold(20);
2327    }
2328    String hiddenLayers = Utils.getOption('H', options);
2329    if (hiddenLayers.length() != 0) {
2330      setHiddenLayers(hiddenLayers);
2331    } else {
2332      setHiddenLayers("a");
2333    }
2334    if (Utils.getFlag('G', options)) {
2335      setGUI(true);
2336    } else {
2337      setGUI(false);
2338    } //small note. since the gui is the only option that can change the other
2339    //options this should be set first to allow the other options to set
2340    //properly
2341    if (Utils.getFlag('A', options)) {
2342      setAutoBuild(false);
2343    } else {
2344      setAutoBuild(true);
2345    }
2346    if (Utils.getFlag('B', options)) {
2347      setNominalToBinaryFilter(false);
2348    } else {
2349      setNominalToBinaryFilter(true);
2350    }
2351    if (Utils.getFlag('C', options)) {
2352      setNormalizeNumericClass(false);
2353    } else {
2354      setNormalizeNumericClass(true);
2355    }
2356    if (Utils.getFlag('I', options)) {
2357      setNormalizeAttributes(false);
2358    } else {
2359      setNormalizeAttributes(true);
2360    }
2361    if (Utils.getFlag('R', options)) {
2362      setReset(false);
2363    } else {
2364      setReset(true);
2365    }
2366    if (Utils.getFlag('D', options)) {
2367      setDecay(true);
2368    } else {
2369      setDecay(false);
2370    }
2371   
2372    Utils.checkForRemainingOptions(options);
2373  }
2374 
2375  /**
2376   * Gets the current settings of NeuralNet.
2377   *
2378   * @return an array of strings suitable for passing to setOptions()
2379   */
2380  public String [] getOptions() {
2381
2382    String [] options = new String [21];
2383    int current = 0;
2384    options[current++] = "-L"; options[current++] = "" + getLearningRate(); 
2385    options[current++] = "-M"; options[current++] = "" + getMomentum();
2386    options[current++] = "-N"; options[current++] = "" + getTrainingTime(); 
2387    options[current++] = "-V"; options[current++] = "" +getValidationSetSize();
2388    options[current++] = "-S"; options[current++] = "" + getSeed();
2389    options[current++] = "-E"; options[current++] =""+getValidationThreshold();
2390    options[current++] = "-H"; options[current++] = getHiddenLayers();
2391    if (getGUI()) {
2392      options[current++] = "-G";
2393    }
2394    if (!getAutoBuild()) {
2395      options[current++] = "-A";
2396    }
2397    if (!getNominalToBinaryFilter()) {
2398      options[current++] = "-B";
2399    }
2400    if (!getNormalizeNumericClass()) {
2401      options[current++] = "-C";
2402    }
2403    if (!getNormalizeAttributes()) {
2404      options[current++] = "-I";
2405    }
2406    if (!getReset()) {
2407      options[current++] = "-R";
2408    }
2409    if (getDecay()) {
2410      options[current++] = "-D";
2411    }
2412
2413   
2414    while (current < options.length) {
2415      options[current++] = "";
2416    }
2417    return options;
2418  }
2419 
2420  /**
2421   * @return string describing the model.
2422   */
2423  public String toString() {
2424    // only ZeroR model?
2425    if (m_ZeroR != null) {
2426      StringBuffer buf = new StringBuffer();
2427      buf.append(this.getClass().getName().replaceAll(".*\\.", "") + "\n");
2428      buf.append(this.getClass().getName().replaceAll(".*\\.", "").replaceAll(".", "=") + "\n\n");
2429      buf.append("Warning: No model could be built, hence ZeroR model is used:\n\n");
2430      buf.append(m_ZeroR.toString());
2431      return buf.toString();
2432    }
2433   
2434    StringBuffer model = new StringBuffer(m_neuralNodes.length * 100); 
2435    //just a rough size guess
2436    NeuralNode con;
2437    double[] weights;
2438    NeuralConnection[] inputs;
2439    for (int noa = 0; noa < m_neuralNodes.length; noa++) {
2440      con = (NeuralNode) m_neuralNodes[noa];  //this would need a change
2441                                              //for items other than nodes!!!
2442      weights = con.getWeights();
2443      inputs = con.getInputs();
2444      if (con.getMethod() instanceof SigmoidUnit) {
2445        model.append("Sigmoid ");
2446      }
2447      else if (con.getMethod() instanceof LinearUnit) {
2448        model.append("Linear ");
2449      }
2450      model.append("Node " + con.getId() + "\n    Inputs    Weights\n");
2451      model.append("    Threshold    " + weights[0] + "\n");
2452      for (int nob = 1; nob < con.getNumInputs() + 1; nob++) {
2453        if ((inputs[nob - 1].getType() & NeuralConnection.PURE_INPUT) 
2454            == NeuralConnection.PURE_INPUT) {
2455          model.append("    Attrib " + 
2456                       m_instances.attribute(((NeuralEnd)inputs[nob-1]).
2457                                             getLink()).name()
2458                       + "    " + weights[nob] + "\n");
2459        }
2460        else {
2461          model.append("    Node " + inputs[nob-1].getId() + "    " +
2462                       weights[nob] + "\n");
2463        }
2464      }     
2465    }
2466    //now put in the ends
2467    for (int noa = 0; noa < m_outputs.length; noa++) {
2468      inputs = m_outputs[noa].getInputs();
2469      model.append("Class " + 
2470                   m_instances.classAttribute().
2471                   value(m_outputs[noa].getLink()) + 
2472                   "\n    Input\n");
2473      for (int nob = 0; nob < m_outputs[noa].getNumInputs(); nob++) {
2474        if ((inputs[nob].getType() & NeuralConnection.PURE_INPUT)
2475            == NeuralConnection.PURE_INPUT) {
2476          model.append("    Attrib " +
2477                       m_instances.attribute(((NeuralEnd)inputs[nob]).
2478                                             getLink()).name() + "\n");
2479        }
2480        else {
2481          model.append("    Node " + inputs[nob].getId() + "\n");
2482        }
2483      }
2484    }
2485    return model.toString();
2486  }
2487
2488  /**
2489   * This will return a string describing the classifier.
2490   * @return The string.
2491   */
2492  public String globalInfo() {
2493    return 
2494        "A Classifier that uses backpropagation to classify instances.\n"
2495      + "This network can be built by hand, created by an algorithm or both. "
2496      + "The network can also be monitored and modified during training time. "
2497      + "The nodes in this network are all sigmoid (except for when the class "
2498      + "is numeric in which case the the output nodes become unthresholded "
2499      + "linear units).";
2500  }
2501 
2502  /**
2503   * @return a string to describe the learning rate option.
2504   */
2505  public String learningRateTipText() {
2506    return "The amount the" + 
2507      " weights are updated.";
2508  }
2509 
2510  /**
2511   * @return a string to describe the momentum option.
2512   */
2513  public String momentumTipText() {
2514    return "Momentum applied to the weights during updating.";
2515  }
2516
2517  /**
2518   * @return a string to describe the AutoBuild option.
2519   */
2520  public String autoBuildTipText() {
2521    return "Adds and connects up hidden layers in the network.";
2522  }
2523
2524  /**
2525   * @return a string to describe the random seed option.
2526   */
2527  public String seedTipText() {
2528    return "Seed used to initialise the random number generator." +
2529      "Random numbers are used for setting the initial weights of the" +
2530      " connections betweem nodes, and also for shuffling the training data.";
2531  }
2532 
2533  /**
2534   * @return a string to describe the validation threshold option.
2535   */
2536  public String validationThresholdTipText() {
2537    return "Used to terminate validation testing." +
2538      "The value here dictates how many times in a row the validation set" +
2539      " error can get worse before training is terminated.";
2540  }
2541 
2542  /**
2543   * @return a string to describe the GUI option.
2544   */
2545  public String GUITipText() {
2546    return "Brings up a gui interface." +
2547      " This will allow the pausing and altering of the nueral network" +
2548      " during training.\n\n" +
2549      "* To add a node left click (this node will be automatically selected," +
2550      " ensure no other nodes were selected).\n" +
2551      "* To select a node left click on it either while no other node is" +
2552      " selected or while holding down the control key (this toggles that" +
2553      " node as being selected and not selected.\n" + 
2554      "* To connect a node, first have the start node(s) selected, then click"+
2555      " either the end node or on an empty space (this will create a new node"+
2556      " that is connected with the selected nodes). The selection status of" +
2557      " nodes will stay the same after the connection. (Note these are" +
2558      " directed connections, also a connection between two nodes will not" +
2559      " be established more than once and certain connections that are" + 
2560      " deemed to be invalid will not be made).\n" +
2561      "* To remove a connection select one of the connected node(s) in the" +
2562      " connection and then right click the other node (it does not matter" +
2563      " whether the node is the start or end the connection will be removed" +
2564      ").\n" +
2565      "* To remove a node right click it while no other nodes (including it)" +
2566      " are selected. (This will also remove all connections to it)\n." +
2567      "* To deselect a node either left click it while holding down control," +
2568      " or right click on empty space.\n" +
2569      "* The raw inputs are provided from the labels on the left.\n" +
2570      "* The red nodes are hidden layers.\n" +
2571      "* The orange nodes are the output nodes.\n" +
2572      "* The labels on the right show the class the output node represents." +
2573      " Note that with a numeric class the output node will automatically be" +
2574      " made into an unthresholded linear unit.\n\n" +
2575      "Alterations to the neural network can only be done while the network" +
2576      " is not running, This also applies to the learning rate and other" +
2577      " fields on the control panel.\n\n" + 
2578      "* You can accept the network as being finished at any time.\n" +
2579      "* The network is automatically paused at the beginning.\n" +
2580      "* There is a running indication of what epoch the network is up to" + 
2581      " and what the (rough) error for that epoch was (or for" +
2582      " the validation if that is being used). Note that this error value" +
2583      " is based on a network that changes as the value is computed." +
2584      " (also depending on whether" +
2585      " the class is normalized will effect the error reported for numeric" +
2586      " classes.\n" +
2587      "* Once the network is done it will pause again and either wait to be" +
2588      " accepted or trained more.\n\n" +
2589      "Note that if the gui is not set the network will not require any" +
2590      " interaction.\n";
2591  }
2592 
2593  /**
2594   * @return a string to describe the validation size option.
2595   */
2596  public String validationSetSizeTipText() {
2597    return "The percentage size of the validation set." +
2598      "(The training will continue until it is observed that" +
2599      " the error on the validation set has been consistently getting" +
2600      " worse, or if the training time is reached).\n" +
2601      "If This is set to zero no validation set will be used and instead" +
2602      " the network will train for the specified number of epochs.";
2603  }
2604 
2605  /**
2606   * @return a string to describe the learning rate option.
2607   */
2608  public String trainingTimeTipText() {
2609    return "The number of epochs to train through." + 
2610      " If the validation set is non-zero then it can terminate the network" +
2611      " early";
2612  }
2613
2614
2615  /**
2616   * @return a string to describe the nominal to binary option.
2617   */
2618  public String nominalToBinaryFilterTipText() {
2619    return "This will preprocess the instances with the filter." +
2620      " This could help improve performance if there are nominal attributes" +
2621      " in the data.";
2622  }
2623
2624  /**
2625   * @return a string to describe the hidden layers in the network.
2626   */
2627  public String hiddenLayersTipText() {
2628    return "This defines the hidden layers of the neural network." +
2629      " This is a list of positive whole numbers. 1 for each hidden layer." +
2630      " Comma seperated. To have no hidden layers put a single 0 here." +
2631      " This will only be used if autobuild is set. There are also wildcard" +
2632      " values 'a' = (attribs + classes) / 2, 'i' = attribs, 'o' = classes" +
2633      " , 't' = attribs + classes.";
2634  }
2635  /**
2636   * @return a string to describe the nominal to binary option.
2637   */
2638  public String normalizeNumericClassTipText() {
2639    return "This will normalize the class if it's numeric." +
2640      " This could help improve performance of the network, It normalizes" +
2641      " the class to be between -1 and 1. Note that this is only internally" +
2642      ", the output will be scaled back to the original range.";
2643  }
2644  /**
2645   * @return a string to describe the nominal to binary option.
2646   */
2647  public String normalizeAttributesTipText() {
2648    return "This will normalize the attributes." +
2649      " This could help improve performance of the network." +
2650      " This is not reliant on the class being numeric. This will also" +
2651      " normalize nominal attributes as well (after they have been run" +
2652      " through the nominal to binary filter if that is in use) so that the" +
2653      " nominal values are between -1 and 1";
2654  }
2655  /**
2656   * @return a string to describe the Reset option.
2657   */
2658  public String resetTipText() {
2659    return "This will allow the network to reset with a lower learning rate." +
2660      " If the network diverges from the answer this will automatically" +
2661      " reset the network with a lower learning rate and begin training" +
2662      " again. This option is only available if the gui is not set. Note" +
2663      " that if the network diverges but isn't allowed to reset it will" +
2664      " fail the training process and return an error message.";
2665  }
2666 
2667  /**
2668   * @return a string to describe the Decay option.
2669   */
2670  public String decayTipText() {
2671    return "This will cause the learning rate to decrease." +
2672      " This will divide the starting learning rate by the epoch number, to" +
2673      " determine what the current learning rate should be. This may help" +
2674      " to stop the network from diverging from the target output, as well" +
2675      " as improve general performance. Note that the decaying learning" +
2676      " rate will not be shown in the gui, only the original learning rate" +
2677      ". If the learning rate is changed in the gui, this is treated as the" +
2678      " starting learning rate.";
2679  }
2680 
2681  /**
2682   * Returns the revision string.
2683   *
2684   * @return            the revision
2685   */
2686  public String getRevision() {
2687    return RevisionUtils.extract("$Revision: 5928 $");
2688  }
2689}
Note: See TracBrowser for help on using the repository browser.