1 | /* |
---|
2 | * This program is free software; you can redistribute it and/or modify |
---|
3 | * it under the terms of the GNU General Public License as published by |
---|
4 | * the Free Software Foundation; either version 2 of the License, or |
---|
5 | * (at your option) any later version. |
---|
6 | * |
---|
7 | * This program is distributed in the hope that it will be useful, |
---|
8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
---|
10 | * GNU General Public License for more details. |
---|
11 | * |
---|
12 | * You should have received a copy of the GNU General Public License |
---|
13 | * along with this program; if not, write to the Free Software |
---|
14 | * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. |
---|
15 | */ |
---|
16 | |
---|
17 | /* |
---|
18 | * LibLINEAR.java |
---|
19 | * Copyright (C) Benedikt Waldvogel |
---|
20 | */ |
---|
21 | package weka.classifiers.functions; |
---|
22 | |
---|
23 | import java.lang.reflect.Array; |
---|
24 | import java.lang.reflect.Constructor; |
---|
25 | import java.lang.reflect.Field; |
---|
26 | import java.lang.reflect.Method; |
---|
27 | import java.util.ArrayList; |
---|
28 | import java.util.Enumeration; |
---|
29 | import java.util.List; |
---|
30 | import java.util.StringTokenizer; |
---|
31 | import java.util.Vector; |
---|
32 | |
---|
33 | import weka.classifiers.Classifier; |
---|
34 | import weka.classifiers.AbstractClassifier; |
---|
35 | import weka.core.Capabilities; |
---|
36 | import weka.core.Instance; |
---|
37 | import weka.core.Instances; |
---|
38 | import weka.core.Option; |
---|
39 | import weka.core.RevisionUtils; |
---|
40 | import weka.core.SelectedTag; |
---|
41 | import weka.core.Tag; |
---|
42 | import weka.core.TechnicalInformation; |
---|
43 | import weka.core.TechnicalInformationHandler; |
---|
44 | import weka.core.Utils; |
---|
45 | import weka.core.WekaException; |
---|
46 | import weka.core.Capabilities.Capability; |
---|
47 | import weka.core.TechnicalInformation.Type; |
---|
48 | import weka.filters.Filter; |
---|
49 | import weka.filters.unsupervised.attribute.NominalToBinary; |
---|
50 | import weka.filters.unsupervised.attribute.Normalize; |
---|
51 | import weka.filters.unsupervised.attribute.ReplaceMissingValues; |
---|
52 | |
---|
53 | /** |
---|
54 | <!-- globalinfo-start --> |
---|
55 | * A wrapper class for the liblinear tools (the liblinear classes, typically the jar file, need to be in the classpath to use this classifier).<br/> |
---|
56 | * Rong-En Fan, Kai-Wei Chang, Cho-Jui Hsieh, Xiang-Rui Wang, Chih-Jen Lin (2008). LIBLINEAR - A Library for Large Linear Classification. URL http://www.csie.ntu.edu.tw/~cjlin/liblinear/. |
---|
57 | * <p/> |
---|
58 | <!-- globalinfo-end --> |
---|
59 | * |
---|
60 | <!-- technical-bibtex-start --> |
---|
61 | * BibTeX: |
---|
62 | * <pre> |
---|
63 | * @misc{Fan2008, |
---|
64 | * author = {Rong-En Fan and Kai-Wei Chang and Cho-Jui Hsieh and Xiang-Rui Wang and Chih-Jen Lin}, |
---|
65 | * note = {The Weka classifier works with version 1.33 of LIBLINEAR}, |
---|
66 | * title = {LIBLINEAR - A Library for Large Linear Classification}, |
---|
67 | * year = {2008}, |
---|
68 | * URL = {http://www.csie.ntu.edu.tw/\~cjlin/liblinear/} |
---|
69 | * } |
---|
70 | * </pre> |
---|
71 | * <p/> |
---|
72 | <!-- technical-bibtex-end --> |
---|
73 | * |
---|
74 | <!-- options-start --> |
---|
75 | * Valid options are: <p/> |
---|
76 | * |
---|
77 | * <pre> -S <int> |
---|
78 | * Set type of solver (default: 1) |
---|
79 | * 0 = L2-regularized logistic regression |
---|
80 | * 1 = L2-loss support vector machines (dual) |
---|
81 | * 2 = L2-loss support vector machines (primal) |
---|
82 | * 3 = L1-loss support vector machines (dual) |
---|
83 | * 4 = multi-class support vector machines by Crammer and Singer</pre> |
---|
84 | * |
---|
85 | * <pre> -C <double> |
---|
86 | * Set the cost parameter C |
---|
87 | * (default: 1)</pre> |
---|
88 | * |
---|
89 | * <pre> -Z |
---|
90 | * Turn on normalization of input data (default: off)</pre> |
---|
91 | * |
---|
92 | * <pre> -N |
---|
93 | * Turn on nominal to binary conversion.</pre> |
---|
94 | * |
---|
95 | * <pre> -M |
---|
96 | * Turn off missing value replacement. |
---|
97 | * WARNING: use only if your data has no missing values.</pre> |
---|
98 | * |
---|
99 | * <pre> -P |
---|
100 | * Use probability estimation (default: off) |
---|
101 | * currently for L2-regularized logistic regression only! </pre> |
---|
102 | * |
---|
103 | * <pre> -E <double> |
---|
104 | * Set tolerance of termination criterion (default: 0.01)</pre> |
---|
105 | * |
---|
106 | * <pre> -W <double> |
---|
107 | * Set the parameters C of class i to weight[i]*C |
---|
108 | * (default: 1)</pre> |
---|
109 | * |
---|
110 | * <pre> -B <double> |
---|
111 | * Add Bias term with the given value if >= 0; if < 0, no bias term added (default: 1)</pre> |
---|
112 | * |
---|
113 | * <pre> -D |
---|
114 | * If set, classifier is run in debug mode and |
---|
115 | * may output additional info to the console</pre> |
---|
116 | * |
---|
117 | <!-- options-end --> |
---|
118 | * |
---|
119 | * @author Benedikt Waldvogel (mail at bwaldvogel.de) |
---|
120 | * @version $Revision: 5928 $ |
---|
121 | */ |
---|
122 | public class LibLINEAR |
---|
123 | extends AbstractClassifier |
---|
124 | implements TechnicalInformationHandler { |
---|
125 | |
---|
126 | /** the svm classname */ |
---|
127 | protected final static String CLASS_LINEAR = "liblinear.Linear"; |
---|
128 | |
---|
129 | /** the svm_model classname */ |
---|
130 | protected final static String CLASS_MODEL = "liblinear.Model"; |
---|
131 | |
---|
132 | /** the svm_problem classname */ |
---|
133 | protected final static String CLASS_PROBLEM = "liblinear.Problem"; |
---|
134 | |
---|
135 | /** the svm_parameter classname */ |
---|
136 | protected final static String CLASS_PARAMETER = "liblinear.Parameter"; |
---|
137 | |
---|
138 | /** the svm_parameter classname */ |
---|
139 | protected final static String CLASS_SOLVERTYPE = "liblinear.SolverType"; |
---|
140 | |
---|
141 | /** the svm_node classname */ |
---|
142 | protected final static String CLASS_FEATURENODE = "liblinear.FeatureNode"; |
---|
143 | |
---|
144 | /** serial UID */ |
---|
145 | protected static final long serialVersionUID = 230504711; |
---|
146 | |
---|
147 | /** LibLINEAR Model */ |
---|
148 | protected Object m_Model; |
---|
149 | |
---|
150 | |
---|
151 | public Object getModel() { |
---|
152 | return m_Model; |
---|
153 | } |
---|
154 | |
---|
155 | /** for normalizing the data */ |
---|
156 | protected Filter m_Filter = null; |
---|
157 | |
---|
158 | /** normalize input data */ |
---|
159 | protected boolean m_Normalize = false; |
---|
160 | |
---|
161 | /** SVM solver type L2-regularized logistic regression */ |
---|
162 | public static final int SVMTYPE_L2_LR = 0; |
---|
163 | /** SVM solver type L2-loss support vector machines (dual) */ |
---|
164 | public static final int SVMTYPE_L2LOSS_SVM_DUAL = 1; |
---|
165 | /** SVM solver type L2-loss support vector machines (primal) */ |
---|
166 | public static final int SVMTYPE_L2LOSS_SVM = 2; |
---|
167 | /** SVM solver type L1-loss support vector machines (dual) */ |
---|
168 | public static final int SVMTYPE_L1LOSS_SVM_DUAL = 3; |
---|
169 | /** SVM solver type multi-class support vector machines by Crammer and Singer */ |
---|
170 | public static final int SVMTYPE_MCSVM_CS = 4; |
---|
171 | /** SVM solver types */ |
---|
172 | public static final Tag[] TAGS_SVMTYPE = { |
---|
173 | new Tag(SVMTYPE_L2_LR, "L2-regularized logistic regression"), |
---|
174 | new Tag(SVMTYPE_L2LOSS_SVM_DUAL, "L2-loss support vector machines (dual)"), |
---|
175 | new Tag(SVMTYPE_L2LOSS_SVM, "L2-loss support vector machines (primal)"), |
---|
176 | new Tag(SVMTYPE_L1LOSS_SVM_DUAL, "L1-loss support vector machines (dual)"), |
---|
177 | new Tag(SVMTYPE_MCSVM_CS, "multi-class support vector machines by Crammer and Singer") |
---|
178 | }; |
---|
179 | |
---|
180 | /** the SVM solver type */ |
---|
181 | protected int m_SVMType = SVMTYPE_L2LOSS_SVM_DUAL; |
---|
182 | |
---|
183 | /** stopping criteria */ |
---|
184 | protected double m_eps = 0.01; |
---|
185 | |
---|
186 | /** cost Parameter C */ |
---|
187 | protected double m_Cost = 1; |
---|
188 | |
---|
189 | /** bias term value */ |
---|
190 | protected double m_Bias = 1; |
---|
191 | |
---|
192 | protected int[] m_WeightLabel = new int[0]; |
---|
193 | |
---|
194 | protected double[] m_Weight = new double[0]; |
---|
195 | |
---|
196 | /** whether to generate probability estimates instead of +1/-1 in case of |
---|
197 | * classification problems */ |
---|
198 | protected boolean m_ProbabilityEstimates = false; |
---|
199 | |
---|
200 | /** The filter used to get rid of missing values. */ |
---|
201 | protected ReplaceMissingValues m_ReplaceMissingValues; |
---|
202 | |
---|
203 | /** The filter used to make attributes numeric. */ |
---|
204 | protected NominalToBinary m_NominalToBinary; |
---|
205 | |
---|
206 | /** If true, the nominal to binary filter is applied */ |
---|
207 | private boolean m_nominalToBinary = false; |
---|
208 | |
---|
209 | /** If true, the replace missing values filter is not applied */ |
---|
210 | private boolean m_noReplaceMissingValues; |
---|
211 | |
---|
212 | /** whether the liblinear classes are in the Classpath */ |
---|
213 | protected static boolean m_Present = false; |
---|
214 | static { |
---|
215 | try { |
---|
216 | Class.forName(CLASS_LINEAR); |
---|
217 | m_Present = true; |
---|
218 | } |
---|
219 | catch (Exception e) { |
---|
220 | m_Present = false; |
---|
221 | } |
---|
222 | } |
---|
223 | |
---|
224 | /** |
---|
225 | * Returns a string describing classifier |
---|
226 | * |
---|
227 | * @return a description suitable for displaying in the |
---|
228 | * explorer/experimenter gui |
---|
229 | */ |
---|
230 | public String globalInfo() { |
---|
231 | return |
---|
232 | "A wrapper class for the liblinear tools (the liblinear classes, typically " |
---|
233 | + "the jar file, need to be in the classpath to use this classifier).\n" |
---|
234 | + getTechnicalInformation().toString(); |
---|
235 | } |
---|
236 | |
---|
237 | /** |
---|
238 | * Returns an instance of a TechnicalInformation object, containing |
---|
239 | * detailed information about the technical background of this class, |
---|
240 | * e.g., paper reference or book this class is based on. |
---|
241 | * |
---|
242 | * @return the technical information about this class |
---|
243 | */ |
---|
244 | public TechnicalInformation getTechnicalInformation() { |
---|
245 | TechnicalInformation result; |
---|
246 | |
---|
247 | result = new TechnicalInformation(Type.MISC); |
---|
248 | result.setValue(TechnicalInformation.Field.AUTHOR, "Rong-En Fan and Kai-Wei Chang and Cho-Jui Hsieh and Xiang-Rui Wang and Chih-Jen Lin"); |
---|
249 | result.setValue(TechnicalInformation.Field.TITLE, "LIBLINEAR - A Library for Large Linear Classification"); |
---|
250 | result.setValue(TechnicalInformation.Field.YEAR, "2008"); |
---|
251 | result.setValue(TechnicalInformation.Field.URL, "http://www.csie.ntu.edu.tw/~cjlin/liblinear/"); |
---|
252 | result.setValue(TechnicalInformation.Field.NOTE, "The Weka classifier works with version 1.33 of LIBLINEAR"); |
---|
253 | |
---|
254 | return result; |
---|
255 | } |
---|
256 | |
---|
257 | /** |
---|
258 | * Returns an enumeration describing the available options. |
---|
259 | * |
---|
260 | * @return an enumeration of all the available options. |
---|
261 | */ |
---|
262 | public Enumeration listOptions() { |
---|
263 | Vector result; |
---|
264 | |
---|
265 | result = new Vector(); |
---|
266 | |
---|
267 | result.addElement( |
---|
268 | new Option( |
---|
269 | "\tSet type of solver (default: 1)\n" |
---|
270 | + "\t\t 0 = L2-regularized logistic regression\n" |
---|
271 | + "\t\t 1 = L2-loss support vector machines (dual)\n" |
---|
272 | + "\t\t 2 = L2-loss support vector machines (primal)\n" |
---|
273 | + "\t\t 3 = L1-loss support vector machines (dual)\n" |
---|
274 | + "\t\t 4 = multi-class support vector machines by Crammer and Singer", |
---|
275 | "S", 1, "-S <int>")); |
---|
276 | |
---|
277 | result.addElement( |
---|
278 | new Option( |
---|
279 | "\tSet the cost parameter C\n" |
---|
280 | + "\t (default: 1)", |
---|
281 | "C", 1, "-C <double>")); |
---|
282 | |
---|
283 | result.addElement( |
---|
284 | new Option( |
---|
285 | "\tTurn on normalization of input data (default: off)", |
---|
286 | "Z", 0, "-Z")); |
---|
287 | |
---|
288 | result.addElement( |
---|
289 | new Option("\tTurn on nominal to binary conversion.", |
---|
290 | "N", 0, "-N")); |
---|
291 | |
---|
292 | result.addElement( |
---|
293 | new Option("\tTurn off missing value replacement." |
---|
294 | + "\n\tWARNING: use only if your data has no missing " |
---|
295 | + "values.", "M", 0, "-M")); |
---|
296 | |
---|
297 | result.addElement( |
---|
298 | new Option( |
---|
299 | "\tUse probability estimation (default: off)\n" + |
---|
300 | "currently for L2-regularized logistic regression only! ", |
---|
301 | "P", 0, "-P")); |
---|
302 | |
---|
303 | result.addElement( |
---|
304 | new Option( |
---|
305 | "\tSet tolerance of termination criterion (default: 0.01)", |
---|
306 | "E", 1, "-E <double>")); |
---|
307 | |
---|
308 | result.addElement( |
---|
309 | new Option( |
---|
310 | "\tSet the parameters C of class i to weight[i]*C\n" |
---|
311 | + "\t (default: 1)", |
---|
312 | "W", 1, "-W <double>")); |
---|
313 | |
---|
314 | result.addElement( |
---|
315 | new Option( |
---|
316 | "\tAdd Bias term with the given value if >= 0; if < 0, no bias term added (default: 1)", |
---|
317 | "B", 1, "-B <double>")); |
---|
318 | |
---|
319 | Enumeration en = super.listOptions(); |
---|
320 | while (en.hasMoreElements()) |
---|
321 | result.addElement(en.nextElement()); |
---|
322 | |
---|
323 | return result.elements(); |
---|
324 | } |
---|
325 | |
---|
326 | /** |
---|
327 | * Sets the classifier options <p/> |
---|
328 | * |
---|
329 | <!-- options-start --> |
---|
330 | * Valid options are: <p/> |
---|
331 | * |
---|
332 | * <pre> -S <int> |
---|
333 | * Set type of solver (default: 1) |
---|
334 | * 0 = L2-regularized logistic regression |
---|
335 | * 1 = L2-loss support vector machines (dual) |
---|
336 | * 2 = L2-loss support vector machines (primal) |
---|
337 | * 3 = L1-loss support vector machines (dual) |
---|
338 | * 4 = multi-class support vector machines by Crammer and Singer</pre> |
---|
339 | * |
---|
340 | * <pre> -C <double> |
---|
341 | * Set the cost parameter C |
---|
342 | * (default: 1)</pre> |
---|
343 | * |
---|
344 | * <pre> -Z |
---|
345 | * Turn on normalization of input data (default: off)</pre> |
---|
346 | * |
---|
347 | * <pre> -N |
---|
348 | * Turn on nominal to binary conversion.</pre> |
---|
349 | * |
---|
350 | * <pre> -M |
---|
351 | * Turn off missing value replacement. |
---|
352 | * WARNING: use only if your data has no missing values.</pre> |
---|
353 | * |
---|
354 | * <pre> -P |
---|
355 | * Use probability estimation (default: off) |
---|
356 | * currently for L2-regularized logistic regression only! </pre> |
---|
357 | * |
---|
358 | * <pre> -E <double> |
---|
359 | * Set tolerance of termination criterion (default: 0.01)</pre> |
---|
360 | * |
---|
361 | * <pre> -W <double> |
---|
362 | * Set the parameters C of class i to weight[i]*C |
---|
363 | * (default: 1)</pre> |
---|
364 | * |
---|
365 | * <pre> -B <double> |
---|
366 | * Add Bias term with the given value if >= 0; if < 0, no bias term added (default: 1)</pre> |
---|
367 | * |
---|
368 | * <pre> -D |
---|
369 | * If set, classifier is run in debug mode and |
---|
370 | * may output additional info to the console</pre> |
---|
371 | * |
---|
372 | <!-- options-end --> |
---|
373 | * |
---|
374 | * @param options the options to parse |
---|
375 | * @throws Exception if parsing fails |
---|
376 | */ |
---|
377 | public void setOptions(String[] options) throws Exception { |
---|
378 | String tmpStr; |
---|
379 | |
---|
380 | tmpStr = Utils.getOption('S', options); |
---|
381 | if (tmpStr.length() != 0) |
---|
382 | setSVMType( |
---|
383 | new SelectedTag(Integer.parseInt(tmpStr), TAGS_SVMTYPE)); |
---|
384 | else |
---|
385 | setSVMType( |
---|
386 | new SelectedTag(SVMTYPE_L2LOSS_SVM_DUAL, TAGS_SVMTYPE)); |
---|
387 | |
---|
388 | tmpStr = Utils.getOption('C', options); |
---|
389 | if (tmpStr.length() != 0) |
---|
390 | setCost(Double.parseDouble(tmpStr)); |
---|
391 | else |
---|
392 | setCost(1); |
---|
393 | |
---|
394 | tmpStr = Utils.getOption('E', options); |
---|
395 | if (tmpStr.length() != 0) |
---|
396 | setEps(Double.parseDouble(tmpStr)); |
---|
397 | else |
---|
398 | setEps(1e-3); |
---|
399 | |
---|
400 | setNormalize(Utils.getFlag('Z', options)); |
---|
401 | |
---|
402 | setConvertNominalToBinary(Utils.getFlag('N', options)); |
---|
403 | setDoNotReplaceMissingValues(Utils.getFlag('M', options)); |
---|
404 | |
---|
405 | tmpStr = Utils.getOption('B', options); |
---|
406 | if (tmpStr.length() != 0) |
---|
407 | setBias(Double.parseDouble(tmpStr)); |
---|
408 | else |
---|
409 | setBias(1); |
---|
410 | |
---|
411 | setWeights(Utils.getOption('W', options)); |
---|
412 | |
---|
413 | setProbabilityEstimates(Utils.getFlag('P', options)); |
---|
414 | |
---|
415 | super.setOptions(options); |
---|
416 | } |
---|
417 | |
---|
418 | /** |
---|
419 | * Returns the current options |
---|
420 | * |
---|
421 | * @return the current setup |
---|
422 | */ |
---|
423 | public String[] getOptions() { |
---|
424 | Vector result; |
---|
425 | |
---|
426 | result = new Vector(); |
---|
427 | |
---|
428 | result.add("-S"); |
---|
429 | result.add("" + m_SVMType); |
---|
430 | |
---|
431 | result.add("-C"); |
---|
432 | result.add("" + getCost()); |
---|
433 | |
---|
434 | result.add("-E"); |
---|
435 | result.add("" + getEps()); |
---|
436 | |
---|
437 | result.add("-B"); |
---|
438 | result.add("" + getBias()); |
---|
439 | |
---|
440 | if (getNormalize()) |
---|
441 | result.add("-Z"); |
---|
442 | |
---|
443 | if (getConvertNominalToBinary()) |
---|
444 | result.add("-N"); |
---|
445 | |
---|
446 | if (getDoNotReplaceMissingValues()) |
---|
447 | result.add("-M"); |
---|
448 | |
---|
449 | if (getWeights().length() != 0) { |
---|
450 | result.add("-W"); |
---|
451 | result.add("" + getWeights()); |
---|
452 | } |
---|
453 | |
---|
454 | if (getProbabilityEstimates()) |
---|
455 | result.add("-P"); |
---|
456 | |
---|
457 | return (String[]) result.toArray(new String[result.size()]); |
---|
458 | } |
---|
459 | |
---|
460 | /** |
---|
461 | * returns whether the liblinear classes are present or not, i.e. whether the |
---|
462 | * classes are in the classpath or not |
---|
463 | * |
---|
464 | * @return whether the liblinear classes are available |
---|
465 | */ |
---|
466 | public static boolean isPresent() { |
---|
467 | return m_Present; |
---|
468 | } |
---|
469 | |
---|
470 | /** |
---|
471 | * Sets type of SVM (default SVMTYPE_L2) |
---|
472 | * |
---|
473 | * @param value the type of the SVM |
---|
474 | */ |
---|
475 | public void setSVMType(SelectedTag value) { |
---|
476 | if (value.getTags() == TAGS_SVMTYPE) |
---|
477 | m_SVMType = value.getSelectedTag().getID(); |
---|
478 | } |
---|
479 | |
---|
480 | /** |
---|
481 | * Gets type of SVM |
---|
482 | * |
---|
483 | * @return the type of the SVM |
---|
484 | */ |
---|
485 | public SelectedTag getSVMType() { |
---|
486 | return new SelectedTag(m_SVMType, TAGS_SVMTYPE); |
---|
487 | } |
---|
488 | |
---|
489 | /** |
---|
490 | * Returns the tip text for this property |
---|
491 | * |
---|
492 | * @return tip text for this property suitable for |
---|
493 | * displaying in the explorer/experimenter gui |
---|
494 | */ |
---|
495 | public String SVMTypeTipText() { |
---|
496 | return "The type of SVM to use."; |
---|
497 | } |
---|
498 | |
---|
499 | /** |
---|
500 | * Sets the cost parameter C (default 1) |
---|
501 | * |
---|
502 | * @param value the cost value |
---|
503 | */ |
---|
504 | public void setCost(double value) { |
---|
505 | m_Cost = value; |
---|
506 | } |
---|
507 | |
---|
508 | /** |
---|
509 | * Returns the cost parameter C |
---|
510 | * |
---|
511 | * @return the cost value |
---|
512 | */ |
---|
513 | public double getCost() { |
---|
514 | return m_Cost; |
---|
515 | } |
---|
516 | |
---|
517 | /** |
---|
518 | * Returns the tip text for this property |
---|
519 | * |
---|
520 | * @return tip text for this property suitable for |
---|
521 | * displaying in the explorer/experimenter gui |
---|
522 | */ |
---|
523 | public String costTipText() { |
---|
524 | return "The cost parameter C."; |
---|
525 | } |
---|
526 | |
---|
527 | /** |
---|
528 | * Sets tolerance of termination criterion (default 0.001) |
---|
529 | * |
---|
530 | * @param value the tolerance |
---|
531 | */ |
---|
532 | public void setEps(double value) { |
---|
533 | m_eps = value; |
---|
534 | } |
---|
535 | |
---|
536 | /** |
---|
537 | * Gets tolerance of termination criterion |
---|
538 | * |
---|
539 | * @return the current tolerance |
---|
540 | */ |
---|
541 | public double getEps() { |
---|
542 | return m_eps; |
---|
543 | } |
---|
544 | |
---|
545 | /** |
---|
546 | * Returns the tip text for this property |
---|
547 | * |
---|
548 | * @return tip text for this property suitable for |
---|
549 | * displaying in the explorer/experimenter gui |
---|
550 | */ |
---|
551 | public String epsTipText() { |
---|
552 | return "The tolerance of the termination criterion."; |
---|
553 | } |
---|
554 | |
---|
555 | /** |
---|
556 | * Sets bias term value (default 1) |
---|
557 | * No bias term is added if value < 0 |
---|
558 | * |
---|
559 | * @param value the bias term value |
---|
560 | */ |
---|
561 | public void setBias(double value) { |
---|
562 | m_Bias = value; |
---|
563 | } |
---|
564 | |
---|
565 | /** |
---|
566 | * Returns bias term value (default 1) |
---|
567 | * No bias term is added if value < 0 |
---|
568 | * |
---|
569 | * @return the bias term value |
---|
570 | */ |
---|
571 | public double getBias() { |
---|
572 | return m_Bias; |
---|
573 | } |
---|
574 | |
---|
575 | /** |
---|
576 | * Returns the tip text for this property |
---|
577 | * |
---|
578 | * @return tip text for this property suitable for |
---|
579 | * displaying in the explorer/experimenter gui |
---|
580 | */ |
---|
581 | public String biasTipText() { |
---|
582 | return "If >= 0, a bias term with that value is added; " + |
---|
583 | "otherwise (<0) no bias term is added (default: 1)."; |
---|
584 | } |
---|
585 | |
---|
586 | /** |
---|
587 | * Returns the tip text for this property |
---|
588 | * |
---|
589 | * @return tip text for this property suitable for |
---|
590 | * displaying in the explorer/experimenter gui |
---|
591 | */ |
---|
592 | public String normalizeTipText() { |
---|
593 | return "Whether to normalize the data."; |
---|
594 | } |
---|
595 | |
---|
596 | /** |
---|
597 | * whether to normalize input data |
---|
598 | * |
---|
599 | * @param value whether to normalize the data |
---|
600 | */ |
---|
601 | public void setNormalize(boolean value) { |
---|
602 | m_Normalize = value; |
---|
603 | } |
---|
604 | |
---|
605 | /** |
---|
606 | * whether to normalize input data |
---|
607 | * |
---|
608 | * @return true, if the data is normalized |
---|
609 | */ |
---|
610 | public boolean getNormalize() { |
---|
611 | return m_Normalize; |
---|
612 | } |
---|
613 | |
---|
614 | /** |
---|
615 | * Returns the tip text for this property |
---|
616 | * |
---|
617 | * @return tip text for this property suitable for |
---|
618 | * displaying in the explorer/experimenter gui |
---|
619 | */ |
---|
620 | public String convertNominalToBinaryTipText() { |
---|
621 | return "Whether to turn on conversion of nominal attributes " |
---|
622 | + "to binary."; |
---|
623 | } |
---|
624 | |
---|
625 | /** |
---|
626 | * Whether to turn on conversion of nominal attributes |
---|
627 | * to binary. |
---|
628 | * |
---|
629 | * @param b true if nominal to binary conversion is to be |
---|
630 | * turned on |
---|
631 | */ |
---|
632 | public void setConvertNominalToBinary(boolean b) { |
---|
633 | m_nominalToBinary = b; |
---|
634 | } |
---|
635 | |
---|
636 | /** |
---|
637 | * Gets whether conversion of nominal to binary is |
---|
638 | * turned on. |
---|
639 | * |
---|
640 | * @return true if nominal to binary conversion is turned |
---|
641 | * on. |
---|
642 | */ |
---|
643 | public boolean getConvertNominalToBinary() { |
---|
644 | return m_nominalToBinary; |
---|
645 | } |
---|
646 | |
---|
647 | /** |
---|
648 | * Returns the tip text for this property |
---|
649 | * |
---|
650 | * @return tip text for this property suitable for |
---|
651 | * displaying in the explorer/experimenter gui |
---|
652 | */ |
---|
653 | public String doNotReplaceMissingValuesTipText() { |
---|
654 | return "Whether to turn off automatic replacement of missing " |
---|
655 | + "values. WARNING: set to true only if the data does not " |
---|
656 | + "contain missing values."; |
---|
657 | } |
---|
658 | |
---|
659 | /** |
---|
660 | * Whether to turn off automatic replacement of missing values. |
---|
661 | * Set to true only if the data does not contain missing values. |
---|
662 | * |
---|
663 | * @param b true if automatic missing values replacement is |
---|
664 | * to be disabled. |
---|
665 | */ |
---|
666 | public void setDoNotReplaceMissingValues(boolean b) { |
---|
667 | m_noReplaceMissingValues = b; |
---|
668 | } |
---|
669 | |
---|
670 | /** |
---|
671 | * Gets whether automatic replacement of missing values is |
---|
672 | * disabled. |
---|
673 | * |
---|
674 | * @return true if automatic replacement of missing values |
---|
675 | * is disabled. |
---|
676 | */ |
---|
677 | public boolean getDoNotReplaceMissingValues() { |
---|
678 | return m_noReplaceMissingValues; |
---|
679 | } |
---|
680 | |
---|
681 | /** |
---|
682 | * Sets the parameters C of class i to weight[i]*C (default 1). |
---|
683 | * Blank separated list of doubles. |
---|
684 | * |
---|
685 | * @param weightsStr the weights (doubles, separated by blanks) |
---|
686 | */ |
---|
687 | public void setWeights(String weightsStr) { |
---|
688 | StringTokenizer tok; |
---|
689 | int i; |
---|
690 | |
---|
691 | tok = new StringTokenizer(weightsStr, " "); |
---|
692 | m_Weight = new double[tok.countTokens()]; |
---|
693 | m_WeightLabel = new int[tok.countTokens()]; |
---|
694 | |
---|
695 | if (m_Weight.length == 0) |
---|
696 | System.out.println( |
---|
697 | "Zero Weights processed. Default weights will be used"); |
---|
698 | |
---|
699 | for (i = 0; i < m_Weight.length; i++) { |
---|
700 | m_Weight[i] = Double.parseDouble(tok.nextToken()); |
---|
701 | m_WeightLabel[i] = i; |
---|
702 | } |
---|
703 | } |
---|
704 | |
---|
705 | /** |
---|
706 | * Gets the parameters C of class i to weight[i]*C (default 1). |
---|
707 | * Blank separated doubles. |
---|
708 | * |
---|
709 | * @return the weights (doubles separated by blanks) |
---|
710 | */ |
---|
711 | public String getWeights() { |
---|
712 | String result; |
---|
713 | int i; |
---|
714 | |
---|
715 | result = ""; |
---|
716 | for (i = 0; i < m_Weight.length; i++) { |
---|
717 | if (i > 0) |
---|
718 | result += " "; |
---|
719 | result += Double.toString(m_Weight[i]); |
---|
720 | } |
---|
721 | |
---|
722 | return result; |
---|
723 | } |
---|
724 | |
---|
725 | /** |
---|
726 | * Returns the tip text for this property |
---|
727 | * |
---|
728 | * @return tip text for this property suitable for |
---|
729 | * displaying in the explorer/experimenter gui |
---|
730 | */ |
---|
731 | public String weightsTipText() { |
---|
732 | return "The weights to use for the classes, if empty 1 is used by default."; |
---|
733 | } |
---|
734 | |
---|
735 | /** |
---|
736 | * Returns whether probability estimates are generated instead of -1/+1 for |
---|
737 | * classification problems. |
---|
738 | * |
---|
739 | * @param value whether to predict probabilities |
---|
740 | */ |
---|
741 | public void setProbabilityEstimates(boolean value) { |
---|
742 | m_ProbabilityEstimates = value; |
---|
743 | } |
---|
744 | |
---|
745 | /** |
---|
746 | * Sets whether to generate probability estimates instead of -1/+1 for |
---|
747 | * classification problems. |
---|
748 | * |
---|
749 | * @return true, if probability estimates should be returned |
---|
750 | */ |
---|
751 | public boolean getProbabilityEstimates() { |
---|
752 | return m_ProbabilityEstimates; |
---|
753 | } |
---|
754 | |
---|
755 | /** |
---|
756 | * Returns the tip text for this property |
---|
757 | * |
---|
758 | * @return tip text for this property suitable for |
---|
759 | * displaying in the explorer/experimenter gui |
---|
760 | */ |
---|
761 | public String probabilityEstimatesTipText() { |
---|
762 | return "Whether to generate probability estimates instead of -1/+1 for classification problems " + |
---|
763 | "(currently for L2-regularized logistic regression only!)"; |
---|
764 | } |
---|
765 | |
---|
766 | /** |
---|
767 | * sets the specified field |
---|
768 | * |
---|
769 | * @param o the object to set the field for |
---|
770 | * @param name the name of the field |
---|
771 | * @param value the new value of the field |
---|
772 | */ |
---|
773 | protected void setField(Object o, String name, Object value) { |
---|
774 | Field f; |
---|
775 | |
---|
776 | try { |
---|
777 | f = o.getClass().getField(name); |
---|
778 | f.set(o, value); |
---|
779 | } |
---|
780 | catch (Exception e) { |
---|
781 | e.printStackTrace(); |
---|
782 | } |
---|
783 | } |
---|
784 | |
---|
785 | /** |
---|
786 | * sets the specified field in an array |
---|
787 | * |
---|
788 | * @param o the object to set the field for |
---|
789 | * @param name the name of the field |
---|
790 | * @param index the index in the array |
---|
791 | * @param value the new value of the field |
---|
792 | */ |
---|
793 | protected void setField(Object o, String name, int index, Object value) { |
---|
794 | Field f; |
---|
795 | |
---|
796 | try { |
---|
797 | f = o.getClass().getField(name); |
---|
798 | Array.set(f.get(o), index, value); |
---|
799 | } |
---|
800 | catch (Exception e) { |
---|
801 | e.printStackTrace(); |
---|
802 | } |
---|
803 | } |
---|
804 | |
---|
805 | /** |
---|
806 | * returns the current value of the specified field |
---|
807 | * |
---|
808 | * @param o the object the field is member of |
---|
809 | * @param name the name of the field |
---|
810 | * @return the value |
---|
811 | */ |
---|
812 | protected Object getField(Object o, String name) { |
---|
813 | Field f; |
---|
814 | Object result; |
---|
815 | |
---|
816 | try { |
---|
817 | f = o.getClass().getField(name); |
---|
818 | result = f.get(o); |
---|
819 | } |
---|
820 | catch (Exception e) { |
---|
821 | e.printStackTrace(); |
---|
822 | result = null; |
---|
823 | } |
---|
824 | |
---|
825 | return result; |
---|
826 | } |
---|
827 | |
---|
828 | /** |
---|
829 | * sets a new array for the field |
---|
830 | * |
---|
831 | * @param o the object to set the array for |
---|
832 | * @param name the name of the field |
---|
833 | * @param type the type of the array |
---|
834 | * @param length the length of the one-dimensional array |
---|
835 | */ |
---|
836 | protected void newArray(Object o, String name, Class type, int length) { |
---|
837 | newArray(o, name, type, new int[]{length}); |
---|
838 | } |
---|
839 | |
---|
840 | /** |
---|
841 | * sets a new array for the field |
---|
842 | * |
---|
843 | * @param o the object to set the array for |
---|
844 | * @param name the name of the field |
---|
845 | * @param type the type of the array |
---|
846 | * @param dimensions the dimensions of the array |
---|
847 | */ |
---|
848 | protected void newArray(Object o, String name, Class type, int[] dimensions) { |
---|
849 | Field f; |
---|
850 | |
---|
851 | try { |
---|
852 | f = o.getClass().getField(name); |
---|
853 | f.set(o, Array.newInstance(type, dimensions)); |
---|
854 | } |
---|
855 | catch (Exception e) { |
---|
856 | e.printStackTrace(); |
---|
857 | } |
---|
858 | } |
---|
859 | |
---|
860 | /** |
---|
861 | * executes the specified method and returns the result, if any |
---|
862 | * |
---|
863 | * @param o the object the method should be called from |
---|
864 | * @param name the name of the method |
---|
865 | * @param paramClasses the classes of the parameters |
---|
866 | * @param paramValues the values of the parameters |
---|
867 | * @return the return value of the method, if any (in that case null) |
---|
868 | */ |
---|
869 | protected Object invokeMethod(Object o, String name, Class[] paramClasses, Object[] paramValues) { |
---|
870 | Method m; |
---|
871 | Object result; |
---|
872 | |
---|
873 | result = null; |
---|
874 | |
---|
875 | try { |
---|
876 | m = o.getClass().getMethod(name, paramClasses); |
---|
877 | result = m.invoke(o, paramValues); |
---|
878 | } |
---|
879 | catch (Exception e) { |
---|
880 | e.printStackTrace(); |
---|
881 | result = null; |
---|
882 | } |
---|
883 | |
---|
884 | return result; |
---|
885 | } |
---|
886 | |
---|
887 | /** |
---|
888 | * transfers the local variables into a svm_parameter object |
---|
889 | * |
---|
890 | * @return the configured svm_parameter object |
---|
891 | */ |
---|
892 | protected Object getParameters() { |
---|
893 | Object result; |
---|
894 | int i; |
---|
895 | |
---|
896 | try { |
---|
897 | Class solverTypeEnumClass = Class.forName(CLASS_SOLVERTYPE); |
---|
898 | Object[] enumValues = solverTypeEnumClass.getEnumConstants(); |
---|
899 | Object solverType = enumValues[m_SVMType]; |
---|
900 | |
---|
901 | Class[] constructorClasses = new Class[] { solverTypeEnumClass, double.class, double.class }; |
---|
902 | Constructor parameterConstructor = Class.forName(CLASS_PARAMETER).getConstructor(constructorClasses); |
---|
903 | |
---|
904 | result = parameterConstructor.newInstance(solverType, Double.valueOf(m_Cost), |
---|
905 | Double.valueOf(m_eps)); |
---|
906 | |
---|
907 | if (m_Weight.length > 0) { |
---|
908 | invokeMethod(result, "setWeights", new Class[] { double[].class, int[].class }, |
---|
909 | new Object[] { m_Weight, m_WeightLabel }); |
---|
910 | } |
---|
911 | } |
---|
912 | catch (Exception e) { |
---|
913 | e.printStackTrace(); |
---|
914 | result = null; |
---|
915 | } |
---|
916 | |
---|
917 | return result; |
---|
918 | } |
---|
919 | |
---|
920 | /** |
---|
921 | * returns the svm_problem |
---|
922 | * |
---|
923 | * @param vx the x values |
---|
924 | * @param vy the y values |
---|
925 | * @param max_index |
---|
926 | * @return the Problem object |
---|
927 | */ |
---|
928 | protected Object getProblem(List<Object> vx, List<Integer> vy, int max_index) { |
---|
929 | Object result; |
---|
930 | |
---|
931 | try { |
---|
932 | result = Class.forName(CLASS_PROBLEM).newInstance(); |
---|
933 | |
---|
934 | setField(result, "l", Integer.valueOf(vy.size())); |
---|
935 | setField(result, "n", Integer.valueOf(max_index)); |
---|
936 | setField(result, "bias", getBias()); |
---|
937 | |
---|
938 | newArray(result, "x", Class.forName(CLASS_FEATURENODE), new int[]{vy.size(), 0}); |
---|
939 | for (int i = 0; i < vy.size(); i++) |
---|
940 | setField(result, "x", i, vx.get(i)); |
---|
941 | |
---|
942 | newArray(result, "y", Integer.TYPE, vy.size()); |
---|
943 | for (int i = 0; i < vy.size(); i++) |
---|
944 | setField(result, "y", i, vy.get(i)); |
---|
945 | } |
---|
946 | catch (Exception e) { |
---|
947 | e.printStackTrace(); |
---|
948 | result = null; |
---|
949 | } |
---|
950 | |
---|
951 | return result; |
---|
952 | } |
---|
953 | |
---|
954 | /** |
---|
955 | * returns an instance into a sparse liblinear array |
---|
956 | * |
---|
957 | * @param instance the instance to work on |
---|
958 | * @return the liblinear array |
---|
959 | * @throws Exception if setup of array fails |
---|
960 | */ |
---|
961 | protected Object instanceToArray(Instance instance) throws Exception { |
---|
962 | int index; |
---|
963 | int count; |
---|
964 | int i; |
---|
965 | Object result; |
---|
966 | |
---|
967 | // determine number of non-zero attributes |
---|
968 | count = 0; |
---|
969 | |
---|
970 | for (i = 0; i < instance.numValues(); i++) { |
---|
971 | if (instance.index(i) == instance.classIndex()) |
---|
972 | continue; |
---|
973 | if (instance.valueSparse(i) != 0) |
---|
974 | count++; |
---|
975 | } |
---|
976 | |
---|
977 | if (m_Bias >= 0) { |
---|
978 | count++; |
---|
979 | } |
---|
980 | |
---|
981 | Class[] intDouble = new Class[] { int.class, double.class }; |
---|
982 | Constructor nodeConstructor = Class.forName(CLASS_FEATURENODE).getConstructor(intDouble); |
---|
983 | |
---|
984 | // fill array |
---|
985 | result = Array.newInstance(Class.forName(CLASS_FEATURENODE), count); |
---|
986 | index = 0; |
---|
987 | for (i = 0; i < instance.numValues(); i++) { |
---|
988 | |
---|
989 | int idx = instance.index(i); |
---|
990 | double val = instance.valueSparse(i); |
---|
991 | |
---|
992 | if (idx == instance.classIndex()) |
---|
993 | continue; |
---|
994 | if (val == 0) |
---|
995 | continue; |
---|
996 | |
---|
997 | Object node = nodeConstructor.newInstance(Integer.valueOf(idx+1), Double.valueOf(val)); |
---|
998 | Array.set(result, index, node); |
---|
999 | index++; |
---|
1000 | } |
---|
1001 | |
---|
1002 | // add bias term |
---|
1003 | if (m_Bias >= 0) { |
---|
1004 | Integer idx = Integer.valueOf(instance.numAttributes()+1); |
---|
1005 | Double value = Double.valueOf(m_Bias); |
---|
1006 | Object node = nodeConstructor.newInstance(idx, value); |
---|
1007 | Array.set(result, index, node); |
---|
1008 | } |
---|
1009 | |
---|
1010 | return result; |
---|
1011 | } |
---|
1012 | /** |
---|
1013 | * Computes the distribution for a given instance. |
---|
1014 | * |
---|
1015 | * @param instance the instance for which distribution is computed |
---|
1016 | * @return the distribution |
---|
1017 | * @throws Exception if the distribution can't be computed successfully |
---|
1018 | */ |
---|
1019 | public double[] distributionForInstance (Instance instance) throws Exception { |
---|
1020 | |
---|
1021 | if (!getDoNotReplaceMissingValues()) { |
---|
1022 | m_ReplaceMissingValues.input(instance); |
---|
1023 | m_ReplaceMissingValues.batchFinished(); |
---|
1024 | instance = m_ReplaceMissingValues.output(); |
---|
1025 | } |
---|
1026 | |
---|
1027 | if (getConvertNominalToBinary() |
---|
1028 | && m_NominalToBinary != null) { |
---|
1029 | m_NominalToBinary.input(instance); |
---|
1030 | m_NominalToBinary.batchFinished(); |
---|
1031 | instance = m_NominalToBinary.output(); |
---|
1032 | } |
---|
1033 | |
---|
1034 | if (m_Filter != null) { |
---|
1035 | m_Filter.input(instance); |
---|
1036 | m_Filter.batchFinished(); |
---|
1037 | instance = m_Filter.output(); |
---|
1038 | } |
---|
1039 | |
---|
1040 | Object x = instanceToArray(instance); |
---|
1041 | double v; |
---|
1042 | double[] result = new double[instance.numClasses()]; |
---|
1043 | if (m_ProbabilityEstimates) { |
---|
1044 | if (m_SVMType != SVMTYPE_L2_LR) { |
---|
1045 | throw new WekaException("probability estimation is currently only " + |
---|
1046 | "supported for L2-regularized logistic regression"); |
---|
1047 | } |
---|
1048 | |
---|
1049 | int[] labels = (int[])invokeMethod(m_Model, "getLabels", null, null); |
---|
1050 | double[] prob_estimates = new double[instance.numClasses()]; |
---|
1051 | |
---|
1052 | v = ((Integer) invokeMethod( |
---|
1053 | Class.forName(CLASS_LINEAR).newInstance(), |
---|
1054 | "predictProbability", |
---|
1055 | new Class[]{ |
---|
1056 | Class.forName(CLASS_MODEL), |
---|
1057 | Array.newInstance(Class.forName(CLASS_FEATURENODE), Array.getLength(x)).getClass(), |
---|
1058 | Array.newInstance(Double.TYPE, prob_estimates.length).getClass()}, |
---|
1059 | new Object[]{ m_Model, x, prob_estimates})).doubleValue(); |
---|
1060 | |
---|
1061 | // Return order of probabilities to canonical weka attribute order |
---|
1062 | for (int k = 0; k < prob_estimates.length; k++) { |
---|
1063 | result[labels[k]] = prob_estimates[k]; |
---|
1064 | } |
---|
1065 | } |
---|
1066 | else { |
---|
1067 | v = ((Integer) invokeMethod( |
---|
1068 | Class.forName(CLASS_LINEAR).newInstance(), |
---|
1069 | "predict", |
---|
1070 | new Class[]{ |
---|
1071 | Class.forName(CLASS_MODEL), |
---|
1072 | Array.newInstance(Class.forName(CLASS_FEATURENODE), Array.getLength(x)).getClass()}, |
---|
1073 | new Object[]{ |
---|
1074 | m_Model, |
---|
1075 | x})).doubleValue(); |
---|
1076 | |
---|
1077 | assert (instance.classAttribute().isNominal()); |
---|
1078 | result[(int) v] = 1; |
---|
1079 | } |
---|
1080 | |
---|
1081 | return result; |
---|
1082 | } |
---|
1083 | |
---|
1084 | /** |
---|
1085 | * Returns default capabilities of the classifier. |
---|
1086 | * |
---|
1087 | * @return the capabilities of this classifier |
---|
1088 | */ |
---|
1089 | public Capabilities getCapabilities() { |
---|
1090 | Capabilities result = super.getCapabilities(); |
---|
1091 | result.disableAll(); |
---|
1092 | |
---|
1093 | // attributes |
---|
1094 | result.enable(Capability.NOMINAL_ATTRIBUTES); |
---|
1095 | result.enable(Capability.NUMERIC_ATTRIBUTES); |
---|
1096 | result.enable(Capability.DATE_ATTRIBUTES); |
---|
1097 | // result.enable(Capability.MISSING_VALUES); |
---|
1098 | |
---|
1099 | // class |
---|
1100 | result.enable(Capability.NOMINAL_CLASS); |
---|
1101 | result.enable(Capability.MISSING_CLASS_VALUES); |
---|
1102 | return result; |
---|
1103 | } |
---|
1104 | |
---|
1105 | /** |
---|
1106 | * builds the classifier |
---|
1107 | * |
---|
1108 | * @param insts the training instances |
---|
1109 | * @throws Exception if liblinear classes not in classpath or liblinear |
---|
1110 | * encountered a problem |
---|
1111 | */ |
---|
1112 | public void buildClassifier(Instances insts) throws Exception { |
---|
1113 | m_NominalToBinary = null; |
---|
1114 | m_Filter = null; |
---|
1115 | |
---|
1116 | if (!isPresent()) |
---|
1117 | throw new Exception("liblinear classes not in CLASSPATH!"); |
---|
1118 | |
---|
1119 | // remove instances with missing class |
---|
1120 | insts = new Instances(insts); |
---|
1121 | insts.deleteWithMissingClass(); |
---|
1122 | |
---|
1123 | if (!getDoNotReplaceMissingValues()) { |
---|
1124 | m_ReplaceMissingValues = new ReplaceMissingValues(); |
---|
1125 | m_ReplaceMissingValues.setInputFormat(insts); |
---|
1126 | insts = Filter.useFilter(insts, m_ReplaceMissingValues); |
---|
1127 | } |
---|
1128 | |
---|
1129 | // can classifier handle the data? |
---|
1130 | // we check this here so that if the user turns off |
---|
1131 | // replace missing values filtering, it will fail |
---|
1132 | // if the data actually does have missing values |
---|
1133 | getCapabilities().testWithFail(insts); |
---|
1134 | |
---|
1135 | if (getConvertNominalToBinary()) { |
---|
1136 | insts = nominalToBinary(insts); |
---|
1137 | } |
---|
1138 | |
---|
1139 | if (getNormalize()) { |
---|
1140 | m_Filter = new Normalize(); |
---|
1141 | m_Filter.setInputFormat(insts); |
---|
1142 | insts = Filter.useFilter(insts, m_Filter); |
---|
1143 | } |
---|
1144 | |
---|
1145 | List<Integer> vy = new ArrayList<Integer>(insts.numInstances()); |
---|
1146 | List<Object> vx = new ArrayList<Object>(insts.numInstances()); |
---|
1147 | int max_index = 0; |
---|
1148 | |
---|
1149 | for (int d = 0; d < insts.numInstances(); d++) { |
---|
1150 | Instance inst = insts.instance(d); |
---|
1151 | Object x = instanceToArray(inst); |
---|
1152 | int m = Array.getLength(x); |
---|
1153 | if (m > 0) |
---|
1154 | max_index = Math.max(max_index, ((Integer) getField(Array.get(x, m - 1), "index")).intValue()); |
---|
1155 | vx.add(x); |
---|
1156 | double classValue = inst.classValue(); |
---|
1157 | int classValueInt = (int)classValue; |
---|
1158 | if (classValueInt != classValue) throw new RuntimeException("unsupported class value: " + classValue); |
---|
1159 | vy.add(Integer.valueOf(classValueInt)); |
---|
1160 | } |
---|
1161 | |
---|
1162 | if (!m_Debug) { |
---|
1163 | invokeMethod( |
---|
1164 | Class.forName(CLASS_LINEAR).newInstance(), |
---|
1165 | "disableDebugOutput", null, null); |
---|
1166 | } else { |
---|
1167 | invokeMethod( |
---|
1168 | Class.forName(CLASS_LINEAR).newInstance(), |
---|
1169 | "enableDebugOutput", null, null); |
---|
1170 | } |
---|
1171 | |
---|
1172 | // reset the PRNG for regression-stable results |
---|
1173 | invokeMethod( |
---|
1174 | Class.forName(CLASS_LINEAR).newInstance(), |
---|
1175 | "resetRandom", null, null); |
---|
1176 | |
---|
1177 | // train model |
---|
1178 | m_Model = invokeMethod( |
---|
1179 | Class.forName(CLASS_LINEAR).newInstance(), |
---|
1180 | "train", |
---|
1181 | new Class[]{ |
---|
1182 | Class.forName(CLASS_PROBLEM), |
---|
1183 | Class.forName(CLASS_PARAMETER)}, |
---|
1184 | new Object[]{ |
---|
1185 | getProblem(vx, vy, max_index), |
---|
1186 | getParameters()}); |
---|
1187 | } |
---|
1188 | |
---|
1189 | /** |
---|
1190 | * turns on nominal to binary filtering |
---|
1191 | * if there are not only numeric attributes |
---|
1192 | */ |
---|
1193 | private Instances nominalToBinary( Instances insts ) throws Exception { |
---|
1194 | boolean onlyNumeric = true; |
---|
1195 | for (int i = 0; i < insts.numAttributes(); i++) { |
---|
1196 | if (i != insts.classIndex()) { |
---|
1197 | if (!insts.attribute(i).isNumeric()) { |
---|
1198 | onlyNumeric = false; |
---|
1199 | break; |
---|
1200 | } |
---|
1201 | } |
---|
1202 | } |
---|
1203 | |
---|
1204 | if (!onlyNumeric) { |
---|
1205 | m_NominalToBinary = new NominalToBinary(); |
---|
1206 | m_NominalToBinary.setInputFormat(insts); |
---|
1207 | insts = Filter.useFilter(insts, m_NominalToBinary); |
---|
1208 | } |
---|
1209 | return insts; |
---|
1210 | } |
---|
1211 | |
---|
1212 | /** |
---|
1213 | * returns a string representation |
---|
1214 | * |
---|
1215 | * @return a string representation |
---|
1216 | */ |
---|
1217 | public String toString() { |
---|
1218 | return "LibLINEAR wrapper"; |
---|
1219 | } |
---|
1220 | |
---|
1221 | /** |
---|
1222 | * Returns the revision string. |
---|
1223 | * |
---|
1224 | * @return the revision |
---|
1225 | */ |
---|
1226 | public String getRevision() { |
---|
1227 | return RevisionUtils.extract("$Revision: 5928 $"); |
---|
1228 | } |
---|
1229 | |
---|
1230 | /** |
---|
1231 | * Main method for testing this class. |
---|
1232 | * |
---|
1233 | * @param args the options |
---|
1234 | */ |
---|
1235 | public static void main(String[] args) { |
---|
1236 | runClassifier(new LibLINEAR(), args); |
---|
1237 | } |
---|
1238 | } |
---|
1239 | |
---|