source: tags/MetisMQIDemo/src/main/java/weka/filters/supervised/instance/SpreadSubsample.java

Last change on this file was 29, checked in by gnappo, 15 years ago

Taggata versione per la demo e aggiunto branch.

File size: 17.6 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    SpreadSubsample.java
19 *    Copyright (C) 2002 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23
24package weka.filters.supervised.instance;
25
26import weka.core.Capabilities;
27import weka.core.Instance;
28import weka.core.Instances;
29import weka.core.Option;
30import weka.core.OptionHandler;
31import weka.core.RevisionUtils;
32import weka.core.UnassignedClassException;
33import weka.core.UnsupportedClassTypeException;
34import weka.core.Utils;
35import weka.core.Capabilities.Capability;
36import weka.filters.Filter;
37import weka.filters.SupervisedFilter;
38
39import java.util.Enumeration;
40import java.util.Hashtable;
41import java.util.Random;
42import java.util.Vector;
43
44/**
45 <!-- globalinfo-start -->
46 * Produces a random subsample of a dataset. The original dataset must fit entirely in memory. This filter allows you to specify the maximum "spread" between the rarest and most common class. For example, you may specify that there be at most a 2:1 difference in class frequencies. When used in batch mode, subsequent batches are NOT resampled.
47 * <p/>
48 <!-- globalinfo-end -->
49 *
50 <!-- options-start -->
51 * Valid options are: <p/>
52 *
53 * <pre> -S &lt;num&gt;
54 *  Specify the random number seed (default 1)</pre>
55 *
56 * <pre> -M &lt;num&gt;
57 *  The maximum class distribution spread.
58 *  0 = no maximum spread, 1 = uniform distribution, 10 = allow at most
59 *  a 10:1 ratio between the classes (default 0)</pre>
60 *
61 * <pre> -W
62 *  Adjust weights so that total weight per class is maintained.
63 *  Individual instance weighting is not preserved. (default no
64 *  weights adjustment</pre>
65 *
66 * <pre> -X &lt;num&gt;
67 *  The maximum count for any class value (default 0 = unlimited).
68 * </pre>
69 *
70 <!-- options-end -->
71 *
72 * @author Stuart Inglis (stuart@reeltwo.com)
73 * @version $Revision: 5492 $
74 **/
75public class SpreadSubsample 
76  extends Filter
77  implements SupervisedFilter, OptionHandler {
78 
79  /** for serialization */
80  static final long serialVersionUID = -3947033795243930016L;
81
82  /** The random number generator seed */
83  private int m_RandomSeed = 1;
84
85  /** The maximum count of any class */
86  private int m_MaxCount;
87
88  /** True if the first batch has been done */
89  private double m_DistributionSpread = 0;
90
91  /**
92   * True if instance weights will be adjusted to maintain
93   * total weight per class.
94   */
95  private boolean m_AdjustWeights = false;
96
97  /**
98   * Returns a string describing this filter
99   *
100   * @return a description of the filter suitable for
101   * displaying in the explorer/experimenter gui
102   */
103  public String globalInfo() {
104
105    return "Produces a random subsample of a dataset. The original dataset must "
106      + "fit entirely in memory. This filter allows you to specify the maximum "
107      + "\"spread\" between the rarest and most common class. For example, you may "
108      + "specify that there be at most a 2:1 difference in class frequencies. "
109      + "When used in batch mode, subsequent batches are NOT resampled.";
110
111  }
112   
113  /**
114   * Returns the tip text for this property
115   *
116   * @return tip text for this property suitable for
117   * displaying in the explorer/experimenter gui
118   */
119  public String adjustWeightsTipText() {
120    return "Wether instance weights will be adjusted to maintain total weight per "
121      + "class.";
122  }
123 
124  /**
125   * Returns true if instance  weights will be adjusted to maintain
126   * total weight per class.
127   *
128   * @return true if instance weights will be adjusted to maintain
129   * total weight per class.
130   */
131  public boolean getAdjustWeights() {
132
133    return m_AdjustWeights;
134  }
135 
136  /**
137   * Sets whether the instance weights will be adjusted to maintain
138   * total weight per class.
139   *
140   * @param newAdjustWeights whether to adjust weights
141   */
142  public void setAdjustWeights(boolean newAdjustWeights) {
143
144    m_AdjustWeights = newAdjustWeights;
145  }
146 
147  /**
148   * Returns an enumeration describing the available options.
149   *
150   * @return an enumeration of all the available options.
151   */
152  public Enumeration listOptions() {
153
154    Vector newVector = new Vector(4);
155
156    newVector.addElement(new Option(
157              "\tSpecify the random number seed (default 1)",
158              "S", 1, "-S <num>"));
159    newVector.addElement(new Option(
160              "\tThe maximum class distribution spread.\n"
161              +"\t0 = no maximum spread, 1 = uniform distribution, 10 = allow at most\n"
162              +"\ta 10:1 ratio between the classes (default 0)",
163              "M", 1, "-M <num>"));
164    newVector.addElement(new Option(
165              "\tAdjust weights so that total weight per class is maintained.\n"
166              +"\tIndividual instance weighting is not preserved. (default no\n"
167              +"\tweights adjustment",
168              "W", 0, "-W"));
169    newVector.addElement(new Option(
170              "\tThe maximum count for any class value (default 0 = unlimited).\n",
171              "X", 0, "-X <num>"));
172
173    return newVector.elements();
174  }
175
176
177  /**
178   * Parses a given list of options. <p/>
179   *
180   <!-- options-start -->
181   * Valid options are: <p/>
182   *
183   * <pre> -S &lt;num&gt;
184   *  Specify the random number seed (default 1)</pre>
185   *
186   * <pre> -M &lt;num&gt;
187   *  The maximum class distribution spread.
188   *  0 = no maximum spread, 1 = uniform distribution, 10 = allow at most
189   *  a 10:1 ratio between the classes (default 0)</pre>
190   *
191   * <pre> -W
192   *  Adjust weights so that total weight per class is maintained.
193   *  Individual instance weighting is not preserved. (default no
194   *  weights adjustment</pre>
195   *
196   * <pre> -X &lt;num&gt;
197   *  The maximum count for any class value (default 0 = unlimited).
198   * </pre>
199   *
200   <!-- options-end -->
201   *
202   * @param options the list of options as an array of strings
203   * @throws Exception if an option is not supported
204   */
205  public void setOptions(String[] options) throws Exception {
206   
207    String seedString = Utils.getOption('S', options);
208    if (seedString.length() != 0) {
209      setRandomSeed(Integer.parseInt(seedString));
210    } else {
211      setRandomSeed(1);
212    }
213
214    String maxString = Utils.getOption('M', options);
215    if (maxString.length() != 0) {
216      setDistributionSpread(Double.valueOf(maxString).doubleValue());
217    } else {
218      setDistributionSpread(0);
219    }
220
221    String maxCount = Utils.getOption('X', options);
222    if (maxCount.length() != 0) {
223      setMaxCount(Double.valueOf(maxCount).doubleValue());
224    } else {
225      setMaxCount(0);
226    }
227
228    setAdjustWeights(Utils.getFlag('W', options));
229
230    if (getInputFormat() != null) {
231      setInputFormat(getInputFormat());
232    }
233  }
234
235  /**
236   * Gets the current settings of the filter.
237   *
238   * @return an array of strings suitable for passing to setOptions
239   */
240  public String [] getOptions() {
241
242    String [] options = new String [7];
243    int current = 0;
244
245    options[current++] = "-M"; 
246    options[current++] = "" + getDistributionSpread();
247
248    options[current++] = "-X"; 
249    options[current++] = "" + getMaxCount();
250
251    options[current++] = "-S"; 
252    options[current++] = "" + getRandomSeed();
253
254    if (getAdjustWeights()) {
255      options[current++] = "-W";
256    }
257
258    while (current < options.length) {
259      options[current++] = "";
260    }
261    return options;
262  }
263   
264  /**
265   * Returns the tip text for this property
266   *
267   * @return tip text for this property suitable for
268   * displaying in the explorer/experimenter gui
269   */
270  public String distributionSpreadTipText() {
271    return "The maximum class distribution spread. "
272      + "(0 = no maximum spread, 1 = uniform distribution, 10 = allow at most a "
273      + "10:1 ratio between the classes).";
274  }
275 
276  /**
277   * Sets the value for the distribution spread
278   *
279   * @param spread the new distribution spread
280   */
281  public void setDistributionSpread(double spread) {
282
283    m_DistributionSpread = spread;
284  }
285
286  /**
287   * Gets the value for the distribution spread
288   *
289   * @return the distribution spread
290   */   
291  public double getDistributionSpread() {
292
293    return m_DistributionSpread;
294  }
295   
296  /**
297   * Returns the tip text for this property
298   *
299   * @return tip text for this property suitable for
300   * displaying in the explorer/experimenter gui
301   */
302  public String maxCountTipText() {
303    return "The maximum count for any class value (0 = unlimited).";
304  }
305 
306  /**
307   * Sets the value for the max count
308   *
309   * @param maxcount the new max count
310   */
311  public void setMaxCount(double maxcount) {
312
313    m_MaxCount = (int)maxcount;
314  }
315
316  /**
317   * Gets the value for the max count
318   *
319   * @return the max count
320   */   
321  public double getMaxCount() {
322
323    return m_MaxCount;
324  }
325   
326  /**
327   * Returns the tip text for this property
328   *
329   * @return tip text for this property suitable for
330   * displaying in the explorer/experimenter gui
331   */
332  public String randomSeedTipText() {
333    return "Sets the random number seed for subsampling.";
334  }
335 
336  /**
337   * Gets the random number seed.
338   *
339   * @return the random number seed.
340   */
341  public int getRandomSeed() {
342
343    return m_RandomSeed;
344  }
345 
346  /**
347   * Sets the random number seed.
348   *
349   * @param newSeed the new random number seed.
350   */
351  public void setRandomSeed(int newSeed) {
352
353    m_RandomSeed = newSeed;
354  }
355
356  /**
357   * Returns the Capabilities of this filter.
358   *
359   * @return            the capabilities of this object
360   * @see               Capabilities
361   */
362  public Capabilities getCapabilities() {
363    Capabilities result = super.getCapabilities();
364    result.disableAll();
365
366    // attributes
367    result.enableAllAttributes();
368    result.enable(Capability.MISSING_VALUES);
369   
370    // class
371    result.enable(Capability.NOMINAL_CLASS);
372   
373    return result;
374  }
375 
376  /**
377   * Sets the format of the input instances.
378   *
379   * @param instanceInfo an Instances object containing the input
380   * instance structure (any instances contained in the object are
381   * ignored - only the structure is required).
382   * @return true if the outputFormat may be collected immediately
383   * @throws UnassignedClassException if no class attribute has been set.
384   * @throws UnsupportedClassTypeException if the class attribute
385   * is not nominal.
386   */
387  public boolean setInputFormat(Instances instanceInfo) 
388       throws Exception {
389
390    super.setInputFormat(instanceInfo);
391    setOutputFormat(instanceInfo);
392    return true;
393  }
394
395  /**
396   * Input an instance for filtering. Filter requires all
397   * training instances be read before producing output.
398   *
399   * @param instance the input instance
400   * @return true if the filtered instance may now be
401   * collected with output().
402   * @throws IllegalStateException if no input structure has been defined
403   */
404  public boolean input(Instance instance) {
405
406    if (getInputFormat() == null) {
407      throw new IllegalStateException("No input instance format defined");
408    }
409    if (m_NewBatch) {
410      resetQueue();
411      m_NewBatch = false;
412    }
413    if (isFirstBatchDone()) {
414      push(instance);
415      return true;
416    } else {
417      bufferInput(instance);
418      return false;
419    }
420  }
421
422  /**
423   * Signify that this batch of input to the filter is finished.
424   * If the filter requires all instances prior to filtering,
425   * output() may now be called to retrieve the filtered instances.
426   *
427   * @return true if there are instances pending output
428   * @throws IllegalStateException if no input structure has been defined
429   */
430  public boolean batchFinished() {
431
432    if (getInputFormat() == null) {
433      throw new IllegalStateException("No input instance format defined");
434    }
435
436    if (!isFirstBatchDone()) {
437      // Do the subsample, and clear the input instances.
438      createSubsample();
439    }
440
441    flushInput();
442    m_NewBatch = true;
443    m_FirstBatchDone = true;
444    return (numPendingOutput() != 0);
445  }
446
447
448  /**
449   * Creates a subsample of the current set of input instances. The output
450   * instances are pushed onto the output queue for collection.
451   */
452  private void createSubsample() {
453
454    int classI = getInputFormat().classIndex();
455    // Sort according to class attribute.
456    getInputFormat().sort(classI);
457    // Determine where each class starts in the sorted dataset
458    int [] classIndices = getClassIndices();
459
460    // Get the existing class distribution
461    int [] counts = new int [getInputFormat().numClasses()];
462    double [] weights = new double [getInputFormat().numClasses()];
463    int min = -1;
464    for (int i = 0; i < getInputFormat().numInstances(); i++) {
465      Instance current = getInputFormat().instance(i);
466      if (current.classIsMissing() == false) {
467        counts[(int)current.classValue()]++;
468        weights[(int)current.classValue()]+= current.weight();
469      }
470    }
471
472    // Convert from total weight to average weight
473    for (int i = 0; i < counts.length; i++) {
474      if (counts[i] > 0) {
475        weights[i] = weights[i] / counts[i];
476      }
477      /*
478      System.err.println("Class:" + i + " " + getInputFormat().classAttribute().value(i)
479                         + " Count:" + counts[i]
480                         + " Total:" + weights[i] * counts[i]
481                         + " Avg:" + weights[i]);
482      */
483    }
484   
485    // find the class with the minimum number of instances
486    int minIndex = -1;
487    for (int i = 0; i < counts.length; i++) {
488      if ( (min < 0) && (counts[i] > 0) ) {
489        min = counts[i];
490        minIndex = i;
491      } else if ((counts[i] < min) && (counts[i] > 0)) {
492        min = counts[i];
493        minIndex = i;
494      }
495    }
496
497    if (min < 0) { 
498        System.err.println("SpreadSubsample: *warning* none of the classes have any values in them.");
499        return;
500    }
501
502    // determine the new distribution
503    int [] new_counts = new int [getInputFormat().numClasses()];
504    for (int i = 0; i < counts.length; i++) {
505      new_counts[i] = (int)Math.abs(Math.min(counts[i],
506                                             min * m_DistributionSpread));
507      if (i == minIndex) {
508        if (m_DistributionSpread > 0 && m_DistributionSpread < 1.0) {
509          // don't undersample the minority class!
510          new_counts[i] = counts[i];
511        }
512      }
513      if (m_DistributionSpread == 0) {
514        new_counts[i] = counts[i];
515      }
516
517      if (m_MaxCount > 0) {
518        new_counts[i] = Math.min(new_counts[i], m_MaxCount);
519      }
520    }
521
522    // Sample without replacement
523    Random random = new Random(m_RandomSeed);
524    Hashtable t = new Hashtable();
525    for (int j = 0; j < new_counts.length; j++) {
526      double newWeight = 1.0;
527      if (m_AdjustWeights && (new_counts[j] > 0)) {
528        newWeight = weights[j] * counts[j] / new_counts[j];
529        /*
530        System.err.println("Class:" + j + " " + getInputFormat().classAttribute().value(j)
531                           + " Count:" + counts[j]
532                           + " Total:" + weights[j] * counts[j]
533                           + " Avg:" + weights[j]
534                           + " NewCount:" + new_counts[j]
535                           + " NewAvg:" + newWeight);
536        */
537      }
538      for (int k = 0; k < new_counts[j]; k++) {
539        boolean ok = false;
540        do {
541          int index = classIndices[j] + (Math.abs(random.nextInt()) 
542                                         % (classIndices[j + 1] - classIndices[j])) ;
543          // Have we used this instance before?
544          if (t.get("" + index) == null) {
545            // if not, add it to the hashtable and use it
546            t.put("" + index, "");
547            ok = true;
548            if(index >= 0) {
549              Instance newInst = (Instance)getInputFormat().instance(index).copy();
550              if (m_AdjustWeights) {
551                newInst.setWeight(newWeight);
552              }
553              push(newInst);
554            }
555          }
556        } while (!ok);
557      }
558    }
559  }
560
561  /**
562   * Creates an index containing the position where each class starts in
563   * the getInputFormat(). m_InputFormat must be sorted on the class attribute.
564   *
565   * @return the positions
566   */
567  private int[] getClassIndices() {
568
569    // Create an index of where each class value starts
570    int [] classIndices = new int [getInputFormat().numClasses() + 1];
571    int currentClass = 0;
572    classIndices[currentClass] = 0;
573    for (int i = 0; i < getInputFormat().numInstances(); i++) {
574      Instance current = getInputFormat().instance(i);
575      if (current.classIsMissing()) {
576        for (int j = currentClass + 1; j < classIndices.length; j++) {
577          classIndices[j] = i;
578        }
579        break;
580      } else if (current.classValue() != currentClass) {
581        for (int j = currentClass + 1; j <= current.classValue(); j++) {
582          classIndices[j] = i;
583        }         
584        currentClass = (int) current.classValue();
585      }
586    }
587    if (currentClass <= getInputFormat().numClasses()) {
588      for (int j = currentClass + 1; j < classIndices.length; j++) {
589        classIndices[j] = getInputFormat().numInstances();
590      }
591    }
592    return classIndices;
593  }
594 
595  /**
596   * Returns the revision string.
597   *
598   * @return            the revision
599   */
600  public String getRevision() {
601    return RevisionUtils.extract("$Revision: 5492 $");
602  }
603
604  /**
605   * Main method for testing this class.
606   *
607   * @param argv should contain arguments to the filter:
608   * use -h for help
609   */
610  public static void main(String [] argv) {
611    runFilter(new SpreadSubsample(), argv);
612  }
613}
Note: See TracBrowser for help on using the repository browser.