source: src/main/java/weka/classifiers/pmml/consumer/NeuralNetwork.java @ 23

Last change on this file since 23 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 30.9 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    NeuralNetwork.java
19 *    Copyright (C) 2008 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.pmml.consumer;
24
25import java.io.Serializable;
26import java.util.ArrayList;
27import java.util.HashMap;
28
29import org.w3c.dom.Element;
30import org.w3c.dom.Node;
31import org.w3c.dom.NodeList;
32
33import weka.core.Attribute;
34import weka.core.Instance;
35import weka.core.Instances;
36import weka.core.RevisionUtils;
37import weka.core.Utils;
38import weka.core.pmml.*;
39
40/**
41 * Class implementing import of PMML Neural Network model. Can be used as a Weka
42 * classifier for prediction (buildClassifier() raises an Exception).
43 *
44 * @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
45 * @version $Revision 1.0 $
46 */
47public class NeuralNetwork extends PMMLClassifier {
48 
49  /**
50   * For serialization
51   */
52  private static final long serialVersionUID = -4545904813133921249L;
53
54  /**
55   * Small inner class for a NeuralInput (essentially just
56   * wraps a DerivedField and adds an ID)
57   */
58  static class NeuralInput implements Serializable {
59   
60    /**
61     * For serialization
62     */
63    private static final long serialVersionUID = -1902233762824835563L;
64   
65    /** Field that this input refers to */
66    private DerivedFieldMetaInfo m_field;
67   
68    /** ID string */
69    private String m_ID = null;
70   
71    private String getID() {
72      return m_ID;
73    }
74   
75    protected NeuralInput(Element input, MiningSchema miningSchema) throws Exception {
76      m_ID = input.getAttribute("id");
77     
78      NodeList fL = input.getElementsByTagName("DerivedField");
79      if (fL.getLength() != 1) {
80        throw new Exception("[NeuralInput] expecting just one derived field!");
81      }
82     
83      Element dF = (Element)fL.item(0);
84      Instances allFields = miningSchema.getFieldsAsInstances();
85      ArrayList<Attribute> fieldDefs = new ArrayList<Attribute>();
86      for (int i = 0; i < allFields.numAttributes(); i++) {
87        fieldDefs.add(allFields.attribute(i));
88      }
89      m_field = new DerivedFieldMetaInfo(dF, fieldDefs, miningSchema.getTransformationDictionary());
90    }
91   
92    protected double getValue(double[] incoming) throws Exception {
93      return m_field.getDerivedValue(incoming);
94    }
95   
96    public String toString() {
97      StringBuffer temp = new StringBuffer();
98     
99      temp.append("Nueral input (" + getID() + ")\n");
100      temp.append(m_field);
101     
102      return temp.toString();
103    }
104  }
105 
106  /**
107   * Inner class representing a layer in the network.
108   */
109  class NeuralLayer implements Serializable {
110   
111    /**
112     * For serialization
113     */
114    private static final long serialVersionUID = -8386042001675763922L;
115
116    /** The number of neurons in this layer */
117    private int m_numNeurons = 0;
118   
119    /** Activation function (if defined, overrides one in NeuralNetwork) */
120    private ActivationFunction m_layerActivationFunction = null;
121   
122    /** Threshold (if defined overrides one in NeuralNetwork) */
123    private double m_layerThreshold = Double.NaN; 
124   
125    /** Width (if defined overrides one in NeuralNetwork) */
126    private double m_layerWidth = Double.NaN;
127   
128    /** Altitude (if defined overrides one in NeuralNetwork) */
129    private double m_layerAltitude = Double.NaN;
130   
131    /** Normalization (if defined overrides one in NeuralNetwork) */
132    private Normalization m_layerNormalization = null;
133   
134    /** The neurons at this hidden layer */
135    private Neuron[] m_layerNeurons = null;
136   
137    /** Stores the output of this layer (for given inputs) */
138    private HashMap<String, Double> m_layerOutput = new HashMap<String, Double>();
139   
140    protected NeuralLayer(Element layerE) {
141     
142      String activationFunction = layerE.getAttribute("activationFunction");
143      if (activationFunction != null && activationFunction.length() > 0) {
144        for (ActivationFunction a : ActivationFunction.values()) {
145          if (a.toString().equals(activationFunction)) {
146            m_layerActivationFunction = a;
147            break;
148          }
149        }
150      } else {
151        // use the network-level activation function
152        m_layerActivationFunction = m_activationFunction;
153      }
154     
155      String threshold = layerE.getAttribute("threshold");
156      if (threshold != null && threshold.length() > 0) {
157        m_layerThreshold = Double.parseDouble(threshold);
158      } else {
159        // use network-level threshold
160        m_layerThreshold = m_threshold;
161      }
162     
163      String width = layerE.getAttribute("width");
164      if (width != null && width.length() > 0) {
165        m_layerWidth = Double.parseDouble(width);
166      } else {
167        // use network-level width
168        m_layerWidth = m_width;
169      }
170     
171      String altitude = layerE.getAttribute("altitude");
172      if (altitude != null && altitude.length() > 0) {
173        m_layerAltitude = Double.parseDouble(altitude);
174      } else {
175        // use network-level altitude
176        m_layerAltitude = m_altitude;
177      }
178     
179      String normMethod = layerE.getAttribute("normalizationMethod");
180      if (normMethod != null && normMethod.length() > 0) {
181        for (Normalization n : Normalization.values()) {
182          if (n.toString().equals(normMethod)) {
183            m_layerNormalization = n;
184            break;
185          }
186        }
187      } else {
188        // use network-level normalization method
189        m_layerNormalization = m_normalizationMethod;
190      }
191     
192      NodeList neuronL = layerE.getElementsByTagName("Neuron");
193      m_numNeurons = neuronL.getLength();
194      m_layerNeurons = new Neuron[m_numNeurons];
195      for (int i = 0; i < neuronL.getLength(); i++) {
196        Node neuronN = neuronL.item(i);
197        if (neuronN.getNodeType() == Node.ELEMENT_NODE) {
198          m_layerNeurons[i] = new Neuron((Element)neuronN, this);
199        }
200      }
201    }
202   
203    protected ActivationFunction getActivationFunction() {
204      return m_layerActivationFunction;
205    }
206   
207    protected double getThreshold() {
208      return m_layerThreshold;
209    }
210   
211    protected double getWidth() {
212      return m_layerWidth;
213    }
214   
215    protected double getAltitude() {
216      return m_layerAltitude;
217    }
218   
219    protected Normalization getNormalization() {
220      return m_layerNormalization;
221    }
222   
223    /**
224     * Compute the output values for this layer.
225     *
226     * @param incoming the incoming values
227     * @return the output values for this layer
228     * @throws Exception if there is a problem computing the outputs
229     */
230    protected HashMap<String, Double> computeOutput(HashMap<String, Double> incoming) 
231      throws Exception {
232     
233      m_layerOutput.clear();
234     
235      double normSum = 0;
236      for (int i = 0; i < m_layerNeurons.length; i++) {
237        double neuronOut = m_layerNeurons[i].getValue(incoming);
238        String neuronID = m_layerNeurons[i].getID();
239
240        if (m_layerNormalization == Normalization.SOFTMAX) {
241          normSum += Math.exp(neuronOut);
242        } else if (m_layerNormalization == Normalization.SIMPLEMAX) {
243          normSum += neuronOut;
244        }
245        //System.err.println("Inserting ID " + neuronID + " " + neuronOut);
246        m_layerOutput.put(neuronID, neuronOut);
247      }
248     
249      // apply the normalization (if necessary)
250      if (m_layerNormalization != Normalization.NONE) {
251        for (int i = 0; i < m_layerNeurons.length; i++) {
252          double val = m_layerOutput.get(m_layerNeurons[i].getID());
253//          System.err.println("Normalizing ID " + m_layerNeurons[i].getID() + " " + val);
254          if (m_layerNormalization == Normalization.SOFTMAX) {
255            val = Math.exp(val) / normSum;
256          } else {
257            val = (val / normSum);
258          }
259          m_layerOutput.put(m_layerNeurons[i].getID(), val);
260        }
261      }
262      return m_layerOutput;
263    }
264   
265    public String toString() {
266      StringBuffer temp = new StringBuffer();
267     
268      temp.append("activation: " + getActivationFunction() + "\n");
269      if (!Double.isNaN(getThreshold())) {
270        temp.append("threshold: " + getThreshold() + "\n");
271      }
272      if (!Double.isNaN(getWidth())) {
273        temp.append("width: " + getWidth() + "\n");
274      }
275      if (!Double.isNaN(getAltitude())) {
276        temp.append("altitude: " + getAltitude() + "\n");
277      }
278      temp.append("normalization: " + m_layerNormalization + "\n");
279      for (int i = 0; i < m_numNeurons; i++) {
280        temp.append(m_layerNeurons[i] + "\n");
281      }
282
283      return temp.toString();
284    }
285  }
286 
287  /**
288   * Inner class encapsulating a Neuron
289   */
290  static class Neuron implements Serializable {
291   
292    /**
293     * For serialization
294     */
295    private static final long serialVersionUID = -3817434025682603443L;
296
297    /** ID string */
298    private String m_ID = null;
299   
300    /** The layer we belong to (for accessing activation function, threshold etc.) */
301    private NeuralLayer m_layer;
302   
303    /** The bias */
304    private double m_bias = 0.0;
305   
306    /** The width (if defined overrides the one in NeuralLayer or NeuralNetwork) */
307    private double m_neuronWidth = Double.NaN;
308   
309    /** The altitude (if defined overrides the one in NeuralLayer or NeuralNetwork) */
310    private double m_neuronAltitude = Double.NaN;
311   
312    /** The IDs of the neurons/neural inputs that we are connected to */
313    private String[] m_connectionIDs = null;
314   
315    /** The weights corresponding to the connections */
316    private double[] m_weights = null;
317   
318    protected Neuron(Element neuronE, NeuralLayer layer) {
319      m_layer = layer;
320     
321      m_ID = neuronE.getAttribute("id");
322     
323      String bias = neuronE.getAttribute("bias");
324      if (bias != null && bias.length() > 0) {
325        m_bias = Double.parseDouble(bias);
326      }
327     
328      String width = neuronE.getAttribute("width");
329      if (width != null && width.length() > 0) {
330        m_neuronWidth = Double.parseDouble(width);
331      }
332     
333      String altitude = neuronE.getAttribute("altitude");
334      if (altitude != null && altitude.length() > 0) {
335        m_neuronAltitude = Double.parseDouble(altitude);
336      }
337     
338      // get the connection details
339      NodeList conL = neuronE.getElementsByTagName("Con");
340      m_connectionIDs = new String[conL.getLength()];
341      m_weights = new double[conL.getLength()];
342      for (int i = 0; i < conL.getLength(); i++) {
343        Node conN = conL.item(i);
344        if (conN.getNodeType() == Node.ELEMENT_NODE) {
345          Element conE = (Element)conN;
346          m_connectionIDs[i] = conE.getAttribute("from");
347          String weight = conE.getAttribute("weight");
348          m_weights[i] = Double.parseDouble(weight);
349        }
350      }
351    }
352   
353    protected String getID() {
354      return m_ID;
355    }   
356   
357    /**
358     * Compute the output of this Neuron.
359     *
360     * @param incoming a Map of input values. The keys are the IDs
361     * of incoming connections (either neural inputs or neurons) and
362     * the values are the output values of the neural input/neuron in
363     * question.
364     *
365     * @return the output of this neuron
366     * @throws Exception if any of our incoming connection IDs cannot be
367     * located in the Map
368     */
369    protected double getValue(HashMap<String, Double> incoming) throws Exception {
370     
371      double z = 0;
372      double result = Double.NaN;
373     
374      double width = (Double.isNaN(m_neuronWidth))
375        ? m_layer.getWidth()
376        : m_neuronWidth;
377
378      z = m_bias;
379      for (int i = 0; i < m_connectionIDs.length; i++) {
380        Double inVal = incoming.get(m_connectionIDs[i]);
381        if (inVal == null) {
382          throw new Exception("[Neuron] unable to find connection " 
383              + m_connectionIDs[i] + " in input Map!");
384        }
385
386        if (m_layer.getActivationFunction() != ActivationFunction.RADIALBASIS) {
387          // multiply with weight
388          double inV = inVal.doubleValue() * m_weights[i];
389          z += inV;
390        } else {
391          // Euclidean distance to the center (stored in m_weights)
392          double inV = Math.pow((inVal.doubleValue() - m_weights[i]), 2.0);
393          z += inV;
394        }
395      }
396     
397      // apply the width if necessary
398      if (m_layer.getActivationFunction() == ActivationFunction.RADIALBASIS) {
399        z /= (2.0 * (width * width));
400      }
401
402      double threshold = m_layer.getThreshold();
403      double altitude = (Double.isNaN(m_neuronAltitude))
404        ? m_layer.getAltitude()
405        : m_neuronAltitude;
406       
407      double fanIn = m_connectionIDs.length;       
408      result = m_layer.getActivationFunction().eval(z, threshold, altitude, fanIn);
409     
410      return result;
411    }
412   
413    public String toString() {
414      StringBuffer temp = new StringBuffer();
415      temp.append("Nueron (" + m_ID + ") [bias:" + m_bias);
416      if (!Double.isNaN(m_neuronWidth)) {
417        temp.append(" width:" + m_neuronWidth);
418      }
419      if (!Double.isNaN(m_neuronAltitude)) {
420        temp.append(" altitude:" + m_neuronAltitude);
421      }
422      temp.append("]\n");
423      temp.append("  con. (ID:weight): ");
424      for (int i = 0; i < m_connectionIDs.length; i++) {
425        temp.append(m_connectionIDs[i] + ":" + Utils.doubleToString(m_weights[i], 2));
426        if ((i + 1) % 10 == 0 || i == m_connectionIDs.length - 1) {
427          temp.append("\n                    ");
428        } else {
429          temp.append(", ");
430        }
431      }
432      return temp.toString();
433    }
434  }
435 
436  static class NeuralOutputs implements Serializable {
437   
438    /**
439     * For serialization
440     */
441    private static final long serialVersionUID = -233611113950482952L;
442
443    /** The neurons we are mapping */
444    private String[] m_outputNeurons = null;
445   
446    /**
447     *  In the case of a nominal class, the index of the value
448     * being predicted by each output neuron
449     */
450    private int[] m_categoricalIndexes = null;
451   
452    /** The class attribute we are mapping to */
453    private Attribute m_classAttribute = null;
454   
455    /** Used when the class is numeric */
456    private NormContinuous m_regressionMapping = null;
457       
458    protected NeuralOutputs(Element outputs, MiningSchema miningSchema) throws Exception {
459      m_classAttribute = miningSchema.getMiningSchemaAsInstances().classAttribute();
460     
461      int vals = (m_classAttribute.isNumeric())
462        ? 1
463        : m_classAttribute.numValues();
464     
465      m_outputNeurons = new String[vals];
466      m_categoricalIndexes = new int[vals];
467     
468      NodeList outputL = outputs.getElementsByTagName("NeuralOutput");
469      if (outputL.getLength() != m_outputNeurons.length) {
470        throw new Exception("[NeuralOutputs] the number of neural outputs does not match "
471            + "the number expected!");
472      }
473     
474      for (int i = 0; i < outputL.getLength(); i++) {
475        Node outputN = outputL.item(i);
476        if (outputN.getNodeType() == Node.ELEMENT_NODE) {
477          Element outputE = (Element)outputN;
478          // get the ID for this output neuron
479          m_outputNeurons[i] = outputE.getAttribute("outputNeuron");
480         
481          if (m_classAttribute.isNumeric()) {
482            // get the single norm continuous
483            NodeList contL = outputE.getElementsByTagName("NormContinuous");
484            if (contL.getLength() != 1) {
485              throw new Exception("[NeuralOutputs] Should be exactly one norm continuous element "
486                  + "for numeric class!");
487            }
488            Node normContNode = contL.item(0);
489            String attName = ((Element)normContNode).getAttribute("field");
490            Attribute dummyTargetDef = new Attribute(attName);
491            ArrayList<Attribute> dummyFieldDefs = new ArrayList<Attribute>();
492            dummyFieldDefs.add(dummyTargetDef);
493           
494            m_regressionMapping = new NormContinuous((Element)normContNode, 
495                FieldMetaInfo.Optype.CONTINUOUS, dummyFieldDefs);
496            break;
497          } else {
498            // we just need to grab the categorical value (out of the NormDiscrete element)
499            // that this output neuron is associated with
500            NodeList discL = outputE.getElementsByTagName("NormDiscrete");
501            if (discL.getLength() != 1) {
502              throw new Exception("[NeuralOutputs] Should be only one norm discrete element "
503                  + "per derived field/neural output for a nominal class!");
504            }
505            Node normDiscNode = discL.item(0);
506            String attValue = ((Element)normDiscNode).getAttribute("value");
507            int index = m_classAttribute.indexOfValue(attValue);
508            if (index < 0) {
509              throw new Exception("[NeuralOutputs] Can't find specified target value "
510                  + attValue + " in class attribute " + m_classAttribute.name());
511            }
512            m_categoricalIndexes[i] = index;
513          }
514        }
515      }
516    }
517   
518    /**
519     * Compute the output. Either a probability distribution or a single
520     * value (regression).
521     *
522     * @param incoming the values from the last hidden layer
523     * @param preds the array to fill with predicted values
524     * @throws Exception if there is a problem computing the output
525     */
526    protected void getOuput(HashMap<String, Double> incoming, double[] preds) throws Exception {
527     
528      if (preds.length != m_outputNeurons.length) {
529        throw new Exception("[NeuralOutputs] Incorrect number of predictions requested: "
530            + preds.length + "requested, " + m_outputNeurons.length + " expected");
531      }
532      for (int i = 0; i < m_outputNeurons.length; i++) {
533        Double neuronOut = incoming.get(m_outputNeurons[i]);
534        if (neuronOut == null) {
535          throw new Exception("[NeuralOutputs] Unable to find output neuron "
536              + m_outputNeurons[i] + " in the incoming HashMap!!");
537        }
538        if (m_classAttribute.isNumeric()) {
539          // will be only one output neuron anyway
540          preds[0] = neuronOut.doubleValue();
541         
542          preds[0] = m_regressionMapping.getResultInverse(preds);
543        } else {
544
545          // clip at zero
546          // preds[m_categoricalIndexes[i]] = (neuronOut < 0) ? 0.0 : neuronOut;
547          preds[m_categoricalIndexes[i]] = neuronOut;
548        }
549      }
550     
551      if (m_classAttribute.isNominal()) {
552        // check for negative values and adjust
553        double min = preds[Utils.minIndex(preds)];
554        if (min < 0) {
555          for (int i = 0; i < preds.length; i++) {
556            preds[i] -= min;
557          }
558        }
559        // do a simplemax normalization
560        Utils.normalize(preds);
561      }
562    }
563   
564    public String toString() {
565      StringBuffer temp = new StringBuffer();
566     
567      for (int i = 0; i < m_outputNeurons.length; i++) {
568        temp.append("Output neuron (" + m_outputNeurons[i] + ")\n");
569        temp.append("mapping:\n");
570        if (m_classAttribute.isNumeric()) {
571          temp.append(m_regressionMapping +"\n");
572        } else {
573          temp.append(m_classAttribute.name() + " = " 
574              + m_classAttribute.value(m_categoricalIndexes[i]) + "\n");
575        }
576      }
577     
578      return temp.toString();
579    }
580  }
581 
582  /**
583   * Enumerated type for the mining function
584   */
585  enum MiningFunction {
586    CLASSIFICATION,
587    REGRESSION;
588  }
589 
590  /** The mining function */
591  protected MiningFunction m_functionType = MiningFunction.CLASSIFICATION;
592 
593  /**
594   * Enumerated type for the activation function.
595   */
596  enum ActivationFunction {
597    THRESHOLD("threshold") {
598      double eval(double z, double threshold, double altitude, double fanIn) {
599        if (z > threshold) {
600          return 1.0;
601        }
602        return 0.0;
603      }
604    },
605    LOGISTIC("logistic") {
606      double eval(double z, double threshold, double altitude, double fanIn) {
607        return 1.0 / (1.0 + Math.exp(-z));
608      }
609    },
610    TANH("tanh") {
611      double eval(double z, double threshold, double altitude, double fanIn) {
612        double a = Math.exp( z );
613        double b = Math.exp( -z );
614        return ((a-b)/(a+b));
615        //return (1.0 - Math.exp(-2.0 * z)) / (1.0 + Math.exp(-2.0 * z));
616      }
617    },
618    IDENTITY("identity") {
619      double eval(double z, double threshold, double altitude, double fanIn) {
620        return z;
621      }
622    },
623    EXPONENTIAL("exponential") {
624      double eval(double z, double threshold, double altitude, double fanIn) {
625        return Math.exp(z);
626      }
627    },
628    RECIPROCAL("reciprocal") {
629      double eval(double z, double threshold, double altitude, double fanIn) {
630        return 1.0 / z;
631      }
632    },
633    SQUARE("square") {
634      double eval(double z, double threshold, double altitude, double fanIn) {
635        return  z * z;
636      }
637    },
638    GAUSS("gauss") {
639      double eval(double z, double threshold, double altitude, double fanIn) {
640        return Math.exp(-(z * z));
641      }
642    },
643    SINE("sine") {
644      double eval(double z, double threshold, double altitude, double fanIn) {
645        return Math.sin(z);
646      }
647    },
648    COSINE("cosine") {
649      double eval(double z, double threshold, double altitude, double fanIn) {
650        return Math.cos(z);
651      }
652    },
653    ELLICOT("ellicot") {
654      double eval(double z, double threshold, double altitude, double fanIn) {
655        return z / (1.0 + Math.abs(z));
656      }
657    },
658    ARCTAN("arctan") {
659      double eval(double z, double threshold, double altitude, double fanIn) {
660        return 2.0 * Math.atan(z) / Math.PI;
661      }
662    },
663    RADIALBASIS("radialBasis") {
664      double eval(double z, double threshold, double altitude, double fanIn) {
665        return Math.exp(fanIn * Math.log(altitude) - z);
666      }
667    };
668   
669    abstract double eval(double z, double threshold, double altitude, double fanIn);
670   
671    private final String m_stringVal;
672   
673    ActivationFunction(String name) {
674      m_stringVal = name;
675    }
676   
677    public String toString() {
678      return m_stringVal;
679    }
680  }
681 
682  /** The activation function to use */
683  protected ActivationFunction m_activationFunction = ActivationFunction.ARCTAN;
684 
685  /**
686   * Enumerated type for the normalization method
687   */
688  enum Normalization {
689    NONE ("none"),
690    SIMPLEMAX ("simplemax"),
691    SOFTMAX ("softmax");
692   
693    private final String m_stringVal;
694   
695    Normalization(String name) {
696      m_stringVal = name;
697    }
698   
699    public String toString() {
700      return m_stringVal;
701    }
702  }
703   
704  /** The normalization method */
705  protected Normalization m_normalizationMethod = Normalization.NONE;
706 
707  /** Threshold activation */
708  protected double m_threshold = 0.0; // default = 0
709 
710  /** Width for radial basis */
711  protected double m_width = Double.NaN; // no default
712 
713  /** Altitude for radial basis */
714  protected double m_altitude = 1.0; // default = 1
715 
716  /** The number of inputs to the network */
717  protected int m_numberOfInputs = 0;
718 
719  /** Number of hidden layers in the network */
720  protected int m_numberOfLayers = 0;
721 
722  /** The inputs to the network */
723  protected NeuralInput[] m_inputs = null;
724 
725  /** A map for storing network input values (computed from an incoming instance) */
726  protected HashMap<String, Double> m_inputMap = new HashMap<String, Double>();
727   
728  /** The hidden layers in the network */
729  protected NeuralLayer[] m_layers = null;
730 
731  /** The outputs of the network */
732  protected NeuralOutputs m_outputs = null;
733 
734  public NeuralNetwork(Element model, Instances dataDictionary,
735                       MiningSchema miningSchema) throws Exception {
736   
737    super(dataDictionary, miningSchema);
738   
739    String fn = model.getAttribute("functionName");
740    if (fn.equals("regression")) {
741      m_functionType = MiningFunction.REGRESSION;
742    }
743   
744    String act = model.getAttribute("activationFunction");
745    if (act == null || act.length() == 0) {
746      throw new Exception("[NeuralNetwork] no activation functon defined");
747    }
748   
749    // get the activation function
750    for (ActivationFunction a : ActivationFunction.values()) {
751      if (a.toString().equals(act)) {
752        m_activationFunction = a;
753        break;
754      }
755    }
756   
757    // get the normalization method (if specified)
758    String norm = model.getAttribute("normalizationMethod");
759    if (norm != null && norm.length() > 0) {
760      for (Normalization n : Normalization.values()) {
761        if (n.toString().equals(norm)) {
762          m_normalizationMethod = n;
763          break;
764        }
765      }
766    }
767   
768    String thresh = model.getAttribute("threshold");
769    if (thresh != null && thresh.length() > 0) {
770      m_threshold = Double.parseDouble(thresh);
771    }
772    String width = model.getAttribute("width");
773    if (width != null && width.length() > 0) {
774      m_width = Double.parseDouble(width);
775    }
776    String alt = model.getAttribute("altitude");
777    if (alt != null && alt.length() > 0) {
778      m_altitude = Double.parseDouble(alt);
779    }
780   
781    // get all the inputs
782    NodeList inputL = model.getElementsByTagName("NeuralInput");
783    m_numberOfInputs = inputL.getLength();
784    m_inputs = new NeuralInput[m_numberOfInputs];
785    for (int i = 0; i < m_numberOfInputs; i++) {
786      Node inputN = inputL.item(i);
787      if (inputN.getNodeType() == Node.ELEMENT_NODE) {
788        NeuralInput nI = new NeuralInput((Element)inputN, m_miningSchema);
789        m_inputs[i] = nI;
790      }
791    }
792   
793    // get the layers
794    NodeList layerL = model.getElementsByTagName("NeuralLayer");
795    m_numberOfLayers = layerL.getLength();
796    m_layers = new NeuralLayer[m_numberOfLayers];
797    for (int i = 0; i < m_numberOfLayers; i++) {
798      Node layerN = layerL.item(i);
799      if (layerN.getNodeType() == Node.ELEMENT_NODE) {
800        NeuralLayer nL = new NeuralLayer((Element)layerN);
801        m_layers[i] = nL;
802      }
803    }
804   
805    // get the outputs
806    NodeList outputL = model.getElementsByTagName("NeuralOutputs");
807    if (outputL.getLength() != 1) {
808      throw new Exception("[NeuralNetwork] Should be just one NeuralOutputs element defined!");
809    }
810   
811    m_outputs = new NeuralOutputs((Element)outputL.item(0), m_miningSchema);
812  }
813
814  /* (non-Javadoc)
815   * @see weka.core.RevisionHandler#getRevision()
816   */
817  public String getRevision() {
818    return RevisionUtils.extract("$Revision: 5987 $");
819  }
820 
821  /**                                                                                                             
822   * Classifies the given test instance. The instance has to belong to a                                         
823   * dataset when it's being classified.                                                         
824   *                                                                                                             
825   * @param inst the instance to be classified                                                               
826   * @return the predicted most likely class for the instance or                                                 
827   * Utils.missingValue() if no prediction is made                                                             
828   * @exception Exception if an error occurred during the prediction                                             
829   */
830  public double[] distributionForInstance(Instance inst) throws Exception {
831    if (!m_initialized) {
832      mapToMiningSchema(inst.dataset());
833    }
834    double[] preds = null;
835   
836    if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
837      preds = new double[1];
838    } else {
839      preds = new double[m_miningSchema.getFieldsAsInstances().classAttribute().numValues()];
840    }
841   
842    double[] incoming = m_fieldsMap.instanceToSchema(inst, m_miningSchema);
843   
844    boolean hasMissing = false;
845    for (int i = 0; i < incoming.length; i++) {
846      if (i != m_miningSchema.getFieldsAsInstances().classIndex() && 
847          Double.isNaN(incoming[i])) {
848        hasMissing = true;
849        //System.err.println("Missing value for att : " + m_miningSchema.getFieldsAsInstances().attribute(i).name());
850        break;
851      }
852    }
853   
854    if (hasMissing) {
855      if (!m_miningSchema.hasTargetMetaData()) {
856        String message = "[NeuralNetwork] WARNING: Instance to predict has missing value(s) but "
857          + "there is no missing value handling meta data and no "
858          + "prior probabilities/default value to fall back to. No "
859          + "prediction will be made (" 
860          + ((m_miningSchema.getFieldsAsInstances().classAttribute().isNominal()
861              || m_miningSchema.getFieldsAsInstances().classAttribute().isString())
862              ? "zero probabilities output)."
863              : "NaN output).");
864        if (m_log == null) {
865          System.err.println(message);
866        } else {
867          m_log.logMessage(message);
868        }
869       
870        if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
871          preds[0] = Utils.missingValue();
872        }
873        return preds;
874      } else {
875        // use prior probablilities/default value
876        TargetMetaInfo targetData = m_miningSchema.getTargetMetaData();
877        if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
878          preds[0] = targetData.getDefaultValue();
879        } else {
880          Instances miningSchemaI = m_miningSchema.getFieldsAsInstances();
881          for (int i = 0; i < miningSchemaI.classAttribute().numValues(); i++) {
882            preds[i] = targetData.getPriorProbability(miningSchemaI.classAttribute().value(i));
883          }
884        }
885        return preds;
886      }
887    } else {
888     
889      // construct the input to the network for this instance
890      m_inputMap.clear();
891      for (int i = 0; i < m_inputs.length; i++) {
892        double networkInVal = m_inputs[i].getValue(incoming);
893        String ID = m_inputs[i].getID();
894        m_inputMap.put(ID, networkInVal);
895      }
896     
897      // now compute the output of each layer
898      HashMap<String, Double> layerOut = m_layers[0].computeOutput(m_inputMap);
899      for (int i = 1; i < m_layers.length; i++) {
900        layerOut = m_layers[i].computeOutput(layerOut);
901      }
902     
903      // now do the output
904      m_outputs.getOuput(layerOut, preds);
905    }
906   
907    return preds;
908  }
909
910  public String toString() {
911    StringBuffer temp = new StringBuffer();
912   
913    temp.append("PMML version " + getPMMLVersion());
914    if (!getCreatorApplication().equals("?")) {
915      temp.append("\nApplication: " + getCreatorApplication());
916    }
917    temp.append("\nPMML Model: Neural network");
918    temp.append("\n\n");
919    temp.append(m_miningSchema);
920   
921    temp.append("Inputs:\n");
922    for (int i = 0; i < m_inputs.length; i++) {
923      temp.append(m_inputs[i] + "\n");
924    }
925
926    for (int i = 0; i < m_layers.length; i++) {
927      temp.append("Layer: " + (i+1) + "\n");
928      temp.append(m_layers[i] + "\n");
929    }
930   
931    temp.append("Outputs:\n");
932    temp.append(m_outputs);
933   
934    return temp.toString();
935  }
936}
Note: See TracBrowser for help on using the repository browser.