source: src/main/java/weka/classifiers/bayes/BayesianLogisticRegression.java @ 21

Last change on this file since 21 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 34.6 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    BayesianLogisticRegression.java
19 *    Copyright (C) 2008 Illinois Institute of Technology
20 *
21 */
22
23package weka.classifiers.bayes;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.classifiers.bayes.blr.GaussianPriorImpl;
28import weka.classifiers.bayes.blr.LaplacePriorImpl;
29import weka.classifiers.bayes.blr.Prior;
30import weka.core.Attribute;
31import weka.core.Capabilities;
32import weka.core.Instance;
33import weka.core.Instances;
34import weka.core.Option;
35import weka.core.OptionHandler;
36import weka.core.RevisionUtils;
37import weka.core.SelectedTag;
38import weka.core.SerializedObject;
39import weka.core.Tag;
40import weka.core.TechnicalInformation;
41import weka.core.TechnicalInformationHandler;
42import weka.core.Utils;
43import weka.core.Capabilities.Capability;
44import weka.core.TechnicalInformation.Field;
45import weka.core.TechnicalInformation.Type;
46import weka.filters.Filter;
47import weka.filters.unsupervised.attribute.Normalize;
48
49import java.util.Enumeration;
50import java.util.Random;
51import java.util.StringTokenizer;
52import java.util.Vector;
53
54/**
55 <!-- globalinfo-start -->
56 * Implements Bayesian Logistic Regression for both Gaussian and Laplace Priors.<br/>
57 * <br/>
58 * For more information, see<br/>
59 * <br/>
60 * Alexander Genkin, David D. Lewis, David Madigan (2004). Large-scale bayesian logistic regression for text categorization. URL http://www.stat.rutgers.edu/~madigan/PAPERS/shortFat-v3a.pdf.
61 * <p/>
62 <!-- globalinfo-end -->
63 *
64 <!-- technical-bibtex-start -->
65 * BibTeX:
66 * <pre>
67 * &#64;techreport{Genkin2004,
68 *    author = {Alexander Genkin and David D. Lewis and David Madigan},
69 *    institution = {DIMACS},
70 *    title = {Large-scale bayesian logistic regression for text categorization},
71 *    year = {2004},
72 *    URL = {http://www.stat.rutgers.edu/\~madigan/PAPERS/shortFat-v3a.pdf}
73 * }
74 * </pre>
75 * <p/>
76 <!-- technical-bibtex-end -->
77 *
78 *
79 *  @author Navendu Garg (gargnav at iit dot edu)
80 *  @version $Revision: 5928 $
81 */
82public class BayesianLogisticRegression extends AbstractClassifier
83  implements OptionHandler, TechnicalInformationHandler {
84 
85  static final long serialVersionUID = -8013478897911757631L;
86
87  /** Log-likelihood values to be used to choose the best hyperparameter. */
88  public static double[] LogLikelihood;
89
90  /** Set of values to be used as hyperparameter values during Cross-Validation. */
91  public static double[] InputHyperparameterValues;
92
93  /** DEBUG Mode*/
94  boolean debug = false;
95
96  /** Choose whether to normalize data or not */
97  public boolean NormalizeData = false;
98
99  /** Tolerance criteria for the stopping criterion. */
100  public double Tolerance = 0.0005;
101
102  /** Threshold for binary classification of probabilisitic estimate*/
103  public double Threshold = 0.5;
104
105  /** Distributions available */
106  public static final int GAUSSIAN = 1;
107  public static final int LAPLACIAN = 2;
108 
109  public static final Tag[] TAGS_PRIOR = {
110    new Tag(GAUSSIAN, "Gaussian"),
111    new Tag(LAPLACIAN, "Laplacian")
112  };
113
114  /** Distribution Prior class */
115  public int PriorClass = GAUSSIAN;
116
117  /** NumFolds for CV based Hyperparameters selection*/
118  public int NumFolds = 2;
119
120  /** Methods for selecting the hyperparameter value */
121  public static final int NORM_BASED = 1;
122  public static final int CV_BASED = 2;
123  public static final int SPECIFIC_VALUE = 3;
124
125  public static final Tag[] TAGS_HYPER_METHOD = {
126    new Tag(NORM_BASED, "Norm-based"),
127    new Tag(CV_BASED, "CV-based"),
128    new Tag(SPECIFIC_VALUE, "Specific value")
129  };
130
131  /** Hyperparameter selection method */
132  public int HyperparameterSelection = NORM_BASED;
133
134  /** The class index from the training data */
135  public int ClassIndex = -1;
136
137  /** Best hyperparameter for test phase */
138  public double HyperparameterValue = 0.27;
139
140  /** CV Hyperparameter Range */
141  public String HyperparameterRange = "R:0.01-316,3.16";
142
143  /** Maximum number of iterations */
144  public int maxIterations = 100;
145
146  /**Iteration counter */
147  public int iterationCounter = 0;
148
149  /** Array for storing coefficients of Bayesian regression model. */
150  public double[] BetaVector;
151
152  /** Array to store Regression Coefficient updates. */
153  public double[] DeltaBeta;
154
155  /**        Trust Region Radius Update*/
156  public double[] DeltaUpdate;
157
158  /** Trust Region Radius*/
159  public double[] Delta;
160
161  /**  Array to store Hyperparameter values for each feature. */
162  public double[] Hyperparameters;
163
164  /** R(i)= BetaVector X x(i) X y(i).
165   * This an intermediate value with respect to vector BETA, input values and corresponding class labels*/
166  public double[] R;
167
168  /** This vector is used to store the increments on the R(i). It is also used to determining the stopping criterion.*/
169  public double[] DeltaR;
170
171  /**
172   * This variable is used to keep track of change in
173   * the value of delta summation of r(i).
174   */
175  public double Change;
176
177  /**
178   * Bayesian Logistic Regression returns the probability of a given instance will belong to a certain
179   * class (p(y=+1|Beta,X). To obtain a binary value the Threshold value is used.
180   * <pre>
181   * p(y=+1|Beta,X)>Threshold ? 1 : -1
182   * </pre>
183   */
184
185  /** Filter interface used to point to weka.filters.unsupervised.attribute.Normalize object
186   *
187   */
188  public Filter m_Filter;
189
190  /** Dataset provided to do Training/Test set.*/
191  protected Instances m_Instances;
192
193  /**        Prior class object interface*/
194  protected Prior m_PriorUpdate;
195
196  public String globalInfo() {
197    return "Implements Bayesian Logistic Regression "
198      + "for both Gaussian and Laplace Priors.\n\n"
199      + "For more information, see\n\n"
200      + getTechnicalInformation();
201  }
202
203  /**
204   * <pre>
205   * (1)Initialize m_Beta[j] to 0.
206   * (2)Initialize m_DeltaUpdate[j].
207   * </pre>
208   *
209   * */
210  public void initialize() throws Exception {
211    int numOfAttributes;
212    int numOfInstances;
213    int i;
214    int j;
215
216    Change = 0.0;
217
218    //Manipulate Data
219    if (NormalizeData) {
220      m_Filter = new Normalize();
221      m_Filter.setInputFormat(m_Instances);
222      m_Instances = Filter.useFilter(m_Instances, m_Filter);
223    }
224
225    //Set the intecept coefficient.
226    Attribute att = new Attribute("(intercept)");
227    Instance instance;
228
229    m_Instances.insertAttributeAt(att, 0);
230
231    for (i = 0; i < m_Instances.numInstances(); i++) {
232      instance = m_Instances.instance(i);
233      instance.setValue(0, 1.0);
234    }
235
236    //Get the number of attributes
237    numOfAttributes = m_Instances.numAttributes();
238    numOfInstances = m_Instances.numInstances();
239    ClassIndex = m_Instances.classIndex();
240    iterationCounter = 0;
241
242    //Initialize Arrays.
243    switch (HyperparameterSelection) {
244    case 1:
245      HyperparameterValue = normBasedHyperParameter();
246
247      if (debug) {
248        System.out.println("Norm-based Hyperparameter: " + HyperparameterValue);
249      }
250
251      break;
252
253    case 2:
254      HyperparameterValue = CVBasedHyperparameter();
255
256      if (debug) {
257        System.out.println("CV-based Hyperparameter: " + HyperparameterValue);
258      }
259
260      break;
261    }
262
263    BetaVector = new double[numOfAttributes];
264    Delta = new double[numOfAttributes];
265    DeltaBeta = new double[numOfAttributes];
266    Hyperparameters = new double[numOfAttributes];
267    DeltaUpdate = new double[numOfAttributes];
268
269    for (j = 0; j < numOfAttributes; j++) {
270      BetaVector[j] = 0.0;
271      Delta[j] = 1.0;
272      DeltaBeta[j] = 0.0;
273      DeltaUpdate[j] = 0.0;
274
275      //TODO: Change the way it takes values.
276      Hyperparameters[j] = HyperparameterValue;
277    }
278
279    DeltaR = new double[numOfInstances];
280    R = new double[numOfInstances];
281
282    for (i = 0; i < numOfInstances; i++) {
283      DeltaR[i] = 0.0;
284      R[i] = 0.0;
285    }
286
287    //Set the Prior interface to the appropriate prior implementation.
288    if (PriorClass == GAUSSIAN) {
289      m_PriorUpdate = new GaussianPriorImpl();
290    } else {
291      m_PriorUpdate = new LaplacePriorImpl();
292    }
293  }
294
295  /**
296   * This method tests what kind of data this classifier can handle.
297   * return Capabilities
298   */
299  public Capabilities getCapabilities() {
300    Capabilities result = super.getCapabilities();
301    result.disableAll();
302
303    // attributes
304    result.enable(Capability.NUMERIC_ATTRIBUTES);
305
306    result.enable(Capability.BINARY_ATTRIBUTES);
307
308    // class
309    result.enable(Capability.BINARY_CLASS);
310
311    // instances
312    result.setMinimumNumberInstances(0);
313
314    return result;
315  }
316
317  /**
318   * <ul>
319   *         <li>(1) Set the data to the class attribute m_Instances.</li>
320   *  <li>(2)Call the method initialize() to initialize the values.</li>
321   * </ul>
322   *        @param data training data
323   *        @exception Exception if classifier can't be built successfully.
324   */
325  public void buildClassifier(Instances data) throws Exception {
326    Instance instance;
327    int i;
328    int j;
329
330    // can classifier handle the data?
331    getCapabilities().testWithFail(data);
332
333    //(1) Set the data to the class attribute m_Instances.
334    m_Instances = new Instances(data);
335
336    //(2)Call the method initialize() to initialize the values.
337    initialize();
338
339    do {
340      //Compute the prior Trust Region Radius Update;
341      for (j = 0; j < m_Instances.numAttributes(); j++) {
342        if (j != ClassIndex) {
343          DeltaUpdate[j] = m_PriorUpdate.update(j, m_Instances, BetaVector[j],
344              Hyperparameters[j], R, Delta[j]);
345          //limit step to trust region.
346          DeltaBeta[j] = Math.min(Math.max(DeltaUpdate[j], 0 - Delta[j]),
347              Delta[j]);
348
349          //Update the
350          for (i = 0; i < m_Instances.numInstances(); i++) {
351            instance = m_Instances.instance(i);
352
353            if (instance.value(j) != 0) {
354              DeltaR[i] = DeltaBeta[j] * instance.value(j) * classSgn(instance.classValue());
355              R[i] += DeltaR[i];
356            }
357          }
358
359          //Updated Beta values.
360          BetaVector[j] += DeltaBeta[j];
361
362          //Update size of trust region.
363          Delta[j] = Math.max(2 * Math.abs(DeltaBeta[j]), Delta[j] / 2.0);
364        }
365      }
366    } while (!stoppingCriterion());
367
368    m_PriorUpdate.computelogLikelihood(BetaVector, m_Instances);
369    m_PriorUpdate.computePenalty(BetaVector, Hyperparameters);
370  }
371
372  /**
373   * This class is used to mask the internal class labels.
374   *
375   * @param value internal class label
376   * @return
377   * <pre>
378   * <ul><li>
379   *  -1 for internal class label 0
380   *  </li>
381   *  <li>
382   *  +1 for internal class label 1
383   *  </li>
384   *  </ul>
385   *  </pre>
386   */
387  public static double classSgn(double value) {
388    if (value == 0.0) {
389      return -1.0;
390    } else {
391      return 1.0;
392    }
393  }
394
395  /**
396    * Returns an instance of a TechnicalInformation object, containing
397    * detailed information about the technical background of this class,
398    * e.g., paper reference or book this class is based on.
399    *
400    * @return the technical information about this class
401    */
402  public TechnicalInformation getTechnicalInformation() {
403    TechnicalInformation result = null;
404
405    result = new TechnicalInformation(Type.TECHREPORT);
406    result.setValue(Field.AUTHOR, "Alexander Genkin and David D. Lewis and David Madigan");
407    result.setValue(Field.YEAR, "2004");
408    result.setValue(Field.TITLE, "Large-scale bayesian logistic regression for text categorization");
409    result.setValue(Field.INSTITUTION, "DIMACS");
410    result.setValue(Field.URL, "http://www.stat.rutgers.edu/~madigan/PAPERS/shortFat-v3a.pdf");
411    return result;
412  }
413
414  /**
415   * This is a convient function that defines and upper bound
416   * (Delta>0) for values of r(i) reachable by updates in the
417   * trust region.
418   *
419   * r BetaVector X x(i)y(i).
420   * delta A parameter where sigma > 0
421   * @return double function value
422   */
423  public static double bigF(double r, double sigma) {
424    double funcValue = 0.25;
425    double absR = Math.abs(r);
426
427    if (absR > sigma) {
428      funcValue = 1.0 / (2.0 + Math.exp(absR - sigma) + Math.exp(sigma - absR));
429    }
430
431    return funcValue;
432  }
433
434  /**
435   * This method implements the stopping criterion
436   * function.
437   *
438   * @return boolean whether to stop or not.
439   */
440  public boolean stoppingCriterion() {
441    int i;
442    double sum_deltaR = 0.0;
443    double sum_R = 1.0;
444    boolean shouldStop;
445    double value = 0.0;
446    double delta;
447
448    //Summation of changes in R(i) vector.
449    for (i = 0; i < m_Instances.numInstances(); i++) {
450      sum_deltaR += Math.abs(DeltaR[i]); //Numerator (deltaR(i))
451      sum_R += Math.abs(R[i]); // Denominator (1+sum(R(i))
452    }
453
454    delta = Math.abs(sum_deltaR - Change);
455    Change = delta / sum_R;
456
457    if (debug) {
458      System.out.println(Change + " <= " + Tolerance);
459    }
460
461    shouldStop = ((Change <= Tolerance) || (iterationCounter >= maxIterations))
462      ? true : false;
463    iterationCounter++;
464    Change = sum_deltaR;
465
466    return shouldStop;
467  }
468
469  /**
470   *  This method computes the values for the logistic link function.
471   *  <pre>f(r)=exp(r)/(1+exp(r))</pre>
472   *
473   * @return output value
474   */
475  public static double logisticLinkFunction(double r) {
476    return Math.exp(r) / (1.0 + Math.exp(r));
477  }
478
479  /**
480   * Sign for a given value.
481   * @param r
482   * @return double +1 if r>0, -1 if r<0
483   */
484  public static double sgn(double r) {
485    double sgn = 0.0;
486
487    if (r > 0) {
488      sgn = 1.0;
489    } else if (r < 0) {
490      sgn = -1.0;
491    }
492
493    return sgn;
494  }
495
496  /**
497   *        This function computes the norm-based hyperparameters
498   *        and stores them in the m_Hyperparameters.
499   */
500  public double normBasedHyperParameter() {
501    //TODO: Implement this method.
502    Instance instance;
503
504    double mean = 0.0;
505
506    for (int i = 0; i < m_Instances.numInstances(); i++) {
507      instance = m_Instances.instance(i);
508
509      double sqr_sum = 0.0;
510
511      for (int j = 0; j < m_Instances.numAttributes(); j++) {
512        if (j != ClassIndex) {
513          sqr_sum += (instance.value(j) * instance.value(j));
514        }
515      }
516
517      //sqr_sum=Math.sqrt(sqr_sum);
518      mean += sqr_sum;
519    }
520
521    mean = mean / (double) m_Instances.numInstances();
522
523    return ((double) m_Instances.numAttributes()) / mean;
524  }
525
526  /**
527   * Classifies the given instance using the Bayesian Logistic Regression function.
528   *
529   * @param instance the test instance
530   * @return the classification
531   * @throws Exception if classification can't be done successfully
532   */
533  public double classifyInstance(Instance instance) throws Exception {
534    //TODO: Implement
535    double sum_R = 0.0;
536    double classification = 0.0;
537
538    sum_R = BetaVector[0];
539
540    for (int j = 0; j < instance.numAttributes(); j++) {
541      if (j != (ClassIndex - 1)) {
542        sum_R += (BetaVector[j + 1] * instance.value(j));
543      }
544    }
545
546    sum_R = logisticLinkFunction(sum_R);
547
548    if (sum_R > Threshold) {
549      classification = 1.0;
550    } else {
551      classification = 0.0;
552    }
553
554    return classification;
555  }
556
557  /**
558   * Outputs the linear regression model as a string.
559   *
560   * @return the model as string
561   */
562  public String toString() {
563
564    if (m_Instances == null) {
565      return "Bayesian logistic regression: No model built yet.";
566    }
567
568    StringBuffer buf = new StringBuffer();
569    String text = "";
570
571    switch (HyperparameterSelection) {
572    case 1:
573      text = "Norm-Based Hyperparameter Selection: ";
574
575      break;
576
577    case 2:
578      text = "Cross-Validation Based Hyperparameter Selection: ";
579
580      break;
581
582    case 3:
583      text = "Specified Hyperparameter: ";
584
585      break;
586    }
587
588    buf.append(text).append(HyperparameterValue).append("\n\n");
589
590    buf.append("Regression Coefficients\n");
591    buf.append("=========================\n\n");
592
593    for (int j = 0; j < m_Instances.numAttributes(); j++) {
594      if (j != ClassIndex) {
595        if (BetaVector[j] != 0.0) {
596          buf.append(m_Instances.attribute(j).name()).append(" : ")
597             .append(BetaVector[j]).append("\n");
598        }
599      }
600    }
601
602    buf.append("===========================\n\n");
603    buf.append("Likelihood: " + m_PriorUpdate.getLoglikelihood() + "\n\n");
604    buf.append("Penalty: " + m_PriorUpdate.getPenalty() + "\n\n");
605    buf.append("Regularized Log Posterior: " + m_PriorUpdate.getLogPosterior() +
606      "\n");
607    buf.append("===========================\n\n");
608
609    return buf.toString();
610  }
611
612  /**
613   * Method computes the best hyperparameter value by doing cross
614   * -validation on the training data and compute the likelihood.
615   * The method can parse a range of values or a list of values.
616   * @return Best hyperparameter value with the max likelihood value on the training data.
617   * @throws Exception
618   */
619  public double CVBasedHyperparameter() throws Exception {
620    //TODO: Method incomplete.
621    double start;
622
623    //TODO: Method incomplete.
624    double end;
625
626    //TODO: Method incomplete.
627    double multiplier;
628    int size = 0;
629    double[] list = null;
630    double MaxHypeValue = 0.0;
631    double MaxLikelihood = 0.0;
632    StringTokenizer tokenizer = new StringTokenizer(HyperparameterRange);
633    String rangeType = tokenizer.nextToken(":");
634
635    if (rangeType.equals("R")) {
636      String temp = tokenizer.nextToken();
637      tokenizer = new StringTokenizer(temp);
638      start = Double.parseDouble(tokenizer.nextToken("-"));
639      tokenizer = new StringTokenizer(tokenizer.nextToken());
640      end = Double.parseDouble(tokenizer.nextToken(","));
641      multiplier = Double.parseDouble(tokenizer.nextToken());
642
643      int steps = (int) (((Math.log10(end) - Math.log10(start)) / Math.log10(multiplier)) +
644        1);
645      list = new double[steps];
646
647      int count = 0;
648
649      for (double i = start; i <= end; i *= multiplier) {
650        list[count++] = i;
651      }
652    } else if (rangeType.equals("L")) {
653      Vector vec = new Vector();
654
655      while (tokenizer.hasMoreTokens()) {
656        vec.add(tokenizer.nextToken(","));
657      }
658
659      list = new double[vec.size()];
660
661      for (int i = 0; i < vec.size(); i++) {
662        list[i] = Double.parseDouble((String) vec.get(i));
663      }
664    } else {
665      //throw exception. 
666    }
667
668    // Perform two-fold cross-validation to collect
669    // unbiased predictions
670    if (list != null) {
671      int numFolds = (int) NumFolds;
672      Random random = new Random();
673      m_Instances.randomize(random);
674      m_Instances.stratify(numFolds);
675
676      for (int k = 0; k < list.length; k++) {
677        for (int i = 0; i < numFolds; i++) {
678          Instances train = m_Instances.trainCV(numFolds, i, random);
679          SerializedObject so = new SerializedObject(this);
680          BayesianLogisticRegression blr = (BayesianLogisticRegression) so.getObject();
681          //          blr.setHyperparameterSelection(3);
682          blr.setHyperparameterSelection(new SelectedTag(SPECIFIC_VALUE, 
683                                                         TAGS_HYPER_METHOD));
684          blr.setHyperparameterValue(list[k]);
685          //          blr.setPriorClass(PriorClass);
686          blr.setPriorClass(new SelectedTag(PriorClass,
687                                            TAGS_PRIOR));
688          blr.setThreshold(Threshold);
689          blr.setTolerance(Tolerance);
690          blr.buildClassifier(train);
691
692          Instances test = m_Instances.testCV(numFolds, i);
693          double val = blr.getLoglikeliHood(blr.BetaVector, test);
694
695          if (debug) {
696            System.out.println("Fold " + i + "Hyperparameter: " + list[k]);
697            System.out.println("===================================");
698            System.out.println(" Likelihood: " + val);
699          }
700
701          if ((k == 0) | (val > MaxLikelihood)) {
702            MaxLikelihood = val;
703            MaxHypeValue = list[k];
704          }
705        }
706      }
707    } else {
708      return HyperparameterValue;
709    }
710
711    return MaxHypeValue;
712  }
713
714  /**
715   *
716   * @return likelihood for a given set of betas and instances
717   */
718  public double getLoglikeliHood(double[] betas, Instances instances) {
719    m_PriorUpdate.computelogLikelihood(betas, instances);
720
721    return m_PriorUpdate.getLoglikelihood();
722  }
723
724  /**
725   * Returns an enumeration describing the available options.
726   *
727   * @return an enumeration of all the available options.
728   */
729  public Enumeration listOptions() {
730    Vector newVector = new Vector();
731
732    newVector.addElement(new Option("\tShow Debugging Output\n", "D", 0, "-D"));
733    newVector.addElement(new Option("\tDistribution of the Prior "
734                                    +"(1=Gaussian, 2=Laplacian)"
735                                    +"\n\t(default: 1=Gaussian)"
736                                    , "P", 1,
737                                    "-P <integer>"));
738    newVector.addElement(new Option("\tHyperparameter Selection Method "
739                                    +"(1=Norm-based, 2=CV-based, 3=specific value)\n"
740                                    +"\t(default: 1=Norm-based)", 
741                                    "H",
742                                    1, 
743                                    "-H <integer>"));
744    newVector.addElement(new Option("\tSpecified Hyperparameter Value (use in conjunction with -H 3)\n"
745                                    +"\t(default: 0.27)", 
746                                    "V", 
747                                    1,
748                                    "-V <double>"));
749    newVector.addElement(new Option(
750        "\tHyperparameter Range (use in conjunction with -H 2)\n"
751        +"\t(format: R:start-end,multiplier OR L:val(1), val(2), ..., val(n))\n"
752        +"\t(default: R:0.01-316,3.16)", 
753        "R", 
754        1,
755        "-R <string>"));
756    newVector.addElement(new Option("\tTolerance Value\n\t(default: 0.0005)", 
757                                    "Tl", 
758                                    1,
759                                    "-Tl <double>"));
760    newVector.addElement(new Option("\tThreshold Value\n\t(default: 0.5)", 
761                                    "S", 
762                                    1, 
763                                    "-S <double>"));
764    newVector.addElement(new Option("\tNumber Of Folds (use in conjuction with -H 2)\n"
765                                    +"\t(default: 2)", 
766                                    "F", 
767                                    1,
768                                    "-F <integer>"));
769    newVector.addElement(new Option("\tMax Number of Iterations\n\t(default: 100)", 
770                                    "I", 
771                                    1,
772                                    "-I <integer>"));
773    newVector.addElement(new Option("\tNormalize the data",
774                                    "N", 0, "-N"));
775
776    return newVector.elements();
777  }
778
779  /**
780   * Parses a given list of options. <p/>
781   *
782   <!-- options-start -->
783   * Valid options are: <p/>
784   *
785   * <pre> -D
786   *  Show Debugging Output
787   * </pre>
788   *
789   * <pre> -P &lt;integer&gt;
790   *  Distribution of the Prior (1=Gaussian, 2=Laplacian)
791   *  (default: 1=Gaussian)</pre>
792   *
793   * <pre> -H &lt;integer&gt;
794   *  Hyperparameter Selection Method (1=Norm-based, 2=CV-based, 3=specific value)
795   *  (default: 1=Norm-based)</pre>
796   *
797   * <pre> -V &lt;double&gt;
798   *  Specified Hyperparameter Value (use in conjunction with -H 3)
799   *  (default: 0.27)</pre>
800   *
801   * <pre> -R &lt;string&gt;
802   *  Hyperparameter Range (use in conjunction with -H 2)
803   *  (format: R:start-end,multiplier OR L:val(1), val(2), ..., val(n))
804   *  (default: R:0.01-316,3.16)</pre>
805   *
806   * <pre> -Tl &lt;double&gt;
807   *  Tolerance Value
808   *  (default: 0.0005)</pre>
809   *
810   * <pre> -S &lt;double&gt;
811   *  Threshold Value
812   *  (default: 0.5)</pre>
813   *
814   * <pre> -F &lt;integer&gt;
815   *  Number Of Folds (use in conjuction with -H 2)
816   *  (default: 2)</pre>
817   *
818   * <pre> -I &lt;integer&gt;
819   *  Max Number of Iterations
820   *  (default: 100)</pre>
821   *
822   * <pre> -N
823   *  Normalize the data</pre>
824   *
825   <!-- options-end -->
826   *
827   * @param options the list of options as an array of strings
828   * @throws Exception if an option is not supported
829   */
830  public void setOptions(String[] options) throws Exception {
831    //Debug Option
832    debug = Utils.getFlag('D', options);
833
834    // Set Tolerance.
835    String Tol = Utils.getOption("Tl", options);
836
837    if (Tol.length() != 0) {
838      Tolerance = Double.parseDouble(Tol);
839    }
840
841    //Set Threshold
842    String Thres = Utils.getOption('S', options);
843
844    if (Thres.length() != 0) {
845      Threshold = Double.parseDouble(Thres);
846    }
847
848    //Set Hyperparameter Type
849    String Hype = Utils.getOption('H', options);
850
851    if (Hype.length() != 0) {
852      HyperparameterSelection = Integer.parseInt(Hype);
853    }
854
855    //Set Hyperparameter Value
856    String HyperValue = Utils.getOption('V', options);
857
858    if (HyperValue.length() != 0) {
859      HyperparameterValue = Double.parseDouble(HyperValue);
860    }
861
862    // Set hyper parameter range or list.
863    String HyperparameterRange = Utils.getOption("R", options);
864
865    //Set Prior class.
866    String strPrior = Utils.getOption('P', options);
867
868    if (strPrior.length() != 0) {
869      PriorClass = Integer.parseInt(strPrior);
870    }
871
872    String folds = Utils.getOption('F', options);
873
874    if (folds.length() != 0) {
875      NumFolds = Integer.parseInt(folds);
876    }
877
878    String iterations = Utils.getOption('I', options);
879
880    if (iterations.length() != 0) {
881      maxIterations = Integer.parseInt(iterations);
882    }
883
884    NormalizeData = Utils.getFlag('N', options);
885
886    //TODO: Implement this method for other options.
887    Utils.checkForRemainingOptions(options);
888  }
889
890  /**
891   *
892   */
893  public String[] getOptions() {
894    Vector result = new Vector();
895
896    //Add Debug Mode to options.
897    result.add("-D");
898
899    //Add Tolerance value to options
900    result.add("-Tl");
901    result.add("" + Tolerance);
902
903    //Add Threshold value to options
904    result.add("-S");
905    result.add("" + Threshold);
906
907    //Add Hyperparameter value to options
908    result.add("-H");
909    result.add("" + HyperparameterSelection);
910
911    result.add("-V");
912    result.add("" + HyperparameterValue);
913
914    result.add("-R");
915    result.add("" + HyperparameterRange);
916
917    //Add Prior Class to options
918    result.add("-P");
919    result.add("" + PriorClass);
920
921    result.add("-F");
922    result.add("" + NumFolds);
923
924    result.add("-I");
925    result.add("" + maxIterations);
926
927    result.add("-N");
928
929    return (String[]) result.toArray(new String[result.size()]);
930  }
931
932  /**
933   * Main method for testing this class.
934   *
935   * @param argv the options
936   */
937  public static void main(String[] argv) {
938    runClassifier(new BayesianLogisticRegression(), argv);
939  }
940
941  /**
942   * Returns the tip text for this property
943   *
944   * @return tip text for this property suitable for
945   * displaying in the explorer/experimenter gui
946   */
947  public String debugTipText() {
948    return "Turns on debugging mode.";
949  }
950
951  /**
952   *
953   */
954  public void setDebug(boolean debugMode) {
955    debug = debugMode;
956  }
957
958  /**
959   * Returns the tip text for this property
960   *
961   * @return tip text for this property suitable for
962   * displaying in the explorer/experimenter gui
963   */
964  public String hyperparameterSelectionTipText() {
965    return "Select the type of Hyperparameter to be used.";
966  }
967
968  /**
969   * Get the method used to select the hyperparameter
970   *
971   * @return the method used to select the hyperparameter
972   */
973  public SelectedTag getHyperparameterSelection() {
974    return new SelectedTag(HyperparameterSelection, 
975                           TAGS_HYPER_METHOD);
976  }
977
978  /**
979   * Set the method used to select the hyperparameter
980   *
981   * @param newMethod the method used to set the hyperparameter
982   */
983  public void setHyperparameterSelection(SelectedTag newMethod) {
984    if (newMethod.getTags() == TAGS_HYPER_METHOD) {
985      int c = newMethod.getSelectedTag().getID();
986      if (c >= 1 && c <= 3) {
987        HyperparameterSelection = c;
988      } else {
989        throw new IllegalArgumentException("Wrong selection type, -H value should be: "
990                                           + "1 for norm-based, 2 for CV-based and "
991                                         + "3 for specific value");
992      }
993    }
994  }
995
996  /**
997   * Returns the tip text for this property
998   *
999   * @return tip text for this property suitable for
1000   * displaying in the explorer/experimenter gui
1001   */
1002  public String priorClassTipText() {
1003    return "The type of prior to be used.";
1004  }
1005
1006  /**
1007   * Set the type of prior to use.
1008   *
1009   * @param newMethod the type of prior to use.
1010   */
1011  public void setPriorClass(SelectedTag newMethod) {
1012    if (newMethod.getTags() == TAGS_PRIOR) {
1013      int c = newMethod.getSelectedTag().getID();
1014      if (c == GAUSSIAN || c == LAPLACIAN) {
1015        PriorClass = c;
1016      } else {
1017        throw new IllegalArgumentException("Wrong selection type, -P value should be: "
1018                                           + "1 for Gaussian or 2 for Laplacian");
1019      }
1020    }
1021  }
1022
1023  /**
1024   * Get the type of prior to use.
1025   *
1026   * @return the type of prior to use
1027   */
1028  public SelectedTag getPriorClass() {
1029    return new SelectedTag(PriorClass,
1030                           TAGS_PRIOR);
1031  }
1032
1033  /**
1034   * Returns the tip text for this property
1035   *
1036   * @return tip text for this property suitable for
1037   * displaying in the explorer/experimenter gui
1038   */
1039  public String thresholdTipText() {
1040    return "Set the threshold for classifiction. The logistic function doesn't "
1041      + "return a class label but an estimate of p(y=+1|B,x(i)). "
1042      + "These estimates need to be converted to binary class label predictions. "
1043      + "values above the threshold are assigned class +1.";
1044  }
1045
1046  /**
1047   * Return the threshold being used.
1048   *
1049   * @return the threshold
1050   */
1051  public double getThreshold() {
1052    return Threshold;
1053  }
1054
1055  /**
1056   * Set the threshold to use.
1057   *
1058   * @param threshold the threshold to use
1059   */
1060  public void setThreshold(double threshold) {
1061    Threshold = threshold;
1062  }
1063
1064  /**
1065   * Returns the tip text for this property
1066   *
1067   * @return tip text for this property suitable for
1068   * displaying in the explorer/experimenter gui
1069   */
1070  public String toleranceTipText() {
1071    return "This value decides the stopping criterion.";
1072  }
1073
1074  /**
1075   * Get the tolerance value
1076   *
1077   * @return the tolerance value
1078   */
1079  public double getTolerance() {
1080    return Tolerance;
1081  }
1082
1083  /**
1084   * Set the tolerance value
1085   *
1086   * @param tolerance the tolerance value to use
1087   */
1088  public void setTolerance(double tolerance) {
1089    Tolerance = tolerance;
1090  }
1091
1092  /**
1093   * Returns the tip text for this property
1094   *
1095   * @return tip text for this property suitable for
1096   * displaying in the explorer/experimenter gui
1097   */
1098  public String hyperparameterValueTipText() {
1099    return "Specific hyperparameter value. Used when the hyperparameter "
1100      + "selection method is set to specific value";
1101  }
1102
1103  /**
1104   * Get the hyperparameter value. Used when the hyperparameter
1105   * selection method is set to specific value
1106   *
1107   * @return the hyperparameter value
1108   */
1109  public double getHyperparameterValue() {
1110    return HyperparameterValue;
1111  }
1112
1113  /**
1114   * Set the hyperparameter value. Used when the hyperparameter
1115   * selection method is set to specific value
1116   *
1117   * @param hyperparameterValue the value of the hyperparameter
1118   */
1119  public void setHyperparameterValue(double hyperparameterValue) {
1120    HyperparameterValue = hyperparameterValue;
1121  }
1122
1123  /**
1124   * Returns the tip text for this property
1125   *
1126   * @return tip text for this property suitable for
1127   * displaying in the explorer/experimenter gui
1128   */
1129  public String numFoldsTipText() {
1130    return "The number of folds to use for CV-based hyperparameter selection.";
1131  }
1132
1133  /**
1134   * Return the number of folds for CV-based hyperparameter selection
1135   *
1136   * @return the number of CV folds
1137   */
1138  public int getNumFolds() {
1139    return NumFolds;
1140  }
1141
1142  /**
1143   * Set the number of folds to use for CV-based hyperparameter
1144   * selection
1145   *
1146   * @param numFolds number of folds to select
1147   */
1148  public void setNumFolds(int numFolds) {
1149    NumFolds = numFolds;
1150  }
1151
1152  /**
1153   * Returns the tip text for this property
1154   *
1155   * @return tip text for this property suitable for
1156   * displaying in the explorer/experimenter gui
1157   */
1158  public String maxIterationsTipText() {
1159    return "The maximum number of iterations to perform.";
1160  }
1161
1162  /**
1163   * Get the maximum number of iterations to perform
1164   *
1165   * @return the maximum number of iterations
1166   */
1167  public int getMaxIterations() {
1168    return maxIterations;
1169  }
1170
1171  /**
1172   * Set the maximum number of iterations to perform
1173   *
1174   * @param maxIterations maximum number of iterations
1175   */
1176  public void setMaxIterations(int maxIterations) {
1177    this.maxIterations = maxIterations;
1178  }
1179
1180  /**
1181   * Returns the tip text for this property
1182   *
1183   * @return tip text for this property suitable for
1184   * displaying in the explorer/experimenter gui
1185   */
1186  public String normalizeDataTipText() {
1187    return "Normalize the data.";
1188  }
1189
1190  /**
1191   * Returns true if the data is to be normalized first
1192   *
1193   * @return true if the data is to be normalized
1194   */
1195  public boolean isNormalizeData() {
1196    return NormalizeData;
1197  }
1198
1199  /**
1200   * Set whether to normalize the data or not
1201   *
1202   * @param normalizeData true if data is to be normalized
1203   */
1204  public void setNormalizeData(boolean normalizeData) {
1205    NormalizeData = normalizeData;
1206  }
1207
1208  /**
1209   * Returns the tip text for this property
1210   *
1211   * @return tip text for this property suitable for
1212   * displaying in the explorer/experimenter gui
1213   */
1214  public String hyperparameterRangeTipText() {
1215    return "Hyperparameter value range. In case of CV-based Hyperparameters, "
1216      + "you can specify the range in two ways: \n"
1217      + "Comma-Separated: L: 3,5,6 (This will be a list of possible values.)\n"
1218      + "Range: R:0.01-316,3.16 (This will take values from 0.01-316 (inclusive) "
1219      + "in multiplications of 3.16";
1220  }
1221
1222  /**
1223   * Get the range of hyperparameter values to consider
1224   * during CV-based selection.
1225   *
1226   * @return the range of hyperparameters as a Stringe
1227   */
1228  public String getHyperparameterRange() {
1229    return HyperparameterRange;
1230  }
1231
1232  /**
1233   * Set the range of hyperparameter values to consider
1234   * during CV-based selection
1235   *
1236   * @param hyperparameterRange the range of hyperparameter values
1237   */
1238  public void setHyperparameterRange(String hyperparameterRange) {
1239    HyperparameterRange = hyperparameterRange;
1240  }
1241
1242  /**
1243   * Returns true if debug is turned on.
1244   *
1245   * @return true if debug is turned on
1246   */
1247  public boolean isDebug() {
1248    return debug;
1249  }
1250 
1251  /**
1252   * Returns the revision string.
1253   *
1254   * @return            the revision
1255   */
1256  public String getRevision() {
1257    return RevisionUtils.extract("$Revision: 5928 $");
1258  }
1259}
1260
Note: See TracBrowser for help on using the repository browser.