source: src/main/java/weka/classifiers/meta/RotationForest.java @ 16

Last change on this file since 16 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 36.7 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    RotationForest.java
19 *    Copyright (C) 2008 Juan Jose Rodriguez
20 *    Copyright (C) 2008 University of Waikato, Hamilton, New Zealand
21 *
22 */
23
24
25package weka.classifiers.meta;
26
27import weka.classifiers.RandomizableParallelIteratedSingleClassifierEnhancer;
28import weka.core.Attribute;
29import weka.core.FastVector;
30import weka.core.Instance;
31import weka.core.DenseInstance;
32import weka.core.Instances;
33import weka.core.Option;
34import weka.core.OptionHandler;
35import weka.core.Randomizable;
36import weka.core.RevisionUtils;
37import weka.core.TechnicalInformation;
38import weka.core.TechnicalInformationHandler;
39import weka.core.Utils;
40import weka.core.WeightedInstancesHandler;
41import weka.core.TechnicalInformation.Field;
42import weka.core.TechnicalInformation.Type;
43import weka.filters.Filter;
44import weka.filters.unsupervised.attribute.Normalize;
45import weka.filters.unsupervised.attribute.PrincipalComponents;
46import weka.filters.unsupervised.attribute.RemoveUseless;
47import weka.filters.unsupervised.instance.RemovePercentage;
48
49import java.util.Enumeration;
50import java.util.Random;
51import java.util.Vector;
52
53/**
54 <!-- globalinfo-start -->
55 * Class for construction a Rotation Forest. Can do classification and regression depending on the base learner. <br/>
56 * <br/>
57 * For more information, see<br/>
58 * <br/>
59 * Juan J. Rodriguez, Ludmila I. Kuncheva, Carlos J. Alonso (2006). Rotation Forest: A new classifier ensemble method. IEEE Transactions on Pattern Analysis and Machine Intelligence. 28(10):1619-1630. URL http://doi.ieeecomputersociety.org/10.1109/TPAMI.2006.211.
60 * <p/>
61 <!-- globalinfo-end -->
62 *
63 <!-- technical-bibtex-start -->
64 * BibTeX:
65 * <pre>
66 * &#64;article{Rodriguez2006,
67 *    author = {Juan J. Rodriguez and Ludmila I. Kuncheva and Carlos J. Alonso},
68 *    journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence},
69 *    number = {10},
70 *    pages = {1619-1630},
71 *    title = {Rotation Forest: A new classifier ensemble method},
72 *    volume = {28},
73 *    year = {2006},
74 *    ISSN = {0162-8828},
75 *    URL = {http://doi.ieeecomputersociety.org/10.1109/TPAMI.2006.211}
76 * }
77 * </pre>
78 * <p/>
79 <!-- technical-bibtex-end -->
80 *
81 <!-- options-start -->
82 * Valid options are: <p/>
83 *
84 * <pre> -N
85 *  Whether minGroup (-G) and maxGroup (-H) refer to
86 *  the number of groups or their size.
87 *  (default: false)</pre>
88 *
89 * <pre> -G &lt;num&gt;
90 *  Minimum size of a group of attributes:
91 *   if numberOfGroups is true, the minimum number
92 *   of groups.
93 *   (default: 3)</pre>
94 *
95 * <pre> -H &lt;num&gt;
96 *  Maximum size of a group of attributes:
97 *   if numberOfGroups is true, the maximum number
98 *   of groups.
99 *   (default: 3)</pre>
100 *
101 * <pre> -P &lt;num&gt;
102 *  Percentage of instances to be removed.
103 *   (default: 50)</pre>
104 *
105 * <pre> -F &lt;filter specification&gt;
106 *  Full class name of filter to use, followed
107 *  by filter options.
108 *  eg: "weka.filters.unsupervised.attribute.PrincipalComponents-R 1.0"</pre>
109 *
110 * <pre> -S &lt;num&gt;
111 *  Random number seed.
112 *  (default 1)</pre>
113 *
114 * <pre> -I &lt;num&gt;
115 *  Number of iterations.
116 *  (default 10)</pre>
117 *
118 * <pre> -D
119 *  If set, classifier is run in debug mode and
120 *  may output additional info to the console</pre>
121 *
122 * <pre> -W
123 *  Full name of base classifier.
124 *  (default: weka.classifiers.trees.J48)</pre>
125 *
126 * <pre>
127 * Options specific to classifier weka.classifiers.trees.J48:
128 * </pre>
129 *
130 * <pre> -U
131 *  Use unpruned tree.</pre>
132 *
133 * <pre> -C &lt;pruning confidence&gt;
134 *  Set confidence threshold for pruning.
135 *  (default 0.25)</pre>
136 *
137 * <pre> -M &lt;minimum number of instances&gt;
138 *  Set minimum number of instances per leaf.
139 *  (default 2)</pre>
140 *
141 * <pre> -R
142 *  Use reduced error pruning.</pre>
143 *
144 * <pre> -N &lt;number of folds&gt;
145 *  Set number of folds for reduced error
146 *  pruning. One fold is used as pruning set.
147 *  (default 3)</pre>
148 *
149 * <pre> -B
150 *  Use binary splits only.</pre>
151 *
152 * <pre> -S
153 *  Don't perform subtree raising.</pre>
154 *
155 * <pre> -L
156 *  Do not clean up after the tree has been built.</pre>
157 *
158 * <pre> -A
159 *  Laplace smoothing for predicted probabilities.</pre>
160 *
161 * <pre> -Q &lt;seed&gt;
162 *  Seed for random data shuffling (default 1).</pre>
163 *
164 <!-- options-end -->
165 *
166 * @author Juan Jose Rodriguez (jjrodriguez@ubu.es)
167 * @version $Revision: 5987 $
168 */
169public class RotationForest 
170  extends RandomizableParallelIteratedSingleClassifierEnhancer
171  implements WeightedInstancesHandler, TechnicalInformationHandler {
172  // It implements WeightedInstancesHandler because the base classifier
173  // can implement this interface, but in this method the weights are
174  // not used
175
176  /** for serialization */
177  static final long serialVersionUID = -3255631880798499936L;
178
179  /** The minimum size of a group */
180  protected int m_MinGroup = 3;
181
182  /** The maximum size of a group */
183  protected int m_MaxGroup = 3;
184
185  /**
186   * Whether minGroup and maxGroup refer to the number of groups or their
187   * size */
188  protected boolean m_NumberOfGroups = false;
189
190  /** The percentage of instances to be removed */
191  protected int m_RemovedPercentage = 50;
192
193  /** The attributes of each group */
194  protected int [][][] m_Groups = null;
195
196  /** The type of projection filter */
197  protected Filter m_ProjectionFilter = null;
198
199  /** The projection filters */
200  protected Filter [][] m_ProjectionFilters = null;
201
202  /** Headers of the transformed dataset */
203  protected Instances [] m_Headers = null;
204
205  /** Headers of the reduced datasets */
206  protected Instances [][] m_ReducedHeaders = null;
207
208  /** Filter that remove useless attributes */
209  protected RemoveUseless m_RemoveUseless = null;
210
211  /** Filter that normalized the attributes */
212  protected Normalize m_Normalize = null;
213 
214  /** Training data */
215  protected Instances m_data;
216
217  protected Instances [] m_instancesOfClasses;
218
219  protected Random m_random;
220
221  /**
222   * Constructor.
223   */
224  public RotationForest() {
225   
226    m_Classifier = new weka.classifiers.trees.J48();
227    m_ProjectionFilter = defaultFilter();
228  }
229
230  /**
231   * Default projection method.
232   */
233  protected Filter defaultFilter() {
234    PrincipalComponents filter = new PrincipalComponents();
235    filter.setNormalize(false);
236    filter.setVarianceCovered(1.0);
237    return filter;
238  }
239 
240  /**
241   * Returns a string describing classifier
242   * @return a description suitable for
243   * displaying in the explorer/experimenter gui
244   */
245  public String globalInfo() {
246 
247    return "Class for construction a Rotation Forest. Can do classification "
248      + "and regression depending on the base learner. \n\n"
249      + "For more information, see\n\n"
250      + getTechnicalInformation().toString();
251  }
252
253  /**
254   * Returns an instance of a TechnicalInformation object, containing
255   * detailed information about the technical background of this class,
256   * e.g., paper reference or book this class is based on.
257   *
258   * @return the technical information about this class
259   */
260  public TechnicalInformation getTechnicalInformation() {
261    TechnicalInformation        result;
262   
263    result = new TechnicalInformation(Type.ARTICLE);
264    result.setValue(Field.AUTHOR, "Juan J. Rodriguez and Ludmila I. Kuncheva and Carlos J. Alonso");
265    result.setValue(Field.YEAR, "2006");
266    result.setValue(Field.TITLE, "Rotation Forest: A new classifier ensemble method");
267    result.setValue(Field.JOURNAL, "IEEE Transactions on Pattern Analysis and Machine Intelligence");
268    result.setValue(Field.VOLUME, "28");
269    result.setValue(Field.NUMBER, "10");
270    result.setValue(Field.PAGES, "1619-1630");
271    result.setValue(Field.ISSN, "0162-8828");
272    result.setValue(Field.URL, "http://doi.ieeecomputersociety.org/10.1109/TPAMI.2006.211");
273   
274    return result;
275  }
276
277  /**
278   * String describing default classifier.
279   *
280   * @return the default classifier classname
281   */
282  protected String defaultClassifierString() {
283   
284    return "weka.classifiers.trees.J48";
285  }
286
287  /**
288   * Returns an enumeration describing the available options.
289   *
290   * @return an enumeration of all the available options.
291   */
292  public Enumeration listOptions() {
293
294    Vector newVector = new Vector(5);
295
296    newVector.addElement(new Option(
297              "\tWhether minGroup (-G) and maxGroup (-H) refer to"
298              + "\n\tthe number of groups or their size."
299              + "\n\t(default: false)",
300              "N", 0, "-N"));
301
302    newVector.addElement(new Option(
303              "\tMinimum size of a group of attributes:"
304              + "\n\t\tif numberOfGroups is true, the minimum number"
305              + "\n\t\tof groups."
306              + "\n\t\t(default: 3)",
307              "G", 1, "-G <num>"));
308
309    newVector.addElement(new Option(
310              "\tMaximum size of a group of attributes:"
311              + "\n\t\tif numberOfGroups is true, the maximum number" 
312              + "\n\t\tof groups."
313              + "\n\t\t(default: 3)",
314              "H", 1, "-H <num>"));
315
316    newVector.addElement(new Option(
317              "\tPercentage of instances to be removed."
318              + "\n\t\t(default: 50)",
319              "P", 1, "-P <num>"));
320
321    newVector.addElement(new Option(
322              "\tFull class name of filter to use, followed\n"
323              + "\tby filter options.\n"
324              + "\teg: \"weka.filters.unsupervised.attribute.PrincipalComponents-R 1.0\"",
325              "F", 1, "-F <filter specification>"));
326
327    Enumeration enu = super.listOptions();
328    while (enu.hasMoreElements()) {
329      newVector.addElement(enu.nextElement());
330    }
331    return newVector.elements();
332  }
333
334  /**
335   * Parses a given list of options. <p/>
336   *
337   <!-- options-start -->
338   * Valid options are: <p/>
339   *
340   * <pre> -N
341   *  Whether minGroup (-G) and maxGroup (-H) refer to
342   *  the number of groups or their size.
343   *  (default: false)</pre>
344   *
345   * <pre> -G &lt;num&gt;
346   *  Minimum size of a group of attributes:
347   *   if numberOfGroups is true, the minimum number
348   *   of groups.
349   *   (default: 3)</pre>
350   *
351   * <pre> -H &lt;num&gt;
352   *  Maximum size of a group of attributes:
353   *   if numberOfGroups is true, the maximum number
354   *   of groups.
355   *   (default: 3)</pre>
356   *
357   * <pre> -P &lt;num&gt;
358   *  Percentage of instances to be removed.
359   *   (default: 50)</pre>
360   *
361   * <pre> -F &lt;filter specification&gt;
362   *  Full class name of filter to use, followed
363   *  by filter options.
364   *  eg: "weka.filters.unsupervised.attribute.PrincipalComponents-R 1.0"</pre>
365   *
366   * <pre> -S &lt;num&gt;
367   *  Random number seed.
368   *  (default 1)</pre>
369   *
370   * <pre> -I &lt;num&gt;
371   *  Number of iterations.
372   *  (default 10)</pre>
373   *
374   * <pre> -D
375   *  If set, classifier is run in debug mode and
376   *  may output additional info to the console</pre>
377   *
378   * <pre> -W
379   *  Full name of base classifier.
380   *  (default: weka.classifiers.trees.J48)</pre>
381   *
382   * <pre>
383   * Options specific to classifier weka.classifiers.trees.J48:
384   * </pre>
385   *
386   * <pre> -U
387   *  Use unpruned tree.</pre>
388   *
389   * <pre> -C &lt;pruning confidence&gt;
390   *  Set confidence threshold for pruning.
391   *  (default 0.25)</pre>
392   *
393   * <pre> -M &lt;minimum number of instances&gt;
394   *  Set minimum number of instances per leaf.
395   *  (default 2)</pre>
396   *
397   * <pre> -R
398   *  Use reduced error pruning.</pre>
399   *
400   * <pre> -N &lt;number of folds&gt;
401   *  Set number of folds for reduced error
402   *  pruning. One fold is used as pruning set.
403   *  (default 3)</pre>
404   *
405   * <pre> -B
406   *  Use binary splits only.</pre>
407   *
408   * <pre> -S
409   *  Don't perform subtree raising.</pre>
410   *
411   * <pre> -L
412   *  Do not clean up after the tree has been built.</pre>
413   *
414   * <pre> -A
415   *  Laplace smoothing for predicted probabilities.</pre>
416   *
417   * <pre> -Q &lt;seed&gt;
418   *  Seed for random data shuffling (default 1).</pre>
419   *
420   <!-- options-end -->
421   *
422   * @param options the list of options as an array of strings
423   * @throws Exception if an option is not supported
424   */
425  public void setOptions(String[] options) throws Exception {
426
427    /* Taken from FilteredClassifier */
428    String filterString = Utils.getOption('F', options);
429    if (filterString.length() > 0) {
430      String [] filterSpec = Utils.splitOptions(filterString);
431      if (filterSpec.length == 0) {
432        throw new IllegalArgumentException("Invalid filter specification string");
433      }
434      String filterName = filterSpec[0];
435      filterSpec[0] = "";
436      setProjectionFilter((Filter) Utils.forName(Filter.class, filterName, filterSpec));
437    } else {
438      setProjectionFilter(defaultFilter());
439    }
440
441    String tmpStr;
442   
443    tmpStr = Utils.getOption('G', options);
444    if (tmpStr.length() != 0)
445      setMinGroup(Integer.parseInt(tmpStr));
446    else
447      setMinGroup(3);
448
449    tmpStr = Utils.getOption('H', options);
450    if (tmpStr.length() != 0)
451      setMaxGroup(Integer.parseInt(tmpStr));
452    else
453      setMaxGroup(3);
454
455    tmpStr = Utils.getOption('P', options);
456    if (tmpStr.length() != 0)
457      setRemovedPercentage(Integer.parseInt(tmpStr));
458    else
459      setRemovedPercentage(50);
460
461    setNumberOfGroups(Utils.getFlag('N', options));
462
463    super.setOptions(options);
464  }
465
466  /**
467   * Gets the current settings of the Classifier.
468   *
469   * @return an array of strings suitable for passing to setOptions
470   */
471  public String [] getOptions() {
472
473    String [] superOptions = super.getOptions();
474    String [] options = new String [superOptions.length + 9];
475
476    int current = 0;
477
478    if (getNumberOfGroups()) { 
479      options[current++] = "-N";
480    }
481
482    options[current++] = "-G"; 
483    options[current++] = "" + getMinGroup();
484
485    options[current++] = "-H"; 
486    options[current++] = "" + getMaxGroup();
487
488    options[current++] = "-P"; 
489    options[current++] = "" + getRemovedPercentage();
490
491    options[current++] = "-F";
492    options[current++] = getProjectionFilterSpec();
493
494    System.arraycopy(superOptions, 0, options, current, 
495                     superOptions.length);
496
497    current += superOptions.length;
498    while (current < options.length) {
499      options[current++] = "";
500    }
501    return options;
502  }
503
504  /**
505   * Returns the tip text for this property
506   * @return tip text for this property suitable for
507   * displaying in the explorer/experimenter gui
508   */
509  public String numberOfGroupsTipText() {
510    return "Whether minGroup and maxGroup refer to the number of groups or their size.";
511  }
512
513  /**
514   * Set whether minGroup and maxGroup refer to the number of groups or their
515   * size
516   *
517   * @param numberOfGroups whether minGroup and maxGroup refer to the number
518   * of groups or their size
519   */
520  public void setNumberOfGroups(boolean numberOfGroups) {
521
522    m_NumberOfGroups = numberOfGroups;
523  }
524
525  /**
526   * Get whether minGroup and maxGroup refer to the number of groups or their
527   * size
528   *
529   * @return whether minGroup and maxGroup refer to the number of groups or
530   * their size
531   */
532  public boolean getNumberOfGroups() {
533
534    return m_NumberOfGroups;
535  }
536
537  /**
538   * Returns the tip text for this property
539   * @return tip text for this property suitable for displaying in the
540   * explorer/experimenter gui
541   */
542  public String minGroupTipText() {
543    return "Minimum size of a group (if numberOfGrups is true, the minimum number of groups.";
544  }
545
546  /**
547   * Sets the minimum size of a group.
548   *
549   * @param minGroup the minimum value.
550   * of attributes.
551   */
552  public void setMinGroup( int minGroup ) throws IllegalArgumentException {
553
554    if( minGroup <= 0 )
555      throw new IllegalArgumentException( "MinGroup has to be positive." );
556    m_MinGroup = minGroup;
557  }
558
559  /**
560   * Gets the minimum size of a group.
561   *
562   * @return            the minimum value.
563   */
564  public int getMinGroup() {
565    return m_MinGroup;
566  }
567
568  /**
569   * Returns the tip text for this property
570   * @return tip text for this property suitable for
571   * displaying in the explorer/experimenter gui
572   */
573  public String maxGroupTipText() {
574    return "Maximum size of a group (if numberOfGrups is true, the maximum number of groups.";
575  }
576
577  /**
578   * Sets the maximum size of a group.
579   *
580   * @param maxGroup the maximum value.
581   * of attributes.
582   */
583  public void setMaxGroup( int maxGroup ) throws IllegalArgumentException {
584 
585    if( maxGroup <= 0 )
586      throw new IllegalArgumentException( "MaxGroup has to be positive." );
587    m_MaxGroup = maxGroup;
588  }
589
590  /**
591   * Gets the maximum size of a group.
592   *
593   * @return            the maximum value.
594   */
595  public int getMaxGroup() {
596    return m_MaxGroup;
597  }
598
599  /**
600   * Returns the tip text for this property
601   * @return tip text for this property suitable for
602   * displaying in the explorer/experimenter gui
603   */
604  public String removedPercentageTipText() {
605    return "The percentage of instances to be removed.";
606  }
607
608  /**
609   * Sets the percentage of instance to be removed
610   *
611   * @param removedPercentage the percentage.
612   */
613  public void setRemovedPercentage( int removedPercentage ) throws IllegalArgumentException {
614
615    if( removedPercentage < 0 )
616      throw new IllegalArgumentException( "RemovedPercentage has to be >=0." );
617    if( removedPercentage >= 100 )
618      throw new IllegalArgumentException( "RemovedPercentage has to be <100." );
619 
620    m_RemovedPercentage = removedPercentage;
621  }
622
623  /**
624   * Gets the percentage of instances to be removed
625   *
626   * @return            the percentage.
627   */
628  public int getRemovedPercentage() {
629    return m_RemovedPercentage;
630  }
631
632  /**
633   * Returns the tip text for this property
634   * @return tip text for this property suitable for
635   * displaying in the explorer/experimenter gui
636   */
637  public String projectionFilterTipText() {
638    return "The filter used to project the data (e.g., PrincipalComponents).";
639  }
640
641  /**
642   * Sets the filter used to project the data.
643   *
644   * @param projectionFilter the filter.
645   */
646  public void setProjectionFilter( Filter projectionFilter ) {
647
648    m_ProjectionFilter = projectionFilter;
649  }
650
651  /**
652   * Gets the filter used to project the data.
653   *
654   * @return            the filter.
655   */
656  public Filter getProjectionFilter() {
657    return m_ProjectionFilter;
658  }
659
660  /**
661   * Gets the filter specification string, which contains the class name of
662   * the filter and any options to the filter
663   *
664   * @return the filter string.
665   */
666  /* Taken from FilteredClassifier */
667  protected String getProjectionFilterSpec() {
668   
669    Filter c = getProjectionFilter();
670    if (c instanceof OptionHandler) {
671      return c.getClass().getName() + " "
672        + Utils.joinOptions(((OptionHandler)c).getOptions());
673    }
674    return c.getClass().getName();
675  }
676
677  /**
678   * Returns description of the Rotation Forest classifier.
679   *
680   * @return description of the Rotation Forest classifier as a string
681   */
682  public String toString() {
683   
684    if (m_Classifiers == null) {
685      return "RotationForest: No model built yet.";
686    }
687    StringBuffer text = new StringBuffer();
688    text.append("All the base classifiers: \n\n");
689    for (int i = 0; i < m_Classifiers.length; i++)
690      text.append(m_Classifiers[i].toString() + "\n\n");
691   
692    return text.toString();
693  }
694
695  /**
696   * Returns the revision string.
697   *
698   * @return            the revision
699   */
700  public String getRevision() {
701    return RevisionUtils.extract("$Revision: 5987 $");
702  }
703 
704  protected class ClassifierWrapper extends weka.classifiers.AbstractClassifier {
705   
706    /** For serialization */
707    private static final long serialVersionUID = 2327175798869994435L;
708   
709    protected weka.classifiers.Classifier m_wrappedClassifier;
710    protected int m_classifierNumber;
711   
712    public ClassifierWrapper(weka.classifiers.Classifier classifier, int classifierNumber) {
713      super();
714     
715      m_wrappedClassifier = classifier;
716      m_classifierNumber = classifierNumber;
717    }
718   
719    public void buildClassifier(Instances data) throws Exception {
720      m_ReducedHeaders[m_classifierNumber] = new Instances[ m_Groups[m_classifierNumber].length ];
721      FastVector transformedAttributes = new FastVector( m_data.numAttributes() );
722     
723      // Construction of the dataset for each group of attributes
724      for( int j = 0; j < m_Groups[ m_classifierNumber ].length; j++ ) {
725        FastVector fv = new FastVector( m_Groups[m_classifierNumber][j].length + 1 );
726        for( int k = 0; k < m_Groups[m_classifierNumber][j].length; k++ ) {
727          fv.addElement( m_data.attribute( m_Groups[m_classifierNumber][j][k] ).copy() );
728        }
729        fv.addElement( m_data.classAttribute( ).copy() );
730        Instances dataSubSet = new Instances( "rotated-" + m_classifierNumber + "-" + j + "-", 
731            fv, 0);
732        dataSubSet.setClassIndex( dataSubSet.numAttributes() - 1 );
733       
734        // Select instances for the dataset
735        m_ReducedHeaders[m_classifierNumber][j] = new Instances( dataSubSet, 0 );
736        boolean [] selectedClasses = selectClasses( m_instancesOfClasses.length, 
737              m_random );
738        for( int c = 0; c < selectedClasses.length; c++ ) {
739          if( !selectedClasses[c] )
740            continue;
741          Enumeration enu = m_instancesOfClasses[c].enumerateInstances();
742          while( enu.hasMoreElements() ) {
743            Instance instance = (Instance)enu.nextElement();
744            Instance newInstance = new DenseInstance(dataSubSet.numAttributes());
745            newInstance.setDataset( dataSubSet );
746            for( int k = 0; k < m_Groups[m_classifierNumber][j].length; k++ ) {
747              newInstance.setValue( k, instance.value( m_Groups[m_classifierNumber][j][k] ) );
748            }
749            newInstance.setClassValue( instance.classValue( ) );
750            dataSubSet.add( newInstance );
751          }
752        }
753       
754        dataSubSet.randomize(m_random);
755        // Remove a percentage of the instances
756        Instances originalDataSubSet = dataSubSet;
757        dataSubSet.randomize(m_random);
758        RemovePercentage rp = new RemovePercentage();
759        rp.setPercentage( m_RemovedPercentage );
760        rp.setInputFormat( dataSubSet );
761        dataSubSet = Filter.useFilter( dataSubSet, rp );
762        if( dataSubSet.numInstances() < 2 ) {
763          dataSubSet = originalDataSubSet;
764        }
765       
766        // Project de data
767        m_ProjectionFilters[m_classifierNumber][j].setInputFormat( dataSubSet );
768        Instances projectedData = null;
769        do {
770          try {
771            projectedData = Filter.useFilter( dataSubSet, 
772                m_ProjectionFilters[m_classifierNumber][j] );
773          } catch ( Exception e ) {
774            // The data could not be projected, we add some random instances
775            addRandomInstances( dataSubSet, 10, m_random );
776          }
777        } while( projectedData == null );
778
779        // Include the projected attributes in the attributes of the
780        // transformed dataset
781        for( int a = 0; a < projectedData.numAttributes() - 1; a++ ) {
782          transformedAttributes.addElement( projectedData.attribute(a).copy());
783        }                       
784      }
785     
786      transformedAttributes.addElement( m_data.classAttribute().copy() );
787      Instances transformedData = new Instances( "rotated-" + m_classifierNumber + "-", 
788        transformedAttributes, 0 );
789      transformedData.setClassIndex( transformedData.numAttributes() - 1 );
790      m_Headers[ m_classifierNumber ] = new Instances( transformedData, 0 );
791
792      // Project all the training data
793      Enumeration enu = m_data.enumerateInstances();
794      while( enu.hasMoreElements() ) {
795        Instance instance = (Instance)enu.nextElement();
796        Instance newInstance = convertInstance( instance, m_classifierNumber );
797        transformedData.add( newInstance );
798      }
799
800      // Build the base classifier
801      if (m_wrappedClassifier instanceof Randomizable) {
802        ((Randomizable) m_wrappedClassifier).setSeed(m_random.nextInt());
803      }
804      m_wrappedClassifier.buildClassifier( transformedData );           
805    }
806   
807    public double classifierInstance(Instance instance) throws Exception {
808      return m_wrappedClassifier.classifyInstance(instance);
809    }
810   
811    public double[] distributionForInstance(Instance instance) throws Exception {
812      return m_wrappedClassifier.distributionForInstance(instance);
813    }
814   
815    public String toString() {
816      return m_wrappedClassifier.toString();
817    }
818  }
819 
820  protected Instances getTrainingSet(int iteration) throws Exception {
821   
822    // The wrapped base classifiers' buildClassifier method creates the
823    // transformed training data
824    return m_data;
825  }
826
827  /**
828   * builds the classifier.
829   *
830   * @param data        the training data to be used for generating the
831   *                    classifier.
832   * @throws Exception  if the classifier could not be built successfully
833   */
834  public void buildClassifier(Instances data) throws Exception {
835
836    // can classifier handle the data?
837    getCapabilities().testWithFail(data);
838
839    m_data = new Instances( data );
840    super.buildClassifier(m_data);
841   
842    // Wrap up the base classifiers
843    for (int i = 0; i < m_Classifiers.length; i++) {
844      ClassifierWrapper cw = new ClassifierWrapper(m_Classifiers[i], i);
845     
846      m_Classifiers[i] = cw;
847    }
848
849    checkMinMax(m_data);
850
851    if( m_data.numInstances() > 0 ) {
852      // This function fails if there are 0 instances
853      m_random = m_data.getRandomNumberGenerator(m_Seed);
854    }
855    else {
856      m_random = new Random(m_Seed);
857    }
858
859    m_RemoveUseless = new RemoveUseless();
860    m_RemoveUseless.setInputFormat(m_data);
861    m_data = Filter.useFilter(data, m_RemoveUseless);
862
863    m_Normalize = new Normalize();
864    m_Normalize.setInputFormat(m_data);
865    m_data = Filter.useFilter(m_data, m_Normalize);
866
867    if(m_NumberOfGroups) {
868      generateGroupsFromNumbers(m_data, m_random);
869    }
870    else {
871      generateGroupsFromSizes(m_data, m_random);
872    }
873
874    m_ProjectionFilters = new Filter[m_Groups.length][];
875    for(int i = 0; i < m_ProjectionFilters.length; i++ ) {
876      m_ProjectionFilters[i] = Filter.makeCopies( m_ProjectionFilter, 
877          m_Groups[i].length );
878    }
879
880    int numClasses = m_data.numClasses();
881
882    m_instancesOfClasses = new Instances[numClasses + 1]; 
883    if( m_data.classAttribute().isNumeric() ) {
884      m_instancesOfClasses = new Instances[numClasses]; 
885      m_instancesOfClasses[0] = m_data;
886    }
887    else {
888      m_instancesOfClasses = new Instances[numClasses+1]; 
889      for( int i = 0; i < m_instancesOfClasses.length; i++ ) {
890        m_instancesOfClasses[ i ] = new Instances( m_data, 0 );
891      }
892      Enumeration enu = m_data.enumerateInstances();
893      while( enu.hasMoreElements() ) {
894        Instance instance = (Instance)enu.nextElement();
895        if( instance.classIsMissing() ) {
896          m_instancesOfClasses[numClasses].add( instance );
897        }
898        else {
899          int c = (int)instance.classValue();
900          m_instancesOfClasses[c].add( instance );
901        }
902      }
903      // If there are not instances with a missing class, we do not need to
904      // consider them
905      if( m_instancesOfClasses[numClasses].numInstances() == 0 ) {
906        Instances [] tmp = m_instancesOfClasses;
907        m_instancesOfClasses =  new Instances[ numClasses ];
908        System.arraycopy( tmp, 0, m_instancesOfClasses, 0, numClasses );
909      }
910    }
911
912    // These arrays keep the information of the transformed data set
913    m_Headers = new Instances[ m_Classifiers.length ];
914    m_ReducedHeaders = new Instances[ m_Classifiers.length ][];
915   
916    buildClassifiers();
917
918    if(m_Debug){
919      printGroups();
920    }
921   
922    // save memory
923    m_data = null;
924    m_instancesOfClasses = null;
925    m_random = null;
926  }
927
928  /**
929   * Adds random instances to the dataset.
930   *
931   * @param dataset the dataset
932   * @param numInstances the number of instances
933   * @param random a random number generator
934   */
935  protected void addRandomInstances( Instances dataset, int numInstances, 
936                                  Random random ) {
937    int n = dataset.numAttributes();                           
938    double [] v = new double[ n ];
939    for( int i = 0; i < numInstances; i++ ) {
940      for( int j = 0; j < n; j++ ) {
941        Attribute att = dataset.attribute( j );
942        if( att.isNumeric() ) {
943          v[ j ] = random.nextDouble();
944        }
945        else if ( att.isNominal() ) { 
946          v[ j ] = random.nextInt( att.numValues() );
947        }
948      }
949      dataset.add( new DenseInstance( 1, v ) );
950    }
951  }
952
953  /**
954   * Checks m_MinGroup and m_MaxGroup
955   *
956   * @param data the dataset
957   */
958  protected void checkMinMax(Instances data) {
959    if( m_MinGroup > m_MaxGroup ) {
960      int tmp = m_MaxGroup;
961      m_MaxGroup = m_MinGroup;
962      m_MinGroup = tmp;
963    }
964   
965    int n = data.numAttributes();
966    if( m_MaxGroup >= n )
967      m_MaxGroup = n - 1;
968    if( m_MinGroup >= n )
969      m_MinGroup = n - 1;
970  }
971
972  /**
973   * Selects a non-empty subset of the classes
974   *
975   * @param numClasses         the number of classes
976   * @param random             the random number generator.
977   * @return a random subset of classes
978   */
979  protected boolean [] selectClasses( int numClasses, Random random ) {
980
981    int numSelected = 0;
982    boolean selected[] = new boolean[ numClasses ];
983
984    for( int i = 0; i < selected.length; i++ ) {
985      if(random.nextBoolean()) {
986        selected[i] = true;
987        numSelected++;
988      }
989    }
990    if( numSelected == 0 ) {
991      selected[random.nextInt( selected.length )] = true;
992    }
993    return selected;
994  }
995
996  /**
997   * generates the groups of attributes, given their minimum and maximum
998   * sizes.
999   *
1000   * @param data        the training data to be used for generating the
1001   *                    groups.
1002   * @param random      the random number generator.
1003   */
1004  protected void generateGroupsFromSizes(Instances data, Random random) {
1005    m_Groups = new int[m_Classifiers.length][][];
1006    for( int i = 0; i < m_Classifiers.length; i++ ) {
1007      int [] permutation = attributesPermutation(data.numAttributes(), 
1008                           data.classIndex(), random);
1009
1010      // The number of groups that have a given size
1011      int [] numGroupsOfSize = new int[m_MaxGroup - m_MinGroup + 1];
1012
1013      int numAttributes = 0;
1014      int numGroups;
1015
1016      // Select the size of each group
1017      for( numGroups = 0; numAttributes < permutation.length; numGroups++ ) {
1018        int n = random.nextInt( numGroupsOfSize.length );
1019        numGroupsOfSize[n]++;
1020        numAttributes += m_MinGroup + n;
1021      }
1022
1023      m_Groups[i] = new int[numGroups][];
1024      int currentAttribute = 0;
1025      int currentSize = 0;
1026      for( int j = 0; j < numGroups; j++ ) {
1027        while( numGroupsOfSize[ currentSize ] == 0 )
1028          currentSize++;
1029        numGroupsOfSize[ currentSize ]--;
1030        int n = m_MinGroup + currentSize;
1031        m_Groups[i][j] = new int[n];
1032        for( int k = 0; k < n; k++ ) {
1033          if( currentAttribute < permutation.length )
1034            m_Groups[i][j][k] = permutation[ currentAttribute ];
1035          else
1036            // For the last group, it can be necessary to reuse some attributes
1037            m_Groups[i][j][k] = permutation[ random.nextInt( 
1038                permutation.length ) ];
1039          currentAttribute++;
1040        }
1041      }
1042    }
1043  }
1044
1045  /**
1046   * generates the groups of attributes, given their minimum and maximum
1047   * numbers.
1048   *
1049   * @param data        the training data to be used for generating the
1050   *                    groups.
1051   * @param random      the random number generator.
1052   */
1053  protected void generateGroupsFromNumbers(Instances data, Random random) {
1054    m_Groups = new int[m_Classifiers.length][][];
1055    for( int i = 0; i < m_Classifiers.length; i++ ) {
1056      int [] permutation = attributesPermutation(data.numAttributes(), 
1057                           data.classIndex(), random);
1058      int numGroups = m_MinGroup + random.nextInt(m_MaxGroup - m_MinGroup + 1);
1059      m_Groups[i] = new int[numGroups][];
1060      int groupSize = permutation.length / numGroups;
1061
1062      // Some groups will have an additional attribute
1063      int numBiggerGroups = permutation.length % numGroups;
1064
1065      // Distribute the attributes in the groups
1066      int currentAttribute = 0;
1067      for( int j = 0; j < numGroups; j++ ) {
1068        if( j < numBiggerGroups ) {
1069          m_Groups[i][j] = new int[groupSize + 1];
1070        }
1071        else {
1072          m_Groups[i][j] = new int[groupSize];
1073        }
1074        for( int k = 0; k < m_Groups[i][j].length; k++ ) {
1075          m_Groups[i][j][k] = permutation[currentAttribute++];
1076        }
1077      }
1078    }
1079  }
1080
1081  /**
1082   * generates a permutation of the attributes.
1083   *
1084   * @param numAttributes       the number of attributes.
1085   * @param classAttributes     the index of the class attribute.
1086   * @param random              the random number generator.
1087   * @return a permutation of the attributes
1088   */
1089  protected int [] attributesPermutation(int numAttributes, int classAttribute,
1090                                         Random random) {
1091    int [] permutation = new int[numAttributes-1];
1092    int i = 0;
1093    for(; i < classAttribute; i++){
1094      permutation[i] = i;
1095    }
1096    for(; i < permutation.length; i++){
1097      permutation[i] = i + 1;
1098    }
1099
1100    permute( permutation, random );
1101
1102    return permutation;
1103  }
1104
1105  /**
1106   * permutes the elements of a given array.
1107   *
1108   * @param v       the array to permute
1109   * @param random  the random number generator.
1110   */
1111  protected void permute( int v[], Random random ) {
1112
1113    for(int i = v.length - 1; i > 0; i-- ) {
1114      int j = random.nextInt( i + 1 );
1115      if( i != j ) {
1116        int tmp = v[i];
1117        v[i] = v[j];
1118        v[j] = tmp;
1119      }
1120    }
1121  }
1122
1123  /**
1124   * prints the groups.
1125   */
1126  protected void printGroups( ) {
1127    for( int i = 0; i < m_Groups.length; i++ ) {
1128      for( int j = 0; j < m_Groups[i].length; j++ ) {
1129        System.err.print( "( " );
1130        for( int k = 0; k < m_Groups[i][j].length; k++ ) {
1131          System.err.print( m_Groups[i][j][k] );
1132          System.err.print( " " );
1133        }
1134        System.err.print( ") " );
1135      }
1136      System.err.println( );
1137    }
1138  }
1139
1140  /**
1141   * Transforms an instance for the i-th classifier.
1142   *
1143   * @param instance the instance to be transformed
1144   * @param i the base classifier number
1145   * @return the transformed instance
1146   * @throws Exception if the instance can't be converted successfully
1147   */
1148  protected Instance convertInstance( Instance instance, int i ) 
1149  throws Exception {
1150    Instance newInstance = new DenseInstance( m_Headers[ i ].numAttributes( ) );
1151    newInstance.setDataset( m_Headers[ i ] );
1152    int currentAttribute = 0;
1153
1154    // Project the data for each group
1155    for( int j = 0; j < m_Groups[i].length; j++ ) {
1156      Instance auxInstance = new DenseInstance( m_Groups[i][j].length + 1 );
1157      int k;
1158      for( k = 0; k < m_Groups[i][j].length; k++ ) {
1159        auxInstance.setValue( k, instance.value( m_Groups[i][j][k] ) );
1160      }
1161      auxInstance.setValue( k, instance.classValue( ) );
1162      auxInstance.setDataset( m_ReducedHeaders[ i ][ j ] );
1163      m_ProjectionFilters[i][j].input( auxInstance );
1164      auxInstance = m_ProjectionFilters[i][j].output( );
1165      m_ProjectionFilters[i][j].batchFinished();
1166      for( int a = 0; a < auxInstance.numAttributes() - 1; a++ ) {
1167        newInstance.setValue( currentAttribute++, auxInstance.value( a ) );
1168      }
1169    }
1170
1171    newInstance.setClassValue( instance.classValue() );
1172    return newInstance;
1173  }
1174
1175  /**
1176   * Calculates the class membership probabilities for the given test
1177   * instance.
1178   *
1179   * @param instance the instance to be classified
1180   * @return preedicted class probability distribution
1181   * @throws Exception if distribution can't be computed successfully
1182   */
1183  public double[] distributionForInstance(Instance instance) throws Exception {
1184
1185    m_RemoveUseless.input(instance);
1186    instance =m_RemoveUseless.output();
1187    m_RemoveUseless.batchFinished();
1188
1189    m_Normalize.input(instance);
1190    instance =m_Normalize.output();
1191    m_Normalize.batchFinished();
1192
1193    double [] sums = new double [instance.numClasses()], newProbs; 
1194   
1195    for (int i = 0; i < m_Classifiers.length; i++) {
1196      Instance convertedInstance = convertInstance(instance, i);
1197      if (instance.classAttribute().isNumeric() == true) {
1198        sums[0] += m_Classifiers[i].classifyInstance(convertedInstance);
1199      } else {
1200        newProbs = m_Classifiers[i].distributionForInstance(convertedInstance);
1201        for (int j = 0; j < newProbs.length; j++)
1202          sums[j] += newProbs[j];
1203      }
1204    }
1205    if (instance.classAttribute().isNumeric() == true) {
1206      sums[0] /= (double)m_NumIterations;
1207      return sums;
1208    } else if (Utils.eq(Utils.sum(sums), 0)) {
1209      return sums;
1210    } else {
1211      Utils.normalize(sums);
1212      return sums;
1213    }
1214  }
1215
1216  /**
1217   * Main method for testing this class.
1218   *
1219   * @param argv the options
1220   */
1221  public static void main(String [] argv) {
1222    runClassifier(new RotationForest(), argv);
1223  }
1224
1225}
1226
Note: See TracBrowser for help on using the repository browser.