source: src/main/java/weka/classifiers/functions/MultilayerPerceptronCS.java @ 16

Last change on this file since 16 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 95.1 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    MultilayerPerceptronCS.java
19 *    Copyright (C) 2000 University of Waikato, Hamilton, New Zealand
20 */
21
22package weka.classifiers.functions;
23import weka.classifiers.AbstractClassifier;
24import weka.classifiers.Classifier;
25import weka.classifiers.functions.neural.LinearUnit;
26import weka.classifiers.functions.neural.NeuralConnection;
27import weka.classifiers.functions.neural.NeuralNode;
28import weka.classifiers.functions.neural.SigmoidUnit;
29import weka.core.Capabilities;
30//modCS.V & modCS.S; require the data source converter to read the files.
31import weka.core.converters.ConverterUtils.DataSource;
32
33import weka.core.FastVector;
34import weka.core.Instance;
35import weka.core.Instances;
36import weka.core.Option;
37import weka.core.OptionHandler;
38import weka.core.Randomizable;
39import weka.core.RevisionHandler;
40import weka.core.RevisionUtils;
41import weka.core.Utils;
42import weka.core.WeightedInstancesHandler;
43import weka.core.Capabilities.Capability;
44import weka.filters.Filter;
45import weka.filters.unsupervised.attribute.NominalToBinary;
46
47import java.awt.BorderLayout;
48import java.awt.Color;
49import java.awt.Component;
50import java.awt.Dimension;
51import java.awt.FontMetrics;
52import java.awt.Graphics;
53import java.awt.event.ActionEvent;
54import java.awt.event.ActionListener;
55import java.awt.event.MouseAdapter;
56import java.awt.event.MouseEvent;
57import java.awt.event.WindowAdapter;
58import java.awt.event.WindowEvent;
59import java.util.Enumeration;
60import java.util.Random;
61import java.util.StringTokenizer;
62import java.util.Vector;
63
64import javax.swing.BorderFactory;
65import javax.swing.Box;
66import javax.swing.BoxLayout;
67import javax.swing.JButton;
68import javax.swing.JFrame;
69import javax.swing.JLabel;
70import javax.swing.JOptionPane;
71import javax.swing.JPanel;
72import javax.swing.JScrollPane;
73import javax.swing.JTextField;
74
75/**
76 <!-- globalinfo-start -->
77 * A Classifier that uses backpropagation to classify instances.<br/>
78 * This network can be built by hand, created by an algorithm or both. The network can also be monitored and modified during training time. The nodes in this network are all sigmoid (except for when the class is numeric in which case the the output nodes become unthresholded linear units).
79 * <p/>
80 <!-- globalinfo-end -->
81 *
82 <!-- options-start -->
83 * Valid options are: <p/>
84 *
85 * <pre> -L &lt;learning rate&gt;
86 *  Learning Rate for the backpropagation algorithm.
87 *  (Value should be between 0 - 1, Default = 0.3).</pre>
88 *
89 * <pre> -M &lt;momentum&gt;
90 *  Momentum Rate for the backpropagation algorithm.
91 *  (Value should be between 0 - 1, Default = 0.2).</pre>
92 *
93 * <pre> -N &lt;number of epochs&gt;
94 *  Number of epochs to train through.
95 *  (Default = 500).</pre>
96 *
97 * <pre> -V &lt;percentage size of validation set&gt;
98 *  Percentage size of validation set to use to terminate
99 *  training (if this is non zero it can pre-empt num of epochs.
100 *  (Value should be between 0 - 100, Default = 0).</pre>
101 *
102 * <pre> -S &lt;seed&gt;
103 *  The value used to seed the random number generator
104 *  (Value should be &gt;= 0 and and a long, Default = 0).</pre>
105 *
106 * <pre> -E &lt;threshold for number of consequetive errors&gt;
107 *  The consequetive number of errors allowed for validation
108 *  testing before the netwrok terminates.
109 *  (Value should be &gt; 0, Default = 20).</pre>
110 *
111 * <pre> -G
112 *  GUI will be opened.
113 *  (Use this to bring up a GUI).</pre>
114 *
115 * <pre> -A
116 *  Autocreation of the network connections will NOT be done.
117 *  (This will be ignored if -G is NOT set)</pre>
118 *
119 * <pre> -B
120 *  A NominalToBinary filter will NOT automatically be used.
121 *  (Set this to not use a NominalToBinary filter).</pre>
122 *
123 * <pre> -H &lt;comma seperated numbers for nodes on each layer&gt;
124 *  The hidden layers to be created for the network.
125 *  (Value should be a list of comma separated Natural
126 *  numbers or the letters 'a' = (attribs + classes) / 2,
127 *  'i' = attribs, 'o' = classes, 't' = attribs .+ classes)
128 *  for wildcard values, Default = a).</pre>
129 *
130 * <pre> -C
131 *  Normalizing a numeric class will NOT be done.
132 *  (Set this to not normalize the class if it's numeric).</pre>
133 *
134 * <pre> -I
135 *  Normalizing the attributes will NOT be done.
136 *  (Set this to not normalize the attributes).</pre>
137 *
138 * <pre> -R
139 *  Reseting the network will NOT be allowed.
140 *  (Set this to not allow the network to reset).</pre>
141 *
142 * <pre> -D
143 *  Learning rate decay will occur.
144 *  (Set this to cause the learning rate to decay).</pre>
145 *
146 * <pre> -validation-set &lt;data source file&gt;
147 *  Validation set to use,  as drawn from the data source file.
148 * </pre>
149 *
150 * <pre> -secondary-training &lt;data source file&gt;
151 *  Secondary task training set to use, as drawn from the data source file.
152 * </pre>
153 *
154 <!-- options-end -->
155 *
156 * @author Malcolm Ware (mfw4@cs.waikato.ac.nz)
157 * @version $Revision: 6202 $
158 */
159public class MultilayerPerceptronCS 
160  extends AbstractClassifier
161  implements OptionHandler, WeightedInstancesHandler, Randomizable {
162 
163  /** for serialization */
164  static final long serialVersionUID = 572250905027665169L;
165 
166  /**
167   * Main method for testing this class.
168   *
169   * @param argv should contain command line options (see setOptions)
170   */
171  public static void main(String [] argv) {
172    runClassifier(new MultilayerPerceptronCS(), argv);
173  }
174 
175
176  /**
177   * This inner class is used to connect the nodes in the network up to
178   * the data that they are classifying, Note that objects of this class are
179   * only suitable to go on the attribute side or class side of the network
180   * and not both.
181   */
182  protected class NeuralEnd 
183    extends NeuralConnection {
184   
185    /** for serialization */
186    static final long serialVersionUID = 7305185603191183338L;
187 
188    /**
189     * the value that represents the instance value this node represents.
190     * For an input it is the attribute number, for an output, if nominal
191     * it is the class value.
192     */
193    private int m_link;
194   
195    /** True if node is an input, False if it's an output. */
196    private boolean m_input;
197
198    /**
199     * Constructor
200     */
201    public NeuralEnd(String id) {
202      super(id);
203
204      m_link = 0;
205      m_input = true;
206     
207    }
208 
209    /**
210     * Call this function to determine if the point at x,y is on the unit.
211     * @param g The graphics context for font size info.
212     * @param x The x coord.
213     * @param y The y coord.
214     * @param w The width of the display.
215     * @param h The height of the display.
216     * @return True if the point is on the unit, false otherwise.
217     */
218    public boolean onUnit(Graphics g, int x, int y, int w, int h) {
219     
220      FontMetrics fm = g.getFontMetrics();
221      int l = (int)(m_x * w) - fm.stringWidth(m_id) / 2;
222      int t = (int)(m_y * h) - fm.getHeight() / 2;
223      if (x < l || x > l + fm.stringWidth(m_id) + 4 
224          || y < t || y > t + fm.getHeight() + fm.getDescent() + 4) {
225        return false;
226      }
227      return true;
228     
229    }
230   
231
232    /**
233     * This will draw the node id to the graphics context.
234     * @param g The graphics context.
235     * @param w The width of the drawing area.
236     * @param h The height of the drawing area.
237     */
238    public void drawNode(Graphics g, int w, int h) {
239     
240      if ((m_type & PURE_INPUT) == PURE_INPUT) {
241        g.setColor(Color.green);
242      }
243      else {
244        g.setColor(Color.orange);
245      }
246     
247      FontMetrics fm = g.getFontMetrics();
248      int l = (int)(m_x * w) - fm.stringWidth(m_id) / 2;
249      int t = (int)(m_y * h) - fm.getHeight() / 2;
250      g.fill3DRect(l, t, fm.stringWidth(m_id) + 4
251                   , fm.getHeight() + fm.getDescent() + 4
252                   , true);
253      g.setColor(Color.black);
254     
255      g.drawString(m_id, l + 2, t + fm.getHeight() + 2);
256
257    }
258
259
260    /**
261     * Call this function to draw the node highlighted.
262     * @param g The graphics context.
263     * @param w The width of the drawing area.
264     * @param h The height of the drawing area.
265     */
266    public void drawHighlight(Graphics g, int w, int h) {
267     
268      g.setColor(Color.black);
269      FontMetrics fm = g.getFontMetrics();
270      int l = (int)(m_x * w) - fm.stringWidth(m_id) / 2;
271      int t = (int)(m_y * h) - fm.getHeight() / 2;
272      g.fillRect(l - 2, t - 2, fm.stringWidth(m_id) + 8
273                 , fm.getHeight() + fm.getDescent() + 8); 
274      drawNode(g, w, h);
275    }
276   
277    /**
278     * Call this to get the output value of this unit.
279     * @param calculate True if the value should be calculated if it hasn't
280     * been already.
281     * @return The output value, or NaN, if the value has not been calculated.
282     */
283    public double outputValue(boolean calculate) {
284     
285      if (Double.isNaN(m_unitValue) && calculate) {
286        if (m_input) {
287          if (m_currentInstance.isMissing(m_link)) {
288            m_unitValue = 0;
289          }
290          else {
291           
292            m_unitValue = m_currentInstance.value(m_link);
293          }
294        }
295        else {
296          //node is an output.
297          m_unitValue = 0;
298          for (int noa = 0; noa < m_numInputs; noa++) {
299            m_unitValue += m_inputList[noa].outputValue(true);
300           
301          }
302          if (m_numeric && m_normalizeClass) {
303            //then scale the value;
304            //this scales linearly from between -1 and 1
305            m_unitValue = m_unitValue * 
306              m_attributeRanges[m_instances.classIndex()] + 
307              m_attributeBases[m_instances.classIndex()];
308          }         
309        }
310      }
311      return m_unitValue;
312     
313     
314    }
315   
316    /**
317     * Call this to get the error value of this unit, which in this case is
318     * the difference between the predicted class, and the actual class.
319     * @param calculate True if the value should be calculated if it hasn't
320     * been already.
321     * @return The error value, or NaN, if the value has not been calculated.
322     */
323    public double errorValue(boolean calculate) {
324     
325      if (!Double.isNaN(m_unitValue) && Double.isNaN(m_unitError) 
326          && calculate) {
327       
328        if (m_input) {
329          m_unitError = 0;
330          for (int noa = 0; noa < m_numOutputs; noa++) {
331            m_unitError += m_outputList[noa].errorValue(true);
332          }
333        }
334        else {
335          if (m_currentInstance.classIsMissing()) {
336            m_unitError = .1; 
337          }
338          //dmod
339          else if (m_instances.classAttribute().isNominal()) {
340            if (m_currentInstance.classValue() == m_link) {
341              //1
342                m_unitError = 1 - m_unitValue;
343            }
344            else { //0
345              m_unitError =  0 - m_unitValue;
346            }
347          }
348          else if (m_numeric) {
349           
350            if (m_normalizeClass) {
351              if (m_attributeRanges[m_instances.classIndex()] == 0) {
352                m_unitError = 0;
353              }
354              else {
355                m_unitError = (m_currentInstance.classValue() - m_unitValue ) /
356                  m_attributeRanges[m_instances.classIndex()];
357                //m_numericRange;
358               
359              }
360            }
361            else {
362              m_unitError = m_currentInstance.classValue() - m_unitValue;
363            }
364          }
365        }
366      }
367      return m_unitError;
368    }
369   
370   
371    /**
372     * Call this to reset the value and error for this unit, ready for the next
373     * run. This will also call the reset function of all units that are
374     * connected as inputs to this one.
375     * This is also the time that the update for the listeners will be
376     * performed.
377     */
378    public void reset() {
379     
380      if (!Double.isNaN(m_unitValue) || !Double.isNaN(m_unitError)) {
381        m_unitValue = Double.NaN;
382        m_unitError = Double.NaN;
383        m_weightsUpdated = false;
384        for (int noa = 0; noa < m_numInputs; noa++) {
385          m_inputList[noa].reset();
386        }
387      }
388    }
389   
390    /**
391     * Call this to have the connection save the current
392     * weights.
393     */
394    public void saveWeights() {
395      for (int i = 0; i < m_numInputs; i++) {
396        m_inputList[i].saveWeights();
397      }
398    }
399   
400    /**
401     * Call this to have the connection restore from the saved
402     * weights.
403     */
404    public void restoreWeights() {
405      for (int i = 0; i < m_numInputs; i++) {
406        m_inputList[i].restoreWeights();
407      }
408    }
409   
410   
411    /**
412     * Call this function to set What this end unit represents.
413     * @param input True if this unit is used for entering an attribute,
414     * False if it's used for determining a class value.
415     * @param val The attribute number or class type that this unit represents.
416     * (for nominal attributes).
417     */
418    public void setLink(boolean input, int val) throws Exception {
419      m_input = input;
420     
421      if (input) {
422        m_type = PURE_INPUT;
423      }
424      else {
425        m_type = PURE_OUTPUT;
426      }
427      if (val < 0 || (input && val > m_instances.numAttributes()) 
428          || (!input && m_instances.classAttribute().isNominal() 
429              && val > m_instances.classAttribute().numValues())) {
430        m_link = 0;
431      }
432      else {
433        m_link = val;
434      }
435    }
436   
437    /**
438     * @return link for this node.
439     */
440    public int getLink() {
441      return m_link;
442    }
443   
444    /**
445     * Returns the revision string.
446     *
447     * @return          the revision
448     */
449    public String getRevision() {
450      return RevisionUtils.extract("$Revision: 6202 $");
451    }
452  }
453 
454
455 
456  /** Inner class used to draw the nodes onto.(uses the node lists!!)
457   * This will also handle the user input. */
458  private class NodePanel 
459    extends JPanel
460    implements RevisionHandler {
461   
462    /** for serialization */
463    static final long serialVersionUID = -3067621833388149984L;
464
465    /**
466     * The constructor.
467     */
468    public NodePanel() {
469     
470
471      addMouseListener(new MouseAdapter() {
472         
473          public void mousePressed(MouseEvent e) {
474           
475            if (!m_stopped) {
476              return;
477            }
478            if ((e.getModifiers() & MouseEvent.BUTTON1_MASK) == MouseEvent.BUTTON1_MASK && 
479                !e.isAltDown()) {
480              Graphics g = NodePanel.this.getGraphics();
481              int x = e.getX();
482              int y = e.getY();
483              int w = NodePanel.this.getWidth();
484              int h = NodePanel.this.getHeight();
485              FastVector tmp = new FastVector(4);
486              for (int noa = 0; noa < m_numAttributes; noa++) {
487                if (m_inputs[noa].onUnit(g, x, y, w, h)) {
488                  tmp.addElement(m_inputs[noa]);
489                  selection(tmp, 
490                            (e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK
491                            , true);
492                  return;
493                }
494              }
495              for (int noa = 0; noa < m_numClasses; noa++) {
496                if (m_outputs[noa].onUnit(g, x, y, w, h)) {
497                  tmp.addElement(m_outputs[noa]);
498                  selection(tmp,
499                            (e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK
500                            , true);
501                  return;
502                }
503              }
504              for (int noa = 0; noa < m_neuralNodes.length; noa++) {
505                if (m_neuralNodes[noa].onUnit(g, x, y, w, h)) {
506                  tmp.addElement(m_neuralNodes[noa]);
507                  selection(tmp,
508                            (e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK
509                            , true);
510                  return;
511                }
512
513              }
514              NeuralNode temp = new NeuralNode(String.valueOf(m_nextId), 
515                                               m_random, m_sigmoidUnit);
516              m_nextId++;
517              temp.setX((double)e.getX() / w);
518              temp.setY((double)e.getY() / h);
519              tmp.addElement(temp);
520              addNode(temp);
521              selection(tmp, (e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK
522                        , true);
523            }
524            else {
525              //then right click
526              Graphics g = NodePanel.this.getGraphics();
527              int x = e.getX();
528              int y = e.getY();
529              int w = NodePanel.this.getWidth();
530              int h = NodePanel.this.getHeight();
531              FastVector tmp = new FastVector(4);
532              for (int noa = 0; noa < m_numAttributes; noa++) {
533                if (m_inputs[noa].onUnit(g, x, y, w, h)) {
534                  tmp.addElement(m_inputs[noa]);
535                  selection(tmp, 
536                            (e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK
537                            , false);
538                  return;
539                }
540               
541               
542              }
543              for (int noa = 0; noa < m_numClasses; noa++) {
544                if (m_outputs[noa].onUnit(g, x, y, w, h)) {
545                  tmp.addElement(m_outputs[noa]);
546                  selection(tmp,
547                            (e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK
548                            , false);
549                  return;
550                }
551              }
552              for (int noa = 0; noa < m_neuralNodes.length; noa++) {
553                if (m_neuralNodes[noa].onUnit(g, x, y, w, h)) {
554                  tmp.addElement(m_neuralNodes[noa]);
555                  selection(tmp,
556                            (e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK
557                            , false);
558                  return;
559                }
560              }
561              selection(null, (e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK
562                        , false);
563            }
564          }
565        });
566    }
567   
568   
569    /**
570     * This function gets called when the user has clicked something
571     * It will amend the current selection or connect the current selection
572     * to the new selection.
573     * Or if nothing was selected and the right button was used it will
574     * delete the node.
575     * @param v The units that were selected.
576     * @param ctrl True if ctrl was held down.
577     * @param left True if it was the left mouse button.
578     */
579    private void selection(FastVector v, boolean ctrl, boolean left) {
580     
581      if (v == null) {
582        //then unselect all.
583        m_selected.removeAllElements();
584        repaint();
585        return;
586      }
587     
588
589      //then exclusive or the new selection with the current one.
590      if ((ctrl || m_selected.size() == 0) && left) {
591        boolean removed = false;
592        for (int noa = 0; noa < v.size(); noa++) {
593          removed = false;
594          for (int nob = 0; nob < m_selected.size(); nob++) {
595            if (v.elementAt(noa) == m_selected.elementAt(nob)) {
596              //then remove that element
597              m_selected.removeElementAt(nob);
598              removed = true;
599              break;
600            }
601          }
602          if (!removed) {
603            m_selected.addElement(v.elementAt(noa));
604          }
605        }
606        repaint();
607        return;
608      }
609
610     
611      if (left) {
612        //then connect the current selection to the new one.
613        for (int noa = 0; noa < m_selected.size(); noa++) {
614          for (int nob = 0; nob < v.size(); nob++) {
615            NeuralConnection
616              .connect((NeuralConnection)m_selected.elementAt(noa)
617                       , (NeuralConnection)v.elementAt(nob));
618          }
619        }
620      }
621      else if (m_selected.size() > 0) {
622        //then disconnect the current selection from the new one.
623       
624        for (int noa = 0; noa < m_selected.size(); noa++) {
625          for (int nob = 0; nob < v.size(); nob++) {
626            NeuralConnection
627              .disconnect((NeuralConnection)m_selected.elementAt(noa)
628                          , (NeuralConnection)v.elementAt(nob));
629           
630            NeuralConnection
631              .disconnect((NeuralConnection)v.elementAt(nob)
632                          , (NeuralConnection)m_selected.elementAt(noa));
633           
634          }
635        }
636      }
637      else {
638        //then remove the selected node. (it was right clicked while
639        //no other units were selected
640        for (int noa = 0; noa < v.size(); noa++) {
641          ((NeuralConnection)v.elementAt(noa)).removeAllInputs();
642          ((NeuralConnection)v.elementAt(noa)).removeAllOutputs();
643          removeNode((NeuralConnection)v.elementAt(noa));
644        }
645      }
646      repaint();
647    }
648
649    /**
650     * This will paint the nodes ontot the panel.
651     * @param g The graphics context.
652     */
653    public void paintComponent(Graphics g) {
654
655      super.paintComponent(g);
656      int x = getWidth();
657      int y = getHeight();
658      if (25 * m_numAttributes > 25 * m_numClasses && 
659          25 * m_numAttributes > y) {
660        setSize(x, 25 * m_numAttributes);
661      }
662      else if (25 * m_numClasses > y) {
663        setSize(x, 25 * m_numClasses);
664      }
665      else {
666        setSize(x, y);
667      }
668
669      y = getHeight();
670      for (int noa = 0; noa < m_numAttributes; noa++) {
671        m_inputs[noa].drawInputLines(g, x, y);
672      }
673      for (int noa = 0; noa < m_numClasses; noa++) {
674        m_outputs[noa].drawInputLines(g, x, y);
675        m_outputs[noa].drawOutputLines(g, x, y);
676      }
677      for (int noa = 0; noa < m_neuralNodes.length; noa++) {
678        m_neuralNodes[noa].drawInputLines(g, x, y);
679      }
680      for (int noa = 0; noa < m_numAttributes; noa++) {
681        m_inputs[noa].drawNode(g, x, y);
682      }
683      for (int noa = 0; noa < m_numClasses; noa++) {
684        m_outputs[noa].drawNode(g, x, y);
685      }
686      for (int noa = 0; noa < m_neuralNodes.length; noa++) {
687        m_neuralNodes[noa].drawNode(g, x, y);
688      }
689
690      for (int noa = 0; noa < m_selected.size(); noa++) {
691        ((NeuralConnection)m_selected.elementAt(noa)).drawHighlight(g, x, y);
692      }
693    }
694   
695    /**
696     * Returns the revision string.
697     *
698     * @return          the revision
699     */
700    public String getRevision() {
701      return RevisionUtils.extract("$Revision: 6202 $");
702    }
703  }
704
705  /**
706   * This provides the basic controls for working with the neuralnetwork
707   * @author Malcolm Ware (mfw4@cs.waikato.ac.nz)
708   * @version $Revision: 6202 $
709   */
710  class ControlPanel 
711    extends JPanel
712    implements RevisionHandler {
713   
714    /** for serialization */
715    static final long serialVersionUID = 7393543302294142271L;
716   
717    /** The start stop button. */
718    public JButton m_startStop;
719   
720    /** The button to accept the network (even if it hasn't done all epochs. */
721    public JButton m_acceptButton;
722   
723    /** A label to state the number of epochs processed so far. */
724    public JPanel m_epochsLabel;
725   
726    /** A label to state the total number of epochs to be processed. */
727    public JLabel m_totalEpochsLabel;
728   
729    /** A text field to allow the changing of the total number of epochs. */
730    public JTextField m_changeEpochs;
731   
732    /** A label to state the learning rate. */
733    public JLabel m_learningLabel;
734   
735    /** A label to state the momentum. */
736    public JLabel m_momentumLabel;
737   
738    /** A text field to allow the changing of the learning rate. */
739    public JTextField m_changeLearning;
740   
741    /** A text field to allow the changing of the momentum. */
742    public JTextField m_changeMomentum;
743   
744    /** A label to state roughly the accuracy of the network.(because the
745        accuracy is calculated per epoch, but the network is changing
746        throughout each epoch train).
747    */
748    public JPanel m_errorLabel;
749    //modCS.E
750    /** A label to indicate the currently lowest validation error on the
751     * neural network GUI. */
752    public JPanel m_lowValErrorLabel;
753    /**
754     * A label to indicate the epoch index when the lowest validation
755     * error is found
756     */
757    public JPanel m_epochIndexLabel;   
758   
759   
760    /** The constructor. */
761    public ControlPanel() { 
762      setBorder(BorderFactory.createTitledBorder("Controls"));
763     
764      m_totalEpochsLabel = new JLabel("Num Of Epochs  ");
765      m_epochsLabel = new JPanel(){ 
766          /** for serialization */
767          private static final long serialVersionUID = 2562773937093221399L;
768
769          public void paintComponent(Graphics g) {
770            super.paintComponent(g);
771            g.setColor(m_controlPanel.m_totalEpochsLabel.getForeground());
772            g.drawString("Epoch  " + m_epoch, 0, 10);
773          }
774        };
775      m_epochsLabel.setFont(m_totalEpochsLabel.getFont());
776
777        //modCS.E; Displays lowest validation error output in the neural
778        // network GUI.     
779        m_lowValErrorLabel = new JPanel()
780        {
781            /** for serialization */
782            private static final long serialVersionUID = 4390239056336679189L;
783
784            public void paintComponent(Graphics g)
785            {
786                    super.paintComponent(g);
787                    g.setColor(m_controlPanel.m_totalEpochsLabel.getForeground());
788                    if (m_valSize == 0 && m_valSet == null)
789                    {
790                            g.drawString("Lowest Validation Error = NA", 0, 10);
791                    }
792                    else
793                    {
794                            g.drawString("Lowest Validation Error = "
795                                    + Utils.doubleToString(m_lowValError, 7), 0, 10);
796                    }
797            }
798        };
799        m_lowValErrorLabel.setFont(m_epochsLabel.getFont());
800        //added : should not be needed anymore; for val error
801        m_epochIndexLabel = new JPanel()
802        {
803            /** for serialization */
804            private static final long serialVersionUID = 4390239056336679189L;
805
806            public void paintComponent(Graphics g)
807            {
808                    super.paintComponent(g);
809                    g.setColor(m_controlPanel.m_totalEpochsLabel.getForeground());
810                    if (m_valSize == 0 && m_valSet == null)
811                    {
812                            g.drawString(
813                                    "Epoch index (lowest validation error) = NA", 0, 10);
814                    }
815                    else
816                    {
817                            g.drawString("Epoch index (lowest validation error) = "
818                                    + m_epochIndex, 0, 10);
819                    }
820            }
821        };
822        m_epochIndexLabel.setFont(m_epochsLabel.getFont());
823     
824     
825      m_changeEpochs = new JTextField();
826      m_changeEpochs.setText("" + m_numEpochs);
827      m_errorLabel = new JPanel(){
828          /** for serialization */
829          private static final long serialVersionUID = 4390239056336679189L;
830
831          public void paintComponent(Graphics g) {
832            super.paintComponent(g);
833            g.setColor(m_controlPanel.m_totalEpochsLabel.getForeground());
834            if (m_valSize == 0 && m_valSet == null) {
835              g.drawString("Error per Epoch = " + 
836                           Utils.doubleToString(m_error, 7), 0, 10);
837            }
838            else {
839              g.drawString("Validation Error per Epoch = "
840                           + Utils.doubleToString(m_error, 7), 0, 10);
841            }
842          }
843        };
844      m_errorLabel.setFont(m_epochsLabel.getFont());
845     
846      m_learningLabel = new JLabel("Learning Rate = ");
847      m_momentumLabel = new JLabel("Momentum = ");
848      m_changeLearning = new JTextField();
849      m_changeMomentum = new JTextField();
850      m_changeLearning.setText("" + m_learningRate);
851      m_changeMomentum.setText("" + m_momentum);
852      setLayout(new BorderLayout(15, 10));
853
854      m_stopIt = true;
855      m_accepted = false;
856      m_startStop = new JButton("Start");
857      m_startStop.setActionCommand("Start");
858     
859      m_acceptButton = new JButton("Accept");
860      m_acceptButton.setActionCommand("Accept");
861     
862      JPanel buttons = new JPanel();
863      buttons.setLayout(new BoxLayout(buttons, BoxLayout.Y_AXIS));
864      buttons.add(m_startStop);
865      buttons.add(m_acceptButton);
866      add(buttons, BorderLayout.WEST);
867      JPanel data = new JPanel();
868      data.setLayout(new BoxLayout(data, BoxLayout.Y_AXIS));
869     
870      Box ab = new Box(BoxLayout.X_AXIS);
871      ab.add(m_epochsLabel);
872      data.add(ab);
873     
874
875     
876      ab = new Box(BoxLayout.X_AXIS);
877      ab.add(m_errorLabel);
878      data.add(ab);
879     
880      //modCS.E
881      //establishes space on the neural network GUI for lowest validation error
882      ab = new Box(BoxLayout.X_AXIS);
883      ab.add(m_lowValErrorLabel);
884      data.add(ab);
885      ab = new Box(BoxLayout.X_AXIS);
886      ab.add(m_epochIndexLabel);
887      data.add(ab);     
888         
889      add(data, BorderLayout.CENTER);
890
891      data = new JPanel();
892      data.setLayout(new BoxLayout(data, BoxLayout.Y_AXIS));
893     
894      ab = new Box(BoxLayout.X_AXIS);
895      Component b = Box.createGlue();
896      ab.add(m_totalEpochsLabel);
897      ab.add(m_changeEpochs);
898      m_changeEpochs.setMaximumSize(new Dimension(200, 20));
899      ab.add(b);
900      data.add(ab);     
901           
902      ab = new Box(BoxLayout.X_AXIS);
903      b = Box.createGlue();
904      ab.add(m_learningLabel);
905      ab.add(m_changeLearning);
906      m_changeLearning.setMaximumSize(new Dimension(200, 20));
907      ab.add(b);
908      data.add(ab);
909     
910      ab = new Box(BoxLayout.X_AXIS);
911      b = Box.createGlue();
912      ab.add(m_momentumLabel);
913      ab.add(m_changeMomentum);
914      m_changeMomentum.setMaximumSize(new Dimension(200, 20));
915      ab.add(b);
916      data.add(ab);
917     
918      add(data, BorderLayout.EAST);
919     
920      m_startStop.addActionListener(new ActionListener() {
921          public void actionPerformed(ActionEvent e) {
922            if (e.getActionCommand().equals("Start")) {
923              m_stopIt = false;
924              m_startStop.setText("Stop");
925              m_startStop.setActionCommand("Stop");
926              int n = Integer.valueOf(m_changeEpochs.getText()).intValue();
927             
928              m_numEpochs = n;
929              m_changeEpochs.setText("" + m_numEpochs);
930             
931              double m=Double.valueOf(m_changeLearning.getText()).
932                doubleValue();
933              setLearningRate(m);
934              m_changeLearning.setText("" + m_learningRate);
935             
936              m = Double.valueOf(m_changeMomentum.getText()).doubleValue();
937              setMomentum(m);
938              m_changeMomentum.setText("" + m_momentum);
939             
940              blocker(false);
941            }
942            else if (e.getActionCommand().equals("Stop")) {
943              m_stopIt = true;
944              m_startStop.setText("Start");
945              m_startStop.setActionCommand("Start");
946            }
947          }
948        });
949     
950      m_acceptButton.addActionListener(new ActionListener() {
951          public void actionPerformed(ActionEvent e) {
952            m_accepted = true;
953            blocker(false);
954          }
955        });
956     
957      m_changeEpochs.addActionListener(new ActionListener() {
958          public void actionPerformed(ActionEvent e) {
959            int n = Integer.valueOf(m_changeEpochs.getText()).intValue();
960            if (n > 0) {
961              m_numEpochs = n;
962              blocker(false);
963            }
964          }
965        });
966    }
967   
968    /**
969     * Returns the revision string.
970     *
971     * @return          the revision
972     */
973    public String getRevision() {
974      return RevisionUtils.extract("$Revision: 6202 $");
975    }
976  }
977 
978
979  /** a ZeroR model in case no model can be built from the data */
980  private Classifier m_ZeroR;
981   
982  /** The training instances. */
983  private Instances m_instances;
984 
985  //modCS.S & //modCS.V
986  /**
987   * Declaration of needed structures for loading and storing external
988   * validation and secondary task training files.
989   */
990  protected DataSource m_valSetSource = null;
991  protected Instances m_valSet = null;
992  protected String m_valSetFileName = null;
993  protected DataSource m_secSetSource = null;
994  protected Instances m_secSet = null;
995  protected String m_secSetFileName = null; 
996 
997  /** The current instance running through the network. */
998  private Instance m_currentInstance;
999 
1000  /** A flag to say that it's a numeric class. */
1001  private boolean m_numeric;
1002
1003  /** The ranges for all the attributes. */
1004  private double[] m_attributeRanges;
1005
1006  /** The base values for all the attributes. */
1007  private double[] m_attributeBases;
1008
1009  /** The output units.(only feeds the errors, does no calcs) */
1010  private NeuralEnd[] m_outputs;
1011
1012  /** The input units.(only feeds the inputs does no calcs) */
1013  private NeuralEnd[] m_inputs;
1014
1015  /** All the nodes that actually comprise the logical neural net. */
1016  private NeuralConnection[] m_neuralNodes;
1017
1018  /** The number of classes. */
1019  private int m_numClasses = 0;
1020 
1021  /** The number of attributes. */
1022  private int m_numAttributes = 0; //note the number doesn't include the class.
1023 
1024  /** The panel the nodes are displayed on. */
1025  private NodePanel m_nodePanel;
1026 
1027  /** The control panel. */
1028  private ControlPanel m_controlPanel;
1029
1030  /** The next id number available for default naming. */
1031  private int m_nextId;
1032   
1033  /** A Vector list of the units currently selected. */
1034  private FastVector m_selected;
1035
1036  /** A Vector list of the graphers. */
1037  private FastVector m_graphers;
1038
1039  /** The number of epochs to train through. */
1040  private int m_numEpochs;
1041
1042  /** a flag to state if the network should be running, or stopped. */
1043  private boolean m_stopIt;
1044
1045  /** a flag to state that the network has in fact stopped. */
1046  private boolean m_stopped;
1047
1048  /** a flag to state that the network should be accepted the way it is. */
1049  private boolean m_accepted;
1050  /** The window for the network. */
1051  private JFrame m_win;
1052
1053  /** A flag to tell the build classifier to automatically build a neural net.
1054   */
1055  private boolean m_autoBuild;
1056
1057  /** A flag to state that the gui for the network should be brought up.
1058      To allow interaction while training. */
1059  private boolean m_gui;
1060
1061  /** An int to say how big the validation set should be. */
1062  private int m_valSize;
1063
1064  /** The number to to use to quit on validation testing. */
1065  private int m_driftThreshold;
1066
1067  /** The number used to seed the random number generator. */
1068  private int m_randomSeed;
1069
1070  /** The actual random number generator. */
1071  private Random m_random;
1072
1073  /** A flag to state that a nominal to binary filter should be used. */
1074  private boolean m_useNomToBin;
1075 
1076  /** The actual filter. */
1077  private NominalToBinary m_nominalToBinaryFilter;
1078
1079  /** The string that defines the hidden layers */
1080  private String m_hiddenLayers;
1081
1082  /** This flag states that the user wants the input values normalized. */
1083  private boolean m_normalizeAttributes;
1084
1085  /** This flag states that the user wants the learning rate to decay. */
1086  private boolean m_decay;
1087
1088  /** This is the learning rate for the network. */
1089  private double m_learningRate;
1090
1091  /** This is the momentum for the network. */
1092  private double m_momentum;
1093
1094  /** Shows the number of the epoch that the network just finished. */
1095  private int m_epoch;
1096
1097  /** Shows the error of the epoch that the network just finished. */
1098  private double m_error;
1099
1100   //modCS.E
1101   /** Shows the lowest validation error */
1102   private double m_lowValError;
1103   /** Shows the epoch index when lowest validation error is found */
1104   private int m_epochIndex; 
1105
1106  /** This flag states that the user wants the network to restart if it
1107   * is found to be generating infinity or NaN for the error value. This
1108   * would restart the network with the current options except that the
1109   * learning rate would be smaller than before, (perhaps half of its current
1110   * value). This option will not be available if the gui is chosen (if the
1111   * gui is open the user can fix the network themselves, it is an
1112   * architectural minefield for the network to be reset with the gui open). */
1113  private boolean m_reset;
1114
1115  /** This flag states that the user wants the class to be normalized while
1116   * processing in the network is done. (the final answer will be in the
1117   * original range regardless). This option will only be used when the class
1118   * is numeric. */
1119  private boolean m_normalizeClass;
1120
1121  /**
1122   * this is a sigmoid unit.
1123   */
1124  private SigmoidUnit m_sigmoidUnit;
1125 
1126  /**
1127   * This is a linear unit.
1128   */
1129  private LinearUnit m_linearUnit;
1130 
1131  /**
1132   * The constructor.
1133   */
1134  public MultilayerPerceptronCS() {
1135    m_instances = null;
1136    //modCS.V
1137    m_valSet = null;
1138    m_valSetSource = null;
1139    m_valSetFileName = null;
1140    //modCS.S
1141    m_secSet = null;
1142    m_secSetSource = null;
1143    m_secSetFileName = null;   
1144   
1145    m_currentInstance = null;
1146    m_controlPanel = null;
1147    m_nodePanel = null;
1148    m_epoch = 0;
1149    m_error = 0;
1150   
1151   
1152    m_outputs = new NeuralEnd[0];
1153    m_inputs = new NeuralEnd[0];
1154    m_numAttributes = 0;
1155    m_numClasses = 0;
1156    m_neuralNodes = new NeuralConnection[0];
1157    m_selected = new FastVector(4);
1158    m_graphers = new FastVector(2);
1159    m_nextId = 0;
1160    m_stopIt = true;
1161    m_stopped = true;
1162    m_accepted = false;
1163    m_numeric = false;
1164    m_random = null;
1165    m_nominalToBinaryFilter = new NominalToBinary();
1166    m_sigmoidUnit = new SigmoidUnit();
1167    m_linearUnit = new LinearUnit();
1168    //setting all the options to their defaults. To completely change these
1169    //defaults they will also need to be changed down the bottom in the
1170    //setoptions function (the text info in the accompanying functions should
1171    //also be changed to reflect the new defaults
1172    m_normalizeClass = true;
1173    m_normalizeAttributes = true;
1174    m_autoBuild = true;
1175    m_gui = false;
1176    m_useNomToBin = true;
1177    m_driftThreshold = 20;
1178    m_numEpochs = 500;
1179    m_valSize = 0;
1180    m_randomSeed = 0;
1181    m_hiddenLayers = "a";
1182    m_learningRate = .3;
1183    m_momentum = .2;
1184    m_reset = true;
1185    m_decay = false;
1186 
1187    //modCS.E
1188    m_lowValError = 0;
1189    m_epochIndex = 0;   
1190   
1191  }
1192
1193  /**
1194   * @param d True if the learning rate should decay.
1195   */
1196  public void setDecay(boolean d) {
1197    m_decay = d;
1198  }
1199 
1200  /**
1201   * @return the flag for having the learning rate decay.
1202   */
1203  public boolean getDecay() {
1204    return m_decay;
1205  }
1206
1207  /**
1208   * This sets the network up to be able to reset itself with the current
1209   * settings and the learning rate at half of what it is currently. This
1210   * will only happen if the network creates NaN or infinite errors. Also this
1211   * will continue to happen until the network is trained properly. The
1212   * learning rate will also get set back to it's original value at the end of
1213   * this. This can only be set to true if the GUI is not brought up.
1214   * @param r True if the network should restart with it's current options
1215   * and set the learning rate to half what it currently is.
1216   */
1217  public void setReset(boolean r) {
1218    if (m_gui) {
1219      r = false;
1220    }
1221    m_reset = r;
1222     
1223  }
1224
1225  /**
1226   * @return The flag for reseting the network.
1227   */
1228  public boolean getReset() {
1229    return m_reset;
1230  }
1231 
1232  /**
1233   * @param c True if the class should be normalized (the class will only ever
1234   * be normalized if it is numeric). (Normalization puts the range between
1235   * -1 - 1).
1236   */
1237  public void setNormalizeNumericClass(boolean c) {
1238    m_normalizeClass = c;
1239  }
1240 
1241  /**
1242   * @return The flag for normalizing a numeric class.
1243   */
1244  public boolean getNormalizeNumericClass() {
1245    return m_normalizeClass;
1246  }
1247
1248  /**
1249   * @param a True if the attributes should be normalized (even nominal
1250   * attributes will get normalized here) (range goes between -1 - 1).
1251   */
1252  public void setNormalizeAttributes(boolean a) {
1253    m_normalizeAttributes = a;
1254  }
1255
1256  /**
1257   * @return The flag for normalizing attributes.
1258   */
1259  public boolean getNormalizeAttributes() {
1260    return m_normalizeAttributes;
1261  }
1262
1263  /**
1264   * @param f True if a nominalToBinary filter should be used on the
1265   * data.
1266   */
1267  public void setNominalToBinaryFilter(boolean f) {
1268    m_useNomToBin = f;
1269  }
1270
1271  /**
1272   * @return The flag for nominal to binary filter use.
1273   */
1274  public boolean getNominalToBinaryFilter() {
1275    return m_useNomToBin;
1276  }
1277
1278  /**
1279   * This seeds the random number generator, that is used when a random
1280   * number is needed for the network.
1281   * @param l The seed.
1282   */
1283  public void setSeed(int l) {
1284    if (l >= 0) {
1285      m_randomSeed = l;
1286    }
1287  }
1288 
1289  /**
1290   * @return The seed for the random number generator.
1291   */
1292  public int getSeed() {
1293    return m_randomSeed;
1294  }
1295
1296  /**
1297   * This sets the threshold to use for when validation testing is being done.
1298   * It works by ending testing once the error on the validation set has
1299   * consecutively increased a certain number of times.
1300   * @param t The threshold to use for this.
1301   */
1302  public void setValidationThreshold(int t) {
1303    if (t > 0) {
1304      m_driftThreshold = t;
1305    }
1306  }
1307
1308  /**
1309   * @return The threshold used for validation testing.
1310   */
1311  public int getValidationThreshold() {
1312    return m_driftThreshold;
1313  }
1314 
1315  /**
1316   * The learning rate can be set using this command.
1317   * NOTE That this is a static variable so it affect all networks that are
1318   * running.
1319   * Must be greater than 0 and no more than 1.
1320   * @param l The New learning rate.
1321   */
1322  public void setLearningRate(double l) {
1323    if (l > 0 && l <= 1) {
1324      m_learningRate = l;
1325   
1326      if (m_controlPanel != null) {
1327        m_controlPanel.m_changeLearning.setText("" + l);
1328      }
1329    }
1330  }
1331
1332  /**
1333   * @return The learning rate for the nodes.
1334   */
1335  public double getLearningRate() {
1336    return m_learningRate;
1337  }
1338
1339  /**
1340   * The momentum can be set using this command.
1341   * THE same conditions apply to this as to the learning rate.
1342   * @param m The new Momentum.
1343   */
1344  public void setMomentum(double m) {
1345    if (m >= 0 && m <= 1) {
1346      m_momentum = m;
1347 
1348      if (m_controlPanel != null) {
1349        m_controlPanel.m_changeMomentum.setText("" + m);
1350      }
1351    }
1352  }
1353 
1354  /**
1355   * @return The momentum for the nodes.
1356   */
1357  public double getMomentum() {
1358    return m_momentum;
1359  }
1360
1361  /**
1362   * This will set whether the network is automatically built
1363   * or if it is left up to the user. (there is nothing to stop a user
1364   * from altering an autobuilt network however).
1365   * @param a True if the network should be auto built.
1366   */
1367  public void setAutoBuild(boolean a) {
1368    if (!m_gui) {
1369      a = true;
1370    }
1371    m_autoBuild = a;
1372  }
1373
1374  /**
1375   * @return The auto build state.
1376   */
1377  public boolean getAutoBuild() {
1378    return m_autoBuild;
1379  }
1380
1381
1382  /**
1383   * This will set what the hidden layers are made up of when auto build is
1384   * enabled. Note to have no hidden units, just put a single 0, Any more
1385   * 0's will indicate that the string is badly formed and make it unaccepted.
1386   * Negative numbers, and floats will do the same. There are also some
1387   * wildcards. These are 'a' = (number of attributes + number of classes) / 2,
1388   * 'i' = number of attributes, 'o' = number of classes, and 't' = number of
1389   * attributes + number of classes.
1390   * @param h A string with a comma seperated list of numbers. Each number is
1391   * the number of nodes to be on a hidden layer.
1392   */
1393  public void setHiddenLayers(String h) {
1394    String tmp = "";
1395    StringTokenizer tok = new StringTokenizer(h, ",");
1396    if (tok.countTokens() == 0) {
1397      return;
1398    }
1399    double dval;
1400    int val;
1401    String c;
1402    boolean first = true;
1403    while (tok.hasMoreTokens()) {
1404      c = tok.nextToken().trim();
1405
1406      if (c.equals("a") || c.equals("i") || c.equals("o") || 
1407               c.equals("t")) {
1408        tmp += c;
1409      }
1410      else {
1411        dval = Double.valueOf(c).doubleValue();
1412        val = (int)dval;
1413       
1414        if ((val == dval && (val != 0 || (tok.countTokens() == 0 && first)) && 
1415             val >= 0)) {
1416          tmp += val;
1417        }
1418        else {
1419          return;
1420        }
1421      }
1422     
1423      first = false;
1424      if (tok.hasMoreTokens()) {
1425        tmp += ", ";
1426      }
1427    }
1428    m_hiddenLayers = tmp;
1429  }
1430
1431  /**
1432   * @return A string representing the hidden layers, each number is the number
1433   * of nodes on a hidden layer.
1434   */
1435  public String getHiddenLayers() {
1436    return m_hiddenLayers;
1437  }
1438
1439  /**
1440   * This will set whether A GUI is brought up to allow interaction by the user
1441   * with the neural network during training.
1442   * @param a True if gui should be created.
1443   */
1444  public void setGUI(boolean a) {
1445    m_gui = a;
1446    if (!a) {
1447      setAutoBuild(true);
1448     
1449    }
1450    else {
1451      setReset(false);
1452    }
1453  }
1454
1455  /**
1456   * @return The true if should show gui.
1457   */
1458  public boolean getGUI() {
1459    return m_gui;
1460  }
1461
1462  /**
1463   * This will set the size of the validation set.
1464   * @param a The size of the validation set, as a percentage of the whole.
1465   */
1466  public void setValidationSetSize(int a) {
1467    if (a < 0 || a > 99) {
1468      return;
1469    }
1470    m_valSize = a;
1471  }
1472
1473  /**
1474   * @return The percentage size of the validation set.
1475   */
1476  public int getValidationSetSize() {
1477    return m_valSize;
1478  }
1479
1480 
1481 
1482 
1483  /**
1484   * Set the number of training epochs to perform.
1485   * Must be greater than 0.
1486   * @param n The number of epochs to train through.
1487   */
1488  public void setTrainingTime(int n) {
1489    if (n > 0) {
1490      m_numEpochs = n;
1491    }
1492  }
1493
1494  /**
1495   * @return The number of epochs to train through.
1496   */
1497  public int getTrainingTime() {
1498    return m_numEpochs;
1499  }
1500 
1501  /**
1502   * Call this function to place a node into the network list.
1503   * @param n The node to place in the list.
1504   */
1505  private void addNode(NeuralConnection n) {
1506   
1507    NeuralConnection[] temp1 = new NeuralConnection[m_neuralNodes.length + 1];
1508    for (int noa = 0; noa < m_neuralNodes.length; noa++) {
1509      temp1[noa] = m_neuralNodes[noa];
1510    }
1511
1512    temp1[temp1.length-1] = n;
1513    m_neuralNodes = temp1;
1514  }
1515
1516  /**
1517   * Call this function to remove the passed node from the list.
1518   * This will only remove the node if it is in the neuralnodes list.
1519   * @param n The neuralConnection to remove.
1520   * @return True if removed false if not (because it wasn't there).
1521   */
1522  private boolean removeNode(NeuralConnection n) {
1523    NeuralConnection[] temp1 = new NeuralConnection[m_neuralNodes.length - 1];
1524    int skip = 0;
1525    for (int noa = 0; noa < m_neuralNodes.length; noa++) {
1526      if (n == m_neuralNodes[noa]) {
1527        skip++;
1528      }
1529      else if (!((noa - skip) >= temp1.length)) {
1530        temp1[noa - skip] = m_neuralNodes[noa];
1531      }
1532      else {
1533        return false;
1534      }
1535    }
1536    m_neuralNodes = temp1;
1537    return true;
1538  }
1539
1540  /**
1541   * This function sets what the m_numeric flag to represent the passed class
1542   * it also performs the normalization of the attributes if applicable
1543   * and sets up the info to normalize the class. (note that regardless of
1544   * the options it will fill an array with the range and base, set to
1545   * normalize all attributes and the class to be between -1 and 1)
1546   * @param inst the instances.
1547   * @return The modified instances. This needs to be done. If the attributes
1548   * are normalized then deep copies will be made of all the instances which
1549   * will need to be passed back out.
1550   */
1551  private Instances setClassType(Instances inst) throws Exception {
1552    if (inst != null) {
1553      // x bounds
1554      double min=Double.POSITIVE_INFINITY;
1555      double max=Double.NEGATIVE_INFINITY;
1556      double value;
1557      m_attributeRanges = new double[inst.numAttributes()];
1558      m_attributeBases = new double[inst.numAttributes()];
1559      for (int noa = 0; noa < inst.numAttributes(); noa++) {
1560        min = Double.POSITIVE_INFINITY;
1561        max = Double.NEGATIVE_INFINITY;
1562        for (int i=0; i < inst.numInstances();i++) {
1563          if (!inst.instance(i).isMissing(noa)) {
1564            value = inst.instance(i).value(noa);
1565            if (value < min) {
1566              min = value;
1567            }
1568            if (value > max) {
1569              max = value;
1570            }
1571          }
1572        }
1573       
1574        m_attributeRanges[noa] = (max - min) / 2;
1575        m_attributeBases[noa] = (max + min) / 2;
1576        //dmod
1577    /*    System.out.println("Attribute " + noa + " Range: "  + m_attributeRanges[noa]);
1578        System.out.println("Attribute " + noa + " Bases: " + m_attributeBases[noa]);
1579        System.out.println(); */
1580       
1581        //Nominal class; hardcode base and range
1582 /*       if(noa == inst.classIndex() && !inst.classAttribute().isNumeric())
1583        {
1584            m_attributeRanges[noa] = 0.4;
1585            m_attributeBases[noa] = 0.5;
1586            System.out.println();
1587            System.out.println();
1588            System.out.println("Nominal attribute detected; hardcoding range and base.");           
1589        } */
1590       
1591        if (noa != inst.classIndex() && m_normalizeAttributes) {
1592          for (int i = 0; i < inst.numInstances(); i++) {
1593            if (m_attributeRanges[noa] != 0) {
1594              inst.instance(i).setValue(noa, (inst.instance(i).value(noa) 
1595                                              - m_attributeBases[noa]) /
1596                                        m_attributeRanges[noa]);
1597            }
1598            else {
1599              inst.instance(i).setValue(noa, inst.instance(i).value(noa) - 
1600                                        m_attributeBases[noa]);
1601            }
1602          }
1603        }
1604      }
1605      if (inst.classAttribute().isNumeric()) {
1606        m_numeric = true;
1607      }
1608      else {
1609        m_numeric = false;
1610        //warningmod
1611        //m_numeric = true;
1612      }
1613    }
1614    return inst;
1615  }
1616
1617  /**
1618   * A function used to stop the code that called buildclassifier
1619   * from continuing on before the user has finished the decision tree.
1620   * @param tf True to stop the thread, False to release the thread that is
1621   * waiting there (if one).
1622   */
1623  public synchronized void blocker(boolean tf) {
1624    if (tf) {
1625      try {
1626        wait();
1627      } catch(InterruptedException e) {
1628      }
1629    }
1630    else {
1631      notifyAll();
1632    }
1633  }
1634
1635  /**
1636   * Call this function to update the control panel for the gui.
1637   */
1638  private void updateDisplay() {
1639   
1640    if (m_gui) {
1641      m_controlPanel.m_errorLabel.repaint();
1642      m_controlPanel.m_epochsLabel.repaint();
1643      //modCS.E
1644      m_controlPanel.m_lowValErrorLabel.repaint();
1645      m_controlPanel.m_epochIndexLabel.repaint();     
1646    }
1647  }
1648 
1649
1650  /**
1651   * this will reset all the nodes in the network.
1652   */
1653  private void resetNetwork() {
1654    for (int noc = 0; noc < m_numClasses; noc++) {
1655      m_outputs[noc].reset();
1656    }
1657  }
1658 
1659  /**
1660   * This will cause the output values of all the nodes to be calculated.
1661   * Note that the m_currentInstance is used to calculate these values.
1662   */
1663  private void calculateOutputs() {
1664    for (int noc = 0; noc < m_numClasses; noc++) {     
1665      //get the values.
1666      m_outputs[noc].outputValue(true);
1667    }
1668  }
1669
1670  /**
1671   * This will cause the error values to be calculated for all nodes.
1672   * Note that the m_currentInstance is used to calculate these values.
1673   * Also the output values should have been calculated first.
1674   * @return The squared error.
1675   */
1676  private double calculateErrors() throws Exception {
1677    double ret = 0, temp = 0; 
1678    for (int noc = 0; noc < m_numAttributes; noc++) {
1679      //get the errors.
1680      m_inputs[noc].errorValue(true);
1681     
1682    }
1683    for (int noc = 0; noc < m_numClasses; noc++) {
1684      temp = m_outputs[noc].errorValue(false);
1685      ret += temp * temp;
1686    }   
1687    return ret;
1688   
1689  }
1690
1691  /**
1692   * This will cause the weight values to be updated based on the learning
1693   * rate, momentum and the errors that have been calculated for each node.
1694   * @param l The learning rate to update with.
1695   * @param m The momentum to update with.
1696   */
1697  private void updateNetworkWeights(double l, double m) {
1698    for (int noc = 0; noc < m_numClasses; noc++) {
1699      //update weights
1700      m_outputs[noc].updateWeights(l, m);
1701    }
1702
1703  }
1704 
1705  /**
1706   * This creates the required input units.
1707   */
1708  private void setupInputs() throws Exception {
1709    m_inputs = new NeuralEnd[m_numAttributes];
1710    int now = 0;
1711    for (int noa = 0; noa < m_numAttributes+1; noa++) {
1712      if (m_instances.classIndex() != noa) {
1713        m_inputs[noa - now] = new NeuralEnd(m_instances.attribute(noa).name());
1714       
1715        m_inputs[noa - now].setX(.1);
1716        m_inputs[noa - now].setY((noa - now + 1.0) / (m_numAttributes + 1));
1717        m_inputs[noa - now].setLink(true, noa);
1718      }   
1719      else {
1720        now = 1;
1721      }
1722    }
1723
1724  }
1725
1726  /**
1727   * This creates the required output units.
1728   */
1729  private void setupOutputs() throws Exception {
1730 
1731    m_outputs = new NeuralEnd[m_numClasses];
1732    for (int noa = 0; noa < m_numClasses; noa++) {
1733      if (m_numeric) {
1734        m_outputs[noa] = new NeuralEnd(m_instances.classAttribute().name());
1735      }
1736      else {
1737        m_outputs[noa]= new NeuralEnd(m_instances.classAttribute().value(noa));
1738      }
1739     
1740      m_outputs[noa].setX(.9);
1741      m_outputs[noa].setY((noa + 1.0) / (m_numClasses + 1));
1742      m_outputs[noa].setLink(false, noa);
1743      NeuralNode temp = new NeuralNode(String.valueOf(m_nextId), m_random,
1744                                       m_sigmoidUnit);
1745      m_nextId++;
1746      temp.setX(.75);
1747      temp.setY((noa + 1.0) / (m_numClasses + 1));
1748      addNode(temp);
1749      NeuralConnection.connect(temp, m_outputs[noa]);
1750    }
1751 
1752  }
1753 
1754  /**
1755   * Call this function to automatically generate the hidden units
1756   */
1757  private void setupHiddenLayer()
1758  {
1759    StringTokenizer tok = new StringTokenizer(m_hiddenLayers, ",");
1760    int val = 0;  //num of nodes in a layer
1761    int prev = 0; //used to remember the previous layer
1762    int num = tok.countTokens(); //number of layers
1763    String c;
1764    for (int noa = 0; noa < num; noa++) {
1765      //note that I am using the Double to get the value rather than the
1766      //Integer class, because for some reason the Double implementation can
1767      //handle leading white space and the integer version can't!?!
1768      c = tok.nextToken().trim();
1769      if (c.equals("a")) {
1770        val = (m_numAttributes + m_numClasses) / 2;
1771      }
1772      else if (c.equals("i")) {
1773        val = m_numAttributes;
1774      }
1775      else if (c.equals("o")) {
1776        val = m_numClasses;
1777      }
1778      else if (c.equals("t")) {
1779        val = m_numAttributes + m_numClasses;
1780      }
1781      else {
1782        val = Double.valueOf(c).intValue();
1783      }
1784      for (int nob = 0; nob < val; nob++) {
1785        NeuralNode temp = new NeuralNode(String.valueOf(m_nextId), m_random,
1786                                         m_sigmoidUnit);
1787        m_nextId++;
1788        temp.setX(.5 / (num) * noa + .25);
1789        temp.setY((nob + 1.0) / (val + 1));
1790        addNode(temp);
1791        if (noa > 0) {
1792          //then do connections
1793          for (int noc = m_neuralNodes.length - nob - 1 - prev;
1794               noc < m_neuralNodes.length - nob - 1; noc++) {
1795            NeuralConnection.connect(m_neuralNodes[noc], temp);
1796          }
1797        }
1798      }     
1799      prev = val;
1800    }
1801    tok = new StringTokenizer(m_hiddenLayers, ",");
1802    c = tok.nextToken();
1803    if (c.equals("a")) {
1804      val = (m_numAttributes + m_numClasses) / 2;
1805    }
1806    else if (c.equals("i")) {
1807      val = m_numAttributes;
1808    }
1809    else if (c.equals("o")) {
1810      val = m_numClasses;
1811    }
1812    else if (c.equals("t")) {
1813      val = m_numAttributes + m_numClasses;
1814    }
1815    else {
1816      val = Double.valueOf(c).intValue();
1817    }
1818   
1819    if (val == 0) {
1820      for (int noa = 0; noa < m_numAttributes; noa++) {
1821        for (int nob = 0; nob < m_numClasses; nob++) {
1822          NeuralConnection.connect(m_inputs[noa], m_neuralNodes[nob]);
1823        }
1824      }
1825    }
1826    else {
1827      for (int noa = 0; noa < m_numAttributes; noa++) {
1828        for (int nob = m_numClasses; nob < m_numClasses + val; nob++) {
1829          NeuralConnection.connect(m_inputs[noa], m_neuralNodes[nob]);
1830        }
1831      }
1832      for (int noa = m_neuralNodes.length - prev; noa < m_neuralNodes.length;
1833           noa++) {
1834        for (int nob = 0; nob < m_numClasses; nob++) {
1835          NeuralConnection.connect(m_neuralNodes[noa], m_neuralNodes[nob]);
1836        }
1837      }
1838    }
1839   
1840  }
1841 
1842  /**
1843   * This will go through all the nodes and check if they are connected
1844   * to a pure output unit. If so they will be set to be linear units.
1845   * If not they will be set to be sigmoid units.
1846   */
1847  private void setEndsToLinear() {
1848    for (int noa = 0; noa < m_neuralNodes.length; noa++) {
1849      if ((m_neuralNodes[noa].getType() & NeuralConnection.OUTPUT) ==
1850          NeuralConnection.OUTPUT) {
1851        ((NeuralNode)m_neuralNodes[noa]).setMethod(m_linearUnit);
1852      }
1853      else {
1854        ((NeuralNode)m_neuralNodes[noa]).setMethod(m_sigmoidUnit);
1855      }
1856    }
1857  }
1858
1859  /**
1860   * Returns default capabilities of the classifier.
1861   *
1862   * @return      the capabilities of this classifier
1863   */
1864  public Capabilities getCapabilities() {
1865    Capabilities result = super.getCapabilities();
1866    result.disableAll();
1867
1868    // attributes
1869    result.enable(Capability.NOMINAL_ATTRIBUTES);
1870    result.enable(Capability.NUMERIC_ATTRIBUTES);
1871    result.enable(Capability.DATE_ATTRIBUTES);
1872    result.enable(Capability.MISSING_VALUES);
1873
1874    // class
1875    result.enable(Capability.NOMINAL_CLASS);
1876    result.enable(Capability.NUMERIC_CLASS);
1877    result.enable(Capability.DATE_CLASS);
1878    result.enable(Capability.MISSING_CLASS_VALUES);
1879   
1880    return result;
1881  }
1882 
1883  /**
1884   * Call this function to build and train a neural network for the training
1885   * data provided.
1886   * @param i The training data.
1887   * @throws Exception if can't build classification properly.
1888   */
1889  public void buildClassifier(Instances i) throws Exception {
1890
1891    // can classifier handle the data?
1892    getCapabilities().testWithFail(i);
1893   
1894    i = new Instances(i);
1895   
1896    //modCS.S Moved randomizer up, so only order of primary data is randomized.   
1897    m_random = new Random(m_randomSeed);
1898    i.randomize(m_random);
1899   
1900    //modCS.S, add secondary instances to primary instances, and replicate
1901    //primaries to match the number of secondaries.
1902    int m_originalNumInstances = i.numInstances();   
1903    if (m_secSet != null)
1904    {
1905        if (!m_secSet.equalHeaders(i))
1906        {
1907            throw new Exception("Training and secondary sets "
1908                + "have different headers.");
1909        }
1910       
1911        int m_numSecondaries = 0;
1912        //calculate how many instances there are for each secondary task.
1913        for (int noa = 0; noa < m_secSet.numInstances(); noa++)
1914        {
1915            //1 indicates the index of the attribute
1916            if(m_secSet.instance(noa).value(1) == 0)
1917            {             
1918            }
1919            else if(m_secSet.instance(noa).value(1) == 1)
1920            {
1921                m_numSecondaries++;
1922            }
1923            else
1924            {
1925                throw new Exception("Cannot find appropriate secondary task " +
1926                        "attribute(s).");
1927            }
1928        }
1929       
1930        if(m_secSet.numInstances() < m_originalNumInstances)
1931        {
1932            throw new Exception("Secondary task training set has less " +
1933                    "instances than the primary training set.");           
1934        }
1935       
1936        /* Do not replicate primaries to be used in a percentage
1937           validation set. Duplicates the remaining primaries. */         
1938        if (m_valSize > 0)
1939        {
1940            int exclude = 0;
1941            int counter = 0;
1942           
1943            exclude = (int)(m_valSize / 100.0 * (m_originalNumInstances));
1944            if(exclude == 0)
1945                exclude = 1;
1946           
1947            while(i.numInstances() < m_numSecondaries + exclude)
1948            {
1949                if(counter % m_originalNumInstances >= exclude)
1950                {
1951                    i.add(i.instance(counter % m_originalNumInstances));
1952                    counter++;
1953                }
1954                else
1955                {
1956                    counter+=exclude;
1957                }
1958            }
1959           
1960           
1961        }
1962        else
1963        {
1964            //Replicate the primaries to match the number of secondaries.
1965            for(int noa= m_originalNumInstances; noa<m_numSecondaries; noa++)
1966            {
1967                i.add(i.instance(noa % m_originalNumInstances));
1968            }
1969           
1970        }
1971        //add the secondaries
1972        for (int noa = 0; noa < m_secSet.numInstances(); noa++)
1973        {
1974            i.add(m_secSet.instance(noa));
1975        }
1976    }   
1977    // remove instances with missing class
1978    i.deleteWithMissingClass();
1979           
1980    // only class? -> build ZeroR model
1981    if (i.numAttributes() == 1) {
1982      System.err.println(
1983          "Cannot build model (only class attribute present in data!), "
1984          + "using ZeroR model instead!");
1985      m_ZeroR = new weka.classifiers.rules.ZeroR();
1986      m_ZeroR.buildClassifier(i);
1987      return;
1988    }
1989    else {
1990      m_ZeroR = null;
1991    }
1992   
1993    m_epoch = 0;
1994    m_error = 0;
1995    m_instances = null;
1996    m_currentInstance = null;
1997    m_controlPanel = null;
1998    m_nodePanel = null;
1999   
2000   
2001    m_outputs = new NeuralEnd[0];
2002    m_inputs = new NeuralEnd[0];
2003    m_numAttributes = 0;
2004    m_numClasses = 0;
2005    m_neuralNodes = new NeuralConnection[0];
2006   
2007    m_selected = new FastVector(4);
2008    m_graphers = new FastVector(2);
2009    m_nextId = 0;
2010    m_stopIt = true;
2011    m_stopped = true;
2012    m_accepted = false;   
2013    m_instances = new Instances(i);
2014     
2015    //modCS.S Moved randomizer up, so only order of primary data is randomized.   
2016 
2017    if (m_useNomToBin) 
2018    {
2019      m_nominalToBinaryFilter = new NominalToBinary();
2020      m_nominalToBinaryFilter.setInputFormat(m_instances);
2021      m_instances = Filter.useFilter(m_instances,
2022                                     m_nominalToBinaryFilter);
2023      //modCS.V
2024      //Run the nominal to binary filter on the validation set, if applicable.
2025      if(m_valSet != null)
2026      {
2027          m_nominalToBinaryFilter.setInputFormat(m_valSet);
2028          m_valSet = Filter.useFilter(
2029                m_valSet, m_nominalToBinaryFilter);   
2030      }
2031    }
2032    m_numAttributes = m_instances.numAttributes() - 1;
2033    m_numClasses = m_instances.numClasses();
2034 
2035   
2036    setClassType(m_instances);
2037    //modCS.V
2038    //Run the setClassType procedure on the validation set, as well.
2039    if(m_valSet != null)
2040        setClassType(m_valSet);
2041
2042    //modCS.V
2043    //Prepare validation instances using either the specified validation set or
2044    //using a percentage of the training examples.
2045    Instances valSet;
2046    int numInVal = 0;
2047    //this sets up the validation set.
2048    if(m_valSet == null)//sets up a validation set using some primary instances.
2049    {
2050        valSet = null;
2051        //numinval is needed later
2052        //modCS.S; properly count the total number of instances if a secondary
2053        //set is included.
2054        if(m_secSet != null)
2055        {
2056            numInVal = (int)(m_valSize / 100.0 * (m_originalNumInstances));
2057        }
2058        else
2059        {
2060            numInVal = (int)(m_valSize / 100.0 * m_instances.numInstances());
2061        }
2062        if (m_valSize > 0) {
2063          if (numInVal == 0) {
2064            numInVal = 1;
2065          }
2066          valSet = new Instances(m_instances, 0, numInVal);
2067        }
2068       
2069        //debug; see the training and validation sets
2070 /*       System.out.println("\n\n\nTraining Set; ");
2071        for(int noa = 0; noa < i.numInstances(); noa++)
2072            System.out.println(i.instance(noa));
2073       
2074        System.out.println("\n\n\n\nValidation Set");
2075        for(int noa = 0; noa < valSet.numInstances(); noa++)
2076            System.out.println(valSet.instance(noa));       
2077       
2078      */ 
2079        ///////////       
2080    }
2081    else //sets up a specified validation set
2082    {
2083        valSet = new Instances(m_valSet);
2084        numInVal = valSet.numInstances();
2085        valSet.deleteWithMissingClass();
2086        valSet.setClassIndex(m_instances.classIndex());
2087        if (!m_instances.equalHeaders(valSet))
2088        {
2089            throw new Exception("Training and validation sets "
2090                + "have different headers.");
2091        }
2092        setClassType(valSet);
2093        if (m_valSize != 0)
2094        {
2095            throw new Exception("Given both a validation set size split and" 
2096                + "a specified validation set. Use only one of either.");
2097        }
2098    }
2099
2100    setupInputs();
2101     
2102    setupOutputs();   
2103    if (m_autoBuild) {
2104      setupHiddenLayer();
2105    }
2106   
2107    /////////////////////////////
2108    //this sets up the gui for usage
2109    if (m_gui) {
2110      m_win = new JFrame();
2111     
2112      m_win.addWindowListener(new WindowAdapter() {
2113          public void windowClosing(WindowEvent e) {
2114            boolean k = m_stopIt;
2115            m_stopIt = true;
2116            int well =JOptionPane.showConfirmDialog(m_win, 
2117                                                    "Are You Sure...\n"
2118                                                    + "Click Yes To Accept"
2119                                                    + " The Neural Network" 
2120                                                    + "\n Click No To Return",
2121                                                    "Accept Neural Network", 
2122                                                    JOptionPane.YES_NO_OPTION);
2123           
2124            if (well == 0) {
2125              m_win.setDefaultCloseOperation(JFrame.DISPOSE_ON_CLOSE);
2126              m_accepted = true;
2127              blocker(false);
2128            }
2129            else {
2130              m_win.setDefaultCloseOperation(JFrame.DO_NOTHING_ON_CLOSE);
2131            }
2132            m_stopIt = k;
2133          }
2134        });
2135     
2136      m_win.getContentPane().setLayout(new BorderLayout());
2137      m_win.setTitle("Neural Network");
2138      m_nodePanel = new NodePanel();
2139      // without the following two lines, the NodePanel.paintComponents(Graphics)
2140      // method will go berserk if the network doesn't fit completely: it will
2141      // get called on a constant basis, using 100% of the CPU
2142      // see the following forum thread:
2143      // http://forum.java.sun.com/thread.jspa?threadID=580929&messageID=2945011
2144      m_nodePanel.setPreferredSize(new Dimension(640, 480));
2145      m_nodePanel.revalidate();
2146
2147      JScrollPane sp = new JScrollPane(m_nodePanel,
2148                                       JScrollPane.VERTICAL_SCROLLBAR_ALWAYS, 
2149                                       JScrollPane.HORIZONTAL_SCROLLBAR_NEVER);
2150      m_controlPanel = new ControlPanel();
2151           
2152      m_win.getContentPane().add(sp, BorderLayout.CENTER);
2153      m_win.getContentPane().add(m_controlPanel, BorderLayout.SOUTH);
2154      m_win.setSize(640, 480);
2155      m_win.setVisible(true);
2156    }
2157   
2158    //This sets up the initial state of the gui
2159    if (m_gui) {
2160      blocker(true);
2161      m_controlPanel.m_changeEpochs.setEnabled(false);
2162      m_controlPanel.m_changeLearning.setEnabled(false);
2163      m_controlPanel.m_changeMomentum.setEnabled(false);
2164    } 
2165   
2166    //For silly situations in which the network gets accepted before training
2167    //commenses
2168    if (m_numeric) {
2169      setEndsToLinear();
2170    }
2171    if (m_accepted) {
2172      m_win.dispose();
2173      m_controlPanel = null;
2174      m_nodePanel = null;
2175      m_instances = new Instances(m_instances, 0);
2176      return;
2177    }
2178
2179    //connections done.
2180    double right = 0;
2181    double driftOff = 0;
2182    double lastRight = Double.POSITIVE_INFINITY;
2183    double bestError = Double.POSITIVE_INFINITY;
2184    double tempRate;
2185    double totalWeight = 0;
2186    double totalValWeight = 0;
2187    double origRate = m_learningRate; //only used for when reset
2188   
2189    //ensure that at least 1 instance is trained through.
2190    if (numInVal == m_instances.numInstances()) {
2191      numInVal--;
2192    }
2193    if (numInVal < 0) {
2194      numInVal = 0;
2195    }
2196    //modCS.V; utilizes either the specified training set or one constructed
2197    //using the percentage of the primary training set.
2198    if (m_valSet == null)
2199    {
2200        for (int noa = numInVal; noa < m_instances.numInstances(); noa++) {
2201          if (!m_instances.instance(noa).classIsMissing()) {
2202            totalWeight += m_instances.instance(noa).weight();
2203          }
2204        }
2205        if (m_valSize != 0) {
2206          for (int noa = 0; noa < valSet.numInstances(); noa++) {
2207            if (!valSet.instance(noa).classIsMissing()) {
2208              totalValWeight += valSet.instance(noa).weight();
2209            }
2210          }
2211        }
2212        m_stopped = false;
2213    }
2214    else
2215    {
2216        for (int noa = 0; noa < m_instances.numInstances(); noa++)
2217        {
2218                if (!m_instances.instance(noa).classIsMissing())
2219                {
2220                        totalWeight += m_instances.instance(noa).weight();
2221                }
2222        }
2223        if (valSet.numInstances() != 0)
2224        {
2225                for (int noa = 0; noa < valSet.numInstances(); noa++)
2226                {
2227                        if (!valSet.instance(noa).classIsMissing())
2228                        {
2229                                totalValWeight += valSet.instance(noa).weight();
2230                        }
2231                }
2232        }   
2233    }
2234
2235    for (int noa = 1; noa < m_numEpochs + 1; noa++) {
2236      right = 0;
2237      for (int nob = numInVal; nob < m_instances.numInstances(); nob++) {
2238        m_currentInstance = m_instances.instance(nob);
2239       
2240        if (!m_currentInstance.classIsMissing()) {
2241           
2242          //this is where the network updating (and training occurs, for the
2243          //training set
2244          resetNetwork();
2245          calculateOutputs();
2246          tempRate = m_learningRate * m_currentInstance.weight(); 
2247          if (m_decay) {
2248            tempRate /= noa;
2249          }
2250
2251          right += (calculateErrors() / m_instances.numClasses()) *
2252            m_currentInstance.weight();
2253          updateNetworkWeights(tempRate, m_momentum);
2254         
2255        }
2256       
2257      }
2258      right /= totalWeight;
2259      if (Double.isInfinite(right) || Double.isNaN(right)) {
2260        if (!m_reset) {
2261          m_instances = null;
2262          throw new Exception("Network cannot train. Try restarting with a" +
2263                              " smaller learning rate.");
2264        }
2265        else {
2266          //reset the network if possible
2267          if (m_learningRate <= Utils.SMALL)
2268            throw new IllegalStateException(
2269                "Learning rate got too small (" + m_learningRate
2270                + " <= " + Utils.SMALL + ")!");
2271          m_learningRate /= 2;
2272          buildClassifier(i);
2273          m_learningRate = origRate;
2274          m_instances = new Instances(m_instances, 0);   
2275          return;
2276        }
2277      }
2278
2279      ////////////////////////do validation testing if applicable
2280     
2281      //modCS.V Recent change for validation calcs
2282      if (m_valSize != 0 || m_valSet != null) {
2283        right = 0;
2284        for (int nob = 0; nob < valSet.numInstances(); nob++) {
2285          m_currentInstance = valSet.instance(nob);
2286          if (!m_currentInstance.classIsMissing()) {
2287            //this is where the network updating occurs, for the validation set
2288            resetNetwork();
2289            calculateOutputs();
2290            right += (calculateErrors() / valSet.numClasses()) 
2291              * m_currentInstance.weight();
2292            //note 'right' could be calculated here just using
2293            //the calculate output values. This would be faster.
2294            //be less modular
2295          }
2296         
2297        }
2298        //mod? If an epoch has less error then the previous, validation
2299        //testing won't end even if the threshold is crossed. Consider
2300        //noting this in the description of the validation threshold, or
2301        //changing this to function as the description currently states.
2302        if (right < lastRight) {
2303          if (right < bestError) {
2304            bestError = right;
2305            //modCS.E; Calculate the lowest validation error and save its index.
2306            m_lowValError = right / totalValWeight;
2307            m_epochIndex = noa;
2308           
2309            // save the network weights at this point
2310            for (int noc = 0; noc < m_numClasses; noc++) {
2311              m_outputs[noc].saveWeights();
2312            }
2313            driftOff = 0;
2314          }
2315        }
2316        else {
2317          driftOff++;
2318        }
2319        lastRight = right;
2320        if (driftOff > m_driftThreshold || noa + 1 >= m_numEpochs) {
2321          for (int noc = 0; noc < m_numClasses; noc++) {
2322            m_outputs[noc].restoreWeights();
2323          }
2324          m_accepted = true;
2325        }
2326        right /= totalValWeight;
2327      }
2328      m_epoch = noa;
2329      m_error = right;
2330      //shows what the neuralnet is upto if a gui exists.
2331      updateDisplay();
2332      //This junction controls what state the gui is in at the end of each
2333      //epoch, Such as if it is paused, if it is resumable etc...
2334      //modCS.V; Extended conditional statements to consider supplied val sets.
2335      if (m_gui) {
2336        while ((m_stopIt || (m_epoch >= m_numEpochs && (m_valSize == 0 
2337                && m_valSet == null))) && !m_accepted) {
2338          m_stopIt = true;
2339          m_stopped = true;
2340          if (m_epoch >= m_numEpochs && (m_valSize == 0 && m_valSet == null)) {
2341           
2342            m_controlPanel.m_startStop.setEnabled(false);
2343          }
2344          else {
2345            m_controlPanel.m_startStop.setEnabled(true);
2346          }
2347          m_controlPanel.m_startStop.setText("Start");
2348          m_controlPanel.m_startStop.setActionCommand("Start");
2349          m_controlPanel.m_changeEpochs.setEnabled(true);
2350          m_controlPanel.m_changeLearning.setEnabled(true);
2351          m_controlPanel.m_changeMomentum.setEnabled(true);
2352         
2353          blocker(true);
2354          if (m_numeric) {
2355            setEndsToLinear();
2356          }
2357        }
2358        m_controlPanel.m_changeEpochs.setEnabled(false);
2359        m_controlPanel.m_changeLearning.setEnabled(false);
2360        m_controlPanel.m_changeMomentum.setEnabled(false);
2361       
2362        m_stopped = false;
2363        //if the network has been accepted stop the training loop
2364        if (m_accepted) {
2365          m_win.dispose();
2366          m_controlPanel = null;
2367          m_nodePanel = null;
2368          m_instances = new Instances(m_instances, 0);
2369          return;
2370        }
2371      }
2372      if (m_accepted) {
2373        m_instances = new Instances(m_instances, 0);
2374        return;
2375      }
2376    }
2377    if (m_gui) {
2378      m_win.dispose();
2379      m_controlPanel = null;
2380      m_nodePanel = null;
2381    }
2382    m_instances = new Instances(m_instances, 0); 
2383  }
2384
2385  /**
2386   * Call this function to predict the class of an instance once a
2387   * classification model has been built with the buildClassifier call.
2388   * @param i The instance to classify.
2389   * @return A double array filled with the probabilities of each class type.
2390   * @throws Exception if can't classify instance.
2391   */
2392  public double[] distributionForInstance(Instance i) throws Exception {
2393
2394    // default model?
2395    if (m_ZeroR != null) {
2396      return m_ZeroR.distributionForInstance(i);
2397    }
2398   
2399    if (m_useNomToBin) {
2400      m_nominalToBinaryFilter.input(i);
2401      m_currentInstance = m_nominalToBinaryFilter.output();
2402    }
2403    else {
2404      m_currentInstance = i;
2405    }
2406   
2407    if (m_normalizeAttributes) {
2408      for (int noa = 0; noa < m_instances.numAttributes(); noa++) {
2409        if (noa != m_instances.classIndex()) {
2410          if (m_attributeRanges[noa] != 0) {
2411            m_currentInstance.setValue(noa, (m_currentInstance.value(noa) - 
2412                                             m_attributeBases[noa]) / 
2413                                       m_attributeRanges[noa]);
2414          }
2415          else {
2416            m_currentInstance.setValue(noa, m_currentInstance.value(noa) -
2417                                       m_attributeBases[noa]);
2418          }
2419        }
2420      }
2421    }
2422    resetNetwork();
2423   
2424    //since all the output values are needed.
2425    //They are calculated manually here and the values collected.
2426    double[] theArray = new double[m_numClasses];
2427    for (int noa = 0; noa < m_numClasses; noa++) {
2428      theArray[noa] = m_outputs[noa].outputValue(true);
2429    }
2430    if (m_instances.classAttribute().isNumeric()) {
2431      return theArray;
2432    }
2433   
2434    //now normalize the array
2435    double count = 0;
2436    for (int noa = 0; noa < m_numClasses; noa++) {
2437      count += theArray[noa];
2438    }
2439  //  System.out.println("Count: " + count + ".");
2440    if (count <= 0) {
2441      return null;
2442    }
2443    for (int noa = 0; noa < m_numClasses; noa++) {
2444        //dmod
2445  //      System.out.println("Array " + noa +":" + theArray[noa]);
2446      theArray[noa] /= count;
2447   //   System.out.println(" and " + theArray[noa] +".");     
2448    }
2449  //  System.out.println(); System.out.println(); System.out.println();
2450    return theArray;
2451  }
2452 
2453
2454
2455  /**
2456   * Returns an enumeration describing the available options.
2457   *
2458   * @return an enumeration of all the available options.
2459   */
2460  public Enumeration listOptions() {
2461    //modCS.S & modCS.V;
2462    //needed more vectors for the increased number of parameters
2463    Vector newVector = new Vector(16);
2464
2465    newVector.addElement(new Option(
2466              "\tLearning Rate for the backpropagation algorithm.\n"
2467              +"\t(Value should be between 0 - 1, Default = 0.3).",
2468              "L", 1, "-L <learning rate>"));
2469    newVector.addElement(new Option(
2470              "\tMomentum Rate for the backpropagation algorithm.\n"
2471              +"\t(Value should be between 0 - 1, Default = 0.2).",
2472              "M", 1, "-M <momentum>"));
2473    newVector.addElement(new Option(
2474              "\tNumber of epochs to train through.\n"
2475              +"\t(Default = 500).",
2476              "N", 1,"-N <number of epochs>"));
2477    newVector.addElement(new Option(
2478              "\tPercentage size of validation set to use to terminate\n"
2479              + "\ttraining (if this is non zero it can pre-empt num of epochs.\n"
2480              +"\t(Value should be between 0 - 100, Default = 0).",
2481              "V", 1, "-V <percentage size of validation set>"));
2482    newVector.addElement(new Option(
2483              "\tThe value used to seed the random number generator\n"
2484              + "\t(Value should be >= 0 and and a long, Default = 0).",
2485              "S", 1, "-S <seed>"));
2486    newVector.addElement(new Option(
2487              "\tThe consequetive number of errors allowed for validation\n"
2488              + "\ttesting before the netwrok terminates.\n"
2489              + "\t(Value should be > 0, Default = 20).",
2490              "E", 1, "-E <threshold for number of consequetive errors>"));
2491    newVector.addElement(new Option(
2492              "\tGUI will be opened.\n"
2493              +"\t(Use this to bring up a GUI).",
2494              "G", 0,"-G"));
2495    newVector.addElement(new Option(
2496              "\tAutocreation of the network connections will NOT be done.\n"
2497              +"\t(This will be ignored if -G is NOT set)",
2498              "A", 0,"-A"));
2499    newVector.addElement(new Option(
2500              "\tA NominalToBinary filter will NOT automatically be used.\n"
2501              +"\t(Set this to not use a NominalToBinary filter).",
2502              "B", 0,"-B"));
2503    newVector.addElement(new Option(
2504              "\tThe hidden layers to be created for the network.\n"
2505              + "\t(Value should be a list of comma separated Natural \n"
2506              + "\tnumbers or the letters 'a' = (attribs + classes) / 2, \n"
2507              + "\t'i' = attribs, 'o' = classes, 't' = attribs .+ classes)\n"
2508              + "\tfor wildcard values, Default = a).",
2509              "H", 1, "-H <comma seperated numbers for nodes on each layer>"));
2510    newVector.addElement(new Option(
2511              "\tNormalizing a numeric class will NOT be done.\n"
2512              +"\t(Set this to not normalize the class if it's numeric).",
2513              "C", 0,"-C"));
2514    newVector.addElement(new Option(
2515              "\tNormalizing the attributes will NOT be done.\n"
2516              +"\t(Set this to not normalize the attributes).",
2517              "I", 0,"-I"));
2518    newVector.addElement(new Option(
2519              "\tReseting the network will NOT be allowed.\n"
2520              +"\t(Set this to not allow the network to reset).",
2521              "R", 0,"-R"));
2522    newVector.addElement(new Option(
2523              "\tLearning rate decay will occur.\n"
2524              +"\t(Set this to cause the learning rate to decay).",
2525              "D", 0,"-D"));
2526    //modCS.V
2527    newVector.addElement(new Option("\tValidation set to use, " +
2528            " as drawn from the data source file.\n",
2529            "validation-set", 1, "-validation-set <data source file>"));
2530    //modCS.S
2531    newVector.addElement(new Option("\tSecondary task training set to use," +
2532            " as drawn from the data source file.\n",
2533            "secondary-training", 1, "-secondary-training <data source file>"));
2534    return newVector.elements();
2535  }
2536
2537  /**
2538   * Parses a given list of options. <p/>
2539   *
2540   <!-- options-start -->
2541   * Valid options are: <p/>
2542   *
2543   * <pre> -L &lt;learning rate&gt;
2544   *  Learning Rate for the backpropagation algorithm.
2545   *  (Value should be between 0 - 1, Default = 0.3).</pre>
2546   *
2547   * <pre> -M &lt;momentum&gt;
2548   *  Momentum Rate for the backpropagation algorithm.
2549   *  (Value should be between 0 - 1, Default = 0.2).</pre>
2550   *
2551   * <pre> -N &lt;number of epochs&gt;
2552   *  Number of epochs to train through.
2553   *  (Default = 500).</pre>
2554   *
2555   * <pre> -V &lt;percentage size of validation set&gt;
2556   *  Percentage size of validation set to use to terminate
2557   *  training (if this is non zero it can pre-empt num of epochs.
2558   *  (Value should be between 0 - 100, Default = 0).</pre>
2559   *
2560   * <pre> -S &lt;seed&gt;
2561   *  The value used to seed the random number generator
2562   *  (Value should be &gt;= 0 and and a long, Default = 0).</pre>
2563   *
2564   * <pre> -E &lt;threshold for number of consequetive errors&gt;
2565   *  The consequetive number of errors allowed for validation
2566   *  testing before the netwrok terminates.
2567   *  (Value should be &gt; 0, Default = 20).</pre>
2568   *
2569   * <pre> -G
2570   *  GUI will be opened.
2571   *  (Use this to bring up a GUI).</pre>
2572   *
2573   * <pre> -A
2574   *  Autocreation of the network connections will NOT be done.
2575   *  (This will be ignored if -G is NOT set)</pre>
2576   *
2577   * <pre> -B
2578   *  A NominalToBinary filter will NOT automatically be used.
2579   *  (Set this to not use a NominalToBinary filter).</pre>
2580   *
2581   * <pre> -H &lt;comma seperated numbers for nodes on each layer&gt;
2582   *  The hidden layers to be created for the network.
2583   *  (Value should be a list of comma separated Natural
2584   *  numbers or the letters 'a' = (attribs + classes) / 2,
2585   *  'i' = attribs, 'o' = classes, 't' = attribs .+ classes)
2586   *  for wildcard values, Default = a).</pre>
2587   *
2588   * <pre> -C
2589   *  Normalizing a numeric class will NOT be done.
2590   *  (Set this to not normalize the class if it's numeric).</pre>
2591   *
2592   * <pre> -I
2593   *  Normalizing the attributes will NOT be done.
2594   *  (Set this to not normalize the attributes).</pre>
2595   *
2596   * <pre> -R
2597   *  Reseting the network will NOT be allowed.
2598   *  (Set this to not allow the network to reset).</pre>
2599   *
2600   * <pre> -D
2601   *  Learning rate decay will occur.
2602   *  (Set this to cause the learning rate to decay).</pre>
2603   *
2604   * <pre> -validation-set &lt;data source file&gt;
2605   *  Validation set to use,  as drawn from the data source file.
2606   * </pre>
2607   *
2608   * <pre> -secondary-training &lt;data source file&gt;
2609   *  Secondary task training set to use, as drawn from the data source file.
2610   * </pre>
2611   *
2612   <!-- options-end -->
2613   *
2614   * @param options the list of options as an array of strings
2615   * @throws Exception if an option is not supported
2616   */
2617  public void setOptions(String[] options) throws Exception {
2618    //the defaults can be found here!!!!
2619    String learningString = Utils.getOption('L', options);
2620    if (learningString.length() != 0) {
2621      setLearningRate((new Double(learningString)).doubleValue());
2622    } else {
2623      setLearningRate(0.3);
2624    }
2625    String momentumString = Utils.getOption('M', options);
2626    if (momentumString.length() != 0) {
2627      setMomentum((new Double(momentumString)).doubleValue());
2628    } else {
2629      setMomentum(0.2);
2630    }
2631    String epochsString = Utils.getOption('N', options);
2632    if (epochsString.length() != 0) {
2633      setTrainingTime(Integer.parseInt(epochsString));
2634    } else {
2635      setTrainingTime(500);
2636    }
2637    String valSizeString = Utils.getOption('V', options);
2638    if (valSizeString.length() != 0) {
2639      setValidationSetSize(Integer.parseInt(valSizeString));
2640    } else {
2641      setValidationSetSize(0);
2642    }
2643    String seedString = Utils.getOption('S', options);
2644    if (seedString.length() != 0) {
2645      setSeed(Integer.parseInt(seedString));
2646    } else {
2647      setSeed(0);
2648    }
2649    String thresholdString = Utils.getOption('E', options);
2650    if (thresholdString.length() != 0) {
2651      setValidationThreshold(Integer.parseInt(thresholdString));
2652    } else {
2653      setValidationThreshold(20);
2654    }
2655    String hiddenLayers = Utils.getOption('H', options);
2656    if (hiddenLayers.length() != 0) {
2657      setHiddenLayers(hiddenLayers);
2658    } else {
2659      setHiddenLayers("a");
2660    }
2661    if (Utils.getFlag('G', options)) {
2662      setGUI(true);
2663    } else {
2664      setGUI(false);
2665    } //small note. since the gui is the only option that can change the other
2666    //options this should be set first to allow the other options to set
2667    //properly
2668    if (Utils.getFlag('A', options)) {
2669      setAutoBuild(false);
2670    } else {
2671      setAutoBuild(true);
2672    }
2673    if (Utils.getFlag('B', options)) {
2674      setNominalToBinaryFilter(false);
2675    } else {
2676      setNominalToBinaryFilter(true);
2677    }
2678    if (Utils.getFlag('C', options)) {
2679      setNormalizeNumericClass(false);
2680    } else {
2681      setNormalizeNumericClass(true);
2682    }
2683    if (Utils.getFlag('I', options)) {
2684      setNormalizeAttributes(false);
2685    } else {
2686      setNormalizeAttributes(true);
2687    }
2688    if (Utils.getFlag('R', options)) {
2689      setReset(false);
2690    } else {
2691      setReset(true);
2692    }
2693    if (Utils.getFlag('D', options)) {
2694      setDecay(true);
2695    } else {
2696      setDecay(false);
2697    }
2698
2699    //modCS.V
2700    String sValFile = Utils.getOption("validation-set", options);
2701    if (sValFile != null && !sValFile.equals("")) {
2702      setValFile(sValFile);
2703    }
2704
2705    //modCS.S
2706    String sSecFile = Utils.getOption("secondary-training", options);
2707    if (sSecFile != null && !sSecFile.equals("")) {
2708      setSecFile(sSecFile);
2709    }   
2710   
2711    Utils.checkForRemainingOptions(options);
2712  }
2713 
2714  /**
2715   * Gets the current settings of NeuralNet.
2716   *
2717   * @return an array of strings suitable for passing to setOptions()
2718   */
2719  public String [] getOptions() {
2720    //modCS.S & modCS.V; more options, so we need a larger array.
2721    String [] options = new String [21];
2722    int current = 0;
2723    options[current++] = "-L"; options[current++] = "" + getLearningRate(); 
2724    options[current++] = "-M"; options[current++] = "" + getMomentum();
2725    options[current++] = "-N"; options[current++] = "" + getTrainingTime(); 
2726    options[current++] = "-V"; options[current++] = "" +getValidationSetSize();
2727    options[current++] = "-S"; options[current++] = "" + getSeed();
2728    options[current++] = "-E"; options[current++] =""+getValidationThreshold();
2729    options[current++] = "-H"; options[current++] = getHiddenLayers();
2730    if (getGUI()) {
2731      options[current++] = "-G";
2732    }
2733    if (!getAutoBuild()) {
2734      options[current++] = "-A";
2735    }
2736    if (!getNominalToBinaryFilter()) {
2737      options[current++] = "-B";
2738    }
2739    if (!getNormalizeNumericClass()) {
2740      options[current++] = "-C";
2741    }
2742    if (!getNormalizeAttributes()) {
2743      options[current++] = "-I";
2744    }
2745    if (!getReset()) {
2746      options[current++] = "-R";
2747    }
2748    if (getDecay()) {
2749      options[current++] = "-D";
2750    }
2751
2752    //modCS.V
2753    if (m_valSet != null) {
2754      options[current++] = "-validation-set";
2755      options[current++] = m_valSetFileName;
2756    }   
2757    //modCS.S
2758    if (m_secSet != null) {
2759      options[current++] = "-secondary-training";
2760      options[current++] = m_secSetFileName;
2761    }       
2762   
2763    while (current < options.length) {
2764      options[current++] = "";
2765    }
2766    return options;
2767  }
2768 
2769  //modCS.V
2770  /**
2771   * Loads and stores validation file information from the given file name.
2772   * @param sValFile the name of the data source file
2773   */
2774  public void setValFile(String sValFile) 
2775  {     
2776      try
2777      {
2778         m_valSetFileName = sValFile; 
2779         m_valSetSource = new DataSource(sValFile);         
2780         m_valSet = m_valSetSource.getDataSet();
2781         if (m_valSet.classIndex() == -1)
2782             m_valSet.setClassIndex(m_valSet.numAttributes() - 1); 
2783       
2784         m_valSet.deleteWithMissingClass();
2785      }
2786      catch (Throwable t)
2787      {
2788          m_valSetFileName = null;
2789          m_valSetSource = null;
2790          m_valSet = null;
2791      }
2792  }
2793
2794  //modCS.V
2795  /**
2796   * Get the name of the validation file.
2797   * @return Validation set file name
2798   */
2799  public String getValFile() {
2800    if (m_valSetFileName != null) 
2801    {
2802        return m_valSetFileName;
2803    }
2804    return "";
2805  } 
2806 
2807  //modCS.S
2808  /**
2809   * Loads and stores secondary task training file information from the given
2810   * file name.
2811   * @param sSecFile the name of the data source file
2812   */
2813  public void setSecFile(String sSecFile) 
2814  {     
2815      try
2816      {
2817         m_secSetFileName = sSecFile; 
2818         m_secSetSource = new DataSource(sSecFile);         
2819         m_secSet = m_secSetSource.getDataSet();
2820         if (m_secSet.classIndex() == -1)
2821             m_secSet.setClassIndex(m_secSet.numAttributes() - 1); 
2822               
2823         m_secSet.deleteWithMissingClass();
2824      }
2825      catch (Throwable t)
2826      {
2827          m_secSetFileName = null;
2828          m_secSetSource = null;
2829          m_secSet = null;
2830      }
2831  }
2832
2833  //modCS.S
2834  /**
2835   * Get the name of the secondary task training file.
2836   * @return Secondary task training set file name
2837   */
2838  public String getSecFile() {
2839    if (m_secSetFileName != null) 
2840    {
2841        return m_secSetFileName;
2842    }
2843    return "";
2844  }   
2845  /**
2846   * @return string describing the model.
2847   */
2848  public String toString() {
2849    // only ZeroR model?
2850    if (m_ZeroR != null) {
2851      StringBuffer buf = new StringBuffer();
2852      buf.append(this.getClass().getName().replaceAll(".*\\.", "") + "\n");
2853      buf.append(this.getClass().getName().replaceAll(".*\\.", "").replaceAll(".", "=") + "\n\n");
2854      buf.append("Warning: No model could be built, hence ZeroR model is used:\n\n");
2855      buf.append(m_ZeroR.toString());
2856      return buf.toString();
2857    }
2858   
2859    StringBuffer model = new StringBuffer(m_neuralNodes.length * 100); 
2860    //just a rough size guess
2861    NeuralNode con;
2862    double[] weights;
2863    NeuralConnection[] inputs;
2864    for (int noa = 0; noa < m_neuralNodes.length; noa++) {
2865      con = (NeuralNode) m_neuralNodes[noa];  //this would need a change
2866                                              //for items other than nodes!!!
2867      weights = con.getWeights();
2868      inputs = con.getInputs();
2869      if (con.getMethod() instanceof SigmoidUnit) {
2870        model.append("Sigmoid ");
2871      }
2872      else if (con.getMethod() instanceof LinearUnit) {
2873        model.append("Linear ");
2874      }
2875      model.append("Node " + con.getId() + "\n    Inputs    Weights\n");
2876      model.append("    Threshold    " + weights[0] + "\n");
2877      for (int nob = 1; nob < con.getNumInputs() + 1; nob++) {
2878        if ((inputs[nob - 1].getType() & NeuralConnection.PURE_INPUT) 
2879            == NeuralConnection.PURE_INPUT) {
2880          model.append("    Attrib " + 
2881                       m_instances.attribute(((NeuralEnd)inputs[nob-1]).
2882                                             getLink()).name()
2883                       + "    " + weights[nob] + "\n");
2884        }
2885        else {
2886          model.append("    Node " + inputs[nob-1].getId() + "    " +
2887                       weights[nob] + "\n");
2888        }
2889      }     
2890    }
2891    //now put in the ends
2892    for (int noa = 0; noa < m_outputs.length; noa++) {
2893      inputs = m_outputs[noa].getInputs();
2894      model.append("Class " + 
2895                   m_instances.classAttribute().
2896                   value(m_outputs[noa].getLink()) + 
2897                   "\n    Input\n");
2898      for (int nob = 0; nob < m_outputs[noa].getNumInputs(); nob++) {
2899        if ((inputs[nob].getType() & NeuralConnection.PURE_INPUT)
2900            == NeuralConnection.PURE_INPUT) {
2901          model.append("    Attrib " +
2902                       m_instances.attribute(((NeuralEnd)inputs[nob]).
2903                                             getLink()).name() + "\n");
2904        }
2905        else {
2906          model.append("    Node " + inputs[nob].getId() + "\n");
2907        }
2908      }
2909    }
2910    return model.toString();
2911  }
2912
2913  /**
2914   * This will return a string describing the classifier.
2915   * @return The string.
2916   */
2917  public String globalInfo() {
2918    return 
2919        "A Classifier that uses backpropagation to classify instances.\n"
2920      + "This network can be built by hand, created by an algorithm or both. "
2921      + "The network can also be monitored and modified during training time. "
2922      + "The nodes in this network are all sigmoid (except for when the class "
2923      + "is numeric in which case the the output nodes become unthresholded "
2924      + "linear units).";
2925  }
2926 
2927  /**
2928   * @return a string to describe the learning rate option.
2929   */
2930  public String learningRateTipText() {
2931    return "The amount the" + 
2932      " weights are updated.";
2933  }
2934 
2935  /**
2936   * @return a string to describe the momentum option.
2937   */
2938  public String momentumTipText() {
2939    return "Momentum applied to the weights during updating.";
2940  }
2941
2942  /**
2943   * @return a string to describe the AutoBuild option.
2944   */
2945  public String autoBuildTipText() {
2946    return "Adds and connects up hidden layers in the network.";
2947  }
2948
2949  /**
2950   * @return a string to describe the random seed option.
2951   */
2952  public String seedTipText() {
2953    return "Seed used to initialise the random number generator." +
2954      "Random numbers are used for setting the initial weights of the" +
2955      " connections betweem nodes, and also for shuffling the training data.";
2956  }
2957 
2958  /**
2959   * @return a string to describe the validation threshold option.
2960   */
2961  public String validationThresholdTipText() {
2962    return "Used to terminate validation testing." +
2963      "The value here dictates how many times in a row the validation set" +
2964      " error can get worse before training is terminated.";
2965  }
2966 
2967  /**
2968   * @return a string to describe the GUI option.
2969   */
2970  public String GUITipText() {
2971    return "Brings up a gui interface." +
2972      " This will allow the pausing and altering of the nueral network" +
2973      " during training.\n\n" +
2974      "* To add a node left click (this node will be automatically selected," +
2975      " ensure no other nodes were selected).\n" +
2976      "* To select a node left click on it either while no other node is" +
2977      " selected or while holding down the control key (this toggles that" +
2978      " node as being selected and not selected.\n" + 
2979      "* To connect a node, first have the start node(s) selected, then click"+
2980      " either the end node or on an empty space (this will create a new node"+
2981      " that is connected with the selected nodes). The selection status of" +
2982      " nodes will stay the same after the connection. (Note these are" +
2983      " directed connections, also a connection between two nodes will not" +
2984      " be established more than once and certain connections that are" + 
2985      " deemed to be invalid will not be made).\n" +
2986      "* To remove a connection select one of the connected node(s) in the" +
2987      " connection and then right click the other node (it does not matter" +
2988      " whether the node is the start or end the connection will be removed" +
2989      ").\n" +
2990      "* To remove a node right click it while no other nodes (including it)" +
2991      " are selected. (This will also remove all connections to it)\n." +
2992      "* To deselect a node either left click it while holding down control," +
2993      " or right click on empty space.\n" +
2994      "* The raw inputs are provided from the labels on the left.\n" +
2995      "* The red nodes are hidden layers.\n" +
2996      "* The orange nodes are the output nodes.\n" +
2997      "* The labels on the right show the class the output node represents." +
2998      " Note that with a numeric class the output node will automatically be" +
2999      " made into an unthresholded linear unit.\n\n" +
3000      "Alterations to the neural network can only be done while the network" +
3001      " is not running, This also applies to the learning rate and other" +
3002      " fields on the control panel.\n\n" + 
3003      "* You can accept the network as being finished at any time.\n" +
3004      "* The network is automatically paused at the beginning.\n" +
3005      "* There is a running indication of what epoch the network is up to" + 
3006      " and what the (rough) error for that epoch was (or for" +
3007      " the validation if that is being used). Note that this error value" +
3008      " is based on a network that changes as the value is computed." +
3009      " (also depending on whether" +
3010      " the class is normalized will effect the error reported for numeric" +
3011      " classes.\n" +
3012      "* Once the network is done it will pause again and either wait to be" +
3013      " accepted or trained more.\n\n" +
3014      "Note that if the gui is not set the network will not require any" +
3015      " interaction.\n";
3016  }
3017 
3018  /**
3019   * @return a string to describe the validation size option.
3020   */
3021  public String validationSetSizeTipText() {
3022    return "The percentage size of the validation set." +
3023      "(The training will continue until it is observed that" +
3024      " the error on the validation set has been consistently getting" +
3025      " worse, or if the training time is reached).\n" +
3026      "If This is set to zero no validation set will be used and instead" +
3027      " the network will train for the specified number of epochs.";
3028  }
3029 
3030  /**
3031   * @return a string to describe the learning rate option.
3032   */
3033  public String trainingTimeTipText() {
3034    return "The number of epochs to train through." + 
3035      " If the validation set is non-zero then it can terminate the network" +
3036      " early";
3037  }
3038
3039
3040  /**
3041   * @return a string to describe the nominal to binary option.
3042   */
3043  public String nominalToBinaryFilterTipText() {
3044    return "This will preprocess the instances with the filter." +
3045      " This could help improve performance if there are nominal attributes" +
3046      " in the data.";
3047  }
3048
3049  /**
3050   * @return a string to describe the hidden layers in the network.
3051   */
3052  public String hiddenLayersTipText() {
3053    return "This defines the hidden layers of the neural network." +
3054      " This is a list of positive whole numbers. 1 for each hidden layer." +
3055      " Comma seperated. To have no hidden layers put a single 0 here." +
3056      " This will only be used if autobuild is set. There are also wildcard" +
3057      " values 'a' = (attribs + classes) / 2, 'i' = attribs, 'o' = classes" +
3058      " , 't' = attribs + classes.";
3059  }
3060  /**
3061   * @return a string to describe the nominal to binary option.
3062   */
3063  public String normalizeNumericClassTipText() {
3064    return "This will normalize the class if it's numeric." +
3065      " This could help improve performance of the network, It normalizes" +
3066      " the class to be between -1 and 1. Note that this is only internally" +
3067      ", the output will be scaled back to the original range.";
3068  }
3069  /**
3070   * @return a string to describe the nominal to binary option.
3071   */
3072  public String normalizeAttributesTipText() {
3073    return "This will normalize the attributes." +
3074      " This could help improve performance of the network." +
3075      " This is not reliant on the class being numeric. This will also" +
3076      " normalize nominal attributes as well (after they have been run" +
3077      " through the nominal to binary filter if that is in use) so that the" +
3078      " nominal values are between -1 and 1";
3079  }
3080  /**
3081   * @return a string to describe the Reset option.
3082   */
3083  public String resetTipText() {
3084    return "This will allow the network to reset with a lower learning rate." +
3085      " If the network diverges from the answer this will automatically" +
3086      " reset the network with a lower learning rate and begin training" +
3087      " again. This option is only available if the gui is not set. Note" +
3088      " that if the network diverges but isn't allowed to reset it will" +
3089      " fail the training process and return an error message.";
3090  }
3091 
3092  /**
3093   * @return a string to describe the Decay option.
3094   */
3095  public String decayTipText() {
3096    return "This will cause the learning rate to decrease." +
3097      " This will divide the starting learning rate by the epoch number, to" +
3098      " determine what the current learning rate should be. This may help" +
3099      " to stop the network from diverging from the target output, as well" +
3100      " as improve general performance. Note that the decaying learning" +
3101      " rate will not be shown in the gui, only the original learning rate" +
3102      ". If the learning rate is changed in the gui, this is treated as the" +
3103      " starting learning rate.";
3104  }
3105 
3106  //modCS.V
3107  /**
3108   * @return a string to describe the validation set file.
3109   */
3110  public String valFileTipText() {
3111    return "Set the name of a validation file in data source format.";
3112  }
3113  //modCS.S
3114  /**
3115   * @return a string to describe the secondary task training set file.
3116   */ 
3117    public String secFileTipText() {
3118    return "Set the name of a secondary training file in data source format.";
3119  } 
3120 
3121  /**
3122   * Returns the revision string.
3123   *
3124   * @return            the revision
3125   */
3126  public String getRevision() {
3127    return RevisionUtils.extract("$Revision: 6202 $");
3128  }
3129}
3130
Note: See TracBrowser for help on using the repository browser.