source: src/main/java/weka/classifiers/functions/LeastMedSq.java @ 28

Last change on this file since 28 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 18.5 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    LeastMedSq.java
19 *
20 *    Copyright (C) 2001 University of Waikato
21 */
22
23package weka.classifiers.functions;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.core.Capabilities;
28import weka.core.Instance;
29import weka.core.Instances;
30import weka.core.Option;
31import weka.core.OptionHandler;
32import weka.core.RevisionUtils;
33import weka.core.TechnicalInformation;
34import weka.core.TechnicalInformationHandler;
35import weka.core.Utils;
36import weka.core.Capabilities.Capability;
37import weka.core.TechnicalInformation.Field;
38import weka.core.TechnicalInformation.Type;
39import weka.filters.Filter;
40import weka.filters.supervised.attribute.NominalToBinary;
41import weka.filters.unsupervised.attribute.ReplaceMissingValues;
42import weka.filters.unsupervised.instance.RemoveRange;
43
44import java.util.Enumeration;
45import java.util.Random;
46import java.util.Vector;
47
48/**
49 <!-- globalinfo-start -->
50 * Implements a least median sqaured linear regression utilising the existing weka LinearRegression class to form predictions. <br/>
51 * Least squared regression functions are generated from random subsamples of the data. The least squared regression with the lowest meadian squared error is chosen as the final model.<br/>
52 * <br/>
53 * The basis of the algorithm is <br/>
54 * <br/>
55 * Peter J. Rousseeuw, Annick M. Leroy (1987). Robust regression and outlier detection. .
56 * <p/>
57 <!-- globalinfo-end -->
58 *
59 <!-- technical-bibtex-start -->
60 * BibTeX:
61 * <pre>
62 * &#64;book{Rousseeuw1987,
63 *    author = {Peter J. Rousseeuw and Annick M. Leroy},
64 *    title = {Robust regression and outlier detection},
65 *    year = {1987}
66 * }
67 * </pre>
68 * <p/>
69 <!-- technical-bibtex-end -->
70 *
71 <!-- options-start -->
72 * Valid options are: <p/>
73 *
74 * <pre> -S &lt;sample size&gt;
75 *  Set sample size
76 *  (default: 4)
77 * </pre>
78 *
79 * <pre> -G &lt;seed&gt;
80 *  Set the seed used to generate samples
81 *  (default: 0)
82 * </pre>
83 *
84 * <pre> -D
85 *  Produce debugging output
86 *  (default no debugging output)
87 * </pre>
88 *
89 <!-- options-end -->
90 *
91 * @author Tony Voyle (tv6@waikato.ac.nz)
92 * @version $Revision: 5928 $
93 */
94public class LeastMedSq 
95  extends AbstractClassifier
96  implements OptionHandler, TechnicalInformationHandler {
97 
98  /** for serialization */
99  static final long serialVersionUID = 4288954049987652970L;
100 
101  private double[] m_Residuals;
102 
103  private double[] m_weight;
104 
105  private double m_SSR;
106 
107  private double m_scalefactor;
108 
109  private double m_bestMedian = Double.POSITIVE_INFINITY;
110 
111  private LinearRegression m_currentRegression;
112 
113  private LinearRegression m_bestRegression;
114 
115  private LinearRegression m_ls;
116
117  private Instances m_Data;
118
119  private Instances m_RLSData;
120
121  private Instances m_SubSample;
122
123  private ReplaceMissingValues m_MissingFilter;
124
125  private NominalToBinary m_TransformFilter;
126
127  private RemoveRange m_SplitFilter;
128
129  private int m_samplesize = 4;
130
131  private int m_samples;
132
133  private boolean m_israndom = false;
134
135  private boolean m_debug = false;
136
137  private Random m_random;
138
139  private long m_randomseed = 0;
140
141  /**
142   * Returns a string describing this classifier
143   * @return a description of the classifier suitable for
144   * displaying in the explorer/experimenter gui
145   */
146  public String globalInfo() {
147    return "Implements a least median sqaured linear regression utilising the "
148      +"existing weka LinearRegression class to form predictions. \n"
149      +"Least squared regression functions are generated from random subsamples of "
150      +"the data. The least squared regression with the lowest meadian squared error "
151      +"is chosen as the final model.\n\n"
152      +"The basis of the algorithm is \n\n"
153      + getTechnicalInformation().toString();
154  }
155
156  /**
157   * Returns an instance of a TechnicalInformation object, containing
158   * detailed information about the technical background of this class,
159   * e.g., paper reference or book this class is based on.
160   *
161   * @return the technical information about this class
162   */
163  public TechnicalInformation getTechnicalInformation() {
164    TechnicalInformation        result;
165   
166    result = new TechnicalInformation(Type.BOOK);
167    result.setValue(Field.AUTHOR, "Peter J. Rousseeuw and Annick M. Leroy");
168    result.setValue(Field.YEAR, "1987");
169    result.setValue(Field.TITLE, "Robust regression and outlier detection");
170   
171    return result;
172  }
173
174  /**
175   * Returns default capabilities of the classifier.
176   *
177   * @return      the capabilities of this classifier
178   */
179  public Capabilities getCapabilities() {
180    Capabilities result = super.getCapabilities();
181    result.disableAll();
182
183    // attributes
184    result.enable(Capability.NOMINAL_ATTRIBUTES);
185    result.enable(Capability.NUMERIC_ATTRIBUTES);
186    result.enable(Capability.DATE_ATTRIBUTES);
187    result.enable(Capability.MISSING_VALUES);
188
189    // class
190    result.enable(Capability.NUMERIC_CLASS);
191    result.enable(Capability.DATE_CLASS);
192    result.enable(Capability.MISSING_CLASS_VALUES);
193   
194    return result;
195  }
196
197  /**
198   * Build lms regression
199   *
200   * @param data training data
201   * @throws Exception if an error occurs
202   */
203  public void buildClassifier(Instances data)throws Exception{
204
205    // can classifier handle the data?
206    getCapabilities().testWithFail(data);
207
208    // remove instances with missing class
209    data = new Instances(data);
210    data.deleteWithMissingClass();
211   
212    cleanUpData(data);
213
214    getSamples();
215
216    findBestRegression();
217
218    buildRLSRegression();
219
220  } // buildClassifier
221
222  /**
223   * Classify a given instance using the best generated
224   * LinearRegression Classifier.
225   *
226   * @param instance instance to be classified
227   * @return class value
228   * @throws Exception if an error occurs
229   */
230  public double classifyInstance(Instance instance)throws Exception{
231
232    Instance transformedInstance = instance;
233    m_TransformFilter.input(transformedInstance);
234    transformedInstance = m_TransformFilter.output();
235    m_MissingFilter.input(transformedInstance);
236    transformedInstance = m_MissingFilter.output();
237
238    return m_ls.classifyInstance(transformedInstance);
239  } // classifyInstance
240
241  /**
242   * Cleans up data
243   *
244   * @param data data to be cleaned up
245   * @throws Exception if an error occurs
246   */
247  private void cleanUpData(Instances data)throws Exception{
248
249    m_Data = data;
250    m_TransformFilter = new NominalToBinary();
251    m_TransformFilter.setInputFormat(m_Data);
252    m_Data = Filter.useFilter(m_Data, m_TransformFilter);
253    m_MissingFilter = new ReplaceMissingValues();
254    m_MissingFilter.setInputFormat(m_Data);
255    m_Data = Filter.useFilter(m_Data, m_MissingFilter);
256    m_Data.deleteWithMissingClass();
257  }
258
259  /**
260   * Gets the number of samples to use.
261   *
262   * @throws Exception if an error occurs
263   */
264  private void getSamples()throws Exception{
265
266    int stuf[] = new int[] {500,50,22,17,15,14};
267    if ( m_samplesize < 7){
268      if ( m_Data.numInstances() < stuf[m_samplesize - 1])
269        m_samples = combinations(m_Data.numInstances(), m_samplesize);
270      else
271        m_samples = m_samplesize * 500;
272
273    } else m_samples = 3000;
274    if (m_debug){
275      System.out.println("m_samplesize: " + m_samplesize);
276      System.out.println("m_samples: " + m_samples);
277      System.out.println("m_randomseed: " + m_randomseed);
278    }
279
280  }
281
282  /**
283   * Set up the random number generator
284   *
285   */
286  private void setRandom(){
287
288    m_random = new Random(getRandomSeed());
289  }
290
291  /**
292   * Finds the best regression generated from m_samples
293   * random samples from the training data
294   *
295   * @throws Exception if an error occurs
296   */
297  private void findBestRegression()throws Exception{
298
299    setRandom();
300    m_bestMedian = Double.POSITIVE_INFINITY;
301    if (m_debug) {
302      System.out.println("Starting:");
303    }
304    for(int s = 0, r = 0; s < m_samples; s++, r++){
305      if (m_debug) {
306        if(s%(m_samples/100)==0)
307          System.out.print("*");
308      }
309      genRegression();
310      getMedian();
311    }
312    if (m_debug) {
313      System.out.println("");
314    }
315    m_currentRegression = m_bestRegression;
316  }
317
318  /**
319   * Generates a LinearRegression classifier from
320   * the current m_SubSample
321   *
322   * @throws Exception if an error occurs
323   */
324  private void genRegression()throws Exception{
325
326    m_currentRegression = new LinearRegression();
327    m_currentRegression.setOptions(new String[]{"-S", "1"});
328    selectSubSample(m_Data);
329    m_currentRegression.buildClassifier(m_SubSample);
330  }
331
332  /**
333   * Finds residuals (squared) for the current
334   * regression.
335   *
336   * @throws Exception if an error occurs
337   */
338  private void findResiduals()throws Exception{
339
340    m_SSR = 0;
341    m_Residuals = new double [m_Data.numInstances()];
342    for(int i = 0; i < m_Data.numInstances(); i++){
343      m_Residuals[i] = m_currentRegression.classifyInstance(m_Data.instance(i));
344      m_Residuals[i] -= m_Data.instance(i).value(m_Data.classAttribute());
345      m_Residuals[i] *= m_Residuals[i];
346      m_SSR += m_Residuals[i];
347    }
348  }
349
350  /**
351   * finds the median residual squared for the
352   * current regression
353   *
354   * @throws Exception if an error occurs
355   */
356  private void getMedian()throws Exception{
357
358    findResiduals();
359    int p = m_Residuals.length;
360    select(m_Residuals, 0, p - 1, p / 2);
361    if(m_Residuals[p / 2] < m_bestMedian){
362      m_bestMedian = m_Residuals[p / 2];
363      m_bestRegression = m_currentRegression;
364    }
365  }
366
367  /**
368   * Returns a string representing the best
369   * LinearRegression classifier found.
370   *
371   * @return String representing the regression
372   */
373  public String toString(){
374
375    if( m_ls == null){
376      return "model has not been built";
377    }
378    return m_ls.toString();
379  }
380
381  /**
382   * Builds a weight function removing instances with an
383   * abnormally high scaled residual
384   *
385   * @throws Exception if weight building fails
386   */
387  private void buildWeight()throws Exception{
388
389    findResiduals();
390    m_scalefactor = 1.4826 * ( 1 + 5 / (m_Data.numInstances()
391                                        - m_Data.numAttributes()))
392      * Math.sqrt(m_bestMedian);
393    m_weight = new double[m_Residuals.length];
394    for (int i = 0; i < m_Residuals.length; i++)
395      m_weight[i] = ((Math.sqrt(m_Residuals[i])/m_scalefactor < 2.5)?1.0:0.0);
396  }
397
398  /**
399   * Builds a new LinearRegression without the 'bad' data
400   * found by buildWeight
401   *
402   * @throws Exception if building fails
403   */
404  private void buildRLSRegression()throws Exception{
405
406    buildWeight();
407    m_RLSData = new Instances(m_Data);
408    int x = 0;
409    int y = 0;
410    int n = m_RLSData.numInstances();
411    while(y < n){
412      if (m_weight[x] == 0){
413        m_RLSData.delete(y);
414        n = m_RLSData.numInstances();
415        y--;
416      }
417      x++;
418      y++;
419    }
420    if ( m_RLSData.numInstances() == 0){
421      System.err.println("rls regression unbuilt");
422      m_ls = m_currentRegression;
423    }else{
424      m_ls = new LinearRegression();
425      m_ls.setOptions(new String[]{"-S", "1"});
426      m_ls.buildClassifier(m_RLSData);
427      m_currentRegression = m_ls;
428    }
429
430  }
431
432  /**
433   * Finds the kth number in an array
434   *
435   * @param a an array of numbers
436   * @param l left pointer
437   * @param r right pointer
438   * @param k position of number to be found
439   */
440  private static void select( double [] a, int l, int r, int k){
441
442    if (r <=l) return;
443    int i = partition( a, l, r);
444    if (i > k) select(a, l, i-1, k);
445    if (i < k) select(a, i+1, r, k);
446  }
447
448  /**
449   * Partitions an array of numbers such that all numbers
450   * less than that at index r, between indexes l and r
451   * will have a smaller index and all numbers greater than
452   * will have a larger index
453   *
454   * @param a an array of numbers
455   * @param l left pointer
456   * @param r right pointer
457   * @return final index of number originally at r
458   */
459  private static int partition( double [] a, int l, int r ){
460
461    int i = l-1, j = r;
462    double v = a[r], temp;
463    while(true){
464      while(a[++i] < v);
465      while(v < a[--j]) if(j == l) break;
466      if(i >= j) break;
467      temp = a[i];
468      a[i] = a[j];
469      a[j] = temp;
470    }
471    temp = a[i];
472    a[i] = a[r];
473    a[r] = temp;
474    return i;
475  }
476
477  /**
478   * Produces a random sample from m_Data
479   * in m_SubSample
480   *
481   * @param data data from which to take sample
482   * @throws Exception if an error occurs
483   */
484  private void selectSubSample(Instances data)throws Exception{
485
486    m_SplitFilter = new RemoveRange();
487    m_SplitFilter.setInvertSelection(true);
488    m_SubSample = data;
489    m_SplitFilter.setInputFormat(m_SubSample);
490    m_SplitFilter.setInstancesIndices(selectIndices(m_SubSample));
491    m_SubSample = Filter.useFilter(m_SubSample, m_SplitFilter);
492  }
493
494  /**
495   * Returns a string suitable for passing to RemoveRange consisting
496   * of m_samplesize indices.
497   *
498   * @param data dataset from which to take indicese
499   * @return string of indices suitable for passing to RemoveRange
500   */
501  private String selectIndices(Instances data){
502
503    StringBuffer text = new StringBuffer();
504    for(int i = 0, x = 0; i < m_samplesize; i++){
505      do{x = (int) (m_random.nextDouble() * data.numInstances());}
506      while(x==0);
507      text.append(Integer.toString(x));
508      if(i < m_samplesize - 1)
509        text.append(",");
510      else
511        text.append("\n");
512    }
513    return text.toString();
514  }
515
516  /**
517   * Returns the tip text for this property
518   * @return tip text for this property suitable for
519   * displaying in the explorer/experimenter gui
520   */
521  public String sampleSizeTipText() {
522    return "Set the size of the random samples used to generate the least sqaured "
523      +"regression functions.";
524  }
525
526  /**
527   * sets number of samples
528   *
529   * @param samplesize value
530   */
531  public void setSampleSize(int samplesize){
532
533    m_samplesize = samplesize;
534  }
535
536  /**
537   * gets number of samples
538   *
539   * @return value
540   */
541  public int getSampleSize(){
542
543    return m_samplesize;
544  }
545
546  /**
547   * Returns the tip text for this property
548   * @return tip text for this property suitable for
549   * displaying in the explorer/experimenter gui
550   */
551  public String randomSeedTipText() {
552    return "Set the seed for selecting random subsamples of the training data.";
553  }
554
555  /**
556   * Set the seed for the random number generator
557   *
558   * @param randomseed the seed
559   */
560  public void setRandomSeed(long randomseed){
561
562    m_randomseed = randomseed;
563  }
564
565  /**
566   * get the seed for the random number generator
567   *
568   * @return the seed value
569   */
570  public long getRandomSeed(){
571
572    return m_randomseed;
573  }
574
575  /**
576   * sets  whether or not debugging output shouild be printed
577   *
578   * @param debug true if debugging output selected
579   */
580  public void setDebug(boolean debug){
581
582    m_debug = debug;
583  }
584
585  /**
586   * Returns whether or not debugging output shouild be printed
587   *
588   * @return true if debuging output selected
589   */
590  public boolean getDebug(){
591
592    return m_debug;
593  }
594
595  /**
596   * Returns an enumeration of all the available options..
597   *
598   * @return an enumeration of all available options.
599   */
600  public Enumeration listOptions(){
601
602    Vector newVector = new Vector(1);
603    newVector.addElement(new Option("\tSet sample size\n"
604                                    + "\t(default: 4)\n",
605                                    "S", 4, "-S <sample size>"));
606    newVector.addElement(new Option("\tSet the seed used to generate samples\n"
607                                    + "\t(default: 0)\n",
608                                    "G", 0, "-G <seed>"));
609    newVector.addElement(new Option("\tProduce debugging output\n"
610                                    + "\t(default no debugging output)\n",
611                                    "D", 0, "-D"));
612
613    return newVector.elements();
614  }
615
616  /**
617   * Sets the OptionHandler's options using the given list. All options
618   * will be set (or reset) during this call (i.e. incremental setting
619   * of options is not possible).
620   *
621   <!-- options-start -->
622   * Valid options are: <p/>
623   *
624   * <pre> -S &lt;sample size&gt;
625   *  Set sample size
626   *  (default: 4)
627   * </pre>
628   *
629   * <pre> -G &lt;seed&gt;
630   *  Set the seed used to generate samples
631   *  (default: 0)
632   * </pre>
633   *
634   * <pre> -D
635   *  Produce debugging output
636   *  (default no debugging output)
637   * </pre>
638   *
639   <!-- options-end -->
640   *
641   * @param options the list of options as an array of strings
642   * @throws Exception if an option is not supported
643   */
644  public void setOptions(String[] options) throws Exception {
645
646    String curropt = Utils.getOption('S', options);
647    if ( curropt.length() != 0){
648      setSampleSize(Integer.parseInt(curropt));
649    } else
650      setSampleSize(4);
651
652    curropt = Utils.getOption('G', options);
653    if ( curropt.length() != 0){
654      setRandomSeed(Long.parseLong(curropt));
655    } else {
656      setRandomSeed(0);
657    }
658
659    setDebug(Utils.getFlag('D', options));
660  }
661
662  /**
663   * Gets the current option settings for the OptionHandler.
664   *
665   * @return the list of current option settings as an array of strings
666   */
667  public String[] getOptions(){
668
669    String options[] = new String[9];
670    int current = 0;
671
672    options[current++] = "-S";
673    options[current++] = "" + getSampleSize();
674
675    options[current++] = "-G";
676    options[current++] = "" + getRandomSeed();
677
678    if (getDebug()) {
679      options[current++] = "-D";
680    }
681
682    while (current < options.length) {
683      options[current++] = "";
684    }
685
686    return options;
687  }
688
689  /**
690   * Produces the combination nCr
691   *
692   * @param n
693   * @param r
694   * @return the combination
695   * @throws Exception if r is greater than n
696   */
697  public static int combinations (int n, int r)throws Exception {
698
699    int c = 1, denom = 1, num = 1, i,orig=r;
700    if (r > n) throw new Exception("r must be less that or equal to n.");
701    r = Math.min( r , n - r);
702
703    for (i = 1 ; i <= r; i++){
704
705      num *= n-i+1;
706      denom *= i;
707    }
708
709    c = num / denom;
710    if(false)
711      System.out.println( "n: "+n+" r: "+orig+" num: "+num+
712                          " denom: "+denom+" c: "+c);
713    return c;
714  }
715 
716  /**
717   * Returns the revision string.
718   *
719   * @return            the revision
720   */
721  public String getRevision() {
722    return RevisionUtils.extract("$Revision: 5928 $");
723  }
724
725  /**
726   * generate a Linear regression predictor for testing
727   *
728   * @param argv options
729   */
730  public static void main(String [] argv){
731    runClassifier(new LeastMedSq(), argv);
732  } // main
733} // lmr
Note: See TracBrowser for help on using the repository browser.