source: src/main/java/weka/classifiers/functions/LibLINEAR.java @ 25

Last change on this file since 25 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 34.8 KB
RevLine 
[4]1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 * LibLINEAR.java
19 * Copyright (C) Benedikt Waldvogel
20 */
21package weka.classifiers.functions;
22
23import java.lang.reflect.Array;
24import java.lang.reflect.Constructor;
25import java.lang.reflect.Field;
26import java.lang.reflect.Method;
27import java.util.ArrayList;
28import java.util.Enumeration;
29import java.util.List;
30import java.util.StringTokenizer;
31import java.util.Vector;
32
33import weka.classifiers.Classifier;
34import weka.classifiers.AbstractClassifier;
35import weka.core.Capabilities;
36import weka.core.Instance;
37import weka.core.Instances;
38import weka.core.Option;
39import weka.core.RevisionUtils;
40import weka.core.SelectedTag;
41import weka.core.Tag;
42import weka.core.TechnicalInformation;
43import weka.core.TechnicalInformationHandler;
44import weka.core.Utils;
45import weka.core.WekaException;
46import weka.core.Capabilities.Capability;
47import weka.core.TechnicalInformation.Type;
48import weka.filters.Filter;
49import weka.filters.unsupervised.attribute.NominalToBinary;
50import weka.filters.unsupervised.attribute.Normalize;
51import weka.filters.unsupervised.attribute.ReplaceMissingValues;
52
53/**
54  <!-- globalinfo-start -->
55  * A wrapper class for the liblinear tools (the liblinear classes, typically the jar file, need to be in the classpath to use this classifier).<br/>
56  * Rong-En Fan, Kai-Wei Chang, Cho-Jui Hsieh, Xiang-Rui Wang, Chih-Jen Lin (2008). LIBLINEAR - A Library for Large Linear Classification. URL http://www.csie.ntu.edu.tw/~cjlin/liblinear/.
57  * <p/>
58  <!-- globalinfo-end -->
59 *
60 <!-- technical-bibtex-start -->
61 * BibTeX:
62 * <pre>
63 * &#64;misc{Fan2008,
64 *    author = {Rong-En Fan and Kai-Wei Chang and Cho-Jui Hsieh and Xiang-Rui Wang and Chih-Jen Lin},
65 *    note = {The Weka classifier works with version 1.33 of LIBLINEAR},
66 *    title = {LIBLINEAR - A Library for Large Linear Classification},
67 *    year = {2008},
68 *    URL = {http://www.csie.ntu.edu.tw/\~cjlin/liblinear/}
69 * }
70 * </pre>
71 * <p/>
72 <!-- technical-bibtex-end -->
73 *
74 <!-- options-start -->
75 * Valid options are: <p/>
76 *
77 * <pre> -S &lt;int&gt;
78 *  Set type of solver (default: 1)
79 *    0 = L2-regularized logistic regression
80 *    1 = L2-loss support vector machines (dual)
81 *    2 = L2-loss support vector machines (primal)
82 *    3 = L1-loss support vector machines (dual)
83 *    4 = multi-class support vector machines by Crammer and Singer</pre>
84 *
85 * <pre> -C &lt;double&gt;
86 *  Set the cost parameter C
87 *   (default: 1)</pre>
88 *
89 * <pre> -Z
90 *  Turn on normalization of input data (default: off)</pre>
91 *
92 * <pre> -N
93 *  Turn on nominal to binary conversion.</pre>
94 *
95 * <pre> -M
96 *  Turn off missing value replacement.
97 *  WARNING: use only if your data has no missing values.</pre>
98 *
99 * <pre> -P
100 *  Use probability estimation (default: off)
101 * currently for L2-regularized logistic regression only! </pre>
102 *
103 * <pre> -E &lt;double&gt;
104 *  Set tolerance of termination criterion (default: 0.01)</pre>
105 *
106 * <pre> -W &lt;double&gt;
107 *  Set the parameters C of class i to weight[i]*C
108 *   (default: 1)</pre>
109 *
110 * <pre> -B &lt;double&gt;
111 *  Add Bias term with the given value if &gt;= 0; if &lt; 0, no bias term added (default: 1)</pre>
112 *
113 * <pre> -D
114 *  If set, classifier is run in debug mode and
115 *  may output additional info to the console</pre>
116 *
117 <!-- options-end -->
118 *
119 * @author  Benedikt Waldvogel (mail at bwaldvogel.de)
120 * @version $Revision: 5928 $
121 */
122public class LibLINEAR
123  extends AbstractClassifier
124  implements TechnicalInformationHandler {
125
126  /** the svm classname */
127  protected final static String CLASS_LINEAR = "liblinear.Linear";
128
129  /** the svm_model classname */
130  protected final static String CLASS_MODEL = "liblinear.Model";
131
132  /** the svm_problem classname */
133  protected final static String CLASS_PROBLEM = "liblinear.Problem";
134
135  /** the svm_parameter classname */
136  protected final static String CLASS_PARAMETER = "liblinear.Parameter";
137
138  /** the svm_parameter classname */
139  protected final static String CLASS_SOLVERTYPE = "liblinear.SolverType";
140
141  /** the svm_node classname */
142  protected final static String CLASS_FEATURENODE = "liblinear.FeatureNode";
143
144  /** serial UID */
145  protected static final long serialVersionUID = 230504711;
146
147  /** LibLINEAR Model */
148  protected Object m_Model;
149
150
151  public Object getModel() {
152    return m_Model;
153  }
154
155  /** for normalizing the data */
156  protected Filter m_Filter = null;
157
158  /** normalize input data */
159  protected boolean m_Normalize = false;
160
161  /** SVM solver type L2-regularized logistic regression */
162  public static final int SVMTYPE_L2_LR = 0;
163  /** SVM solver type L2-loss support vector machines (dual) */
164  public static final int SVMTYPE_L2LOSS_SVM_DUAL = 1;
165  /** SVM solver type L2-loss support vector machines (primal) */
166  public static final int SVMTYPE_L2LOSS_SVM = 2;
167  /** SVM solver type L1-loss support vector machines (dual) */
168  public static final int SVMTYPE_L1LOSS_SVM_DUAL = 3;
169  /** SVM solver type multi-class support vector machines by Crammer and Singer */
170  public static final int SVMTYPE_MCSVM_CS = 4;
171  /** SVM solver types */
172  public static final Tag[] TAGS_SVMTYPE = {
173    new Tag(SVMTYPE_L2_LR, "L2-regularized logistic regression"),
174    new Tag(SVMTYPE_L2LOSS_SVM_DUAL, "L2-loss support vector machines (dual)"),
175    new Tag(SVMTYPE_L2LOSS_SVM, "L2-loss support vector machines (primal)"),
176    new Tag(SVMTYPE_L1LOSS_SVM_DUAL, "L1-loss support vector machines (dual)"),
177    new Tag(SVMTYPE_MCSVM_CS, "multi-class support vector machines by Crammer and Singer")
178  };
179
180  /** the SVM solver type */
181  protected int m_SVMType = SVMTYPE_L2LOSS_SVM_DUAL;
182
183  /** stopping criteria */
184  protected double m_eps = 0.01;
185
186  /** cost Parameter C */
187  protected double m_Cost = 1;
188
189  /** bias term value */
190  protected double m_Bias = 1;
191
192  protected int[] m_WeightLabel = new int[0];
193
194  protected double[] m_Weight = new double[0];
195
196  /** whether to generate probability estimates instead of +1/-1 in case of
197   * classification problems */
198  protected boolean m_ProbabilityEstimates = false;
199
200  /** The filter used to get rid of missing values. */
201  protected ReplaceMissingValues m_ReplaceMissingValues;
202
203  /** The filter used to make attributes numeric. */
204  protected NominalToBinary m_NominalToBinary;
205
206  /** If true, the nominal to binary filter is applied */
207  private boolean m_nominalToBinary = false;
208
209  /** If true, the replace missing values filter is not applied */
210  private boolean m_noReplaceMissingValues;
211
212  /** whether the liblinear classes are in the Classpath */
213  protected static boolean m_Present = false;
214  static {
215    try {
216      Class.forName(CLASS_LINEAR);
217      m_Present = true;
218    }
219    catch (Exception e) {
220      m_Present = false;
221    }
222  }
223
224  /**
225   * Returns a string describing classifier
226   *
227   * @return a description suitable for displaying in the
228   *         explorer/experimenter gui
229   */
230  public String globalInfo() {
231    return
232      "A wrapper class for the liblinear tools (the liblinear classes, typically "
233      + "the jar file, need to be in the classpath to use this classifier).\n"
234      + getTechnicalInformation().toString();
235  }
236
237  /**
238   * Returns an instance of a TechnicalInformation object, containing
239   * detailed information about the technical background of this class,
240   * e.g., paper reference or book this class is based on.
241   *
242   * @return the technical information about this class
243   */
244  public TechnicalInformation getTechnicalInformation() {
245    TechnicalInformation        result;
246
247    result = new TechnicalInformation(Type.MISC);
248    result.setValue(TechnicalInformation.Field.AUTHOR, "Rong-En Fan and Kai-Wei Chang and Cho-Jui Hsieh and Xiang-Rui Wang and Chih-Jen Lin");
249    result.setValue(TechnicalInformation.Field.TITLE, "LIBLINEAR - A Library for Large Linear Classification");
250    result.setValue(TechnicalInformation.Field.YEAR, "2008");
251    result.setValue(TechnicalInformation.Field.URL, "http://www.csie.ntu.edu.tw/~cjlin/liblinear/");
252    result.setValue(TechnicalInformation.Field.NOTE, "The Weka classifier works with version 1.33 of LIBLINEAR");
253
254    return result;
255  }
256
257  /**
258   * Returns an enumeration describing the available options.
259   *
260   * @return an enumeration of all the available options.
261   */
262  public Enumeration listOptions() {
263    Vector      result;
264
265    result = new Vector();
266
267    result.addElement(
268        new Option(
269          "\tSet type of solver (default: 1)\n"
270          + "\t\t 0 = L2-regularized logistic regression\n"
271          + "\t\t 1 = L2-loss support vector machines (dual)\n"
272          + "\t\t 2 = L2-loss support vector machines (primal)\n"
273          + "\t\t 3 = L1-loss support vector machines (dual)\n"
274          + "\t\t 4 = multi-class support vector machines by Crammer and Singer",
275          "S", 1, "-S <int>"));
276
277    result.addElement(
278        new Option(
279          "\tSet the cost parameter C\n"
280          + "\t (default: 1)",
281          "C", 1, "-C <double>"));
282
283    result.addElement(
284        new Option(
285          "\tTurn on normalization of input data (default: off)",
286          "Z", 0, "-Z"));
287   
288    result.addElement(
289        new Option("\tTurn on nominal to binary conversion.",
290            "N", 0, "-N"));
291   
292    result.addElement(
293        new Option("\tTurn off missing value replacement."
294            + "\n\tWARNING: use only if your data has no missing "
295            + "values.", "M", 0, "-M"));
296
297    result.addElement(
298        new Option(
299          "\tUse probability estimation (default: off)\n" +
300          "currently for L2-regularized logistic regression only! ",
301          "P", 0, "-P"));
302
303    result.addElement(
304        new Option(
305          "\tSet tolerance of termination criterion (default: 0.01)",
306          "E", 1, "-E <double>"));
307
308    result.addElement(
309        new Option(
310          "\tSet the parameters C of class i to weight[i]*C\n"
311          + "\t (default: 1)",
312          "W", 1, "-W <double>"));
313
314    result.addElement(
315        new Option(
316          "\tAdd Bias term with the given value if >= 0; if < 0, no bias term added (default: 1)",
317          "B", 1, "-B <double>"));
318
319    Enumeration en = super.listOptions();
320    while (en.hasMoreElements())
321      result.addElement(en.nextElement());
322
323    return result.elements();
324  }
325
326  /**
327   * Sets the classifier options <p/>
328   *
329   <!-- options-start -->
330   * Valid options are: <p/>
331   *
332   * <pre> -S &lt;int&gt;
333   *  Set type of solver (default: 1)
334   *    0 = L2-regularized logistic regression
335   *    1 = L2-loss support vector machines (dual)
336   *    2 = L2-loss support vector machines (primal)
337   *    3 = L1-loss support vector machines (dual)
338   *    4 = multi-class support vector machines by Crammer and Singer</pre>
339   *
340   * <pre> -C &lt;double&gt;
341   *  Set the cost parameter C
342   *   (default: 1)</pre>
343   *
344   * <pre> -Z
345   *  Turn on normalization of input data (default: off)</pre>
346   *
347   * <pre> -N
348   *  Turn on nominal to binary conversion.</pre>
349   *
350   * <pre> -M
351   *  Turn off missing value replacement.
352   *  WARNING: use only if your data has no missing values.</pre>
353   *
354   * <pre> -P
355   *  Use probability estimation (default: off)
356   * currently for L2-regularized logistic regression only! </pre>
357   *
358   * <pre> -E &lt;double&gt;
359   *  Set tolerance of termination criterion (default: 0.01)</pre>
360   *
361   * <pre> -W &lt;double&gt;
362   *  Set the parameters C of class i to weight[i]*C
363   *   (default: 1)</pre>
364   *
365   * <pre> -B &lt;double&gt;
366   *  Add Bias term with the given value if &gt;= 0; if &lt; 0, no bias term added (default: 1)</pre>
367   *
368   * <pre> -D
369   *  If set, classifier is run in debug mode and
370   *  may output additional info to the console</pre>
371   *
372   <!-- options-end -->
373   *
374   * @param options     the options to parse
375   * @throws Exception  if parsing fails
376   */
377  public void setOptions(String[] options) throws Exception {
378    String      tmpStr;
379
380    tmpStr = Utils.getOption('S', options);
381    if (tmpStr.length() != 0)
382      setSVMType(
383          new SelectedTag(Integer.parseInt(tmpStr), TAGS_SVMTYPE));
384    else
385      setSVMType(
386          new SelectedTag(SVMTYPE_L2LOSS_SVM_DUAL, TAGS_SVMTYPE));
387
388    tmpStr = Utils.getOption('C', options);
389    if (tmpStr.length() != 0)
390      setCost(Double.parseDouble(tmpStr));
391    else
392      setCost(1);
393
394    tmpStr = Utils.getOption('E', options);
395    if (tmpStr.length() != 0)
396      setEps(Double.parseDouble(tmpStr));
397    else
398      setEps(1e-3);
399
400    setNormalize(Utils.getFlag('Z', options));
401   
402    setConvertNominalToBinary(Utils.getFlag('N', options));
403    setDoNotReplaceMissingValues(Utils.getFlag('M', options));
404
405    tmpStr = Utils.getOption('B', options);
406    if (tmpStr.length() != 0)
407      setBias(Double.parseDouble(tmpStr));
408    else
409      setBias(1);
410
411    setWeights(Utils.getOption('W', options));
412
413    setProbabilityEstimates(Utils.getFlag('P', options));
414   
415    super.setOptions(options);
416  }
417
418  /**
419   * Returns the current options
420   *
421   * @return            the current setup
422   */
423  public String[] getOptions() {
424    Vector        result;
425
426    result  = new Vector();
427
428    result.add("-S");
429    result.add("" + m_SVMType);
430
431    result.add("-C");
432    result.add("" + getCost());
433
434    result.add("-E");
435    result.add("" + getEps());
436
437    result.add("-B");
438    result.add("" + getBias());
439
440    if (getNormalize())
441      result.add("-Z");
442   
443    if (getConvertNominalToBinary())
444      result.add("-N");
445   
446    if (getDoNotReplaceMissingValues())
447      result.add("-M");
448
449    if (getWeights().length() != 0) {
450      result.add("-W");
451      result.add("" + getWeights());
452    }
453
454    if (getProbabilityEstimates())
455      result.add("-P");
456
457    return (String[]) result.toArray(new String[result.size()]);
458  }
459
460  /**
461   * returns whether the liblinear classes are present or not, i.e. whether the
462   * classes are in the classpath or not
463   *
464   * @return whether the liblinear classes are available
465   */
466  public static boolean isPresent() {
467    return m_Present;
468  }
469
470  /**
471   * Sets type of SVM (default SVMTYPE_L2)
472   *
473   * @param value       the type of the SVM
474   */
475  public void setSVMType(SelectedTag value) {
476    if (value.getTags() == TAGS_SVMTYPE)
477      m_SVMType = value.getSelectedTag().getID();
478  }
479
480  /**
481   * Gets type of SVM
482   *
483   * @return            the type of the SVM
484   */
485  public SelectedTag getSVMType() {
486    return new SelectedTag(m_SVMType, TAGS_SVMTYPE);
487  }
488
489  /**
490   * Returns the tip text for this property
491   *
492   * @return tip text for this property suitable for
493   *         displaying in the explorer/experimenter gui
494   */
495  public String SVMTypeTipText() {
496    return "The type of SVM to use.";
497  }
498
499  /**
500   * Sets the cost parameter C (default 1)
501   *
502   * @param value       the cost value
503   */
504  public void setCost(double value) {
505    m_Cost = value;
506  }
507
508  /**
509   * Returns the cost parameter C
510   *
511   * @return            the cost value
512   */
513  public double getCost() {
514    return m_Cost;
515  }
516
517  /**
518   * Returns the tip text for this property
519   *
520   * @return tip text for this property suitable for
521   *         displaying in the explorer/experimenter gui
522   */
523  public String costTipText() {
524    return "The cost parameter C.";
525  }
526
527  /**
528   * Sets tolerance of termination criterion (default 0.001)
529   *
530   * @param value       the tolerance
531   */
532  public void setEps(double value) {
533    m_eps = value;
534  }
535
536  /**
537   * Gets tolerance of termination criterion
538   *
539   * @return            the current tolerance
540   */
541  public double getEps() {
542    return m_eps;
543  }
544
545  /**
546   * Returns the tip text for this property
547   *
548   * @return tip text for this property suitable for
549   *         displaying in the explorer/experimenter gui
550   */
551  public String epsTipText() {
552    return "The tolerance of the termination criterion.";
553  }
554
555  /**
556   * Sets bias term value (default 1)
557   * No bias term is added if value &lt; 0
558   *
559   * @param value       the bias term value
560   */
561  public void setBias(double value) {
562    m_Bias = value;
563  }
564
565  /**
566   * Returns bias term value (default 1)
567   * No bias term is added if value &lt; 0
568   *
569   * @return             the bias term value
570   */
571  public double getBias() {
572    return m_Bias;
573  }
574
575  /**
576   * Returns the tip text for this property
577   *
578   * @return tip text for this property suitable for
579   *         displaying in the explorer/experimenter gui
580   */
581  public String biasTipText() {
582    return "If >= 0, a bias term with that value is added; " +
583      "otherwise (<0) no bias term is added (default: 1).";
584  }
585
586  /**
587   * Returns the tip text for this property
588   *
589   * @return tip text for this property suitable for
590   *         displaying in the explorer/experimenter gui
591   */
592  public String normalizeTipText() {
593    return "Whether to normalize the data.";
594  }
595 
596  /**
597   * whether to normalize input data
598   *
599   * @param value       whether to normalize the data
600   */
601  public void setNormalize(boolean value) {
602    m_Normalize = value;
603  }
604
605  /**
606   * whether to normalize input data
607   *
608   * @return            true, if the data is normalized
609   */
610  public boolean getNormalize() {
611    return m_Normalize;
612  }
613 
614  /**
615   * Returns the tip text for this property
616   *
617   * @return tip text for this property suitable for
618   *         displaying in the explorer/experimenter gui
619   */
620  public String convertNominalToBinaryTipText() {
621    return "Whether to turn on conversion of nominal attributes "
622      + "to binary.";
623  }
624 
625  /**
626   * Whether to turn on conversion of nominal attributes
627   * to binary.
628   *
629   * @param b true if nominal to binary conversion is to be
630   * turned on
631   */
632  public void setConvertNominalToBinary(boolean b) {
633    m_nominalToBinary = b;
634  }
635 
636  /**
637   * Gets whether conversion of nominal to binary is
638   * turned on.
639   *
640   * @return true if nominal to binary conversion is turned
641   * on.
642   */
643  public boolean getConvertNominalToBinary() {
644    return m_nominalToBinary;
645  }
646 
647  /**
648   * Returns the tip text for this property
649   *
650   * @return tip text for this property suitable for
651   *         displaying in the explorer/experimenter gui
652   */
653  public String doNotReplaceMissingValuesTipText() {
654    return "Whether to turn off automatic replacement of missing "
655      + "values. WARNING: set to true only if the data does not "
656      + "contain missing values.";
657  }
658 
659  /**
660   * Whether to turn off automatic replacement of missing values.
661   * Set to true only if the data does not contain missing values.
662   *
663   * @param b true if automatic missing values replacement is
664   * to be disabled.
665   */
666  public void setDoNotReplaceMissingValues(boolean b) {
667    m_noReplaceMissingValues = b;
668  }
669 
670  /**
671   * Gets whether automatic replacement of missing values is
672   * disabled.
673   *
674   * @return true if automatic replacement of missing values
675   * is disabled.
676   */
677  public boolean getDoNotReplaceMissingValues() {
678    return m_noReplaceMissingValues;
679  }
680
681  /**
682   * Sets the parameters C of class i to weight[i]*C (default 1).
683   * Blank separated list of doubles.
684   *
685   * @param weightsStr          the weights (doubles, separated by blanks)
686   */
687  public void setWeights(String weightsStr) {
688    StringTokenizer       tok;
689    int                   i;
690
691    tok           = new StringTokenizer(weightsStr, " ");
692    m_Weight      = new double[tok.countTokens()];
693    m_WeightLabel = new int[tok.countTokens()];
694
695    if (m_Weight.length == 0)
696      System.out.println(
697          "Zero Weights processed. Default weights will be used");
698
699    for (i = 0; i < m_Weight.length; i++) {
700      m_Weight[i]      = Double.parseDouble(tok.nextToken());
701      m_WeightLabel[i] = i;
702    }
703  }
704
705  /**
706   * Gets the parameters C of class i to weight[i]*C (default 1).
707   * Blank separated doubles.
708   *
709   * @return            the weights (doubles separated by blanks)
710   */
711  public String getWeights() {
712    String      result;
713    int         i;
714
715    result = "";
716    for (i = 0; i < m_Weight.length; i++) {
717      if (i > 0)
718        result += " ";
719      result += Double.toString(m_Weight[i]);
720    }
721
722    return result;
723  }
724
725  /**
726   * Returns the tip text for this property
727   *
728   * @return tip text for this property suitable for
729   *         displaying in the explorer/experimenter gui
730   */
731  public String weightsTipText() {
732    return "The weights to use for the classes, if empty 1 is used by default.";
733  }
734
735  /**
736   * Returns whether probability estimates are generated instead of -1/+1 for
737   * classification problems.
738   *
739   * @param value       whether to predict probabilities
740   */
741  public void setProbabilityEstimates(boolean value) {
742    m_ProbabilityEstimates = value;
743  }
744
745  /**
746   * Sets whether to generate probability estimates instead of -1/+1 for
747   * classification problems.
748   *
749   * @return            true, if probability estimates should be returned
750   */
751  public boolean getProbabilityEstimates() {
752    return m_ProbabilityEstimates;
753  }
754
755  /**
756   * Returns the tip text for this property
757   *
758   * @return tip text for this property suitable for
759   *         displaying in the explorer/experimenter gui
760   */
761  public String probabilityEstimatesTipText() {
762    return "Whether to generate probability estimates instead of -1/+1 for classification problems " +
763      "(currently for L2-regularized logistic regression only!)";
764  }
765
766  /**
767   * sets the specified field
768   *
769   * @param o           the object to set the field for
770   * @param name        the name of the field
771   * @param value       the new value of the field
772   */
773  protected void setField(Object o, String name, Object value) {
774    Field       f;
775
776    try {
777      f = o.getClass().getField(name);
778      f.set(o, value);
779    }
780    catch (Exception e) {
781      e.printStackTrace();
782    }
783  }
784
785  /**
786   * sets the specified field in an array
787   *
788   * @param o           the object to set the field for
789   * @param name        the name of the field
790   * @param index       the index in the array
791   * @param value       the new value of the field
792   */
793  protected void setField(Object o, String name, int index, Object value) {
794    Field       f;
795
796    try {
797      f = o.getClass().getField(name);
798      Array.set(f.get(o), index, value);
799    }
800    catch (Exception e) {
801      e.printStackTrace();
802    }
803  }
804
805  /**
806   * returns the current value of the specified field
807   *
808   * @param o           the object the field is member of
809   * @param name        the name of the field
810   * @return            the value
811   */
812  protected Object getField(Object o, String name) {
813    Field       f;
814    Object      result;
815
816    try {
817      f      = o.getClass().getField(name);
818      result = f.get(o);
819    }
820    catch (Exception e) {
821      e.printStackTrace();
822      result = null;
823    }
824
825    return result;
826  }
827
828  /**
829   * sets a new array for the field
830   *
831   * @param o           the object to set the array for
832   * @param name        the name of the field
833   * @param type        the type of the array
834   * @param length      the length of the one-dimensional array
835   */
836  protected void newArray(Object o, String name, Class type, int length) {
837    newArray(o, name, type, new int[]{length});
838  }
839
840  /**
841   * sets a new array for the field
842   *
843   * @param o           the object to set the array for
844   * @param name        the name of the field
845   * @param type        the type of the array
846   * @param dimensions  the dimensions of the array
847   */
848  protected void newArray(Object o, String name, Class type, int[] dimensions) {
849    Field       f;
850
851    try {
852      f = o.getClass().getField(name);
853      f.set(o, Array.newInstance(type, dimensions));
854    }
855    catch (Exception e) {
856      e.printStackTrace();
857    }
858  }
859
860  /**
861   * executes the specified method and returns the result, if any
862   *
863   * @param o                   the object the method should be called from
864   * @param name                the name of the method
865   * @param paramClasses        the classes of the parameters
866   * @param paramValues         the values of the parameters
867   * @return                    the return value of the method, if any (in that case null)
868   */
869  protected Object invokeMethod(Object o, String name, Class[] paramClasses, Object[] paramValues) {
870    Method      m;
871    Object      result;
872
873    result = null;
874
875    try {
876      m      = o.getClass().getMethod(name, paramClasses);
877      result = m.invoke(o, paramValues);
878    }
879    catch (Exception e) {
880      e.printStackTrace();
881      result = null;
882    }
883
884    return result;
885  }
886
887  /**
888   * transfers the local variables into a svm_parameter object
889   *
890   * @return the configured svm_parameter object
891   */
892  protected Object getParameters() {
893    Object      result;
894    int         i;
895
896    try {
897      Class solverTypeEnumClass = Class.forName(CLASS_SOLVERTYPE);
898      Object[] enumValues = solverTypeEnumClass.getEnumConstants();
899      Object solverType = enumValues[m_SVMType];
900
901      Class[] constructorClasses = new Class[] { solverTypeEnumClass, double.class, double.class };
902      Constructor parameterConstructor = Class.forName(CLASS_PARAMETER).getConstructor(constructorClasses);
903
904      result = parameterConstructor.newInstance(solverType, Double.valueOf(m_Cost),
905          Double.valueOf(m_eps));
906
907      if (m_Weight.length > 0) {
908        invokeMethod(result, "setWeights", new Class[] { double[].class, int[].class },
909            new Object[] { m_Weight, m_WeightLabel });
910      }
911    }
912    catch (Exception e) {
913      e.printStackTrace();
914      result = null;
915    }
916
917    return result;
918  }
919
920  /**
921   * returns the svm_problem
922   *
923   * @param vx the x values
924   * @param vy the y values
925   * @param max_index
926   * @return the Problem object
927   */
928  protected Object getProblem(List<Object> vx, List<Integer> vy, int max_index) {
929    Object      result;
930
931    try {
932      result = Class.forName(CLASS_PROBLEM).newInstance();
933
934      setField(result, "l", Integer.valueOf(vy.size()));
935      setField(result, "n", Integer.valueOf(max_index));
936      setField(result, "bias", getBias());
937
938      newArray(result, "x", Class.forName(CLASS_FEATURENODE), new int[]{vy.size(), 0});
939      for (int i = 0; i < vy.size(); i++)
940        setField(result, "x", i, vx.get(i));
941
942      newArray(result, "y", Integer.TYPE, vy.size());
943      for (int i = 0; i < vy.size(); i++)
944        setField(result, "y", i, vy.get(i));
945    }
946    catch (Exception e) {
947      e.printStackTrace();
948      result = null;
949    }
950
951    return result;
952  }
953
954  /**
955   * returns an instance into a sparse liblinear array
956   *
957   * @param instance    the instance to work on
958   * @return            the liblinear array
959   * @throws Exception  if setup of array fails
960   */
961  protected Object instanceToArray(Instance instance) throws Exception {
962    int     index;
963    int     count;
964    int     i;
965    Object  result;
966
967    // determine number of non-zero attributes
968    count = 0;
969
970    for (i = 0; i < instance.numValues(); i++) {
971      if (instance.index(i) == instance.classIndex())
972        continue;
973      if (instance.valueSparse(i) != 0)
974        count++;
975    }
976
977    if (m_Bias >= 0) {
978      count++;
979    }
980
981    Class[] intDouble = new Class[] { int.class, double.class };
982    Constructor nodeConstructor = Class.forName(CLASS_FEATURENODE).getConstructor(intDouble);
983
984    // fill array
985    result = Array.newInstance(Class.forName(CLASS_FEATURENODE), count);
986    index  = 0;
987    for (i = 0; i < instance.numValues(); i++) {
988
989      int idx = instance.index(i);
990      double val = instance.valueSparse(i);
991
992      if (idx == instance.classIndex())
993        continue;
994      if (val == 0)
995        continue;
996
997      Object node = nodeConstructor.newInstance(Integer.valueOf(idx+1), Double.valueOf(val));
998      Array.set(result, index, node);
999      index++;
1000    }
1001
1002    // add bias term
1003    if (m_Bias >= 0) {
1004      Integer idx = Integer.valueOf(instance.numAttributes()+1);
1005      Double value = Double.valueOf(m_Bias);
1006      Object node = nodeConstructor.newInstance(idx, value);
1007      Array.set(result, index, node);
1008    }
1009
1010    return result;
1011  }
1012  /**
1013   * Computes the distribution for a given instance.
1014   *
1015   * @param instance            the instance for which distribution is computed
1016   * @return                    the distribution
1017   * @throws Exception          if the distribution can't be computed successfully
1018   */
1019  public double[] distributionForInstance (Instance instance) throws Exception {
1020
1021    if (!getDoNotReplaceMissingValues()) {
1022      m_ReplaceMissingValues.input(instance);
1023      m_ReplaceMissingValues.batchFinished();
1024      instance = m_ReplaceMissingValues.output();
1025    }
1026
1027    if (getConvertNominalToBinary() 
1028        && m_NominalToBinary != null) {
1029      m_NominalToBinary.input(instance);
1030      m_NominalToBinary.batchFinished();
1031      instance = m_NominalToBinary.output();
1032    }
1033
1034    if (m_Filter != null) {
1035      m_Filter.input(instance);
1036      m_Filter.batchFinished();
1037      instance = m_Filter.output();
1038    }
1039
1040    Object x = instanceToArray(instance);
1041    double v;
1042    double[] result = new double[instance.numClasses()];
1043    if (m_ProbabilityEstimates) {
1044      if (m_SVMType != SVMTYPE_L2_LR) {
1045        throw new WekaException("probability estimation is currently only " +
1046            "supported for L2-regularized logistic regression");
1047      }
1048
1049      int[] labels = (int[])invokeMethod(m_Model, "getLabels", null, null);
1050      double[] prob_estimates = new double[instance.numClasses()];
1051
1052      v = ((Integer) invokeMethod(
1053            Class.forName(CLASS_LINEAR).newInstance(),
1054            "predictProbability",
1055            new Class[]{
1056              Class.forName(CLASS_MODEL),
1057        Array.newInstance(Class.forName(CLASS_FEATURENODE), Array.getLength(x)).getClass(),
1058        Array.newInstance(Double.TYPE, prob_estimates.length).getClass()},
1059        new Object[]{ m_Model, x, prob_estimates})).doubleValue();
1060
1061      // Return order of probabilities to canonical weka attribute order
1062      for (int k = 0; k < prob_estimates.length; k++) {
1063        result[labels[k]] = prob_estimates[k];
1064      }
1065    }
1066    else {
1067      v = ((Integer) invokeMethod(
1068            Class.forName(CLASS_LINEAR).newInstance(),
1069            "predict",
1070            new Class[]{
1071              Class.forName(CLASS_MODEL),
1072        Array.newInstance(Class.forName(CLASS_FEATURENODE), Array.getLength(x)).getClass()},
1073        new Object[]{
1074          m_Model,
1075        x})).doubleValue();
1076
1077      assert (instance.classAttribute().isNominal());
1078      result[(int) v] = 1;
1079    }
1080
1081    return result;
1082  }
1083
1084  /**
1085   * Returns default capabilities of the classifier.
1086   *
1087   * @return      the capabilities of this classifier
1088   */
1089  public Capabilities getCapabilities() {
1090    Capabilities result = super.getCapabilities();
1091    result.disableAll();
1092
1093    // attributes
1094    result.enable(Capability.NOMINAL_ATTRIBUTES);
1095    result.enable(Capability.NUMERIC_ATTRIBUTES);
1096    result.enable(Capability.DATE_ATTRIBUTES);
1097//    result.enable(Capability.MISSING_VALUES);
1098
1099    // class
1100    result.enable(Capability.NOMINAL_CLASS);
1101    result.enable(Capability.MISSING_CLASS_VALUES);
1102    return result;
1103  }
1104
1105  /**
1106   * builds the classifier
1107   *
1108   * @param insts       the training instances
1109   * @throws Exception  if liblinear classes not in classpath or liblinear
1110   *                    encountered a problem
1111   */
1112  public void buildClassifier(Instances insts) throws Exception {
1113    m_NominalToBinary = null;
1114    m_Filter = null;
1115   
1116    if (!isPresent())
1117      throw new Exception("liblinear classes not in CLASSPATH!");
1118
1119    // remove instances with missing class
1120    insts = new Instances(insts);
1121    insts.deleteWithMissingClass();
1122   
1123    if (!getDoNotReplaceMissingValues()) {
1124      m_ReplaceMissingValues = new ReplaceMissingValues();
1125      m_ReplaceMissingValues.setInputFormat(insts);
1126      insts = Filter.useFilter(insts, m_ReplaceMissingValues);
1127    }
1128   
1129    // can classifier handle the data?
1130    // we check this here so that if the user turns off
1131    // replace missing values filtering, it will fail
1132    // if the data actually does have missing values
1133    getCapabilities().testWithFail(insts);
1134
1135    if (getConvertNominalToBinary()) {
1136      insts = nominalToBinary(insts);
1137    }
1138
1139    if (getNormalize()) {
1140      m_Filter = new Normalize();
1141      m_Filter.setInputFormat(insts);
1142      insts = Filter.useFilter(insts, m_Filter);
1143    }
1144
1145    List<Integer> vy = new ArrayList<Integer>(insts.numInstances());
1146    List<Object> vx = new ArrayList<Object>(insts.numInstances());
1147    int max_index = 0;
1148
1149    for (int d = 0; d < insts.numInstances(); d++) {
1150      Instance inst = insts.instance(d);
1151      Object x = instanceToArray(inst);
1152      int m = Array.getLength(x);
1153      if (m > 0)
1154        max_index = Math.max(max_index, ((Integer) getField(Array.get(x, m - 1), "index")).intValue());
1155      vx.add(x);
1156      double classValue = inst.classValue();
1157      int classValueInt = (int)classValue;
1158      if (classValueInt != classValue) throw new RuntimeException("unsupported class value: " + classValue);
1159      vy.add(Integer.valueOf(classValueInt));
1160    }
1161
1162    if (!m_Debug) {
1163      invokeMethod(
1164          Class.forName(CLASS_LINEAR).newInstance(),
1165          "disableDebugOutput", null, null);
1166    } else {
1167      invokeMethod(
1168          Class.forName(CLASS_LINEAR).newInstance(),
1169          "enableDebugOutput", null, null);
1170    }
1171
1172    // reset the PRNG for regression-stable results
1173    invokeMethod(
1174        Class.forName(CLASS_LINEAR).newInstance(),
1175        "resetRandom", null, null);
1176
1177    // train model
1178    m_Model = invokeMethod(
1179        Class.forName(CLASS_LINEAR).newInstance(),
1180        "train",
1181        new Class[]{
1182          Class.forName(CLASS_PROBLEM),
1183            Class.forName(CLASS_PARAMETER)},
1184            new Object[]{
1185              getProblem(vx, vy, max_index),
1186            getParameters()});
1187  }
1188
1189  /**
1190   * turns on nominal to binary filtering
1191   * if there are not only numeric attributes
1192   */
1193  private Instances nominalToBinary( Instances insts ) throws Exception {
1194    boolean onlyNumeric = true;
1195    for (int i = 0; i < insts.numAttributes(); i++) {
1196      if (i != insts.classIndex()) {
1197        if (!insts.attribute(i).isNumeric()) {
1198          onlyNumeric = false;
1199          break;
1200        }
1201      }
1202    }
1203
1204    if (!onlyNumeric) {
1205      m_NominalToBinary = new NominalToBinary();
1206      m_NominalToBinary.setInputFormat(insts);
1207      insts = Filter.useFilter(insts, m_NominalToBinary);
1208    }
1209    return insts;
1210  }
1211
1212  /**
1213   * returns a string representation
1214   *
1215   * @return            a string representation
1216   */
1217  public String toString() {
1218    return "LibLINEAR wrapper";
1219  }
1220
1221  /**
1222   * Returns the revision string.
1223   *
1224   * @return            the revision
1225   */
1226  public String getRevision() {
1227    return RevisionUtils.extract("$Revision: 5928 $");
1228  }
1229
1230  /**
1231   * Main method for testing this class.
1232   *
1233   * @param args the options
1234   */
1235  public static void main(String[] args) {
1236    runClassifier(new LibLINEAR(), args);
1237  }
1238}
1239
Note: See TracBrowser for help on using the repository browser.