1 | /* |
---|
2 | * This program is free software; you can redistribute it and/or modify |
---|
3 | * it under the terms of the GNU General Public License as published by |
---|
4 | * the Free Software Foundation; either version 2 of the License, or |
---|
5 | * (at your option) any later version. |
---|
6 | * |
---|
7 | * This program is distributed in the hope that it will be useful, |
---|
8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
---|
10 | * GNU General Public License for more details. |
---|
11 | * |
---|
12 | * You should have received a copy of the GNU General Public License |
---|
13 | * along with this program; if not, write to the Free Software |
---|
14 | * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. |
---|
15 | */ |
---|
16 | |
---|
17 | /* |
---|
18 | * EnsembleSelection.java |
---|
19 | * Copyright (C) 2006 David Michael |
---|
20 | * |
---|
21 | */ |
---|
22 | |
---|
23 | package weka.classifiers.meta.ensembleSelection; |
---|
24 | |
---|
25 | import weka.classifiers.Classifier; |
---|
26 | import weka.classifiers.AbstractClassifier; |
---|
27 | import weka.classifiers.EnsembleLibraryModel; |
---|
28 | import weka.core.Instance; |
---|
29 | import weka.core.Instances; |
---|
30 | import weka.core.RevisionUtils; |
---|
31 | import weka.core.Utils; |
---|
32 | import weka.core.OptionHandler; |
---|
33 | |
---|
34 | import java.io.File; |
---|
35 | import java.io.FileInputStream; |
---|
36 | import java.io.FileOutputStream; |
---|
37 | import java.io.IOException; |
---|
38 | import java.io.ObjectInputStream; |
---|
39 | import java.io.ObjectOutput; |
---|
40 | import java.io.ObjectOutputStream; |
---|
41 | import java.io.Serializable; |
---|
42 | import java.io.UnsupportedEncodingException; |
---|
43 | import java.util.Date; |
---|
44 | import java.util.zip.Adler32; |
---|
45 | |
---|
46 | /** |
---|
47 | * This class represents a library model that is used for EnsembleSelection. At |
---|
48 | * this level the concept of cross validation is abstracted away. This class |
---|
49 | * keeps track of the performance statistics and bookkeeping information for its |
---|
50 | * "model type" accross all the CV folds. By "model type", I mean the |
---|
51 | * combination of both the Classifier type (e.g. J48), and its set of parameters |
---|
52 | * (e.g. -C 0.5 -X 1 -Y 5). So for example, if you are using 5 fold cross |
---|
53 | * validaiton, this model will keep an array of classifiers[] of length 5 and |
---|
54 | * will keep track of their performances accordingly. This class also has |
---|
55 | * methods to deal with serializing all of this information into the .elm file |
---|
56 | * that will represent this model. |
---|
57 | * <p/> |
---|
58 | * Also it is worth mentioning that another important function of this class is |
---|
59 | * to track all of the dataset information that was used to create this model. |
---|
60 | * This is because we want to protect users from doing foreseeably bad things. |
---|
61 | * e.g., trying to build an ensemble for a dataset with models that were trained |
---|
62 | * on the wrong partitioning of the dataset. This could lead to artificially high |
---|
63 | * performance due to the fact that instances used for the test set to gauge |
---|
64 | * performance could have accidentally been used to train the base classifiers. |
---|
65 | * So in a nutshell, we are preventing people from unintentionally "cheating" by |
---|
66 | * enforcing that the seed, #folds, validation ration, and the checksum of the |
---|
67 | * Instances.toString() method ALL match exactly. Otherwise we throw an |
---|
68 | * exception. |
---|
69 | * |
---|
70 | * @author Robert Jung (mrbobjung@gmail.com) |
---|
71 | * @version $Revision: 5928 $ |
---|
72 | */ |
---|
73 | public class EnsembleSelectionLibraryModel |
---|
74 | extends EnsembleLibraryModel |
---|
75 | implements Serializable { |
---|
76 | |
---|
77 | /** |
---|
78 | * This is the serialVersionUID that SHOULD stay the same so that future |
---|
79 | * modified versions of this class will be backwards compatible with older |
---|
80 | * model versions. |
---|
81 | */ |
---|
82 | private static final long serialVersionUID = -6426075459862947640L; |
---|
83 | |
---|
84 | /** The default file extension for ensemble library models */ |
---|
85 | public static final String FILE_EXTENSION = ".elm"; |
---|
86 | |
---|
87 | /** the models */ |
---|
88 | private Classifier[] m_models = null; |
---|
89 | |
---|
90 | /** The seed that was used to create this model */ |
---|
91 | private int m_seed; |
---|
92 | |
---|
93 | /** |
---|
94 | * The checksum of the instances.arff object that was used to create this |
---|
95 | * model |
---|
96 | */ |
---|
97 | private String m_checksum; |
---|
98 | |
---|
99 | /** The validation ratio that was used to create this model */ |
---|
100 | private double m_validationRatio; |
---|
101 | |
---|
102 | /** |
---|
103 | * The number of folds, or number of CV models that was used to create this |
---|
104 | * "model" |
---|
105 | */ |
---|
106 | private int m_folds; |
---|
107 | |
---|
108 | /** |
---|
109 | * The .elm file name that this model should be saved/loaded to/from |
---|
110 | */ |
---|
111 | private String m_fileName; |
---|
112 | |
---|
113 | /** |
---|
114 | * The debug flag as propagated from the main EnsembleSelection class. |
---|
115 | */ |
---|
116 | public transient boolean m_Debug = true; |
---|
117 | |
---|
118 | /** |
---|
119 | * the validation predictions of this model. First index for the instance. |
---|
120 | * third is for the class (we use distributionForInstance). |
---|
121 | */ |
---|
122 | private double[][] m_validationPredictions = null; // = new double[0][0]; |
---|
123 | |
---|
124 | /** |
---|
125 | * Default Constructor |
---|
126 | */ |
---|
127 | public EnsembleSelectionLibraryModel() { |
---|
128 | } |
---|
129 | |
---|
130 | /** |
---|
131 | * Constructor for LibaryModel |
---|
132 | * |
---|
133 | * @param classifier the classifier to use |
---|
134 | * @param seed the random seed value |
---|
135 | * @param checksum the checksum |
---|
136 | * @param validationRatio the ration to use |
---|
137 | * @param folds the number of folds to use |
---|
138 | */ |
---|
139 | public EnsembleSelectionLibraryModel(Classifier classifier, int seed, |
---|
140 | String checksum, double validationRatio, int folds) { |
---|
141 | |
---|
142 | super(classifier); |
---|
143 | |
---|
144 | m_seed = seed; |
---|
145 | m_checksum = checksum; |
---|
146 | m_validationRatio = validationRatio; |
---|
147 | m_models = null; |
---|
148 | m_folds = folds; |
---|
149 | } |
---|
150 | |
---|
151 | /** |
---|
152 | * This is used to propagate the m_Debug flag of the EnsembleSelection |
---|
153 | * classifier to this class. There are things we would want to print out |
---|
154 | * here also. |
---|
155 | * |
---|
156 | * @param debug if true additional information is output |
---|
157 | */ |
---|
158 | public void setDebug(boolean debug) { |
---|
159 | m_Debug = debug; |
---|
160 | } |
---|
161 | |
---|
162 | /** |
---|
163 | * Returns the average of the prediction of the models across all folds. |
---|
164 | * |
---|
165 | * @param instance the instance to get predictions for |
---|
166 | * @return the average prediction |
---|
167 | * @throws Exception if something goes wrong |
---|
168 | */ |
---|
169 | public double[] getAveragePrediction(Instance instance) throws Exception { |
---|
170 | |
---|
171 | // Return the average prediction from all classifiers that make up |
---|
172 | // this model. |
---|
173 | double average[] = new double[instance.numClasses()]; |
---|
174 | for (int i = 0; i < m_folds; ++i) { |
---|
175 | // Some models alter the instance (MultiLayerPerceptron), so we need |
---|
176 | // to copy it. |
---|
177 | Instance temp_instance = (Instance) instance.copy(); |
---|
178 | double[] pred = getFoldPrediction(temp_instance, i); |
---|
179 | if (pred == null) { |
---|
180 | // Some models have bugs whereby they can return a null |
---|
181 | // prediction |
---|
182 | // array (again, MultiLayerPerceptron). We return null, and this |
---|
183 | // should be handled above in EnsembleSelection. |
---|
184 | System.err.println("Null validation predictions given: " |
---|
185 | + getStringRepresentation()); |
---|
186 | return null; |
---|
187 | } |
---|
188 | if (i == 0) { |
---|
189 | // The first time through the loop, just use the first returned |
---|
190 | // prediction array. Just a simple optimization. |
---|
191 | average = pred; |
---|
192 | } else { |
---|
193 | // For the rest, add the prediction to the average array. |
---|
194 | for (int j = 0; j < pred.length; ++j) { |
---|
195 | average[j] += pred[j]; |
---|
196 | } |
---|
197 | } |
---|
198 | } |
---|
199 | if (instance.classAttribute().isNominal()) { |
---|
200 | // Normalize predictions for classes to add up to 1. |
---|
201 | Utils.normalize(average); |
---|
202 | } else { |
---|
203 | average[0] /= m_folds; |
---|
204 | } |
---|
205 | return average; |
---|
206 | } |
---|
207 | |
---|
208 | /** |
---|
209 | * Basic Constructor |
---|
210 | * |
---|
211 | * @param classifier the classifier to use |
---|
212 | */ |
---|
213 | public EnsembleSelectionLibraryModel(Classifier classifier) { |
---|
214 | super(classifier); |
---|
215 | } |
---|
216 | |
---|
217 | /** |
---|
218 | * Returns prediction of the classifier for the specified fold. |
---|
219 | * |
---|
220 | * @param instance |
---|
221 | * instance for which to make a prediction. |
---|
222 | * @param fold |
---|
223 | * fold number of the classifier to use. |
---|
224 | * @return the prediction for the classes |
---|
225 | * @throws Exception if prediction fails |
---|
226 | */ |
---|
227 | public double[] getFoldPrediction(Instance instance, int fold) |
---|
228 | throws Exception { |
---|
229 | |
---|
230 | return m_models[fold].distributionForInstance(instance); |
---|
231 | } |
---|
232 | |
---|
233 | /** |
---|
234 | * Creates the model. If there are n folds, it constructs n classifiers |
---|
235 | * using the current Classifier class and options. If the model has already |
---|
236 | * been created or loaded, starts fresh. |
---|
237 | * |
---|
238 | * @param data the data to work with |
---|
239 | * @param hillclimbData the data for hillclimbing |
---|
240 | * @param dataDirectoryName the directory to use |
---|
241 | * @param algorithm the type of algorithm |
---|
242 | * @throws Exception if something goeds wrong |
---|
243 | */ |
---|
244 | public void createModel(Instances[] data, Instances[] hillclimbData, |
---|
245 | String dataDirectoryName, int algorithm) throws Exception { |
---|
246 | |
---|
247 | String modelFileName = getFileName(getStringRepresentation()); |
---|
248 | |
---|
249 | File modelFile = new File(dataDirectoryName, modelFileName); |
---|
250 | |
---|
251 | String relativePath = (new File(dataDirectoryName)).getName() |
---|
252 | + File.separatorChar + modelFileName; |
---|
253 | // if (m_Debug) System.out.println("setting relative path to: |
---|
254 | // "+relativePath); |
---|
255 | setFileName(relativePath); |
---|
256 | |
---|
257 | if (!modelFile.exists()) { |
---|
258 | |
---|
259 | Date startTime = new Date(); |
---|
260 | |
---|
261 | String lockFileName = EnsembleSelectionLibraryModel |
---|
262 | .getFileName(getStringRepresentation()); |
---|
263 | lockFileName = lockFileName.substring(0, lockFileName.length() - 3) |
---|
264 | + "LCK"; |
---|
265 | File lockFile = new File(dataDirectoryName, lockFileName); |
---|
266 | |
---|
267 | if (lockFile.exists()) { |
---|
268 | if (m_Debug) |
---|
269 | System.out.println("Detected lock file. Skipping: " |
---|
270 | + lockFileName); |
---|
271 | throw new Exception("Lock File Detected: " + lockFile.getName()); |
---|
272 | |
---|
273 | } else { // if (algorithm == |
---|
274 | // EnsembleSelection.ALGORITHM_BUILD_LIBRARY) { |
---|
275 | // This lock file lets other computers that might be sharing the |
---|
276 | // same file |
---|
277 | // system that this model is already being trained so they know |
---|
278 | // to move ahead |
---|
279 | // and train other models. |
---|
280 | |
---|
281 | if (lockFile.createNewFile()) { |
---|
282 | |
---|
283 | if (m_Debug) |
---|
284 | System.out |
---|
285 | .println("lock file created: " + lockFileName); |
---|
286 | |
---|
287 | if (m_Debug) |
---|
288 | System.out.println("Creating model in locked mode: " |
---|
289 | + modelFile.getPath()); |
---|
290 | |
---|
291 | m_models = new Classifier[m_folds]; |
---|
292 | for (int i = 0; i < m_folds; ++i) { |
---|
293 | |
---|
294 | try { |
---|
295 | m_models[i] = AbstractClassifier.forName(getModelClass() |
---|
296 | .getName(), null); |
---|
297 | ((OptionHandler)m_models[i]).setOptions(getOptions()); |
---|
298 | } catch (Exception e) { |
---|
299 | throw new Exception("Invalid Options: " |
---|
300 | + e.getMessage()); |
---|
301 | } |
---|
302 | } |
---|
303 | |
---|
304 | try { |
---|
305 | for (int i = 0; i < m_folds; ++i) { |
---|
306 | train(data[i], i); |
---|
307 | } |
---|
308 | } catch (Exception e) { |
---|
309 | throw new Exception("Could not Train: " |
---|
310 | + e.getMessage()); |
---|
311 | } |
---|
312 | |
---|
313 | Date endTime = new Date(); |
---|
314 | int diff = (int) (endTime.getTime() - startTime.getTime()); |
---|
315 | |
---|
316 | // We don't need the actual model for hillclimbing. To save |
---|
317 | // memory, release |
---|
318 | // it. |
---|
319 | |
---|
320 | // if (!invalidModels.contains(model)) { |
---|
321 | // EnsembleLibraryModel.saveModel(dataDirectory.getPath(), |
---|
322 | // model); |
---|
323 | // model.releaseModel(); |
---|
324 | // } |
---|
325 | if (m_Debug) |
---|
326 | System.out.println("Train time for " + modelFileName |
---|
327 | + " was: " + diff); |
---|
328 | |
---|
329 | if (m_Debug) |
---|
330 | System.out |
---|
331 | .println("Generating validation set predictions"); |
---|
332 | |
---|
333 | startTime = new Date(); |
---|
334 | |
---|
335 | int total = 0; |
---|
336 | for (int i = 0; i < m_folds; ++i) { |
---|
337 | total += hillclimbData[i].numInstances(); |
---|
338 | } |
---|
339 | |
---|
340 | m_validationPredictions = new double[total][]; |
---|
341 | |
---|
342 | int preds_index = 0; |
---|
343 | for (int i = 0; i < m_folds; ++i) { |
---|
344 | for (int j = 0; j < hillclimbData[i].numInstances(); ++j) { |
---|
345 | Instance temp = (Instance) hillclimbData[i] |
---|
346 | .instance(j).copy();// new |
---|
347 | // Instance(m_hillclimbData[i].instance(j)); |
---|
348 | // must copy the instance because SOME classifiers |
---|
349 | // (I'm not pointing fingers... |
---|
350 | // MULTILAYERPERCEPTRON) |
---|
351 | // change the instance! |
---|
352 | |
---|
353 | m_validationPredictions[preds_index] = getFoldPrediction( |
---|
354 | temp, i); |
---|
355 | |
---|
356 | if (m_validationPredictions[preds_index] == null) { |
---|
357 | throw new Exception( |
---|
358 | "Null validation predictions given: " |
---|
359 | + getStringRepresentation()); |
---|
360 | } |
---|
361 | |
---|
362 | ++preds_index; |
---|
363 | } |
---|
364 | } |
---|
365 | |
---|
366 | endTime = new Date(); |
---|
367 | diff = (int) (endTime.getTime() - startTime.getTime()); |
---|
368 | |
---|
369 | // if (m_Debug) System.out.println("Generated a validation |
---|
370 | // set array of size: "+m_validationPredictions.length); |
---|
371 | if (m_Debug) |
---|
372 | System.out |
---|
373 | .println("Time to create validation predictions was: " |
---|
374 | + diff); |
---|
375 | |
---|
376 | EnsembleSelectionLibraryModel.saveModel(dataDirectoryName, |
---|
377 | this); |
---|
378 | |
---|
379 | if (m_Debug) |
---|
380 | System.out.println("deleting lock file: " |
---|
381 | + lockFileName); |
---|
382 | lockFile.delete(); |
---|
383 | |
---|
384 | } else { |
---|
385 | |
---|
386 | if (m_Debug) |
---|
387 | System.out |
---|
388 | .println("Could not create lock file. Skipping: " |
---|
389 | + lockFileName); |
---|
390 | throw new Exception( |
---|
391 | "Could not create lock file. Skipping: " |
---|
392 | + lockFile.getName()); |
---|
393 | |
---|
394 | } |
---|
395 | |
---|
396 | } |
---|
397 | |
---|
398 | } else { |
---|
399 | // This branch is responsible for loading a model from a .elm file |
---|
400 | |
---|
401 | if (m_Debug) |
---|
402 | System.out.println("Loading model: " + modelFile.getPath()); |
---|
403 | // now we need to check to see if the model is valid, if so then |
---|
404 | // load it |
---|
405 | Date startTime = new Date(); |
---|
406 | |
---|
407 | EnsembleSelectionLibraryModel newModel = loadModel(modelFile |
---|
408 | .getPath()); |
---|
409 | |
---|
410 | if (!newModel.getStringRepresentation().equals( |
---|
411 | getStringRepresentation())) |
---|
412 | throw new EnsembleModelMismatchException( |
---|
413 | "String representations " |
---|
414 | + newModel.getStringRepresentation() + " and " |
---|
415 | + getStringRepresentation() + " not equal"); |
---|
416 | |
---|
417 | if (!newModel.getChecksum().equals(getChecksum())) |
---|
418 | throw new EnsembleModelMismatchException("Checksums " |
---|
419 | + newModel.getChecksum() + " and " + getChecksum() |
---|
420 | + " not equal"); |
---|
421 | |
---|
422 | if (newModel.getSeed() != getSeed()) |
---|
423 | throw new EnsembleModelMismatchException("Seeds " |
---|
424 | + newModel.getSeed() + " and " + getSeed() |
---|
425 | + " not equal"); |
---|
426 | |
---|
427 | if (newModel.getFolds() != getFolds()) |
---|
428 | throw new EnsembleModelMismatchException("Folds " |
---|
429 | + newModel.getFolds() + " and " + getFolds() |
---|
430 | + " not equal"); |
---|
431 | |
---|
432 | if (newModel.getValidationRatio() != getValidationRatio()) |
---|
433 | throw new EnsembleModelMismatchException("Validation Ratios " |
---|
434 | + newModel.getValidationRatio() + " and " |
---|
435 | + getValidationRatio() + " not equal"); |
---|
436 | |
---|
437 | // setFileName(modelFileName); |
---|
438 | |
---|
439 | m_models = newModel.getModels(); |
---|
440 | m_validationPredictions = newModel.getValidationPredictions(); |
---|
441 | |
---|
442 | Date endTime = new Date(); |
---|
443 | int diff = (int) (endTime.getTime() - startTime.getTime()); |
---|
444 | if (m_Debug) |
---|
445 | System.out.println("Time to load " + modelFileName + " was: " |
---|
446 | + diff); |
---|
447 | } |
---|
448 | } |
---|
449 | |
---|
450 | /** |
---|
451 | * The purpose of this method is to "rehydrate" the classifier object fot |
---|
452 | * this library model from the filesystem. |
---|
453 | * |
---|
454 | * @param workingDirectory the working directory to use |
---|
455 | */ |
---|
456 | public void rehydrateModel(String workingDirectory) { |
---|
457 | |
---|
458 | if (m_models == null) { |
---|
459 | |
---|
460 | File file = new File(workingDirectory, m_fileName); |
---|
461 | |
---|
462 | if (m_Debug) |
---|
463 | System.out.println("Rehydrating Model: " + file.getPath()); |
---|
464 | EnsembleSelectionLibraryModel model = EnsembleSelectionLibraryModel |
---|
465 | .loadModel(file.getPath()); |
---|
466 | |
---|
467 | m_models = model.getModels(); |
---|
468 | |
---|
469 | } |
---|
470 | } |
---|
471 | |
---|
472 | /** |
---|
473 | * Releases the model from memory. TODO - need to be saving these so we can |
---|
474 | * retrieve them later!! |
---|
475 | */ |
---|
476 | public void releaseModel() { |
---|
477 | /* |
---|
478 | * if (m_unsaved) { saveModel(); } |
---|
479 | */ |
---|
480 | m_models = null; |
---|
481 | } |
---|
482 | |
---|
483 | /** |
---|
484 | * Train the classifier for the specified fold on the given data |
---|
485 | * |
---|
486 | * @param trainData the data to train with |
---|
487 | * @param fold the fold number |
---|
488 | * @throws Exception if something goes wrong, e.g., out of memory |
---|
489 | */ |
---|
490 | public void train(Instances trainData, int fold) throws Exception { |
---|
491 | if (m_models != null) { |
---|
492 | |
---|
493 | try { |
---|
494 | // OK, this is it... this is the point where our code surrenders |
---|
495 | // to the weka classifiers. |
---|
496 | m_models[fold].buildClassifier(trainData); |
---|
497 | } catch (Throwable t) { |
---|
498 | m_models[fold] = null; |
---|
499 | throw new Exception( |
---|
500 | "Exception caught while training: (null could mean out of memory)" |
---|
501 | + t.getMessage()); |
---|
502 | } |
---|
503 | |
---|
504 | } else { |
---|
505 | throw new Exception("Cannot train: model was null"); |
---|
506 | // TODO: throw Exception? |
---|
507 | } |
---|
508 | } |
---|
509 | |
---|
510 | /** |
---|
511 | * Set the seed |
---|
512 | * |
---|
513 | * @param seed the seed value |
---|
514 | */ |
---|
515 | public void setSeed(int seed) { |
---|
516 | m_seed = seed; |
---|
517 | } |
---|
518 | |
---|
519 | /** |
---|
520 | * Get the seed |
---|
521 | * |
---|
522 | * @return the seed value |
---|
523 | */ |
---|
524 | public int getSeed() { |
---|
525 | return m_seed; |
---|
526 | } |
---|
527 | |
---|
528 | /** |
---|
529 | * Sets the validation set ratio (only meaningful if folds == 1) |
---|
530 | * |
---|
531 | * @param validationRatio the new ration |
---|
532 | */ |
---|
533 | public void setValidationRatio(double validationRatio) { |
---|
534 | m_validationRatio = validationRatio; |
---|
535 | } |
---|
536 | |
---|
537 | /** |
---|
538 | * get validationRatio |
---|
539 | * |
---|
540 | * @return the current ratio |
---|
541 | */ |
---|
542 | public double getValidationRatio() { |
---|
543 | return m_validationRatio; |
---|
544 | } |
---|
545 | |
---|
546 | /** |
---|
547 | * Set the number of folds for cross validation. The number of folds also |
---|
548 | * indicates how many classifiers will be built to represent this model. |
---|
549 | * |
---|
550 | * @param folds the number of folds to use |
---|
551 | */ |
---|
552 | public void setFolds(int folds) { |
---|
553 | m_folds = folds; |
---|
554 | } |
---|
555 | |
---|
556 | /** |
---|
557 | * get the number of folds |
---|
558 | * |
---|
559 | * @return the current number of folds |
---|
560 | */ |
---|
561 | public int getFolds() { |
---|
562 | return m_folds; |
---|
563 | } |
---|
564 | |
---|
565 | /** |
---|
566 | * set the checksum |
---|
567 | * |
---|
568 | * @param instancesChecksum the new checksum |
---|
569 | */ |
---|
570 | public void setChecksum(String instancesChecksum) { |
---|
571 | m_checksum = instancesChecksum; |
---|
572 | } |
---|
573 | |
---|
574 | /** |
---|
575 | * get the checksum |
---|
576 | * |
---|
577 | * @return the current checksum |
---|
578 | */ |
---|
579 | public String getChecksum() { |
---|
580 | return m_checksum; |
---|
581 | } |
---|
582 | |
---|
583 | /** |
---|
584 | * Returs the array of classifiers |
---|
585 | * |
---|
586 | * @return the current models |
---|
587 | */ |
---|
588 | public Classifier[] getModels() { |
---|
589 | return m_models; |
---|
590 | } |
---|
591 | |
---|
592 | /** |
---|
593 | * Sets the .elm file name for this library model |
---|
594 | * |
---|
595 | * @param fileName the new filename |
---|
596 | */ |
---|
597 | public void setFileName(String fileName) { |
---|
598 | m_fileName = fileName; |
---|
599 | } |
---|
600 | |
---|
601 | /** |
---|
602 | * Gets a checksum for the string defining this classifier. This is used to |
---|
603 | * preserve uniqueness in the classifier names. |
---|
604 | * |
---|
605 | * @param string the classifier definition |
---|
606 | * @return the checksum string |
---|
607 | */ |
---|
608 | public static String getStringChecksum(String string) { |
---|
609 | |
---|
610 | String checksumString = null; |
---|
611 | |
---|
612 | try { |
---|
613 | |
---|
614 | Adler32 checkSummer = new Adler32(); |
---|
615 | |
---|
616 | byte[] utf8 = string.toString().getBytes("UTF8"); |
---|
617 | ; |
---|
618 | |
---|
619 | checkSummer.update(utf8); |
---|
620 | checksumString = Long.toHexString(checkSummer.getValue()); |
---|
621 | |
---|
622 | } catch (UnsupportedEncodingException e) { |
---|
623 | // TODO Auto-generated catch block |
---|
624 | e.printStackTrace(); |
---|
625 | } |
---|
626 | |
---|
627 | return checksumString; |
---|
628 | } |
---|
629 | |
---|
630 | /** |
---|
631 | * The purpose of this method is to get an appropriate file name for a model |
---|
632 | * based on its string representation of a model. All generated filenames |
---|
633 | * are limited to less than 128 characters and all of them will end with a |
---|
634 | * 64 bit checksum value of their string representation to try to maintain |
---|
635 | * some uniqueness of file names. |
---|
636 | * |
---|
637 | * @param stringRepresentation string representation of model |
---|
638 | * @return unique filename |
---|
639 | */ |
---|
640 | public static String getFileName(String stringRepresentation) { |
---|
641 | |
---|
642 | // Get rid of space and quote marks(windows doesn't lke them) |
---|
643 | String fileName = stringRepresentation.trim().replace(' ', '_') |
---|
644 | .replace('"', '_'); |
---|
645 | |
---|
646 | if (fileName.length() > 115) { |
---|
647 | |
---|
648 | fileName = fileName.substring(0, 115); |
---|
649 | |
---|
650 | } |
---|
651 | |
---|
652 | fileName += getStringChecksum(stringRepresentation) |
---|
653 | + EnsembleSelectionLibraryModel.FILE_EXTENSION; |
---|
654 | |
---|
655 | return fileName; |
---|
656 | } |
---|
657 | |
---|
658 | /** |
---|
659 | * Saves the given model to the specified file. |
---|
660 | * |
---|
661 | * @param directory the directory to save the model to |
---|
662 | * @param model the model to save |
---|
663 | */ |
---|
664 | public static void saveModel(String directory, |
---|
665 | EnsembleSelectionLibraryModel model) { |
---|
666 | |
---|
667 | try { |
---|
668 | String fileName = getFileName(model.getStringRepresentation()); |
---|
669 | |
---|
670 | File file = new File(directory, fileName); |
---|
671 | |
---|
672 | // System.out.println("Saving model: "+file.getPath()); |
---|
673 | |
---|
674 | // model.setFileName(new String(file.getPath())); |
---|
675 | |
---|
676 | // Serialize to a file |
---|
677 | ObjectOutput out = new ObjectOutputStream( |
---|
678 | new FileOutputStream(file)); |
---|
679 | out.writeObject(model); |
---|
680 | |
---|
681 | out.close(); |
---|
682 | |
---|
683 | } catch (IOException e) { |
---|
684 | |
---|
685 | e.printStackTrace(); |
---|
686 | } |
---|
687 | } |
---|
688 | |
---|
689 | /** |
---|
690 | * loads the specified model |
---|
691 | * |
---|
692 | * @param modelFilePath the path of the model |
---|
693 | * @return the model |
---|
694 | */ |
---|
695 | public static EnsembleSelectionLibraryModel loadModel(String modelFilePath) { |
---|
696 | |
---|
697 | EnsembleSelectionLibraryModel model = null; |
---|
698 | |
---|
699 | try { |
---|
700 | |
---|
701 | File file = new File(modelFilePath); |
---|
702 | |
---|
703 | ObjectInputStream in = new ObjectInputStream(new FileInputStream( |
---|
704 | file)); |
---|
705 | |
---|
706 | model = (EnsembleSelectionLibraryModel) in.readObject(); |
---|
707 | |
---|
708 | in.close(); |
---|
709 | |
---|
710 | } catch (ClassNotFoundException e) { |
---|
711 | |
---|
712 | e.printStackTrace(); |
---|
713 | |
---|
714 | } catch (IOException e) { |
---|
715 | |
---|
716 | e.printStackTrace(); |
---|
717 | |
---|
718 | } |
---|
719 | |
---|
720 | return model; |
---|
721 | } |
---|
722 | |
---|
723 | /* |
---|
724 | * Problems persist in this code so we left it commented out. The intent was |
---|
725 | * to create the methods necessary for custom serialization to allow for |
---|
726 | * forwards/backwards compatability of .elm files accross multiple versions |
---|
727 | * of this classifier. The main problem however is that these methods do not |
---|
728 | * appear to be called. I'm not sure what the problem is, but this would be |
---|
729 | * a great feature. If anyone is a seasoned veteran of this serialization |
---|
730 | * stuff, please help! |
---|
731 | * |
---|
732 | * private void writeObject(ObjectOutputStream stream) throws IOException { |
---|
733 | * //stream.defaultWriteObject(); //stream.writeObject(b); |
---|
734 | * |
---|
735 | * //first serialize the LibraryModel fields |
---|
736 | * |
---|
737 | * //super.writeObject(stream); |
---|
738 | * |
---|
739 | * //now serialize the LibraryModel fields |
---|
740 | * |
---|
741 | * stream.writeObject(m_Classifier); |
---|
742 | * |
---|
743 | * stream.writeObject(m_DescriptionText); |
---|
744 | * |
---|
745 | * stream.writeObject(m_ErrorText); |
---|
746 | * |
---|
747 | * stream.writeObject(new Boolean(m_OptionsWereValid)); |
---|
748 | * |
---|
749 | * stream.writeObject(m_StringRepresentation); |
---|
750 | * |
---|
751 | * stream.writeObject(m_models); |
---|
752 | * |
---|
753 | * |
---|
754 | * //now serialize the EnsembleLibraryModel fields //stream.writeObject(new |
---|
755 | * String("blah")); |
---|
756 | * |
---|
757 | * stream.writeObject(new Integer(m_seed)); |
---|
758 | * |
---|
759 | * stream.writeObject(m_checksum); |
---|
760 | * |
---|
761 | * stream.writeObject(new Double(m_validationRatio)); |
---|
762 | * |
---|
763 | * stream.writeObject(new Integer(m_folds)); |
---|
764 | * |
---|
765 | * stream.writeObject(m_fileName); |
---|
766 | * |
---|
767 | * stream.writeObject(new Boolean(m_isTrained)); |
---|
768 | * |
---|
769 | * |
---|
770 | * if (m_validationPredictions == null) { |
---|
771 | * } |
---|
772 | * |
---|
773 | * if (m_Debug) System.out.println("Saving |
---|
774 | * "+m_validationPredictions.length+" indexed array"); |
---|
775 | * stream.writeObject(m_validationPredictions); |
---|
776 | * } |
---|
777 | * |
---|
778 | * private void readObject(ObjectInputStream stream) throws IOException, |
---|
779 | * ClassNotFoundException { //stream.defaultReadObject(); //b = (String) |
---|
780 | * stream.readObject(); |
---|
781 | * |
---|
782 | * //super.readObject(stream); |
---|
783 | * |
---|
784 | * //deserialize the LibraryModel fields m_Classifier = |
---|
785 | * (Classifier)stream.readObject(); |
---|
786 | * |
---|
787 | * m_DescriptionText = (String)stream.readObject(); |
---|
788 | * |
---|
789 | * m_ErrorText = (String)stream.readObject(); |
---|
790 | * |
---|
791 | * m_OptionsWereValid = ((Boolean)stream.readObject()).booleanValue(); |
---|
792 | * |
---|
793 | * m_StringRepresentation = (String)stream.readObject(); |
---|
794 | * |
---|
795 | * |
---|
796 | * |
---|
797 | * //now deserialize the EnsembleLibraryModel fields m_models = |
---|
798 | * (Classifier[])stream.readObject(); |
---|
799 | * |
---|
800 | * m_seed = ((Integer)stream.readObject()).intValue(); |
---|
801 | * |
---|
802 | * m_checksum = (String)stream.readObject(); |
---|
803 | * |
---|
804 | * m_validationRatio = ((Double)stream.readObject()).doubleValue(); |
---|
805 | * |
---|
806 | * m_folds = ((Integer)stream.readObject()).intValue(); |
---|
807 | * |
---|
808 | * m_fileName = (String)stream.readObject(); |
---|
809 | * |
---|
810 | * m_isTrained = ((Boolean)stream.readObject()).booleanValue(); |
---|
811 | * |
---|
812 | * m_validationPredictions = (double[][])stream.readObject(); |
---|
813 | * |
---|
814 | * if (m_Debug) System.out.println("Loaded |
---|
815 | * "+m_validationPredictions.length+" indexed array"); } |
---|
816 | * |
---|
817 | */ |
---|
818 | |
---|
819 | /** |
---|
820 | * getter for validation predictions |
---|
821 | * |
---|
822 | * @return the current validation predictions |
---|
823 | */ |
---|
824 | public double[][] getValidationPredictions() { |
---|
825 | return m_validationPredictions; |
---|
826 | } |
---|
827 | |
---|
828 | /** |
---|
829 | * setter for validation predictions |
---|
830 | * |
---|
831 | * @param predictions the new validation predictions |
---|
832 | */ |
---|
833 | public void setValidationPredictions(double[][] predictions) { |
---|
834 | if (m_Debug) |
---|
835 | System.out.println("Saving validation array of size " |
---|
836 | + predictions.length); |
---|
837 | m_validationPredictions = new double[predictions.length][]; |
---|
838 | System.arraycopy(predictions, 0, m_validationPredictions, 0, |
---|
839 | predictions.length); |
---|
840 | } |
---|
841 | |
---|
842 | /** |
---|
843 | * Returns the revision string. |
---|
844 | * |
---|
845 | * @return the revision |
---|
846 | */ |
---|
847 | public String getRevision() { |
---|
848 | return RevisionUtils.extract("$Revision: 5928 $"); |
---|
849 | } |
---|
850 | } |
---|