source: src/main/java/weka/experiment/CrossValidationSplitResultProducer.java @ 8

Last change on this file since 8 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 8.8 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    CrossValidationSplitResultProducer.java
19 *    Copyright (C) 1999, 2009 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23
24package weka.experiment;
25
26import weka.core.AdditionalMeasureProducer;
27import weka.core.Instance;
28import weka.core.Instances;
29import weka.core.Option;
30import weka.core.OptionHandler;
31import weka.core.RevisionHandler;
32import weka.core.RevisionUtils;
33import weka.core.Utils;
34
35import java.io.File;
36import java.util.Calendar;
37import java.util.Enumeration;
38import java.util.Random;
39import java.util.TimeZone;
40import java.util.Vector;
41
42/**
43 <!-- globalinfo-start -->
44 * Carries out one split of a repeated k-fold cross-validation, using the set SplitEvaluator to generate some results. Note that the run number is actually the nth split of a repeated k-fold cross-validation, i.e. if k=10, run number 100 is the 10th fold of the 10th cross-validation run. This producer's sole purpose is to allow more fine-grained distribution of cross-validation experiments. If the class attribute is nominal, the dataset is stratified.
45 * <p/>
46 <!-- globalinfo-end -->
47 *
48 <!-- options-start -->
49 * Valid options are: <p/>
50 *
51 * <pre> -X &lt;number of folds&gt;
52 *  The number of folds to use for the cross-validation.
53 *  (default 10)</pre>
54 *
55 * <pre> -D
56 * Save raw split evaluator output.</pre>
57 *
58 * <pre> -O &lt;file/directory name/path&gt;
59 *  The filename where raw output will be stored.
60 *  If a directory name is specified then then individual
61 *  outputs will be gzipped, otherwise all output will be
62 *  zipped to the named file. Use in conjuction with -D. (default splitEvalutorOut.zip)</pre>
63 *
64 * <pre> -W &lt;class name&gt;
65 *  The full class name of a SplitEvaluator.
66 *  eg: weka.experiment.ClassifierSplitEvaluator</pre>
67 *
68 * <pre>
69 * Options specific to split evaluator weka.experiment.ClassifierSplitEvaluator:
70 * </pre>
71 *
72 * <pre> -W &lt;class name&gt;
73 *  The full class name of the classifier.
74 *  eg: weka.classifiers.bayes.NaiveBayes</pre>
75 *
76 * <pre> -C &lt;index&gt;
77 *  The index of the class for which IR statistics
78 *  are to be output. (default 1)</pre>
79 *
80 * <pre> -I &lt;index&gt;
81 *  The index of an attribute to output in the
82 *  results. This attribute should identify an
83 *  instance in order to know which instances are
84 *  in the test set of a cross validation. if 0
85 *  no output (default 0).</pre>
86 *
87 * <pre> -P
88 *  Add target and prediction columns to the result
89 *  for each fold.</pre>
90 *
91 * <pre>
92 * Options specific to classifier weka.classifiers.rules.ZeroR:
93 * </pre>
94 *
95 * <pre> -D
96 *  If set, classifier is run in debug mode and
97 *  may output additional info to the console</pre>
98 *
99 <!-- options-end -->
100 *
101 * All options after -- will be passed to the split evaluator.
102 *
103 * @author Len Trigg
104 * @author Eibe Frank
105 * @version $Revision: 5828 $
106 */
107public class CrossValidationSplitResultProducer 
108  extends CrossValidationResultProducer {
109 
110  /** for serialization */
111  static final long serialVersionUID = 1403798164046795073L;
112 
113  /**
114   * Returns a string describing this result producer
115   * @return a description of the result producer suitable for
116   * displaying in the explorer/experimenter gui
117   */
118  public String globalInfo() {
119    return 
120        "Carries out one split of a repeated k-fold cross-validation, "
121      + "using the set SplitEvaluator to generate some results. "
122      + "Note that the run number is actually the nth split of a repeated "
123      + "k-fold cross-validation, i.e. if k=10, run number 100 is the 10th "
124      + "fold of the 10th cross-validation run. This producer's sole purpose "
125      + "is to allow more fine-grained distribution of cross-validation "
126      + "experiments. If the class attribute is nominal, the dataset is stratified.";
127  }
128 
129  /**
130   * Gets the keys for a specified run number. Different run
131   * numbers correspond to different randomizations of the data. Keys
132   * produced should be sent to the current ResultListener
133   *
134   * @param run the run number to get keys for.
135   * @throws Exception if a problem occurs while getting the keys
136   */
137  public void doRunKeys(int run) throws Exception {
138    if (m_Instances == null) {
139      throw new Exception("No Instances set");
140    }
141
142    // Add in some fields to the key like run and fold number, dataset name
143    Object [] seKey = m_SplitEvaluator.getKey();
144    Object [] key = new Object [seKey.length + 3];
145    key[0] = Utils.backQuoteChars(m_Instances.relationName());
146    key[2] = "" + (((run - 1) % m_NumFolds) + 1);
147    key[1] = "" + (((run - 1) / m_NumFolds) + 1);
148    System.arraycopy(seKey, 0, key, 3, seKey.length);
149    if (m_ResultListener.isResultRequired(this, key)) {
150      try {
151        m_ResultListener.acceptResult(this, key, null);
152      } catch (Exception ex) {
153        // Save the train and test datasets for debugging purposes?
154        throw ex;
155      }
156    }
157  }
158
159  /**
160   * Gets the results for a specified run number. Different run
161   * numbers correspond to different randomizations of the data. Results
162   * produced should be sent to the current ResultListener
163   *
164   * @param run the run number to get results for.
165   * @throws Exception if a problem occurs while getting the results
166   */
167  public void doRun(int run) throws Exception {
168
169    if (getRawOutput()) {
170      if (m_ZipDest == null) {
171        m_ZipDest = new OutputZipper(m_OutputFile);
172      }
173    }
174
175    if (m_Instances == null) {
176      throw new Exception("No Instances set");
177    }
178
179    // Compute run and fold number from given run
180    int fold = (run - 1) % m_NumFolds;
181    run = ((run - 1) / m_NumFolds) + 1; 
182   
183
184    // Randomize on a copy of the original dataset
185    Instances runInstances = new Instances(m_Instances);
186    Random random = new Random(run);
187    runInstances.randomize(random);
188    if (runInstances.classAttribute().isNominal()) {
189      runInstances.stratify(m_NumFolds);
190    }
191
192    // Add in some fields to the key like run and fold number, dataset name
193    Object [] seKey = m_SplitEvaluator.getKey();
194    Object [] key = new Object [seKey.length + 3];
195    key[0] =  Utils.backQuoteChars(m_Instances.relationName());
196    key[1] = "" + run;
197    key[2] = "" + (fold + 1);
198    System.arraycopy(seKey, 0, key, 3, seKey.length);
199    if (m_ResultListener.isResultRequired(this, key)) {
200      Instances train = runInstances.trainCV(m_NumFolds, fold, random);
201      Instances test = runInstances.testCV(m_NumFolds, fold);
202      try {
203        Object [] seResults = m_SplitEvaluator.getResult(train, test);
204        Object [] results = new Object [seResults.length + 1];
205        results[0] = getTimestamp();
206        System.arraycopy(seResults, 0, results, 1,
207                         seResults.length);
208        if (m_debugOutput) {
209          String resultName = (""+run+"."+(fold+1)+"."
210                               + Utils.backQuoteChars(runInstances.relationName())
211                               +"."
212                               +m_SplitEvaluator.toString()).replace(' ','_');
213          resultName = Utils.removeSubstring(resultName, 
214                                             "weka.classifiers.");
215          resultName = Utils.removeSubstring(resultName, 
216                                             "weka.filters.");
217          resultName = Utils.removeSubstring(resultName, 
218                                             "weka.attributeSelection.");
219          m_ZipDest.zipit(m_SplitEvaluator.getRawResultOutput(), resultName);
220        }
221        m_ResultListener.acceptResult(this, key, results);
222      } catch (Exception ex) {
223        // Save the train and test datasets for debugging purposes?
224        throw ex;
225      }
226    }
227  }
228
229  /**
230   * Gets a text descrption of the result producer.
231   *
232   * @return a text description of the result producer.
233   */
234  public String toString() {
235
236    String result = "CrossValidationSplitResultProducer: ";
237    result += getCompatibilityState();
238    if (m_Instances == null) {
239      result += ": <null Instances>";
240    } else {
241      result += ": " + Utils.backQuoteChars(m_Instances.relationName());
242    }
243    return result;
244  }
245
246  /**
247   * Returns the revision string.
248   *
249   * @return            the revision
250   */
251  public String getRevision() {
252    return RevisionUtils.extract("$Revision: 5828 $");
253  }
254} // CrossValidationSplitResultProducer
255
Note: See TracBrowser for help on using the repository browser.