source: src/main/java/weka/gui/beans/SerializedModelSaver.java @ 20

Last change on this file since 20 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 20.8 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    SerializedModelSaver.java
19 *    Copyright (C) 2008 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.gui.beans;
24
25import java.io.ObjectInputStream;
26import java.io.Serializable;
27import java.io.File;
28import java.io.ObjectOutputStream;
29import java.io.FileOutputStream;
30import java.io.BufferedOutputStream;
31import java.io.IOException;
32import java.awt.BorderLayout;
33import java.beans.EventSetDescriptor;
34import java.util.ArrayList;
35import java.util.Vector;
36import javax.swing.JPanel;
37
38import weka.classifiers.Classifier;
39import weka.classifiers.AbstractClassifier;
40import weka.core.Instances;
41import weka.core.Environment;
42import weka.core.EnvironmentHandler;
43import weka.core.xml.KOML;
44import weka.core.xml.XStream;
45import weka.core.Tag;
46import weka.core.Utils;
47
48/**
49 * A bean that saves serialized models
50 *
51 * @author Mark Hall (mhall{[at]}pentaho{[dot]}org
52 * @version $Revision: 5928 $
53 */
54public class SerializedModelSaver
55  extends JPanel
56  implements BeanCommon, Visible, BatchClassifierListener, 
57             IncrementalClassifierListener, BatchClustererListener,
58             EnvironmentHandler, Serializable {
59
60  /** for serialization */
61  private static final long serialVersionUID = 3956528599473814287L;
62
63  /**
64   * Default visual for data sources
65   */
66  protected BeanVisual m_visual = 
67    new BeanVisual("AbstractDataSink", 
68                   BeanVisual.ICON_PATH+"SerializedModelSaver.gif",
69                   BeanVisual.ICON_PATH+"SerializedModelSaver_animated.gif");
70
71  /**
72   * Non null if this object is a target for any events.
73   * Provides for the simplest case when only one incomming connection
74   * is allowed.
75   */
76  protected Object m_listenee = null;
77
78  /**
79   * The log for this bean
80   */
81  protected transient weka.gui.Logger m_logger = null;
82
83  /**
84   * The prefix for the file name (model + training set info will be appended)
85   */
86  private String m_filenamePrefix = "";
87
88  /**
89   * The directory to hold the saved model(s)
90   */
91  private File m_directory = new File(System.getProperty("user.dir"));
92
93  /**
94   * File format stuff
95   */
96  private Tag m_fileFormat;
97
98  public final static int BINARY = 0;
99  public final static int KOMLV = 1;
100  public final static int XSTREAM = 2;
101
102  /** the extension for serialized models (binary Java serialization) */
103  public final static String FILE_EXTENSION = "model";
104
105  /** relative path for the directory (relative to the user.dir (startup directory))? */
106  private boolean m_useRelativePath = false;
107 
108  /** include relation name in filename */
109  private boolean m_includeRelationName = false;
110
111  /**
112   * Available file formats. Reflection is used to check if classes
113   * are available for deep object serialization to XML
114   */
115  public static ArrayList<Tag> s_fileFormatsAvailable;
116  static {
117    s_fileFormatsAvailable = new ArrayList<Tag>();
118    s_fileFormatsAvailable.add(new Tag(BINARY, "Binary serialized model file (*"
119                                       + FILE_EXTENSION + ")", "", false));
120    if (KOML.isPresent()) {
121      s_fileFormatsAvailable.add(new Tag(KOMLV,
122                                         "XML serialized model file (*"
123                                         + KOML.FILE_EXTENSION + FILE_EXTENSION + ")", "", false));
124    }
125
126    if (XStream.isPresent()) {
127      s_fileFormatsAvailable.add(new Tag(XSTREAM,
128                                         "XML serialized model file (*"
129                                         + XStream.FILE_EXTENSION + FILE_EXTENSION + ")", "", false));
130    }
131  }
132 
133  /**
134   * The environment variables.
135   */
136  protected transient Environment m_env;
137
138  /**
139   * Constructor.
140   */
141  public SerializedModelSaver() {
142    useDefaultVisual();
143    setLayout(new BorderLayout());
144    add(m_visual, BorderLayout.CENTER);
145    m_fileFormat = s_fileFormatsAvailable.get(0);
146   
147    m_env = Environment.getSystemWide();
148  }
149
150  /**
151   * Set a custom (descriptive) name for this bean
152   *
153   * @param name the name to use
154   */
155  public void setCustomName(String name) {
156    m_visual.setText(name);
157  }
158
159  /**
160   * Get the custom (descriptive) name for this bean (if one has been set)
161   *
162   * @return the custom name (or the default name)
163   */
164  public String getCustomName() {
165    return m_visual.getText();
166  }
167
168  /**
169   * Use the default images for this bean.
170   *
171   */
172  public void useDefaultVisual() {
173    m_visual.loadIcons(BeanVisual.ICON_PATH+"SerializedModelSaver.gif",
174                       BeanVisual.ICON_PATH+"SerializedModelSaver_animated.gif");
175    m_visual.setText("SerializedModelSaver");
176  }
177
178  /**
179   * Set the visual for this data source.
180   *
181   * @param newVisual a <code>BeanVisual</code> value
182   */
183  public void setVisual(BeanVisual newVisual) {
184    m_visual = newVisual;
185  }
186
187  /**
188   * Get the visual being used by this data source.
189   *
190   */
191  public BeanVisual getVisual() {
192    return m_visual;
193  }
194
195  /**
196   * Returns true if, at this time,
197   * the object will accept a connection according to the supplied
198   * EventSetDescriptor.
199   *
200   * @param esd the EventSetDescriptor
201   * @return true if the object will accept a connection
202   */
203  public boolean connectionAllowed(EventSetDescriptor esd) {
204    return connectionAllowed(esd.getName());
205  }
206
207  /**
208   * Returns true if, at this time,
209   * the object will accept a connection according to the supplied
210   * event name.
211   *
212   * @param eventName the event
213   * @return true if the object will accept a connection
214   */
215  public boolean connectionAllowed(String eventName) {
216    return (m_listenee == null);
217  }
218
219  /**
220   * Notify this object that it has been registered as a listener with
221   * a source with respect to the supplied event name.
222   *
223   * @param eventName the event
224   * @param source the source with which this object has been registered as
225   * a listener
226   */
227  public synchronized void connectionNotification(String eventName,
228                                                  Object source) {
229    if (connectionAllowed(eventName)) {
230      m_listenee = source;
231    }
232  }
233
234  /**
235   * Notify this object that it has been deregistered as a listener with
236   * a source with respect to the supplied event name.
237   *
238   * @param eventName the event
239   * @param source the source with which this object has been registered as
240   * a listener
241   */
242  public synchronized void disconnectionNotification(String eventName,
243                                                     Object source) {
244    if (m_listenee == source) {
245      m_listenee = null;
246    }
247  }
248 
249  /**
250   * Set a log for this bean.
251   *
252   * @param logger a <code>weka.gui.Logger</code> value
253   */
254  public void setLog(weka.gui.Logger logger) {
255    m_logger = logger;
256  }
257
258  /**
259   * Stop any processing that the bean might be doing.
260   */
261  public void stop() {
262    // tell the listenee (upstream bean) to stop
263    if (m_listenee instanceof BeanCommon) {
264      ((BeanCommon)m_listenee).stop();
265    }
266  }
267 
268  /**
269   * Returns true if. at this time, the bean is busy with some
270   * (i.e. perhaps a worker thread is performing some calculation).
271   *
272   * @return true if the bean is busy.
273   */
274  public boolean isBusy() {
275    return false;
276  }
277
278  /**
279   * makes sure that the filename is valid, i.e., replaces slashes,
280   * backslashes and colons with underscores ("_").
281   *
282   * @param filename    the filename to cleanse
283   * @return            the cleansed filename
284   */
285  protected String sanitizeFilename(String filename) {
286    return filename.replaceAll("\\\\", "_").replaceAll(":", "_").replaceAll("/", "_");
287  }
288
289  /**
290   * Accept and save a batch trained clusterer.
291   *
292   * @param ce a <code>BatchClassifierEvent</code> value
293   */
294  public void acceptClusterer(BatchClustererEvent ce) {
295    if (ce.getTestSet() == null || 
296        ce.getTestOrTrain() == BatchClustererEvent.TEST ||
297        ce.getTestSet().isStructureOnly()) {
298      return;
299    }
300
301    Instances trainHeader = new Instances(ce.getTestSet().getDataSet(), 0);
302    String titleString = ce.getClusterer().getClass().getName();                     
303    titleString = titleString.
304      substring(titleString.lastIndexOf('.') + 1,
305                titleString.length());
306
307    String prefix = "";
308    String relationName = (m_includeRelationName)
309    ? trainHeader.relationName()
310    : "";
311    try {
312      prefix = m_env.substitute(m_filenamePrefix);
313    } catch (Exception ex) {
314      stop(); // stop all processing
315      String message = "[SerializedModelSaver] " 
316        + statusMessagePrefix() 
317        + " Can't save model. Reason: " 
318        + ex.getMessage();
319      if (m_logger != null) {
320        m_logger.logMessage(message);
321        m_logger.statusMessage(statusMessagePrefix()
322            + "ERROR (See log for details)");
323      } else {
324        System.err.println(message);
325      }
326      return;
327    }
328    String fileName = "" 
329      + prefix
330      + relationName
331      + titleString
332      + "_"
333      + ce.getSetNumber() 
334      + "_" + ce.getMaxSetNumber();
335    fileName = sanitizeFilename(fileName);
336   
337    String dirName = m_directory.getPath();
338    try {
339      dirName = m_env.substitute(dirName);
340    } catch (Exception ex) {
341      stop(); // stop all processing
342      String message = "[SerializedModelSaver] "
343        + statusMessagePrefix() + " Can't save model. Reason: " 
344                           + ex.getMessage();
345      if (m_logger != null) {
346        m_logger.logMessage(message);
347        m_logger.statusMessage(statusMessagePrefix()
348            + "ERROR (See log for details)");
349      } else {
350        System.err.println(message);
351      }
352      return;
353    }
354    File tempFile = new File(dirName);
355    fileName = tempFile.getAbsolutePath() 
356      + File.separator
357      + fileName;
358
359    saveModel(fileName, trainHeader, ce.getClusterer());
360  }
361
362  /**
363   * Accept and save an incrementally trained classifier.
364   *
365   * @param ce the BatchClassifierEvent containing the classifier
366   */
367  public void acceptClassifier(final IncrementalClassifierEvent ce) {
368    if (ce.getStatus() == IncrementalClassifierEvent.BATCH_FINISHED) {
369      // Only save model when the end of the stream is reached
370      Instances header = ce.getStructure();
371      String titleString = ce.getClassifier().getClass().getName();                   
372      titleString = titleString.
373        substring(titleString.lastIndexOf('.') + 1,
374                  titleString.length());
375
376      String prefix = "";
377      String relationName = (m_includeRelationName)
378        ? header.relationName()
379        : "";
380       
381      try {
382        prefix = m_env.substitute(m_filenamePrefix);
383      } catch (Exception ex) {
384        stop(); // stop processing
385        String message = "[SerializedModelSaver] "
386          + statusMessagePrefix() + " Can't save model. Reason: " 
387          + ex.getMessage();
388        if (m_logger != null) {
389          m_logger.logMessage(message);
390          m_logger.statusMessage(statusMessagePrefix()
391              + "ERROR (See log for details)");
392        } else {
393          System.err.println(message);
394        }
395        return;
396      }
397     
398      String fileName = "" + prefix + relationName + titleString;
399      fileName = sanitizeFilename(fileName);
400
401      String dirName = m_directory.getPath();
402      try {
403        dirName = m_env.substitute(dirName);
404      } catch (Exception ex) {
405        stop(); // stop processing
406        String message = "[SerializedModelSaver] "
407          + statusMessagePrefix() + " Can't save model. Reason: " 
408          + ex.getMessage();
409        if (m_logger != null) {
410          m_logger.logMessage(message);
411          m_logger.statusMessage(statusMessagePrefix()
412              + "ERROR (See log for details)");
413        } else {
414          System.err.println(message);
415        }
416        return;
417      }
418      File tempFile = new File(dirName);
419
420      fileName = tempFile.getAbsolutePath() 
421        + File.separator
422        + fileName;
423     
424      saveModel(fileName, header, ce.getClassifier());
425    }
426  }
427 
428  /**
429   * Accept and save a batch trained classifier.
430   *
431   * @param ce the BatchClassifierEvent containing the classifier
432   */
433  public void acceptClassifier(final BatchClassifierEvent ce) {
434    if (ce.getTrainSet() == null || 
435        ce.getTrainSet().isStructureOnly()) {
436      return;
437    }
438    Instances trainHeader = new Instances(ce.getTrainSet().getDataSet(), 0);
439    String titleString = ce.getClassifier().getClass().getName();                     
440    titleString = titleString.
441      substring(titleString.lastIndexOf('.') + 1,
442                titleString.length());
443
444    String prefix = "";
445    String relationName = (m_includeRelationName)
446    ? trainHeader.relationName()
447    : "";
448    try {
449      prefix = m_env.substitute(m_filenamePrefix);
450    } catch (Exception ex) {
451      stop(); // stop processing
452      String message = "[SerializedModelSaver] "
453        + statusMessagePrefix() + " Can't save model. Reason: " 
454        + ex.getMessage();
455      if (m_logger != null) {
456        m_logger.logMessage(message);
457        m_logger.statusMessage(statusMessagePrefix()
458            + "ERROR (See log for details)");
459      } else {
460        System.err.println(message);
461      }
462      return;
463    }
464
465    String fileName = "" 
466      + prefix
467      + relationName
468      + titleString
469      + "_"
470      + ce.getSetNumber() 
471      + "_" + ce.getMaxSetNumber();
472    fileName = sanitizeFilename(fileName);
473   
474    String dirName = m_directory.getPath();
475    try {
476      dirName = m_env.substitute(dirName);
477    } catch (Exception ex) {
478      stop(); // stop processing
479      String message = "[SerializedModelSaver] "
480        + statusMessagePrefix() + " Can't save model. Reason: " 
481                           + ex.getMessage();
482      if (m_logger != null) {
483        m_logger.logMessage(message);
484        m_logger.statusMessage(statusMessagePrefix()
485            + "ERROR (See log for details)");
486      } else {
487        System.err.println(message);
488      }
489      return;
490    }
491    File tempFile = new File(dirName);
492
493    fileName = tempFile.getAbsolutePath() 
494      + File.separator
495      + fileName;
496
497    saveModel(fileName, trainHeader, ce.getClassifier());
498  }
499
500  /**
501   * Helper routine to actually save the models.
502   */
503  private void saveModel(String fileName, Instances trainHeader, Object model) {
504    m_fileFormat = validateFileFormat(m_fileFormat);
505    if (m_fileFormat == null) {
506      // default to binary if validation fails
507      m_fileFormat = s_fileFormatsAvailable.get(0);
508    }
509    try {
510      switch (m_fileFormat.getID()) {
511      case KOMLV:
512        fileName = fileName + KOML.FILE_EXTENSION + FILE_EXTENSION;
513        saveKOML(new File(fileName), model, trainHeader);
514        break;
515      case XSTREAM:
516        fileName = fileName + XStream.FILE_EXTENSION + FILE_EXTENSION;
517        saveXStream(new File(fileName), model, trainHeader);
518        break;
519      default:
520        fileName = fileName + "." + FILE_EXTENSION;
521        saveBinary(new File(fileName), model, trainHeader);
522        break;
523      }       
524    } catch (Exception ex) {
525      stop(); // stop all processing
526      System.err.println("[SerializedModelSaver] Problem saving model");
527      if (m_logger != null) {
528        m_logger.logMessage("[SerializedModelSaver] "
529            + statusMessagePrefix() + " Problem saving model");
530        m_logger.statusMessage(statusMessagePrefix()
531            + "ERROR (See log for details)");
532      }
533    }
534  }
535
536  /**
537   * Save a model in binary form.
538   *
539   * @param saveTo the file name to save to
540   * @param model the model to save
541   * @param header the header of the data that was used to train the model (optional)
542   */
543  public static void saveBinary(File saveTo, Object model, Instances header) throws IOException {
544    ObjectOutputStream os =
545      new ObjectOutputStream(new BufferedOutputStream(
546                             new FileOutputStream(saveTo)));
547    os.writeObject(model);
548    // now the header
549    if (header != null) {
550      os.writeObject(header);
551    }
552    os.close();
553  }
554
555  /**
556   * Save a model in KOML deep object serialized XML form.
557   *
558   * @param saveTo the file name to save to
559   * @param model the model to save
560   * @param header the header of the data that was used to train the model (optional)
561   */
562  public static void saveKOML(File saveTo, Object model, Instances header) throws Exception {
563    Vector v = new Vector();
564    v.add(model);
565    if (header != null) {
566      v.add(header);
567    }
568    v.trimToSize();
569    KOML.write(saveTo.getAbsolutePath(), v);
570  }
571
572  /**
573   * Save a model in XStream deep object serialized XML form.
574   *
575   * @param saveTo the file name to save to
576   * @param model the model to save
577   * @param header the header of the data that was used to train the model (optional)
578   */
579  public static void saveXStream(File saveTo, Object model, Instances header) throws Exception {
580    Vector v = new Vector();
581    v.add(model);
582    if (header != null) {
583      v.add(header);
584    }
585    v.trimToSize();
586    XStream.write(saveTo.getAbsolutePath(), v);
587  }
588
589  /**
590   * Get the directory that the model(s) will be saved into
591   *
592   * @return the directory to save to
593   */
594  public File getDirectory() {
595    return m_directory;
596  }
597 
598  /**
599   * Set the directory that the model(s) will be saved into.
600   *
601   * @param d the directory to save to
602   */
603  public void setDirectory(File d) {
604    m_directory = d;
605    if (m_useRelativePath) {
606      try {
607        m_directory = Utils.convertToRelativePath(m_directory);
608      } catch (Exception ex) {
609      }
610    }
611  }
612
613  /**
614   * Set whether to use relative paths for the directory.
615   * I.e. relative to the startup (user.dir) directory
616   *
617   * @param rp true if relative paths are to be used
618   */
619  public void setUseRelativePath(boolean rp) {
620    m_useRelativePath = rp;
621  }
622 
623  /**
624   * Get whether to use relative paths for the directory.
625   * I.e. relative to the startup (user.dir) directory
626   *
627   * @return true if relative paths are to be used
628   */
629  public boolean getUseRelativePath() {
630    return m_useRelativePath;
631  }
632 
633  /**
634   * Set whether the relation name of the training data
635   * used to create the model should be included as part
636   * of the filename for the serialized model.
637   *
638   * @param rn true if the relation name should be included
639   * in the file name
640   */
641  public void setIncludeRelationName(boolean rn) {
642    m_includeRelationName = rn;
643  }
644 
645  /**
646   * Get whether the relation name of the training
647   * data used to create the model is to be included
648   * in the filename of the serialized model.
649   *
650   * @return true if the relation name is to be included
651   * in the file name
652   */
653  public boolean getIncludeRelationName() {
654    return m_includeRelationName;
655  }
656
657  /**
658   * Get the prefix to prepend to the model file names.
659   *
660   * @return the prefix to prepend
661   */
662  public String getPrefix() {
663    return m_filenamePrefix;
664  }
665
666  /**
667   * Set the prefix to prepend to the model file names.
668   *
669   * @param p the prefix to prepend
670   */
671  public void setPrefix(String p) {
672    m_filenamePrefix = p;
673  }
674
675  /**
676   * Global info for this bean. Gets displayed in the GUI.
677   *
678   * @return information about this bean.
679   */
680  public String globalInfo() {
681    return "Save trained models to serialized object files.";
682  }
683
684  /**
685   * Set the file format to use for saving.
686   *
687   * @param ff the file format to use
688   */
689  public void setFileFormat(Tag ff) {
690    m_fileFormat = ff;
691  }
692
693  /**
694   * Get the file format to use for saving.
695   *
696   * @return the file format to use
697   */
698  public Tag getFileFormat() {
699    return m_fileFormat;
700  }
701
702  /**
703   * Validate the file format. After this bean is deserialized, classes for
704   * XML serialization may not be in the classpath any more.
705   *
706   * @param ff the current file format to validate
707   */
708  public Tag validateFileFormat(Tag ff) {
709    Tag r = ff;
710    if (ff.getID() == BINARY) {
711      return ff;
712    }
713
714    if (ff.getID() == KOMLV && !KOML.isPresent()) {
715      r = null;
716    }
717
718    if (ff.getID() == XSTREAM && !XStream.isPresent()) {
719      r = null;
720    }
721
722    return r;
723  }
724 
725  private String statusMessagePrefix() {
726    return getCustomName() + "$" + hashCode() + "|";
727  }
728 
729  /**
730   * Set environment variables to use.
731   *
732   * @param env the environment variables to
733   * use
734   */
735  public void setEnvironment(Environment env) {
736    m_env = env;
737  }
738 
739  // Custom de-serialization in order to set default
740  // environment variables on de-serialization
741  private void readObject(ObjectInputStream aStream) 
742    throws IOException, ClassNotFoundException {
743    aStream.defaultReadObject();
744   
745    // set a default environment to use
746    m_env = Environment.getSystemWide();
747  }
748}
Note: See TracBrowser for help on using the repository browser.