source: src/main/java/weka/gui/beans/TrainTestSplitMaker.java @ 24

Last change on this file since 24 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 10.2 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    TrainTestSplitMaker.java
19 *    Copyright (C) 2002 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.gui.beans;
24
25import weka.core.Instances;
26
27import java.io.Serializable;
28import java.util.Enumeration;
29import java.util.Random;
30import java.util.Vector;
31
32/**
33 * Bean that accepts data sets, training sets, test sets and produces
34 * both a training and test set by randomly spliting the data
35 *
36 * @author <a href="mailto:mhall@cs.waikato.ac.nz">Mark Hall</a>
37 * @version $Revision: 4813 $
38 */
39public class TrainTestSplitMaker
40  extends AbstractTrainAndTestSetProducer
41  implements DataSourceListener, TrainingSetListener, TestSetListener,
42             UserRequestAcceptor, EventConstraints, Serializable {
43
44  /** for serialization */
45  private static final long serialVersionUID = 7390064039444605943L;
46
47  private double m_trainPercentage = 66;
48  private int m_randomSeed = 1;
49 
50  private Thread m_splitThread = null;
51
52  public TrainTestSplitMaker() {
53         m_visual.loadIcons(BeanVisual.ICON_PATH
54                       +"TrainTestSplitMaker.gif",
55                       BeanVisual.ICON_PATH
56                       +"TrainTestSplittMaker_animated.gif");
57    m_visual.setText("TrainTestSplitMaker");
58  }
59
60  /**
61   * Set a custom (descriptive) name for this bean
62   *
63   * @param name the name to use
64   */
65  public void setCustomName(String name) {
66    m_visual.setText(name);
67  }
68
69  /**
70   * Get the custom (descriptive) name for this bean (if one has been set)
71   *
72   * @return the custom name (or the default name)
73   */
74  public String getCustomName() {
75    return m_visual.getText();
76  }
77
78  /**
79   * Global info for this bean
80   *
81   * @return a <code>String</code> value
82   */
83  public String globalInfo() {
84    return "Split an incoming data set into separate train and test sets." ;
85  }
86
87  /**
88   * Tip text info for this property
89   *
90   * @return a <code>String</code> value
91   */
92  public String trainPercentTipText() {
93    return "The percentage of data to go into the training set";
94  }
95
96  /**
97   * Set the percentage of data to be in the training portion of the split
98   *
99   * @param newTrainPercent an <code>int</code> value
100   */
101  public void setTrainPercent(double newTrainPercent) {
102    m_trainPercentage = newTrainPercent;
103  }
104
105  /**
106   * Get the percentage of the data that will be in the training portion of
107   * the split
108   *
109   * @return an <code>int</code> value
110   */
111  public double getTrainPercent() {
112    return m_trainPercentage;
113  }
114
115  /**
116   * Tip text for this property
117   *
118   * @return a <code>String</code> value
119   */
120  public String seedTipText() {
121    return "The randomization seed";
122  }
123
124  /**
125   * Set the random seed
126   *
127   * @param newSeed an <code>int</code> value
128   */
129  public void setSeed(int newSeed) {
130    m_randomSeed = newSeed;
131  }
132
133  /**
134   * Get the value of the random seed
135   *
136   * @return an <code>int</code> value
137   */
138  public int getSeed() {
139    return m_randomSeed;
140  }
141
142  /**
143   * Accept a training set
144   *
145   * @param e a <code>TrainingSetEvent</code> value
146   */
147  public void acceptTrainingSet(TrainingSetEvent e) {
148    Instances trainingSet = e.getTrainingSet();
149    DataSetEvent dse = new DataSetEvent(this, trainingSet);
150    acceptDataSet(dse);
151  }
152
153  /**
154   * Accept a test set
155   *
156   * @param e a <code>TestSetEvent</code> value
157   */
158  public void acceptTestSet(TestSetEvent e) {
159    Instances testSet = e.getTestSet();
160    DataSetEvent dse = new DataSetEvent(this, testSet);
161    acceptDataSet(dse);
162  }
163
164  /**
165   * Accept a data set
166   *
167   * @param e a <code>DataSetEvent</code> value
168   */
169  public void acceptDataSet(DataSetEvent e) {
170    if (m_splitThread == null) {
171      final Instances dataSet = new Instances(e.getDataSet());
172      m_splitThread = new Thread() {
173          public void run() {
174            try {
175              dataSet.randomize(new Random(m_randomSeed));
176              int trainSize = 
177                (int)Math.round(dataSet.numInstances() * m_trainPercentage / 100);
178              int testSize = dataSet.numInstances() - trainSize;
179     
180              Instances train = new Instances(dataSet, 0, trainSize);
181              Instances test = new Instances(dataSet, trainSize, testSize);
182     
183              TrainingSetEvent tse =
184                new TrainingSetEvent(TrainTestSplitMaker.this, train);
185              tse.m_setNumber = 1; tse.m_maxSetNumber = 1;
186              if (m_splitThread != null) {
187                notifyTrainingSetProduced(tse);
188              }
189   
190              // inform all test set listeners
191              TestSetEvent teste = 
192                new TestSetEvent(TrainTestSplitMaker.this, test);
193              teste.m_setNumber = 1; teste.m_maxSetNumber = 1;
194              if (m_splitThread != null) {
195                notifyTestSetProduced(teste);
196              } else {
197                if (m_logger != null) {
198                  m_logger.logMessage("[TrainTestSplitMaker] "
199                      + statusMessagePrefix() + " Split has been canceled!");
200                  m_logger.statusMessage(statusMessagePrefix()
201                      + "INTERRUPTED");
202                }
203              }
204            } catch (Exception ex) {
205              stop(); // stop all processing
206              if (m_logger != null) {
207                  m_logger.statusMessage(statusMessagePrefix()
208                      + "ERROR (See log for details)");
209                  m_logger.logMessage("[TrainTestSplitMaker] " 
210                      + statusMessagePrefix()
211                      + " problem during split creation. " 
212                      + ex.getMessage());
213              }
214              ex.printStackTrace();
215            } finally {
216              if (isInterrupted()) {
217                if (m_logger != null) {
218                  m_logger.logMessage("[TrainTestSplitMaker] "
219                      + statusMessagePrefix() + " Split has been canceled!");
220                  m_logger.statusMessage(statusMessagePrefix()
221                      + "INTERRUPTED");
222                }
223              }
224              block(false);
225            }
226          }
227        };
228      m_splitThread.setPriority(Thread.MIN_PRIORITY);
229      m_splitThread.start();
230
231      //      if (m_splitThread.isAlive()) {
232      block(true);
233      //      }
234      m_splitThread = null;
235    }
236  }
237
238  /**
239   * Notify test set listeners that a test set is available
240   *
241   * @param tse a <code>TestSetEvent</code> value
242   */
243  protected void notifyTestSetProduced(TestSetEvent tse) {
244    Vector l;
245    synchronized (this) {
246      l = (Vector)m_testListeners.clone();
247    }
248    if (l.size() > 0) {
249      for(int i = 0; i < l.size(); i++) {
250        if (m_splitThread == null) {
251          break;
252        }
253        //      System.err.println("Notifying test listeners "
254        //                         +"(Train - test split maker)");
255        ((TestSetListener)l.elementAt(i)).acceptTestSet(tse);
256      }
257    }
258  }
259
260  /**
261   * Notify training set listeners that a training set is available
262   *
263   * @param tse a <code>TrainingSetEvent</code> value
264   */
265  protected void notifyTrainingSetProduced(TrainingSetEvent tse) {
266    Vector l;
267    synchronized (this) {
268      l = (Vector)m_trainingListeners.clone();
269    }
270    if (l.size() > 0) {
271      for(int i = 0; i < l.size(); i++) {
272        if (m_splitThread == null) {
273          break;
274        }
275        //      System.err.println("Notifying training listeners "
276        //                         +"(Train - test split fold maker)");
277        ((TrainingSetListener)l.elementAt(i)).acceptTrainingSet(tse);
278      }
279    }
280  }
281
282  /**
283   * Function used to stop code that calls acceptDataSet. This is
284   * needed as split is performed inside a separate
285   * thread of execution.
286   *
287   * @param tf a <code>boolean</code> value
288   */
289  private synchronized void block(boolean tf) {
290    if (tf) {
291      try {
292        // make sure that the thread is still alive before blocking
293        if (m_splitThread.isAlive()) {
294          wait();
295        }
296      } catch (InterruptedException ex) {
297      }
298    } else {
299      notifyAll();
300    }
301  }
302
303  /**
304   * Stop processing
305   */
306  public void stop() {
307    // tell the listenee (upstream bean) to stop
308    if (m_listenee instanceof BeanCommon) {
309      //      System.err.println("Listener is BeanCommon");
310      ((BeanCommon)m_listenee).stop();
311    }
312
313    // stop the split thread
314    if (m_splitThread != null) {
315      Thread temp = m_splitThread;
316      m_splitThread = null;
317      temp.interrupt();
318      temp.stop();
319    }
320  }
321 
322  /**
323   * Returns true if. at this time, the bean is busy with some
324   * (i.e. perhaps a worker thread is performing some calculation).
325   *
326   * @return true if the bean is busy.
327   */
328  public boolean isBusy() {
329    return (m_splitThread != null);
330  }
331
332  /**
333   * Get list of user requests
334   *
335   * @return an <code>Enumeration</code> value
336   */
337  public Enumeration enumerateRequests() {
338    Vector newVector = new Vector(0);
339    if (m_splitThread != null) {
340      newVector.addElement("Stop");
341    }
342    return newVector.elements();
343  }
344
345  /**
346   * Perform the named request
347   *
348   * @param request a <code>String</code> value
349   * @exception IllegalArgumentException if an error occurs
350   */
351  public void performRequest(String request) {
352    if (request.compareTo("Stop") == 0) {
353      stop();
354    } else {
355      throw new IllegalArgumentException(request
356                         + " not supported (TrainTestSplitMaker)");
357    }
358  }
359
360  /**
361   * Returns true, if at the current time, the named event could
362   * be generated. Assumes that the supplied event name is
363   * an event that could be generated by this bean
364   *
365   * @param eventName the name of the event in question
366   * @return true if the named event could be generated at this point in
367   * time
368   */
369  public boolean eventGeneratable(String eventName) {
370    if (m_listenee == null) {
371      return false;
372    }
373   
374    if (m_listenee instanceof EventConstraints) {
375      if (((EventConstraints)m_listenee).eventGeneratable("dataSet") ||
376          ((EventConstraints)m_listenee).eventGeneratable("trainingSet") ||
377          ((EventConstraints)m_listenee).eventGeneratable("testSet")) {
378        return true;
379      } else {
380        return false;
381      }
382    }
383    return true;
384  }
385 
386  private String statusMessagePrefix() {
387    return getCustomName() + "$" + hashCode() + "|";
388  }
389}
Note: See TracBrowser for help on using the repository browser.