source: src/main/java/weka/classifiers/meta/CVParameterSelection.java @ 6

Last change on this file since 6 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 24.9 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    CVParameterSelection.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.meta;
24
25import weka.classifiers.Evaluation;
26import weka.classifiers.RandomizableSingleClassifierEnhancer;
27import weka.core.Capabilities;
28import weka.core.Drawable;
29import weka.core.FastVector;
30import weka.core.Instance;
31import weka.core.Instances;
32import weka.core.Option;
33import weka.core.OptionHandler;
34import weka.core.RevisionHandler;
35import weka.core.RevisionUtils;
36import weka.core.Summarizable;
37import weka.core.TechnicalInformation;
38import weka.core.TechnicalInformationHandler;
39import weka.core.Utils;
40import weka.core.TechnicalInformation.Field;
41import weka.core.TechnicalInformation.Type;
42
43import java.io.Serializable;
44import java.io.StreamTokenizer;
45import java.io.StringReader;
46import java.util.Enumeration;
47import java.util.Random;
48import java.util.Vector;
49
50/**
51 <!-- globalinfo-start -->
52 * Class for performing parameter selection by cross-validation for any classifier.<br/>
53 * <br/>
54 * For more information, see:<br/>
55 * <br/>
56 * R. Kohavi (1995). Wrappers for Performance Enhancement and Oblivious Decision Graphs. Department of Computer Science, Stanford University.
57 * <p/>
58 <!-- globalinfo-end -->
59 *
60 <!-- technical-bibtex-start -->
61 * BibTeX:
62 * <pre>
63 * &#64;phdthesis{Kohavi1995,
64 *    address = {Department of Computer Science, Stanford University},
65 *    author = {R. Kohavi},
66 *    school = {Stanford University},
67 *    title = {Wrappers for Performance Enhancement and Oblivious Decision Graphs},
68 *    year = {1995}
69 * }
70 * </pre>
71 * <p/>
72 <!-- technical-bibtex-end -->
73 *
74 <!-- options-start -->
75 * Valid options are: <p/>
76 *
77 * <pre> -X &lt;number of folds&gt;
78 *  Number of folds used for cross validation (default 10).</pre>
79 *
80 * <pre> -P &lt;classifier parameter&gt;
81 *  Classifier parameter options.
82 *  eg: "N 1 5 10" Sets an optimisation parameter for the
83 *  classifier with name -N, with lower bound 1, upper bound
84 *  5, and 10 optimisation steps. The upper bound may be the
85 *  character 'A' or 'I' to substitute the number of
86 *  attributes or instances in the training data,
87 *  respectively. This parameter may be supplied more than
88 *  once to optimise over several classifier options
89 *  simultaneously.</pre>
90 *
91 * <pre> -S &lt;num&gt;
92 *  Random number seed.
93 *  (default 1)</pre>
94 *
95 * <pre> -D
96 *  If set, classifier is run in debug mode and
97 *  may output additional info to the console</pre>
98 *
99 * <pre> -W
100 *  Full name of base classifier.
101 *  (default: weka.classifiers.rules.ZeroR)</pre>
102 *
103 * <pre>
104 * Options specific to classifier weka.classifiers.rules.ZeroR:
105 * </pre>
106 *
107 * <pre> -D
108 *  If set, classifier is run in debug mode and
109 *  may output additional info to the console</pre>
110 *
111 <!-- options-end -->
112 *
113 * Options after -- are passed to the designated sub-classifier. <p>
114 *
115 * @author Len Trigg (trigg@cs.waikato.ac.nz)
116 * @version $Revision: 5928 $
117*/
118public class CVParameterSelection 
119  extends RandomizableSingleClassifierEnhancer
120  implements Drawable, Summarizable, TechnicalInformationHandler {
121
122  /** for serialization */
123  static final long serialVersionUID = -6529603380876641265L;
124 
125  /**
126   * A data structure to hold values associated with a single
127   * cross-validation search parameter
128   */
129  protected class CVParameter 
130    implements Serializable, RevisionHandler {
131   
132    /** for serialization */
133    static final long serialVersionUID = -4668812017709421953L;
134
135    /**  Char used to identify the option of interest */
136    private char m_ParamChar;   
137
138    /**  Lower bound for the CV search */
139    private double m_Lower;     
140
141    /**  Upper bound for the CV search */
142    private double m_Upper;     
143
144    /**  Number of steps during the search */
145    private double m_Steps;     
146
147    /**  The parameter value with the best performance */
148    private double m_ParamValue; 
149
150    /**  True if the parameter should be added at the end of the argument list */
151    private boolean m_AddAtEnd; 
152
153    /**  True if the parameter should be rounded to an integer */
154    private boolean m_RoundParam;
155
156    /**
157     * Constructs a CVParameter.
158     *
159     * @param param the parameter definition
160     * @throws Exception if construction of CVParameter fails
161     */
162    public CVParameter(String param) throws Exception {
163     
164      // Tokenize the string into it's parts
165      StreamTokenizer st = new StreamTokenizer(new StringReader(param));
166      if (st.nextToken() != StreamTokenizer.TT_WORD) {
167        throw new Exception("CVParameter " + param
168                            + ": Character parameter identifier expected");
169      }
170      m_ParamChar = st.sval.charAt(0);
171      if (st.nextToken() != StreamTokenizer.TT_NUMBER) {
172        throw new Exception("CVParameter " + param
173                            + ": Numeric lower bound expected");
174      }
175      m_Lower = st.nval;
176      if (st.nextToken() == StreamTokenizer.TT_NUMBER) {
177        m_Upper = st.nval;
178        if (m_Upper < m_Lower) {
179          throw new Exception("CVParameter " + param
180                              + ": Upper bound is less than lower bound");
181        }
182      } else if (st.ttype == StreamTokenizer.TT_WORD) {
183        if (st.sval.toUpperCase().charAt(0) == 'A') {
184          m_Upper = m_Lower - 1;
185        } else if (st.sval.toUpperCase().charAt(0) == 'I') {
186          m_Upper = m_Lower - 2;
187        } else {
188          throw new Exception("CVParameter " + param
189              + ": Upper bound must be numeric, or 'A' or 'N'");
190        }
191      } else {
192        throw new Exception("CVParameter " + param
193              + ": Upper bound must be numeric, or 'A' or 'N'");
194      }
195      if (st.nextToken() != StreamTokenizer.TT_NUMBER) {
196        throw new Exception("CVParameter " + param
197                            + ": Numeric number of steps expected");
198      }
199      m_Steps = st.nval;
200      if (st.nextToken() == StreamTokenizer.TT_WORD) {
201        if (st.sval.toUpperCase().charAt(0) == 'R') {
202          m_RoundParam = true;
203        }
204      }
205    }
206
207    /**
208     * Returns a CVParameter as a string.
209     *
210     * @return the CVParameter as string
211     */
212    public String toString() {
213
214      String result = m_ParamChar + " " + m_Lower + " ";
215      switch ((int)(m_Lower - m_Upper + 0.5)) {
216      case 1:
217        result += "A";
218        break;
219      case 2:
220        result += "I";
221        break;
222      default:
223        result += m_Upper;
224        break;
225      }
226      result += " " + m_Steps;
227      if (m_RoundParam) {
228        result += " R";
229      }
230      return result;
231    }
232   
233    /**
234     * Returns the revision string.
235     *
236     * @return          the revision
237     */
238    public String getRevision() {
239      return RevisionUtils.extract("$Revision: 5928 $");
240    }
241  }
242
243  /**
244   * The base classifier options (not including those being set
245   * by cross-validation)
246   */
247  protected String [] m_ClassifierOptions;
248
249  /** The set of all classifier options as determined by cross-validation */
250  protected String [] m_BestClassifierOptions;
251
252  /** The set of all options at initialization time. So that getOptions
253      can return this. */
254  protected String [] m_InitOptions;
255
256  /** The cross-validated performance of the best options */
257  protected double m_BestPerformance;
258
259  /** The set of parameters to cross-validate over */
260  protected FastVector m_CVParams = new FastVector();
261
262  /** The number of attributes in the data */
263  protected int m_NumAttributes;
264
265  /** The number of instances in a training fold */
266  protected int m_TrainFoldSize;
267 
268  /** The number of folds used in cross-validation */
269  protected int m_NumFolds = 10;
270
271  /**
272   * Create the options array to pass to the classifier. The parameter
273   * values and positions are taken from m_ClassifierOptions and
274   * m_CVParams.
275   *
276   * @return the options array
277   */
278  protected String [] createOptions() {
279   
280    String [] options = new String [m_ClassifierOptions.length 
281                                   + 2 * m_CVParams.size()];
282    int start = 0, end = options.length;
283
284    // Add the cross-validation parameters and their values
285    for (int i = 0; i < m_CVParams.size(); i++) {
286      CVParameter cvParam = (CVParameter)m_CVParams.elementAt(i);
287      double paramValue = cvParam.m_ParamValue;
288      if (cvParam.m_RoundParam) {
289        //      paramValue = (double)((int) (paramValue + 0.5));
290        paramValue = Math.rint(paramValue);
291      }
292      if (cvParam.m_AddAtEnd) {
293        options[--end] = "" + 
294        Utils.doubleToString(paramValue,4);
295        options[--end] = "-" + cvParam.m_ParamChar;
296      } else {
297        options[start++] = "-" + cvParam.m_ParamChar;
298        options[start++] = "" 
299        + Utils.doubleToString(paramValue,4);
300      }
301    }
302    // Add the static parameters
303    System.arraycopy(m_ClassifierOptions, 0,
304                     options, start,
305                     m_ClassifierOptions.length);
306
307    return options;
308  }
309
310  /**
311   * Finds the best parameter combination. (recursive for each parameter
312   * being optimised).
313   *
314   * @param depth the index of the parameter to be optimised at this level
315   * @param trainData the data the search is based on
316   * @param random a random number generator
317   * @throws Exception if an error occurs
318   */
319  protected void findParamsByCrossValidation(int depth, Instances trainData,
320                                             Random random)
321    throws Exception {
322
323    if (depth < m_CVParams.size()) {
324      CVParameter cvParam = (CVParameter)m_CVParams.elementAt(depth);
325
326      double upper;
327      switch ((int)(cvParam.m_Lower - cvParam.m_Upper + 0.5)) {
328      case 1:
329        upper = m_NumAttributes;
330        break;
331      case 2:
332        upper = m_TrainFoldSize;
333        break;
334      default:
335        upper = cvParam.m_Upper;
336        break;
337      }
338      double increment = (upper - cvParam.m_Lower) / (cvParam.m_Steps - 1);
339      for(cvParam.m_ParamValue = cvParam.m_Lower; 
340          cvParam.m_ParamValue <= upper; 
341          cvParam.m_ParamValue += increment) {
342        findParamsByCrossValidation(depth + 1, trainData, random);
343      }
344    } else {
345     
346      Evaluation evaluation = new Evaluation(trainData);
347
348      // Set the classifier options
349      String [] options = createOptions();
350      if (m_Debug) {
351        System.err.print("Setting options for " 
352                         + m_Classifier.getClass().getName() + ":");
353        for (int i = 0; i < options.length; i++) {
354          System.err.print(" " + options[i]);
355        }
356        System.err.println("");
357      }
358      ((OptionHandler)m_Classifier).setOptions(options);
359      for (int j = 0; j < m_NumFolds; j++) {
360
361        // We want to randomize the data the same way for every
362        // learning scheme.
363        Instances train = trainData.trainCV(m_NumFolds, j, new Random(1));
364        Instances test = trainData.testCV(m_NumFolds, j);
365        m_Classifier.buildClassifier(train);
366        evaluation.setPriors(train);
367        evaluation.evaluateModel(m_Classifier, test);
368      }
369      double error = evaluation.errorRate();
370      if (m_Debug) {
371        System.err.println("Cross-validated error rate: " 
372                           + Utils.doubleToString(error, 6, 4));
373      }
374      if ((m_BestPerformance == -99) || (error < m_BestPerformance)) {
375       
376        m_BestPerformance = error;
377        m_BestClassifierOptions = createOptions();
378      }
379    }
380  }
381
382  /**
383   * Returns a string describing this classifier
384   * @return a description of the classifier suitable for
385   * displaying in the explorer/experimenter gui
386   */
387  public String globalInfo() {
388    return    "Class for performing parameter selection by cross-validation "
389            + "for any classifier.\n\n"
390            + "For more information, see:\n\n"
391            + getTechnicalInformation().toString();
392  }
393
394  /**
395   * Returns an instance of a TechnicalInformation object, containing
396   * detailed information about the technical background of this class,
397   * e.g., paper reference or book this class is based on.
398   *
399   * @return the technical information about this class
400   */
401  public TechnicalInformation getTechnicalInformation() {
402    TechnicalInformation        result;
403   
404    result = new TechnicalInformation(Type.PHDTHESIS);
405    result.setValue(Field.AUTHOR, "R. Kohavi");
406    result.setValue(Field.YEAR, "1995");
407    result.setValue(Field.TITLE, "Wrappers for Performance Enhancement and Oblivious Decision Graphs");
408    result.setValue(Field.SCHOOL, "Stanford University");
409    result.setValue(Field.ADDRESS, "Department of Computer Science, Stanford University");
410   
411    return result;
412  }
413
414  /**
415   * Returns an enumeration describing the available options.
416   *
417   * @return an enumeration of all the available options.
418   */
419  public Enumeration listOptions() {
420
421    Vector newVector = new Vector(2);
422
423    newVector.addElement(new Option(
424              "\tNumber of folds used for cross validation (default 10).",
425              "X", 1, "-X <number of folds>"));
426    newVector.addElement(new Option(
427              "\tClassifier parameter options.\n"
428              + "\teg: \"N 1 5 10\" Sets an optimisation parameter for the\n"
429              + "\tclassifier with name -N, with lower bound 1, upper bound\n"
430              + "\t5, and 10 optimisation steps. The upper bound may be the\n"
431              + "\tcharacter 'A' or 'I' to substitute the number of\n"
432              + "\tattributes or instances in the training data,\n"
433              + "\trespectively. This parameter may be supplied more than\n"
434              + "\tonce to optimise over several classifier options\n"
435              + "\tsimultaneously.",
436              "P", 1, "-P <classifier parameter>"));
437
438
439    Enumeration enu = super.listOptions();
440    while (enu.hasMoreElements()) {
441      newVector.addElement(enu.nextElement());
442    }
443    return newVector.elements();
444  }
445
446
447  /**
448   * Parses a given list of options. <p/>
449   *
450   <!-- options-start -->
451   * Valid options are: <p/>
452   *
453   * <pre> -X &lt;number of folds&gt;
454   *  Number of folds used for cross validation (default 10).</pre>
455   *
456   * <pre> -P &lt;classifier parameter&gt;
457   *  Classifier parameter options.
458   *  eg: "N 1 5 10" Sets an optimisation parameter for the
459   *  classifier with name -N, with lower bound 1, upper bound
460   *  5, and 10 optimisation steps. The upper bound may be the
461   *  character 'A' or 'I' to substitute the number of
462   *  attributes or instances in the training data,
463   *  respectively. This parameter may be supplied more than
464   *  once to optimise over several classifier options
465   *  simultaneously.</pre>
466   *
467   * <pre> -S &lt;num&gt;
468   *  Random number seed.
469   *  (default 1)</pre>
470   *
471   * <pre> -D
472   *  If set, classifier is run in debug mode and
473   *  may output additional info to the console</pre>
474   *
475   * <pre> -W
476   *  Full name of base classifier.
477   *  (default: weka.classifiers.rules.ZeroR)</pre>
478   *
479   * <pre>
480   * Options specific to classifier weka.classifiers.rules.ZeroR:
481   * </pre>
482   *
483   * <pre> -D
484   *  If set, classifier is run in debug mode and
485   *  may output additional info to the console</pre>
486   *
487   <!-- options-end -->
488   *
489   * Options after -- are passed to the designated sub-classifier. <p>
490   *
491   * @param options the list of options as an array of strings
492   * @throws Exception if an option is not supported
493   */
494  public void setOptions(String[] options) throws Exception {
495
496    String foldsString = Utils.getOption('X', options);
497    if (foldsString.length() != 0) {
498      setNumFolds(Integer.parseInt(foldsString));
499    } else {
500      setNumFolds(10);
501    }
502
503    String cvParam;
504    m_CVParams = new FastVector();
505    do {
506      cvParam = Utils.getOption('P', options);
507      if (cvParam.length() != 0) {
508        addCVParameter(cvParam);
509      }
510    } while (cvParam.length() != 0);
511
512    super.setOptions(options);
513  }
514
515  /**
516   * Gets the current settings of the Classifier.
517   *
518   * @return an array of strings suitable for passing to setOptions
519   */
520  public String [] getOptions() {
521
522    String[] superOptions;
523
524    if (m_InitOptions != null) {
525      try {
526        ((OptionHandler)m_Classifier).setOptions((String[])m_InitOptions.clone());
527        superOptions = super.getOptions();
528        ((OptionHandler)m_Classifier).setOptions((String[])m_BestClassifierOptions.clone());
529      } catch (Exception e) {
530        throw new RuntimeException("CVParameterSelection: could not set options " +
531                                   "in getOptions().");
532      } 
533    } else {
534      superOptions = super.getOptions();
535    }
536    String [] options = new String [superOptions.length + m_CVParams.size() * 2 + 2];
537
538    int current = 0;
539    for (int i = 0; i < m_CVParams.size(); i++) {
540      options[current++] = "-P"; options[current++] = "" + getCVParameter(i);
541    }
542    options[current++] = "-X"; options[current++] = "" + getNumFolds();
543
544    System.arraycopy(superOptions, 0, options, current, 
545                     superOptions.length);
546
547    return options;
548  }
549
550  /**
551   * Returns (a copy of) the best options found for the classifier.
552   *
553   * @return the best options
554   */
555  public String[] getBestClassifierOptions() {
556    return (String[]) m_BestClassifierOptions.clone();
557  }
558 
559  /**
560   * Returns default capabilities of the classifier.
561   *
562   * @return      the capabilities of this classifier
563   */
564  public Capabilities getCapabilities() {
565    Capabilities result = super.getCapabilities();
566
567    result.setMinimumNumberInstances(m_NumFolds);
568   
569    return result;
570  }
571
572  /**
573   * Generates the classifier.
574   *
575   * @param instances set of instances serving as training data
576   * @throws Exception if the classifier has not been generated successfully
577   */
578  public void buildClassifier(Instances instances) throws Exception {
579
580    // can classifier handle the data?
581    getCapabilities().testWithFail(instances);
582
583    // remove instances with missing class
584    Instances trainData = new Instances(instances);
585    trainData.deleteWithMissingClass();
586   
587    if (!(m_Classifier instanceof OptionHandler)) {
588      throw new IllegalArgumentException("Base classifier should be OptionHandler.");
589    }
590    m_InitOptions = ((OptionHandler)m_Classifier).getOptions();
591    m_BestPerformance = -99;
592    m_NumAttributes = trainData.numAttributes();
593    Random random = new Random(m_Seed);
594    trainData.randomize(random);
595    m_TrainFoldSize = trainData.trainCV(m_NumFolds, 0).numInstances();
596
597    // Check whether there are any parameters to optimize
598    if (m_CVParams.size() == 0) {
599       m_Classifier.buildClassifier(trainData);
600       m_BestClassifierOptions = m_InitOptions;
601       return;
602    }
603
604    if (trainData.classAttribute().isNominal()) {
605      trainData.stratify(m_NumFolds);
606    }
607    m_BestClassifierOptions = null;
608   
609    // Set up m_ClassifierOptions -- take getOptions() and remove
610    // those being optimised.
611    m_ClassifierOptions = ((OptionHandler)m_Classifier).getOptions();
612    for (int i = 0; i < m_CVParams.size(); i++) {
613      Utils.getOption(((CVParameter)m_CVParams.elementAt(i)).m_ParamChar,
614                      m_ClassifierOptions);
615    }
616    findParamsByCrossValidation(0, trainData, random);
617
618    String [] options = (String [])m_BestClassifierOptions.clone();
619    ((OptionHandler)m_Classifier).setOptions(options);
620    m_Classifier.buildClassifier(trainData);
621  }
622
623
624  /**
625   * Predicts the class distribution for the given test instance.
626   *
627   * @param instance the instance to be classified
628   * @return the predicted class value
629   * @throws Exception if an error occurred during the prediction
630   */
631  public double[] distributionForInstance(Instance instance) throws Exception {
632   
633    return m_Classifier.distributionForInstance(instance);
634  }
635
636  /**
637   * Adds a scheme parameter to the list of parameters to be set
638   * by cross-validation
639   *
640   * @param cvParam the string representation of a scheme parameter. The
641   * format is: <br>
642   * param_char lower_bound upper_bound number_of_steps <br>
643   * eg to search a parameter -P from 1 to 10 by increments of 1: <br>
644   * P 1 10 11 <br>
645   * @throws Exception if the parameter specifier is of the wrong format
646   */
647  public void addCVParameter(String cvParam) throws Exception {
648
649    CVParameter newCV = new CVParameter(cvParam);
650   
651    m_CVParams.addElement(newCV);
652  }
653
654  /**
655   * Gets the scheme paramter with the given index.
656   *
657   * @param index the index for the parameter
658   * @return the scheme parameter
659   */
660  public String getCVParameter(int index) {
661
662    if (m_CVParams.size() <= index) {
663      return "";
664    }
665    return ((CVParameter)m_CVParams.elementAt(index)).toString();
666  }
667
668  /**
669   * Returns the tip text for this property
670   * @return tip text for this property suitable for
671   * displaying in the explorer/experimenter gui
672   */
673  public String CVParametersTipText() {
674    return "Sets the scheme parameters which are to be set "+
675           "by cross-validation.\n"+
676           "The format for each string should be:\n"+
677           "param_char lower_bound upper_bound number_of_steps\n"+
678           "eg to search a parameter -P from 1 to 10 by increments of 1:\n"+
679           "    \"P 1 10 10\" ";
680  }
681
682  /**
683   * Get method for CVParameters.
684   *
685   * @return the CVParameters
686   */
687  public Object[] getCVParameters() {
688     
689      Object[] CVParams = m_CVParams.toArray();
690     
691      String params[] = new String[CVParams.length];
692     
693      for(int i=0; i<CVParams.length; i++) 
694          params[i] = CVParams[i].toString();
695     
696      return params;
697     
698  }
699 
700  /**
701   * Set method for CVParameters.
702   *
703   * @param params the CVParameters to use
704   * @throws Exception if the setting of the CVParameters fails
705   */
706  public void setCVParameters(Object[] params) throws Exception {
707     
708      FastVector backup = m_CVParams;
709      m_CVParams = new FastVector();
710     
711      for(int i=0; i<params.length; i++) {
712          try{
713          addCVParameter((String)params[i]);
714          }
715          catch(Exception ex) { m_CVParams = backup; throw ex; }
716      }
717  }
718
719  /**
720   * Returns the tip text for this property
721   * @return tip text for this property suitable for
722   * displaying in the explorer/experimenter gui
723   */
724  public String numFoldsTipText() {
725    return "Get the number of folds used for cross-validation.";
726  }
727
728  /**
729   * Gets the number of folds for the cross-validation.
730   *
731   * @return the number of folds for the cross-validation
732   */
733  public int getNumFolds() {
734
735    return m_NumFolds;
736  }
737
738  /**
739   * Sets the number of folds for the cross-validation.
740   *
741   * @param numFolds the number of folds for the cross-validation
742   * @throws Exception if parameter illegal
743   */
744  public void setNumFolds(int numFolds) throws Exception {
745   
746    if (numFolds < 0) {
747      throw new IllegalArgumentException("Stacking: Number of cross-validation " +
748                                         "folds must be positive.");
749    }
750    m_NumFolds = numFolds;
751  }
752 
753  /**
754   *  Returns the type of graph this classifier
755   *  represents.
756   * 
757   *  @return the type of graph this classifier represents
758   */   
759  public int graphType() {
760   
761    if (m_Classifier instanceof Drawable)
762      return ((Drawable)m_Classifier).graphType();
763    else 
764      return Drawable.NOT_DRAWABLE;
765  }
766
767  /**
768   * Returns graph describing the classifier (if possible).
769   *
770   * @return the graph of the classifier in dotty format
771   * @throws Exception if the classifier cannot be graphed
772   */
773  public String graph() throws Exception {
774   
775    if (m_Classifier instanceof Drawable)
776      return ((Drawable)m_Classifier).graph();
777    else throw new Exception("Classifier: " + 
778                             m_Classifier.getClass().getName() + " " +
779                             Utils.joinOptions(m_BestClassifierOptions)
780                             + " cannot be graphed");
781  }
782
783  /**
784   * Returns description of the cross-validated classifier.
785   *
786   * @return description of the cross-validated classifier as a string
787   */
788  public String toString() {
789
790    if (m_InitOptions == null)
791      return "CVParameterSelection: No model built yet.";
792
793    String result = "Cross-validated Parameter selection.\n"
794    + "Classifier: " + m_Classifier.getClass().getName() + "\n";
795    try {
796      for (int i = 0; i < m_CVParams.size(); i++) {
797        CVParameter cvParam = (CVParameter)m_CVParams.elementAt(i);
798        result += "Cross-validation Parameter: '-" 
799          + cvParam.m_ParamChar + "'"
800          + " ranged from " + cvParam.m_Lower 
801          + " to ";
802        switch ((int)(cvParam.m_Lower - cvParam.m_Upper + 0.5)) {
803        case 1:
804          result += m_NumAttributes;
805          break;
806        case 2:
807          result += m_TrainFoldSize;
808          break;
809        default:
810          result += cvParam.m_Upper;
811          break;
812        }
813        result += " with " + cvParam.m_Steps + " steps\n";
814      }
815    } catch (Exception ex) {
816      result += ex.getMessage();
817    }
818    result += "Classifier Options: "
819      + Utils.joinOptions(m_BestClassifierOptions)
820      + "\n\n" + m_Classifier.toString();
821    return result;
822  }
823
824  /**
825   * A concise description of the model.
826   *
827   * @return a concise description of the model
828   */
829  public String toSummaryString() {
830
831    String result = "Selected values: "
832      + Utils.joinOptions(m_BestClassifierOptions);
833    return result + '\n';
834  }
835 
836  /**
837   * Returns the revision string.
838   *
839   * @return            the revision
840   */
841  public String getRevision() {
842    return RevisionUtils.extract("$Revision: 5928 $");
843  }
844 
845  /**
846   * Main method for testing this class.
847   *
848   * @param argv the options
849   */
850  public static void main(String [] argv) {
851    runClassifier(new CVParameterSelection(), argv);
852  }
853}
Note: See TracBrowser for help on using the repository browser.