source: src/main/java/weka/classifiers/functions/VotedPerceptron.java @ 7

Last change on this file since 7 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 16.0 KB
RevLine 
[4]1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    VotedPerceptron.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23
24package weka.classifiers.functions;
25
26import weka.classifiers.Classifier;
27import weka.classifiers.AbstractClassifier;
28import weka.core.Capabilities;
29import weka.core.Instance;
30import weka.core.Instances;
31import weka.core.Option;
32import weka.core.OptionHandler;
33import weka.core.RevisionUtils;
34import weka.core.TechnicalInformation;
35import weka.core.TechnicalInformationHandler;
36import weka.core.Utils;
37import weka.core.Capabilities.Capability;
38import weka.core.TechnicalInformation.Field;
39import weka.core.TechnicalInformation.Type;
40import weka.filters.Filter;
41import weka.filters.unsupervised.attribute.NominalToBinary;
42import weka.filters.unsupervised.attribute.ReplaceMissingValues;
43
44import java.util.Enumeration;
45import java.util.Random;
46import java.util.Vector;
47
48/**
49 <!-- globalinfo-start -->
50 * Implementation of the voted perceptron algorithm by Freund and Schapire. Globally replaces all missing values, and transforms nominal attributes into binary ones.<br/>
51 * <br/>
52 * For more information, see:<br/>
53 * <br/>
54 * Y. Freund, R. E. Schapire: Large margin classification using the perceptron algorithm. In: 11th Annual Conference on Computational Learning Theory, New York, NY, 209-217, 1998.
55 * <p/>
56 <!-- globalinfo-end -->
57 *
58 <!-- technical-bibtex-start -->
59 * BibTeX:
60 * <pre>
61 * &#64;inproceedings{Freund1998,
62 *    address = {New York, NY},
63 *    author = {Y. Freund and R. E. Schapire},
64 *    booktitle = {11th Annual Conference on Computational Learning Theory},
65 *    pages = {209-217},
66 *    publisher = {ACM Press},
67 *    title = {Large margin classification using the perceptron algorithm},
68 *    year = {1998}
69 * }
70 * </pre>
71 * <p/>
72 <!-- technical-bibtex-end -->
73 *
74 <!-- options-start -->
75 * Valid options are: <p/>
76 *
77 * <pre> -I &lt;int&gt;
78 *  The number of iterations to be performed.
79 *  (default 1)</pre>
80 *
81 * <pre> -E &lt;double&gt;
82 *  The exponent for the polynomial kernel.
83 *  (default 1)</pre>
84 *
85 * <pre> -S &lt;int&gt;
86 *  The seed for the random number generation.
87 *  (default 1)</pre>
88 *
89 * <pre> -M &lt;int&gt;
90 *  The maximum number of alterations allowed.
91 *  (default 10000)</pre>
92 *
93 <!-- options-end -->
94 *
95 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
96 * @version $Revision: 5928 $
97 */
98public class VotedPerceptron 
99  extends AbstractClassifier
100  implements OptionHandler, TechnicalInformationHandler {
101 
102  /** for serialization */
103  static final long serialVersionUID = -1072429260104568698L;
104 
105  /** The maximum number of alterations to the perceptron */
106  private int m_MaxK = 10000;
107
108  /** The number of iterations */
109  private int m_NumIterations = 1;
110
111  /** The exponent */
112  private double m_Exponent = 1.0;
113
114  /** The actual number of alterations */
115  private int m_K = 0;
116
117  /** The training instances added to the perceptron */
118  private int[] m_Additions = null;
119
120  /** Addition or subtraction? */
121  private boolean[] m_IsAddition = null;
122
123  /** The weights for each perceptron */
124  private int[] m_Weights = null;
125 
126  /** The training instances */
127  private Instances m_Train = null;
128
129  /** Seed used for shuffling the dataset */
130  private int m_Seed = 1;
131
132  /** The filter used to make attributes numeric. */
133  private NominalToBinary m_NominalToBinary;
134
135  /** The filter used to get rid of missing values. */
136  private ReplaceMissingValues m_ReplaceMissingValues;
137
138  /**
139   * Returns a string describing this classifier
140   * @return a description of the classifier suitable for
141   * displaying in the explorer/experimenter gui
142   */
143  public String globalInfo() {
144    return 
145        "Implementation of the voted perceptron algorithm by Freund and "
146      + "Schapire. Globally replaces all missing values, and transforms "
147      + "nominal attributes into binary ones.\n\n"
148      + "For more information, see:\n\n"
149      + getTechnicalInformation().toString();
150  }
151
152  /**
153   * Returns an instance of a TechnicalInformation object, containing
154   * detailed information about the technical background of this class,
155   * e.g., paper reference or book this class is based on.
156   *
157   * @return the technical information about this class
158   */
159  public TechnicalInformation getTechnicalInformation() {
160    TechnicalInformation        result;
161   
162    result = new TechnicalInformation(Type.INPROCEEDINGS);
163    result.setValue(Field.AUTHOR, "Y. Freund and R. E. Schapire");
164    result.setValue(Field.TITLE, "Large margin classification using the perceptron algorithm");
165    result.setValue(Field.BOOKTITLE, "11th Annual Conference on Computational Learning Theory");
166    result.setValue(Field.YEAR, "1998");
167    result.setValue(Field.PAGES, "209-217");
168    result.setValue(Field.PUBLISHER, "ACM Press");
169    result.setValue(Field.ADDRESS, "New York, NY");
170   
171    return result;
172  }
173
174  /**
175   * Returns an enumeration describing the available options.
176   *
177   * @return an enumeration of all the available options.
178   */
179  public Enumeration listOptions() {
180
181    Vector newVector = new Vector(4);
182
183    newVector.addElement(new Option("\tThe number of iterations to be performed.\n"
184                                    + "\t(default 1)",
185                                    "I", 1, "-I <int>"));
186    newVector.addElement(new Option("\tThe exponent for the polynomial kernel.\n"
187                                    + "\t(default 1)",
188                                    "E", 1, "-E <double>"));
189    newVector.addElement(new Option("\tThe seed for the random number generation.\n"
190                                    + "\t(default 1)",
191                                    "S", 1, "-S <int>"));
192    newVector.addElement(new Option("\tThe maximum number of alterations allowed.\n"
193                                    + "\t(default 10000)",
194                                    "M", 1, "-M <int>"));
195
196    return newVector.elements();
197  }
198
199  /**
200   * Parses a given list of options. <p/>
201   *
202   <!-- options-start -->
203   * Valid options are: <p/>
204   *
205   * <pre> -I &lt;int&gt;
206   *  The number of iterations to be performed.
207   *  (default 1)</pre>
208   *
209   * <pre> -E &lt;double&gt;
210   *  The exponent for the polynomial kernel.
211   *  (default 1)</pre>
212   *
213   * <pre> -S &lt;int&gt;
214   *  The seed for the random number generation.
215   *  (default 1)</pre>
216   *
217   * <pre> -M &lt;int&gt;
218   *  The maximum number of alterations allowed.
219   *  (default 10000)</pre>
220   *
221   <!-- options-end -->
222   *
223   * @param options the list of options as an array of strings
224   * @throws Exception if an option is not supported
225   */
226  public void setOptions(String[] options) throws Exception {
227   
228    String iterationsString = Utils.getOption('I', options);
229    if (iterationsString.length() != 0) {
230      m_NumIterations = Integer.parseInt(iterationsString);
231    } else {
232      m_NumIterations = 1;
233    }
234    String exponentsString = Utils.getOption('E', options);
235    if (exponentsString.length() != 0) {
236      m_Exponent = (new Double(exponentsString)).doubleValue();
237    } else {
238      m_Exponent = 1.0;
239    }
240    String seedString = Utils.getOption('S', options);
241    if (seedString.length() != 0) {
242      m_Seed = Integer.parseInt(seedString);
243    } else {
244      m_Seed = 1;
245    }
246    String alterationsString = Utils.getOption('M', options);
247    if (alterationsString.length() != 0) {
248      m_MaxK = Integer.parseInt(alterationsString);
249    } else {
250      m_MaxK = 10000;
251    }
252  }
253
254  /**
255   * Gets the current settings of the classifier.
256   *
257   * @return an array of strings suitable for passing to setOptions
258   */
259  public String[] getOptions() {
260
261    String[] options = new String [8];
262    int current = 0;
263
264    options[current++] = "-I"; options[current++] = "" + m_NumIterations;
265    options[current++] = "-E"; options[current++] = "" + m_Exponent;
266    options[current++] = "-S"; options[current++] = "" + m_Seed;
267    options[current++] = "-M"; options[current++] = "" + m_MaxK;
268    while (current < options.length) {
269      options[current++] = "";
270    }
271    return options;
272  }
273
274  /**
275   * Returns default capabilities of the classifier.
276   *
277   * @return      the capabilities of this classifier
278   */
279  public Capabilities getCapabilities() {
280    Capabilities result = super.getCapabilities();
281    result.disableAll();
282
283    // attributes
284    result.enable(Capability.NOMINAL_ATTRIBUTES);
285    result.enable(Capability.NUMERIC_ATTRIBUTES);
286    result.enable(Capability.DATE_ATTRIBUTES);
287    result.enable(Capability.MISSING_VALUES);
288
289    // class
290    result.enable(Capability.BINARY_CLASS);
291    result.enable(Capability.MISSING_CLASS_VALUES);
292
293    // instances
294    result.setMinimumNumberInstances(0);
295   
296    return result;
297  }
298
299  /**
300   * Builds the ensemble of perceptrons.
301   *
302   * @param insts the data to train the classifier with
303   * @throws Exception if something goes wrong during building
304   */
305  public void buildClassifier(Instances insts) throws Exception {
306 
307    // can classifier handle the data?
308    getCapabilities().testWithFail(insts);
309
310    // remove instances with missing class
311    insts = new Instances(insts);
312    insts.deleteWithMissingClass();
313   
314    // Filter data
315    m_Train = new Instances(insts);
316    m_ReplaceMissingValues = new ReplaceMissingValues();
317    m_ReplaceMissingValues.setInputFormat(m_Train);
318    m_Train = Filter.useFilter(m_Train, m_ReplaceMissingValues);
319   
320    m_NominalToBinary = new NominalToBinary();
321    m_NominalToBinary.setInputFormat(m_Train);
322    m_Train = Filter.useFilter(m_Train, m_NominalToBinary);
323
324    /** Randomize training data */
325    m_Train.randomize(new Random(m_Seed));
326
327    /** Make space to store perceptrons */
328    m_Additions = new int[m_MaxK + 1];
329    m_IsAddition = new boolean[m_MaxK + 1];
330    m_Weights = new int[m_MaxK + 1];
331
332    /** Compute perceptrons */
333    m_K = 0;
334  out:
335    for (int it = 0; it < m_NumIterations; it++) {
336      for (int i = 0; i < m_Train.numInstances(); i++) {
337        Instance inst = m_Train.instance(i);
338        if (!inst.classIsMissing()) {
339          int prediction = makePrediction(m_K, inst);
340          int classValue = (int) inst.classValue();
341          if (prediction == classValue) {
342            m_Weights[m_K]++;
343          } else {
344            m_IsAddition[m_K] = (classValue == 1);
345            m_Additions[m_K] = i;
346            m_K++;
347            m_Weights[m_K]++;
348          }
349          if (m_K == m_MaxK) {
350            break out;
351          }
352        }
353      }
354    }
355  }
356
357  /**
358   * Outputs the distribution for the given output.
359   *
360   * Pipes output of SVM through sigmoid function.
361   * @param inst the instance for which distribution is to be computed
362   * @return the distribution
363   * @throws Exception if something goes wrong
364   */
365  public double[] distributionForInstance(Instance inst) throws Exception {
366
367    // Filter instance
368    m_ReplaceMissingValues.input(inst);
369    m_ReplaceMissingValues.batchFinished();
370    inst = m_ReplaceMissingValues.output();
371
372    m_NominalToBinary.input(inst);
373    m_NominalToBinary.batchFinished();
374    inst = m_NominalToBinary.output();
375   
376    // Get probabilities
377    double output = 0, sumSoFar = 0;
378    if (m_K > 0) {
379      for (int i = 0; i <= m_K; i++) {
380        if (sumSoFar < 0) {
381          output -= m_Weights[i];
382        } else {
383          output += m_Weights[i];
384        }
385        if (m_IsAddition[i]) {
386          sumSoFar += innerProduct(m_Train.instance(m_Additions[i]), inst);
387        } else {
388          sumSoFar -= innerProduct(m_Train.instance(m_Additions[i]), inst);
389        }
390      }
391    }
392    double[] result = new double[2];
393    result[1] = 1 / (1 + Math.exp(-output));
394    result[0] = 1 - result[1];
395
396    return result;
397  }
398
399  /**
400   * Returns textual description of classifier.
401   *
402   * @return the model as string
403   */
404  public String toString() {
405
406    return "VotedPerceptron: Number of perceptrons=" + m_K;
407  }
408 
409  /**
410   * Returns the tip text for this property
411   * @return tip text for this property suitable for
412   * displaying in the explorer/experimenter gui
413   */
414  public String maxKTipText() {
415    return "The maximum number of alterations to the perceptron.";
416  }
417
418  /**
419   * Get the value of maxK.
420   *
421   * @return Value of maxK.
422   */
423  public int getMaxK() {
424   
425    return m_MaxK;
426  }
427 
428  /**
429   * Set the value of maxK.
430   *
431   * @param v  Value to assign to maxK.
432   */
433  public void setMaxK(int v) {
434   
435    m_MaxK = v;
436  }
437 
438  /**
439   * Returns the tip text for this property
440   * @return tip text for this property suitable for
441   * displaying in the explorer/experimenter gui
442   */
443  public String numIterationsTipText() {
444    return "Number of iterations to be performed.";
445  }
446
447  /**
448   * Get the value of NumIterations.
449   *
450   * @return Value of NumIterations.
451   */
452  public int getNumIterations() {
453   
454    return m_NumIterations;
455  }
456 
457  /**
458   * Set the value of NumIterations.
459   *
460   * @param v  Value to assign to NumIterations.
461   */
462  public void setNumIterations(int v) {
463   
464    m_NumIterations = v;
465  }
466
467  /**
468   * Returns the tip text for this property
469   * @return tip text for this property suitable for
470   * displaying in the explorer/experimenter gui
471   */
472  public String exponentTipText() {
473    return "Exponent for the polynomial kernel.";
474  }
475
476  /**
477   * Get the value of exponent.
478   *
479   * @return Value of exponent.
480   */
481  public double getExponent() {
482   
483    return m_Exponent;
484  }
485 
486  /**
487   * Set the value of exponent.
488   *
489   * @param v  Value to assign to exponent.
490   */
491  public void setExponent(double v) {
492   
493    m_Exponent = v;
494  }
495 
496  /**
497   * Returns the tip text for this property
498   * @return tip text for this property suitable for
499   * displaying in the explorer/experimenter gui
500   */
501  public String seedTipText() {
502    return "Seed for the random number generator.";
503  }
504
505  /**
506   * Get the value of Seed.
507   *
508   * @return Value of Seed.
509   */
510  public int getSeed() {
511   
512    return m_Seed;
513  }
514 
515  /**
516   * Set the value of Seed.
517   *
518   * @param v  Value to assign to Seed.
519   */
520  public void setSeed(int v) {
521   
522    m_Seed = v;
523  }
524
525  /**
526   * Computes the inner product of two instances
527   *
528   * @param i1 first instance
529   * @param i2 second instance
530   * @return the inner product
531   * @throws Exception if computation fails
532   */
533  private double innerProduct(Instance i1, Instance i2) throws Exception {
534
535    // we can do a fast dot product
536    double result = 0;
537    int n1 = i1.numValues(); int n2 = i2.numValues();
538    int classIndex = m_Train.classIndex();
539    for (int p1 = 0, p2 = 0; p1 < n1 && p2 < n2;) {
540        int ind1 = i1.index(p1);
541        int ind2 = i2.index(p2);
542        if (ind1 == ind2) {
543            if (ind1 != classIndex) {
544                result += i1.valueSparse(p1) *
545                          i2.valueSparse(p2);
546            }
547            p1++; p2++;
548        } else if (ind1 > ind2) {
549            p2++;
550        } else {
551            p1++;
552        }
553    }
554    result += 1.0;
555   
556    if (m_Exponent != 1) {
557      return Math.pow(result, m_Exponent);
558    } else {
559      return result;
560    }
561  }
562
563  /**
564   * Compute a prediction from a perceptron
565   *
566   * @param k
567   * @param inst the instance to make a prediction for
568   * @return the prediction
569   * @throws Exception if computation fails
570   */
571  private int makePrediction(int k, Instance inst) throws Exception {
572
573    double result = 0;
574    for (int i = 0; i < k; i++) {
575      if (m_IsAddition[i]) {
576        result += innerProduct(m_Train.instance(m_Additions[i]), inst);
577      } else {
578        result -= innerProduct(m_Train.instance(m_Additions[i]), inst);
579      }
580    }
581    if (result < 0) {
582      return 0;
583    } else {
584      return 1;
585    }
586  }
587 
588  /**
589   * Returns the revision string.
590   *
591   * @return            the revision
592   */
593  public String getRevision() {
594    return RevisionUtils.extract("$Revision: 5928 $");
595  }
596
597  /**
598   * Main method.
599   *
600   * @param argv the commandline options
601   */
602  public static void main(String[] argv) {
603    runClassifier(new VotedPerceptron(), argv);
604  }
605}
Note: See TracBrowser for help on using the repository browser.