source: src/main/java/weka/gui/beans/CrossValidationFoldMaker.java @ 18

Last change on this file since 18 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 12.4 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    CrossValidationFoldMaker.java
19 *    Copyright (C) 2002 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.gui.beans;
24
25import weka.core.Instances;
26
27import java.io.Serializable;
28import java.util.Enumeration;
29import java.util.Random;
30import java.util.Vector;
31
32/**
33 * Bean for splitting instances into training ant test sets according to
34 * a cross validation
35 *
36 * @author <a href="mailto:mhall@cs.waikato.ac.nz">Mark Hall</a>
37 * @version $Revision: 6003 $
38 */
39public class CrossValidationFoldMaker 
40  extends AbstractTrainAndTestSetProducer
41  implements DataSourceListener, TrainingSetListener, TestSetListener, 
42             UserRequestAcceptor, EventConstraints, Serializable {
43
44  /** for serialization */
45  private static final long serialVersionUID = -6350179298851891512L;
46
47  private int m_numFolds = 10;
48  private int m_randomSeed = 1;
49 
50  private boolean m_preserveOrder = false;
51
52  private transient Thread m_foldThread = null;
53
54  public CrossValidationFoldMaker() {
55    m_visual.loadIcons(BeanVisual.ICON_PATH
56                       +"CrossValidationFoldMaker.gif",
57                       BeanVisual.ICON_PATH
58                       +"CrossValidationFoldMaker_animated.gif");
59    m_visual.setText("CrossValidationFoldMaker");
60  }
61
62  /**
63   * Set a custom (descriptive) name for this bean
64   *
65   * @param name the name to use
66   */
67  public void setCustomName(String name) {
68    m_visual.setText(name);
69  }
70
71  /**
72   * Get the custom (descriptive) name for this bean (if one has been set)
73   *
74   * @return the custom name (or the default name)
75   */
76  public String getCustomName() {
77    return m_visual.getText();
78  }
79
80  /**
81   * Global info for this bean
82   *
83   * @return a <code>String</code> value
84   */
85  public String globalInfo() {
86    return "Split an incoming data set into cross validation folds. "
87      +"Separate train and test sets are produced for each of the k folds.";
88  }
89
90  /**
91   * Accept a training set
92   *
93   * @param e a <code>TrainingSetEvent</code> value
94   */
95  public void acceptTrainingSet(TrainingSetEvent e) {
96    Instances trainingSet = e.getTrainingSet();
97    DataSetEvent dse = new DataSetEvent(this, trainingSet);
98    acceptDataSet(dse);
99  }
100
101  /**
102   * Accept a test set
103   *
104   * @param e a <code>TestSetEvent</code> value
105   */
106  public void acceptTestSet(TestSetEvent e) {
107    Instances testSet = e.getTestSet();
108    DataSetEvent dse = new DataSetEvent(this, testSet);
109    acceptDataSet(dse);
110  }
111 
112  /**
113   * Accept a data set
114   *
115   * @param e a <code>DataSetEvent</code> value
116   */
117  public void acceptDataSet(DataSetEvent e) {
118    if (e.isStructureOnly()) {
119      // Pass on structure to training and test set listeners
120      TrainingSetEvent tse = new TrainingSetEvent(this, e.getDataSet());
121      TestSetEvent tsee = new TestSetEvent(this, e.getDataSet());
122      notifyTrainingSetProduced(tse);
123      notifyTestSetProduced(tsee);
124      return;
125    }
126    if (m_foldThread == null) {
127      final Instances dataSet = new Instances(e.getDataSet());
128      m_foldThread = new Thread() {
129          public void run() {
130            boolean errorOccurred = false;
131            try {
132              Random random = new Random(getSeed());
133              if (!m_preserveOrder) {
134                dataSet.randomize(random);
135              }
136              if (dataSet.classIndex() >= 0 && 
137                  dataSet.attribute(dataSet.classIndex()).isNominal() &&
138                  !m_preserveOrder) {
139                dataSet.stratify(getFolds());
140                if (m_logger != null) {
141                  m_logger.logMessage("[" + getCustomName() + "] "
142                                      +"stratifying data");
143                }
144              }
145             
146              for (int i = 0; i < getFolds(); i++) {
147                if (m_foldThread == null) {
148                  if (m_logger != null) {
149                    m_logger.logMessage("[" + getCustomName() + "] Cross validation has been canceled!");
150                  }
151                  // exit gracefully
152                  break;
153                }
154                Instances train = (!m_preserveOrder) 
155                  ? dataSet.trainCV(getFolds(), i, random)
156                  : dataSet.trainCV(getFolds(), i); 
157                Instances test  = dataSet.testCV(getFolds(), i);
158
159                // inform all training set listeners
160                TrainingSetEvent tse = new TrainingSetEvent(this, train);
161                tse.m_setNumber = i+1; tse.m_maxSetNumber = getFolds();
162                String msg = getCustomName() + "$" 
163                  + CrossValidationFoldMaker.this.hashCode() + "|";
164                if (m_logger != null) {
165                  m_logger.statusMessage(msg + "seed: " + getSeed() + " folds: "
166                      + getFolds() + "|Training fold " + (i+1));
167                }
168                if (m_foldThread != null) {
169                  //              System.err.println("--Just before notify training set");
170                  notifyTrainingSetProduced(tse);
171                  //              System.err.println("---Just after notify");
172                }
173             
174                // inform all test set listeners
175                TestSetEvent teste = new TestSetEvent(this, test);
176                teste.m_setNumber = i+1; teste.m_maxSetNumber = getFolds();
177               
178                if (m_logger != null) {
179                  m_logger.statusMessage(msg + "seed: " + getSeed() + " folds: "
180                      + getFolds() + "|Test fold " + (i+1));
181                }
182                if (m_foldThread != null) {
183                  notifyTestSetProduced(teste);
184                }
185              }
186            } catch (Exception ex) {
187              // stop all processing
188              errorOccurred = true;
189              if (m_logger != null) {
190                m_logger.logMessage("[" + getCustomName() 
191                    + "] problem during fold creation. "
192                    + ex.getMessage());
193              }
194              ex.printStackTrace();
195              CrossValidationFoldMaker.this.stop();
196            } finally {
197              m_foldThread = null;
198             
199              if (errorOccurred) {
200                if (m_logger != null) {
201                  m_logger.statusMessage(getCustomName() 
202                      + "$" + CrossValidationFoldMaker.this.hashCode()
203                      + "|"
204                      + "ERROR (See log for details).");
205                }
206              } else if (isInterrupted()) {
207                String msg = "[" + getCustomName() + "] Cross validation interrupted";
208                if (m_logger != null) {
209                  m_logger.logMessage("[" + getCustomName() + "] Cross validation interrupted");
210                  m_logger.statusMessage(getCustomName() + "$"
211                      + CrossValidationFoldMaker.this.hashCode() + "|"
212                      + "INTERRUPTED");
213                } else {
214                  System.err.println(msg);
215                }
216              } else {
217                String msg = getCustomName() + "$" 
218                + CrossValidationFoldMaker.this.hashCode() + "|";
219                if (m_logger != null) {
220                  m_logger.statusMessage(msg + "Finished.");
221                }
222              }
223              block(false);
224            }
225          }
226        };
227      m_foldThread.setPriority(Thread.MIN_PRIORITY);
228      m_foldThread.start();
229
230      //      if (m_foldThread.isAlive()) {
231      block(true);
232        //      }
233      m_foldThread = null;
234    }
235  }
236
237
238  /**
239   * Notify all test set listeners of a TestSet event
240   *
241   * @param tse a <code>TestSetEvent</code> value
242   */
243  private void notifyTestSetProduced(TestSetEvent tse) {
244    Vector l;
245    synchronized (this) {
246      l = (Vector)m_testListeners.clone();
247    }
248    if (l.size() > 0) {
249      for(int i = 0; i < l.size(); i++) {
250        if (m_foldThread == null) {
251          break;
252        }
253        //      System.err.println("Notifying test listeners "
254        //                         +"(cross validation fold maker)");
255        ((TestSetListener)l.elementAt(i)).acceptTestSet(tse);
256      }
257    }
258  }
259
260  /**
261   * Notify all listeners of a TrainingSet event
262   *
263   * @param tse a <code>TrainingSetEvent</code> value
264   */
265  protected void notifyTrainingSetProduced(TrainingSetEvent tse) {
266    Vector l;
267    synchronized (this) {
268      l = (Vector)m_trainingListeners.clone();
269    }
270    if (l.size() > 0) {
271      for(int i = 0; i < l.size(); i++) {
272        if (m_foldThread == null) {
273          break;
274        }
275        //      System.err.println("Notifying training listeners "
276        //                         +"(cross validation fold maker)");
277        ((TrainingSetListener)l.elementAt(i)).acceptTrainingSet(tse);
278      }
279    }
280  }
281
282  /**
283   * Set the number of folds for the cross validation
284   *
285   * @param numFolds an <code>int</code> value
286   */
287  public void setFolds(int numFolds) {
288    m_numFolds = numFolds;
289  }
290 
291  /**
292   * Get the currently set number of folds
293   *
294   * @return an <code>int</code> value
295   */
296  public int getFolds() {
297    return m_numFolds;
298  }
299
300  /**
301   * Tip text for this property
302   *
303   * @return a <code>String</code> value
304   */
305  public String foldsTipText() {
306    return "The number of train and test splits to produce";
307  }
308   
309  /**
310   * Set the seed
311   *
312   * @param randomSeed an <code>int</code> value
313   */
314  public void setSeed(int randomSeed) {
315    m_randomSeed = randomSeed;
316  }
317 
318  /**
319   * Get the currently set seed
320   *
321   * @return an <code>int</code> value
322   */
323  public int getSeed() {
324    return m_randomSeed;
325  }
326 
327  /**
328   * Tip text for this property
329   *
330   * @return a <code>String</code> value
331   */
332  public String seedTipText() {
333    return "The randomization seed";
334  }
335 
336  /**
337   * Returns true if the order of the incoming instances is to
338   * be preserved under cross-validation (no randomization or
339   * stratification is done in this case).
340   *
341   * @return true if the order of the incoming instances is to
342   * be preserved.
343   */
344  public boolean getPreserveOrder() {
345    return m_preserveOrder;
346  }
347 
348  /**
349   * Sets whether the order of the incoming instances is to be
350   * preserved under cross-validation (no randomization or
351   * stratification is done in this case).
352   * 
353   * @param p true if the order is to be preserved.
354   */
355  public void setPreserveOrder(boolean p) {
356    m_preserveOrder = p;
357  }
358 
359  /**
360   * Returns true if. at this time, the bean is busy with some
361   * (i.e. perhaps a worker thread is performing some calculation).
362   *
363   * @return true if the bean is busy.
364   */
365  public boolean isBusy() {
366    return (m_foldThread != null);
367  }
368
369  /**
370   * Stop any action
371   */
372  public void stop() {
373    // tell the listenee (upstream bean) to stop
374    if (m_listenee instanceof BeanCommon) {
375      //      System.err.println("Listener is BeanCommon");
376      ((BeanCommon)m_listenee).stop();
377    }
378
379    // stop the fold thread
380    if (m_foldThread != null) {
381      Thread temp = m_foldThread;
382      m_foldThread = null;
383      temp.interrupt();
384      temp.stop();
385    }
386  }
387
388  /**
389   * Function used to stop code that calls acceptDataSet. This is
390   * needed as cross validation is performed inside a separate
391   * thread of execution.
392   *
393   * @param tf a <code>boolean</code> value
394   */
395  private synchronized void block(boolean tf) {
396    if (tf) {
397      try {
398        // make sure the thread is still running before we block
399        if (m_foldThread != null && m_foldThread.isAlive()) {
400          wait();
401        }
402      } catch (InterruptedException ex) {
403      }
404    } else {
405      notifyAll();
406    }
407  }
408
409  /**
410   * Return an enumeration of user requests
411   *
412   * @return an <code>Enumeration</code> value
413   */
414  public Enumeration enumerateRequests() {
415    Vector newVector = new Vector(0);
416    if (m_foldThread != null) {
417      newVector.addElement("Stop");
418    }
419    return newVector.elements();
420  }
421
422  /**
423   * Perform the named request
424   *
425   * @param request a <code>String</code> value
426   * @exception IllegalArgumentException if an error occurs
427   */
428  public void performRequest(String request) {
429    if (request.compareTo("Stop") == 0) {
430      stop();
431    } else {
432      throw new IllegalArgumentException(request
433                                         + " not supported (CrossValidation)");
434    }
435  }
436
437  /**
438   * Returns true, if at the current time, the named event could
439   * be generated. Assumes that the supplied event name is
440   * an event that could be generated by this bean
441   *
442   * @param eventName the name of the event in question
443   * @return true if the named event could be generated at this point in
444   * time
445   */
446  public boolean eventGeneratable(String eventName) {
447    if (m_listenee == null) {
448      return false;
449    }
450   
451    if (m_listenee instanceof EventConstraints) {
452      if (((EventConstraints)m_listenee).eventGeneratable("dataSet") ||
453          ((EventConstraints)m_listenee).eventGeneratable("trainingSet") ||
454          ((EventConstraints)m_listenee).eventGeneratable("testSet")) {
455        return true;
456      } else {
457        return false;
458      }
459    }
460    return true;
461  }
462}
Note: See TracBrowser for help on using the repository browser.