source: src/main/java/weka/classifiers/mi/MINND.java @ 4

Last change on this file since 4 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 32.6 KB
RevLine 
[4]1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 * MINND.java
19 * Copyright (C) 2005 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.mi;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.core.Capabilities;
28import weka.core.Instance;
29import weka.core.DenseInstance;
30import weka.core.Instances;
31import weka.core.MultiInstanceCapabilitiesHandler;
32import weka.core.Option;
33import weka.core.OptionHandler;
34import weka.core.RevisionUtils;
35import weka.core.TechnicalInformation;
36import weka.core.TechnicalInformationHandler;
37import weka.core.Utils;
38import weka.core.Capabilities.Capability;
39import weka.core.TechnicalInformation.Field;
40import weka.core.TechnicalInformation.Type;
41
42import java.util.Enumeration;
43import java.util.Vector;
44
45/**
46 <!-- globalinfo-start -->
47 * Multiple-Instance Nearest Neighbour with Distribution learner.<br/>
48 * <br/>
49 * It uses gradient descent to find the weight for each dimension of each exeamplar from the starting point of 1.0. In order to avoid overfitting, it uses mean-square function (i.e. the Euclidean distance) to search for the weights.<br/>
50 *  It then uses the weights to cleanse the training data. After that it searches for the weights again from the starting points of the weights searched before.<br/>
51 *  Finally it uses the most updated weights to cleanse the test exemplar and then finds the nearest neighbour of the test exemplar using partly-weighted Kullback distance. But the variances in the Kullback distance are the ones before cleansing.<br/>
52 * <br/>
53 * For more information see:<br/>
54 * <br/>
55 * Xin Xu (2001). A nearest distribution approach to multiple-instance learning. Hamilton, NZ.
56 * <p/>
57 <!-- globalinfo-end -->
58 *
59 <!-- technical-bibtex-start -->
60 * BibTeX:
61 * <pre>
62 * &#64;misc{Xu2001,
63 *    address = {Hamilton, NZ},
64 *    author = {Xin Xu},
65 *    note = {0657.591B},
66 *    school = {University of Waikato},
67 *    title = {A nearest distribution approach to multiple-instance learning},
68 *    year = {2001}
69 * }
70 * </pre>
71 * <p/>
72 <!-- technical-bibtex-end -->
73 *
74 <!-- options-start -->
75 * Valid options are: <p/>
76 *
77 * <pre> -K &lt;number of neighbours&gt;
78 *  Set number of nearest neighbour for prediction
79 *  (default 1)</pre>
80 *
81 * <pre> -S &lt;number of neighbours&gt;
82 *  Set number of nearest neighbour for cleansing the training data
83 *  (default 1)</pre>
84 *
85 * <pre> -E &lt;number of neighbours&gt;
86 *  Set number of nearest neighbour for cleansing the testing data
87 *  (default 1)</pre>
88 *
89 <!-- options-end -->
90 *
91 * @author Xin Xu (xx5@cs.waikato.ac.nz)
92 * @version $Revision: 5987 $
93 */
94public class MINND 
95  extends AbstractClassifier
96  implements OptionHandler, MultiInstanceCapabilitiesHandler,
97             TechnicalInformationHandler {
98
99  /** for serialization */
100  static final long serialVersionUID = -4512599203273864994L;
101 
102  /** The number of nearest neighbour for prediction */
103  protected int m_Neighbour = 1;
104
105  /** The mean for each attribute of each exemplar */
106  protected double[][] m_Mean = null;
107
108  /** The variance for each attribute of each exemplar */
109  protected double[][] m_Variance = null;
110
111  /** The dimension of each exemplar, i.e. (numAttributes-2) */
112  protected int m_Dimension = 0;
113
114  /** header info of the data */
115  protected Instances m_Attributes;;
116
117  /** The class label of each exemplar */
118  protected double[] m_Class = null;
119
120  /** The number of class labels in the data */
121  protected int m_NumClasses = 0;
122
123  /** The weight of each exemplar */
124  protected double[] m_Weights = null;
125
126  /** The very small number representing zero */
127  static private double m_ZERO = 1.0e-45;
128
129  /** The learning rate in the gradient descent */
130  protected double m_Rate = -1;
131
132  /** The minimum values for numeric attributes. */
133  private double [] m_MinArray=null;
134
135  /** The maximum values for numeric attributes. */
136  private double [] m_MaxArray=null;
137
138  /** The stopping criteria of gradient descent*/
139  private double m_STOP = 1.0e-45;
140
141  /** The weights that alter the dimnesion of each exemplar */
142  private double[][] m_Change=null;
143
144  /** The noise data of each exemplar */
145  private double[][] m_NoiseM = null, m_NoiseV = null, m_ValidM = null, 
146          m_ValidV = null;
147
148  /** The number of nearest neighbour instances in the selection of noises
149    in the training data*/
150  private int m_Select = 1;
151
152  /** The number of nearest neighbour exemplars in the selection of noises
153    in the test data */
154  private int m_Choose = 1;
155
156  /** The decay rate of learning rate */
157  private double m_Decay = 0.5;
158
159  /**
160   * Returns a string describing this filter
161   *
162   * @return a description of the filter suitable for
163   * displaying in the explorer/experimenter gui
164   */
165  public String globalInfo() {
166    return 
167        "Multiple-Instance Nearest Neighbour with Distribution learner.\n\n"
168      + "It uses gradient descent to find the weight for each dimension of "
169      + "each exeamplar from the starting point of 1.0. In order to avoid "
170      + "overfitting, it uses mean-square function (i.e. the Euclidean "
171      + "distance) to search for the weights.\n "
172      + "It then uses the weights to cleanse the training data. After that "
173      + "it searches for the weights again from the starting points of the "
174      + "weights searched before.\n "
175      + "Finally it uses the most updated weights to cleanse the test exemplar "
176      + "and then finds the nearest neighbour of the test exemplar using "
177      + "partly-weighted Kullback distance. But the variances in the Kullback "
178      + "distance are the ones before cleansing.\n\n"
179      + "For more information see:\n\n"
180      + getTechnicalInformation().toString();
181  }
182
183  /**
184   * Returns an instance of a TechnicalInformation object, containing
185   * detailed information about the technical background of this class,
186   * e.g., paper reference or book this class is based on.
187   *
188   * @return the technical information about this class
189   */
190  public TechnicalInformation getTechnicalInformation() {
191    TechnicalInformation        result;
192   
193    result = new TechnicalInformation(Type.MISC);
194    result.setValue(Field.AUTHOR, "Xin Xu");
195    result.setValue(Field.YEAR, "2001");
196    result.setValue(Field.TITLE, "A nearest distribution approach to multiple-instance learning");
197    result.setValue(Field.SCHOOL, "University of Waikato");
198    result.setValue(Field.ADDRESS, "Hamilton, NZ");
199    result.setValue(Field.NOTE, "0657.591B");
200   
201    return result;
202  }
203
204  /**
205   * Returns default capabilities of the classifier.
206   *
207   * @return      the capabilities of this classifier
208   */
209  public Capabilities getCapabilities() {
210    Capabilities result = super.getCapabilities();
211    result.disableAll();
212
213    // attributes
214    result.enable(Capability.NOMINAL_ATTRIBUTES);
215    result.enable(Capability.RELATIONAL_ATTRIBUTES);
216    result.enable(Capability.MISSING_VALUES);
217
218    // class
219    result.enable(Capability.NOMINAL_CLASS);
220    result.enable(Capability.MISSING_CLASS_VALUES);
221   
222    // other
223    result.enable(Capability.ONLY_MULTIINSTANCE);
224   
225    return result;
226  }
227
228  /**
229   * Returns the capabilities of this multi-instance classifier for the
230   * relational data.
231   *
232   * @return            the capabilities of this object
233   * @see               Capabilities
234   */
235  public Capabilities getMultiInstanceCapabilities() {
236    Capabilities result = super.getCapabilities();
237    result.disableAll();
238   
239    // attributes
240    result.enable(Capability.NUMERIC_ATTRIBUTES);
241    result.enable(Capability.DATE_ATTRIBUTES);
242    result.enable(Capability.MISSING_VALUES);
243
244    // class
245    result.disableAllClasses();
246    result.enable(Capability.NO_CLASS);
247   
248    return result;
249  }
250
251  /**
252   * As normal Nearest Neighbour algorithm does, it's lazy and simply
253   * records the exemplar information (i.e. mean and variance for each
254   * dimension of each exemplar and their classes) when building the model.
255   * There is actually no need to store the exemplars themselves.
256   *
257   * @param exs the training exemplars
258   * @throws Exception if the model cannot be built properly
259   */   
260  public void buildClassifier(Instances exs)throws Exception{
261    // can classifier handle the data?
262    getCapabilities().testWithFail(exs);
263
264    // remove instances with missing class
265    Instances newData = new Instances(exs);
266    newData.deleteWithMissingClass();
267   
268    int numegs = newData.numInstances();
269    m_Dimension = newData.attribute(1).relation().numAttributes();
270    m_Attributes = newData.stringFreeStructure(); 
271    m_Change = new double[numegs][m_Dimension];
272    m_NumClasses = exs.numClasses();
273    m_Mean = new double[numegs][m_Dimension];
274    m_Variance = new double[numegs][m_Dimension];
275    m_Class = new double[numegs];
276    m_Weights = new double[numegs];
277    m_NoiseM = new double[numegs][m_Dimension];
278    m_NoiseV = new double[numegs][m_Dimension];
279    m_ValidM = new double[numegs][m_Dimension];
280    m_ValidV = new double[numegs][m_Dimension];
281    m_MinArray = new double[m_Dimension];
282    m_MaxArray = new double[m_Dimension];
283    for(int v=0; v < m_Dimension; v++)
284      m_MinArray[v] = m_MaxArray[v] = Double.NaN;
285
286    for(int w=0; w < numegs; w++){
287      updateMinMax(newData.instance(w));
288    }
289
290    // Scale exemplars
291    Instances data = m_Attributes;
292
293    for(int x=0; x < numegs; x++){
294      Instance example = newData.instance(x);
295      example = scale(example);
296      for (int i=0; i<m_Dimension; i++) {
297        m_Mean[x][i] = example.relationalValue(1).meanOrMode(i);       
298        m_Variance[x][i] = example.relationalValue(1).variance(i);
299        if(Utils.eq(m_Variance[x][i],0.0))
300          m_Variance[x][i] = m_ZERO;
301        m_Change[x][i] = 1.0;
302      }
303      /* for(int y=0; y < m_Variance[x].length; y++){
304         if(Utils.eq(m_Variance[x][y],0.0))
305         m_Variance[x][y] = m_ZERO;
306         m_Change[x][y] = 1.0;
307         }  */ 
308
309      data.add(example);
310      m_Class[x] = example.classValue();
311      m_Weights[x] = example.weight(); 
312    }
313
314    for(int z=0; z < numegs; z++)
315      findWeights(z, m_Mean);
316
317    // Pre-process and record "true estimated" parameters for distributions
318    for(int x=0; x < numegs; x++){
319      Instance example = preprocess(data, x);
320      if (getDebug())
321        System.out.println("???Exemplar "+x+" has been pre-processed:"+
322            data.instance(x).relationalValue(1).sumOfWeights()+
323            "|"+example.relationalValue(1).sumOfWeights()+
324            "; class:"+m_Class[x]);
325      if(Utils.gr(example.relationalValue(1).sumOfWeights(), 0)){       
326        for (int i=0; i<m_Dimension; i++) {
327          m_ValidM[x][i] = example.relationalValue(1).meanOrMode(i);
328          m_ValidV[x][i] = example.relationalValue(1).variance(i);
329          if(Utils.eq(m_ValidV[x][i],0.0))
330            m_ValidV[x][i] = m_ZERO;
331        }
332        /*      for(int y=0; y < m_ValidV[x].length; y++){
333                if(Utils.eq(m_ValidV[x][y],0.0))
334                m_ValidV[x][y] = m_ZERO;
335                }*/     
336      }
337      else{
338        m_ValidM[x] = null;
339        m_ValidV[x] = null;
340      }
341    }
342
343    for(int z=0; z < numegs; z++)
344      if(m_ValidM[z] != null)
345        findWeights(z, m_ValidM);       
346
347  }
348
349  /**
350   * Pre-process the given exemplar according to the other exemplars
351   * in the given exemplars.  It also updates noise data statistics.
352   *
353   * @param data the whole exemplars
354   * @param pos the position of given exemplar in data
355   * @return the processed exemplar
356   * @throws Exception if the returned exemplar is wrong
357   */
358  public Instance preprocess(Instances data, int pos)
359    throws Exception{
360    Instance before = data.instance(pos);
361    if((int)before.classValue() == 0){
362      m_NoiseM[pos] = null;
363      m_NoiseV[pos] = null;
364      return before;
365    }
366
367    Instances after_relationInsts =before.attribute(1).relation().stringFreeStructure();
368    Instances noises_relationInsts =before.attribute(1).relation().stringFreeStructure();
369
370    Instances newData = m_Attributes;
371    Instance after = new DenseInstance(before.numAttributes());
372    Instance noises =  new DenseInstance(before.numAttributes());
373    after.setDataset(newData);
374    noises.setDataset(newData);
375
376    for(int g=0; g < before.relationalValue(1).numInstances(); g++){
377      Instance datum = before.relationalValue(1).instance(g);
378      double[] dists = new double[data.numInstances()];
379
380      for(int i=0; i < data.numInstances(); i++){
381        if(i != pos)
382          dists[i] = distance(datum, m_Mean[i], m_Variance[i], i);
383        else
384          dists[i] = Double.POSITIVE_INFINITY;
385      }           
386
387      int[] pred = new int[m_NumClasses];
388      for(int n=0; n < pred.length; n++)
389        pred[n] = 0;
390
391      for(int o=0; o<m_Select; o++){
392        int index = Utils.minIndex(dists);
393        pred[(int)m_Class[index]]++;
394        dists[index] = Double.POSITIVE_INFINITY;
395      }
396
397      int clas = Utils.maxIndex(pred);
398      if((int)before.classValue() != clas)
399        noises_relationInsts.add(datum);
400      else
401        after_relationInsts.add(datum);         
402    }
403
404    int relationValue;
405    relationValue = noises.attribute(1).addRelation( noises_relationInsts);
406    noises.setValue(0,before.value(0));
407    noises.setValue(1, relationValue);
408    noises.setValue(2, before.classValue());
409
410    relationValue = after.attribute(1).addRelation( after_relationInsts);
411    after.setValue(0,before.value(0));
412    after.setValue(1, relationValue);
413    after.setValue(2, before.classValue());
414
415
416    if(Utils.gr(noises.relationalValue(1).sumOfWeights(), 0)){ 
417      for (int i=0; i<m_Dimension; i++) {
418        m_NoiseM[pos][i] = noises.relationalValue(1).meanOrMode(i);
419        m_NoiseV[pos][i] = noises.relationalValue(1).variance(i);
420        if(Utils.eq(m_NoiseV[pos][i],0.0))
421          m_NoiseV[pos][i] = m_ZERO;
422      }
423      /* for(int y=0; y < m_NoiseV[pos].length; y++){
424         if(Utils.eq(m_NoiseV[pos][y],0.0))
425         m_NoiseV[pos][y] = m_ZERO;
426         } */   
427    }
428    else{
429      m_NoiseM[pos] = null;
430      m_NoiseV[pos] = null;
431    }
432
433    return after;
434  }
435
436  /**
437   * Calculates the distance between two instances
438   *
439   * @param first the first instance
440   * @param second the second instance
441   * @return the distance between the two given instances
442   */         
443  private double distance(Instance first, double[] mean, double[] var, int pos) {
444
445    double diff, distance = 0;
446
447    for(int i = 0; i < m_Dimension; i++) { 
448      // If attribute is numeric
449      if(first.attribute(i).isNumeric()){
450        if (!first.isMissing(i)){     
451          diff = first.value(i) - mean[i];
452          if(Utils.gr(var[i], m_ZERO))
453            distance += m_Change[pos][i] * var[i] * diff * diff;
454          else
455            distance += m_Change[pos][i] * diff * diff; 
456        }
457        else{
458          if(Utils.gr(var[i], m_ZERO))
459            distance += m_Change[pos][i] * var[i];
460          else
461            distance += m_Change[pos][i] * 1.0;
462        }
463      }
464
465    }
466
467    return distance;
468  }
469
470  /**
471   * Updates the minimum and maximum values for all the attributes
472   * based on a new exemplar.
473   *
474   * @param ex the new exemplar
475   */
476  private void updateMinMax(Instance ex) {     
477    Instances insts = ex.relationalValue(1);
478    for (int j = 0;j < m_Dimension; j++) {
479      if (insts.attribute(j).isNumeric()){
480        for(int k=0; k < insts.numInstances(); k++){
481          Instance ins = insts.instance(k);
482          if(!ins.isMissing(j)){
483            if (Double.isNaN(m_MinArray[j])) {
484              m_MinArray[j] = ins.value(j);
485              m_MaxArray[j] = ins.value(j);
486            } else {
487              if (ins.value(j) < m_MinArray[j])
488                m_MinArray[j] = ins.value(j);
489              else if (ins.value(j) > m_MaxArray[j])
490                m_MaxArray[j] = ins.value(j);
491            }
492          }
493        }
494      }
495    }
496  }
497
498  /**
499   * Scale the given exemplar so that the returned exemplar
500   * has the value of 0 to 1 for each dimension
501   *
502   * @param before the given exemplar
503   * @return the resultant exemplar after scaling
504   * @throws Exception if given exampler cannot be scaled properly
505   */
506  private Instance scale(Instance before) throws Exception{
507
508    Instances afterInsts = before.relationalValue(1).stringFreeStructure();
509    Instance after = new DenseInstance(before.numAttributes());
510    after.setDataset(m_Attributes);
511
512    for(int i=0; i < before.relationalValue(1).numInstances(); i++){
513      Instance datum = before.relationalValue(1).instance(i);
514      Instance inst = (Instance)datum.copy();
515
516      for(int j=0; j < m_Dimension; j++){
517        if(before.relationalValue(1).attribute(j).isNumeric())
518          inst.setValue(j, (datum.value(j) - m_MinArray[j])/(m_MaxArray[j] - m_MinArray[j]));   
519      }
520      afterInsts.add(inst);
521    }
522
523    int attValue = after.attribute(1).addRelation(afterInsts);
524    after.setValue(0, before.value( 0));
525    after.setValue(1, attValue);       
526    after.setValue(2, before.value( 2));
527
528    return after;
529  }
530
531  /**
532   * Use gradient descent to distort the MU parameter for
533   * the exemplar.  The exemplar can be in the specified row in the
534   * given matrix, which has numExemplar rows and numDimension columns;
535   * or not in the matrix.
536   *
537   * @param row the given row index
538   * @param mean
539   */
540  public void findWeights(int row, double[][] mean){
541
542    double[] neww = new double[m_Dimension];
543    double[] oldw = new double[m_Dimension];
544    System.arraycopy(m_Change[row], 0, neww, 0, m_Dimension);
545    //for(int z=0; z<m_Dimension; z++)
546    //System.out.println("mu("+row+"): "+origin[z]+" | "+newmu[z]);
547    double newresult = target(neww, mean, row, m_Class);
548    double result = Double.POSITIVE_INFINITY;
549    double rate= 0.05;
550    if(m_Rate != -1)
551      rate = m_Rate;
552    //System.out.println("???Start searching ...");
553search: 
554    while(Utils.gr((result-newresult), m_STOP)){ // Full step
555      oldw = neww;
556      neww= new double[m_Dimension];
557
558      double[] delta = delta(oldw, mean, row, m_Class);
559
560      for(int i=0; i < m_Dimension; i++)
561        if(Utils.gr(m_Variance[row][i], 0.0))
562          neww[i] = oldw[i] + rate * delta[i];
563
564      result = newresult;
565      newresult = target(neww, mean, row, m_Class);
566
567      //System.out.println("???old: "+result+"|new: "+newresult);
568      while(Utils.gr(newresult, result)){ // Search back
569        //System.out.println("search back");
570        if(m_Rate == -1){
571          rate *= m_Decay; // Decay
572          for(int i=0; i < m_Dimension; i++)
573            if(Utils.gr(m_Variance[row][i], 0.0))
574              neww[i] = oldw[i] + rate * delta[i];
575          newresult = target(neww, mean, row, m_Class);
576        }
577        else{
578          for(int i=0; i < m_Dimension; i++)
579            neww[i] = oldw[i];
580          break search;
581        }
582      }
583    }
584    //System.out.println("???Stop");
585    m_Change[row] = neww;
586  }
587
588  /**
589   * Delta of x in one step of gradient descent:
590   * delta(Wij) = 1/2 * sum[k=1..N, k!=i](sqrt(P)*(Yi-Yk)/D - 1) * (MUij -
591   * MUkj)^2 where D = sqrt(sum[j=1..P]Kkj(MUij - MUkj)^2)
592   * N is number of exemplars and P is number of dimensions
593   *
594   * @param x the weights of the exemplar in question
595   * @param rowpos row index of x in X
596   * @param Y the observed class label
597   * @return the delta for all dimensions
598   */
599  private double[] delta(double[] x, double[][] X, int rowpos, double[] Y){
600    double y = Y[rowpos];
601
602    double[] delta=new double[m_Dimension];
603    for(int h=0; h < m_Dimension; h++)
604      delta[h] = 0.0;
605
606    for(int i=0; i < X.length; i++){
607      if((i != rowpos) && (X[i] != null)){
608        double var = (y==Y[i]) ? 0.0 : Math.sqrt((double)m_Dimension - 1);
609        double distance=0;
610        for(int j=0; j < m_Dimension; j++)
611          if(Utils.gr(m_Variance[rowpos][j], 0.0))
612            distance += x[j]*(X[rowpos][j]-X[i][j]) * (X[rowpos][j]-X[i][j]);
613        distance = Math.sqrt(distance);
614        if(distance != 0)
615          for(int k=0; k < m_Dimension; k++)
616            if(m_Variance[rowpos][k] > 0.0)
617              delta[k] += (var/distance - 1.0) * 0.5 *
618                (X[rowpos][k]-X[i][k]) *
619                (X[rowpos][k]-X[i][k]);
620      }
621    }
622    //System.out.println("???delta: "+delta);
623    return delta;
624  }
625
626  /**
627   * Compute the target function to minimize in gradient descent
628   * The formula is:<br/>
629   * 1/2*sum[i=1..p](f(X, Xi)-var(Y, Yi))^2 <p/>
630   * where p is the number of exemplars and Y is the class label.
631   * In the case of X=MU, f() is the Euclidean distance between two
632   * exemplars together with the related weights and var() is
633   * sqrt(numDimension)*(Y-Yi) where Y-Yi is either 0 (when Y==Yi)
634   * or 1 (Y!=Yi)
635   *
636   * @param x the weights of the exemplar in question
637   * @param rowpos row index of x in X
638   * @param Y the observed class label
639   * @return the result of the target function
640   */
641  public double target(double[] x, double[][] X, int rowpos, double[] Y){
642    double y = Y[rowpos], result=0;
643
644    for(int i=0; i < X.length; i++){
645      if((i != rowpos) && (X[i] != null)){
646        double var = (y==Y[i]) ? 0.0 : Math.sqrt((double)m_Dimension - 1);
647        double f=0;
648        for(int j=0; j < m_Dimension; j++)
649          if(Utils.gr(m_Variance[rowpos][j], 0.0)){
650            f += x[j]*(X[rowpos][j]-X[i][j]) * (X[rowpos][j]-X[i][j]);     
651            //System.out.println("i:"+i+" j: "+j+" row: "+rowpos);
652          }
653        f = Math.sqrt(f);
654        //System.out.println("???distance between "+rowpos+" and "+i+": "+f+"|y:"+y+" vs "+Y[i]);
655        if(Double.isInfinite(f))
656          System.exit(1);
657        result += 0.5 * (f - var) * (f - var);
658      }
659    }
660    //System.out.println("???target: "+result);
661    return result;
662  }   
663
664  /**
665   * Use Kullback Leibler distance to find the nearest neighbours of
666   * the given exemplar.
667   * It also uses K-Nearest Neighbour algorithm to classify the
668   * test exemplar
669   *
670   * @param ex the given test exemplar
671   * @return the classification
672   * @throws Exception if the exemplar could not be classified
673   * successfully
674   */
675  public double classifyInstance(Instance ex)throws Exception{
676
677    ex = scale(ex);
678
679    double[] var = new double [m_Dimension];
680    for (int i=0; i<m_Dimension; i++) 
681      var[i]= ex.relationalValue(1).variance(i);       
682
683    // The Kullback distance to all exemplars
684    double[] kullback = new double[m_Class.length];
685
686    // The first K nearest neighbours' predictions */
687  double[] predict = new double[m_NumClasses];
688  for(int h=0; h < predict.length; h++)
689    predict[h] = 0;
690  ex = cleanse(ex);
691
692  if(ex.relationalValue(1).numInstances() == 0){
693    if (getDebug())
694      System.out.println("???Whole exemplar falls into ambiguous area!");
695    return 1.0;                          // Bias towards positive class
696  }
697
698  double[] mean = new double[m_Dimension];     
699  for (int i=0; i<m_Dimension; i++)
700    mean [i]=ex.relationalValue(1).meanOrMode(i);
701
702  // Avoid zero sigma
703  for(int h=0; h < var.length; h++){
704    if(Utils.eq(var[h],0.0))
705      var[h] = m_ZERO;
706  }     
707
708  for(int i=0; i < m_Class.length; i++){
709    if(m_ValidM[i] != null)
710      kullback[i] = kullback(mean, m_ValidM[i], var, m_Variance[i], i);
711    else
712      kullback[i] = Double.POSITIVE_INFINITY;
713  }
714
715  for(int j=0; j < m_Neighbour; j++){
716    int pos = Utils.minIndex(kullback);
717    predict[(int)m_Class[pos]] += m_Weights[pos];         
718    kullback[pos] = Double.POSITIVE_INFINITY;
719  }     
720
721  if (getDebug())
722    System.out.println("???There are still some unambiguous instances in this exemplar! Predicted as: "+Utils.maxIndex(predict));
723  return (double)Utils.maxIndex(predict);       
724  } 
725
726  /**
727   * Cleanse the given exemplar according to the valid and noise data
728   * statistics
729   *
730   * @param before the given exemplar
731   * @return the processed exemplar
732   * @throws Exception if the returned exemplar is wrong
733   */
734  public Instance cleanse(Instance before) throws Exception{
735
736    Instances insts = before.relationalValue(1).stringFreeStructure();
737    Instance after = new DenseInstance(before.numAttributes());
738    after.setDataset(m_Attributes);
739
740    for(int g=0; g < before.relationalValue(1).numInstances(); g++){
741      Instance datum = before.relationalValue(1).instance(g);
742      double[] minNoiDists = new double[m_Choose];
743      double[] minValDists = new double[m_Choose];
744      int noiseCount = 0, validCount = 0;
745      double[] nDist = new double[m_Mean.length]; 
746      double[] vDist = new double[m_Mean.length]; 
747
748      for(int h=0; h < m_Mean.length; h++){
749        if(m_ValidM[h] == null)
750          vDist[h] = Double.POSITIVE_INFINITY;
751        else
752          vDist[h] = distance(datum, m_ValidM[h], m_ValidV[h], h);
753
754        if(m_NoiseM[h] == null)
755          nDist[h] = Double.POSITIVE_INFINITY;
756        else
757          nDist[h] = distance(datum, m_NoiseM[h], m_NoiseV[h], h);
758      }
759
760      for(int k=0; k < m_Choose; k++){
761        int pos = Utils.minIndex(vDist);
762        minValDists[k] = vDist[pos];
763        vDist[pos] = Double.POSITIVE_INFINITY;
764        pos = Utils.minIndex(nDist);
765        minNoiDists[k] = nDist[pos];
766        nDist[pos] = Double.POSITIVE_INFINITY;
767      }
768
769      int x = 0,y = 0;
770      while((x+y) < m_Choose){
771        if(minValDists[x] <= minNoiDists[y]){
772          validCount++;
773          x++;
774        }
775        else{
776          noiseCount++;
777          y++;
778        }
779      }
780      if(x >= y)
781        insts.add (datum);
782
783    }
784
785    after.setValue(0, before.value( 0));
786    after.setValue(1, after.attribute(1).addRelation(insts));
787    after.setValue(2, before.value( 2));
788
789    return after;
790  }   
791
792  /**
793   * This function calculates the Kullback Leibler distance between
794   * two normal distributions.  This distance is always positive.
795   * Kullback Leibler distance = integral{f(X)ln(f(X)/g(X))}
796   * Note that X is a vector.  Since we assume dimensions are independent
797   * f(X)(g(X) the same) is actually the product of normal density
798   * functions of each dimensions.  Also note that it should be log2
799   * instead of (ln) in the formula, but we use (ln) simply for computational
800   * convenience.
801   *
802   * The result is as follows, suppose there are P dimensions, and f(X)
803   * is the first distribution and g(X) is the second:
804   * Kullback = sum[1..P](ln(SIGMA2/SIGMA1)) +
805   *            sum[1..P](SIGMA1^2 / (2*(SIGMA2^2))) +
806   *            sum[1..P]((MU1-MU2)^2 / (2*(SIGMA2^2))) -
807   *            P/2
808   *
809   * @param mu1 mu of the first normal distribution
810   * @param mu2 mu of the second normal distribution
811   * @param var1 variance(SIGMA^2) of the first normal distribution
812   * @param var2 variance(SIGMA^2) of the second normal distribution
813   * @return the Kullback distance of two distributions
814   */
815  public double kullback(double[] mu1, double[] mu2,
816      double[] var1, double[] var2, int pos){
817    int p = mu1.length;
818    double result = 0;
819
820    for(int y=0; y < p; y++){
821      if((Utils.gr(var1[y], 0)) && (Utils.gr(var2[y], 0))){
822        result += 
823          ((Math.log(Math.sqrt(var2[y]/var1[y]))) +
824           (var1[y] / (2.0*var2[y])) + 
825           (m_Change[pos][y] * (mu1[y]-mu2[y])*(mu1[y]-mu2[y]) / (2.0*var2[y])) -
826           0.5);
827      }
828    }
829
830    return result;
831  }
832
833  /**
834   * Returns an enumeration describing the available options
835   *
836   * @return an enumeration of all the available options
837   */
838  public Enumeration listOptions() {
839    Vector result = new Vector();
840
841    result.addElement(new Option(
842          "\tSet number of nearest neighbour for prediction\n"
843          + "\t(default 1)",
844          "K", 1, "-K <number of neighbours>"));
845   
846    result.addElement(new Option(
847          "\tSet number of nearest neighbour for cleansing the training data\n"
848          + "\t(default 1)",
849          "S", 1, "-S <number of neighbours>"));
850   
851    result.addElement(new Option(
852          "\tSet number of nearest neighbour for cleansing the testing data\n"
853          + "\t(default 1)",
854          "E", 1, "-E <number of neighbours>"));
855
856    return result.elements();
857  }
858
859  /**
860   * Parses a given list of options. <p/>
861   *
862   <!-- options-start -->
863   * Valid options are: <p/>
864   *
865   * <pre> -K &lt;number of neighbours&gt;
866   *  Set number of nearest neighbour for prediction
867   *  (default 1)</pre>
868   *
869   * <pre> -S &lt;number of neighbours&gt;
870   *  Set number of nearest neighbour for cleansing the training data
871   *  (default 1)</pre>
872   *
873   * <pre> -E &lt;number of neighbours&gt;
874   *  Set number of nearest neighbour for cleansing the testing data
875   *  (default 1)</pre>
876   *
877   <!-- options-end -->
878   *
879   * @param options the list of options as an array of strings
880   * @throws Exception if an option is not supported
881   */
882  public void setOptions(String[] options) throws Exception{
883
884    setDebug(Utils.getFlag('D', options));
885
886    String numNeighbourString = Utils.getOption('K', options);
887    if (numNeighbourString.length() != 0) 
888      setNumNeighbours(Integer.parseInt(numNeighbourString));
889    else 
890      setNumNeighbours(1);
891
892    numNeighbourString = Utils.getOption('S', options);
893    if (numNeighbourString.length() != 0) 
894      setNumTrainingNoises(Integer.parseInt(numNeighbourString));
895    else 
896      setNumTrainingNoises(1);
897
898    numNeighbourString = Utils.getOption('E', options);
899    if (numNeighbourString.length() != 0) 
900      setNumTestingNoises(Integer.parseInt(numNeighbourString));
901    else 
902      setNumTestingNoises(1);
903  }
904
905  /**
906   * Gets the current settings of the Classifier.
907   *
908   * @return an array of strings suitable for passing to setOptions
909   */
910  public String[] getOptions() {
911    Vector        result;
912   
913    result = new Vector();
914
915    if (getDebug())
916      result.add("-D");
917   
918    result.add("-K");
919    result.add("" + getNumNeighbours());
920   
921    result.add("-S");
922    result.add("" + getNumTrainingNoises());
923   
924    result.add("-E");
925    result.add("" + getNumTestingNoises());
926
927    return (String[]) result.toArray(new String[result.size()]);
928  }
929
930  /**
931   * Returns the tip text for this property
932   *
933   * @return tip text for this property suitable for
934   * displaying in the explorer/experimenter gui
935   */
936  public String numNeighboursTipText() {
937    return "The number of nearest neighbours to the estimate the class prediction of test bags.";
938  }
939
940  /**
941   * Sets the number of nearest neighbours to estimate
942   * the class prediction of tests bags
943   * @param numNeighbour the number of citers
944   */
945  public void setNumNeighbours(int numNeighbour){
946    m_Neighbour = numNeighbour;
947  }
948
949  /**
950   * Returns the number of nearest neighbours to estimate
951   * the class prediction of tests bags
952   * @return the number of neighbours
953   */
954  public int getNumNeighbours(){
955    return m_Neighbour;
956  }
957
958  /**
959   * Returns the tip text for this property
960   *
961   * @return tip text for this property suitable for
962   * displaying in the explorer/experimenter gui
963   */
964  public String numTrainingNoisesTipText() {
965    return "The number of nearest neighbour instances in the selection of noises in the training data.";
966  }
967
968  /**
969   * Sets the number of nearest neighbour instances in the
970   * selection of noises in the training data
971   *
972   * @param numTraining the number of noises in training data
973   */
974  public void setNumTrainingNoises (int numTraining){
975    m_Select = numTraining;
976  }
977
978  /**
979   * Returns the number of nearest neighbour instances in the
980   * selection of noises in the training data
981   *
982   * @return the number of noises in training data
983   */
984  public int getNumTrainingNoises(){
985    return m_Select;
986  }
987
988  /**
989   * Returns the tip text for this property
990   *
991   * @return tip text for this property suitable for
992   * displaying in the explorer/experimenter gui
993   */
994  public String numTestingNoisesTipText() {
995    return "The number of nearest neighbour instances in the selection of noises in the test data.";
996  }
997
998  /**
999   * Returns The number of nearest neighbour instances in the
1000   * selection of noises in the test data
1001   * @return the number of noises in test data
1002   */
1003  public int getNumTestingNoises(){
1004    return m_Choose;
1005  }
1006
1007  /**
1008   * Sets The number of nearest neighbour exemplars in the
1009   * selection of noises in the test data
1010   * @param numTesting the number of noises in test data
1011   */
1012  public void setNumTestingNoises (int numTesting){
1013    m_Choose = numTesting;
1014  }
1015 
1016  /**
1017   * Returns the revision string.
1018   *
1019   * @return            the revision
1020   */
1021  public String getRevision() {
1022    return RevisionUtils.extract("$Revision: 5987 $");
1023  }
1024
1025  /**
1026   * Main method for testing.
1027   *
1028   * @param args the options for the classifier
1029   */
1030  public static void main(String[] args) {     
1031    runClassifier(new MINND(), args);
1032  }
1033}
Note: See TracBrowser for help on using the repository browser.