source: tags/MetisMQIDemo/src/main/java/weka/filters/supervised/instance/SMOTE.java

Last change on this file was 29, checked in by gnappo, 15 years ago

Taggata versione per la demo e aggiunto branch.

File size: 21.4 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 * SMOTE.java
19 *
20 * Copyright (C) 2008 Ryan Lichtenwalter
21 * Copyright (C) 2008 University of Waikato, Hamilton, New Zealand
22 */
23
24package weka.filters.supervised.instance;
25
26import weka.core.Attribute;
27import weka.core.Capabilities;
28import weka.core.Instance;
29import weka.core.DenseInstance;
30import weka.core.Instances;
31import weka.core.Option;
32import weka.core.OptionHandler;
33import weka.core.RevisionUtils;
34import weka.core.TechnicalInformation;
35import weka.core.TechnicalInformationHandler;
36import weka.core.Utils;
37import weka.core.Capabilities.Capability;
38import weka.core.TechnicalInformation.Field;
39import weka.core.TechnicalInformation.Type;
40import weka.filters.Filter;
41import weka.filters.SupervisedFilter;
42
43import java.util.Collections;
44import java.util.Comparator;
45import java.util.Enumeration;
46import java.util.HashMap;
47import java.util.HashSet;
48import java.util.Iterator;
49import java.util.LinkedList;
50import java.util.List;
51import java.util.Map;
52import java.util.Random;
53import java.util.Set;
54import java.util.Vector;
55
56/**
57 <!-- globalinfo-start -->
58 * Resamples a dataset by applying the Synthetic Minority Oversampling TEchnique (SMOTE). The original dataset must fit entirely in memory. The amount of SMOTE and number of nearest neighbors may be specified. For more information, see <br/>
59 * <br/>
60 * Nitesh V. Chawla et. al. (2002). Synthetic Minority Over-sampling Technique. Journal of Artificial Intelligence Research. 16:321-357.
61 * <p/>
62 <!-- globalinfo-end -->
63 *
64 <!-- technical-bibtex-start -->
65 * BibTeX:
66 * <pre>
67 * &#64;article{al.2002,
68 *    author = {Nitesh V. Chawla et. al.},
69 *    journal = {Journal of Artificial Intelligence Research},
70 *    pages = {321-357},
71 *    title = {Synthetic Minority Over-sampling Technique},
72 *    volume = {16},
73 *    year = {2002}
74 * }
75 * </pre>
76 * <p/>
77 <!-- technical-bibtex-end -->
78 *
79 <!-- options-start -->
80 * Valid options are: <p/>
81 *
82 * <pre> -S &lt;num&gt;
83 *  Specifies the random number seed
84 *  (default 1)</pre>
85 *
86 * <pre> -P &lt;percentage&gt;
87 *  Specifies percentage of SMOTE instances to create.
88 *  (default 100.0)
89 * </pre>
90 *
91 * <pre> -K &lt;nearest-neighbors&gt;
92 *  Specifies the number of nearest neighbors to use.
93 *  (default 5)
94 * </pre>
95 *
96 * <pre> -C &lt;value-index&gt;
97 *  Specifies the index of the nominal class value to SMOTE
98 *  (default 0: auto-detect non-empty minority class))
99 * </pre>
100 *
101 <!-- options-end -->
102 * 
103 * @author Ryan Lichtenwalter (rlichtenwalter@gmail.com)
104 * @version $Revision: 5987 $
105 */
106public class SMOTE
107  extends Filter
108  implements SupervisedFilter, OptionHandler, TechnicalInformationHandler {
109
110  /** for serialization. */
111  static final long serialVersionUID = -1653880819059250364L;
112
113  /** the number of neighbors to use. */
114  protected int m_NearestNeighbors = 5;
115 
116  /** the random seed to use. */
117  protected int m_RandomSeed = 1;
118 
119  /** the percentage of SMOTE instances to create. */
120  protected double m_Percentage = 100.0;
121 
122  /** the index of the class value. */
123  protected String m_ClassValueIndex = "0";
124 
125  /** whether to detect the minority class automatically. */
126  protected boolean m_DetectMinorityClass = true;
127
128  /**
129   * Returns a string describing this classifier.
130   *
131   * @return            a description of the classifier suitable for
132   *                    displaying in the explorer/experimenter gui
133   */
134  public String globalInfo() {
135    return "Resamples a dataset by applying the Synthetic Minority Oversampling TEchnique (SMOTE)." +
136    " The original dataset must fit entirely in memory." +
137    " The amount of SMOTE and number of nearest neighbors may be specified." +
138    " For more information, see \n\n" 
139    + getTechnicalInformation().toString();
140  }
141
142  /**
143   * Returns an instance of a TechnicalInformation object, containing
144   * detailed information about the technical background of this class,
145   * e.g., paper reference or book this class is based on.
146   *
147   * @return            the technical information about this class
148   */
149  public TechnicalInformation getTechnicalInformation() {
150    TechnicalInformation result = new TechnicalInformation(Type.ARTICLE);
151
152    result.setValue(Field.AUTHOR, "Nitesh V. Chawla et. al.");
153    result.setValue(Field.TITLE, "Synthetic Minority Over-sampling Technique");
154    result.setValue(Field.JOURNAL, "Journal of Artificial Intelligence Research");
155    result.setValue(Field.YEAR, "2002");
156    result.setValue(Field.VOLUME, "16");
157    result.setValue(Field.PAGES, "321-357");
158
159    return result;
160  }
161
162  /**
163   * Returns the revision string.
164   *
165   * @return            the revision
166   */
167  public String getRevision() {
168    return RevisionUtils.extract("$Revision: 5987 $");
169  }
170
171  /**
172   * Returns the Capabilities of this filter.
173   *
174   * @return            the capabilities of this object
175   * @see               Capabilities
176   */
177  public Capabilities getCapabilities() {
178    Capabilities result = super.getCapabilities();
179    result.disableAll();
180
181    // attributes
182    result.enableAllAttributes();
183    result.enable(Capability.MISSING_VALUES);
184
185    // class
186    result.enable(Capability.NOMINAL_CLASS);
187    result.enable(Capability.MISSING_CLASS_VALUES);
188
189    return result;
190  }
191
192  /**
193   * Returns an enumeration describing the available options.
194   *
195   * @return an enumeration of all the available options.
196   */
197  public Enumeration listOptions() {
198    Vector newVector = new Vector();
199   
200    newVector.addElement(new Option(
201        "\tSpecifies the random number seed\n"
202        + "\t(default 1)",
203        "S", 1, "-S <num>"));
204   
205    newVector.addElement(new Option(
206        "\tSpecifies percentage of SMOTE instances to create.\n"
207        + "\t(default 100.0)\n",
208        "P", 1, "-P <percentage>"));
209   
210    newVector.addElement(new Option(
211        "\tSpecifies the number of nearest neighbors to use.\n"
212        + "\t(default 5)\n",
213        "K", 1, "-K <nearest-neighbors>"));
214   
215    newVector.addElement(new Option(
216        "\tSpecifies the index of the nominal class value to SMOTE\n"
217        +"\t(default 0: auto-detect non-empty minority class))\n",
218        "C", 1, "-C <value-index>"));
219
220    return newVector.elements();
221  }
222
223  /**
224   * Parses a given list of options.
225   *
226   <!-- options-start -->
227   * Valid options are: <p/>
228   *
229   * <pre> -S &lt;num&gt;
230   *  Specifies the random number seed
231   *  (default 1)</pre>
232   *
233   * <pre> -P &lt;percentage&gt;
234   *  Specifies percentage of SMOTE instances to create.
235   *  (default 100.0)
236   * </pre>
237   *
238   * <pre> -K &lt;nearest-neighbors&gt;
239   *  Specifies the number of nearest neighbors to use.
240   *  (default 5)
241   * </pre>
242   *
243   * <pre> -C &lt;value-index&gt;
244   *  Specifies the index of the nominal class value to SMOTE
245   *  (default 0: auto-detect non-empty minority class))
246   * </pre>
247   *
248   <!-- options-end -->
249   *
250   * @param options the list of options as an array of strings
251   * @throws Exception if an option is not supported
252   */
253  public void setOptions(String[] options) throws Exception {
254    String seedStr = Utils.getOption('S', options);
255    if (seedStr.length() != 0) {
256      setRandomSeed(Integer.parseInt(seedStr));
257    } else {
258      setRandomSeed(1);
259    }
260
261    String percentageStr = Utils.getOption('P', options);
262    if (percentageStr.length() != 0) {
263      setPercentage(new Double(percentageStr).doubleValue());
264    } else {
265      setPercentage(100.0);
266    }
267
268    String nnStr = Utils.getOption('K', options);
269    if (nnStr.length() != 0) {
270      setNearestNeighbors(Integer.parseInt(nnStr));
271    } else {
272      setNearestNeighbors(5);
273    }
274
275    String classValueIndexStr = Utils.getOption( 'C', options);
276    if (classValueIndexStr.length() != 0) {
277      setClassValue(classValueIndexStr);
278    } else {
279      m_DetectMinorityClass = true;
280    }
281  }
282
283  /**
284   * Gets the current settings of the filter.
285   *
286   * @return an array   of strings suitable for passing to setOptions
287   */
288  public String[] getOptions() {
289    Vector<String>      result;
290   
291    result = new Vector<String>();
292   
293    result.add("-C");
294    result.add(getClassValue());
295   
296    result.add("-K");
297    result.add("" + getNearestNeighbors());
298   
299    result.add("-P");
300    result.add("" + getPercentage());
301   
302    result.add("-S");
303    result.add("" + getRandomSeed());
304   
305    return result.toArray(new String[result.size()]);
306  }
307
308  /**
309   * Returns the tip text for this property.
310   *
311   * @return            tip text for this property suitable for
312   *                    displaying in the explorer/experimenter gui
313   */
314  public String randomSeedTipText() {
315    return "The seed used for random sampling.";
316  }
317
318  /**
319   * Gets the random number seed.
320   *
321   * @return            the random number seed.
322   */
323  public int getRandomSeed() {
324    return m_RandomSeed;
325  }
326
327  /**
328   * Sets the random number seed.
329   *
330   * @param value       the new random number seed.
331   */
332  public void setRandomSeed(int value) {
333    m_RandomSeed = value;
334  }
335
336  /**
337   * Returns the tip text for this property.
338   *
339   * @return            tip text for this property suitable for
340   *                    displaying in the explorer/experimenter gui
341   */
342  public String percentageTipText() {
343    return "The percentage of SMOTE instances to create.";
344  }
345
346  /**
347   * Sets the percentage of SMOTE instances to create.
348   *
349   * @param value       the percentage to use
350   */
351  public void setPercentage(double value) {
352    if (value >= 0)
353      m_Percentage = value;
354    else
355      System.err.println("Percentage must be >= 0!");
356  }
357
358  /**
359   * Gets the percentage of SMOTE instances to create.
360   *
361   * @return            the percentage of SMOTE instances to create
362   */
363  public double getPercentage() {
364    return m_Percentage;
365  }
366
367  /**
368   * Returns the tip text for this property.
369   *
370   * @return            tip text for this property suitable for
371   *                    displaying in the explorer/experimenter gui
372   */
373  public String nearestNeighborsTipText() {
374    return "The number of nearest neighbors to use.";
375  }
376
377  /**
378   * Sets the number of nearest neighbors to use.
379   *
380   * @param value       the number of nearest neighbors to use
381   */
382  public void setNearestNeighbors(int value) {
383    if (value >= 1)
384      m_NearestNeighbors = value;
385    else
386      System.err.println("At least 1 neighbor necessary!");
387  }
388
389  /**
390   * Gets the number of nearest neighbors to use.
391   *
392   * @return            the number of nearest neighbors to use
393   */
394  public int getNearestNeighbors() {
395    return m_NearestNeighbors;
396  }
397
398  /**
399   * Returns the tip text for this property.
400   *
401   * @return            tip text for this property suitable for
402   *                    displaying in the explorer/experimenter gui
403   */
404  public String classValueTipText() {
405    return "The index of the class value to which SMOTE should be applied. " +
406    "Use a value of 0 to auto-detect the non-empty minority class.";
407  }
408
409  /**
410   * Sets the index of the class value to which SMOTE should be applied.
411   *
412   * @param value       the class value index
413   */
414  public void setClassValue(String value) {
415    m_ClassValueIndex = value;
416    if (m_ClassValueIndex.equals("0")) {
417      m_DetectMinorityClass = true;
418    } else {
419      m_DetectMinorityClass = false;
420    }
421  }
422
423  /**
424   * Gets the index of the class value to which SMOTE should be applied.
425   *
426   * @return            the index of the clas value to which SMOTE should be applied
427   */
428  public String getClassValue() {
429    return m_ClassValueIndex;
430  }
431
432  /**
433   * Sets the format of the input instances.
434   *
435   * @param instanceInfo        an Instances object containing the input
436   *                            instance structure (any instances contained in
437   *                            the object are ignored - only the structure is required).
438   * @return                    true if the outputFormat may be collected immediately
439   * @throws Exception          if the input format can't be set successfully
440   */
441  public boolean setInputFormat(Instances instanceInfo) throws Exception {
442    super.setInputFormat(instanceInfo);
443    super.setOutputFormat(instanceInfo);
444    return true;
445  }
446
447  /**
448   * Input an instance for filtering. Filter requires all
449   * training instances be read before producing output.
450   *
451   * @param instance            the input instance
452   * @return                    true if the filtered instance may now be
453   *                            collected with output().
454   * @throws IllegalStateException if no input structure has been defined
455   */
456  public boolean input(Instance instance) {
457    if (getInputFormat() == null) {
458      throw new IllegalStateException("No input instance format defined");
459    }
460    if (m_NewBatch) {
461      resetQueue();
462      m_NewBatch = false;
463    }
464    if (m_FirstBatchDone) {
465      push(instance);
466      return true;
467    } else {
468      bufferInput(instance);
469      return false;
470    }
471  }
472
473  /**
474   * Signify that this batch of input to the filter is finished.
475   * If the filter requires all instances prior to filtering,
476   * output() may now be called to retrieve the filtered instances.
477   *
478   * @return            true if there are instances pending output
479   * @throws IllegalStateException if no input structure has been defined
480   * @throws Exception  if provided options cannot be executed
481   *                    on input instances
482   */
483  public boolean batchFinished() throws Exception {
484    if (getInputFormat() == null) {
485      throw new IllegalStateException("No input instance format defined");
486    }
487
488    if (!m_FirstBatchDone) {
489      // Do SMOTE, and clear the input instances.
490      doSMOTE();
491    }
492    flushInput();
493
494    m_NewBatch = true;
495    m_FirstBatchDone = true;
496    return (numPendingOutput() != 0);
497  }
498
499  /**
500   * The procedure implementing the SMOTE algorithm. The output
501   * instances are pushed onto the output queue for collection.
502   *
503   * @throws Exception  if provided options cannot be executed
504   *                    on input instances
505   */
506  protected void doSMOTE() throws Exception {
507    int minIndex = 0;
508    int min = Integer.MAX_VALUE;
509    if (m_DetectMinorityClass) {
510      // find minority class
511      int[] classCounts = getInputFormat().attributeStats(getInputFormat().classIndex()).nominalCounts;
512      for (int i = 0; i < classCounts.length; i++) {
513        if (classCounts[i] != 0 && classCounts[i] < min) {
514          min = classCounts[i];
515          minIndex = i;
516        }
517      }
518    } else {
519      String classVal = getClassValue();
520      if (classVal.equalsIgnoreCase("first")) {
521        minIndex = 1;
522      } else if (classVal.equalsIgnoreCase("last")) {
523        minIndex = getInputFormat().numClasses();
524      } else {
525        minIndex = Integer.parseInt(classVal);
526      }
527      if (minIndex > getInputFormat().numClasses()) {
528        throw new Exception("value index must be <= the number of classes");
529      }
530      minIndex--; // make it an index
531    }
532
533    int nearestNeighbors;
534    if (min <= getNearestNeighbors()) {
535      nearestNeighbors = min - 1;
536    } else {
537      nearestNeighbors = getNearestNeighbors();
538    }
539    if (nearestNeighbors < 1)
540      throw new Exception("Cannot use 0 neighbors!");
541
542    // compose minority class dataset
543    // also push all dataset instances
544    Instances sample = getInputFormat().stringFreeStructure();
545    Enumeration instanceEnum = getInputFormat().enumerateInstances();
546    while(instanceEnum.hasMoreElements()) {
547      Instance instance = (Instance) instanceEnum.nextElement();
548      push((Instance) instance.copy());
549      if ((int) instance.classValue() == minIndex) {
550        sample.add(instance);
551      }
552    }
553
554    // compute Value Distance Metric matrices for nominal features
555    Map vdmMap = new HashMap();
556    Enumeration attrEnum = getInputFormat().enumerateAttributes();
557    while(attrEnum.hasMoreElements()) {
558      Attribute attr = (Attribute) attrEnum.nextElement();
559      if (!attr.equals(getInputFormat().classAttribute())) {
560        if (attr.isNominal() || attr.isString()) {
561          double[][] vdm = new double[attr.numValues()][attr.numValues()];
562          vdmMap.put(attr, vdm);
563          int[] featureValueCounts = new int[attr.numValues()];
564          int[][] featureValueCountsByClass = new int[getInputFormat().classAttribute().numValues()][attr.numValues()];
565          instanceEnum = getInputFormat().enumerateInstances();
566          while(instanceEnum.hasMoreElements()) {
567            Instance instance = (Instance) instanceEnum.nextElement();
568            int value = (int) instance.value(attr);
569            int classValue = (int) instance.classValue();
570            featureValueCounts[value]++;
571            featureValueCountsByClass[classValue][value]++;
572          }
573          for (int valueIndex1 = 0; valueIndex1 < attr.numValues(); valueIndex1++) {
574            for (int valueIndex2 = 0; valueIndex2 < attr.numValues(); valueIndex2++) {
575              double sum = 0;
576              for (int classValueIndex = 0; classValueIndex < getInputFormat().numClasses(); classValueIndex++) {
577                double c1i = (double) featureValueCountsByClass[classValueIndex][valueIndex1];
578                double c2i = (double) featureValueCountsByClass[classValueIndex][valueIndex2];
579                double c1 = (double) featureValueCounts[valueIndex1];
580                double c2 = (double) featureValueCounts[valueIndex2];
581                double term1 = c1i / c1;
582                double term2 = c2i / c2;
583                sum += Math.abs(term1 - term2);
584              }
585              vdm[valueIndex1][valueIndex2] = sum;
586            }
587          }
588        }
589      }
590    }
591
592    // use this random source for all required randomness
593    Random rand = new Random(getRandomSeed());
594
595    // find the set of extra indices to use if the percentage is not evenly divisible by 100
596    List extraIndices = new LinkedList();
597    double percentageRemainder = (getPercentage() / 100) - Math.floor(getPercentage() / 100.0);
598    int extraIndicesCount = (int) (percentageRemainder * sample.numInstances());
599    if (extraIndicesCount >= 1) {
600      for (int i = 0; i < sample.numInstances(); i++) {
601        extraIndices.add(i);
602      }
603    }
604    Collections.shuffle(extraIndices, rand);
605    extraIndices = extraIndices.subList(0, extraIndicesCount);
606    Set extraIndexSet = new HashSet(extraIndices);
607
608    // the main loop to handle computing nearest neighbors and generating SMOTE
609    // examples from each instance in the original minority class data
610    Instance[] nnArray = new Instance[nearestNeighbors];
611    for (int i = 0; i < sample.numInstances(); i++) {
612      Instance instanceI = sample.instance(i);
613      // find k nearest neighbors for each instance
614      List distanceToInstance = new LinkedList();
615      for (int j = 0; j < sample.numInstances(); j++) {
616        Instance instanceJ = sample.instance(j);
617        if (i != j) {
618          double distance = 0;
619          attrEnum = getInputFormat().enumerateAttributes();
620          while(attrEnum.hasMoreElements()) {
621            Attribute attr = (Attribute) attrEnum.nextElement();
622            if (!attr.equals(getInputFormat().classAttribute())) {
623              double iVal = instanceI.value(attr);
624              double jVal = instanceJ.value(attr);
625              if (attr.isNumeric()) {
626                distance += Math.pow(iVal - jVal, 2);
627              } else {
628                distance += ((double[][]) vdmMap.get(attr))[(int) iVal][(int) jVal];
629              }
630            }
631          }
632          distance = Math.pow(distance, .5);
633          distanceToInstance.add(new Object[] {distance, instanceJ});
634        }
635      }
636
637      // sort the neighbors according to distance
638      Collections.sort(distanceToInstance, new Comparator() {
639        public int compare(Object o1, Object o2) {
640          double distance1 = (Double) ((Object[]) o1)[0];
641          double distance2 = (Double) ((Object[]) o2)[0];
642          return (int) Math.ceil(distance1 - distance2);
643        }
644      });
645
646      // populate the actual nearest neighbor instance array
647      Iterator entryIterator = distanceToInstance.iterator();
648      int j = 0;
649      while(entryIterator.hasNext() && j < nearestNeighbors) {
650        nnArray[j] = (Instance) ((Object[])entryIterator.next())[1];
651        j++;
652      }
653
654      // create synthetic examples
655      int n = (int) Math.floor(getPercentage() / 100);
656      while(n > 0 || extraIndexSet.remove(i)) {
657        double[] values = new double[sample.numAttributes()];
658        int nn = rand.nextInt(nearestNeighbors);
659        attrEnum = getInputFormat().enumerateAttributes();
660        while(attrEnum.hasMoreElements()) {
661          Attribute attr = (Attribute) attrEnum.nextElement();
662          if (!attr.equals(getInputFormat().classAttribute())) {
663            if (attr.isNumeric()) {
664              double dif = nnArray[nn].value(attr) - instanceI.value(attr);
665              double gap = rand.nextDouble();
666              values[attr.index()] = (double) (instanceI.value(attr) + gap * dif);
667            } else if (attr.isDate()) {
668              double dif = nnArray[nn].value(attr) - instanceI.value(attr);
669              double gap = rand.nextDouble();
670              values[attr.index()] = (long) (instanceI.value(attr) + gap * dif);
671            } else {
672              int[] valueCounts = new int[attr.numValues()];
673              int iVal = (int) instanceI.value(attr);
674              valueCounts[iVal]++;
675              for (int nnEx = 0; nnEx < nearestNeighbors; nnEx++) {
676                int val = (int) nnArray[nnEx].value(attr);
677                valueCounts[val]++;
678              }
679              int maxIndex = 0;
680              int max = Integer.MIN_VALUE;
681              for (int index = 0; index < attr.numValues(); index++) {
682                if (valueCounts[index] > max) {
683                  max = valueCounts[index];
684                  maxIndex = index;
685                }
686              }
687              values[attr.index()] = maxIndex;
688            }
689          }
690        }
691        values[sample.classIndex()] = minIndex;
692        Instance synthetic = new DenseInstance(1.0, values);
693        push(synthetic);
694        n--;
695      }
696    }
697  }
698
699  /**
700   * Main method for running this filter.
701   *
702   * @param args        should contain arguments to the filter:
703   *                    use -h for help
704   */
705  public static void main(String[] args) {
706    runFilter(new SMOTE(), args);
707  }
708}
Note: See TracBrowser for help on using the repository browser.