source: tags/MetisMQIDemo/src/main/java/weka/filters/supervised/instance/StratifiedRemoveFolds.java

Last change on this file was 29, checked in by gnappo, 15 years ago

Taggata versione per la demo e aggiunto branch.

File size: 12.0 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    StratifiedRemoveFolds.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23
24package weka.filters.supervised.instance;
25
26import weka.core.Capabilities;
27import weka.core.Instance;
28import weka.core.Instances;
29import weka.core.Option;
30import weka.core.OptionHandler;
31import weka.core.RevisionUtils;
32import weka.core.Utils;
33import weka.core.Capabilities.Capability;
34import weka.filters.Filter;
35import weka.filters.SupervisedFilter;
36
37import java.util.Enumeration;
38import java.util.Random;
39import java.util.Vector;
40
41/**
42 <!-- globalinfo-start -->
43 * This filter takes a dataset and outputs a specified fold for cross validation. If you do not want the folds to be stratified use the unsupervised version.
44 * <p/>
45 <!-- globalinfo-end -->
46 *
47 <!-- options-start -->
48 * Valid options are: <p/>
49 *
50 * <pre> -V
51 *  Specifies if inverse of selection is to be output.
52 * </pre>
53 *
54 * <pre> -N &lt;number of folds&gt;
55 *  Specifies number of folds dataset is split into.
56 *  (default 10)
57 * </pre>
58 *
59 * <pre> -F &lt;fold&gt;
60 *  Specifies which fold is selected. (default 1)
61 * </pre>
62 *
63 * <pre> -S &lt;seed&gt;
64 *  Specifies random number seed. (default 0, no randomizing)
65 * </pre>
66 *
67 <!-- options-end -->
68 *
69 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
70 * @version $Revision: 5492 $
71 */
72public class StratifiedRemoveFolds 
73  extends Filter
74  implements SupervisedFilter, OptionHandler {
75 
76  /** for serialization */
77  static final long serialVersionUID = -7069148179905814324L;
78
79  /** Indicates if inverse of selection is to be output. */
80  private boolean m_Inverse = false;
81
82  /** Number of folds to split dataset into */
83  private int m_NumFolds = 10;
84
85  /** Fold to output */
86  private int m_Fold = 1;
87
88  /** Random number seed. */
89  private long m_Seed = 0;
90
91  /**
92   * Gets an enumeration describing the available options..
93   *
94   * @return an enumeration of all the available options.
95   */
96  public Enumeration listOptions() {
97
98    Vector newVector = new Vector(6);
99
100    newVector.addElement(new Option(
101              "\tSpecifies if inverse of selection is to be output.\n",
102              "V", 0, "-V"));
103
104    newVector.addElement(new Option(
105              "\tSpecifies number of folds dataset is split into. \n"
106              + "\t(default 10)\n",
107              "N", 1, "-N <number of folds>"));
108
109    newVector.addElement(new Option(
110              "\tSpecifies which fold is selected. (default 1)\n",
111              "F", 1, "-F <fold>"));
112
113    newVector.addElement(new Option(
114              "\tSpecifies random number seed. (default 0, no randomizing)\n",
115              "S", 1, "-S <seed>"));
116
117    return newVector.elements();
118  }
119
120  /**
121   * Parses a given list of options. <p/>
122   *
123   <!-- options-start -->
124   * Valid options are: <p/>
125   *
126   * <pre> -V
127   *  Specifies if inverse of selection is to be output.
128   * </pre>
129   *
130   * <pre> -N &lt;number of folds&gt;
131   *  Specifies number of folds dataset is split into.
132   *  (default 10)
133   * </pre>
134   *
135   * <pre> -F &lt;fold&gt;
136   *  Specifies which fold is selected. (default 1)
137   * </pre>
138   *
139   * <pre> -S &lt;seed&gt;
140   *  Specifies random number seed. (default 0, no randomizing)
141   * </pre>
142   *
143   <!-- options-end -->
144   *
145   * @param options the list of options as an array of strings
146   * @throws Exception if an option is not supported
147   */
148  public void setOptions(String[] options) throws Exception {
149
150    setInvertSelection(Utils.getFlag('V', options));
151    String numFolds = Utils.getOption('N', options);
152    if (numFolds.length() != 0) {
153      setNumFolds(Integer.parseInt(numFolds));
154    } else {
155      setNumFolds(10);
156    }
157    String fold = Utils.getOption('F', options);
158    if (fold.length() != 0) {
159      setFold(Integer.parseInt(fold));
160    } else {
161      setFold(1);
162    }
163    String seed = Utils.getOption('S', options);
164    if (seed.length() != 0) {
165      setSeed(Integer.parseInt(seed));
166    } else {
167      setSeed(0);
168    }
169    if (getInputFormat() != null) {
170      setInputFormat(getInputFormat());
171    }
172  }
173
174  /**
175   * Gets the current settings of the filter.
176   *
177   * @return an array of strings suitable for passing to setOptions
178   */
179  public String [] getOptions() {
180
181    String [] options = new String [8];
182    int current = 0;
183
184    options[current++] = "-S"; options[current++] = "" + getSeed();
185    if (getInvertSelection()) {
186      options[current++] = "-V";
187    }
188    options[current++] = "-N"; options[current++] = "" + getNumFolds();
189    options[current++] = "-F"; options[current++] = "" + getFold();
190    while (current < options.length) {
191      options[current++] = "";
192    }
193    return options;
194  }
195
196  /**
197   * Returns a string describing this filter
198   *
199   * @return a description of the filter suitable for
200   * displaying in the explorer/experimenter gui
201   */
202  public String globalInfo() {
203    return 
204        "This filter takes a dataset and outputs a specified fold for "
205      + "cross validation. If you do not want the folds to be stratified "
206      + "use the unsupervised version.";
207  }
208
209  /**
210   * Returns the tip text for this property
211   *
212   * @return tip text for this property suitable for
213   * displaying in the explorer/experimenter gui
214   */
215  public String invertSelectionTipText() {
216
217    return "Whether to invert the selection.";
218  }
219
220  /**
221   * Gets if selection is to be inverted.
222   *
223   * @return true if the selection is to be inverted
224   */
225  public boolean getInvertSelection() {
226
227    return m_Inverse;
228  }
229
230  /**
231   * Sets if selection is to be inverted.
232   *
233   * @param inverse true if inversion is to be performed
234   */
235  public void setInvertSelection(boolean inverse) {
236   
237    m_Inverse = inverse;
238  }
239
240  /**
241   * Returns the tip text for this property
242   *
243   * @return tip text for this property suitable for
244   * displaying in the explorer/experimenter gui
245   */
246  public String numFoldsTipText() {
247
248    return "The number of folds to split the dataset into.";
249  }
250
251  /**
252   * Gets the number of folds in which dataset is to be split into.
253   *
254   * @return the number of folds the dataset is to be split into.
255   */
256  public int getNumFolds() {
257
258    return m_NumFolds;
259  }
260
261  /**
262   * Sets the number of folds the dataset is split into. If the number
263   * of folds is zero, it won't split it into folds.
264   *
265   * @param numFolds number of folds dataset is to be split into
266   * @throws IllegalArgumentException if number of folds is negative
267   */
268  public void setNumFolds(int numFolds) {
269
270    if (numFolds < 0) {
271      throw new IllegalArgumentException("Number of folds has to be positive or zero.");
272    }
273    m_NumFolds = numFolds;
274  }
275
276  /**
277   * Returns the tip text for this property
278   *
279   * @return tip text for this property suitable for
280   * displaying in the explorer/experimenter gui
281   */
282  public String foldTipText() {
283
284    return "The fold which is selected.";
285  }
286
287  /**
288   * Gets the fold which is selected.
289   *
290   * @return the fold which is selected
291   */
292  public int getFold() {
293
294    return m_Fold;
295  }
296
297  /**
298   * Selects a fold.
299   *
300   * @param fold the fold to be selected.
301   * @throws IllegalArgumentException if fold's index is smaller than 1
302   */
303  public void setFold(int fold) {
304
305    if (fold < 1) {
306      throw new IllegalArgumentException("Fold's index has to be greater than 0.");
307    }
308    m_Fold = fold;
309  }
310
311  /**
312   * Returns the tip text for this property
313   *
314   * @return tip text for this property suitable for
315   * displaying in the explorer/experimenter gui
316   */
317  public String seedTipText() {
318
319    return "the random number seed for shuffling the dataset. If seed is negative, shuffling will not be performed.";
320  }
321
322  /**
323   * Gets the random number seed used for shuffling the dataset.
324   *
325   * @return the random number seed
326   */
327  public long getSeed() {
328
329    return m_Seed;
330  }
331
332  /**
333   * Sets the random number seed for shuffling the dataset. If seed
334   * is negative, shuffling won't be performed.
335   *
336   * @param seed the random number seed
337   */
338  public void setSeed(long seed) {
339   
340    m_Seed = seed;
341  }
342
343  /**
344   * Returns the Capabilities of this filter.
345   *
346   * @return            the capabilities of this object
347   * @see               Capabilities
348   */
349  public Capabilities getCapabilities() {
350    Capabilities result = super.getCapabilities();
351    result.disableAll();
352
353    // attributes
354    result.enableAllAttributes();
355    result.enable(Capability.MISSING_VALUES);
356   
357    // class
358    result.enableAllClasses();
359    result.enable(Capability.MISSING_CLASS_VALUES);
360   
361    return result;
362  }
363
364  /**
365   * Sets the format of the input instances.
366   *
367   * @param instanceInfo an Instances object containing the input instance
368   * structure (any instances contained in the object are ignored - only the
369   * structure is required).
370   * @return true because outputFormat can be collected immediately
371   * @throws Exception if the input format can't be set successfully
372   */ 
373  public boolean setInputFormat(Instances instanceInfo) throws Exception {
374
375    if ((m_NumFolds > 0) && (m_NumFolds < m_Fold)) {
376      throw new IllegalArgumentException("Fold has to be smaller or equal to "+
377                                         "number of folds.");
378    }
379    super.setInputFormat(instanceInfo);
380    setOutputFormat(instanceInfo);
381    return true;
382  }
383
384  /**
385   * Input an instance for filtering. Filter requires all
386   * training instances be read before producing output.
387   *
388   * @param instance the input instance
389   * @return true if the filtered instance may now be
390   * collected with output().
391   * @throws IllegalStateException if no input structure has been defined
392   */
393  public boolean input(Instance instance) {
394
395    if (getInputFormat() == null) {
396      throw new IllegalStateException("No input instance format defined");
397    }
398    if (m_NewBatch) {
399      resetQueue();
400      m_NewBatch = false;
401    }
402    if (isFirstBatchDone()) {
403      push(instance);
404      return true;
405    } else {
406      bufferInput(instance);
407      return false;
408    }
409  }
410
411  /**
412   * Signify that this batch of input to the filter is
413   * finished. Output() may now be called to retrieve the filtered
414   * instances.
415   *
416   * @return true if there are instances pending output
417   * @throws IllegalStateException if no input structure has been defined
418   */
419  public boolean batchFinished() {
420
421    if (getInputFormat() == null) {
422      throw new IllegalStateException("No input instance format defined");
423    }
424   
425    Instances instances;
426
427    if (!isFirstBatchDone()) {
428      if (m_Seed > 0) {
429        // User has provided a random number seed.
430        getInputFormat().randomize(new Random(m_Seed));
431      }
432
433      // Select out a fold
434      getInputFormat().stratify(m_NumFolds);
435      if (!m_Inverse) {
436        instances = getInputFormat().testCV(m_NumFolds, m_Fold - 1);
437      } else {
438        instances = getInputFormat().trainCV(m_NumFolds, m_Fold - 1);
439      }
440    }
441    else {
442      instances = getInputFormat();
443    }
444   
445    flushInput();
446
447    for (int i = 0; i < instances.numInstances(); i++) {
448      push(instances.instance(i));
449    }
450
451    m_NewBatch = true;
452    m_FirstBatchDone = true;
453    return (numPendingOutput() != 0);
454  }
455 
456  /**
457   * Returns the revision string.
458   *
459   * @return            the revision
460   */
461  public String getRevision() {
462    return RevisionUtils.extract("$Revision: 5492 $");
463  }
464
465  /**
466   * Main method for testing this class.
467   *
468   * @param argv should contain arguments to the filter: use -h for help
469   */
470  public static void main(String [] argv) {
471    runFilter(new StratifiedRemoveFolds(), argv);
472  }
473}
Note: See TracBrowser for help on using the repository browser.