source: src/main/java/weka/classifiers/meta/OneClassClassifier.java @ 16

Last change on this file since 16 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 43.2 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/**
18 *   OneClassClassifier.java
19 *   Copyright (C) 2008 K.Hempstalk, University of Waikato, Hamilton, New Zealand.
20 */
21
22package weka.classifiers.meta;
23
24import weka.classifiers.RandomizableSingleClassifierEnhancer;
25import weka.classifiers.meta.generators.GaussianGenerator;
26import weka.classifiers.meta.generators.Generator;
27import weka.classifiers.meta.generators.InstanceHandler;
28import weka.classifiers.meta.generators.Mean;
29import weka.classifiers.meta.generators.NominalGenerator;
30import weka.classifiers.meta.generators.NominalAttributeGenerator;
31import weka.classifiers.meta.generators.NumericAttributeGenerator;
32import weka.classifiers.meta.generators.Ranged;
33import weka.core.Attribute;
34import weka.core.Capabilities;
35import weka.core.Instance;
36import weka.core.DenseInstance;
37import weka.core.Instances;
38import weka.core.Option;
39import weka.core.TechnicalInformation;
40import weka.core.TechnicalInformationHandler;
41import weka.core.Utils;
42import weka.core.Capabilities.Capability;
43import weka.core.TechnicalInformation.Field;
44import weka.core.TechnicalInformation.Type;
45import weka.filters.Filter;
46import weka.filters.unsupervised.attribute.AddValues;
47import weka.filters.unsupervised.attribute.MergeManyValues;
48
49import java.io.StringReader;
50import java.util.ArrayList;
51import java.util.Arrays;
52import java.util.Enumeration;
53import java.util.Random;
54import java.util.Vector;
55
56/**
57 <!-- globalinfo-start -->
58 * Performs one-class classification on a dataset.<br/>
59 * <br/>
60 * Classifier reduces the class being classified to just a single class, and learns the datawithout using any information from other classes.  The testing stage will classify as 'target'or 'outlier' - so in order to calculate the outlier pass rate the dataset must contain informationfrom more than one class.<br/>
61 * <br/>
62 * Also, the output varies depending on whether the label 'outlier' exists in the instances usedto build the classifier.  If so, then 'outlier' will be predicted, if not, then the label willbe considered missing when the prediction does not favour the target class.  The 'outlier' classwill not be used to build the model if there are instances of this class in the dataset.  It cansimply be used as a flag, you do not need to relabel any classes.<br/>
63 * <br/>
64 * For more information, see:<br/>
65 * <br/>
66 * Kathryn Hempstalk, Eibe Frank, Ian H. Witten: One-Class Classification by Combining Density and Class Probability Estimation. In: Proceedings of the 12th European Conference on Principles and Practice of Knowledge Discovery in Databases and 19th European Conference on Machine Learning, ECMLPKDD2008, Berlin, 505--519, 2008.
67 * <p/>
68 <!-- globalinfo-end -->
69 *
70 <!-- technical-bibtex-start -->
71 * BibTeX:
72 * <pre>
73 * &#64;conference{Hempstalk2008,
74 *    address = {Berlin},
75 *    author = {Kathryn Hempstalk and Eibe Frank and Ian H. Witten},
76 *    booktitle = {Proceedings of the 12th European Conference on Principles and Practice of Knowledge Discovery in Databases and 19th European Conference on Machine Learning, ECMLPKDD2008},
77 *    month = {September},
78 *    pages = {505--519},
79 *    publisher = {Springer},
80 *    series = {Lecture Notes in Computer Science},
81 *    title = {One-Class Classification by Combining Density and Class Probability Estimation},
82 *    volume = {Vol. 5211},
83 *    year = {2008},
84 *    location = {Antwerp, Belgium}
85 * }
86 * </pre>
87 * <p/>
88 <!-- technical-bibtex-end -->
89 *
90 <!-- options-start -->
91 * Valid options are: <p/>
92 *
93 * <pre> -trr &lt;rate&gt;
94 *  Sets the target rejection rate
95 *  (default: 0.1)</pre>
96 *
97 * <pre> -tcl &lt;label&gt;
98 *  Sets the target class label
99 *  (default: 'target')</pre>
100 *
101 * <pre> -cvr &lt;rep&gt;
102 *  Sets the number of times to repeat cross validation
103 *  to find the threshold
104 *  (default: 10)</pre>
105 *
106 * <pre> -P &lt;prop&gt;
107 *  Sets the proportion of generated data
108 *  (default: 0.5)</pre>
109 *
110 * <pre> -cvf &lt;perc&gt;
111 *  Sets the percentage of heldout data for each cross validation
112 *  fold
113 *  (default: 10)</pre>
114 *
115 * <pre> -num &lt;classname + options&gt;
116 *  Sets the numeric generator
117 *  (default: weka.classifiers.meta.generators.GaussianGenerator)</pre>
118 *
119 * <pre> -nom &lt;classname + options&gt;
120 *  Sets the nominal generator
121 *  (default: weka.classifiers.meta.generators.NominalGenerator)</pre>
122 *
123 * <pre> -L
124 *  Sets whether to correct the number of classes to two,
125 *  if omitted no correction will be made.</pre>
126 *
127 * <pre> -E
128 *  Sets whether to exclusively use the density estimate.</pre>
129 *
130 * <pre> -I
131 *  Sets whether to use instance weights.</pre>
132 *
133 * <pre> -S &lt;num&gt;
134 *  Random number seed.
135 *  (default 1)</pre>
136 *
137 * <pre> -D
138 *  If set, classifier is run in debug mode and
139 *  may output additional info to the console</pre>
140 *
141 * <pre> -W
142 *  Full name of base classifier.
143 *  (default: weka.classifiers.meta.Bagging)</pre>
144 *
145 * <pre>
146 * Options specific to classifier weka.classifiers.meta.Bagging:
147 * </pre>
148 *
149 * <pre> -P
150 *  Size of each bag, as a percentage of the
151 *  training set size. (default 100)</pre>
152 *
153 * <pre> -O
154 *  Calculate the out of bag error.</pre>
155 *
156 * <pre> -S &lt;num&gt;
157 *  Random number seed.
158 *  (default 1)</pre>
159 *
160 * <pre> -I &lt;num&gt;
161 *  Number of iterations.
162 *  (default 10)</pre>
163 *
164 * <pre> -D
165 *  If set, classifier is run in debug mode and
166 *  may output additional info to the console</pre>
167 *
168 * <pre> -W
169 *  Full name of base classifier.
170 *  (default: weka.classifiers.trees.REPTree)</pre>
171 *
172 * <pre>
173 * Options specific to classifier weka.classifiers.trees.REPTree:
174 * </pre>
175 *
176 * <pre> -M &lt;minimum number of instances&gt;
177 *  Set minimum number of instances per leaf (default 2).</pre>
178 *
179 * <pre> -V &lt;minimum variance for split&gt;
180 *  Set minimum numeric class variance proportion
181 *  of train variance for split (default 1e-3).</pre>
182 *
183 * <pre> -N &lt;number of folds&gt;
184 *  Number of folds for reduced error pruning (default 3).</pre>
185 *
186 * <pre> -S &lt;seed&gt;
187 *  Seed for random data shuffling (default 1).</pre>
188 *
189 * <pre> -P
190 *  No pruning.</pre>
191 *
192 * <pre> -L
193 *  Maximum tree depth (default -1, no maximum)</pre>
194 *
195 <!-- options-end -->
196 *
197 * Options after -- are passed to the designated classifier.
198 *
199 * @author Kathryn Hempstalk (kah18 at cs.waikato.ac.nz)
200 * @author Eibe Frank (eibe at cs.waikato.ac.nz)
201 * @version $Revision: 5987 $
202 */
203public class OneClassClassifier
204  extends RandomizableSingleClassifierEnhancer
205  implements TechnicalInformationHandler {
206
207  /** for serialization. */
208  private static final long serialVersionUID = 6199125385010158931L;
209
210  /**
211   * The rejection rate of valid target objects (used to set the threshold).
212   */
213  protected double m_TargetRejectionRate = 0.1;
214
215  /**
216   * The probability threshold (only classes above this will be considered target).
217   */
218  protected double m_Threshold = 0.5;
219
220  /**
221   * The generators for the numeric attributes.
222   */
223  protected ArrayList m_Generators;
224
225  /**
226   * The value of the class attribute to consider the target class.
227   */
228  protected String m_TargetClassLabel = "target";
229
230  /**
231   * The number of times to repeat cross validation during learning.
232   */
233  protected int m_NumRepeats = 10; 
234
235  /**
236   * The percentage of heldout data.
237   */
238  protected double m_PercentHeldout = 10;
239
240  /**
241   * The proportion of the data that will be generated.
242   */
243  protected double m_ProportionGenerated = 0.5;
244
245  /**
246   * The default data generator for numeric attributes.
247   */
248    protected NumericAttributeGenerator m_DefaultNumericGenerator = (NumericAttributeGenerator) new GaussianGenerator();
249
250  /**
251   * The default data generator for nominal attributes.
252   */
253    protected NominalAttributeGenerator m_DefaultNominalGenerator = (NominalAttributeGenerator) new NominalGenerator();
254
255  /**
256   * Adds the outlier class if it doesn't already exist.
257   */
258  protected AddValues m_AddOutlierFilter;
259
260  /**
261   * Whether to include laplace correction so if there are multiple
262   * values for a class, it is reduced to just two so that any laplace
263   * correction in another classifier corrects with one possible other class
264   * rather than several. 
265   */
266  protected boolean m_UseLaplaceCorrection = false;
267
268  /**
269   * The filter that merges the instances down to two values.
270   */
271  protected MergeManyValues m_MergeFilter;
272
273  /**
274   * The label for the outlier class.
275   */
276  public static final String OUTLIER_LABEL = "outlier";
277
278
279  /**
280   * Whether to use only the density estimate, or to include the
281   * base classifier in the probability estimates.
282   */
283  protected boolean m_UseDensityOnly = false;
284
285  /**
286   * Whether to weight instances based on their prevalence in the
287   * test set used for calculating P(X|T).
288   */
289  protected boolean m_UseInstanceWeights = false;
290 
291  /** The random number generator used internally. */
292  protected Random m_Random;
293
294 
295  /**
296   * Default constructor.
297   */
298  public OneClassClassifier() {
299    super();
300   
301    m_Classifier = new weka.classifiers.meta.Bagging();
302  }
303
304  /**
305   * Returns a string describing this classes ability.
306   *
307   * @return A description of the method.
308   */
309  public String globalInfo() {
310    return 
311       "Performs one-class classification on a dataset.\n\n"
312     + "Classifier reduces the class being classified to just a single class, and learns the data"
313     + "without using any information from other classes.  The testing stage will classify as 'target'"
314     + "or 'outlier' - so in order to calculate the outlier pass rate the dataset must contain information"
315     + "from more than one class.\n"
316     + "\n"
317     + "Also, the output varies depending on whether the label 'outlier' exists in the instances used"
318     + "to build the classifier.  If so, then 'outlier' will be predicted, if not, then the label will"
319     + "be considered missing when the prediction does not favour the target class.  The 'outlier' class"
320     + "will not be used to build the model if there are instances of this class in the dataset.  It can"
321     + "simply be used as a flag, you do not need to relabel any classes.\n"
322     + "\n"
323     + "For more information, see:\n"
324     + "\n"
325     + getTechnicalInformation().toString();
326  }
327
328  /**
329   * Returns an instance of a TechnicalInformation object, containing
330   * detailed information about the technical background of this class,
331   * e.g., paper reference or book this class is based on.
332   *
333   * @return the technical information about this class
334   */
335  public TechnicalInformation getTechnicalInformation() {
336    TechnicalInformation        result;
337
338    result = new TechnicalInformation(Type.CONFERENCE);
339    result.setValue(Field.AUTHOR, "Kathryn Hempstalk and Eibe Frank and Ian H. Witten");
340    result.setValue(Field.YEAR, "2008");
341    result.setValue(Field.TITLE, "One-Class Classification by Combining Density and Class Probability Estimation");
342    result.setValue(Field.BOOKTITLE, "Proceedings of the 12th European Conference on Principles and Practice of Knowledge Discovery in Databases and 19th European Conference on Machine Learning, ECMLPKDD2008");
343    result.setValue(Field.VOLUME, "Vol. 5211");
344    result.setValue(Field.PAGES, "505--519");
345    result.setValue(Field.PUBLISHER, "Springer");
346    result.setValue(Field.ADDRESS, "Berlin");
347    result.setValue(Field.SERIES, "Lecture Notes in Computer Science");
348    result.setValue(Field.LOCATION, "Antwerp, Belgium");
349    result.setValue(Field.MONTH, "September");
350
351    return result;
352  }
353
354  /**
355   * Returns an enumeration describing the available options.
356   *
357   * @return An enumeration of all the available options.
358   */
359  public Enumeration listOptions() {
360    Vector result = new Vector();   
361
362    result.addElement(new Option(
363        "\tSets the target rejection rate\n"
364        + "\t(default: 0.1)",
365        "trr", 1, "-trr <rate>"));
366
367    result.addElement(new Option(
368        "\tSets the target class label\n"
369        + "\t(default: 'target')",
370        "tcl", 1, "-tcl <label>"));
371
372    result.addElement(new Option(
373        "\tSets the number of times to repeat cross validation\n" 
374        + "\tto find the threshold\n"
375        + "\t(default: 10)",
376        "cvr", 1, "-cvr <rep>"));
377
378    result.addElement(new Option(
379        "\tSets the proportion of generated data\n"
380        + "\t(default: 0.5)",
381        "P", 1, "-P <prop>"));
382
383    result.addElement(new Option(
384        "\tSets the percentage of heldout data for each cross validation\n"
385        + "\tfold\n"
386        + "\t(default: 10)",
387        "cvf", 1, "-cvf <perc>"));
388
389    result.addElement(new Option(
390        "\tSets the numeric generator\n"
391        + "\t(default: " + GaussianGenerator.class.getName() + ")",
392        "num", 1, "-num <classname + options>"));
393
394    result.addElement(new Option(
395        "\tSets the nominal generator\n"
396        + "\t(default: " + NominalGenerator.class.getName() + ")",
397        "nom", 1, "-nom <classname + options>"));
398
399    result.addElement(new Option(
400        "\tSets whether to correct the number of classes to two,\n"
401        + "\tif omitted no correction will be made.",
402        "L", 1, "-L"));
403
404    result.addElement(new Option(
405        "\tSets whether to exclusively use the density estimate.",
406        "E", 0, "-E"));
407
408    result.addElement(new Option(
409        "\tSets whether to use instance weights.",
410        "I", 0, "-I"));
411   
412    Enumeration enu = super.listOptions();
413    while (enu.hasMoreElements())
414      result.addElement(enu.nextElement());
415
416    return result.elements();
417  }
418
419  /**
420   * Parses a given list of options. <p/>
421   *
422   <!-- options-start -->
423   * Valid options are: <p/>
424   *
425   * <pre> -trr &lt;rate&gt;
426   *  Sets the target rejection rate
427   *  (default: 0.1)</pre>
428   *
429   * <pre> -tcl &lt;label&gt;
430   *  Sets the target class label
431   *  (default: 'target')</pre>
432   *
433   * <pre> -cvr &lt;rep&gt;
434   *  Sets the number of times to repeat cross validation
435   *  to find the threshold
436   *  (default: 10)</pre>
437   *
438   * <pre> -P &lt;prop&gt;
439   *  Sets the proportion of generated data
440   *  (default: 0.5)</pre>
441   *
442   * <pre> -cvf &lt;perc&gt;
443   *  Sets the percentage of heldout data for each cross validation
444   *  fold
445   *  (default: 10)</pre>
446   *
447   * <pre> -num &lt;classname + options&gt;
448   *  Sets the numeric generator
449   *  (default: weka.classifiers.meta.generators.GaussianGenerator)</pre>
450   *
451   * <pre> -nom &lt;classname + options&gt;
452   *  Sets the nominal generator
453   *  (default: weka.classifiers.meta.generators.NominalGenerator)</pre>
454   *
455   * <pre> -L
456   *  Sets whether to correct the number of classes to two,
457   *  if omitted no correction will be made.</pre>
458   *
459   * <pre> -E
460   *  Sets whether to exclusively use the density estimate.</pre>
461   *
462   * <pre> -I
463   *  Sets whether to use instance weights.</pre>
464   *
465   * <pre> -S &lt;num&gt;
466   *  Random number seed.
467   *  (default 1)</pre>
468   *
469   * <pre> -D
470   *  If set, classifier is run in debug mode and
471   *  may output additional info to the console</pre>
472   *
473   * <pre> -W
474   *  Full name of base classifier.
475   *  (default: weka.classifiers.meta.Bagging)</pre>
476   *
477   * <pre>
478   * Options specific to classifier weka.classifiers.meta.Bagging:
479   * </pre>
480   *
481   * <pre> -P
482   *  Size of each bag, as a percentage of the
483   *  training set size. (default 100)</pre>
484   *
485   * <pre> -O
486   *  Calculate the out of bag error.</pre>
487   *
488   * <pre> -S &lt;num&gt;
489   *  Random number seed.
490   *  (default 1)</pre>
491   *
492   * <pre> -I &lt;num&gt;
493   *  Number of iterations.
494   *  (default 10)</pre>
495   *
496   * <pre> -D
497   *  If set, classifier is run in debug mode and
498   *  may output additional info to the console</pre>
499   *
500   * <pre> -W
501   *  Full name of base classifier.
502   *  (default: weka.classifiers.trees.REPTree)</pre>
503   *
504   * <pre>
505   * Options specific to classifier weka.classifiers.trees.REPTree:
506   * </pre>
507   *
508   * <pre> -M &lt;minimum number of instances&gt;
509   *  Set minimum number of instances per leaf (default 2).</pre>
510   *
511   * <pre> -V &lt;minimum variance for split&gt;
512   *  Set minimum numeric class variance proportion
513   *  of train variance for split (default 1e-3).</pre>
514   *
515   * <pre> -N &lt;number of folds&gt;
516   *  Number of folds for reduced error pruning (default 3).</pre>
517   *
518   * <pre> -S &lt;seed&gt;
519   *  Seed for random data shuffling (default 1).</pre>
520   *
521   * <pre> -P
522   *  No pruning.</pre>
523   *
524   * <pre> -L
525   *  Maximum tree depth (default -1, no maximum)</pre>
526   *
527   <!-- options-end -->
528   *
529   * @param options The list of options as an array of strings.
530   * @throws Exception If an option is not supported.
531   */
532  public void setOptions(String[] options) throws Exception {
533    String      tmpStr;
534    String[]    tmpOptions;
535
536    // numeric generator
537    tmpStr = Utils.getOption("num", options);
538    if (tmpStr.length() != 0) { 
539      tmpOptions    = Utils.splitOptions(tmpStr);
540      tmpStr        = tmpOptions[0];
541      tmpOptions[0] = "";
542      setNumericGenerator((NumericAttributeGenerator) Utils.forName(Generator.class, tmpStr, tmpOptions));
543    }
544    else {
545      setNumericGenerator((NumericAttributeGenerator) Utils.forName(Generator.class, defaultNumericGeneratorString(), null));
546    }
547
548    // nominal generator
549    tmpStr = Utils.getOption("nom", options);
550    if (tmpStr.length() != 0) { 
551      tmpOptions    = Utils.splitOptions(tmpStr);
552      tmpStr        = tmpOptions[0];
553      tmpOptions[0] = "";
554      setNominalGenerator((NominalAttributeGenerator) Utils.forName(Generator.class, tmpStr, tmpOptions));
555    }
556    else {
557      setNominalGenerator((NominalAttributeGenerator) Utils.forName(Generator.class, defaultNominalGeneratorString(), null));
558    }
559
560    //target rejection rate
561    tmpStr = Utils.getOption("trr", options);
562    if (tmpStr.length() != 0)
563      setTargetRejectionRate(Double.parseDouble(tmpStr));
564    else
565      setTargetRejectionRate(0.1);
566
567    //target class label
568    tmpStr = Utils.getOption("tcl", options);
569    if (tmpStr.length() != 0)
570      setTargetClassLabel(tmpStr);
571    else
572      setTargetClassLabel("target");
573
574    //cross validation repeats
575    tmpStr = Utils.getOption("cvr", options);
576    if (tmpStr.length() != 0)
577      setNumRepeats(Integer.parseInt(tmpStr));
578    else
579      setNumRepeats(10);
580
581    //cross validation fold size
582    tmpStr = Utils.getOption("cvf", options);
583    if (tmpStr.length() != 0)
584      setPercentageHeldout(Double.parseDouble(tmpStr));
585    else
586      setPercentageHeldout(10.0);
587
588    //proportion generated
589    tmpStr = Utils.getOption("P", options);
590    if (tmpStr.length() != 0)
591      setProportionGenerated(Double.parseDouble(tmpStr));
592    else
593      setProportionGenerated(0.5);
594
595    //use laplace
596    setUseLaplaceCorrection(Utils.getFlag('L',options));
597
598    //set whether to exclusively use the density estimate
599    setDensityOnly(Utils.getFlag('E', options));
600
601    //use instance weights
602    setUseInstanceWeights(Utils.getFlag('I', options));
603   
604    // set the parent's options first
605    super.setOptions(options);
606  }
607
608  /**
609   * Gets the current settings of the Classifier.
610   *
611   * @return An array of strings suitable for passing to setOptions.
612   */
613  public String[] getOptions() {
614    Vector<String>      result;
615    String[]            options;
616    int                 i;
617
618    result = new Vector<String>();
619   
620    result.add("-num");
621    result.add(
622          m_DefaultNumericGenerator.getClass().getName() 
623        + " " 
624          + Utils.joinOptions(((Generator)m_DefaultNumericGenerator).getOptions()));
625
626    result.add("-nom");
627    result.add(
628          m_DefaultNominalGenerator.getClass().getName() 
629        + " " 
630          + Utils.joinOptions(((Generator)m_DefaultNominalGenerator).getOptions()));
631   
632    result.add("-trr");
633    result.add("" + m_TargetRejectionRate);
634   
635    result.add("-tcl");
636    result.add("" + m_TargetClassLabel);
637   
638    result.add("-cvr");
639    result.add("" + m_NumRepeats);
640   
641    result.add("-cvf");
642    result.add("" + m_PercentHeldout);
643   
644    result.add("-P");
645    result.add("" + m_ProportionGenerated);
646   
647    if (m_UseLaplaceCorrection)
648      result.add("-L");   
649 
650    if (m_UseDensityOnly)
651      result.add("-E");
652   
653    if (m_UseInstanceWeights)
654      result.add("-I");
655
656    options = super.getOptions();
657    for (i = 0; i < options.length; i++)
658      result.add(options[i]);
659
660    return result.toArray(new String[result.size()]);
661  }
662
663  /**
664   * Gets whether only the density estimate should be used by the classifier.  If false,
665   * the base classifier's estimate will be incorporated using bayes rule for two classes.
666   *
667   * @return Whether to use only the density estimate.
668   */
669  public boolean getDensityOnly() {
670    return m_UseDensityOnly;
671  }
672
673  /**
674   * Sets whether the density estimate will be used by itself.
675   *
676   * @param density Whether to use the density estimate exclusively or not.
677   */
678  public void setDensityOnly(boolean density) {
679    m_UseDensityOnly = density;
680  }
681 
682  /**
683   * Returns the tip text for this property.
684   *
685   * @return            tip text for this property suitable for
686   *                    displaying in the explorer/experimenter gui
687   */
688  public String densityOnlyTipText() {
689    return "If true, the density estimate will be used by itself.";
690  }
691
692  /**
693   * Gets the target rejection rate - the proportion of target class samples
694   * that will be rejected in order to build a threshold.
695   *
696   * @return The target rejection rate.
697   */
698  public double getTargetRejectionRate() {
699    return m_TargetRejectionRate;
700  }
701
702  /**
703   * Sets the target rejection rate.
704   *
705   * @param rate The new target rejection rate.
706   */
707  public void setTargetRejectionRate(double rate) {
708    m_TargetRejectionRate = rate;
709  }
710 
711  /**
712   * Returns the tip text for this property.
713   *
714   * @return            tip text for this property suitable for
715   *                    displaying in the explorer/experimenter gui
716   */
717  public String targetRejectionRateTipText() {
718    return 
719        "The target rejection rate, ie, the proportion of target class "
720      + "samples that will be rejected in order to build a threshold.";
721  }
722
723  /**
724   * Gets the target class label - the class label to perform one
725   * class classification on.
726   *
727   * @return The target class label.
728   */
729  public String getTargetClassLabel() {
730    return m_TargetClassLabel;
731  }
732
733  /**
734   * Sets the target class label to a new value.
735   *
736   * @param label The target class label to classify for.
737   */
738  public void setTargetClassLabel(String label) {
739    m_TargetClassLabel = label;
740  }
741 
742  /**
743   * Returns the tip text for this property.
744   *
745   * @return            tip text for this property suitable for
746   *                    displaying in the explorer/experimenter gui
747   */
748  public String targetClassLabelTipText() {
749    return "The class label to perform one-class classification on.";
750  }
751
752  /**
753   * Gets the number of repeats for (internal) cross validation.
754   *
755   * @return The number of repeats for internal cross validation.
756   */
757  public int getNumRepeats() {
758    return m_NumRepeats;
759  }
760
761  /**
762   * Sets the number of repeats for (internal) cross validation to a new value.
763   *
764   * @param repeats The new number of repeats for cross validation.
765   */
766  public void setNumRepeats(int repeats) {
767    m_NumRepeats = repeats;
768  }
769 
770  /**
771   * Returns the tip text for this property.
772   *
773   * @return            tip text for this property suitable for
774   *                    displaying in the explorer/experimenter gui
775   */
776  public String numRepeatsTipText() {
777    return "The number of repeats for (internal) cross-validation.";
778  }
779
780  /**
781   * Sets the proportion of generated data to a new value.
782   *
783   * @param prop The new proportion.
784   */
785  public void setProportionGenerated(double prop) {
786    m_ProportionGenerated = prop;
787  }
788
789  /**
790   * Gets the proportion of data that will be generated compared to the
791   * target class label.
792   *
793   * @return The proportion of generated data.
794   */
795  public double getProportionGenerated() {
796    return m_ProportionGenerated;
797  }
798 
799  /**
800   * Returns the tip text for this property.
801   *
802   * @return            tip text for this property suitable for
803   *                    displaying in the explorer/experimenter gui
804   */
805  public String proportionGeneratedTipText() {
806    return 
807        "The proportion of data that will be generated compared to the "
808      + "target class label.";
809  }
810
811  /**
812   * Sets the percentage heldout in each CV fold.
813   *
814   * @param percent The new percent of heldout data.
815   */
816  public void setPercentageHeldout(double percent) {
817    m_PercentHeldout = percent;
818  }
819
820  /**
821   * Gets the percentage of data that will be heldout in each
822   * iteration of cross validation.
823   *
824   * @return The percentage of heldout data.
825   */
826  public double getPercentageHeldout() {
827    return m_PercentHeldout;
828  }
829 
830  /**
831   * Returns the tip text for this property.
832   *
833   * @return            tip text for this property suitable for
834   *                    displaying in the explorer/experimenter gui
835   */
836  public String percentageHeldoutTipText() {
837    return 
838        "The percentage of data that will be heldout in each iteration "
839      + "of (internal) cross-validation.";
840  }
841
842  /**
843   * Gets thegenerator that will be used by default to generate
844   * numeric outlier data.
845   *
846   * @return The numeric data generator.
847   */
848  public NumericAttributeGenerator getNumericGenerator() {
849    return m_DefaultNumericGenerator;
850  }
851
852  /**
853   * Sets the generator that will be used by default to generate
854   * numeric outlier data.
855   *
856   * @param agen The new numeric data generator to use.
857   */
858  public void setNumericGenerator(NumericAttributeGenerator agen) {
859    m_DefaultNumericGenerator = agen;
860  }
861 
862  /**
863   * Returns the tip text for this property.
864   *
865   * @return            tip text for this property suitable for
866   *                    displaying in the explorer/experimenter gui
867   */
868  public String numericGeneratorTipText() {
869    return "The numeric data generator to use.";
870  }
871
872  /**
873   * Gets the generator that will be used by default to generate
874   * nominal outlier data.
875   *
876   * @return The nominal data generator.
877   */
878  public NominalAttributeGenerator getNominalGenerator() {
879    return m_DefaultNominalGenerator;
880  }
881
882  /**
883   * Sets the generator that will be used by default to generate
884   * nominal outlier data.
885   *
886   * @param agen The new nominal data generator to use.
887   */
888  public void setNominalGenerator(NominalAttributeGenerator agen) {
889    m_DefaultNominalGenerator = agen;
890  }
891 
892  /**
893   * Returns the tip text for this property.
894   *
895   * @return            tip text for this property suitable for
896   *                    displaying in the explorer/experimenter gui
897   */
898  public String nominalGeneratorTipText() {
899    return "The nominal data generator to use.";
900  }
901
902  /**
903   * Gets whether a laplace correction should be used.
904   *
905   * @return Whether a laplace correction should be used.
906   */
907  public boolean getUseLaplaceCorrection() {
908    return m_UseLaplaceCorrection;
909  }
910
911  /**
912   * Sets whether a laplace correction should be used.  A laplace
913   * correction will reduce the number of class labels to two, the
914   * target and outlier classes, regardless of how many labels
915   * actually exist.  This is useful for classifiers that use
916   * the number of class labels to make use a laplace value
917   * based on the unseen class.
918   *
919   * @param newuse Whether to use the laplace correction (default: true).
920   */
921  public void setUseLaplaceCorrection(boolean newuse) {
922    m_UseLaplaceCorrection = newuse;
923  }
924 
925  /**
926   * Returns the tip text for this property.
927   *
928   * @return            tip text for this property suitable for
929   *                    displaying in the explorer/experimenter gui
930   */
931  public String useLaplaceCorrectionTipText() {
932    return 
933        "If true, then Laplace correction will be used (reduces the "
934      + "number of class labels to two, target and outlier class, regardless "
935      + "of how many class labels actually exist) - useful for classifiers "
936      + "that use the number of class labels to make use of a Laplace value "
937      + "based on the unseen class.";
938  }
939
940
941  /**
942   * Sets whether to perform weighting on instances based on their
943   * prevalence in the data.
944   *
945   * @param newuse Whether or not to use instance weighting.
946   */
947  public void setUseInstanceWeights(boolean newuse) {
948    m_UseInstanceWeights = newuse;
949  }
950
951  /**
952   * Gets whether instance weighting will be performed.
953   *
954   * @return Whether instance weighting will be performed.
955   */
956  public boolean getUseInstanceWeights() {
957    return m_UseInstanceWeights;
958  }
959 
960  /**
961   * Returns the tip text for this property.
962   *
963   * @return            tip text for this property suitable for
964   *                    displaying in the explorer/experimenter gui
965   */
966  public String useInstanceWeightsTipText() {
967    return 
968        "If true, the weighting on instances is based on their prevalence "
969      + "in the data.";
970  }
971
972  /**
973   * String describing default internal classifier.
974   *
975   * @return The default classifier classname.
976   */
977  protected String defaultClassifierString() {   
978    return "weka.classifiers.meta.Bagging";
979  }
980
981  /**
982   * String describing default generator / density estimator.
983   *
984   * @return The default numeric generator classname.
985   */
986  protected String defaultNumericGeneratorString() {   
987    return "weka.classifiers.meta.generators.GaussianGenerator";
988  }
989
990  /**
991   * String describing default generator / density estimator.
992   *
993   * @return The default nominal generator classname.
994   */
995  protected String defaultNominalGeneratorString() {   
996    return "weka.classifiers.meta.generators.NominalGenerator";
997  }
998
999  /**
1000   * Returns default capabilities of the base classifier.
1001   *
1002   * @return      the capabilities of the base classifier
1003   */
1004  public Capabilities getCapabilities() {
1005    Capabilities        result;
1006   
1007    result = super.getCapabilities();
1008   
1009    // only nominal classes can be processed!
1010    result.disableAllClasses();
1011    result.disableAllClassDependencies();
1012    result.enable(Capability.NOMINAL_CLASS);
1013   
1014    return result;
1015  }
1016
1017  /**
1018   * Build the one-class classifier, any non-target data values
1019   * are ignored.  The target class label must exist in the arff
1020   * file or else an exception will be thrown.
1021   *
1022   * @param data The training data.
1023   * @throws Exception If the classifier could not be built successfully.
1024   */
1025  public void buildClassifier(Instances data) throws Exception {
1026    if (m_Classifier == null) {
1027      throw new Exception("No base classifier has been set!");
1028    }
1029    // can classifier handle the data?
1030    getCapabilities().testWithFail(data);
1031
1032    // remove instances with missing class
1033    Instances newData = new Instances(data);
1034
1035    m_Random = new Random(m_Seed);
1036   
1037    //delete the data that's not of the class we are trying to classify.
1038    Attribute classAttribute = newData.classAttribute();
1039    double targetClassValue = classAttribute.indexOfValue(m_TargetClassLabel);       
1040    if (targetClassValue == -1) {
1041      throw new Exception("Target class value doesn't exist!");
1042    }
1043
1044    int index = 0;
1045    while(index < newData.numInstances()) {
1046      Instance aninst = newData.instance(index);
1047      if (aninst.classValue() != targetClassValue) {
1048        newData.delete(index);
1049      } else
1050        index++;
1051    }
1052
1053    if (newData.numInstances() == 0) {
1054      throw new Exception("No instances found belonging to the target class!");
1055    }
1056
1057    //now we need to add the "outlier" attribute if it doesn't already exist.
1058    m_AddOutlierFilter = new AddValues();
1059    m_AddOutlierFilter.setAttributeIndex("" + (newData.classIndex() + 1));
1060    m_AddOutlierFilter.setLabels(OneClassClassifier.OUTLIER_LABEL);
1061    m_AddOutlierFilter.setInputFormat(newData);
1062    newData = Filter.useFilter(newData, m_AddOutlierFilter);
1063
1064    if (m_UseLaplaceCorrection) {
1065      newData = this.mergeToTwo(newData);   
1066    }   
1067
1068    //make sure the generators are created.
1069    m_Generators = new ArrayList();
1070
1071    //one for each column
1072    //need to work out the range or mean/stddev for each attribute
1073    int numAttributes = newData.numAttributes();
1074
1075    //work out the ranges
1076    double lowranges[] = new double[numAttributes - 1];
1077    double highranges[] = new double[numAttributes - 1];
1078    double means[] = new double[numAttributes - 1];
1079    double stddevs[] = new double[numAttributes - 1];
1080    double instanceCount[] = new double[numAttributes - 1];
1081    int attrIndexes[] = new int[numAttributes - 1];
1082
1083    //initialise
1084    for (int i = 0; i < numAttributes - 1; i++) {
1085      lowranges[i] = Double.MAX_VALUE;
1086      highranges[i] = -1 * Double.MAX_VALUE;
1087      means[i] = 0;
1088      stddevs[i] = 0;
1089      attrIndexes[i] = 0;
1090      instanceCount[i] = 0;
1091    }
1092
1093    //calculate low/high ranges and means.
1094    //missing attributes are ignored
1095    for (int i = 0; i < newData.numInstances(); i++) {
1096      Instance anInst = newData.instance(i);
1097      int attCount = 0;
1098      for (int j = 0; j < numAttributes; j++) {
1099        if (j != newData.classIndex()) {
1100          double attVal = anInst.value(j);
1101          if (!anInst.isMissing(j)) {
1102            if (attVal > highranges[attCount])
1103              highranges[attCount] = attVal;
1104            if (attVal < lowranges[attCount])
1105              lowranges[attCount] = attVal;
1106
1107            means[attCount] += attVal;
1108            instanceCount[attCount] += 1;                       
1109          }
1110          attrIndexes[attCount] = j;
1111          attCount++;
1112
1113        }
1114      }
1115    }
1116
1117    //calculate means...
1118    for (int i = 0; i < numAttributes - 1; i++) {
1119      if (instanceCount[i] > 0) {
1120        means[i] = means[i] / instanceCount[i];
1121      }
1122    }
1123
1124    //and now standard deviations
1125    for (int i = 0; i < newData.numInstances(); i++) {
1126      Instance anInst = newData.instance(i);
1127      int attCount = 0;
1128      for (int j = 0; j < numAttributes - 1; j++) {
1129        if (instanceCount[j] > 0) {
1130          stddevs[attCount] += Math.pow(anInst.value(j) - means[attCount], 2);
1131          attCount++;
1132        }
1133      }
1134    }
1135
1136    for (int i = 0; i < numAttributes - 1; i++) {
1137      if (instanceCount[i] > 0) {
1138        stddevs[i] = Math.sqrt(stddevs[i] / instanceCount[i]);         
1139      }
1140    }
1141
1142
1143    //ok, now we have everything, need to make a generator for each column
1144    for (int i = 0; i < numAttributes - 1; i++) {
1145      Generator agen;
1146      if (newData.attribute(attrIndexes[i]).isNominal()) {
1147        agen = ((Generator)m_DefaultNominalGenerator).copy();
1148        ((NominalAttributeGenerator)agen).buildGenerator(newData, newData.attribute(attrIndexes[i]));
1149      } else {
1150        agen = ((Generator)m_DefaultNumericGenerator).copy();
1151       
1152        if (agen instanceof Ranged) {
1153          ((Ranged)agen).setLowerRange(lowranges[i]);
1154          ((Ranged)agen).setUpperRange(highranges[i]);
1155        }
1156       
1157        if (agen instanceof Mean) {
1158          ((Mean)agen).setMean(means[i]);
1159          ((Mean)agen).setStandardDeviation(stddevs[i]);
1160        }
1161
1162        if (agen instanceof InstanceHandler) {
1163          //generator needs to be setup with the instances,
1164          //need to pass over a set of instances with just the current
1165          //attribute.
1166          StringBuffer sb = new StringBuffer("@relation OneClass-SingleAttribute\n\n");
1167          sb.append("@attribute tempName numeric\n\n");
1168          sb.append("@data\n\n");
1169          Enumeration instancesEnum = newData.enumerateInstances();
1170          while(instancesEnum.hasMoreElements()) {
1171            Instance aninst = (Instance)instancesEnum.nextElement();
1172            if (!aninst.isMissing(attrIndexes[i]))
1173              sb.append("" + aninst.value(attrIndexes[i]) + "\n");
1174          }
1175          sb.append("\n\n");
1176          Instances removed = new Instances(new StringReader(sb.toString()));
1177          removed.deleteWithMissing(0);
1178          ((InstanceHandler)agen).buildGenerator(removed);
1179        }
1180      }
1181
1182      m_Generators.add(agen);
1183    }
1184
1185
1186    //REPEAT
1187    ArrayList thresholds = new ArrayList();
1188    for (int i = 0; i < m_NumRepeats; i++) {
1189
1190      //hold some data out
1191      Instances copyData = new Instances(newData);
1192      Instances heldout = new Instances(newData, 0);
1193      for (int k = 0; k < newData.numInstances() / m_PercentHeldout; k++) {
1194        int anindex = m_Random.nextInt(copyData.numInstances());
1195        heldout.add(copyData.instance(anindex));
1196        copyData.delete(anindex);
1197      }
1198
1199
1200      //generate some data
1201      this.generateData(copyData);
1202
1203      //build the classifier on the generated data   
1204      if (!m_UseDensityOnly)
1205        m_Classifier.buildClassifier(copyData);
1206
1207      //test the generated data, work out the threshold (average it later)
1208      double[] scores = new double[heldout.numInstances()];
1209      Enumeration iterInst = heldout.enumerateInstances();
1210      int classIndex = heldout.classAttribute().indexOfValue(m_TargetClassLabel);
1211      int count = 0; 
1212      while(iterInst.hasMoreElements()) {
1213        Instance anInst = (Instance)iterInst.nextElement();
1214        scores[count] = this.getProbXGivenC(anInst, classIndex);       
1215        count++;
1216      }
1217
1218      Arrays.sort(scores);
1219      //work out the where the threshold should be
1220      //higher probabilities = passes
1221      //sorted into ascending order (getting bigger)
1222      int passposition = (int)((double)heldout.numInstances() * m_TargetRejectionRate);
1223      if (passposition >=  heldout.numInstances())
1224        passposition = heldout.numInstances() - 1;
1225
1226      thresholds.add(new Double(scores[passposition]));
1227    }
1228    //END REPEAT
1229
1230
1231
1232    //build the classifier on the generated data 
1233
1234    //set the threshold
1235    m_Threshold = 0;
1236    for (int k = 0; k < thresholds.size(); k++) {
1237      m_Threshold += ((Double)thresholds.get(k)).doubleValue();
1238    }
1239    m_Threshold /= (double)thresholds.size();
1240
1241    //rebuild the classifier using all the data
1242    this.generateData(newData);
1243
1244    if (!m_UseDensityOnly)
1245      m_Classifier.buildClassifier(newData);
1246
1247  }
1248 
1249
1250  /**
1251   * Merges the class values of the instances down to two values,
1252   * the target class and the "outlier" class.
1253   *
1254   * @param newData The data to merge.
1255   * @return The merged data.
1256   */
1257  protected Instances mergeToTwo(Instances newData) throws Exception{
1258
1259    m_MergeFilter = new MergeManyValues();
1260    m_MergeFilter.setAttributeIndex("" + (newData.classIndex() + 1));
1261
1262    //figure out the indexes that aren't the outlier label or
1263    //the target label
1264    StringBuffer sb = new StringBuffer("");
1265
1266    Attribute theAttr = newData.classAttribute();
1267    for (int i = 0; i < theAttr.numValues(); i++) {
1268      if (! (theAttr.value(i).equalsIgnoreCase(OneClassClassifier.OUTLIER_LABEL) 
1269          || theAttr.value(i).equalsIgnoreCase(m_TargetClassLabel))) {
1270        //add it to the merge list
1271        sb.append((i + 1) + ",");
1272      }
1273    }
1274    String mergeList = sb.toString();
1275    if (mergeList.length() != 0) {
1276      mergeList = mergeList.substring(0, mergeList.length() - 1);
1277      int classIndex = newData.classIndex();
1278      newData.setClassIndex(-1);
1279      m_MergeFilter.setMergeValueRange(mergeList);
1280      m_MergeFilter.setLabel(OneClassClassifier.OUTLIER_LABEL);
1281      m_MergeFilter.setInputFormat(newData);
1282      newData = Filter.useFilter(newData, m_MergeFilter);
1283      newData.setClassIndex(classIndex);
1284    } else {
1285      m_MergeFilter = null;
1286    }
1287
1288    return newData;
1289  }
1290
1291  /**
1292   * Gets the probability that an instance, X, belongs to the target class, C. 
1293   *
1294   * @param instance The instance X.
1295   * @param targetClassIndex The index of the target class label for the class attribute.
1296   * @return The probability of X given C, P(X|C).
1297   */
1298  protected double getProbXGivenC(Instance instance, int targetClassIndex) throws Exception{
1299    double probC = 1 - m_ProportionGenerated;   
1300    double probXgivenA = 0;
1301    int count = 0;
1302    for (int i = 0; i < instance.numAttributes(); i++) {
1303      if (i != instance.classIndex()) {
1304        Generator agen = (Generator)m_Generators.get(count);
1305        if (!instance.isMissing(i)) {
1306          probXgivenA += agen.getLogProbabilityOf(instance.value(i));
1307        }
1308        count++;       
1309      }
1310    }
1311
1312    if (m_UseDensityOnly)
1313      return probXgivenA;
1314
1315    double[] distribution = m_Classifier.distributionForInstance(instance);
1316    double probCgivenX = distribution[targetClassIndex];
1317    if(probCgivenX == 1)
1318        return Double.POSITIVE_INFINITY;
1319
1320    //final calculation
1321    double top = Math.log(1 - probC) + Math.log(probCgivenX);
1322    double bottom = Math.log(probC) + Math.log(1 - probCgivenX);   
1323
1324    return (top - bottom) + probXgivenA;       
1325  }
1326
1327
1328  /**
1329   * Generates some outlier data and returns the targetData with some outlier data included.
1330   *
1331   * @param targetData The data for the target class.
1332   * @return The outlier and target data together.
1333   */
1334  protected Instances generateData(Instances targetData) {
1335    double totalInstances = ((double)targetData.numInstances()) / (1 - m_ProportionGenerated);
1336
1337    int numInstances = (int)(totalInstances - (double)targetData.numInstances());
1338
1339    //first reweight the target data
1340    if (m_UseInstanceWeights) {
1341      for (int i = 0; i < targetData.numInstances(); i++) {
1342        targetData.instance(i).setWeight(0.5 * (1 / (1 - m_ProportionGenerated)));
1343      }
1344    }
1345
1346    for (int j = 0; j < numInstances; j++) {
1347      //add to the targetData the instances that we generate...
1348      Instance anInst = new DenseInstance(targetData.numAttributes());
1349      anInst.setDataset(targetData);
1350      int position = 0;
1351      for (int i = 0; i < targetData.numAttributes(); i++) {
1352        if (targetData.classIndex() != i) {
1353          //not the class attribute
1354          Generator agen = (Generator)m_Generators.get(position);
1355          anInst.setValue(i, agen.generate());
1356          position++;
1357        } else {
1358          //is the class attribute
1359          anInst.setValue(i, OneClassClassifier.OUTLIER_LABEL);
1360          if (m_UseInstanceWeights)
1361            anInst.setWeight(0.5 * (1 / m_ProportionGenerated));
1362        }
1363
1364      }
1365      targetData.add(anInst);       
1366    }
1367
1368    return targetData;
1369  }
1370
1371  /**
1372   * Returns a probability distribution for a given instance.
1373   *
1374   * @param instance The instance to calculate the probability distribution for.
1375   * @return The probability for each class.
1376   */
1377  public double[] distributionForInstance(Instance instance) throws Exception {
1378    Instance filtered = (Instance)instance.copy();
1379
1380    m_AddOutlierFilter.input(instance);
1381    filtered = m_AddOutlierFilter.output();
1382
1383    if (m_UseLaplaceCorrection && m_MergeFilter != null) {
1384      m_MergeFilter.input(filtered);
1385      filtered = m_MergeFilter.output();
1386    }
1387
1388    double[] dist = new double[instance.numClasses()];
1389    double probForOutlierClass = 1 / (1 + Math.exp(this.getProbXGivenC(filtered, filtered.classAttribute().indexOfValue(m_TargetClassLabel)) - m_Threshold));
1390    if(this.getProbXGivenC(filtered, filtered.classAttribute().indexOfValue(m_TargetClassLabel)) == Double.POSITIVE_INFINITY)
1391        probForOutlierClass = 0;
1392
1393    dist[instance.classAttribute().indexOfValue(m_TargetClassLabel)] = 1 - probForOutlierClass; 
1394    if (instance.classAttribute().indexOfValue(OneClassClassifier.OUTLIER_LABEL) == -1) {
1395      if (this.getProbXGivenC(filtered, filtered.classAttribute().indexOfValue(m_TargetClassLabel)) >= m_Threshold)
1396        dist[instance.classAttribute().indexOfValue(m_TargetClassLabel)] = 1;
1397      else
1398        dist[instance.classAttribute().indexOfValue(m_TargetClassLabel)] = 0;
1399    } else
1400      dist[instance.classAttribute().indexOfValue(OneClassClassifier.OUTLIER_LABEL)] = probForOutlierClass;
1401
1402    return dist;
1403  }
1404
1405  /**
1406   * Output a representation of this classifier
1407   *
1408   * @return a representation of this classifier
1409   */
1410  public String toString() {
1411
1412    StringBuffer result = new StringBuffer();
1413    result.append("\n\nClassifier Model\n"+m_Classifier.toString());
1414
1415    return result.toString();
1416  }
1417
1418  /**
1419   * Returns the revision string.
1420   *
1421   * @return The revision string.
1422   */
1423  public String getRevision() {
1424    return "$Revision: 5987 $";
1425  }
1426
1427  /**
1428   * Main method for executing this classifier.
1429   *
1430   * @param args        use -h to see all available options
1431   */
1432  public static void main(String[] args) {
1433    runClassifier(new OneClassClassifier(), args);
1434  }
1435}
1436
Note: See TracBrowser for help on using the repository browser.