source: src/main/java/weka/classifiers/lazy/IBk.java @ 21

Last change on this file since 21 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 31.3 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    IBk.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.lazy;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.classifiers.UpdateableClassifier;
28import weka.core.Attribute;
29import weka.core.Capabilities;
30import weka.core.Instance;
31import weka.core.Instances;
32import weka.core.neighboursearch.LinearNNSearch;
33import weka.core.neighboursearch.NearestNeighbourSearch;
34import weka.core.Option;
35import weka.core.OptionHandler;
36import weka.core.RevisionUtils;
37import weka.core.SelectedTag;
38import weka.core.Tag;
39import weka.core.TechnicalInformation;
40import weka.core.TechnicalInformationHandler;
41import weka.core.Utils;
42import weka.core.WeightedInstancesHandler;
43import weka.core.Capabilities.Capability;
44import weka.core.TechnicalInformation.Field;
45import weka.core.TechnicalInformation.Type;
46import weka.core.AdditionalMeasureProducer;
47
48import java.util.Enumeration;
49import java.util.Vector;
50
51/**
52 <!-- globalinfo-start -->
53 * K-nearest neighbours classifier. Can select appropriate value of K based on cross-validation. Can also do distance weighting.<br/>
54 * <br/>
55 * For more information, see<br/>
56 * <br/>
57 * D. Aha, D. Kibler (1991). Instance-based learning algorithms. Machine Learning. 6:37-66.
58 * <p/>
59 <!-- globalinfo-end -->
60 *
61 <!-- technical-bibtex-start -->
62 * BibTeX:
63 * <pre>
64 * &#64;article{Aha1991,
65 *    author = {D. Aha and D. Kibler},
66 *    journal = {Machine Learning},
67 *    pages = {37-66},
68 *    title = {Instance-based learning algorithms},
69 *    volume = {6},
70 *    year = {1991}
71 * }
72 * </pre>
73 * <p/>
74 <!-- technical-bibtex-end -->
75 *
76 <!-- options-start -->
77 * Valid options are: <p/>
78 *
79 * <pre> -I
80 *  Weight neighbours by the inverse of their distance
81 *  (use when k &gt; 1)</pre>
82 *
83 * <pre> -F
84 *  Weight neighbours by 1 - their distance
85 *  (use when k &gt; 1)</pre>
86 *
87 * <pre> -K &lt;number of neighbors&gt;
88 *  Number of nearest neighbours (k) used in classification.
89 *  (Default = 1)</pre>
90 *
91 * <pre> -E
92 *  Minimise mean squared error rather than mean absolute
93 *  error when using -X option with numeric prediction.</pre>
94 *
95 * <pre> -W &lt;window size&gt;
96 *  Maximum number of training instances maintained.
97 *  Training instances are dropped FIFO. (Default = no window)</pre>
98 *
99 * <pre> -X
100 *  Select the number of nearest neighbours between 1
101 *  and the k value specified using hold-one-out evaluation
102 *  on the training data (use when k &gt; 1)</pre>
103 *
104 * <pre> -A
105 *  The nearest neighbour search algorithm to use (default: weka.core.neighboursearch.LinearNNSearch).
106 * </pre>
107 *
108 <!-- options-end -->
109 *
110 * @author Stuart Inglis (singlis@cs.waikato.ac.nz)
111 * @author Len Trigg (trigg@cs.waikato.ac.nz)
112 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
113 * @version $Revision: 5928 $
114 */
115public class IBk 
116  extends AbstractClassifier
117  implements OptionHandler, UpdateableClassifier, WeightedInstancesHandler,
118             TechnicalInformationHandler, AdditionalMeasureProducer {
119
120  /** for serialization. */
121  static final long serialVersionUID = -3080186098777067172L;
122
123  /** The training instances used for classification. */
124  protected Instances m_Train;
125
126  /** The number of class values (or 1 if predicting numeric). */
127  protected int m_NumClasses;
128
129  /** The class attribute type. */
130  protected int m_ClassType;
131
132  /** The number of neighbours to use for classification (currently). */
133  protected int m_kNN;
134
135  /**
136   * The value of kNN provided by the user. This may differ from
137   * m_kNN if cross-validation is being used.
138   */
139  protected int m_kNNUpper;
140
141  /**
142   * Whether the value of k selected by cross validation has
143   * been invalidated by a change in the training instances.
144   */
145  protected boolean m_kNNValid;
146
147  /**
148   * The maximum number of training instances allowed. When
149   * this limit is reached, old training instances are removed,
150   * so the training data is "windowed". Set to 0 for unlimited
151   * numbers of instances.
152   */
153  protected int m_WindowSize;
154
155  /** Whether the neighbours should be distance-weighted. */
156  protected int m_DistanceWeighting;
157
158  /** Whether to select k by cross validation. */
159  protected boolean m_CrossValidate;
160
161  /**
162   * Whether to minimise mean squared error rather than mean absolute
163   * error when cross-validating on numeric prediction tasks.
164   */
165  protected boolean m_MeanSquared;
166
167  /** no weighting. */
168  public static final int WEIGHT_NONE = 1;
169  /** weight by 1/distance. */
170  public static final int WEIGHT_INVERSE = 2;
171  /** weight by 1-distance. */
172  public static final int WEIGHT_SIMILARITY = 4;
173  /** possible instance weighting methods. */
174  public static final Tag [] TAGS_WEIGHTING = {
175    new Tag(WEIGHT_NONE, "No distance weighting"),
176    new Tag(WEIGHT_INVERSE, "Weight by 1/distance"),
177    new Tag(WEIGHT_SIMILARITY, "Weight by 1-distance")
178  };
179 
180  /** for nearest-neighbor search. */
181  protected NearestNeighbourSearch m_NNSearch = new LinearNNSearch();
182
183  /** The number of attributes the contribute to a prediction. */
184  protected double m_NumAttributesUsed;
185 
186  /**
187   * IBk classifier. Simple instance-based learner that uses the class
188   * of the nearest k training instances for the class of the test
189   * instances.
190   *
191   * @param k the number of nearest neighbors to use for prediction
192   */
193  public IBk(int k) {
194
195    init();
196    setKNN(k);
197  } 
198
199  /**
200   * IB1 classifer. Instance-based learner. Predicts the class of the
201   * single nearest training instance for each test instance.
202   */
203  public IBk() {
204
205    init();
206  }
207 
208  /**
209   * Returns a string describing classifier.
210   * @return a description suitable for
211   * displaying in the explorer/experimenter gui
212   */
213  public String globalInfo() {
214
215    return  "K-nearest neighbours classifier. Can "
216      + "select appropriate value of K based on cross-validation. Can also do "
217      + "distance weighting.\n\n"
218      + "For more information, see\n\n"
219      + getTechnicalInformation().toString();
220  }
221
222  /**
223   * Returns an instance of a TechnicalInformation object, containing
224   * detailed information about the technical background of this class,
225   * e.g., paper reference or book this class is based on.
226   *
227   * @return the technical information about this class
228   */
229  public TechnicalInformation getTechnicalInformation() {
230    TechnicalInformation        result;
231   
232    result = new TechnicalInformation(Type.ARTICLE);
233    result.setValue(Field.AUTHOR, "D. Aha and D. Kibler");
234    result.setValue(Field.YEAR, "1991");
235    result.setValue(Field.TITLE, "Instance-based learning algorithms");
236    result.setValue(Field.JOURNAL, "Machine Learning");
237    result.setValue(Field.VOLUME, "6");
238    result.setValue(Field.PAGES, "37-66");
239   
240    return result;
241  }
242
243  /**
244   * Returns the tip text for this property.
245   * @return tip text for this property suitable for
246   * displaying in the explorer/experimenter gui
247   */
248  public String KNNTipText() {
249    return "The number of neighbours to use.";
250  }
251 
252  /**
253   * Set the number of neighbours the learner is to use.
254   *
255   * @param k the number of neighbours.
256   */
257  public void setKNN(int k) {
258    m_kNN = k;
259    m_kNNUpper = k;
260    m_kNNValid = false;
261  }
262
263  /**
264   * Gets the number of neighbours the learner will use.
265   *
266   * @return the number of neighbours.
267   */
268  public int getKNN() {
269
270    return m_kNN;
271  }
272
273  /**
274   * Returns the tip text for this property.
275   * @return tip text for this property suitable for
276   * displaying in the explorer/experimenter gui
277   */
278  public String windowSizeTipText() {
279    return "Gets the maximum number of instances allowed in the training " +
280      "pool. The addition of new instances above this value will result " +
281      "in old instances being removed. A value of 0 signifies no limit " +
282      "to the number of training instances.";
283  }
284 
285  /**
286   * Gets the maximum number of instances allowed in the training
287   * pool. The addition of new instances above this value will result
288   * in old instances being removed. A value of 0 signifies no limit
289   * to the number of training instances.
290   *
291   * @return Value of WindowSize.
292   */
293  public int getWindowSize() {
294   
295    return m_WindowSize;
296  }
297 
298  /**
299   * Sets the maximum number of instances allowed in the training
300   * pool. The addition of new instances above this value will result
301   * in old instances being removed. A value of 0 signifies no limit
302   * to the number of training instances.
303   *
304   * @param newWindowSize Value to assign to WindowSize.
305   */
306  public void setWindowSize(int newWindowSize) {
307   
308    m_WindowSize = newWindowSize;
309  }
310 
311  /**
312   * Returns the tip text for this property.
313   * @return tip text for this property suitable for
314   * displaying in the explorer/experimenter gui
315   */
316  public String distanceWeightingTipText() {
317
318    return "Gets the distance weighting method used.";
319  }
320 
321  /**
322   * Gets the distance weighting method used. Will be one of
323   * WEIGHT_NONE, WEIGHT_INVERSE, or WEIGHT_SIMILARITY
324   *
325   * @return the distance weighting method used.
326   */
327  public SelectedTag getDistanceWeighting() {
328
329    return new SelectedTag(m_DistanceWeighting, TAGS_WEIGHTING);
330  }
331 
332  /**
333   * Sets the distance weighting method used. Values other than
334   * WEIGHT_NONE, WEIGHT_INVERSE, or WEIGHT_SIMILARITY will be ignored.
335   *
336   * @param newMethod the distance weighting method to use
337   */
338  public void setDistanceWeighting(SelectedTag newMethod) {
339   
340    if (newMethod.getTags() == TAGS_WEIGHTING) {
341      m_DistanceWeighting = newMethod.getSelectedTag().getID();
342    }
343  }
344 
345  /**
346   * Returns the tip text for this property.
347   * @return tip text for this property suitable for
348   * displaying in the explorer/experimenter gui
349   */
350  public String meanSquaredTipText() {
351
352    return "Whether the mean squared error is used rather than mean "
353      + "absolute error when doing cross-validation for regression problems.";
354  }
355
356  /**
357   * Gets whether the mean squared error is used rather than mean
358   * absolute error when doing cross-validation.
359   *
360   * @return true if so.
361   */
362  public boolean getMeanSquared() {
363   
364    return m_MeanSquared;
365  }
366 
367  /**
368   * Sets whether the mean squared error is used rather than mean
369   * absolute error when doing cross-validation.
370   *
371   * @param newMeanSquared true if so.
372   */
373  public void setMeanSquared(boolean newMeanSquared) {
374   
375    m_MeanSquared = newMeanSquared;
376  }
377 
378  /**
379   * Returns the tip text for this property.
380   * @return tip text for this property suitable for
381   * displaying in the explorer/experimenter gui
382   */
383  public String crossValidateTipText() {
384
385    return "Whether hold-one-out cross-validation will be used " +
386      "to select the best k value.";
387  }
388 
389  /**
390   * Gets whether hold-one-out cross-validation will be used
391   * to select the best k value.
392   *
393   * @return true if cross-validation will be used.
394   */
395  public boolean getCrossValidate() {
396   
397    return m_CrossValidate;
398  }
399 
400  /**
401   * Sets whether hold-one-out cross-validation will be used
402   * to select the best k value.
403   *
404   * @param newCrossValidate true if cross-validation should be used.
405   */
406  public void setCrossValidate(boolean newCrossValidate) {
407   
408    m_CrossValidate = newCrossValidate;
409  }
410
411  /**
412   * Returns the tip text for this property.
413   * @return tip text for this property suitable for
414   * displaying in the explorer/experimenter gui
415   */
416  public String nearestNeighbourSearchAlgorithmTipText() {
417    return "The nearest neighbour search algorithm to use " +
418           "(Default: weka.core.neighboursearch.LinearNNSearch).";
419  }
420 
421  /**
422   * Returns the current nearestNeighbourSearch algorithm in use.
423   * @return the NearestNeighbourSearch algorithm currently in use.
424   */
425  public NearestNeighbourSearch getNearestNeighbourSearchAlgorithm() {
426    return m_NNSearch;
427  }
428 
429  /**
430   * Sets the nearestNeighbourSearch algorithm to be used for finding nearest
431   * neighbour(s).
432   * @param nearestNeighbourSearchAlgorithm - The NearestNeighbourSearch class.
433   */
434  public void setNearestNeighbourSearchAlgorithm(NearestNeighbourSearch nearestNeighbourSearchAlgorithm) {
435    m_NNSearch = nearestNeighbourSearchAlgorithm;
436  }
437   
438  /**
439   * Get the number of training instances the classifier is currently using.
440   *
441   * @return the number of training instances the classifier is currently using
442   */
443  public int getNumTraining() {
444
445    return m_Train.numInstances();
446  }
447
448  /**
449   * Returns default capabilities of the classifier.
450   *
451   * @return      the capabilities of this classifier
452   */
453  public Capabilities getCapabilities() {
454    Capabilities result = super.getCapabilities();
455    result.disableAll();
456
457    // attributes
458    result.enable(Capability.NOMINAL_ATTRIBUTES);
459    result.enable(Capability.NUMERIC_ATTRIBUTES);
460    result.enable(Capability.DATE_ATTRIBUTES);
461    result.enable(Capability.MISSING_VALUES);
462
463    // class
464    result.enable(Capability.NOMINAL_CLASS);
465    result.enable(Capability.NUMERIC_CLASS);
466    result.enable(Capability.DATE_CLASS);
467    result.enable(Capability.MISSING_CLASS_VALUES);
468
469    // instances
470    result.setMinimumNumberInstances(0);
471   
472    return result;
473  }
474 
475  /**
476   * Generates the classifier.
477   *
478   * @param instances set of instances serving as training data
479   * @throws Exception if the classifier has not been generated successfully
480   */
481  public void buildClassifier(Instances instances) throws Exception {
482   
483    // can classifier handle the data?
484    getCapabilities().testWithFail(instances);
485
486    // remove instances with missing class
487    instances = new Instances(instances);
488    instances.deleteWithMissingClass();
489   
490    m_NumClasses = instances.numClasses();
491    m_ClassType = instances.classAttribute().type();
492    m_Train = new Instances(instances, 0, instances.numInstances());
493
494    // Throw away initial instances until within the specified window size
495    if ((m_WindowSize > 0) && (instances.numInstances() > m_WindowSize)) {
496      m_Train = new Instances(m_Train, 
497                              m_Train.numInstances()-m_WindowSize, 
498                              m_WindowSize);
499    }
500
501    m_NumAttributesUsed = 0.0;
502    for (int i = 0; i < m_Train.numAttributes(); i++) {
503      if ((i != m_Train.classIndex()) && 
504          (m_Train.attribute(i).isNominal() ||
505           m_Train.attribute(i).isNumeric())) {
506        m_NumAttributesUsed += 1.0;
507      }
508    }
509   
510    m_NNSearch.setInstances(m_Train);
511
512    // Invalidate any currently cross-validation selected k
513    m_kNNValid = false;
514  }
515
516  /**
517   * Adds the supplied instance to the training set.
518   *
519   * @param instance the instance to add
520   * @throws Exception if instance could not be incorporated
521   * successfully
522   */
523  public void updateClassifier(Instance instance) throws Exception {
524
525    if (m_Train.equalHeaders(instance.dataset()) == false) {
526      throw new Exception("Incompatible instance types\n" + m_Train.equalHeadersMsg(instance.dataset()));
527    }
528    if (instance.classIsMissing()) {
529      return;
530    }
531
532    m_Train.add(instance);
533    m_NNSearch.update(instance);
534    m_kNNValid = false;
535    if ((m_WindowSize > 0) && (m_Train.numInstances() > m_WindowSize)) {
536      boolean deletedInstance=false;
537      while (m_Train.numInstances() > m_WindowSize) {
538        m_Train.delete(0);
539        deletedInstance=true;
540      }
541      //rebuild datastructure KDTree currently can't delete
542      if(deletedInstance==true)
543        m_NNSearch.setInstances(m_Train);
544    }
545  }
546
547  /**
548   * Calculates the class membership probabilities for the given test instance.
549   *
550   * @param instance the instance to be classified
551   * @return predicted class probability distribution
552   * @throws Exception if an error occurred during the prediction
553   */
554  public double [] distributionForInstance(Instance instance) throws Exception {
555
556    if (m_Train.numInstances() == 0) {
557      throw new Exception("No training instances!");
558    }
559    if ((m_WindowSize > 0) && (m_Train.numInstances() > m_WindowSize)) {
560      m_kNNValid = false;
561      boolean deletedInstance=false;
562      while (m_Train.numInstances() > m_WindowSize) {
563        m_Train.delete(0);
564      }
565      //rebuild datastructure KDTree currently can't delete
566      if(deletedInstance==true)
567        m_NNSearch.setInstances(m_Train);
568    }
569
570    // Select k by cross validation
571    if (!m_kNNValid && (m_CrossValidate) && (m_kNNUpper >= 1)) {
572      crossValidate();
573    }
574
575    m_NNSearch.addInstanceInfo(instance);
576
577    Instances neighbours = m_NNSearch.kNearestNeighbours(instance, m_kNN);
578    double [] distances = m_NNSearch.getDistances();
579    double [] distribution = makeDistribution( neighbours, distances );
580
581    return distribution;
582  }
583
584  /**
585   * Returns an enumeration describing the available options.
586   *
587   * @return an enumeration of all the available options.
588   */
589  public Enumeration listOptions() {
590
591    Vector newVector = new Vector(8);
592
593    newVector.addElement(new Option(
594              "\tWeight neighbours by the inverse of their distance\n"+
595              "\t(use when k > 1)",
596              "I", 0, "-I"));
597    newVector.addElement(new Option(
598              "\tWeight neighbours by 1 - their distance\n"+
599              "\t(use when k > 1)",
600              "F", 0, "-F"));
601    newVector.addElement(new Option(
602              "\tNumber of nearest neighbours (k) used in classification.\n"+
603              "\t(Default = 1)",
604              "K", 1,"-K <number of neighbors>"));
605    newVector.addElement(new Option(
606          "\tMinimise mean squared error rather than mean absolute\n"+
607              "\terror when using -X option with numeric prediction.",
608              "E", 0,"-E"));
609    newVector.addElement(new Option(
610          "\tMaximum number of training instances maintained.\n"+
611              "\tTraining instances are dropped FIFO. (Default = no window)",
612              "W", 1,"-W <window size>"));
613    newVector.addElement(new Option(
614              "\tSelect the number of nearest neighbours between 1\n"+
615              "\tand the k value specified using hold-one-out evaluation\n"+
616              "\ton the training data (use when k > 1)",
617              "X", 0,"-X"));
618    newVector.addElement(new Option(
619              "\tThe nearest neighbour search algorithm to use "+
620          "(default: weka.core.neighboursearch.LinearNNSearch).\n",
621              "A", 0, "-A"));
622
623    return newVector.elements();
624  }
625
626  /**
627   * Parses a given list of options. <p/>
628   *
629   <!-- options-start -->
630   * Valid options are: <p/>
631   *
632   * <pre> -I
633   *  Weight neighbours by the inverse of their distance
634   *  (use when k &gt; 1)</pre>
635   *
636   * <pre> -F
637   *  Weight neighbours by 1 - their distance
638   *  (use when k &gt; 1)</pre>
639   *
640   * <pre> -K &lt;number of neighbors&gt;
641   *  Number of nearest neighbours (k) used in classification.
642   *  (Default = 1)</pre>
643   *
644   * <pre> -E
645   *  Minimise mean squared error rather than mean absolute
646   *  error when using -X option with numeric prediction.</pre>
647   *
648   * <pre> -W &lt;window size&gt;
649   *  Maximum number of training instances maintained.
650   *  Training instances are dropped FIFO. (Default = no window)</pre>
651   *
652   * <pre> -X
653   *  Select the number of nearest neighbours between 1
654   *  and the k value specified using hold-one-out evaluation
655   *  on the training data (use when k &gt; 1)</pre>
656   *
657   * <pre> -A
658   *  The nearest neighbour search algorithm to use (default: weka.core.neighboursearch.LinearNNSearch).
659   * </pre>
660   *
661   <!-- options-end -->
662   *
663   * @param options the list of options as an array of strings
664   * @throws Exception if an option is not supported
665   */
666  public void setOptions(String[] options) throws Exception {
667   
668    String knnString = Utils.getOption('K', options);
669    if (knnString.length() != 0) {
670      setKNN(Integer.parseInt(knnString));
671    } else {
672      setKNN(1);
673    }
674    String windowString = Utils.getOption('W', options);
675    if (windowString.length() != 0) {
676      setWindowSize(Integer.parseInt(windowString));
677    } else {
678      setWindowSize(0);
679    }
680    if (Utils.getFlag('I', options)) {
681      setDistanceWeighting(new SelectedTag(WEIGHT_INVERSE, TAGS_WEIGHTING));
682    } else if (Utils.getFlag('F', options)) {
683      setDistanceWeighting(new SelectedTag(WEIGHT_SIMILARITY, TAGS_WEIGHTING));
684    } else {
685      setDistanceWeighting(new SelectedTag(WEIGHT_NONE, TAGS_WEIGHTING));
686    }
687    setCrossValidate(Utils.getFlag('X', options));
688    setMeanSquared(Utils.getFlag('E', options));
689
690    String nnSearchClass = Utils.getOption('A', options);
691    if(nnSearchClass.length() != 0) {
692      String nnSearchClassSpec[] = Utils.splitOptions(nnSearchClass);
693      if(nnSearchClassSpec.length == 0) { 
694        throw new Exception("Invalid NearestNeighbourSearch algorithm " +
695                            "specification string."); 
696      }
697      String className = nnSearchClassSpec[0];
698      nnSearchClassSpec[0] = "";
699
700      setNearestNeighbourSearchAlgorithm( (NearestNeighbourSearch)
701                  Utils.forName( NearestNeighbourSearch.class, 
702                                 className, 
703                                 nnSearchClassSpec)
704                                        );
705    }
706    else 
707      this.setNearestNeighbourSearchAlgorithm(new LinearNNSearch());
708   
709    Utils.checkForRemainingOptions(options);
710  }
711
712  /**
713   * Gets the current settings of IBk.
714   *
715   * @return an array of strings suitable for passing to setOptions()
716   */
717  public String [] getOptions() {
718
719    String [] options = new String [11];
720    int current = 0;
721    options[current++] = "-K"; options[current++] = "" + getKNN();
722    options[current++] = "-W"; options[current++] = "" + m_WindowSize;
723    if (getCrossValidate()) {
724      options[current++] = "-X";
725    }
726    if (getMeanSquared()) {
727      options[current++] = "-E";
728    }
729    if (m_DistanceWeighting == WEIGHT_INVERSE) {
730      options[current++] = "-I";
731    } else if (m_DistanceWeighting == WEIGHT_SIMILARITY) {
732      options[current++] = "-F";
733    }
734
735    options[current++] = "-A";
736    options[current++] = m_NNSearch.getClass().getName()+" "+Utils.joinOptions(m_NNSearch.getOptions()); 
737   
738    while (current < options.length) {
739      options[current++] = "";
740    }
741   
742    return options;
743  }
744
745  /**
746   * Returns an enumeration of the additional measure names
747   * produced by the neighbour search algorithm, plus the chosen K in case
748   * cross-validation is enabled.
749   *
750   * @return an enumeration of the measure names
751   */
752  public Enumeration enumerateMeasures() {
753    if (m_CrossValidate) {
754      Enumeration enm = m_NNSearch.enumerateMeasures();
755      Vector measures = new Vector();
756      while (enm.hasMoreElements())
757        measures.add(enm.nextElement());
758      measures.add("measureKNN");
759      return measures.elements();
760    }
761    else {
762      return m_NNSearch.enumerateMeasures();
763    }
764  }
765 
766  /**
767   * Returns the value of the named measure from the
768   * neighbour search algorithm, plus the chosen K in case
769   * cross-validation is enabled.
770   *
771   * @param additionalMeasureName the name of the measure to query for its value
772   * @return the value of the named measure
773   * @throws IllegalArgumentException if the named measure is not supported
774   */
775  public double getMeasure(String additionalMeasureName) {
776    if (additionalMeasureName.equals("measureKNN"))
777      return m_kNN;
778    else
779      return m_NNSearch.getMeasure(additionalMeasureName);
780  }
781 
782 
783  /**
784   * Returns a description of this classifier.
785   *
786   * @return a description of this classifier as a string.
787   */
788  public String toString() {
789
790    if (m_Train == null) {
791      return "IBk: No model built yet.";
792    }
793
794    if (!m_kNNValid && m_CrossValidate) {
795      crossValidate();
796    }
797
798    String result = "IB1 instance-based classifier\n" +
799      "using " + m_kNN;
800
801    switch (m_DistanceWeighting) {
802    case WEIGHT_INVERSE:
803      result += " inverse-distance-weighted";
804      break;
805    case WEIGHT_SIMILARITY:
806      result += " similarity-weighted";
807      break;
808    }
809    result += " nearest neighbour(s) for classification\n";
810
811    if (m_WindowSize != 0) {
812      result += "using a maximum of " 
813        + m_WindowSize + " (windowed) training instances\n";
814    }
815    return result;
816  }
817
818  /**
819   * Initialise scheme variables.
820   */
821  protected void init() {
822
823    setKNN(1);
824    m_WindowSize = 0;
825    m_DistanceWeighting = WEIGHT_NONE;
826    m_CrossValidate = false;
827    m_MeanSquared = false;
828  }
829 
830  /**
831   * Turn the list of nearest neighbors into a probability distribution.
832   *
833   * @param neighbours the list of nearest neighboring instances
834   * @param distances the distances of the neighbors
835   * @return the probability distribution
836   * @throws Exception if computation goes wrong or has no class attribute
837   */
838  protected double [] makeDistribution(Instances neighbours, double[] distances)
839    throws Exception {
840
841    double total = 0, weight;
842    double [] distribution = new double [m_NumClasses];
843   
844    // Set up a correction to the estimator
845    if (m_ClassType == Attribute.NOMINAL) {
846      for(int i = 0; i < m_NumClasses; i++) {
847        distribution[i] = 1.0 / Math.max(1,m_Train.numInstances());
848      }
849      total = (double)m_NumClasses / Math.max(1,m_Train.numInstances());
850    }
851
852    for(int i=0; i < neighbours.numInstances(); i++) {
853      // Collect class counts
854      Instance current = neighbours.instance(i);
855      distances[i] = distances[i]*distances[i];
856      distances[i] = Math.sqrt(distances[i]/m_NumAttributesUsed);
857      switch (m_DistanceWeighting) {
858        case WEIGHT_INVERSE:
859          weight = 1.0 / (distances[i] + 0.001); // to avoid div by zero
860          break;
861        case WEIGHT_SIMILARITY:
862          weight = 1.0 - distances[i];
863          break;
864        default:                                 // WEIGHT_NONE:
865          weight = 1.0;
866          break;
867      }
868      weight *= current.weight();
869      try {
870        switch (m_ClassType) {
871          case Attribute.NOMINAL:
872            distribution[(int)current.classValue()] += weight;
873            break;
874          case Attribute.NUMERIC:
875            distribution[0] += current.classValue() * weight;
876            break;
877        }
878      } catch (Exception ex) {
879        throw new Error("Data has no class attribute!");
880      }
881      total += weight;     
882    }
883
884    // Normalise distribution
885    if (total > 0) {
886      Utils.normalize(distribution, total);
887    }
888    return distribution;
889  }
890
891  /**
892   * Select the best value for k by hold-one-out cross-validation.
893   * If the class attribute is nominal, classification error is
894   * minimised. If the class attribute is numeric, mean absolute
895   * error is minimised
896   */
897  protected void crossValidate() {
898
899    try {
900      if (m_NNSearch instanceof weka.core.neighboursearch.CoverTree)
901        throw new Exception("CoverTree doesn't support hold-one-out "+
902                            "cross-validation. Use some other NN " +
903                            "method.");
904
905      double [] performanceStats = new double [m_kNNUpper];
906      double [] performanceStatsSq = new double [m_kNNUpper];
907
908      for(int i = 0; i < m_kNNUpper; i++) {
909        performanceStats[i] = 0;
910        performanceStatsSq[i] = 0;
911      }
912
913
914      m_kNN = m_kNNUpper;
915      Instance instance;
916      Instances neighbours;
917      double[] origDistances, convertedDistances;
918      for(int i = 0; i < m_Train.numInstances(); i++) {
919        if (m_Debug && (i % 50 == 0)) {
920          System.err.print("Cross validating "
921                           + i + "/" + m_Train.numInstances() + "\r");
922        }
923        instance = m_Train.instance(i);
924        neighbours = m_NNSearch.kNearestNeighbours(instance, m_kNN);
925        origDistances = m_NNSearch.getDistances();
926       
927        for(int j = m_kNNUpper - 1; j >= 0; j--) {
928          // Update the performance stats
929          convertedDistances = new double[origDistances.length];
930          System.arraycopy(origDistances, 0, 
931                           convertedDistances, 0, origDistances.length);
932          double [] distribution = makeDistribution(neighbours, 
933                                                    convertedDistances);
934          double thisPrediction = Utils.maxIndex(distribution);
935          if (m_Train.classAttribute().isNumeric()) {
936            thisPrediction = distribution[0];
937            double err = thisPrediction - instance.classValue();
938            performanceStatsSq[j] += err * err;   // Squared error
939            performanceStats[j] += Math.abs(err); // Absolute error
940          } else {
941            if (thisPrediction != instance.classValue()) {
942              performanceStats[j] ++;             // Classification error
943            }
944          }
945          if (j >= 1) {
946            neighbours = pruneToK(neighbours, convertedDistances, j);
947          }
948        }
949      }
950
951      // Display the results of the cross-validation
952      for(int i = 0; i < m_kNNUpper; i++) {
953        if (m_Debug) {
954          System.err.print("Hold-one-out performance of " + (i + 1)
955                           + " neighbors " );
956        }
957        if (m_Train.classAttribute().isNumeric()) {
958          if (m_Debug) {
959            if (m_MeanSquared) {
960              System.err.println("(RMSE) = "
961                                 + Math.sqrt(performanceStatsSq[i]
962                                             / m_Train.numInstances()));
963            } else {
964              System.err.println("(MAE) = "
965                                 + performanceStats[i]
966                                 / m_Train.numInstances());
967            }
968          }
969        } else {
970          if (m_Debug) {
971            System.err.println("(%ERR) = "
972                               + 100.0 * performanceStats[i]
973                               / m_Train.numInstances());
974          }
975        }
976      }
977
978
979      // Check through the performance stats and select the best
980      // k value (or the lowest k if more than one best)
981      double [] searchStats = performanceStats;
982      if (m_Train.classAttribute().isNumeric() && m_MeanSquared) {
983        searchStats = performanceStatsSq;
984      }
985      double bestPerformance = Double.NaN;
986      int bestK = 1;
987      for(int i = 0; i < m_kNNUpper; i++) {
988        if (Double.isNaN(bestPerformance)
989            || (bestPerformance > searchStats[i])) {
990          bestPerformance = searchStats[i];
991          bestK = i + 1;
992        }
993      }
994      m_kNN = bestK;
995      if (m_Debug) {
996        System.err.println("Selected k = " + bestK);
997      }
998     
999      m_kNNValid = true;
1000    } catch (Exception ex) {
1001      throw new Error("Couldn't optimize by cross-validation: "
1002                      +ex.getMessage());
1003    }
1004  }
1005 
1006  /**
1007   * Prunes the list to contain the k nearest neighbors. If there are
1008   * multiple neighbors at the k'th distance, all will be kept.
1009   *
1010   * @param neighbours the neighbour instances.
1011   * @param distances the distances of the neighbours from target instance.
1012   * @param k the number of neighbors to keep.
1013   * @return the pruned neighbours.
1014   */
1015  public Instances pruneToK(Instances neighbours, double[] distances, int k) {
1016   
1017    if(neighbours==null || distances==null || neighbours.numInstances()==0) {
1018      return null;
1019    }
1020    if (k < 1) {
1021      k = 1;
1022    }
1023   
1024    int currentK = 0;
1025    double currentDist;
1026    for(int i=0; i < neighbours.numInstances(); i++) {
1027      currentK++;
1028      currentDist = distances[i];
1029      if(currentK>k && currentDist!=distances[i-1]) {
1030        currentK--;
1031        neighbours = new Instances(neighbours, 0, currentK);
1032        break;
1033      }
1034    }
1035
1036    return neighbours;
1037  }
1038 
1039  /**
1040   * Returns the revision string.
1041   *
1042   * @return            the revision
1043   */
1044  public String getRevision() {
1045    return RevisionUtils.extract("$Revision: 5928 $");
1046  }
1047 
1048  /**
1049   * Main method for testing this class.
1050   *
1051   * @param argv should contain command line options (see setOptions)
1052   */
1053  public static void main(String [] argv) {
1054    runClassifier(new IBk(), argv);
1055  }
1056}
Note: See TracBrowser for help on using the repository browser.