1 | /* |
---|
2 | * This program is free software; you can redistribute it and/or modify |
---|
3 | * it under the terms of the GNU General Public License as published by |
---|
4 | * the Free Software Foundation; either version 2 of the License, or |
---|
5 | * (at your option) any later version. |
---|
6 | * |
---|
7 | * This program is distributed in the hope that it will be useful, |
---|
8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
---|
10 | * GNU General Public License for more details. |
---|
11 | * |
---|
12 | * You should have received a copy of the GNU General Public License |
---|
13 | * along with this program; if not, write to the Free Software |
---|
14 | * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. |
---|
15 | */ |
---|
16 | |
---|
17 | /* |
---|
18 | * NeuralNetwork.java |
---|
19 | * Copyright (C) 2008 University of Waikato, Hamilton, New Zealand |
---|
20 | * |
---|
21 | */ |
---|
22 | |
---|
23 | package weka.classifiers.pmml.consumer; |
---|
24 | |
---|
25 | import java.io.Serializable; |
---|
26 | import java.util.ArrayList; |
---|
27 | import java.util.HashMap; |
---|
28 | |
---|
29 | import org.w3c.dom.Element; |
---|
30 | import org.w3c.dom.Node; |
---|
31 | import org.w3c.dom.NodeList; |
---|
32 | |
---|
33 | import weka.core.Attribute; |
---|
34 | import weka.core.Instance; |
---|
35 | import weka.core.Instances; |
---|
36 | import weka.core.RevisionUtils; |
---|
37 | import weka.core.Utils; |
---|
38 | import weka.core.pmml.*; |
---|
39 | |
---|
40 | /** |
---|
41 | * Class implementing import of PMML Neural Network model. Can be used as a Weka |
---|
42 | * classifier for prediction (buildClassifier() raises an Exception). |
---|
43 | * |
---|
44 | * @author Mark Hall (mhall{[at]}pentaho{[dot]}com) |
---|
45 | * @version $Revision 1.0 $ |
---|
46 | */ |
---|
47 | public class NeuralNetwork extends PMMLClassifier { |
---|
48 | |
---|
49 | /** |
---|
50 | * For serialization |
---|
51 | */ |
---|
52 | private static final long serialVersionUID = -4545904813133921249L; |
---|
53 | |
---|
54 | /** |
---|
55 | * Small inner class for a NeuralInput (essentially just |
---|
56 | * wraps a DerivedField and adds an ID) |
---|
57 | */ |
---|
58 | static class NeuralInput implements Serializable { |
---|
59 | |
---|
60 | /** |
---|
61 | * For serialization |
---|
62 | */ |
---|
63 | private static final long serialVersionUID = -1902233762824835563L; |
---|
64 | |
---|
65 | /** Field that this input refers to */ |
---|
66 | private DerivedFieldMetaInfo m_field; |
---|
67 | |
---|
68 | /** ID string */ |
---|
69 | private String m_ID = null; |
---|
70 | |
---|
71 | private String getID() { |
---|
72 | return m_ID; |
---|
73 | } |
---|
74 | |
---|
75 | protected NeuralInput(Element input, MiningSchema miningSchema) throws Exception { |
---|
76 | m_ID = input.getAttribute("id"); |
---|
77 | |
---|
78 | NodeList fL = input.getElementsByTagName("DerivedField"); |
---|
79 | if (fL.getLength() != 1) { |
---|
80 | throw new Exception("[NeuralInput] expecting just one derived field!"); |
---|
81 | } |
---|
82 | |
---|
83 | Element dF = (Element)fL.item(0); |
---|
84 | Instances allFields = miningSchema.getFieldsAsInstances(); |
---|
85 | ArrayList<Attribute> fieldDefs = new ArrayList<Attribute>(); |
---|
86 | for (int i = 0; i < allFields.numAttributes(); i++) { |
---|
87 | fieldDefs.add(allFields.attribute(i)); |
---|
88 | } |
---|
89 | m_field = new DerivedFieldMetaInfo(dF, fieldDefs, miningSchema.getTransformationDictionary()); |
---|
90 | } |
---|
91 | |
---|
92 | protected double getValue(double[] incoming) throws Exception { |
---|
93 | return m_field.getDerivedValue(incoming); |
---|
94 | } |
---|
95 | |
---|
96 | public String toString() { |
---|
97 | StringBuffer temp = new StringBuffer(); |
---|
98 | |
---|
99 | temp.append("Nueral input (" + getID() + ")\n"); |
---|
100 | temp.append(m_field); |
---|
101 | |
---|
102 | return temp.toString(); |
---|
103 | } |
---|
104 | } |
---|
105 | |
---|
106 | /** |
---|
107 | * Inner class representing a layer in the network. |
---|
108 | */ |
---|
109 | class NeuralLayer implements Serializable { |
---|
110 | |
---|
111 | /** |
---|
112 | * For serialization |
---|
113 | */ |
---|
114 | private static final long serialVersionUID = -8386042001675763922L; |
---|
115 | |
---|
116 | /** The number of neurons in this layer */ |
---|
117 | private int m_numNeurons = 0; |
---|
118 | |
---|
119 | /** Activation function (if defined, overrides one in NeuralNetwork) */ |
---|
120 | private ActivationFunction m_layerActivationFunction = null; |
---|
121 | |
---|
122 | /** Threshold (if defined overrides one in NeuralNetwork) */ |
---|
123 | private double m_layerThreshold = Double.NaN; |
---|
124 | |
---|
125 | /** Width (if defined overrides one in NeuralNetwork) */ |
---|
126 | private double m_layerWidth = Double.NaN; |
---|
127 | |
---|
128 | /** Altitude (if defined overrides one in NeuralNetwork) */ |
---|
129 | private double m_layerAltitude = Double.NaN; |
---|
130 | |
---|
131 | /** Normalization (if defined overrides one in NeuralNetwork) */ |
---|
132 | private Normalization m_layerNormalization = null; |
---|
133 | |
---|
134 | /** The neurons at this hidden layer */ |
---|
135 | private Neuron[] m_layerNeurons = null; |
---|
136 | |
---|
137 | /** Stores the output of this layer (for given inputs) */ |
---|
138 | private HashMap<String, Double> m_layerOutput = new HashMap<String, Double>(); |
---|
139 | |
---|
140 | protected NeuralLayer(Element layerE) { |
---|
141 | |
---|
142 | String activationFunction = layerE.getAttribute("activationFunction"); |
---|
143 | if (activationFunction != null && activationFunction.length() > 0) { |
---|
144 | for (ActivationFunction a : ActivationFunction.values()) { |
---|
145 | if (a.toString().equals(activationFunction)) { |
---|
146 | m_layerActivationFunction = a; |
---|
147 | break; |
---|
148 | } |
---|
149 | } |
---|
150 | } else { |
---|
151 | // use the network-level activation function |
---|
152 | m_layerActivationFunction = m_activationFunction; |
---|
153 | } |
---|
154 | |
---|
155 | String threshold = layerE.getAttribute("threshold"); |
---|
156 | if (threshold != null && threshold.length() > 0) { |
---|
157 | m_layerThreshold = Double.parseDouble(threshold); |
---|
158 | } else { |
---|
159 | // use network-level threshold |
---|
160 | m_layerThreshold = m_threshold; |
---|
161 | } |
---|
162 | |
---|
163 | String width = layerE.getAttribute("width"); |
---|
164 | if (width != null && width.length() > 0) { |
---|
165 | m_layerWidth = Double.parseDouble(width); |
---|
166 | } else { |
---|
167 | // use network-level width |
---|
168 | m_layerWidth = m_width; |
---|
169 | } |
---|
170 | |
---|
171 | String altitude = layerE.getAttribute("altitude"); |
---|
172 | if (altitude != null && altitude.length() > 0) { |
---|
173 | m_layerAltitude = Double.parseDouble(altitude); |
---|
174 | } else { |
---|
175 | // use network-level altitude |
---|
176 | m_layerAltitude = m_altitude; |
---|
177 | } |
---|
178 | |
---|
179 | String normMethod = layerE.getAttribute("normalizationMethod"); |
---|
180 | if (normMethod != null && normMethod.length() > 0) { |
---|
181 | for (Normalization n : Normalization.values()) { |
---|
182 | if (n.toString().equals(normMethod)) { |
---|
183 | m_layerNormalization = n; |
---|
184 | break; |
---|
185 | } |
---|
186 | } |
---|
187 | } else { |
---|
188 | // use network-level normalization method |
---|
189 | m_layerNormalization = m_normalizationMethod; |
---|
190 | } |
---|
191 | |
---|
192 | NodeList neuronL = layerE.getElementsByTagName("Neuron"); |
---|
193 | m_numNeurons = neuronL.getLength(); |
---|
194 | m_layerNeurons = new Neuron[m_numNeurons]; |
---|
195 | for (int i = 0; i < neuronL.getLength(); i++) { |
---|
196 | Node neuronN = neuronL.item(i); |
---|
197 | if (neuronN.getNodeType() == Node.ELEMENT_NODE) { |
---|
198 | m_layerNeurons[i] = new Neuron((Element)neuronN, this); |
---|
199 | } |
---|
200 | } |
---|
201 | } |
---|
202 | |
---|
203 | protected ActivationFunction getActivationFunction() { |
---|
204 | return m_layerActivationFunction; |
---|
205 | } |
---|
206 | |
---|
207 | protected double getThreshold() { |
---|
208 | return m_layerThreshold; |
---|
209 | } |
---|
210 | |
---|
211 | protected double getWidth() { |
---|
212 | return m_layerWidth; |
---|
213 | } |
---|
214 | |
---|
215 | protected double getAltitude() { |
---|
216 | return m_layerAltitude; |
---|
217 | } |
---|
218 | |
---|
219 | protected Normalization getNormalization() { |
---|
220 | return m_layerNormalization; |
---|
221 | } |
---|
222 | |
---|
223 | /** |
---|
224 | * Compute the output values for this layer. |
---|
225 | * |
---|
226 | * @param incoming the incoming values |
---|
227 | * @return the output values for this layer |
---|
228 | * @throws Exception if there is a problem computing the outputs |
---|
229 | */ |
---|
230 | protected HashMap<String, Double> computeOutput(HashMap<String, Double> incoming) |
---|
231 | throws Exception { |
---|
232 | |
---|
233 | m_layerOutput.clear(); |
---|
234 | |
---|
235 | double normSum = 0; |
---|
236 | for (int i = 0; i < m_layerNeurons.length; i++) { |
---|
237 | double neuronOut = m_layerNeurons[i].getValue(incoming); |
---|
238 | String neuronID = m_layerNeurons[i].getID(); |
---|
239 | |
---|
240 | if (m_layerNormalization == Normalization.SOFTMAX) { |
---|
241 | normSum += Math.exp(neuronOut); |
---|
242 | } else if (m_layerNormalization == Normalization.SIMPLEMAX) { |
---|
243 | normSum += neuronOut; |
---|
244 | } |
---|
245 | //System.err.println("Inserting ID " + neuronID + " " + neuronOut); |
---|
246 | m_layerOutput.put(neuronID, neuronOut); |
---|
247 | } |
---|
248 | |
---|
249 | // apply the normalization (if necessary) |
---|
250 | if (m_layerNormalization != Normalization.NONE) { |
---|
251 | for (int i = 0; i < m_layerNeurons.length; i++) { |
---|
252 | double val = m_layerOutput.get(m_layerNeurons[i].getID()); |
---|
253 | // System.err.println("Normalizing ID " + m_layerNeurons[i].getID() + " " + val); |
---|
254 | if (m_layerNormalization == Normalization.SOFTMAX) { |
---|
255 | val = Math.exp(val) / normSum; |
---|
256 | } else { |
---|
257 | val = (val / normSum); |
---|
258 | } |
---|
259 | m_layerOutput.put(m_layerNeurons[i].getID(), val); |
---|
260 | } |
---|
261 | } |
---|
262 | return m_layerOutput; |
---|
263 | } |
---|
264 | |
---|
265 | public String toString() { |
---|
266 | StringBuffer temp = new StringBuffer(); |
---|
267 | |
---|
268 | temp.append("activation: " + getActivationFunction() + "\n"); |
---|
269 | if (!Double.isNaN(getThreshold())) { |
---|
270 | temp.append("threshold: " + getThreshold() + "\n"); |
---|
271 | } |
---|
272 | if (!Double.isNaN(getWidth())) { |
---|
273 | temp.append("width: " + getWidth() + "\n"); |
---|
274 | } |
---|
275 | if (!Double.isNaN(getAltitude())) { |
---|
276 | temp.append("altitude: " + getAltitude() + "\n"); |
---|
277 | } |
---|
278 | temp.append("normalization: " + m_layerNormalization + "\n"); |
---|
279 | for (int i = 0; i < m_numNeurons; i++) { |
---|
280 | temp.append(m_layerNeurons[i] + "\n"); |
---|
281 | } |
---|
282 | |
---|
283 | return temp.toString(); |
---|
284 | } |
---|
285 | } |
---|
286 | |
---|
287 | /** |
---|
288 | * Inner class encapsulating a Neuron |
---|
289 | */ |
---|
290 | static class Neuron implements Serializable { |
---|
291 | |
---|
292 | /** |
---|
293 | * For serialization |
---|
294 | */ |
---|
295 | private static final long serialVersionUID = -3817434025682603443L; |
---|
296 | |
---|
297 | /** ID string */ |
---|
298 | private String m_ID = null; |
---|
299 | |
---|
300 | /** The layer we belong to (for accessing activation function, threshold etc.) */ |
---|
301 | private NeuralLayer m_layer; |
---|
302 | |
---|
303 | /** The bias */ |
---|
304 | private double m_bias = 0.0; |
---|
305 | |
---|
306 | /** The width (if defined overrides the one in NeuralLayer or NeuralNetwork) */ |
---|
307 | private double m_neuronWidth = Double.NaN; |
---|
308 | |
---|
309 | /** The altitude (if defined overrides the one in NeuralLayer or NeuralNetwork) */ |
---|
310 | private double m_neuronAltitude = Double.NaN; |
---|
311 | |
---|
312 | /** The IDs of the neurons/neural inputs that we are connected to */ |
---|
313 | private String[] m_connectionIDs = null; |
---|
314 | |
---|
315 | /** The weights corresponding to the connections */ |
---|
316 | private double[] m_weights = null; |
---|
317 | |
---|
318 | protected Neuron(Element neuronE, NeuralLayer layer) { |
---|
319 | m_layer = layer; |
---|
320 | |
---|
321 | m_ID = neuronE.getAttribute("id"); |
---|
322 | |
---|
323 | String bias = neuronE.getAttribute("bias"); |
---|
324 | if (bias != null && bias.length() > 0) { |
---|
325 | m_bias = Double.parseDouble(bias); |
---|
326 | } |
---|
327 | |
---|
328 | String width = neuronE.getAttribute("width"); |
---|
329 | if (width != null && width.length() > 0) { |
---|
330 | m_neuronWidth = Double.parseDouble(width); |
---|
331 | } |
---|
332 | |
---|
333 | String altitude = neuronE.getAttribute("altitude"); |
---|
334 | if (altitude != null && altitude.length() > 0) { |
---|
335 | m_neuronAltitude = Double.parseDouble(altitude); |
---|
336 | } |
---|
337 | |
---|
338 | // get the connection details |
---|
339 | NodeList conL = neuronE.getElementsByTagName("Con"); |
---|
340 | m_connectionIDs = new String[conL.getLength()]; |
---|
341 | m_weights = new double[conL.getLength()]; |
---|
342 | for (int i = 0; i < conL.getLength(); i++) { |
---|
343 | Node conN = conL.item(i); |
---|
344 | if (conN.getNodeType() == Node.ELEMENT_NODE) { |
---|
345 | Element conE = (Element)conN; |
---|
346 | m_connectionIDs[i] = conE.getAttribute("from"); |
---|
347 | String weight = conE.getAttribute("weight"); |
---|
348 | m_weights[i] = Double.parseDouble(weight); |
---|
349 | } |
---|
350 | } |
---|
351 | } |
---|
352 | |
---|
353 | protected String getID() { |
---|
354 | return m_ID; |
---|
355 | } |
---|
356 | |
---|
357 | /** |
---|
358 | * Compute the output of this Neuron. |
---|
359 | * |
---|
360 | * @param incoming a Map of input values. The keys are the IDs |
---|
361 | * of incoming connections (either neural inputs or neurons) and |
---|
362 | * the values are the output values of the neural input/neuron in |
---|
363 | * question. |
---|
364 | * |
---|
365 | * @return the output of this neuron |
---|
366 | * @throws Exception if any of our incoming connection IDs cannot be |
---|
367 | * located in the Map |
---|
368 | */ |
---|
369 | protected double getValue(HashMap<String, Double> incoming) throws Exception { |
---|
370 | |
---|
371 | double z = 0; |
---|
372 | double result = Double.NaN; |
---|
373 | |
---|
374 | double width = (Double.isNaN(m_neuronWidth)) |
---|
375 | ? m_layer.getWidth() |
---|
376 | : m_neuronWidth; |
---|
377 | |
---|
378 | z = m_bias; |
---|
379 | for (int i = 0; i < m_connectionIDs.length; i++) { |
---|
380 | Double inVal = incoming.get(m_connectionIDs[i]); |
---|
381 | if (inVal == null) { |
---|
382 | throw new Exception("[Neuron] unable to find connection " |
---|
383 | + m_connectionIDs[i] + " in input Map!"); |
---|
384 | } |
---|
385 | |
---|
386 | if (m_layer.getActivationFunction() != ActivationFunction.RADIALBASIS) { |
---|
387 | // multiply with weight |
---|
388 | double inV = inVal.doubleValue() * m_weights[i]; |
---|
389 | z += inV; |
---|
390 | } else { |
---|
391 | // Euclidean distance to the center (stored in m_weights) |
---|
392 | double inV = Math.pow((inVal.doubleValue() - m_weights[i]), 2.0); |
---|
393 | z += inV; |
---|
394 | } |
---|
395 | } |
---|
396 | |
---|
397 | // apply the width if necessary |
---|
398 | if (m_layer.getActivationFunction() == ActivationFunction.RADIALBASIS) { |
---|
399 | z /= (2.0 * (width * width)); |
---|
400 | } |
---|
401 | |
---|
402 | double threshold = m_layer.getThreshold(); |
---|
403 | double altitude = (Double.isNaN(m_neuronAltitude)) |
---|
404 | ? m_layer.getAltitude() |
---|
405 | : m_neuronAltitude; |
---|
406 | |
---|
407 | double fanIn = m_connectionIDs.length; |
---|
408 | result = m_layer.getActivationFunction().eval(z, threshold, altitude, fanIn); |
---|
409 | |
---|
410 | return result; |
---|
411 | } |
---|
412 | |
---|
413 | public String toString() { |
---|
414 | StringBuffer temp = new StringBuffer(); |
---|
415 | temp.append("Nueron (" + m_ID + ") [bias:" + m_bias); |
---|
416 | if (!Double.isNaN(m_neuronWidth)) { |
---|
417 | temp.append(" width:" + m_neuronWidth); |
---|
418 | } |
---|
419 | if (!Double.isNaN(m_neuronAltitude)) { |
---|
420 | temp.append(" altitude:" + m_neuronAltitude); |
---|
421 | } |
---|
422 | temp.append("]\n"); |
---|
423 | temp.append(" con. (ID:weight): "); |
---|
424 | for (int i = 0; i < m_connectionIDs.length; i++) { |
---|
425 | temp.append(m_connectionIDs[i] + ":" + Utils.doubleToString(m_weights[i], 2)); |
---|
426 | if ((i + 1) % 10 == 0 || i == m_connectionIDs.length - 1) { |
---|
427 | temp.append("\n "); |
---|
428 | } else { |
---|
429 | temp.append(", "); |
---|
430 | } |
---|
431 | } |
---|
432 | return temp.toString(); |
---|
433 | } |
---|
434 | } |
---|
435 | |
---|
436 | static class NeuralOutputs implements Serializable { |
---|
437 | |
---|
438 | /** |
---|
439 | * For serialization |
---|
440 | */ |
---|
441 | private static final long serialVersionUID = -233611113950482952L; |
---|
442 | |
---|
443 | /** The neurons we are mapping */ |
---|
444 | private String[] m_outputNeurons = null; |
---|
445 | |
---|
446 | /** |
---|
447 | * In the case of a nominal class, the index of the value |
---|
448 | * being predicted by each output neuron |
---|
449 | */ |
---|
450 | private int[] m_categoricalIndexes = null; |
---|
451 | |
---|
452 | /** The class attribute we are mapping to */ |
---|
453 | private Attribute m_classAttribute = null; |
---|
454 | |
---|
455 | /** Used when the class is numeric */ |
---|
456 | private NormContinuous m_regressionMapping = null; |
---|
457 | |
---|
458 | protected NeuralOutputs(Element outputs, MiningSchema miningSchema) throws Exception { |
---|
459 | m_classAttribute = miningSchema.getMiningSchemaAsInstances().classAttribute(); |
---|
460 | |
---|
461 | int vals = (m_classAttribute.isNumeric()) |
---|
462 | ? 1 |
---|
463 | : m_classAttribute.numValues(); |
---|
464 | |
---|
465 | m_outputNeurons = new String[vals]; |
---|
466 | m_categoricalIndexes = new int[vals]; |
---|
467 | |
---|
468 | NodeList outputL = outputs.getElementsByTagName("NeuralOutput"); |
---|
469 | if (outputL.getLength() != m_outputNeurons.length) { |
---|
470 | throw new Exception("[NeuralOutputs] the number of neural outputs does not match " |
---|
471 | + "the number expected!"); |
---|
472 | } |
---|
473 | |
---|
474 | for (int i = 0; i < outputL.getLength(); i++) { |
---|
475 | Node outputN = outputL.item(i); |
---|
476 | if (outputN.getNodeType() == Node.ELEMENT_NODE) { |
---|
477 | Element outputE = (Element)outputN; |
---|
478 | // get the ID for this output neuron |
---|
479 | m_outputNeurons[i] = outputE.getAttribute("outputNeuron"); |
---|
480 | |
---|
481 | if (m_classAttribute.isNumeric()) { |
---|
482 | // get the single norm continuous |
---|
483 | NodeList contL = outputE.getElementsByTagName("NormContinuous"); |
---|
484 | if (contL.getLength() != 1) { |
---|
485 | throw new Exception("[NeuralOutputs] Should be exactly one norm continuous element " |
---|
486 | + "for numeric class!"); |
---|
487 | } |
---|
488 | Node normContNode = contL.item(0); |
---|
489 | String attName = ((Element)normContNode).getAttribute("field"); |
---|
490 | Attribute dummyTargetDef = new Attribute(attName); |
---|
491 | ArrayList<Attribute> dummyFieldDefs = new ArrayList<Attribute>(); |
---|
492 | dummyFieldDefs.add(dummyTargetDef); |
---|
493 | |
---|
494 | m_regressionMapping = new NormContinuous((Element)normContNode, |
---|
495 | FieldMetaInfo.Optype.CONTINUOUS, dummyFieldDefs); |
---|
496 | break; |
---|
497 | } else { |
---|
498 | // we just need to grab the categorical value (out of the NormDiscrete element) |
---|
499 | // that this output neuron is associated with |
---|
500 | NodeList discL = outputE.getElementsByTagName("NormDiscrete"); |
---|
501 | if (discL.getLength() != 1) { |
---|
502 | throw new Exception("[NeuralOutputs] Should be only one norm discrete element " |
---|
503 | + "per derived field/neural output for a nominal class!"); |
---|
504 | } |
---|
505 | Node normDiscNode = discL.item(0); |
---|
506 | String attValue = ((Element)normDiscNode).getAttribute("value"); |
---|
507 | int index = m_classAttribute.indexOfValue(attValue); |
---|
508 | if (index < 0) { |
---|
509 | throw new Exception("[NeuralOutputs] Can't find specified target value " |
---|
510 | + attValue + " in class attribute " + m_classAttribute.name()); |
---|
511 | } |
---|
512 | m_categoricalIndexes[i] = index; |
---|
513 | } |
---|
514 | } |
---|
515 | } |
---|
516 | } |
---|
517 | |
---|
518 | /** |
---|
519 | * Compute the output. Either a probability distribution or a single |
---|
520 | * value (regression). |
---|
521 | * |
---|
522 | * @param incoming the values from the last hidden layer |
---|
523 | * @param preds the array to fill with predicted values |
---|
524 | * @throws Exception if there is a problem computing the output |
---|
525 | */ |
---|
526 | protected void getOuput(HashMap<String, Double> incoming, double[] preds) throws Exception { |
---|
527 | |
---|
528 | if (preds.length != m_outputNeurons.length) { |
---|
529 | throw new Exception("[NeuralOutputs] Incorrect number of predictions requested: " |
---|
530 | + preds.length + "requested, " + m_outputNeurons.length + " expected"); |
---|
531 | } |
---|
532 | for (int i = 0; i < m_outputNeurons.length; i++) { |
---|
533 | Double neuronOut = incoming.get(m_outputNeurons[i]); |
---|
534 | if (neuronOut == null) { |
---|
535 | throw new Exception("[NeuralOutputs] Unable to find output neuron " |
---|
536 | + m_outputNeurons[i] + " in the incoming HashMap!!"); |
---|
537 | } |
---|
538 | if (m_classAttribute.isNumeric()) { |
---|
539 | // will be only one output neuron anyway |
---|
540 | preds[0] = neuronOut.doubleValue(); |
---|
541 | |
---|
542 | preds[0] = m_regressionMapping.getResultInverse(preds); |
---|
543 | } else { |
---|
544 | |
---|
545 | // clip at zero |
---|
546 | // preds[m_categoricalIndexes[i]] = (neuronOut < 0) ? 0.0 : neuronOut; |
---|
547 | preds[m_categoricalIndexes[i]] = neuronOut; |
---|
548 | } |
---|
549 | } |
---|
550 | |
---|
551 | if (m_classAttribute.isNominal()) { |
---|
552 | // check for negative values and adjust |
---|
553 | double min = preds[Utils.minIndex(preds)]; |
---|
554 | if (min < 0) { |
---|
555 | for (int i = 0; i < preds.length; i++) { |
---|
556 | preds[i] -= min; |
---|
557 | } |
---|
558 | } |
---|
559 | // do a simplemax normalization |
---|
560 | Utils.normalize(preds); |
---|
561 | } |
---|
562 | } |
---|
563 | |
---|
564 | public String toString() { |
---|
565 | StringBuffer temp = new StringBuffer(); |
---|
566 | |
---|
567 | for (int i = 0; i < m_outputNeurons.length; i++) { |
---|
568 | temp.append("Output neuron (" + m_outputNeurons[i] + ")\n"); |
---|
569 | temp.append("mapping:\n"); |
---|
570 | if (m_classAttribute.isNumeric()) { |
---|
571 | temp.append(m_regressionMapping +"\n"); |
---|
572 | } else { |
---|
573 | temp.append(m_classAttribute.name() + " = " |
---|
574 | + m_classAttribute.value(m_categoricalIndexes[i]) + "\n"); |
---|
575 | } |
---|
576 | } |
---|
577 | |
---|
578 | return temp.toString(); |
---|
579 | } |
---|
580 | } |
---|
581 | |
---|
582 | /** |
---|
583 | * Enumerated type for the mining function |
---|
584 | */ |
---|
585 | enum MiningFunction { |
---|
586 | CLASSIFICATION, |
---|
587 | REGRESSION; |
---|
588 | } |
---|
589 | |
---|
590 | /** The mining function */ |
---|
591 | protected MiningFunction m_functionType = MiningFunction.CLASSIFICATION; |
---|
592 | |
---|
593 | /** |
---|
594 | * Enumerated type for the activation function. |
---|
595 | */ |
---|
596 | enum ActivationFunction { |
---|
597 | THRESHOLD("threshold") { |
---|
598 | double eval(double z, double threshold, double altitude, double fanIn) { |
---|
599 | if (z > threshold) { |
---|
600 | return 1.0; |
---|
601 | } |
---|
602 | return 0.0; |
---|
603 | } |
---|
604 | }, |
---|
605 | LOGISTIC("logistic") { |
---|
606 | double eval(double z, double threshold, double altitude, double fanIn) { |
---|
607 | return 1.0 / (1.0 + Math.exp(-z)); |
---|
608 | } |
---|
609 | }, |
---|
610 | TANH("tanh") { |
---|
611 | double eval(double z, double threshold, double altitude, double fanIn) { |
---|
612 | double a = Math.exp( z ); |
---|
613 | double b = Math.exp( -z ); |
---|
614 | return ((a-b)/(a+b)); |
---|
615 | //return (1.0 - Math.exp(-2.0 * z)) / (1.0 + Math.exp(-2.0 * z)); |
---|
616 | } |
---|
617 | }, |
---|
618 | IDENTITY("identity") { |
---|
619 | double eval(double z, double threshold, double altitude, double fanIn) { |
---|
620 | return z; |
---|
621 | } |
---|
622 | }, |
---|
623 | EXPONENTIAL("exponential") { |
---|
624 | double eval(double z, double threshold, double altitude, double fanIn) { |
---|
625 | return Math.exp(z); |
---|
626 | } |
---|
627 | }, |
---|
628 | RECIPROCAL("reciprocal") { |
---|
629 | double eval(double z, double threshold, double altitude, double fanIn) { |
---|
630 | return 1.0 / z; |
---|
631 | } |
---|
632 | }, |
---|
633 | SQUARE("square") { |
---|
634 | double eval(double z, double threshold, double altitude, double fanIn) { |
---|
635 | return z * z; |
---|
636 | } |
---|
637 | }, |
---|
638 | GAUSS("gauss") { |
---|
639 | double eval(double z, double threshold, double altitude, double fanIn) { |
---|
640 | return Math.exp(-(z * z)); |
---|
641 | } |
---|
642 | }, |
---|
643 | SINE("sine") { |
---|
644 | double eval(double z, double threshold, double altitude, double fanIn) { |
---|
645 | return Math.sin(z); |
---|
646 | } |
---|
647 | }, |
---|
648 | COSINE("cosine") { |
---|
649 | double eval(double z, double threshold, double altitude, double fanIn) { |
---|
650 | return Math.cos(z); |
---|
651 | } |
---|
652 | }, |
---|
653 | ELLICOT("ellicot") { |
---|
654 | double eval(double z, double threshold, double altitude, double fanIn) { |
---|
655 | return z / (1.0 + Math.abs(z)); |
---|
656 | } |
---|
657 | }, |
---|
658 | ARCTAN("arctan") { |
---|
659 | double eval(double z, double threshold, double altitude, double fanIn) { |
---|
660 | return 2.0 * Math.atan(z) / Math.PI; |
---|
661 | } |
---|
662 | }, |
---|
663 | RADIALBASIS("radialBasis") { |
---|
664 | double eval(double z, double threshold, double altitude, double fanIn) { |
---|
665 | return Math.exp(fanIn * Math.log(altitude) - z); |
---|
666 | } |
---|
667 | }; |
---|
668 | |
---|
669 | abstract double eval(double z, double threshold, double altitude, double fanIn); |
---|
670 | |
---|
671 | private final String m_stringVal; |
---|
672 | |
---|
673 | ActivationFunction(String name) { |
---|
674 | m_stringVal = name; |
---|
675 | } |
---|
676 | |
---|
677 | public String toString() { |
---|
678 | return m_stringVal; |
---|
679 | } |
---|
680 | } |
---|
681 | |
---|
682 | /** The activation function to use */ |
---|
683 | protected ActivationFunction m_activationFunction = ActivationFunction.ARCTAN; |
---|
684 | |
---|
685 | /** |
---|
686 | * Enumerated type for the normalization method |
---|
687 | */ |
---|
688 | enum Normalization { |
---|
689 | NONE ("none"), |
---|
690 | SIMPLEMAX ("simplemax"), |
---|
691 | SOFTMAX ("softmax"); |
---|
692 | |
---|
693 | private final String m_stringVal; |
---|
694 | |
---|
695 | Normalization(String name) { |
---|
696 | m_stringVal = name; |
---|
697 | } |
---|
698 | |
---|
699 | public String toString() { |
---|
700 | return m_stringVal; |
---|
701 | } |
---|
702 | } |
---|
703 | |
---|
704 | /** The normalization method */ |
---|
705 | protected Normalization m_normalizationMethod = Normalization.NONE; |
---|
706 | |
---|
707 | /** Threshold activation */ |
---|
708 | protected double m_threshold = 0.0; // default = 0 |
---|
709 | |
---|
710 | /** Width for radial basis */ |
---|
711 | protected double m_width = Double.NaN; // no default |
---|
712 | |
---|
713 | /** Altitude for radial basis */ |
---|
714 | protected double m_altitude = 1.0; // default = 1 |
---|
715 | |
---|
716 | /** The number of inputs to the network */ |
---|
717 | protected int m_numberOfInputs = 0; |
---|
718 | |
---|
719 | /** Number of hidden layers in the network */ |
---|
720 | protected int m_numberOfLayers = 0; |
---|
721 | |
---|
722 | /** The inputs to the network */ |
---|
723 | protected NeuralInput[] m_inputs = null; |
---|
724 | |
---|
725 | /** A map for storing network input values (computed from an incoming instance) */ |
---|
726 | protected HashMap<String, Double> m_inputMap = new HashMap<String, Double>(); |
---|
727 | |
---|
728 | /** The hidden layers in the network */ |
---|
729 | protected NeuralLayer[] m_layers = null; |
---|
730 | |
---|
731 | /** The outputs of the network */ |
---|
732 | protected NeuralOutputs m_outputs = null; |
---|
733 | |
---|
734 | public NeuralNetwork(Element model, Instances dataDictionary, |
---|
735 | MiningSchema miningSchema) throws Exception { |
---|
736 | |
---|
737 | super(dataDictionary, miningSchema); |
---|
738 | |
---|
739 | String fn = model.getAttribute("functionName"); |
---|
740 | if (fn.equals("regression")) { |
---|
741 | m_functionType = MiningFunction.REGRESSION; |
---|
742 | } |
---|
743 | |
---|
744 | String act = model.getAttribute("activationFunction"); |
---|
745 | if (act == null || act.length() == 0) { |
---|
746 | throw new Exception("[NeuralNetwork] no activation functon defined"); |
---|
747 | } |
---|
748 | |
---|
749 | // get the activation function |
---|
750 | for (ActivationFunction a : ActivationFunction.values()) { |
---|
751 | if (a.toString().equals(act)) { |
---|
752 | m_activationFunction = a; |
---|
753 | break; |
---|
754 | } |
---|
755 | } |
---|
756 | |
---|
757 | // get the normalization method (if specified) |
---|
758 | String norm = model.getAttribute("normalizationMethod"); |
---|
759 | if (norm != null && norm.length() > 0) { |
---|
760 | for (Normalization n : Normalization.values()) { |
---|
761 | if (n.toString().equals(norm)) { |
---|
762 | m_normalizationMethod = n; |
---|
763 | break; |
---|
764 | } |
---|
765 | } |
---|
766 | } |
---|
767 | |
---|
768 | String thresh = model.getAttribute("threshold"); |
---|
769 | if (thresh != null && thresh.length() > 0) { |
---|
770 | m_threshold = Double.parseDouble(thresh); |
---|
771 | } |
---|
772 | String width = model.getAttribute("width"); |
---|
773 | if (width != null && width.length() > 0) { |
---|
774 | m_width = Double.parseDouble(width); |
---|
775 | } |
---|
776 | String alt = model.getAttribute("altitude"); |
---|
777 | if (alt != null && alt.length() > 0) { |
---|
778 | m_altitude = Double.parseDouble(alt); |
---|
779 | } |
---|
780 | |
---|
781 | // get all the inputs |
---|
782 | NodeList inputL = model.getElementsByTagName("NeuralInput"); |
---|
783 | m_numberOfInputs = inputL.getLength(); |
---|
784 | m_inputs = new NeuralInput[m_numberOfInputs]; |
---|
785 | for (int i = 0; i < m_numberOfInputs; i++) { |
---|
786 | Node inputN = inputL.item(i); |
---|
787 | if (inputN.getNodeType() == Node.ELEMENT_NODE) { |
---|
788 | NeuralInput nI = new NeuralInput((Element)inputN, m_miningSchema); |
---|
789 | m_inputs[i] = nI; |
---|
790 | } |
---|
791 | } |
---|
792 | |
---|
793 | // get the layers |
---|
794 | NodeList layerL = model.getElementsByTagName("NeuralLayer"); |
---|
795 | m_numberOfLayers = layerL.getLength(); |
---|
796 | m_layers = new NeuralLayer[m_numberOfLayers]; |
---|
797 | for (int i = 0; i < m_numberOfLayers; i++) { |
---|
798 | Node layerN = layerL.item(i); |
---|
799 | if (layerN.getNodeType() == Node.ELEMENT_NODE) { |
---|
800 | NeuralLayer nL = new NeuralLayer((Element)layerN); |
---|
801 | m_layers[i] = nL; |
---|
802 | } |
---|
803 | } |
---|
804 | |
---|
805 | // get the outputs |
---|
806 | NodeList outputL = model.getElementsByTagName("NeuralOutputs"); |
---|
807 | if (outputL.getLength() != 1) { |
---|
808 | throw new Exception("[NeuralNetwork] Should be just one NeuralOutputs element defined!"); |
---|
809 | } |
---|
810 | |
---|
811 | m_outputs = new NeuralOutputs((Element)outputL.item(0), m_miningSchema); |
---|
812 | } |
---|
813 | |
---|
814 | /* (non-Javadoc) |
---|
815 | * @see weka.core.RevisionHandler#getRevision() |
---|
816 | */ |
---|
817 | public String getRevision() { |
---|
818 | return RevisionUtils.extract("$Revision: 5987 $"); |
---|
819 | } |
---|
820 | |
---|
821 | /** |
---|
822 | * Classifies the given test instance. The instance has to belong to a |
---|
823 | * dataset when it's being classified. |
---|
824 | * |
---|
825 | * @param inst the instance to be classified |
---|
826 | * @return the predicted most likely class for the instance or |
---|
827 | * Utils.missingValue() if no prediction is made |
---|
828 | * @exception Exception if an error occurred during the prediction |
---|
829 | */ |
---|
830 | public double[] distributionForInstance(Instance inst) throws Exception { |
---|
831 | if (!m_initialized) { |
---|
832 | mapToMiningSchema(inst.dataset()); |
---|
833 | } |
---|
834 | double[] preds = null; |
---|
835 | |
---|
836 | if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) { |
---|
837 | preds = new double[1]; |
---|
838 | } else { |
---|
839 | preds = new double[m_miningSchema.getFieldsAsInstances().classAttribute().numValues()]; |
---|
840 | } |
---|
841 | |
---|
842 | double[] incoming = m_fieldsMap.instanceToSchema(inst, m_miningSchema); |
---|
843 | |
---|
844 | boolean hasMissing = false; |
---|
845 | for (int i = 0; i < incoming.length; i++) { |
---|
846 | if (i != m_miningSchema.getFieldsAsInstances().classIndex() && |
---|
847 | Double.isNaN(incoming[i])) { |
---|
848 | hasMissing = true; |
---|
849 | //System.err.println("Missing value for att : " + m_miningSchema.getFieldsAsInstances().attribute(i).name()); |
---|
850 | break; |
---|
851 | } |
---|
852 | } |
---|
853 | |
---|
854 | if (hasMissing) { |
---|
855 | if (!m_miningSchema.hasTargetMetaData()) { |
---|
856 | String message = "[NeuralNetwork] WARNING: Instance to predict has missing value(s) but " |
---|
857 | + "there is no missing value handling meta data and no " |
---|
858 | + "prior probabilities/default value to fall back to. No " |
---|
859 | + "prediction will be made (" |
---|
860 | + ((m_miningSchema.getFieldsAsInstances().classAttribute().isNominal() |
---|
861 | || m_miningSchema.getFieldsAsInstances().classAttribute().isString()) |
---|
862 | ? "zero probabilities output)." |
---|
863 | : "NaN output)."); |
---|
864 | if (m_log == null) { |
---|
865 | System.err.println(message); |
---|
866 | } else { |
---|
867 | m_log.logMessage(message); |
---|
868 | } |
---|
869 | |
---|
870 | if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) { |
---|
871 | preds[0] = Utils.missingValue(); |
---|
872 | } |
---|
873 | return preds; |
---|
874 | } else { |
---|
875 | // use prior probablilities/default value |
---|
876 | TargetMetaInfo targetData = m_miningSchema.getTargetMetaData(); |
---|
877 | if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) { |
---|
878 | preds[0] = targetData.getDefaultValue(); |
---|
879 | } else { |
---|
880 | Instances miningSchemaI = m_miningSchema.getFieldsAsInstances(); |
---|
881 | for (int i = 0; i < miningSchemaI.classAttribute().numValues(); i++) { |
---|
882 | preds[i] = targetData.getPriorProbability(miningSchemaI.classAttribute().value(i)); |
---|
883 | } |
---|
884 | } |
---|
885 | return preds; |
---|
886 | } |
---|
887 | } else { |
---|
888 | |
---|
889 | // construct the input to the network for this instance |
---|
890 | m_inputMap.clear(); |
---|
891 | for (int i = 0; i < m_inputs.length; i++) { |
---|
892 | double networkInVal = m_inputs[i].getValue(incoming); |
---|
893 | String ID = m_inputs[i].getID(); |
---|
894 | m_inputMap.put(ID, networkInVal); |
---|
895 | } |
---|
896 | |
---|
897 | // now compute the output of each layer |
---|
898 | HashMap<String, Double> layerOut = m_layers[0].computeOutput(m_inputMap); |
---|
899 | for (int i = 1; i < m_layers.length; i++) { |
---|
900 | layerOut = m_layers[i].computeOutput(layerOut); |
---|
901 | } |
---|
902 | |
---|
903 | // now do the output |
---|
904 | m_outputs.getOuput(layerOut, preds); |
---|
905 | } |
---|
906 | |
---|
907 | return preds; |
---|
908 | } |
---|
909 | |
---|
910 | public String toString() { |
---|
911 | StringBuffer temp = new StringBuffer(); |
---|
912 | |
---|
913 | temp.append("PMML version " + getPMMLVersion()); |
---|
914 | if (!getCreatorApplication().equals("?")) { |
---|
915 | temp.append("\nApplication: " + getCreatorApplication()); |
---|
916 | } |
---|
917 | temp.append("\nPMML Model: Neural network"); |
---|
918 | temp.append("\n\n"); |
---|
919 | temp.append(m_miningSchema); |
---|
920 | |
---|
921 | temp.append("Inputs:\n"); |
---|
922 | for (int i = 0; i < m_inputs.length; i++) { |
---|
923 | temp.append(m_inputs[i] + "\n"); |
---|
924 | } |
---|
925 | |
---|
926 | for (int i = 0; i < m_layers.length; i++) { |
---|
927 | temp.append("Layer: " + (i+1) + "\n"); |
---|
928 | temp.append(m_layers[i] + "\n"); |
---|
929 | } |
---|
930 | |
---|
931 | temp.append("Outputs:\n"); |
---|
932 | temp.append(m_outputs); |
---|
933 | |
---|
934 | return temp.toString(); |
---|
935 | } |
---|
936 | } |
---|