1 | /* |
---|
2 | * This program is free software; you can redistribute it and/or modify |
---|
3 | * it under the terms of the GNU General Public License as published by |
---|
4 | * the Free Software Foundation; either version 2 of the License, or |
---|
5 | * (at your option) any later version. |
---|
6 | * |
---|
7 | * This program is distributed in the hope that it will be useful, |
---|
8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
---|
10 | * GNU General Public License for more details. |
---|
11 | * |
---|
12 | * You should have received a copy of the GNU General Public License |
---|
13 | * along with this program; if not, write to the Free Software |
---|
14 | * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. |
---|
15 | */ |
---|
16 | |
---|
17 | /* |
---|
18 | * IBk.java |
---|
19 | * Copyright (C) 1999 University of Waikato, Hamilton, New Zealand |
---|
20 | * |
---|
21 | */ |
---|
22 | |
---|
23 | package weka.classifiers.lazy; |
---|
24 | |
---|
25 | import weka.classifiers.Classifier; |
---|
26 | import weka.classifiers.AbstractClassifier; |
---|
27 | import weka.classifiers.UpdateableClassifier; |
---|
28 | import weka.core.Attribute; |
---|
29 | import weka.core.Capabilities; |
---|
30 | import weka.core.Instance; |
---|
31 | import weka.core.Instances; |
---|
32 | import weka.core.neighboursearch.LinearNNSearch; |
---|
33 | import weka.core.neighboursearch.NearestNeighbourSearch; |
---|
34 | import weka.core.Option; |
---|
35 | import weka.core.OptionHandler; |
---|
36 | import weka.core.RevisionUtils; |
---|
37 | import weka.core.SelectedTag; |
---|
38 | import weka.core.Tag; |
---|
39 | import weka.core.TechnicalInformation; |
---|
40 | import weka.core.TechnicalInformationHandler; |
---|
41 | import weka.core.Utils; |
---|
42 | import weka.core.WeightedInstancesHandler; |
---|
43 | import weka.core.Capabilities.Capability; |
---|
44 | import weka.core.TechnicalInformation.Field; |
---|
45 | import weka.core.TechnicalInformation.Type; |
---|
46 | import weka.core.AdditionalMeasureProducer; |
---|
47 | |
---|
48 | import java.util.Enumeration; |
---|
49 | import java.util.Vector; |
---|
50 | |
---|
51 | /** |
---|
52 | <!-- globalinfo-start --> |
---|
53 | * K-nearest neighbours classifier. Can select appropriate value of K based on cross-validation. Can also do distance weighting.<br/> |
---|
54 | * <br/> |
---|
55 | * For more information, see<br/> |
---|
56 | * <br/> |
---|
57 | * D. Aha, D. Kibler (1991). Instance-based learning algorithms. Machine Learning. 6:37-66. |
---|
58 | * <p/> |
---|
59 | <!-- globalinfo-end --> |
---|
60 | * |
---|
61 | <!-- technical-bibtex-start --> |
---|
62 | * BibTeX: |
---|
63 | * <pre> |
---|
64 | * @article{Aha1991, |
---|
65 | * author = {D. Aha and D. Kibler}, |
---|
66 | * journal = {Machine Learning}, |
---|
67 | * pages = {37-66}, |
---|
68 | * title = {Instance-based learning algorithms}, |
---|
69 | * volume = {6}, |
---|
70 | * year = {1991} |
---|
71 | * } |
---|
72 | * </pre> |
---|
73 | * <p/> |
---|
74 | <!-- technical-bibtex-end --> |
---|
75 | * |
---|
76 | <!-- options-start --> |
---|
77 | * Valid options are: <p/> |
---|
78 | * |
---|
79 | * <pre> -I |
---|
80 | * Weight neighbours by the inverse of their distance |
---|
81 | * (use when k > 1)</pre> |
---|
82 | * |
---|
83 | * <pre> -F |
---|
84 | * Weight neighbours by 1 - their distance |
---|
85 | * (use when k > 1)</pre> |
---|
86 | * |
---|
87 | * <pre> -K <number of neighbors> |
---|
88 | * Number of nearest neighbours (k) used in classification. |
---|
89 | * (Default = 1)</pre> |
---|
90 | * |
---|
91 | * <pre> -E |
---|
92 | * Minimise mean squared error rather than mean absolute |
---|
93 | * error when using -X option with numeric prediction.</pre> |
---|
94 | * |
---|
95 | * <pre> -W <window size> |
---|
96 | * Maximum number of training instances maintained. |
---|
97 | * Training instances are dropped FIFO. (Default = no window)</pre> |
---|
98 | * |
---|
99 | * <pre> -X |
---|
100 | * Select the number of nearest neighbours between 1 |
---|
101 | * and the k value specified using hold-one-out evaluation |
---|
102 | * on the training data (use when k > 1)</pre> |
---|
103 | * |
---|
104 | * <pre> -A |
---|
105 | * The nearest neighbour search algorithm to use (default: weka.core.neighboursearch.LinearNNSearch). |
---|
106 | * </pre> |
---|
107 | * |
---|
108 | <!-- options-end --> |
---|
109 | * |
---|
110 | * @author Stuart Inglis (singlis@cs.waikato.ac.nz) |
---|
111 | * @author Len Trigg (trigg@cs.waikato.ac.nz) |
---|
112 | * @author Eibe Frank (eibe@cs.waikato.ac.nz) |
---|
113 | * @version $Revision: 5928 $ |
---|
114 | */ |
---|
115 | public class IBk |
---|
116 | extends AbstractClassifier |
---|
117 | implements OptionHandler, UpdateableClassifier, WeightedInstancesHandler, |
---|
118 | TechnicalInformationHandler, AdditionalMeasureProducer { |
---|
119 | |
---|
120 | /** for serialization. */ |
---|
121 | static final long serialVersionUID = -3080186098777067172L; |
---|
122 | |
---|
123 | /** The training instances used for classification. */ |
---|
124 | protected Instances m_Train; |
---|
125 | |
---|
126 | /** The number of class values (or 1 if predicting numeric). */ |
---|
127 | protected int m_NumClasses; |
---|
128 | |
---|
129 | /** The class attribute type. */ |
---|
130 | protected int m_ClassType; |
---|
131 | |
---|
132 | /** The number of neighbours to use for classification (currently). */ |
---|
133 | protected int m_kNN; |
---|
134 | |
---|
135 | /** |
---|
136 | * The value of kNN provided by the user. This may differ from |
---|
137 | * m_kNN if cross-validation is being used. |
---|
138 | */ |
---|
139 | protected int m_kNNUpper; |
---|
140 | |
---|
141 | /** |
---|
142 | * Whether the value of k selected by cross validation has |
---|
143 | * been invalidated by a change in the training instances. |
---|
144 | */ |
---|
145 | protected boolean m_kNNValid; |
---|
146 | |
---|
147 | /** |
---|
148 | * The maximum number of training instances allowed. When |
---|
149 | * this limit is reached, old training instances are removed, |
---|
150 | * so the training data is "windowed". Set to 0 for unlimited |
---|
151 | * numbers of instances. |
---|
152 | */ |
---|
153 | protected int m_WindowSize; |
---|
154 | |
---|
155 | /** Whether the neighbours should be distance-weighted. */ |
---|
156 | protected int m_DistanceWeighting; |
---|
157 | |
---|
158 | /** Whether to select k by cross validation. */ |
---|
159 | protected boolean m_CrossValidate; |
---|
160 | |
---|
161 | /** |
---|
162 | * Whether to minimise mean squared error rather than mean absolute |
---|
163 | * error when cross-validating on numeric prediction tasks. |
---|
164 | */ |
---|
165 | protected boolean m_MeanSquared; |
---|
166 | |
---|
167 | /** no weighting. */ |
---|
168 | public static final int WEIGHT_NONE = 1; |
---|
169 | /** weight by 1/distance. */ |
---|
170 | public static final int WEIGHT_INVERSE = 2; |
---|
171 | /** weight by 1-distance. */ |
---|
172 | public static final int WEIGHT_SIMILARITY = 4; |
---|
173 | /** possible instance weighting methods. */ |
---|
174 | public static final Tag [] TAGS_WEIGHTING = { |
---|
175 | new Tag(WEIGHT_NONE, "No distance weighting"), |
---|
176 | new Tag(WEIGHT_INVERSE, "Weight by 1/distance"), |
---|
177 | new Tag(WEIGHT_SIMILARITY, "Weight by 1-distance") |
---|
178 | }; |
---|
179 | |
---|
180 | /** for nearest-neighbor search. */ |
---|
181 | protected NearestNeighbourSearch m_NNSearch = new LinearNNSearch(); |
---|
182 | |
---|
183 | /** The number of attributes the contribute to a prediction. */ |
---|
184 | protected double m_NumAttributesUsed; |
---|
185 | |
---|
186 | /** |
---|
187 | * IBk classifier. Simple instance-based learner that uses the class |
---|
188 | * of the nearest k training instances for the class of the test |
---|
189 | * instances. |
---|
190 | * |
---|
191 | * @param k the number of nearest neighbors to use for prediction |
---|
192 | */ |
---|
193 | public IBk(int k) { |
---|
194 | |
---|
195 | init(); |
---|
196 | setKNN(k); |
---|
197 | } |
---|
198 | |
---|
199 | /** |
---|
200 | * IB1 classifer. Instance-based learner. Predicts the class of the |
---|
201 | * single nearest training instance for each test instance. |
---|
202 | */ |
---|
203 | public IBk() { |
---|
204 | |
---|
205 | init(); |
---|
206 | } |
---|
207 | |
---|
208 | /** |
---|
209 | * Returns a string describing classifier. |
---|
210 | * @return a description suitable for |
---|
211 | * displaying in the explorer/experimenter gui |
---|
212 | */ |
---|
213 | public String globalInfo() { |
---|
214 | |
---|
215 | return "K-nearest neighbours classifier. Can " |
---|
216 | + "select appropriate value of K based on cross-validation. Can also do " |
---|
217 | + "distance weighting.\n\n" |
---|
218 | + "For more information, see\n\n" |
---|
219 | + getTechnicalInformation().toString(); |
---|
220 | } |
---|
221 | |
---|
222 | /** |
---|
223 | * Returns an instance of a TechnicalInformation object, containing |
---|
224 | * detailed information about the technical background of this class, |
---|
225 | * e.g., paper reference or book this class is based on. |
---|
226 | * |
---|
227 | * @return the technical information about this class |
---|
228 | */ |
---|
229 | public TechnicalInformation getTechnicalInformation() { |
---|
230 | TechnicalInformation result; |
---|
231 | |
---|
232 | result = new TechnicalInformation(Type.ARTICLE); |
---|
233 | result.setValue(Field.AUTHOR, "D. Aha and D. Kibler"); |
---|
234 | result.setValue(Field.YEAR, "1991"); |
---|
235 | result.setValue(Field.TITLE, "Instance-based learning algorithms"); |
---|
236 | result.setValue(Field.JOURNAL, "Machine Learning"); |
---|
237 | result.setValue(Field.VOLUME, "6"); |
---|
238 | result.setValue(Field.PAGES, "37-66"); |
---|
239 | |
---|
240 | return result; |
---|
241 | } |
---|
242 | |
---|
243 | /** |
---|
244 | * Returns the tip text for this property. |
---|
245 | * @return tip text for this property suitable for |
---|
246 | * displaying in the explorer/experimenter gui |
---|
247 | */ |
---|
248 | public String KNNTipText() { |
---|
249 | return "The number of neighbours to use."; |
---|
250 | } |
---|
251 | |
---|
252 | /** |
---|
253 | * Set the number of neighbours the learner is to use. |
---|
254 | * |
---|
255 | * @param k the number of neighbours. |
---|
256 | */ |
---|
257 | public void setKNN(int k) { |
---|
258 | m_kNN = k; |
---|
259 | m_kNNUpper = k; |
---|
260 | m_kNNValid = false; |
---|
261 | } |
---|
262 | |
---|
263 | /** |
---|
264 | * Gets the number of neighbours the learner will use. |
---|
265 | * |
---|
266 | * @return the number of neighbours. |
---|
267 | */ |
---|
268 | public int getKNN() { |
---|
269 | |
---|
270 | return m_kNN; |
---|
271 | } |
---|
272 | |
---|
273 | /** |
---|
274 | * Returns the tip text for this property. |
---|
275 | * @return tip text for this property suitable for |
---|
276 | * displaying in the explorer/experimenter gui |
---|
277 | */ |
---|
278 | public String windowSizeTipText() { |
---|
279 | return "Gets the maximum number of instances allowed in the training " + |
---|
280 | "pool. The addition of new instances above this value will result " + |
---|
281 | "in old instances being removed. A value of 0 signifies no limit " + |
---|
282 | "to the number of training instances."; |
---|
283 | } |
---|
284 | |
---|
285 | /** |
---|
286 | * Gets the maximum number of instances allowed in the training |
---|
287 | * pool. The addition of new instances above this value will result |
---|
288 | * in old instances being removed. A value of 0 signifies no limit |
---|
289 | * to the number of training instances. |
---|
290 | * |
---|
291 | * @return Value of WindowSize. |
---|
292 | */ |
---|
293 | public int getWindowSize() { |
---|
294 | |
---|
295 | return m_WindowSize; |
---|
296 | } |
---|
297 | |
---|
298 | /** |
---|
299 | * Sets the maximum number of instances allowed in the training |
---|
300 | * pool. The addition of new instances above this value will result |
---|
301 | * in old instances being removed. A value of 0 signifies no limit |
---|
302 | * to the number of training instances. |
---|
303 | * |
---|
304 | * @param newWindowSize Value to assign to WindowSize. |
---|
305 | */ |
---|
306 | public void setWindowSize(int newWindowSize) { |
---|
307 | |
---|
308 | m_WindowSize = newWindowSize; |
---|
309 | } |
---|
310 | |
---|
311 | /** |
---|
312 | * Returns the tip text for this property. |
---|
313 | * @return tip text for this property suitable for |
---|
314 | * displaying in the explorer/experimenter gui |
---|
315 | */ |
---|
316 | public String distanceWeightingTipText() { |
---|
317 | |
---|
318 | return "Gets the distance weighting method used."; |
---|
319 | } |
---|
320 | |
---|
321 | /** |
---|
322 | * Gets the distance weighting method used. Will be one of |
---|
323 | * WEIGHT_NONE, WEIGHT_INVERSE, or WEIGHT_SIMILARITY |
---|
324 | * |
---|
325 | * @return the distance weighting method used. |
---|
326 | */ |
---|
327 | public SelectedTag getDistanceWeighting() { |
---|
328 | |
---|
329 | return new SelectedTag(m_DistanceWeighting, TAGS_WEIGHTING); |
---|
330 | } |
---|
331 | |
---|
332 | /** |
---|
333 | * Sets the distance weighting method used. Values other than |
---|
334 | * WEIGHT_NONE, WEIGHT_INVERSE, or WEIGHT_SIMILARITY will be ignored. |
---|
335 | * |
---|
336 | * @param newMethod the distance weighting method to use |
---|
337 | */ |
---|
338 | public void setDistanceWeighting(SelectedTag newMethod) { |
---|
339 | |
---|
340 | if (newMethod.getTags() == TAGS_WEIGHTING) { |
---|
341 | m_DistanceWeighting = newMethod.getSelectedTag().getID(); |
---|
342 | } |
---|
343 | } |
---|
344 | |
---|
345 | /** |
---|
346 | * Returns the tip text for this property. |
---|
347 | * @return tip text for this property suitable for |
---|
348 | * displaying in the explorer/experimenter gui |
---|
349 | */ |
---|
350 | public String meanSquaredTipText() { |
---|
351 | |
---|
352 | return "Whether the mean squared error is used rather than mean " |
---|
353 | + "absolute error when doing cross-validation for regression problems."; |
---|
354 | } |
---|
355 | |
---|
356 | /** |
---|
357 | * Gets whether the mean squared error is used rather than mean |
---|
358 | * absolute error when doing cross-validation. |
---|
359 | * |
---|
360 | * @return true if so. |
---|
361 | */ |
---|
362 | public boolean getMeanSquared() { |
---|
363 | |
---|
364 | return m_MeanSquared; |
---|
365 | } |
---|
366 | |
---|
367 | /** |
---|
368 | * Sets whether the mean squared error is used rather than mean |
---|
369 | * absolute error when doing cross-validation. |
---|
370 | * |
---|
371 | * @param newMeanSquared true if so. |
---|
372 | */ |
---|
373 | public void setMeanSquared(boolean newMeanSquared) { |
---|
374 | |
---|
375 | m_MeanSquared = newMeanSquared; |
---|
376 | } |
---|
377 | |
---|
378 | /** |
---|
379 | * Returns the tip text for this property. |
---|
380 | * @return tip text for this property suitable for |
---|
381 | * displaying in the explorer/experimenter gui |
---|
382 | */ |
---|
383 | public String crossValidateTipText() { |
---|
384 | |
---|
385 | return "Whether hold-one-out cross-validation will be used " + |
---|
386 | "to select the best k value."; |
---|
387 | } |
---|
388 | |
---|
389 | /** |
---|
390 | * Gets whether hold-one-out cross-validation will be used |
---|
391 | * to select the best k value. |
---|
392 | * |
---|
393 | * @return true if cross-validation will be used. |
---|
394 | */ |
---|
395 | public boolean getCrossValidate() { |
---|
396 | |
---|
397 | return m_CrossValidate; |
---|
398 | } |
---|
399 | |
---|
400 | /** |
---|
401 | * Sets whether hold-one-out cross-validation will be used |
---|
402 | * to select the best k value. |
---|
403 | * |
---|
404 | * @param newCrossValidate true if cross-validation should be used. |
---|
405 | */ |
---|
406 | public void setCrossValidate(boolean newCrossValidate) { |
---|
407 | |
---|
408 | m_CrossValidate = newCrossValidate; |
---|
409 | } |
---|
410 | |
---|
411 | /** |
---|
412 | * Returns the tip text for this property. |
---|
413 | * @return tip text for this property suitable for |
---|
414 | * displaying in the explorer/experimenter gui |
---|
415 | */ |
---|
416 | public String nearestNeighbourSearchAlgorithmTipText() { |
---|
417 | return "The nearest neighbour search algorithm to use " + |
---|
418 | "(Default: weka.core.neighboursearch.LinearNNSearch)."; |
---|
419 | } |
---|
420 | |
---|
421 | /** |
---|
422 | * Returns the current nearestNeighbourSearch algorithm in use. |
---|
423 | * @return the NearestNeighbourSearch algorithm currently in use. |
---|
424 | */ |
---|
425 | public NearestNeighbourSearch getNearestNeighbourSearchAlgorithm() { |
---|
426 | return m_NNSearch; |
---|
427 | } |
---|
428 | |
---|
429 | /** |
---|
430 | * Sets the nearestNeighbourSearch algorithm to be used for finding nearest |
---|
431 | * neighbour(s). |
---|
432 | * @param nearestNeighbourSearchAlgorithm - The NearestNeighbourSearch class. |
---|
433 | */ |
---|
434 | public void setNearestNeighbourSearchAlgorithm(NearestNeighbourSearch nearestNeighbourSearchAlgorithm) { |
---|
435 | m_NNSearch = nearestNeighbourSearchAlgorithm; |
---|
436 | } |
---|
437 | |
---|
438 | /** |
---|
439 | * Get the number of training instances the classifier is currently using. |
---|
440 | * |
---|
441 | * @return the number of training instances the classifier is currently using |
---|
442 | */ |
---|
443 | public int getNumTraining() { |
---|
444 | |
---|
445 | return m_Train.numInstances(); |
---|
446 | } |
---|
447 | |
---|
448 | /** |
---|
449 | * Returns default capabilities of the classifier. |
---|
450 | * |
---|
451 | * @return the capabilities of this classifier |
---|
452 | */ |
---|
453 | public Capabilities getCapabilities() { |
---|
454 | Capabilities result = super.getCapabilities(); |
---|
455 | result.disableAll(); |
---|
456 | |
---|
457 | // attributes |
---|
458 | result.enable(Capability.NOMINAL_ATTRIBUTES); |
---|
459 | result.enable(Capability.NUMERIC_ATTRIBUTES); |
---|
460 | result.enable(Capability.DATE_ATTRIBUTES); |
---|
461 | result.enable(Capability.MISSING_VALUES); |
---|
462 | |
---|
463 | // class |
---|
464 | result.enable(Capability.NOMINAL_CLASS); |
---|
465 | result.enable(Capability.NUMERIC_CLASS); |
---|
466 | result.enable(Capability.DATE_CLASS); |
---|
467 | result.enable(Capability.MISSING_CLASS_VALUES); |
---|
468 | |
---|
469 | // instances |
---|
470 | result.setMinimumNumberInstances(0); |
---|
471 | |
---|
472 | return result; |
---|
473 | } |
---|
474 | |
---|
475 | /** |
---|
476 | * Generates the classifier. |
---|
477 | * |
---|
478 | * @param instances set of instances serving as training data |
---|
479 | * @throws Exception if the classifier has not been generated successfully |
---|
480 | */ |
---|
481 | public void buildClassifier(Instances instances) throws Exception { |
---|
482 | |
---|
483 | // can classifier handle the data? |
---|
484 | getCapabilities().testWithFail(instances); |
---|
485 | |
---|
486 | // remove instances with missing class |
---|
487 | instances = new Instances(instances); |
---|
488 | instances.deleteWithMissingClass(); |
---|
489 | |
---|
490 | m_NumClasses = instances.numClasses(); |
---|
491 | m_ClassType = instances.classAttribute().type(); |
---|
492 | m_Train = new Instances(instances, 0, instances.numInstances()); |
---|
493 | |
---|
494 | // Throw away initial instances until within the specified window size |
---|
495 | if ((m_WindowSize > 0) && (instances.numInstances() > m_WindowSize)) { |
---|
496 | m_Train = new Instances(m_Train, |
---|
497 | m_Train.numInstances()-m_WindowSize, |
---|
498 | m_WindowSize); |
---|
499 | } |
---|
500 | |
---|
501 | m_NumAttributesUsed = 0.0; |
---|
502 | for (int i = 0; i < m_Train.numAttributes(); i++) { |
---|
503 | if ((i != m_Train.classIndex()) && |
---|
504 | (m_Train.attribute(i).isNominal() || |
---|
505 | m_Train.attribute(i).isNumeric())) { |
---|
506 | m_NumAttributesUsed += 1.0; |
---|
507 | } |
---|
508 | } |
---|
509 | |
---|
510 | m_NNSearch.setInstances(m_Train); |
---|
511 | |
---|
512 | // Invalidate any currently cross-validation selected k |
---|
513 | m_kNNValid = false; |
---|
514 | } |
---|
515 | |
---|
516 | /** |
---|
517 | * Adds the supplied instance to the training set. |
---|
518 | * |
---|
519 | * @param instance the instance to add |
---|
520 | * @throws Exception if instance could not be incorporated |
---|
521 | * successfully |
---|
522 | */ |
---|
523 | public void updateClassifier(Instance instance) throws Exception { |
---|
524 | |
---|
525 | if (m_Train.equalHeaders(instance.dataset()) == false) { |
---|
526 | throw new Exception("Incompatible instance types\n" + m_Train.equalHeadersMsg(instance.dataset())); |
---|
527 | } |
---|
528 | if (instance.classIsMissing()) { |
---|
529 | return; |
---|
530 | } |
---|
531 | |
---|
532 | m_Train.add(instance); |
---|
533 | m_NNSearch.update(instance); |
---|
534 | m_kNNValid = false; |
---|
535 | if ((m_WindowSize > 0) && (m_Train.numInstances() > m_WindowSize)) { |
---|
536 | boolean deletedInstance=false; |
---|
537 | while (m_Train.numInstances() > m_WindowSize) { |
---|
538 | m_Train.delete(0); |
---|
539 | deletedInstance=true; |
---|
540 | } |
---|
541 | //rebuild datastructure KDTree currently can't delete |
---|
542 | if(deletedInstance==true) |
---|
543 | m_NNSearch.setInstances(m_Train); |
---|
544 | } |
---|
545 | } |
---|
546 | |
---|
547 | /** |
---|
548 | * Calculates the class membership probabilities for the given test instance. |
---|
549 | * |
---|
550 | * @param instance the instance to be classified |
---|
551 | * @return predicted class probability distribution |
---|
552 | * @throws Exception if an error occurred during the prediction |
---|
553 | */ |
---|
554 | public double [] distributionForInstance(Instance instance) throws Exception { |
---|
555 | |
---|
556 | if (m_Train.numInstances() == 0) { |
---|
557 | throw new Exception("No training instances!"); |
---|
558 | } |
---|
559 | if ((m_WindowSize > 0) && (m_Train.numInstances() > m_WindowSize)) { |
---|
560 | m_kNNValid = false; |
---|
561 | boolean deletedInstance=false; |
---|
562 | while (m_Train.numInstances() > m_WindowSize) { |
---|
563 | m_Train.delete(0); |
---|
564 | } |
---|
565 | //rebuild datastructure KDTree currently can't delete |
---|
566 | if(deletedInstance==true) |
---|
567 | m_NNSearch.setInstances(m_Train); |
---|
568 | } |
---|
569 | |
---|
570 | // Select k by cross validation |
---|
571 | if (!m_kNNValid && (m_CrossValidate) && (m_kNNUpper >= 1)) { |
---|
572 | crossValidate(); |
---|
573 | } |
---|
574 | |
---|
575 | m_NNSearch.addInstanceInfo(instance); |
---|
576 | |
---|
577 | Instances neighbours = m_NNSearch.kNearestNeighbours(instance, m_kNN); |
---|
578 | double [] distances = m_NNSearch.getDistances(); |
---|
579 | double [] distribution = makeDistribution( neighbours, distances ); |
---|
580 | |
---|
581 | return distribution; |
---|
582 | } |
---|
583 | |
---|
584 | /** |
---|
585 | * Returns an enumeration describing the available options. |
---|
586 | * |
---|
587 | * @return an enumeration of all the available options. |
---|
588 | */ |
---|
589 | public Enumeration listOptions() { |
---|
590 | |
---|
591 | Vector newVector = new Vector(8); |
---|
592 | |
---|
593 | newVector.addElement(new Option( |
---|
594 | "\tWeight neighbours by the inverse of their distance\n"+ |
---|
595 | "\t(use when k > 1)", |
---|
596 | "I", 0, "-I")); |
---|
597 | newVector.addElement(new Option( |
---|
598 | "\tWeight neighbours by 1 - their distance\n"+ |
---|
599 | "\t(use when k > 1)", |
---|
600 | "F", 0, "-F")); |
---|
601 | newVector.addElement(new Option( |
---|
602 | "\tNumber of nearest neighbours (k) used in classification.\n"+ |
---|
603 | "\t(Default = 1)", |
---|
604 | "K", 1,"-K <number of neighbors>")); |
---|
605 | newVector.addElement(new Option( |
---|
606 | "\tMinimise mean squared error rather than mean absolute\n"+ |
---|
607 | "\terror when using -X option with numeric prediction.", |
---|
608 | "E", 0,"-E")); |
---|
609 | newVector.addElement(new Option( |
---|
610 | "\tMaximum number of training instances maintained.\n"+ |
---|
611 | "\tTraining instances are dropped FIFO. (Default = no window)", |
---|
612 | "W", 1,"-W <window size>")); |
---|
613 | newVector.addElement(new Option( |
---|
614 | "\tSelect the number of nearest neighbours between 1\n"+ |
---|
615 | "\tand the k value specified using hold-one-out evaluation\n"+ |
---|
616 | "\ton the training data (use when k > 1)", |
---|
617 | "X", 0,"-X")); |
---|
618 | newVector.addElement(new Option( |
---|
619 | "\tThe nearest neighbour search algorithm to use "+ |
---|
620 | "(default: weka.core.neighboursearch.LinearNNSearch).\n", |
---|
621 | "A", 0, "-A")); |
---|
622 | |
---|
623 | return newVector.elements(); |
---|
624 | } |
---|
625 | |
---|
626 | /** |
---|
627 | * Parses a given list of options. <p/> |
---|
628 | * |
---|
629 | <!-- options-start --> |
---|
630 | * Valid options are: <p/> |
---|
631 | * |
---|
632 | * <pre> -I |
---|
633 | * Weight neighbours by the inverse of their distance |
---|
634 | * (use when k > 1)</pre> |
---|
635 | * |
---|
636 | * <pre> -F |
---|
637 | * Weight neighbours by 1 - their distance |
---|
638 | * (use when k > 1)</pre> |
---|
639 | * |
---|
640 | * <pre> -K <number of neighbors> |
---|
641 | * Number of nearest neighbours (k) used in classification. |
---|
642 | * (Default = 1)</pre> |
---|
643 | * |
---|
644 | * <pre> -E |
---|
645 | * Minimise mean squared error rather than mean absolute |
---|
646 | * error when using -X option with numeric prediction.</pre> |
---|
647 | * |
---|
648 | * <pre> -W <window size> |
---|
649 | * Maximum number of training instances maintained. |
---|
650 | * Training instances are dropped FIFO. (Default = no window)</pre> |
---|
651 | * |
---|
652 | * <pre> -X |
---|
653 | * Select the number of nearest neighbours between 1 |
---|
654 | * and the k value specified using hold-one-out evaluation |
---|
655 | * on the training data (use when k > 1)</pre> |
---|
656 | * |
---|
657 | * <pre> -A |
---|
658 | * The nearest neighbour search algorithm to use (default: weka.core.neighboursearch.LinearNNSearch). |
---|
659 | * </pre> |
---|
660 | * |
---|
661 | <!-- options-end --> |
---|
662 | * |
---|
663 | * @param options the list of options as an array of strings |
---|
664 | * @throws Exception if an option is not supported |
---|
665 | */ |
---|
666 | public void setOptions(String[] options) throws Exception { |
---|
667 | |
---|
668 | String knnString = Utils.getOption('K', options); |
---|
669 | if (knnString.length() != 0) { |
---|
670 | setKNN(Integer.parseInt(knnString)); |
---|
671 | } else { |
---|
672 | setKNN(1); |
---|
673 | } |
---|
674 | String windowString = Utils.getOption('W', options); |
---|
675 | if (windowString.length() != 0) { |
---|
676 | setWindowSize(Integer.parseInt(windowString)); |
---|
677 | } else { |
---|
678 | setWindowSize(0); |
---|
679 | } |
---|
680 | if (Utils.getFlag('I', options)) { |
---|
681 | setDistanceWeighting(new SelectedTag(WEIGHT_INVERSE, TAGS_WEIGHTING)); |
---|
682 | } else if (Utils.getFlag('F', options)) { |
---|
683 | setDistanceWeighting(new SelectedTag(WEIGHT_SIMILARITY, TAGS_WEIGHTING)); |
---|
684 | } else { |
---|
685 | setDistanceWeighting(new SelectedTag(WEIGHT_NONE, TAGS_WEIGHTING)); |
---|
686 | } |
---|
687 | setCrossValidate(Utils.getFlag('X', options)); |
---|
688 | setMeanSquared(Utils.getFlag('E', options)); |
---|
689 | |
---|
690 | String nnSearchClass = Utils.getOption('A', options); |
---|
691 | if(nnSearchClass.length() != 0) { |
---|
692 | String nnSearchClassSpec[] = Utils.splitOptions(nnSearchClass); |
---|
693 | if(nnSearchClassSpec.length == 0) { |
---|
694 | throw new Exception("Invalid NearestNeighbourSearch algorithm " + |
---|
695 | "specification string."); |
---|
696 | } |
---|
697 | String className = nnSearchClassSpec[0]; |
---|
698 | nnSearchClassSpec[0] = ""; |
---|
699 | |
---|
700 | setNearestNeighbourSearchAlgorithm( (NearestNeighbourSearch) |
---|
701 | Utils.forName( NearestNeighbourSearch.class, |
---|
702 | className, |
---|
703 | nnSearchClassSpec) |
---|
704 | ); |
---|
705 | } |
---|
706 | else |
---|
707 | this.setNearestNeighbourSearchAlgorithm(new LinearNNSearch()); |
---|
708 | |
---|
709 | Utils.checkForRemainingOptions(options); |
---|
710 | } |
---|
711 | |
---|
712 | /** |
---|
713 | * Gets the current settings of IBk. |
---|
714 | * |
---|
715 | * @return an array of strings suitable for passing to setOptions() |
---|
716 | */ |
---|
717 | public String [] getOptions() { |
---|
718 | |
---|
719 | String [] options = new String [11]; |
---|
720 | int current = 0; |
---|
721 | options[current++] = "-K"; options[current++] = "" + getKNN(); |
---|
722 | options[current++] = "-W"; options[current++] = "" + m_WindowSize; |
---|
723 | if (getCrossValidate()) { |
---|
724 | options[current++] = "-X"; |
---|
725 | } |
---|
726 | if (getMeanSquared()) { |
---|
727 | options[current++] = "-E"; |
---|
728 | } |
---|
729 | if (m_DistanceWeighting == WEIGHT_INVERSE) { |
---|
730 | options[current++] = "-I"; |
---|
731 | } else if (m_DistanceWeighting == WEIGHT_SIMILARITY) { |
---|
732 | options[current++] = "-F"; |
---|
733 | } |
---|
734 | |
---|
735 | options[current++] = "-A"; |
---|
736 | options[current++] = m_NNSearch.getClass().getName()+" "+Utils.joinOptions(m_NNSearch.getOptions()); |
---|
737 | |
---|
738 | while (current < options.length) { |
---|
739 | options[current++] = ""; |
---|
740 | } |
---|
741 | |
---|
742 | return options; |
---|
743 | } |
---|
744 | |
---|
745 | /** |
---|
746 | * Returns an enumeration of the additional measure names |
---|
747 | * produced by the neighbour search algorithm, plus the chosen K in case |
---|
748 | * cross-validation is enabled. |
---|
749 | * |
---|
750 | * @return an enumeration of the measure names |
---|
751 | */ |
---|
752 | public Enumeration enumerateMeasures() { |
---|
753 | if (m_CrossValidate) { |
---|
754 | Enumeration enm = m_NNSearch.enumerateMeasures(); |
---|
755 | Vector measures = new Vector(); |
---|
756 | while (enm.hasMoreElements()) |
---|
757 | measures.add(enm.nextElement()); |
---|
758 | measures.add("measureKNN"); |
---|
759 | return measures.elements(); |
---|
760 | } |
---|
761 | else { |
---|
762 | return m_NNSearch.enumerateMeasures(); |
---|
763 | } |
---|
764 | } |
---|
765 | |
---|
766 | /** |
---|
767 | * Returns the value of the named measure from the |
---|
768 | * neighbour search algorithm, plus the chosen K in case |
---|
769 | * cross-validation is enabled. |
---|
770 | * |
---|
771 | * @param additionalMeasureName the name of the measure to query for its value |
---|
772 | * @return the value of the named measure |
---|
773 | * @throws IllegalArgumentException if the named measure is not supported |
---|
774 | */ |
---|
775 | public double getMeasure(String additionalMeasureName) { |
---|
776 | if (additionalMeasureName.equals("measureKNN")) |
---|
777 | return m_kNN; |
---|
778 | else |
---|
779 | return m_NNSearch.getMeasure(additionalMeasureName); |
---|
780 | } |
---|
781 | |
---|
782 | |
---|
783 | /** |
---|
784 | * Returns a description of this classifier. |
---|
785 | * |
---|
786 | * @return a description of this classifier as a string. |
---|
787 | */ |
---|
788 | public String toString() { |
---|
789 | |
---|
790 | if (m_Train == null) { |
---|
791 | return "IBk: No model built yet."; |
---|
792 | } |
---|
793 | |
---|
794 | if (!m_kNNValid && m_CrossValidate) { |
---|
795 | crossValidate(); |
---|
796 | } |
---|
797 | |
---|
798 | String result = "IB1 instance-based classifier\n" + |
---|
799 | "using " + m_kNN; |
---|
800 | |
---|
801 | switch (m_DistanceWeighting) { |
---|
802 | case WEIGHT_INVERSE: |
---|
803 | result += " inverse-distance-weighted"; |
---|
804 | break; |
---|
805 | case WEIGHT_SIMILARITY: |
---|
806 | result += " similarity-weighted"; |
---|
807 | break; |
---|
808 | } |
---|
809 | result += " nearest neighbour(s) for classification\n"; |
---|
810 | |
---|
811 | if (m_WindowSize != 0) { |
---|
812 | result += "using a maximum of " |
---|
813 | + m_WindowSize + " (windowed) training instances\n"; |
---|
814 | } |
---|
815 | return result; |
---|
816 | } |
---|
817 | |
---|
818 | /** |
---|
819 | * Initialise scheme variables. |
---|
820 | */ |
---|
821 | protected void init() { |
---|
822 | |
---|
823 | setKNN(1); |
---|
824 | m_WindowSize = 0; |
---|
825 | m_DistanceWeighting = WEIGHT_NONE; |
---|
826 | m_CrossValidate = false; |
---|
827 | m_MeanSquared = false; |
---|
828 | } |
---|
829 | |
---|
830 | /** |
---|
831 | * Turn the list of nearest neighbors into a probability distribution. |
---|
832 | * |
---|
833 | * @param neighbours the list of nearest neighboring instances |
---|
834 | * @param distances the distances of the neighbors |
---|
835 | * @return the probability distribution |
---|
836 | * @throws Exception if computation goes wrong or has no class attribute |
---|
837 | */ |
---|
838 | protected double [] makeDistribution(Instances neighbours, double[] distances) |
---|
839 | throws Exception { |
---|
840 | |
---|
841 | double total = 0, weight; |
---|
842 | double [] distribution = new double [m_NumClasses]; |
---|
843 | |
---|
844 | // Set up a correction to the estimator |
---|
845 | if (m_ClassType == Attribute.NOMINAL) { |
---|
846 | for(int i = 0; i < m_NumClasses; i++) { |
---|
847 | distribution[i] = 1.0 / Math.max(1,m_Train.numInstances()); |
---|
848 | } |
---|
849 | total = (double)m_NumClasses / Math.max(1,m_Train.numInstances()); |
---|
850 | } |
---|
851 | |
---|
852 | for(int i=0; i < neighbours.numInstances(); i++) { |
---|
853 | // Collect class counts |
---|
854 | Instance current = neighbours.instance(i); |
---|
855 | distances[i] = distances[i]*distances[i]; |
---|
856 | distances[i] = Math.sqrt(distances[i]/m_NumAttributesUsed); |
---|
857 | switch (m_DistanceWeighting) { |
---|
858 | case WEIGHT_INVERSE: |
---|
859 | weight = 1.0 / (distances[i] + 0.001); // to avoid div by zero |
---|
860 | break; |
---|
861 | case WEIGHT_SIMILARITY: |
---|
862 | weight = 1.0 - distances[i]; |
---|
863 | break; |
---|
864 | default: // WEIGHT_NONE: |
---|
865 | weight = 1.0; |
---|
866 | break; |
---|
867 | } |
---|
868 | weight *= current.weight(); |
---|
869 | try { |
---|
870 | switch (m_ClassType) { |
---|
871 | case Attribute.NOMINAL: |
---|
872 | distribution[(int)current.classValue()] += weight; |
---|
873 | break; |
---|
874 | case Attribute.NUMERIC: |
---|
875 | distribution[0] += current.classValue() * weight; |
---|
876 | break; |
---|
877 | } |
---|
878 | } catch (Exception ex) { |
---|
879 | throw new Error("Data has no class attribute!"); |
---|
880 | } |
---|
881 | total += weight; |
---|
882 | } |
---|
883 | |
---|
884 | // Normalise distribution |
---|
885 | if (total > 0) { |
---|
886 | Utils.normalize(distribution, total); |
---|
887 | } |
---|
888 | return distribution; |
---|
889 | } |
---|
890 | |
---|
891 | /** |
---|
892 | * Select the best value for k by hold-one-out cross-validation. |
---|
893 | * If the class attribute is nominal, classification error is |
---|
894 | * minimised. If the class attribute is numeric, mean absolute |
---|
895 | * error is minimised |
---|
896 | */ |
---|
897 | protected void crossValidate() { |
---|
898 | |
---|
899 | try { |
---|
900 | if (m_NNSearch instanceof weka.core.neighboursearch.CoverTree) |
---|
901 | throw new Exception("CoverTree doesn't support hold-one-out "+ |
---|
902 | "cross-validation. Use some other NN " + |
---|
903 | "method."); |
---|
904 | |
---|
905 | double [] performanceStats = new double [m_kNNUpper]; |
---|
906 | double [] performanceStatsSq = new double [m_kNNUpper]; |
---|
907 | |
---|
908 | for(int i = 0; i < m_kNNUpper; i++) { |
---|
909 | performanceStats[i] = 0; |
---|
910 | performanceStatsSq[i] = 0; |
---|
911 | } |
---|
912 | |
---|
913 | |
---|
914 | m_kNN = m_kNNUpper; |
---|
915 | Instance instance; |
---|
916 | Instances neighbours; |
---|
917 | double[] origDistances, convertedDistances; |
---|
918 | for(int i = 0; i < m_Train.numInstances(); i++) { |
---|
919 | if (m_Debug && (i % 50 == 0)) { |
---|
920 | System.err.print("Cross validating " |
---|
921 | + i + "/" + m_Train.numInstances() + "\r"); |
---|
922 | } |
---|
923 | instance = m_Train.instance(i); |
---|
924 | neighbours = m_NNSearch.kNearestNeighbours(instance, m_kNN); |
---|
925 | origDistances = m_NNSearch.getDistances(); |
---|
926 | |
---|
927 | for(int j = m_kNNUpper - 1; j >= 0; j--) { |
---|
928 | // Update the performance stats |
---|
929 | convertedDistances = new double[origDistances.length]; |
---|
930 | System.arraycopy(origDistances, 0, |
---|
931 | convertedDistances, 0, origDistances.length); |
---|
932 | double [] distribution = makeDistribution(neighbours, |
---|
933 | convertedDistances); |
---|
934 | double thisPrediction = Utils.maxIndex(distribution); |
---|
935 | if (m_Train.classAttribute().isNumeric()) { |
---|
936 | thisPrediction = distribution[0]; |
---|
937 | double err = thisPrediction - instance.classValue(); |
---|
938 | performanceStatsSq[j] += err * err; // Squared error |
---|
939 | performanceStats[j] += Math.abs(err); // Absolute error |
---|
940 | } else { |
---|
941 | if (thisPrediction != instance.classValue()) { |
---|
942 | performanceStats[j] ++; // Classification error |
---|
943 | } |
---|
944 | } |
---|
945 | if (j >= 1) { |
---|
946 | neighbours = pruneToK(neighbours, convertedDistances, j); |
---|
947 | } |
---|
948 | } |
---|
949 | } |
---|
950 | |
---|
951 | // Display the results of the cross-validation |
---|
952 | for(int i = 0; i < m_kNNUpper; i++) { |
---|
953 | if (m_Debug) { |
---|
954 | System.err.print("Hold-one-out performance of " + (i + 1) |
---|
955 | + " neighbors " ); |
---|
956 | } |
---|
957 | if (m_Train.classAttribute().isNumeric()) { |
---|
958 | if (m_Debug) { |
---|
959 | if (m_MeanSquared) { |
---|
960 | System.err.println("(RMSE) = " |
---|
961 | + Math.sqrt(performanceStatsSq[i] |
---|
962 | / m_Train.numInstances())); |
---|
963 | } else { |
---|
964 | System.err.println("(MAE) = " |
---|
965 | + performanceStats[i] |
---|
966 | / m_Train.numInstances()); |
---|
967 | } |
---|
968 | } |
---|
969 | } else { |
---|
970 | if (m_Debug) { |
---|
971 | System.err.println("(%ERR) = " |
---|
972 | + 100.0 * performanceStats[i] |
---|
973 | / m_Train.numInstances()); |
---|
974 | } |
---|
975 | } |
---|
976 | } |
---|
977 | |
---|
978 | |
---|
979 | // Check through the performance stats and select the best |
---|
980 | // k value (or the lowest k if more than one best) |
---|
981 | double [] searchStats = performanceStats; |
---|
982 | if (m_Train.classAttribute().isNumeric() && m_MeanSquared) { |
---|
983 | searchStats = performanceStatsSq; |
---|
984 | } |
---|
985 | double bestPerformance = Double.NaN; |
---|
986 | int bestK = 1; |
---|
987 | for(int i = 0; i < m_kNNUpper; i++) { |
---|
988 | if (Double.isNaN(bestPerformance) |
---|
989 | || (bestPerformance > searchStats[i])) { |
---|
990 | bestPerformance = searchStats[i]; |
---|
991 | bestK = i + 1; |
---|
992 | } |
---|
993 | } |
---|
994 | m_kNN = bestK; |
---|
995 | if (m_Debug) { |
---|
996 | System.err.println("Selected k = " + bestK); |
---|
997 | } |
---|
998 | |
---|
999 | m_kNNValid = true; |
---|
1000 | } catch (Exception ex) { |
---|
1001 | throw new Error("Couldn't optimize by cross-validation: " |
---|
1002 | +ex.getMessage()); |
---|
1003 | } |
---|
1004 | } |
---|
1005 | |
---|
1006 | /** |
---|
1007 | * Prunes the list to contain the k nearest neighbors. If there are |
---|
1008 | * multiple neighbors at the k'th distance, all will be kept. |
---|
1009 | * |
---|
1010 | * @param neighbours the neighbour instances. |
---|
1011 | * @param distances the distances of the neighbours from target instance. |
---|
1012 | * @param k the number of neighbors to keep. |
---|
1013 | * @return the pruned neighbours. |
---|
1014 | */ |
---|
1015 | public Instances pruneToK(Instances neighbours, double[] distances, int k) { |
---|
1016 | |
---|
1017 | if(neighbours==null || distances==null || neighbours.numInstances()==0) { |
---|
1018 | return null; |
---|
1019 | } |
---|
1020 | if (k < 1) { |
---|
1021 | k = 1; |
---|
1022 | } |
---|
1023 | |
---|
1024 | int currentK = 0; |
---|
1025 | double currentDist; |
---|
1026 | for(int i=0; i < neighbours.numInstances(); i++) { |
---|
1027 | currentK++; |
---|
1028 | currentDist = distances[i]; |
---|
1029 | if(currentK>k && currentDist!=distances[i-1]) { |
---|
1030 | currentK--; |
---|
1031 | neighbours = new Instances(neighbours, 0, currentK); |
---|
1032 | break; |
---|
1033 | } |
---|
1034 | } |
---|
1035 | |
---|
1036 | return neighbours; |
---|
1037 | } |
---|
1038 | |
---|
1039 | /** |
---|
1040 | * Returns the revision string. |
---|
1041 | * |
---|
1042 | * @return the revision |
---|
1043 | */ |
---|
1044 | public String getRevision() { |
---|
1045 | return RevisionUtils.extract("$Revision: 5928 $"); |
---|
1046 | } |
---|
1047 | |
---|
1048 | /** |
---|
1049 | * Main method for testing this class. |
---|
1050 | * |
---|
1051 | * @param argv should contain command line options (see setOptions) |
---|
1052 | */ |
---|
1053 | public static void main(String [] argv) { |
---|
1054 | runClassifier(new IBk(), argv); |
---|
1055 | } |
---|
1056 | } |
---|