source: src/main/java/weka/classifiers/mi/MISMO.java @ 17

Last change on this file since 17 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 62.3 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 * MISMO.java
19 * Copyright (C) 2005 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.mi;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.classifiers.functions.Logistic;
28import weka.classifiers.functions.supportVector.Kernel;
29import weka.classifiers.functions.supportVector.SMOset;
30import weka.classifiers.mi.supportVector.MIPolyKernel;
31import weka.core.Attribute;
32import weka.core.Capabilities;
33import weka.core.FastVector;
34import weka.core.Instance;
35import weka.core.DenseInstance;
36import weka.core.Instances;
37import weka.core.MultiInstanceCapabilitiesHandler;
38import weka.core.Option;
39import weka.core.OptionHandler;
40import weka.core.RevisionHandler;
41import weka.core.RevisionUtils;
42import weka.core.SelectedTag;
43import weka.core.SerializedObject;
44import weka.core.Tag;
45import weka.core.TechnicalInformation;
46import weka.core.TechnicalInformationHandler;
47import weka.core.Utils;
48import weka.core.WeightedInstancesHandler;
49import weka.core.Capabilities.Capability;
50import weka.core.TechnicalInformation.Field;
51import weka.core.TechnicalInformation.Type;
52import weka.filters.Filter;
53import weka.filters.unsupervised.attribute.MultiInstanceToPropositional;
54import weka.filters.unsupervised.attribute.NominalToBinary;
55import weka.filters.unsupervised.attribute.Normalize;
56import weka.filters.unsupervised.attribute.PropositionalToMultiInstance;
57import weka.filters.unsupervised.attribute.ReplaceMissingValues;
58import weka.filters.unsupervised.attribute.Standardize;
59
60import java.io.Serializable;
61import java.util.Enumeration;
62import java.util.Random;
63import java.util.Vector;
64
65/**
66 <!-- globalinfo-start -->
67 * Implements John Platt's sequential minimal optimization algorithm for training a support vector classifier.<br/>
68 * <br/>
69 * This implementation globally replaces all missing values and transforms nominal attributes into binary ones. It also normalizes all attributes by default. (In that case the coefficients in the output are based on the normalized data, not the original data --- this is important for interpreting the classifier.)<br/>
70 * <br/>
71 * Multi-class problems are solved using pairwise classification.<br/>
72 * <br/>
73 * To obtain proper probability estimates, use the option that fits logistic regression models to the outputs of the support vector machine. In the multi-class case the predicted probabilities are coupled using Hastie and Tibshirani's pairwise coupling method.<br/>
74 * <br/>
75 * Note: for improved speed normalization should be turned off when operating on SparseInstances.<br/>
76 * <br/>
77 * For more information on the SMO algorithm, see<br/>
78 * <br/>
79 * J. Platt: Machines using Sequential Minimal Optimization. In B. Schoelkopf and C. Burges and A. Smola, editors, Advances in Kernel Methods - Support Vector Learning, 1998.<br/>
80 * <br/>
81 * S.S. Keerthi, S.K. Shevade, C. Bhattacharyya, K.R.K. Murthy (2001). Improvements to Platt's SMO Algorithm for SVM Classifier Design. Neural Computation. 13(3):637-649.
82 * <p/>
83 <!-- globalinfo-end -->
84 *
85 <!-- technical-bibtex-start -->
86 * BibTeX:
87 * <pre>
88 * &#64;incollection{Platt1998,
89 *    author = {J. Platt},
90 *    booktitle = {Advances in Kernel Methods - Support Vector Learning},
91 *    editor = {B. Schoelkopf and C. Burges and A. Smola},
92 *    publisher = {MIT Press},
93 *    title = {Machines using Sequential Minimal Optimization},
94 *    year = {1998}
95 * }
96 *
97 * &#64;article{Keerthi2001,
98 *    author = {S.S. Keerthi and S.K. Shevade and C. Bhattacharyya and K.R.K. Murthy},
99 *    journal = {Neural Computation},
100 *    number = {3},
101 *    pages = {637-649},
102 *    title = {Improvements to Platt's SMO Algorithm for SVM Classifier Design},
103 *    volume = {13},
104 *    year = {2001}
105 * }
106 * </pre>
107 * <p/>
108 <!-- technical-bibtex-end -->
109 *
110 <!-- options-start -->
111 * Valid options are: <p/>
112 *
113 * <pre> -D
114 *  If set, classifier is run in debug mode and
115 *  may output additional info to the console</pre>
116 *
117 * <pre> -no-checks
118 *  Turns off all checks - use with caution!
119 *  Turning them off assumes that data is purely numeric, doesn't
120 *  contain any missing values, and has a nominal class. Turning them
121 *  off also means that no header information will be stored if the
122 *  machine is linear. Finally, it also assumes that no instance has
123 *  a weight equal to 0.
124 *  (default: checks on)</pre>
125 *
126 * <pre> -C &lt;double&gt;
127 *  The complexity constant C. (default 1)</pre>
128 *
129 * <pre> -N
130 *  Whether to 0=normalize/1=standardize/2=neither.
131 *  (default 0=normalize)</pre>
132 *
133 * <pre> -I
134 *  Use MIminimax feature space. </pre>
135 *
136 * <pre> -L &lt;double&gt;
137 *  The tolerance parameter. (default 1.0e-3)</pre>
138 *
139 * <pre> -P &lt;double&gt;
140 *  The epsilon for round-off error. (default 1.0e-12)</pre>
141 *
142 * <pre> -M
143 *  Fit logistic models to SVM outputs. </pre>
144 *
145 * <pre> -V &lt;double&gt;
146 *  The number of folds for the internal cross-validation.
147 *  (default -1, use training data)</pre>
148 *
149 * <pre> -W &lt;double&gt;
150 *  The random number seed. (default 1)</pre>
151 *
152 * <pre> -K &lt;classname and parameters&gt;
153 *  The Kernel to use.
154 *  (default: weka.classifiers.functions.supportVector.PolyKernel)</pre>
155 *
156 * <pre>
157 * Options specific to kernel weka.classifiers.mi.supportVector.MIPolyKernel:
158 * </pre>
159 *
160 * <pre> -D
161 *  Enables debugging output (if available) to be printed.
162 *  (default: off)</pre>
163 *
164 * <pre> -no-checks
165 *  Turns off all checks - use with caution!
166 *  (default: checks on)</pre>
167 *
168 * <pre> -C &lt;num&gt;
169 *  The size of the cache (a prime number), 0 for full cache and
170 *  -1 to turn it off.
171 *  (default: 250007)</pre>
172 *
173 * <pre> -E &lt;num&gt;
174 *  The Exponent to use.
175 *  (default: 1.0)</pre>
176 *
177 * <pre> -L
178 *  Use lower-order terms.
179 *  (default: no)</pre>
180 *
181 <!-- options-end -->
182 *
183 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
184 * @author Shane Legg (shane@intelligenesis.net) (sparse vector code)
185 * @author Stuart Inglis (stuart@reeltwo.com) (sparse vector code)
186 * @author Lin Dong (ld21@cs.waikato.ac.nz) (code for adapting to MI data)
187 * @version $Revision: 5987 $
188 */
189public class MISMO 
190  extends AbstractClassifier
191  implements WeightedInstancesHandler, MultiInstanceCapabilitiesHandler,
192             TechnicalInformationHandler {
193
194  /** for serialization */
195  static final long serialVersionUID = -5834036950143719712L;
196 
197  /**
198   * Returns a string describing classifier
199   * @return a description suitable for
200   * displaying in the explorer/experimenter gui
201   */
202  public String globalInfo() {
203
204    return  "Implements John Platt's sequential minimal optimization "
205      + "algorithm for training a support vector classifier.\n\n"
206      + "This implementation globally replaces all missing values and "
207      + "transforms nominal attributes into binary ones. It also "
208      + "normalizes all attributes by default. (In that case the coefficients "
209      + "in the output are based on the normalized data, not the "
210      + "original data --- this is important for interpreting the classifier.)\n\n"
211      + "Multi-class problems are solved using pairwise classification.\n\n"
212      + "To obtain proper probability estimates, use the option that fits "
213      + "logistic regression models to the outputs of the support vector "
214      + "machine. In the multi-class case the predicted probabilities "
215      + "are coupled using Hastie and Tibshirani's pairwise coupling "
216      + "method.\n\n"
217      + "Note: for improved speed normalization should be turned off when "
218      + "operating on SparseInstances.\n\n"
219      + "For more information on the SMO algorithm, see\n\n"
220      + getTechnicalInformation().toString();
221  }
222
223  /**
224   * Returns an instance of a TechnicalInformation object, containing
225   * detailed information about the technical background of this class,
226   * e.g., paper reference or book this class is based on.
227   *
228   * @return the technical information about this class
229   */
230  public TechnicalInformation getTechnicalInformation() {
231    TechnicalInformation        result;
232    TechnicalInformation        additional;
233   
234    result = new TechnicalInformation(Type.INCOLLECTION);
235    result.setValue(Field.AUTHOR, "J. Platt");
236    result.setValue(Field.YEAR, "1998");
237    result.setValue(Field.TITLE, "Machines using Sequential Minimal Optimization");
238    result.setValue(Field.BOOKTITLE, "Advances in Kernel Methods - Support Vector Learning");
239    result.setValue(Field.EDITOR, "B. Schoelkopf and C. Burges and A. Smola");
240    result.setValue(Field.PUBLISHER, "MIT Press");
241   
242    additional = result.add(Type.ARTICLE);
243    additional.setValue(Field.AUTHOR, "S.S. Keerthi and S.K. Shevade and C. Bhattacharyya and K.R.K. Murthy");
244    additional.setValue(Field.YEAR, "2001");
245    additional.setValue(Field.TITLE, "Improvements to Platt's SMO Algorithm for SVM Classifier Design");
246    additional.setValue(Field.JOURNAL, "Neural Computation");
247    additional.setValue(Field.VOLUME, "13");
248    additional.setValue(Field.NUMBER, "3");
249    additional.setValue(Field.PAGES, "637-649");
250   
251    return result;
252  }
253
254  /**
255   * Class for building a binary support vector machine.
256   */
257  protected class BinaryMISMO 
258    implements Serializable, RevisionHandler {
259
260    /** for serialization */
261    static final long serialVersionUID = -7107082483475433531L;
262   
263    /** The Lagrange multipliers. */
264    protected double[] m_alpha;
265
266    /** The thresholds. */
267    protected double m_b, m_bLow, m_bUp;
268
269    /** The indices for m_bLow and m_bUp */
270    protected int m_iLow, m_iUp;
271
272    /** The training data. */
273    protected Instances m_data;
274
275    /** Weight vector for linear machine. */
276    protected double[] m_weights;
277
278    /** Variables to hold weight vector in sparse form.
279      (To reduce storage requirements.) */
280    protected double[] m_sparseWeights;
281    protected int[] m_sparseIndices;
282
283    /** Kernel to use **/
284    protected Kernel m_kernel;
285
286    /** The transformed class values. */
287    protected double[] m_class;
288
289    /** The current set of errors for all non-bound examples. */
290    protected double[] m_errors;
291
292    /* The five different sets used by the algorithm. */
293    /** {i: 0 < m_alpha[i] < C} */
294    protected SMOset m_I0;
295    /** {i: m_class[i] = 1, m_alpha[i] = 0} */
296    protected SMOset m_I1; 
297    /** {i: m_class[i] = -1, m_alpha[i] = C} */
298    protected SMOset m_I2; 
299    /** {i: m_class[i] = 1, m_alpha[i] = C} */
300    protected SMOset m_I3; 
301    /** {i: m_class[i] = -1, m_alpha[i] = 0} */
302    protected SMOset m_I4; 
303
304    /** The set of support vectors {i: 0 < m_alpha[i]} */
305    protected SMOset m_supportVectors;
306
307    /** Stores logistic regression model for probability estimate */
308    protected Logistic m_logistic = null;
309
310    /** Stores the weight of the training instances */
311    protected double m_sumOfWeights = 0;
312
313    /**
314     * Fits logistic regression model to SVM outputs analogue
315     * to John Platt's method. 
316     *
317     * @param insts the set of training instances
318     * @param cl1 the first class' index
319     * @param cl2 the second class' index
320     * @param numFolds the number of folds for cross-validation
321     * @param random the random number generator for cross-validation
322     * @throws Exception if the sigmoid can't be fit successfully
323     */
324    protected void fitLogistic(Instances insts, int cl1, int cl2,
325        int numFolds, Random random) 
326      throws Exception {
327
328      // Create header of instances object
329      FastVector atts = new FastVector(2);
330      atts.addElement(new Attribute("pred"));
331      FastVector attVals = new FastVector(2);
332      attVals.addElement(insts.classAttribute().value(cl1));
333      attVals.addElement(insts.classAttribute().value(cl2));
334      atts.addElement(new Attribute("class", attVals));
335      Instances data = new Instances("data", atts, insts.numInstances());
336      data.setClassIndex(1);
337
338      // Collect data for fitting the logistic model
339      if (numFolds <= 0) {
340
341        // Use training data
342        for (int j = 0; j < insts.numInstances(); j++) {
343          Instance inst = insts.instance(j);
344          double[] vals = new double[2];
345          vals[0] = SVMOutput(-1, inst);
346          if (inst.classValue() == cl2) {
347            vals[1] = 1;
348          }
349          data.add(new DenseInstance(inst.weight(), vals));
350        }
351      } else {
352
353        // Check whether number of folds too large
354        if (numFolds > insts.numInstances()) {
355          numFolds = insts.numInstances();
356        }
357
358        // Make copy of instances because we will shuffle them around
359        insts = new Instances(insts);
360
361        // Perform three-fold cross-validation to collect
362        // unbiased predictions
363        insts.randomize(random);
364        insts.stratify(numFolds);
365        for (int i = 0; i < numFolds; i++) {
366          Instances train = insts.trainCV(numFolds, i, random);
367          SerializedObject so = new SerializedObject(this);
368          BinaryMISMO smo = (BinaryMISMO)so.getObject();
369          smo.buildClassifier(train, cl1, cl2, false, -1, -1);
370          Instances test = insts.testCV(numFolds, i);
371          for (int j = 0; j < test.numInstances(); j++) {
372            double[] vals = new double[2];
373            vals[0] = smo.SVMOutput(-1, test.instance(j));
374            if (test.instance(j).classValue() == cl2) {
375              vals[1] = 1;
376            }
377            data.add(new DenseInstance(test.instance(j).weight(), vals));
378          }
379        }
380      }
381
382      // Build logistic regression model
383      m_logistic = new Logistic();
384      m_logistic.buildClassifier(data);
385    }
386   
387    /**
388     * sets the kernel to use
389     *
390     * @param value     the kernel to use
391     */
392    public void setKernel(Kernel value) {
393      m_kernel = value;
394    }
395   
396    /**
397     * Returns the kernel to use
398     *
399     * @return          the current kernel
400     */
401    public Kernel getKernel() {
402      return m_kernel;
403    }
404
405    /**
406     * Method for building the binary classifier.
407     *
408     * @param insts the set of training instances
409     * @param cl1 the first class' index
410     * @param cl2 the second class' index
411     * @param fitLogistic true if logistic model is to be fit
412     * @param numFolds number of folds for internal cross-validation
413     * @param randomSeed seed value for random number generator for cross-validation
414     * @throws Exception if the classifier can't be built successfully
415     */
416    protected void buildClassifier(Instances insts, int cl1, int cl2,
417        boolean fitLogistic, int numFolds,
418        int randomSeed) throws Exception {
419
420      // Initialize some variables
421      m_bUp = -1; m_bLow = 1; m_b = 0; 
422      m_alpha = null; m_data = null; m_weights = null; m_errors = null;
423      m_logistic = null; m_I0 = null; m_I1 = null; m_I2 = null;
424      m_I3 = null; m_I4 = null; m_sparseWeights = null; m_sparseIndices = null;
425
426      // Store the sum of weights
427      m_sumOfWeights = insts.sumOfWeights();
428
429      // Set class values
430      m_class = new double[insts.numInstances()];
431      m_iUp = -1; m_iLow = -1;
432      for (int i = 0; i < m_class.length; i++) {
433        if ((int) insts.instance(i).classValue() == cl1) {
434          m_class[i] = -1; m_iLow = i;
435        } else if ((int) insts.instance(i).classValue() == cl2) {
436          m_class[i] = 1; m_iUp = i;
437        } else {
438          throw new Exception ("This should never happen!");
439        }
440      }
441
442      // Check whether one or both classes are missing
443      if ((m_iUp == -1) || (m_iLow == -1)) {
444        if (m_iUp != -1) {
445          m_b = -1;
446        } else if (m_iLow != -1) {
447          m_b = 1;
448        } else {
449          m_class = null;
450          return;
451        }
452        m_supportVectors = new SMOset(0);
453        m_alpha = new double[0];
454        m_class = new double[0];
455
456        // Fit sigmoid if requested
457        if (fitLogistic) {
458          fitLogistic(insts, cl1, cl2, numFolds, new Random(randomSeed));
459        }
460        return;
461      }
462
463      // Set the reference to the data
464      m_data = insts;
465      m_weights = null;
466
467      // Initialize alpha array to zero
468      m_alpha = new double[m_data.numInstances()];
469
470      // Initialize sets
471      m_supportVectors = new SMOset(m_data.numInstances());
472      m_I0 = new SMOset(m_data.numInstances());
473      m_I1 = new SMOset(m_data.numInstances());
474      m_I2 = new SMOset(m_data.numInstances());
475      m_I3 = new SMOset(m_data.numInstances());
476      m_I4 = new SMOset(m_data.numInstances());
477
478      // Clean out some instance variables
479      m_sparseWeights = null;
480      m_sparseIndices = null;
481
482      // Initialize error cache
483      m_errors = new double[m_data.numInstances()];
484      m_errors[m_iLow] = 1; m_errors[m_iUp] = -1;
485
486      // Initialize kernel
487      m_kernel.buildKernel(m_data);
488
489      // Build up I1 and I4
490      for (int i = 0; i < m_class.length; i++ ) {
491        if (m_class[i] == 1) {
492          m_I1.insert(i);
493        } else {
494          m_I4.insert(i);
495        }
496      }
497
498      // Loop to find all the support vectors
499      int numChanged = 0;
500      boolean examineAll = true;
501      while ((numChanged > 0) || examineAll) {
502        numChanged = 0;
503        if (examineAll) {
504          for (int i = 0; i < m_alpha.length; i++) {
505            if (examineExample(i)) {
506              numChanged++;
507            }
508          }
509        } else {
510
511          // This code implements Modification 1 from Keerthi et al.'s paper
512          for (int i = 0; i < m_alpha.length; i++) {
513            if ((m_alpha[i] > 0) && 
514                (m_alpha[i] < m_C * m_data.instance(i).weight())) {
515              if (examineExample(i)) {
516                numChanged++;
517              }
518
519              // Is optimality on unbound vectors obtained?
520              if (m_bUp > m_bLow - 2 * m_tol) {
521                numChanged = 0;
522                break;
523              }
524                }
525          }
526
527          //This is the code for Modification 2 from Keerthi et al.'s paper
528          /*boolean innerLoopSuccess = true;
529            numChanged = 0;
530            while ((m_bUp < m_bLow - 2 * m_tol) && (innerLoopSuccess == true)) {
531            innerLoopSuccess = takeStep(m_iUp, m_iLow, m_errors[m_iLow]);
532            }*/
533        }
534
535        if (examineAll) {
536          examineAll = false;
537        } else if (numChanged == 0) {
538          examineAll = true;
539        }
540      }
541
542      // Set threshold
543      m_b = (m_bLow + m_bUp) / 2.0;
544
545      // Save memory
546      m_kernel.clean(); 
547
548      m_errors = null;
549      m_I0 = m_I1 = m_I2 = m_I3 = m_I4 = null;
550
551      // Fit sigmoid if requested
552      if (fitLogistic) {
553        fitLogistic(insts, cl1, cl2, numFolds, new Random(randomSeed));
554      }
555
556    }
557
558    /**
559     * Computes SVM output for given instance.
560     *
561     * @param index the instance for which output is to be computed
562     * @param inst the instance
563     * @return the output of the SVM for the given instance
564     * @throws Exception if something goes wrong
565     */
566    protected double SVMOutput(int index, Instance inst) throws Exception {
567
568      double result = 0;
569
570      for (int i = m_supportVectors.getNext(-1); i != -1; 
571          i = m_supportVectors.getNext(i)) {
572        result += m_class[i] * m_alpha[i] * m_kernel.eval(index, i, inst);
573      }
574      result -= m_b;
575
576      return result;
577    }
578
579    /**
580     * Prints out the classifier.
581     *
582     * @return a description of the classifier as a string
583     */
584    public String toString() {
585
586      StringBuffer text = new StringBuffer();
587      int printed = 0;
588
589      if ((m_alpha == null) && (m_sparseWeights == null)) {
590        return "BinaryMISMO: No model built yet.\n";
591      }
592      try {
593        text.append("BinaryMISMO\n\n");
594
595        for (int i = 0; i < m_alpha.length; i++) {
596          if (m_supportVectors.contains(i)) {
597            double val = m_alpha[i];
598            if (m_class[i] == 1) {
599              if (printed > 0) {
600                text.append(" + ");
601              }
602            } else {
603              text.append(" - ");
604            }
605            text.append(Utils.doubleToString(val, 12, 4) 
606                + " * <");
607            for (int j = 0; j < m_data.numAttributes(); j++) {
608              if (j != m_data.classIndex()) {
609                text.append(m_data.instance(i).toString(j));
610              }
611              if (j != m_data.numAttributes() - 1) {
612                text.append(" ");
613              }
614            }
615            text.append("> * X]\n");
616            printed++;
617          }
618        }
619
620        if (m_b > 0) {
621          text.append(" - " + Utils.doubleToString(m_b, 12, 4));
622        } else {
623          text.append(" + " + Utils.doubleToString(-m_b, 12, 4));
624        }
625
626        text.append("\n\nNumber of support vectors: " + 
627            m_supportVectors.numElements());
628        int numEval = 0;
629        int numCacheHits = -1;
630        if(m_kernel != null)
631        {
632          numEval = m_kernel.numEvals();
633          numCacheHits = m_kernel.numCacheHits();
634        }
635        text.append("\n\nNumber of kernel evaluations: " + numEval);
636        if (numCacheHits >= 0 && numEval > 0)
637        {
638          double hitRatio = 1 - numEval*1.0/(numCacheHits+numEval);
639          text.append(" (" + Utils.doubleToString(hitRatio*100, 7, 3).trim() + "% cached)");
640        }
641
642      } catch (Exception e) {
643        e.printStackTrace();
644
645        return "Can't print BinaryMISMO classifier.";
646      }
647
648      return text.toString();
649    }
650
651    /**
652     * Examines instance.
653     *
654     * @param i2 index of instance to examine
655     * @return true if examination was successfull
656     * @throws Exception if something goes wrong
657     */
658    protected boolean examineExample(int i2) throws Exception {
659
660      double y2, F2;
661      int i1 = -1;
662
663      y2 = m_class[i2];
664      if (m_I0.contains(i2)) {
665        F2 = m_errors[i2];
666      } else { 
667        F2 = SVMOutput(i2, m_data.instance(i2)) + m_b - y2;
668        m_errors[i2] = F2;
669
670        // Update thresholds
671        if ((m_I1.contains(i2) || m_I2.contains(i2)) && (F2 < m_bUp)) {
672          m_bUp = F2; m_iUp = i2;
673        } else if ((m_I3.contains(i2) || m_I4.contains(i2)) && (F2 > m_bLow)) {
674          m_bLow = F2; m_iLow = i2;
675        }
676      }
677
678      // Check optimality using current bLow and bUp and, if
679      // violated, find an index i1 to do joint optimization
680      // with i2...
681      boolean optimal = true;
682      if (m_I0.contains(i2) || m_I1.contains(i2) || m_I2.contains(i2)) {
683        if (m_bLow - F2 > 2 * m_tol) {
684          optimal = false; i1 = m_iLow;
685        }
686      }
687      if (m_I0.contains(i2) || m_I3.contains(i2) || m_I4.contains(i2)) {
688        if (F2 - m_bUp > 2 * m_tol) {
689          optimal = false; i1 = m_iUp;
690        }
691      }
692      if (optimal) {
693        return false;
694      }
695
696      // For i2 unbound choose the better i1...
697      if (m_I0.contains(i2)) {
698        if (m_bLow - F2 > F2 - m_bUp) {
699          i1 = m_iLow;
700        } else {
701          i1 = m_iUp;
702        }
703      }
704      if (i1 == -1) {
705        throw new Exception("This should never happen!");
706      }
707      return takeStep(i1, i2, F2);
708    }
709
710    /**
711     * Method solving for the Lagrange multipliers for
712     * two instances.
713     *
714     * @param i1 index of the first instance
715     * @param i2 index of the second instance
716     * @param F2
717     * @return true if multipliers could be found
718     * @throws Exception if something goes wrong
719     */
720    protected boolean takeStep(int i1, int i2, double F2) throws Exception {
721
722      double alph1, alph2, y1, y2, F1, s, L, H, k11, k12, k22, eta,
723             a1, a2, f1, f2, v1, v2, Lobj, Hobj;
724      double C1 = m_C * m_data.instance(i1).weight();
725      double C2 = m_C * m_data.instance(i2).weight();
726
727      // Don't do anything if the two instances are the same
728      if (i1 == i2) {
729        return false;
730      }
731
732      // Initialize variables
733      alph1 = m_alpha[i1]; alph2 = m_alpha[i2];
734      y1 = m_class[i1]; y2 = m_class[i2];
735      F1 = m_errors[i1];
736      s = y1 * y2;
737
738      // Find the constraints on a2
739      if (y1 != y2) {
740        L = Math.max(0, alph2 - alph1); 
741        H = Math.min(C2, C1 + alph2 - alph1);
742      } else {
743        L = Math.max(0, alph1 + alph2 - C1);
744        H = Math.min(C2, alph1 + alph2);
745      }
746      if (L >= H) {
747        return false;
748      }
749
750      // Compute second derivative of objective function
751      k11 = m_kernel.eval(i1, i1, m_data.instance(i1));
752      k12 = m_kernel.eval(i1, i2, m_data.instance(i1));
753      k22 = m_kernel.eval(i2, i2, m_data.instance(i2));
754      eta = 2 * k12 - k11 - k22;
755
756      // Check if second derivative is negative
757      if (eta < 0) {
758
759        // Compute unconstrained maximum
760        a2 = alph2 - y2 * (F1 - F2) / eta;
761
762        // Compute constrained maximum
763        if (a2 < L) {
764          a2 = L;
765        } else if (a2 > H) {
766          a2 = H;
767        }
768      } else {
769
770        // Look at endpoints of diagonal
771        f1 = SVMOutput(i1, m_data.instance(i1));
772        f2 = SVMOutput(i2, m_data.instance(i2));
773        v1 = f1 + m_b - y1 * alph1 * k11 - y2 * alph2 * k12; 
774        v2 = f2 + m_b - y1 * alph1 * k12 - y2 * alph2 * k22; 
775        double gamma = alph1 + s * alph2;
776        Lobj = (gamma - s * L) + L - 0.5 * k11 * (gamma - s * L) * (gamma - s * L) - 
777          0.5 * k22 * L * L - s * k12 * (gamma - s * L) * L - 
778          y1 * (gamma - s * L) * v1 - y2 * L * v2;
779        Hobj = (gamma - s * H) + H - 0.5 * k11 * (gamma - s * H) * (gamma - s * H) - 
780          0.5 * k22 * H * H - s * k12 * (gamma - s * H) * H - 
781          y1 * (gamma - s * H) * v1 - y2 * H * v2;
782        if (Lobj > Hobj + m_eps) {
783          a2 = L;
784        } else if (Lobj < Hobj - m_eps) {
785          a2 = H;
786        } else {
787          a2 = alph2;
788        }
789      }
790      if (Math.abs(a2 - alph2) < m_eps * (a2 + alph2 + m_eps)) {
791        return false;
792      }
793
794      // To prevent precision problems
795      if (a2 > C2 - m_Del * C2) {
796        a2 = C2;
797      } else if (a2 <= m_Del * C2) {
798        a2 = 0;
799      }
800
801      // Recompute a1
802      a1 = alph1 + s * (alph2 - a2);
803
804      // To prevent precision problems
805      if (a1 > C1 - m_Del * C1) {
806        a1 = C1;
807      } else if (a1 <= m_Del * C1) {
808        a1 = 0;
809      }
810
811      // Update sets
812      if (a1 > 0) {
813        m_supportVectors.insert(i1);
814      } else {
815        m_supportVectors.delete(i1);
816      }
817      if ((a1 > 0) && (a1 < C1)) {
818        m_I0.insert(i1);
819      } else {
820        m_I0.delete(i1);
821      }
822      if ((y1 == 1) && (a1 == 0)) {
823        m_I1.insert(i1);
824      } else {
825        m_I1.delete(i1);
826      }
827      if ((y1 == -1) && (a1 == C1)) {
828        m_I2.insert(i1);
829      } else {
830        m_I2.delete(i1);
831      }
832      if ((y1 == 1) && (a1 == C1)) {
833        m_I3.insert(i1);
834      } else {
835        m_I3.delete(i1);
836      }
837      if ((y1 == -1) && (a1 == 0)) {
838        m_I4.insert(i1);
839      } else {
840        m_I4.delete(i1);
841      }
842      if (a2 > 0) {
843        m_supportVectors.insert(i2);
844      } else {
845        m_supportVectors.delete(i2);
846      }
847      if ((a2 > 0) && (a2 < C2)) {
848        m_I0.insert(i2);
849      } else {
850        m_I0.delete(i2);
851      }
852      if ((y2 == 1) && (a2 == 0)) {
853        m_I1.insert(i2);
854      } else {
855        m_I1.delete(i2);
856      }
857      if ((y2 == -1) && (a2 == C2)) {
858        m_I2.insert(i2);
859      } else {
860        m_I2.delete(i2);
861      }
862      if ((y2 == 1) && (a2 == C2)) {
863        m_I3.insert(i2);
864      } else {
865        m_I3.delete(i2);
866      }
867      if ((y2 == -1) && (a2 == 0)) {
868        m_I4.insert(i2);
869      } else {
870        m_I4.delete(i2);
871      }
872
873      // Update error cache using new Lagrange multipliers
874      for (int j = m_I0.getNext(-1); j != -1; j = m_I0.getNext(j)) {
875        if ((j != i1) && (j != i2)) {
876          m_errors[j] += 
877            y1 * (a1 - alph1) * m_kernel.eval(i1, j, m_data.instance(i1)) + 
878            y2 * (a2 - alph2) * m_kernel.eval(i2, j, m_data.instance(i2));
879        }
880      }
881
882      // Update error cache for i1 and i2
883      m_errors[i1] += y1 * (a1 - alph1) * k11 + y2 * (a2 - alph2) * k12;
884      m_errors[i2] += y1 * (a1 - alph1) * k12 + y2 * (a2 - alph2) * k22;
885
886      // Update array with Lagrange multipliers
887      m_alpha[i1] = a1;
888      m_alpha[i2] = a2;
889
890      // Update thresholds
891      m_bLow = -Double.MAX_VALUE; m_bUp = Double.MAX_VALUE;
892      m_iLow = -1; m_iUp = -1;
893      for (int j = m_I0.getNext(-1); j != -1; j = m_I0.getNext(j)) {
894        if (m_errors[j] < m_bUp) {
895          m_bUp = m_errors[j]; m_iUp = j;
896        }
897        if (m_errors[j] > m_bLow) {
898          m_bLow = m_errors[j]; m_iLow = j;
899        }
900      }
901      if (!m_I0.contains(i1)) {
902        if (m_I3.contains(i1) || m_I4.contains(i1)) {
903          if (m_errors[i1] > m_bLow) {
904            m_bLow = m_errors[i1]; m_iLow = i1;
905          } 
906        } else {
907          if (m_errors[i1] < m_bUp) {
908            m_bUp = m_errors[i1]; m_iUp = i1;
909          }
910        }
911      }
912      if (!m_I0.contains(i2)) {
913        if (m_I3.contains(i2) || m_I4.contains(i2)) {
914          if (m_errors[i2] > m_bLow) {
915            m_bLow = m_errors[i2]; m_iLow = i2;
916          }
917        } else {
918          if (m_errors[i2] < m_bUp) {
919            m_bUp = m_errors[i2]; m_iUp = i2;
920          }
921        }
922      }
923      if ((m_iLow == -1) || (m_iUp == -1)) {
924        throw new Exception("This should never happen!");
925      }
926
927      // Made some progress.
928      return true;
929    }
930
931    /**
932     * Quick and dirty check whether the quadratic programming problem is solved.
933     *
934     * @throws Exception if something goes wrong
935     */
936    protected void checkClassifier() throws Exception {
937
938      double sum = 0;
939      for (int i = 0; i < m_alpha.length; i++) {
940        if (m_alpha[i] > 0) {
941          sum += m_class[i] * m_alpha[i];
942        }
943      }
944      System.err.println("Sum of y(i) * alpha(i): " + sum);
945
946      for (int i = 0; i < m_alpha.length; i++) {
947        double output = SVMOutput(i, m_data.instance(i));
948        if (Utils.eq(m_alpha[i], 0)) {
949          if (Utils.sm(m_class[i] * output, 1)) {
950            System.err.println("KKT condition 1 violated: " + m_class[i] * output);
951          }
952        } 
953        if (Utils.gr(m_alpha[i], 0) && 
954            Utils.sm(m_alpha[i], m_C * m_data.instance(i).weight())) {
955          if (!Utils.eq(m_class[i] * output, 1)) {
956            System.err.println("KKT condition 2 violated: " + m_class[i] * output);
957          }
958            } 
959        if (Utils.eq(m_alpha[i], m_C * m_data.instance(i).weight())) {
960          if (Utils.gr(m_class[i] * output, 1)) {
961            System.err.println("KKT condition 3 violated: " + m_class[i] * output);
962          }
963        } 
964      }
965    } 
966   
967    /**
968     * Returns the revision string.
969     *
970     * @return          the revision
971     */
972    public String getRevision() {
973      return RevisionUtils.extract("$Revision: 5987 $");
974    }
975  }
976
977  /** Normalize training data */
978  public static final int FILTER_NORMALIZE = 0;
979  /** Standardize training data */
980  public static final int FILTER_STANDARDIZE = 1;
981  /** No normalization/standardization */
982  public static final int FILTER_NONE = 2;
983  /** The filter to apply to the training data */
984  public static final Tag [] TAGS_FILTER = {
985    new Tag(FILTER_NORMALIZE, "Normalize training data"),
986    new Tag(FILTER_STANDARDIZE, "Standardize training data"),
987    new Tag(FILTER_NONE, "No normalization/standardization"),
988  };
989
990  /** The binary classifier(s) */
991  protected BinaryMISMO[][] m_classifiers = null;
992
993  /** The complexity parameter. */
994  protected double m_C = 1.0;
995
996  /** Epsilon for rounding. */
997  protected double m_eps = 1.0e-12;
998
999  /** Tolerance for accuracy of result. */
1000  protected double m_tol = 1.0e-3;
1001
1002  /** Whether to normalize/standardize/neither */
1003  protected int m_filterType = FILTER_NORMALIZE;
1004
1005  /** Use MIMinimax feature space?  */
1006  protected boolean m_minimax = false;   
1007
1008  /** The filter used to make attributes numeric. */
1009  protected NominalToBinary m_NominalToBinary;
1010
1011  /** The filter used to standardize/normalize all values. */
1012  protected Filter m_Filter = null;
1013
1014  /** The filter used to get rid of missing values. */
1015  protected ReplaceMissingValues m_Missing;
1016
1017  /** The class index from the training data */
1018  protected int m_classIndex = -1;
1019
1020  /** The class attribute */
1021  protected Attribute m_classAttribute;
1022 
1023  /** Kernel to use **/
1024  protected Kernel m_kernel = new MIPolyKernel();
1025
1026  /** Turn off all checks and conversions? Turning them off assumes
1027    that data is purely numeric, doesn't contain any missing values,
1028    and has a nominal class. Turning them off also means that
1029    no header information will be stored if the machine is linear.
1030    Finally, it also assumes that no instance has a weight equal to 0.*/
1031  protected boolean m_checksTurnedOff;
1032
1033  /** Precision constant for updating sets */
1034  protected static double m_Del = 1000 * Double.MIN_VALUE;
1035
1036  /** Whether logistic models are to be fit */
1037  protected boolean m_fitLogisticModels = false;
1038
1039  /** The number of folds for the internal cross-validation */
1040  protected int m_numFolds = -1;
1041
1042  /** The random number seed  */
1043  protected int m_randomSeed = 1;
1044
1045  /**
1046   * Turns off checks for missing values, etc. Use with caution.
1047   */
1048  public void turnChecksOff() {
1049
1050    m_checksTurnedOff = true;
1051  }
1052
1053  /**
1054   * Turns on checks for missing values, etc.
1055   */
1056  public void turnChecksOn() {
1057
1058    m_checksTurnedOff = false;
1059  }
1060
1061  /**
1062   * Returns default capabilities of the classifier.
1063   *
1064   * @return      the capabilities of this classifier
1065   */
1066  public Capabilities getCapabilities() {
1067    Capabilities result = getKernel().getCapabilities();
1068    result.setOwner(this);
1069
1070    // attributes
1071    result.enable(Capability.NOMINAL_ATTRIBUTES);
1072    result.enable(Capability.RELATIONAL_ATTRIBUTES);
1073    result.enable(Capability.MISSING_VALUES);
1074
1075    // class
1076    result.disableAllClasses();
1077    result.disableAllClassDependencies();
1078    result.enable(Capability.NOMINAL_CLASS);
1079    result.enable(Capability.MISSING_CLASS_VALUES);
1080   
1081    // other
1082    result.enable(Capability.ONLY_MULTIINSTANCE);
1083   
1084    return result;
1085  }
1086
1087  /**
1088   * Returns the capabilities of this multi-instance classifier for the
1089   * relational data.
1090   *
1091   * @return            the capabilities of this object
1092   * @see               Capabilities
1093   */
1094  public Capabilities getMultiInstanceCapabilities() {
1095    Capabilities result = ((MultiInstanceCapabilitiesHandler) getKernel()).getMultiInstanceCapabilities();
1096    result.setOwner(this);
1097
1098    // attribute
1099    result.enableAllAttributeDependencies();
1100    // with NominalToBinary we can also handle nominal attributes, but only
1101    // if the kernel can handle numeric attributes
1102    if (result.handles(Capability.NUMERIC_ATTRIBUTES))
1103      result.enable(Capability.NOMINAL_ATTRIBUTES);
1104    result.enable(Capability.MISSING_VALUES);
1105   
1106    return result;
1107  }
1108
1109  /**
1110   * Method for building the classifier. Implements a one-against-one
1111   * wrapper for multi-class problems.
1112   *
1113   * @param insts the set of training instances
1114   * @throws Exception if the classifier can't be built successfully
1115   */
1116  public void buildClassifier(Instances insts) throws Exception {
1117    if (!m_checksTurnedOff) {
1118      // can classifier handle the data?
1119      getCapabilities().testWithFail(insts);
1120
1121      // remove instances with missing class
1122      insts = new Instances(insts);
1123      insts.deleteWithMissingClass();
1124
1125      /* Removes all the instances with weight equal to 0.
1126         MUST be done since condition (8) of Keerthi's paper
1127         is made with the assertion Ci > 0 (See equation (3a). */
1128      Instances data = new Instances(insts, insts.numInstances());
1129      for(int i = 0; i < insts.numInstances(); i++){
1130        if(insts.instance(i).weight() > 0)
1131          data.add(insts.instance(i));
1132      }
1133      if (data.numInstances() == 0) {
1134        throw new Exception("No training instances left after removing " + 
1135            "instance with either a weight null or a missing class!");
1136      }
1137      insts = data;     
1138    }
1139
1140    // filter data
1141    if (!m_checksTurnedOff) 
1142      m_Missing = new ReplaceMissingValues();
1143    else 
1144      m_Missing = null;
1145
1146    if (getCapabilities().handles(Capability.NUMERIC_ATTRIBUTES)) {
1147      boolean onlyNumeric = true;
1148      if (!m_checksTurnedOff) {
1149        for (int i = 0; i < insts.numAttributes(); i++) {
1150          if (i != insts.classIndex()) {
1151            if (!insts.attribute(i).isNumeric()) {
1152              onlyNumeric = false;
1153              break;
1154            }
1155          }
1156        }
1157      }
1158     
1159      if (!onlyNumeric) {
1160        m_NominalToBinary = new NominalToBinary();
1161        // exclude the bag attribute
1162        m_NominalToBinary.setAttributeIndices("2-last");
1163      }
1164      else {
1165        m_NominalToBinary = null;
1166      }
1167    }
1168    else {
1169      m_NominalToBinary = null;
1170    }
1171
1172    if (m_filterType == FILTER_STANDARDIZE) 
1173      m_Filter = new Standardize();
1174    else if (m_filterType == FILTER_NORMALIZE)
1175      m_Filter = new Normalize();
1176    else 
1177      m_Filter = null;
1178
1179
1180    Instances transformedInsts;
1181    Filter convertToProp = new MultiInstanceToPropositional();
1182    Filter convertToMI = new PropositionalToMultiInstance();
1183
1184    //transform the data into single-instance format
1185    if (m_minimax){ 
1186      /* using SimpleMI class minimax transform method.
1187         this method transforms the multi-instance dataset into minmax feature space (single-instance) */
1188      SimpleMI transMinimax = new SimpleMI();
1189      transMinimax.setTransformMethod(
1190          new SelectedTag(
1191            SimpleMI.TRANSFORMMETHOD_MINIMAX, SimpleMI.TAGS_TRANSFORMMETHOD));
1192      transformedInsts = transMinimax.transform(insts);
1193    }
1194    else { 
1195      convertToProp.setInputFormat(insts);
1196      transformedInsts=Filter.useFilter(insts, convertToProp);
1197    }
1198
1199    if (m_Missing != null) {
1200      m_Missing.setInputFormat(transformedInsts);
1201      transformedInsts = Filter.useFilter(transformedInsts, m_Missing); 
1202    }
1203
1204    if (m_NominalToBinary != null) { 
1205      m_NominalToBinary.setInputFormat(transformedInsts);
1206      transformedInsts = Filter.useFilter(transformedInsts, m_NominalToBinary); 
1207    }
1208
1209    if (m_Filter != null) {
1210      m_Filter.setInputFormat(transformedInsts);
1211      transformedInsts = Filter.useFilter(transformedInsts, m_Filter); 
1212    }
1213
1214    // convert the single-instance format to multi-instance format
1215    convertToMI.setInputFormat(transformedInsts);
1216    insts = Filter.useFilter( transformedInsts, convertToMI);
1217
1218    m_classIndex = insts.classIndex();
1219    m_classAttribute = insts.classAttribute();
1220
1221    // Generate subsets representing each class
1222    Instances[] subsets = new Instances[insts.numClasses()];
1223    for (int i = 0; i < insts.numClasses(); i++) {
1224      subsets[i] = new Instances(insts, insts.numInstances());
1225    }
1226    for (int j = 0; j < insts.numInstances(); j++) {
1227      Instance inst = insts.instance(j);
1228      subsets[(int)inst.classValue()].add(inst);
1229    }
1230    for (int i = 0; i < insts.numClasses(); i++) {
1231      subsets[i].compactify();
1232    }
1233
1234    // Build the binary classifiers
1235    Random rand = new Random(m_randomSeed);
1236    m_classifiers = new BinaryMISMO[insts.numClasses()][insts.numClasses()];
1237    for (int i = 0; i < insts.numClasses(); i++) {
1238      for (int j = i + 1; j < insts.numClasses(); j++) {
1239        m_classifiers[i][j] = new BinaryMISMO(); 
1240        m_classifiers[i][j].setKernel(Kernel.makeCopy(getKernel()));
1241        Instances data = new Instances(insts, insts.numInstances());
1242        for (int k = 0; k < subsets[i].numInstances(); k++) {
1243          data.add(subsets[i].instance(k));
1244        }
1245        for (int k = 0; k < subsets[j].numInstances(); k++) {
1246          data.add(subsets[j].instance(k));
1247        } 
1248        data.compactify(); 
1249        data.randomize(rand);
1250        m_classifiers[i][j].buildClassifier(data, i, j, 
1251            m_fitLogisticModels,
1252            m_numFolds, m_randomSeed);
1253      }
1254    } 
1255
1256  }
1257
1258  /**
1259   * Estimates class probabilities for given instance.
1260   *
1261   * @param inst the instance to compute the distribution for
1262   * @return the class probabilities
1263   * @throws Exception if computation fails
1264   */
1265  public double[] distributionForInstance(Instance inst) throws Exception { 
1266
1267    //convert instance into instances
1268    Instances insts = new Instances(inst.dataset(), 0);
1269    insts.add(inst);
1270
1271    //transform the data into single-instance format
1272    Filter convertToProp = new MultiInstanceToPropositional();
1273    Filter convertToMI = new PropositionalToMultiInstance();
1274
1275    if (m_minimax){ // using minimax feature space
1276      SimpleMI transMinimax = new SimpleMI();
1277      transMinimax.setTransformMethod(
1278          new SelectedTag(
1279            SimpleMI.TRANSFORMMETHOD_MINIMAX, SimpleMI.TAGS_TRANSFORMMETHOD));
1280      insts = transMinimax.transform (insts);
1281    }
1282    else{
1283      convertToProp.setInputFormat(insts);
1284      insts=Filter.useFilter( insts, convertToProp);
1285    }
1286
1287    // Filter instances
1288    if (m_Missing!=null) 
1289      insts = Filter.useFilter(insts, m_Missing); 
1290
1291    if (m_Filter!=null)
1292      insts = Filter.useFilter(insts, m_Filter);     
1293
1294    // convert the single-instance format to multi-instance format
1295    convertToMI.setInputFormat(insts);
1296    insts=Filter.useFilter( insts, convertToMI);
1297
1298    inst = insts.instance(0); 
1299
1300    if (!m_fitLogisticModels) {
1301      double[] result = new double[inst.numClasses()];
1302      for (int i = 0; i < inst.numClasses(); i++) {
1303        for (int j = i + 1; j < inst.numClasses(); j++) {
1304          if ((m_classifiers[i][j].m_alpha != null) || 
1305              (m_classifiers[i][j].m_sparseWeights != null)) {
1306            double output = m_classifiers[i][j].SVMOutput(-1, inst);
1307            if (output > 0) {
1308              result[j] += 1;
1309            } else {
1310              result[i] += 1;
1311            }
1312              }
1313        } 
1314      }
1315      Utils.normalize(result);
1316      return result;
1317    } else {
1318
1319      // We only need to do pairwise coupling if there are more
1320      // then two classes.
1321      if (inst.numClasses() == 2) {
1322        double[] newInst = new double[2];
1323        newInst[0] = m_classifiers[0][1].SVMOutput(-1, inst);
1324        newInst[1] = Utils.missingValue();
1325        return m_classifiers[0][1].m_logistic.
1326          distributionForInstance(new DenseInstance(1, newInst));
1327      }
1328      double[][] r = new double[inst.numClasses()][inst.numClasses()];
1329      double[][] n = new double[inst.numClasses()][inst.numClasses()];
1330      for (int i = 0; i < inst.numClasses(); i++) {
1331        for (int j = i + 1; j < inst.numClasses(); j++) {
1332          if ((m_classifiers[i][j].m_alpha != null) || 
1333              (m_classifiers[i][j].m_sparseWeights != null)) {
1334            double[] newInst = new double[2];
1335            newInst[0] = m_classifiers[i][j].SVMOutput(-1, inst);
1336            newInst[1] = Utils.missingValue();
1337            r[i][j] = m_classifiers[i][j].m_logistic.
1338              distributionForInstance(new DenseInstance(1, newInst))[0];
1339            n[i][j] = m_classifiers[i][j].m_sumOfWeights;
1340              }
1341        }
1342      }
1343      return pairwiseCoupling(n, r);
1344    }
1345  }
1346
1347  /**
1348   * Implements pairwise coupling.
1349   *
1350   * @param n the sum of weights used to train each model
1351   * @param r the probability estimate from each model
1352   * @return the coupled estimates
1353   */
1354  public double[] pairwiseCoupling(double[][] n, double[][] r) {
1355
1356    // Initialize p and u array
1357    double[] p = new double[r.length];
1358    for (int i =0; i < p.length; i++) {
1359      p[i] = 1.0 / (double)p.length;
1360    }
1361    double[][] u = new double[r.length][r.length];
1362    for (int i = 0; i < r.length; i++) {
1363      for (int j = i + 1; j < r.length; j++) {
1364        u[i][j] = 0.5;
1365      }
1366    }
1367
1368    // firstSum doesn't change
1369    double[] firstSum = new double[p.length];
1370    for (int i = 0; i < p.length; i++) {
1371      for (int j = i + 1; j < p.length; j++) {
1372        firstSum[i] += n[i][j] * r[i][j];
1373        firstSum[j] += n[i][j] * (1 - r[i][j]);
1374      }
1375    }
1376
1377    // Iterate until convergence
1378    boolean changed;
1379    do {
1380      changed = false;
1381      double[] secondSum = new double[p.length];
1382      for (int i = 0; i < p.length; i++) {
1383        for (int j = i + 1; j < p.length; j++) {
1384          secondSum[i] += n[i][j] * u[i][j];
1385          secondSum[j] += n[i][j] * (1 - u[i][j]);
1386        }
1387      }
1388      for (int i = 0; i < p.length; i++) {
1389        if ((firstSum[i] == 0) || (secondSum[i] == 0)) {
1390          if (p[i] > 0) {
1391            changed = true;
1392          }
1393          p[i] = 0;
1394        } else {
1395          double factor = firstSum[i] / secondSum[i];
1396          double pOld = p[i];
1397          p[i] *= factor;
1398          if (Math.abs(pOld - p[i]) > 1.0e-3) {
1399            changed = true;
1400          }
1401        }
1402      }
1403      Utils.normalize(p);
1404      for (int i = 0; i < r.length; i++) {
1405        for (int j = i + 1; j < r.length; j++) {
1406          u[i][j] = p[i] / (p[i] + p[j]);
1407        }
1408      }
1409    } while (changed);
1410    return p;
1411  }
1412
1413  /**
1414   * Returns the weights in sparse format.
1415   *
1416   * @return the weights in sparse format
1417   */
1418  public double [][][] sparseWeights() {
1419
1420    int numValues = m_classAttribute.numValues();
1421    double [][][] sparseWeights = new double[numValues][numValues][];
1422
1423    for (int i = 0; i < numValues; i++) {
1424      for (int j = i + 1; j < numValues; j++) {
1425        sparseWeights[i][j] = m_classifiers[i][j].m_sparseWeights;
1426      }
1427    }
1428
1429    return sparseWeights;
1430  }
1431
1432  /**
1433   * Returns the indices in sparse format.
1434   *
1435   * @return the indices in sparse format
1436   */
1437  public int [][][] sparseIndices() {
1438
1439    int numValues = m_classAttribute.numValues();
1440    int [][][] sparseIndices = new int[numValues][numValues][];
1441
1442    for (int i = 0; i < numValues; i++) {
1443      for (int j = i + 1; j < numValues; j++) {
1444        sparseIndices[i][j] = m_classifiers[i][j].m_sparseIndices;
1445      }
1446    }
1447
1448    return sparseIndices;
1449  }
1450
1451  /**
1452   * Returns the bias of each binary SMO.
1453   *
1454   * @return the bias of each binary SMO
1455   */
1456  public double [][] bias() {
1457
1458    int numValues = m_classAttribute.numValues();
1459    double [][] bias = new double[numValues][numValues];
1460
1461    for (int i = 0; i < numValues; i++) {
1462      for (int j = i + 1; j < numValues; j++) {
1463        bias[i][j] = m_classifiers[i][j].m_b;
1464      }
1465    }
1466
1467    return bias;
1468  }
1469
1470  /**
1471   * Returns the number of values of the class attribute.
1472   *
1473   * @return the number values of the class attribute
1474   */
1475  public int numClassAttributeValues() {
1476
1477    return m_classAttribute.numValues();
1478  }
1479
1480  /**
1481   * Returns the names of the class attributes.
1482   *
1483   * @return the names of the class attributes
1484   */
1485  public String[] classAttributeNames() {
1486
1487    int numValues = m_classAttribute.numValues();
1488
1489    String[] classAttributeNames = new String[numValues];
1490
1491    for (int i = 0; i < numValues; i++) {
1492      classAttributeNames[i] = m_classAttribute.value(i);
1493    }
1494
1495    return classAttributeNames;
1496  }
1497
1498  /**
1499   * Returns the attribute names.
1500   *
1501   * @return the attribute names
1502   */
1503  public String[][][] attributeNames() {
1504
1505    int numValues = m_classAttribute.numValues();
1506    String[][][] attributeNames = new String[numValues][numValues][];
1507
1508    for (int i = 0; i < numValues; i++) {
1509      for (int j = i + 1; j < numValues; j++) {
1510        int numAttributes = m_classifiers[i][j].m_data.numAttributes();
1511        String[] attrNames = new String[numAttributes];
1512        for (int k = 0; k < numAttributes; k++) {
1513          attrNames[k] = m_classifiers[i][j].m_data.attribute(k).name();
1514        }
1515        attributeNames[i][j] = attrNames;         
1516      }
1517    }
1518    return attributeNames;
1519  }
1520
1521  /**
1522   * Returns an enumeration describing the available options.
1523   *
1524   * @return an enumeration of all the available options.
1525   */
1526  public Enumeration listOptions() {
1527
1528    Vector result = new Vector();
1529
1530    Enumeration enm = super.listOptions();
1531    while (enm.hasMoreElements())
1532      result.addElement(enm.nextElement());
1533
1534    result.addElement(new Option(
1535        "\tTurns off all checks - use with caution!\n"
1536        + "\tTurning them off assumes that data is purely numeric, doesn't\n"
1537        + "\tcontain any missing values, and has a nominal class. Turning them\n"
1538        + "\toff also means that no header information will be stored if the\n"
1539        + "\tmachine is linear. Finally, it also assumes that no instance has\n"
1540        + "\ta weight equal to 0.\n"
1541        + "\t(default: checks on)",
1542        "no-checks", 0, "-no-checks"));
1543
1544    result.addElement(new Option(
1545          "\tThe complexity constant C. (default 1)",
1546          "C", 1, "-C <double>"));
1547   
1548    result.addElement(new Option(
1549          "\tWhether to 0=normalize/1=standardize/2=neither.\n" 
1550          + "\t(default 0=normalize)",
1551          "N", 1, "-N"));
1552   
1553    result.addElement(new Option(
1554          "\tUse MIminimax feature space. ",
1555          "I", 0, "-I"));
1556   
1557    result.addElement(new Option(
1558          "\tThe tolerance parameter. (default 1.0e-3)",
1559          "L", 1, "-L <double>"));
1560   
1561    result.addElement(new Option(
1562          "\tThe epsilon for round-off error. (default 1.0e-12)",
1563          "P", 1, "-P <double>"));
1564   
1565    result.addElement(new Option(
1566          "\tFit logistic models to SVM outputs. ",
1567          "M", 0, "-M"));
1568   
1569    result.addElement(new Option(
1570          "\tThe number of folds for the internal cross-validation. \n"
1571          + "\t(default -1, use training data)",
1572          "V", 1, "-V <double>"));
1573   
1574    result.addElement(new Option(
1575          "\tThe random number seed. (default 1)",
1576          "W", 1, "-W <double>"));
1577   
1578    result.addElement(new Option(
1579        "\tThe Kernel to use.\n"
1580        + "\t(default: weka.classifiers.functions.supportVector.PolyKernel)",
1581        "K", 1, "-K <classname and parameters>"));
1582
1583    result.addElement(new Option(
1584        "",
1585        "", 0, "\nOptions specific to kernel "
1586        + getKernel().getClass().getName() + ":"));
1587   
1588    enm = ((OptionHandler) getKernel()).listOptions();
1589    while (enm.hasMoreElements())
1590      result.addElement(enm.nextElement());
1591
1592    return result.elements();
1593  }
1594
1595  /**
1596   * Parses a given list of options. <p/>
1597   *
1598   <!-- options-start -->
1599   * Valid options are: <p/>
1600   *
1601   * <pre> -D
1602   *  If set, classifier is run in debug mode and
1603   *  may output additional info to the console</pre>
1604   *
1605   * <pre> -no-checks
1606   *  Turns off all checks - use with caution!
1607   *  Turning them off assumes that data is purely numeric, doesn't
1608   *  contain any missing values, and has a nominal class. Turning them
1609   *  off also means that no header information will be stored if the
1610   *  machine is linear. Finally, it also assumes that no instance has
1611   *  a weight equal to 0.
1612   *  (default: checks on)</pre>
1613   *
1614   * <pre> -C &lt;double&gt;
1615   *  The complexity constant C. (default 1)</pre>
1616   *
1617   * <pre> -N
1618   *  Whether to 0=normalize/1=standardize/2=neither.
1619   *  (default 0=normalize)</pre>
1620   *
1621   * <pre> -I
1622   *  Use MIminimax feature space. </pre>
1623   *
1624   * <pre> -L &lt;double&gt;
1625   *  The tolerance parameter. (default 1.0e-3)</pre>
1626   *
1627   * <pre> -P &lt;double&gt;
1628   *  The epsilon for round-off error. (default 1.0e-12)</pre>
1629   *
1630   * <pre> -M
1631   *  Fit logistic models to SVM outputs. </pre>
1632   *
1633   * <pre> -V &lt;double&gt;
1634   *  The number of folds for the internal cross-validation.
1635   *  (default -1, use training data)</pre>
1636   *
1637   * <pre> -W &lt;double&gt;
1638   *  The random number seed. (default 1)</pre>
1639   *
1640   * <pre> -K &lt;classname and parameters&gt;
1641   *  The Kernel to use.
1642   *  (default: weka.classifiers.functions.supportVector.PolyKernel)</pre>
1643   *
1644   * <pre>
1645   * Options specific to kernel weka.classifiers.mi.supportVector.MIPolyKernel:
1646   * </pre>
1647   *
1648   * <pre> -D
1649   *  Enables debugging output (if available) to be printed.
1650   *  (default: off)</pre>
1651   *
1652   * <pre> -no-checks
1653   *  Turns off all checks - use with caution!
1654   *  (default: checks on)</pre>
1655   *
1656   * <pre> -C &lt;num&gt;
1657   *  The size of the cache (a prime number), 0 for full cache and
1658   *  -1 to turn it off.
1659   *  (default: 250007)</pre>
1660   *
1661   * <pre> -E &lt;num&gt;
1662   *  The Exponent to use.
1663   *  (default: 1.0)</pre>
1664   *
1665   * <pre> -L
1666   *  Use lower-order terms.
1667   *  (default: no)</pre>
1668   *
1669   <!-- options-end -->
1670   *
1671   * @param options the list of options as an array of strings
1672   * @throws Exception if an option is not supported
1673   */
1674  public void setOptions(String[] options) throws Exception {
1675    String      tmpStr;
1676    String[]    tmpOptions;
1677   
1678    setChecksTurnedOff(Utils.getFlag("no-checks", options));
1679
1680    tmpStr = Utils.getOption('C', options);
1681    if (tmpStr.length() != 0)
1682      setC(Double.parseDouble(tmpStr));
1683    else
1684      setC(1.0);
1685
1686    tmpStr = Utils.getOption('L', options);
1687    if (tmpStr.length() != 0)
1688      setToleranceParameter(Double.parseDouble(tmpStr));
1689    else
1690      setToleranceParameter(1.0e-3);
1691   
1692    tmpStr = Utils.getOption('P', options);
1693    if (tmpStr.length() != 0)
1694      setEpsilon(new Double(tmpStr));
1695    else
1696      setEpsilon(1.0e-12);
1697
1698    setMinimax(Utils.getFlag('I', options));
1699
1700    tmpStr = Utils.getOption('N', options);
1701    if (tmpStr.length() != 0)
1702      setFilterType(new SelectedTag(Integer.parseInt(tmpStr), TAGS_FILTER));
1703    else
1704      setFilterType(new SelectedTag(FILTER_NORMALIZE, TAGS_FILTER));
1705   
1706    setBuildLogisticModels(Utils.getFlag('M', options));
1707   
1708    tmpStr = Utils.getOption('V', options);
1709    if (tmpStr.length() != 0)
1710      m_numFolds = Integer.parseInt(tmpStr);
1711    else
1712      m_numFolds = -1;
1713
1714    tmpStr = Utils.getOption('W', options);
1715    if (tmpStr.length() != 0)
1716      setRandomSeed(Integer.parseInt(tmpStr));
1717    else
1718      setRandomSeed(1);
1719
1720    tmpStr     = Utils.getOption('K', options);
1721    tmpOptions = Utils.splitOptions(tmpStr);
1722    if (tmpOptions.length != 0) {
1723      tmpStr        = tmpOptions[0];
1724      tmpOptions[0] = "";
1725      setKernel(Kernel.forName(tmpStr, tmpOptions));
1726    }
1727   
1728    super.setOptions(options);
1729  }
1730
1731  /**
1732   * Gets the current settings of the classifier.
1733   *
1734   * @return an array of strings suitable for passing to setOptions
1735   */
1736  public String[] getOptions() {
1737    int       i;
1738    Vector    result;
1739    String[]  options;
1740
1741    result = new Vector();
1742    options = super.getOptions();
1743    for (i = 0; i < options.length; i++)
1744      result.add(options[i]);
1745
1746    if (getChecksTurnedOff())
1747      result.add("-no-checks");
1748
1749    result.add("-C"); 
1750    result.add("" + getC());
1751   
1752    result.add("-L");
1753    result.add("" + getToleranceParameter());
1754   
1755    result.add("-P");
1756    result.add("" + getEpsilon());
1757   
1758    result.add("-N");
1759    result.add("" + m_filterType);
1760   
1761    if (getMinimax())
1762      result.add("-I");
1763
1764    if (getBuildLogisticModels())
1765      result.add("-M");
1766   
1767    result.add("-V");
1768    result.add("" + getNumFolds());
1769   
1770    result.add("-W");
1771    result.add("" + getRandomSeed());
1772   
1773    result.add("-K");
1774    result.add("" + getKernel().getClass().getName() + " " + Utils.joinOptions(getKernel().getOptions()));
1775   
1776    return (String[]) result.toArray(new String[result.size()]);         
1777  }
1778
1779  /**
1780   * Disables or enables the checks (which could be time-consuming). Use with
1781   * caution!
1782   *
1783   * @param value       if true turns off all checks
1784   */
1785  public void setChecksTurnedOff(boolean value) {
1786    if (value)
1787      turnChecksOff();
1788    else
1789      turnChecksOn();
1790  }
1791 
1792  /**
1793   * Returns whether the checks are turned off or not.
1794   *
1795   * @return            true if the checks are turned off
1796   */
1797  public boolean getChecksTurnedOff() {
1798    return m_checksTurnedOff;
1799  }
1800
1801  /**
1802   * Returns the tip text for this property
1803   *
1804   * @return            tip text for this property suitable for
1805   *                    displaying in the explorer/experimenter gui
1806   */
1807  public String checksTurnedOffTipText() {
1808    return "Turns time-consuming checks off - use with caution.";
1809  }
1810 
1811  /**
1812   * Returns the tip text for this property
1813   *
1814   * @return            tip text for this property suitable for
1815   *                    displaying in the explorer/experimenter gui
1816   */
1817  public String kernelTipText() {
1818    return "The kernel to use.";
1819  }
1820
1821  /**
1822   * Gets the kernel to use.
1823   *
1824   * @return            the kernel
1825   */
1826  public Kernel getKernel() {
1827    return m_kernel;
1828  }
1829   
1830  /**
1831   * Sets the kernel to use.
1832   *
1833   * @param value       the kernel
1834   */
1835  public void setKernel(Kernel value) {
1836    if (!(value instanceof MultiInstanceCapabilitiesHandler))
1837      throw new IllegalArgumentException(
1838          "Kernel must be able to handle multi-instance data!\n"
1839          + "(This one does not implement " + MultiInstanceCapabilitiesHandler.class.getName() + ")");
1840   
1841    m_kernel = value;
1842  }
1843
1844  /**
1845   * Returns the tip text for this property
1846   * @return tip text for this property suitable for
1847   * displaying in the explorer/experimenter gui
1848   */
1849  public String cTipText() {
1850    return "The complexity parameter C.";
1851  }
1852
1853  /**
1854   * Get the value of C.
1855   *
1856   * @return Value of C.
1857   */
1858  public double getC() {
1859
1860    return m_C;
1861  }
1862
1863  /**
1864   * Set the value of C.
1865   *
1866   * @param v  Value to assign to C.
1867   */
1868  public void setC(double v) {
1869
1870    m_C = v;
1871  }
1872
1873  /**
1874   * Returns the tip text for this property
1875   * @return tip text for this property suitable for
1876   * displaying in the explorer/experimenter gui
1877   */
1878  public String toleranceParameterTipText() {
1879    return "The tolerance parameter (shouldn't be changed).";
1880  }
1881
1882  /**
1883   * Get the value of tolerance parameter.
1884   * @return Value of tolerance parameter.
1885   */
1886  public double getToleranceParameter() {
1887
1888    return m_tol;
1889  }
1890
1891  /**
1892   * Set the value of tolerance parameter.
1893   * @param v  Value to assign to tolerance parameter.
1894   */
1895  public void setToleranceParameter(double v) {
1896
1897    m_tol = v;
1898  }
1899
1900  /**
1901   * Returns the tip text for this property
1902   * @return tip text for this property suitable for
1903   * displaying in the explorer/experimenter gui
1904   */
1905  public String epsilonTipText() {
1906    return "The epsilon for round-off error (shouldn't be changed).";
1907  }
1908
1909  /**
1910   * Get the value of epsilon.
1911   * @return Value of epsilon.
1912   */
1913  public double getEpsilon() {
1914
1915    return m_eps;
1916  }
1917
1918  /**
1919   * Set the value of epsilon.
1920   * @param v  Value to assign to epsilon.
1921   */
1922  public void setEpsilon(double v) {
1923
1924    m_eps = v;
1925  }
1926
1927  /**
1928   * Returns the tip text for this property
1929   * @return tip text for this property suitable for
1930   * displaying in the explorer/experimenter gui
1931   */
1932  public String filterTypeTipText() {
1933    return "Determines how/if the data will be transformed.";
1934  }
1935
1936  /**
1937   * Gets how the training data will be transformed. Will be one of
1938   * FILTER_NORMALIZE, FILTER_STANDARDIZE, FILTER_NONE.
1939   *
1940   * @return the filtering mode
1941   */
1942  public SelectedTag getFilterType() {
1943
1944    return new SelectedTag(m_filterType, TAGS_FILTER);
1945  }
1946
1947  /**
1948   * Sets how the training data will be transformed. Should be one of
1949   * FILTER_NORMALIZE, FILTER_STANDARDIZE, FILTER_NONE.
1950   *
1951   * @param newType the new filtering mode
1952   */
1953  public void setFilterType(SelectedTag newType) {
1954
1955    if (newType.getTags() == TAGS_FILTER) {
1956      m_filterType = newType.getSelectedTag().getID();
1957    }
1958  }
1959
1960  /**
1961   * Returns the tip text for this property
1962   *
1963   * @return tip text for this property suitable for
1964   * displaying in the explorer/experimenter gui
1965   */
1966  public String minimaxTipText() {
1967    return "Whether the MIMinimax feature space is to be used.";
1968  }
1969
1970  /**
1971   * Check if the MIMinimax feature space is to be used.
1972   * @return true if minimax
1973   */
1974  public boolean getMinimax() {
1975
1976    return m_minimax;
1977  }
1978
1979  /**
1980   * Set if the MIMinimax feature space is to be used.
1981   * @param v  true if RBF
1982   */
1983  public void setMinimax(boolean v) {
1984    m_minimax = v;
1985  }
1986
1987  /**
1988   * Returns the tip text for this property
1989   * @return tip text for this property suitable for
1990   * displaying in the explorer/experimenter gui
1991   */
1992  public String buildLogisticModelsTipText() {
1993    return "Whether to fit logistic models to the outputs (for proper "
1994      + "probability estimates).";
1995  }
1996
1997  /**
1998   * Get the value of buildLogisticModels.
1999   *
2000   * @return Value of buildLogisticModels.
2001   */
2002  public boolean getBuildLogisticModels() {
2003
2004    return m_fitLogisticModels;
2005  }
2006
2007  /**
2008   * Set the value of buildLogisticModels.
2009   *
2010   * @param newbuildLogisticModels Value to assign to buildLogisticModels.
2011   */
2012  public void setBuildLogisticModels(boolean newbuildLogisticModels) {
2013
2014    m_fitLogisticModels = newbuildLogisticModels;
2015  }
2016
2017  /**
2018   * Returns the tip text for this property
2019   * @return tip text for this property suitable for
2020   * displaying in the explorer/experimenter gui
2021   */
2022  public String numFoldsTipText() {
2023    return "The number of folds for cross-validation used to generate "
2024      + "training data for logistic models (-1 means use training data).";
2025  }
2026
2027  /**
2028   * Get the value of numFolds.
2029   *
2030   * @return Value of numFolds.
2031   */
2032  public int getNumFolds() {
2033
2034    return m_numFolds;
2035  }
2036
2037  /**
2038   * Set the value of numFolds.
2039   *
2040   * @param newnumFolds Value to assign to numFolds.
2041   */
2042  public void setNumFolds(int newnumFolds) {
2043
2044    m_numFolds = newnumFolds;
2045  }
2046
2047  /**
2048   * Returns the tip text for this property
2049   * @return tip text for this property suitable for
2050   * displaying in the explorer/experimenter gui
2051   */
2052  public String randomSeedTipText() {
2053    return "Random number seed for the cross-validation.";
2054  }
2055
2056  /**
2057   * Get the value of randomSeed.
2058   *
2059   * @return Value of randomSeed.
2060   */
2061  public int getRandomSeed() {
2062
2063    return m_randomSeed;
2064  }
2065
2066  /**
2067   * Set the value of randomSeed.
2068   *
2069   * @param newrandomSeed Value to assign to randomSeed.
2070   */
2071  public void setRandomSeed(int newrandomSeed) {
2072
2073    m_randomSeed = newrandomSeed;
2074  }
2075
2076  /**
2077   * Prints out the classifier.
2078   *
2079   * @return a description of the classifier as a string
2080   */
2081  public String toString() {
2082
2083    StringBuffer text = new StringBuffer();
2084
2085    if ((m_classAttribute == null)) {
2086      return "SMO: No model built yet.";
2087    }
2088    try {
2089      text.append("SMO\n\n");
2090      for (int i = 0; i < m_classAttribute.numValues(); i++) {
2091        for (int j = i + 1; j < m_classAttribute.numValues(); j++) {
2092          text.append("Classifier for classes: " + 
2093              m_classAttribute.value(i) + ", " +
2094              m_classAttribute.value(j) + "\n\n");
2095          text.append(m_classifiers[i][j]);
2096          if (m_fitLogisticModels) {
2097            text.append("\n\n");
2098            if ( m_classifiers[i][j].m_logistic == null) {
2099              text.append("No logistic model has been fit.\n");
2100            } else {
2101              text.append(m_classifiers[i][j].m_logistic);
2102            }
2103          }
2104          text.append("\n\n");
2105        }
2106      }
2107    } catch (Exception e) {
2108      return "Can't print SMO classifier.";
2109    }
2110
2111    return text.toString();
2112  }
2113 
2114  /**
2115   * Returns the revision string.
2116   *
2117   * @return            the revision
2118   */
2119  public String getRevision() {
2120    return RevisionUtils.extract("$Revision: 5987 $");
2121  }
2122
2123  /**
2124   * Main method for testing this class.
2125   *
2126   * @param argv the commandline parameters
2127   */
2128  public static void main(String[] argv) {
2129    runClassifier(new MISMO(), argv);
2130  }
2131}
Note: See TracBrowser for help on using the repository browser.