source: src/main/java/weka/attributeSelection/ReliefFAttributeEval.java @ 10

Last change on this file since 10 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 39.1 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    ReliefFAttributeEval.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.attributeSelection;
24
25import weka.core.Attribute;
26import weka.core.Capabilities;
27import weka.core.Instance;
28import weka.core.Instances;
29import weka.core.Option;
30import weka.core.OptionHandler;
31import weka.core.RevisionUtils;
32import weka.core.TechnicalInformation;
33import weka.core.TechnicalInformationHandler;
34import weka.core.Utils;
35import weka.core.Capabilities.Capability;
36import weka.core.TechnicalInformation.Field;
37import weka.core.TechnicalInformation.Type;
38
39import java.util.Enumeration;
40import java.util.Random;
41import java.util.Vector;
42
43/**
44 <!-- globalinfo-start -->
45 * ReliefFAttributeEval :<br/>
46 * <br/>
47 * Evaluates the worth of an attribute by repeatedly sampling an instance and considering the value of the given attribute for the nearest instance of the same and different class. Can operate on both discrete and continuous class data.<br/>
48 * <br/>
49 * For more information see:<br/>
50 * <br/>
51 * Kenji Kira, Larry A. Rendell: A Practical Approach to Feature Selection. In: Ninth International Workshop on Machine Learning, 249-256, 1992.<br/>
52 * <br/>
53 * Igor Kononenko: Estimating Attributes: Analysis and Extensions of RELIEF. In: European Conference on Machine Learning, 171-182, 1994.<br/>
54 * <br/>
55 * Marko Robnik-Sikonja, Igor Kononenko: An adaptation of Relief for attribute estimation in regression. In: Fourteenth International Conference on Machine Learning, 296-304, 1997.
56 * <p/>
57 <!-- globalinfo-end -->
58 *
59 <!-- technical-bibtex-start -->
60 * BibTeX:
61 * <pre>
62 * &#64;inproceedings{Kira1992,
63 *    author = {Kenji Kira and Larry A. Rendell},
64 *    booktitle = {Ninth International Workshop on Machine Learning},
65 *    editor = {Derek H. Sleeman and Peter Edwards},
66 *    pages = {249-256},
67 *    publisher = {Morgan Kaufmann},
68 *    title = {A Practical Approach to Feature Selection},
69 *    year = {1992}
70 * }
71 *
72 * &#64;inproceedings{Kononenko1994,
73 *    author = {Igor Kononenko},
74 *    booktitle = {European Conference on Machine Learning},
75 *    editor = {Francesco Bergadano and Luc De Raedt},
76 *    pages = {171-182},
77 *    publisher = {Springer},
78 *    title = {Estimating Attributes: Analysis and Extensions of RELIEF},
79 *    year = {1994}
80 * }
81 *
82 * &#64;inproceedings{Robnik-Sikonja1997,
83 *    author = {Marko Robnik-Sikonja and Igor Kononenko},
84 *    booktitle = {Fourteenth International Conference on Machine Learning},
85 *    editor = {Douglas H. Fisher},
86 *    pages = {296-304},
87 *    publisher = {Morgan Kaufmann},
88 *    title = {An adaptation of Relief for attribute estimation in regression},
89 *    year = {1997}
90 * }
91 * </pre>
92 * <p/>
93 <!-- technical-bibtex-end -->
94 *
95 <!-- options-start -->
96 * Valid options are: <p/>
97 *
98 * <pre> -M &lt;num instances&gt;
99 *  Specify the number of instances to
100 *  sample when estimating attributes.
101 *  If not specified, then all instances
102 *  will be used.</pre>
103 *
104 * <pre> -D &lt;seed&gt;
105 *  Seed for randomly sampling instances.
106 *  (Default = 1)</pre>
107 *
108 * <pre> -K &lt;number of neighbours&gt;
109 *  Number of nearest neighbours (k) used
110 *  to estimate attribute relevances
111 *  (Default = 10).</pre>
112 *
113 * <pre> -W
114 *  Weight nearest neighbours by distance</pre>
115 *
116 * <pre> -A &lt;num&gt;
117 *  Specify sigma value (used in an exp
118 *  function to control how quickly
119 *  weights for more distant instances
120 *  decrease. Use in conjunction with -W.
121 *  Sensible value=1/5 to 1/10 of the
122 *  number of nearest neighbours.
123 *  (Default = 2)</pre>
124 *
125 <!-- options-end -->
126 *
127 * @author Mark Hall (mhall@cs.waikato.ac.nz)
128 * @version $Revision: 5987 $
129 */
130public class ReliefFAttributeEval
131  extends ASEvaluation
132  implements AttributeEvaluator,
133             OptionHandler, 
134             TechnicalInformationHandler {
135 
136  /** for serialization */
137  static final long serialVersionUID = -8422186665795839379L;
138
139  /** The training instances */
140  private Instances m_trainInstances;
141
142  /** The class index */
143  private int m_classIndex;
144
145  /** The number of attributes */
146  private int m_numAttribs;
147
148  /** The number of instances */
149  private int m_numInstances;
150
151  /** Numeric class */
152  private boolean m_numericClass;
153
154  /** The number of classes if class is nominal */
155  private int m_numClasses;
156
157  /**
158   * Used to hold the probability of a different class val given nearest
159   * instances (numeric class)
160   */
161  private double m_ndc;
162
163  /**
164   * Used to hold the prob of different value of an attribute given
165   * nearest instances (numeric class case)
166   */
167  private double[] m_nda;
168
169  /**
170   * Used to hold the prob of a different class val and different att
171   * val given nearest instances (numeric class case)
172   */
173  private double[] m_ndcda;
174
175  /** Holds the weights that relief assigns to attributes */
176  private double[] m_weights;
177
178  /** Prior class probabilities (discrete class case) */
179  private double[] m_classProbs;
180
181  /**
182   * The number of instances to sample when estimating attributes
183   * default == -1, use all instances
184   */
185  private int m_sampleM;
186
187  /** The number of nearest hits/misses */
188  private int m_Knn;
189
190  /** k nearest scores + instance indexes for n classes */
191  private double[][][] m_karray;
192
193  /** Upper bound for numeric attributes */
194  private double[] m_maxArray;
195
196  /** Lower bound for numeric attributes */
197  private double[] m_minArray;
198
199  /** Keep track of the farthest instance for each class */
200  private double[] m_worst;
201
202  /** Index in the m_karray of the farthest instance for each class */
203  private int[] m_index;
204
205  /** Number of nearest neighbours stored of each class */
206  private int[] m_stored;
207 
208  /** Random number seed used for sampling instances */
209  private int m_seed;
210
211  /**
212   *  used to (optionally) weight nearest neighbours by their distance
213   *  from the instance in question. Each entry holds
214   *  exp(-((rank(r_i, i_j)/sigma)^2)) where rank(r_i,i_j) is the rank of
215   *  instance i_j in a sequence of instances ordered by the distance
216   *  from r_i. sigma is a user defined parameter, default=20
217   **/
218  private double[] m_weightsByRank;
219  private int m_sigma;
220 
221  /** Weight by distance rather than equal weights */
222  private boolean m_weightByDistance;
223
224  /**
225   * Constructor
226   */
227  public ReliefFAttributeEval () {
228    resetOptions();
229  }
230
231  /**
232   * Returns a string describing this attribute evaluator
233   * @return a description of the evaluator suitable for
234   * displaying in the explorer/experimenter gui
235   */
236  public String globalInfo() {
237    return "ReliefFAttributeEval :\n\nEvaluates the worth of an attribute by "
238      +"repeatedly sampling an instance and considering the value of the "
239      +"given attribute for the nearest instance of the same and different "
240      +"class. Can operate on both discrete and continuous class data.\n\n"
241      + "For more information see:\n\n"
242      + getTechnicalInformation().toString();
243  }
244
245  /**
246   * Returns an instance of a TechnicalInformation object, containing
247   * detailed information about the technical background of this class,
248   * e.g., paper reference or book this class is based on.
249   *
250   * @return the technical information about this class
251   */
252  public TechnicalInformation getTechnicalInformation() {
253    TechnicalInformation        result;
254    TechnicalInformation        additional;
255   
256    result = new TechnicalInformation(Type.INPROCEEDINGS);
257    result.setValue(Field.AUTHOR, "Kenji Kira and Larry A. Rendell");
258    result.setValue(Field.TITLE, "A Practical Approach to Feature Selection");
259    result.setValue(Field.BOOKTITLE, "Ninth International Workshop on Machine Learning");
260    result.setValue(Field.EDITOR, "Derek H. Sleeman and Peter Edwards");
261    result.setValue(Field.YEAR, "1992");
262    result.setValue(Field.PAGES, "249-256");
263    result.setValue(Field.PUBLISHER, "Morgan Kaufmann");
264   
265    additional = result.add(Type.INPROCEEDINGS);
266    additional.setValue(Field.AUTHOR, "Igor Kononenko");
267    additional.setValue(Field.TITLE, "Estimating Attributes: Analysis and Extensions of RELIEF");
268    additional.setValue(Field.BOOKTITLE, "European Conference on Machine Learning");
269    additional.setValue(Field.EDITOR, "Francesco Bergadano and Luc De Raedt");
270    additional.setValue(Field.YEAR, "1994");
271    additional.setValue(Field.PAGES, "171-182");
272    additional.setValue(Field.PUBLISHER, "Springer");
273   
274    additional = result.add(Type.INPROCEEDINGS);
275    additional.setValue(Field.AUTHOR, "Marko Robnik-Sikonja and Igor Kononenko");
276    additional.setValue(Field.TITLE, "An adaptation of Relief for attribute estimation in regression");
277    additional.setValue(Field.BOOKTITLE, "Fourteenth International Conference on Machine Learning");
278    additional.setValue(Field.EDITOR, "Douglas H. Fisher");
279    additional.setValue(Field.YEAR, "1997");
280    additional.setValue(Field.PAGES, "296-304");
281    additional.setValue(Field.PUBLISHER, "Morgan Kaufmann");
282   
283    return result;
284  }
285
286  /**
287   * Returns an enumeration describing the available options.
288   * @return an enumeration of all the available options.
289   **/
290  public Enumeration listOptions () {
291    Vector newVector = new Vector(4);
292    newVector
293      .addElement(new Option("\tSpecify the number of instances to\n" 
294                             + "\tsample when estimating attributes.\n" 
295                             + "\tIf not specified, then all instances\n" 
296                             + "\twill be used.", "M", 1
297                             , "-M <num instances>"));
298    newVector.
299      addElement(new Option("\tSeed for randomly sampling instances.\n" 
300                            + "\t(Default = 1)", "D", 1
301                            , "-D <seed>"));
302    newVector.
303      addElement(new Option("\tNumber of nearest neighbours (k) used\n" 
304                            + "\tto estimate attribute relevances\n" 
305                            + "\t(Default = 10).", "K", 1
306                            , "-K <number of neighbours>"));
307    newVector.
308      addElement(new Option("\tWeight nearest neighbours by distance", "W"
309                            , 0, "-W"));
310    newVector.
311      addElement(new Option("\tSpecify sigma value (used in an exp\n" 
312                            + "\tfunction to control how quickly\n" 
313                            + "\tweights for more distant instances\n" 
314                            + "\tdecrease. Use in conjunction with -W.\n" 
315                            + "\tSensible value=1/5 to 1/10 of the\n" 
316                            + "\tnumber of nearest neighbours.\n" 
317                            + "\t(Default = 2)", "A", 1, "-A <num>"));
318    return  newVector.elements();
319  }
320
321
322  /**
323   * Parses a given list of options. <p/>
324   *
325   <!-- options-start -->
326   * Valid options are: <p/>
327   *
328   * <pre> -M &lt;num instances&gt;
329   *  Specify the number of instances to
330   *  sample when estimating attributes.
331   *  If not specified, then all instances
332   *  will be used.</pre>
333   *
334   * <pre> -D &lt;seed&gt;
335   *  Seed for randomly sampling instances.
336   *  (Default = 1)</pre>
337   *
338   * <pre> -K &lt;number of neighbours&gt;
339   *  Number of nearest neighbours (k) used
340   *  to estimate attribute relevances
341   *  (Default = 10).</pre>
342   *
343   * <pre> -W
344   *  Weight nearest neighbours by distance</pre>
345   *
346   * <pre> -A &lt;num&gt;
347   *  Specify sigma value (used in an exp
348   *  function to control how quickly
349   *  weights for more distant instances
350   *  decrease. Use in conjunction with -W.
351   *  Sensible value=1/5 to 1/10 of the
352   *  number of nearest neighbours.
353   *  (Default = 2)</pre>
354   *
355   <!-- options-end -->
356   *
357   * @param options the list of options as an array of strings
358   * @throws Exception if an option is not supported
359   */
360  public void setOptions (String[] options)
361    throws Exception {
362    String optionString;
363    resetOptions();
364    setWeightByDistance(Utils.getFlag('W', options));
365    optionString = Utils.getOption('M', options);
366
367    if (optionString.length() != 0) {
368      setSampleSize(Integer.parseInt(optionString));
369    }
370
371    optionString = Utils.getOption('D', options);
372
373    if (optionString.length() != 0) {
374      setSeed(Integer.parseInt(optionString));
375    }
376
377    optionString = Utils.getOption('K', options);
378
379    if (optionString.length() != 0) {
380      setNumNeighbours(Integer.parseInt(optionString));
381    }
382
383    optionString = Utils.getOption('A', options);
384
385    if (optionString.length() != 0) {
386      setWeightByDistance(true); // turn on weighting by distance
387      setSigma(Integer.parseInt(optionString));
388    }
389  }
390
391  /**
392   * Returns the tip text for this property
393   * @return tip text for this property suitable for
394   * displaying in the explorer/experimenter gui
395   */
396  public String sigmaTipText() {
397    return "Set influence of nearest neighbours. Used in an exp function to "
398      +"control how quickly weights decrease for more distant instances. "
399      +"Use in conjunction with weightByDistance. Sensible values = 1/5 to "
400      +"1/10 the number of nearest neighbours.";
401  }
402
403  /**
404   * Sets the sigma value.
405   *
406   * @param s the value of sigma (> 0)
407   * @throws Exception if s is not positive
408   */
409  public void setSigma (int s)
410    throws Exception {
411    if (s <= 0) {
412      throw  new Exception("value of sigma must be > 0!");
413    }
414
415    m_sigma = s;
416  }
417
418
419  /**
420   * Get the value of sigma.
421   *
422   * @return the sigma value.
423   */
424  public int getSigma () {
425    return  m_sigma;
426  }
427
428  /**
429   * Returns the tip text for this property
430   * @return tip text for this property suitable for
431   * displaying in the explorer/experimenter gui
432   */
433  public String numNeighboursTipText() {
434    return "Number of nearest neighbours for attribute estimation.";
435  }
436
437  /**
438   * Set the number of nearest neighbours
439   *
440   * @param n the number of nearest neighbours.
441   */
442  public void setNumNeighbours (int n) {
443    m_Knn = n;
444  }
445
446
447  /**
448   * Get the number of nearest neighbours
449   *
450   * @return the number of nearest neighbours
451   */
452  public int getNumNeighbours () {
453    return  m_Knn;
454  }
455
456  /**
457   * Returns the tip text for this property
458   * @return tip text for this property suitable for
459   * displaying in the explorer/experimenter gui
460   */
461  public String seedTipText() {
462    return "Random seed for sampling instances.";
463  }
464
465  /**
466   * Set the random number seed for randomly sampling instances.
467   *
468   * @param s the random number seed.
469   */
470  public void setSeed (int s) {
471    m_seed = s;
472  }
473
474
475  /**
476   * Get the seed used for randomly sampling instances.
477   *
478   * @return the random number seed.
479   */
480  public int getSeed () {
481    return  m_seed;
482  }
483
484  /**
485   * Returns the tip text for this property
486   * @return tip text for this property suitable for
487   * displaying in the explorer/experimenter gui
488   */
489  public String sampleSizeTipText() {
490    return "Number of instances to sample. Default (-1) indicates that all "
491      +"instances will be used for attribute estimation.";
492  }
493
494  /**
495   * Set the number of instances to sample for attribute estimation
496   *
497   * @param s the number of instances to sample.
498   */
499  public void setSampleSize (int s) {
500    m_sampleM = s;
501  }
502
503
504  /**
505   * Get the number of instances used for estimating attributes
506   *
507   * @return the number of instances.
508   */
509  public int getSampleSize () {
510    return  m_sampleM;
511  }
512
513  /**
514   * Returns the tip text for this property
515   * @return tip text for this property suitable for
516   * displaying in the explorer/experimenter gui
517   */
518  public String weightByDistanceTipText() {
519    return "Weight nearest neighbours by their distance.";
520  }
521
522  /**
523   * Set the nearest neighbour weighting method
524   *
525   * @param b true nearest neighbours are to be weighted by distance.
526   */
527  public void setWeightByDistance (boolean b) {
528    m_weightByDistance = b;
529  }
530
531
532  /**
533   * Get whether nearest neighbours are being weighted by distance
534   *
535   * @return m_weightByDiffernce
536   */
537  public boolean getWeightByDistance () {
538    return  m_weightByDistance;
539  }
540
541
542  /**
543   * Gets the current settings of ReliefFAttributeEval.
544   *
545   * @return an array of strings suitable for passing to setOptions()
546   */
547  public String[] getOptions () {
548    String[] options = new String[9];
549    int current = 0;
550
551    if (getWeightByDistance()) {
552      options[current++] = "-W";
553    }
554
555    options[current++] = "-M";
556    options[current++] = "" + getSampleSize();
557    options[current++] = "-D";
558    options[current++] = "" + getSeed();
559    options[current++] = "-K";
560    options[current++] = "" + getNumNeighbours();
561   
562    if (getWeightByDistance()) {
563      options[current++] = "-A";
564      options[current++] = "" + getSigma();
565    }
566
567    while (current < options.length) {
568      options[current++] = "";
569    }
570
571    return  options;
572  }
573
574
575  /**
576   * Return a description of the ReliefF attribute evaluator.
577   *
578   * @return a description of the evaluator as a String.
579   */
580  public String toString () {
581    StringBuffer text = new StringBuffer();
582
583    if (m_trainInstances == null) {
584      text.append("ReliefF feature evaluator has not been built yet\n");
585    }
586    else {
587      text.append("\tReliefF Ranking Filter");
588      text.append("\n\tInstances sampled: ");
589
590      if (m_sampleM == -1) {
591        text.append("all\n");
592      }
593      else {
594        text.append(m_sampleM + "\n");
595      }
596
597      text.append("\tNumber of nearest neighbours (k): " + m_Knn + "\n");
598
599      if (m_weightByDistance) {
600        text.append("\tExponentially decreasing (with distance) " 
601                    + "influence for\n" 
602                    + "\tnearest neighbours. Sigma: " 
603                    + m_sigma + "\n");
604      }
605      else {
606        text.append("\tEqual influence nearest neighbours\n");
607      }
608    }
609
610    return  text.toString();
611  }
612
613  /**
614   * Returns the capabilities of this evaluator.
615   *
616   * @return            the capabilities of this evaluator
617   * @see               Capabilities
618   */
619  public Capabilities getCapabilities() {
620    Capabilities result = super.getCapabilities();
621    result.disableAll();
622   
623    // attributes
624    result.enable(Capability.NOMINAL_ATTRIBUTES);
625    result.enable(Capability.NUMERIC_ATTRIBUTES);
626    result.enable(Capability.DATE_ATTRIBUTES);
627    result.enable(Capability.MISSING_VALUES);
628   
629    // class
630    result.enable(Capability.NOMINAL_CLASS);
631    result.enable(Capability.NUMERIC_CLASS);
632    result.enable(Capability.DATE_CLASS);
633    result.enable(Capability.MISSING_CLASS_VALUES);
634   
635    return result;
636  }
637
638  /**
639   * Initializes a ReliefF attribute evaluator.
640   *
641   * @param data set of instances serving as training data
642   * @throws Exception if the evaluator has not been
643   * generated successfully
644   */
645  public void buildEvaluator (Instances data)
646    throws Exception {
647   
648    int z, totalInstances;
649    Random r = new Random(m_seed);
650
651    // can evaluator handle data?
652    getCapabilities().testWithFail(data);
653
654    m_trainInstances = data;
655    m_classIndex = m_trainInstances.classIndex();
656    m_numAttribs = m_trainInstances.numAttributes();
657    m_numInstances = m_trainInstances.numInstances();
658
659    if (m_trainInstances.attribute(m_classIndex).isNumeric()) {
660      m_numericClass = true;
661    }
662    else {
663      m_numericClass = false;
664    }
665
666    if (!m_numericClass) {
667      m_numClasses = m_trainInstances.attribute(m_classIndex).numValues();
668    }
669    else {
670      m_ndc = 0;
671      m_numClasses = 1;
672      m_nda = new double[m_numAttribs];
673      m_ndcda = new double[m_numAttribs];
674    }
675
676    if (m_weightByDistance) // set up the rank based weights
677      {
678        m_weightsByRank = new double[m_Knn];
679
680        for (int i = 0; i < m_Knn; i++) {
681          m_weightsByRank[i] = 
682            Math.exp(-((i/(double)m_sigma)*(i/(double)m_sigma)));
683        }
684      }
685
686    // the final attribute weights
687    m_weights = new double[m_numAttribs];
688    // num classes (1 for numeric class) knn neighbours,
689    // and 0 = distance, 1 = instance index
690    m_karray = new double[m_numClasses][m_Knn][2];
691
692    if (!m_numericClass) {
693      m_classProbs = new double[m_numClasses];
694
695      for (int i = 0; i < m_numInstances; i++) {
696        m_classProbs[(int)m_trainInstances.instance(i).value(m_classIndex)]++;
697      }
698
699      for (int i = 0; i < m_numClasses; i++) {
700        m_classProbs[i] /= m_numInstances;
701      }
702    }
703
704    m_worst = new double[m_numClasses];
705    m_index = new int[m_numClasses];
706    m_stored = new int[m_numClasses];
707    m_minArray = new double[m_numAttribs];
708    m_maxArray = new double[m_numAttribs];
709
710    for (int i = 0; i < m_numAttribs; i++) {
711      m_minArray[i] = m_maxArray[i] = Double.NaN;
712    }
713
714    for (int i = 0; i < m_numInstances; i++) {
715      updateMinMax(m_trainInstances.instance(i));
716    }
717   
718    if ((m_sampleM > m_numInstances) || (m_sampleM < 0)) {
719      totalInstances = m_numInstances;
720    }
721    else {
722      totalInstances = m_sampleM;
723    }
724
725    // process each instance, updating attribute weights
726    for (int i = 0; i < totalInstances; i++) {
727      if (totalInstances == m_numInstances) {
728        z = i;
729      }
730      else {
731        z = r.nextInt()%m_numInstances;
732      }
733
734      if (z < 0) {
735        z *= -1;
736      }
737
738      if (!(m_trainInstances.instance(z).isMissing(m_classIndex))) {
739        // first clear the knn and worst index stuff for the classes
740        for (int j = 0; j < m_numClasses; j++) {
741          m_index[j] = m_stored[j] = 0;
742
743          for (int k = 0; k < m_Knn; k++) {
744            m_karray[j][k][0] = m_karray[j][k][1] = 0;
745          }
746        }
747
748        findKHitMiss(z);
749
750        if (m_numericClass) {
751          updateWeightsNumericClass(z);
752        }
753        else {
754          updateWeightsDiscreteClass(z);
755        }
756      }
757    }
758
759    // now scale weights by 1/m_numInstances (nominal class) or
760    // calculate weights numeric class
761    // System.out.println("num inst:"+m_numInstances+" r_ndc:"+r_ndc);
762    for (int i = 0; i < m_numAttribs; i++) {if (i != m_classIndex) {
763      if (m_numericClass) {
764        m_weights[i] = m_ndcda[i]/m_ndc - 
765          ((m_nda[i] - m_ndcda[i])/((double)totalInstances - m_ndc));
766      }
767      else {
768        m_weights[i] *= (1.0/(double)totalInstances);
769      }
770
771      //          System.out.println(r_weights[i]);
772    }
773    }
774  }
775
776
777  /**
778   * Evaluates an individual attribute using ReliefF's instance based approach.
779   * The actual work is done by buildEvaluator which evaluates all features.
780   *
781   * @param attribute the index of the attribute to be evaluated
782   * @throws Exception if the attribute could not be evaluated
783   */
784  public double evaluateAttribute (int attribute)
785    throws Exception {
786    return  m_weights[attribute];
787  }
788
789
790  /**
791   * Reset options to their default values
792   */
793  protected void resetOptions () {
794    m_trainInstances = null;
795    m_sampleM = -1;
796    m_Knn = 10;
797    m_sigma = 2;
798    m_weightByDistance = false;
799    m_seed = 1;
800  }
801
802
803  /**
804   * Normalizes a given value of a numeric attribute.
805   *
806   * @param x the value to be normalized
807   * @param i the attribute's index
808   * @return the normalized value
809   */
810  private double norm (double x, int i) {
811    if (Double.isNaN(m_minArray[i]) || 
812        Utils.eq(m_maxArray[i], m_minArray[i])) {
813      return  0;
814    }
815    else {
816      return  (x - m_minArray[i])/(m_maxArray[i] - m_minArray[i]);
817    }
818  }
819
820
821  /**
822   * Updates the minimum and maximum values for all the attributes
823   * based on a new instance.
824   *
825   * @param instance the new instance
826   */
827  private void updateMinMax (Instance instance) {
828    //    for (int j = 0; j < m_numAttribs; j++) {
829    try {
830      for (int j = 0; j < instance.numValues(); j++) {
831        if ((instance.attributeSparse(j).isNumeric()) && 
832            (!instance.isMissingSparse(j))) {
833          if (Double.isNaN(m_minArray[instance.index(j)])) {
834            m_minArray[instance.index(j)] = instance.valueSparse(j);
835            m_maxArray[instance.index(j)] = instance.valueSparse(j);
836          }
837        else {
838          if (instance.valueSparse(j) < m_minArray[instance.index(j)]) {
839            m_minArray[instance.index(j)] = instance.valueSparse(j);
840          }
841          else {
842            if (instance.valueSparse(j) > m_maxArray[instance.index(j)]) {
843              m_maxArray[instance.index(j)] = instance.valueSparse(j);
844            }
845          }
846        }
847        }
848      }
849    } catch (Exception ex) {
850      System.err.println(ex);
851      ex.printStackTrace();
852    }
853  }
854
855  /**
856   * Computes the difference between two given attribute
857   * values.
858   */
859  private double difference(int index, double val1, double val2) {
860
861    switch (m_trainInstances.attribute(index).type()) {
862    case Attribute.NOMINAL:
863     
864      // If attribute is nominal
865      if (Utils.isMissingValue(val1) || 
866          Utils.isMissingValue(val2)) {
867        return (1.0 - (1.0/((double)m_trainInstances.
868                            attribute(index).numValues())));
869      } else if ((int)val1 != (int)val2) {
870        return 1;
871      } else {
872        return 0;
873      }
874    case Attribute.NUMERIC:
875
876      // If attribute is numeric
877      if (Utils.isMissingValue(val1) || 
878          Utils.isMissingValue(val2)) {
879        if (Utils.isMissingValue(val1) && 
880            Utils.isMissingValue(val2)) {
881          return 1;
882        } else {
883          double diff;
884          if (Utils.isMissingValue(val2)) {
885            diff = norm(val1, index);
886          } else {
887            diff = norm(val2, index);
888          }
889          if (diff < 0.5) {
890            diff = 1.0 - diff;
891          }
892          return diff;
893        }
894      } else {
895        return Math.abs(norm(val1, index) - norm(val2, index));
896      }
897    default:
898      return 0;
899    }
900  }
901
902  /**
903   * Calculates the distance between two instances
904   *
905   * @param first the first instance
906   * @param second the second instance
907   * @return the distance between the two given instances, between 0 and 1
908   */         
909  private double distance(Instance first, Instance second) { 
910
911    double distance = 0;
912    int firstI, secondI;
913
914    for (int p1 = 0, p2 = 0; 
915         p1 < first.numValues() || p2 < second.numValues();) {
916      if (p1 >= first.numValues()) {
917        firstI = m_trainInstances.numAttributes();
918      } else {
919        firstI = first.index(p1); 
920      }
921      if (p2 >= second.numValues()) {
922        secondI = m_trainInstances.numAttributes();
923      } else {
924        secondI = second.index(p2);
925      }
926      if (firstI == m_trainInstances.classIndex()) {
927        p1++; continue;
928      } 
929      if (secondI == m_trainInstances.classIndex()) {
930        p2++; continue;
931      } 
932      double diff;
933      if (firstI == secondI) {
934        diff = difference(firstI, 
935                          first.valueSparse(p1),
936                          second.valueSparse(p2));
937        p1++; p2++;
938      } else if (firstI > secondI) {
939        diff = difference(secondI, 
940                          0, second.valueSparse(p2));
941        p2++;
942      } else {
943        diff = difference(firstI, 
944                          first.valueSparse(p1), 0);
945        p1++;
946      }
947      //      distance += diff * diff;
948      distance += diff;
949    }
950   
951    //    return Math.sqrt(distance / m_NumAttributesUsed);
952    return distance;
953  }
954
955
956  /**
957   * update attribute weights given an instance when the class is numeric
958   *
959   * @param instNum the index of the instance to use when updating weights
960   */
961  private void updateWeightsNumericClass (int instNum) {
962    int i, j;
963    double temp,temp2;
964    int[] tempSorted = null;
965    double[] tempDist = null;
966    double distNorm = 1.0;
967    int firstI, secondI;
968
969    Instance inst = m_trainInstances.instance(instNum);
970   
971    // sort nearest neighbours and set up normalization variable
972    if (m_weightByDistance) {
973      tempDist = new double[m_stored[0]];
974
975      for (j = 0, distNorm = 0; j < m_stored[0]; j++) {
976        // copy the distances
977        tempDist[j] = m_karray[0][j][0];
978        // sum normalizer
979        distNorm += m_weightsByRank[j];
980      }
981
982      tempSorted = Utils.sort(tempDist);
983    }
984
985    for (i = 0; i < m_stored[0]; i++) {
986      // P diff prediction (class) given nearest instances
987      if (m_weightByDistance) {
988        temp = difference(m_classIndex, 
989                          inst.value(m_classIndex),
990                          m_trainInstances.
991                          instance((int)m_karray[0][tempSorted[i]][1]).
992                          value(m_classIndex));
993        temp *= (m_weightsByRank[i]/distNorm);
994      }
995      else {
996        temp = difference(m_classIndex, 
997                          inst.value(m_classIndex), 
998                          m_trainInstances.
999                          instance((int)m_karray[0][i][1]).
1000                          value(m_classIndex));
1001        temp *= (1.0/(double)m_stored[0]); // equal influence
1002      }
1003
1004      m_ndc += temp;
1005
1006      Instance cmp;
1007      cmp = (m_weightByDistance) 
1008        ? m_trainInstances.instance((int)m_karray[0][tempSorted[i]][1])
1009        : m_trainInstances.instance((int)m_karray[0][i][1]);
1010 
1011      double temp_diffP_diffA_givNearest = 
1012        difference(m_classIndex, inst.value(m_classIndex),
1013                   cmp.value(m_classIndex));
1014      // now the attributes
1015      for (int p1 = 0, p2 = 0; 
1016           p1 < inst.numValues() || p2 < cmp.numValues();) {
1017        if (p1 >= inst.numValues()) {
1018          firstI = m_trainInstances.numAttributes();
1019        } else {
1020          firstI = inst.index(p1); 
1021        }
1022        if (p2 >= cmp.numValues()) {
1023          secondI = m_trainInstances.numAttributes();
1024        } else {
1025          secondI = cmp.index(p2);
1026        }
1027        if (firstI == m_trainInstances.classIndex()) {
1028          p1++; continue;
1029        } 
1030        if (secondI == m_trainInstances.classIndex()) {
1031          p2++; continue;
1032        } 
1033        temp = 0.0;
1034        temp2 = 0.0;
1035     
1036        if (firstI == secondI) {
1037          j = firstI;
1038          temp = difference(j, inst.valueSparse(p1), cmp.valueSparse(p2)); 
1039          p1++;p2++;
1040        } else if (firstI > secondI) {
1041          j = secondI;
1042          temp = difference(j, 0, cmp.valueSparse(p2));
1043          p2++;
1044        } else {
1045          j = firstI;
1046          temp = difference(j, inst.valueSparse(p1), 0);
1047          p1++;
1048        } 
1049       
1050        temp2 = temp_diffP_diffA_givNearest * temp; 
1051        // P of different prediction and different att value given
1052        // nearest instances
1053        if (m_weightByDistance) {
1054          temp2 *= (m_weightsByRank[i]/distNorm);
1055        }
1056        else {
1057          temp2 *= (1.0/(double)m_stored[0]); // equal influence
1058        }
1059
1060        m_ndcda[j] += temp2;
1061       
1062        // P of different attribute val given nearest instances
1063        if (m_weightByDistance) {
1064          temp *= (m_weightsByRank[i]/distNorm);
1065        }
1066        else {
1067          temp *= (1.0/(double)m_stored[0]); // equal influence
1068        }
1069
1070        m_nda[j] += temp;
1071      }
1072    }
1073  }
1074
1075
1076  /**
1077   * update attribute weights given an instance when the class is discrete
1078   *
1079   * @param instNum the index of the instance to use when updating weights
1080   */
1081  private void updateWeightsDiscreteClass (int instNum) {
1082    int i, j, k;
1083    int cl;
1084    double temp_diff, w_norm = 1.0;
1085    double[] tempDistClass;
1086    int[] tempSortedClass = null;
1087    double distNormClass = 1.0;
1088    double[] tempDistAtt;
1089    int[][] tempSortedAtt = null;
1090    double[] distNormAtt = null;
1091    int firstI, secondI;
1092
1093    // store the indexes (sparse instances) of non-zero elements
1094    Instance inst = m_trainInstances.instance(instNum);
1095
1096    // get the class of this instance
1097    cl = (int)m_trainInstances.instance(instNum).value(m_classIndex);
1098
1099    // sort nearest neighbours and set up normalization variables
1100    if (m_weightByDistance) {
1101      // do class (hits) first
1102      // sort the distances
1103      tempDistClass = new double[m_stored[cl]];
1104
1105      for (j = 0, distNormClass = 0; j < m_stored[cl]; j++) {
1106        // copy the distances
1107        tempDistClass[j] = m_karray[cl][j][0];
1108        // sum normalizer
1109        distNormClass += m_weightsByRank[j];
1110      }
1111
1112      tempSortedClass = Utils.sort(tempDistClass);
1113      // do misses (other classes)
1114      tempSortedAtt = new int[m_numClasses][1];
1115      distNormAtt = new double[m_numClasses];
1116
1117      for (k = 0; k < m_numClasses; k++) {
1118        if (k != cl) // already done cl
1119          {
1120            // sort the distances
1121            tempDistAtt = new double[m_stored[k]];
1122
1123            for (j = 0, distNormAtt[k] = 0; j < m_stored[k]; j++) {
1124              // copy the distances
1125              tempDistAtt[j] = m_karray[k][j][0];
1126              // sum normalizer
1127              distNormAtt[k] += m_weightsByRank[j];
1128            }
1129
1130            tempSortedAtt[k] = Utils.sort(tempDistAtt);
1131          }
1132      }
1133    }
1134
1135    if (m_numClasses > 2) {
1136      // the amount of probability space left after removing the
1137      // probability of this instance's class value
1138      w_norm = (1.0 - m_classProbs[cl]);
1139    }
1140   
1141    // do the k nearest hits of the same class
1142    for (j = 0, temp_diff = 0.0; j < m_stored[cl]; j++) {
1143      Instance cmp;
1144      cmp = (m_weightByDistance) 
1145        ? m_trainInstances.
1146        instance((int)m_karray[cl][tempSortedClass[j]][1])
1147        : m_trainInstances.instance((int)m_karray[cl][j][1]);
1148
1149      for (int p1 = 0, p2 = 0; 
1150           p1 < inst.numValues() || p2 < cmp.numValues();) {
1151        if (p1 >= inst.numValues()) {
1152          firstI = m_trainInstances.numAttributes();
1153        } else {
1154          firstI = inst.index(p1); 
1155        }
1156        if (p2 >= cmp.numValues()) {
1157          secondI = m_trainInstances.numAttributes();
1158        } else {
1159          secondI = cmp.index(p2);
1160        }
1161        if (firstI == m_trainInstances.classIndex()) {
1162          p1++; continue;
1163        } 
1164        if (secondI == m_trainInstances.classIndex()) {
1165          p2++; continue;
1166        } 
1167        if (firstI == secondI) {
1168          i = firstI;
1169          temp_diff = difference(i, inst.valueSparse(p1), 
1170                                 cmp.valueSparse(p2)); 
1171          p1++;p2++;
1172        } else if (firstI > secondI) {
1173          i = secondI;
1174          temp_diff = difference(i, 0, cmp.valueSparse(p2));
1175          p2++;
1176        } else {
1177          i = firstI;
1178          temp_diff = difference(i, inst.valueSparse(p1), 0);
1179          p1++;
1180        } 
1181       
1182        if (m_weightByDistance) {
1183          temp_diff *=
1184            (m_weightsByRank[j]/distNormClass);
1185        } else {
1186          if (m_stored[cl] > 0) {
1187            temp_diff /= (double)m_stored[cl];
1188          }
1189        }
1190        m_weights[i] -= temp_diff;
1191
1192      }
1193    }
1194     
1195
1196    // now do k nearest misses from each of the other classes
1197    temp_diff = 0.0;
1198
1199    for (k = 0; k < m_numClasses; k++) {
1200      if (k != cl) // already done cl
1201        {
1202          for (j = 0; j < m_stored[k]; j++) {
1203            Instance cmp;
1204            cmp = (m_weightByDistance) 
1205              ? m_trainInstances.
1206              instance((int)m_karray[k][tempSortedAtt[k][j]][1])
1207              : m_trainInstances.instance((int)m_karray[k][j][1]);
1208       
1209            for (int p1 = 0, p2 = 0; 
1210                 p1 < inst.numValues() || p2 < cmp.numValues();) {
1211              if (p1 >= inst.numValues()) {
1212                firstI = m_trainInstances.numAttributes();
1213              } else {
1214                firstI = inst.index(p1); 
1215              }
1216              if (p2 >= cmp.numValues()) {
1217                secondI = m_trainInstances.numAttributes();
1218              } else {
1219                secondI = cmp.index(p2);
1220              }
1221              if (firstI == m_trainInstances.classIndex()) {
1222                p1++; continue;
1223              } 
1224              if (secondI == m_trainInstances.classIndex()) {
1225                p2++; continue;
1226              } 
1227              if (firstI == secondI) {
1228                i = firstI;
1229                temp_diff = difference(i, inst.valueSparse(p1), 
1230                                       cmp.valueSparse(p2)); 
1231                p1++;p2++;
1232              } else if (firstI > secondI) {
1233                i = secondI;
1234                temp_diff = difference(i, 0, cmp.valueSparse(p2));
1235                p2++;
1236              } else {
1237                i = firstI;
1238                temp_diff = difference(i, inst.valueSparse(p1), 0);
1239                p1++;
1240              } 
1241
1242              if (m_weightByDistance) {
1243                temp_diff *=
1244                  (m_weightsByRank[j]/distNormAtt[k]);
1245              }
1246              else {
1247                if (m_stored[k] > 0) {
1248                  temp_diff /= (double)m_stored[k];
1249                }
1250              }
1251              if (m_numClasses > 2) {
1252                m_weights[i] += ((m_classProbs[k]/w_norm)*temp_diff);
1253              } else {
1254                m_weights[i] += temp_diff;
1255              }
1256            }
1257          }
1258        }
1259    }
1260  }
1261
1262
1263  /**
1264   * Find the K nearest instances to supplied instance if the class is numeric,
1265   * or the K nearest Hits (same class) and Misses (K from each of the other
1266   * classes) if the class is discrete.
1267   *
1268   * @param instNum the index of the instance to find nearest neighbours of
1269   */
1270  private void findKHitMiss (int instNum) {
1271    int i, j;
1272    int cl;
1273    double ww;
1274    double temp_diff = 0.0;
1275    Instance thisInst = m_trainInstances.instance(instNum);
1276
1277    for (i = 0; i < m_numInstances; i++) {
1278      if (i != instNum) {
1279        Instance cmpInst = m_trainInstances.instance(i);
1280        temp_diff = distance(cmpInst, thisInst);
1281
1282        // class of this training instance or 0 if numeric
1283        if (m_numericClass) {
1284          cl = 0;
1285        }
1286        else {
1287          cl = (int)m_trainInstances.instance(i).value(m_classIndex);
1288        }
1289
1290        // add this diff to the list for the class of this instance
1291        if (m_stored[cl] < m_Knn) {
1292          m_karray[cl][m_stored[cl]][0] = temp_diff;
1293          m_karray[cl][m_stored[cl]][1] = i;
1294          m_stored[cl]++;
1295
1296          // note the worst diff for this class
1297          for (j = 0, ww = -1.0; j < m_stored[cl]; j++) {
1298            if (m_karray[cl][j][0] > ww) {
1299              ww = m_karray[cl][j][0];
1300              m_index[cl] = j;
1301            }
1302          }
1303
1304          m_worst[cl] = ww;
1305        }
1306        else 
1307          /* if we already have stored knn for this class then check to
1308             see if this instance is better than the worst */
1309          {
1310            if (temp_diff < m_karray[cl][m_index[cl]][0]) {
1311              m_karray[cl][m_index[cl]][0] = temp_diff;
1312              m_karray[cl][m_index[cl]][1] = i;
1313
1314              for (j = 0, ww = -1.0; j < m_stored[cl]; j++) {
1315                if (m_karray[cl][j][0] > ww) {
1316                  ww = m_karray[cl][j][0];
1317                  m_index[cl] = j;
1318                }
1319              }
1320
1321              m_worst[cl] = ww;
1322            }
1323          }
1324      }
1325    }
1326  }
1327 
1328  /**
1329   * Returns the revision string.
1330   *
1331   * @return            the revision
1332   */
1333  public String getRevision() {
1334    return RevisionUtils.extract("$Revision: 5987 $");
1335  }
1336
1337  // ============
1338  // Test method.
1339  // ============
1340  /**
1341   * Main method for testing this class.
1342   *
1343   * @param args the options
1344   */
1345  public static void main (String[] args) {
1346    runEvaluator(new ReliefFAttributeEval(), args);
1347  }
1348}
Note: See TracBrowser for help on using the repository browser.