source: src/main/java/weka/estimators/Estimator.java @ 6

Last change on this file since 6 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 21.3 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    Estimator.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.estimators;
24
25import weka.core.Capabilities;
26import weka.core.CapabilitiesHandler;
27import weka.core.Instance;
28import weka.core.Instances;
29import weka.core.Option;
30import weka.core.OptionHandler;
31import weka.core.RevisionHandler;
32import weka.core.RevisionUtils;
33import weka.core.SerializedObject;
34import weka.core.Utils;
35import weka.core.Capabilities.Capability;
36
37import java.io.BufferedReader;
38import java.io.FileReader;
39import java.io.InputStreamReader;
40import java.io.Reader;
41import java.io.Serializable;
42import java.util.Enumeration;
43import java.util.Vector;
44 
45/**
46 *
47 * Abstract class for all estimators.
48 *
49 * Example code for a nonincremental estimator
50 * <code> <pre>
51 *   // create a histogram for estimation
52 *   EqualWidthEstimator est = new EqualWidthEstimator();
53 *   est.addValues(instances, attrIndex);
54 * </pre> </code>
55 *
56 *
57 * Example code for an incremental estimator (incremental
58 * estimators must implement interface IncrementalEstimator)
59 * <code> <pre>
60 *   // Create a discrete estimator that takes values 0 to 9
61 *   DiscreteEstimator newEst = new DiscreteEstimator(10, true);
62 *
63 *   // Create 50 random integers first predicting the probability of the
64 *   // value, then adding the value to the estimator
65 *   Random r = new Random(seed);
66 *   for(int i = 0; i < 50; i++) {
67 *     current = Math.abs(r.nextInt() % 10);
68 *     System.out.println(newEst);
69 *     System.out.println("Prediction for " + current
70 *                        + " = " + newEst.getProbability(current));
71 *     newEst.addValue(current, 1);
72 *   }
73 * </pre> </code>
74 *
75 *
76 * Example code for a main method for an estimator.<p>
77 * <code> <pre>
78 * public static void main(String [] argv) {
79 *
80 *   try {
81 *     LoglikeliEstimator est = new LoglikeliEstimator();     
82 *     Estimator.buildEstimator((Estimator) est, argv, false);     
83 *     System.out.println(est.toString());
84 *   } catch (Exception ex) {
85 *     ex.printStackTrace();
86 *     System.out.println(ex.getMessage());
87 *   }
88 * }
89 * </pre> </code>
90 *
91 *
92 * @author Gabi Schmidberger (gabi@cs.waikato.ac.nz)
93 * @author Len Trigg (trigg@cs.waikato.ac.nz)
94 * @version $Revision: 5489 $
95 */
96public abstract class Estimator 
97  implements Cloneable, Serializable, OptionHandler, CapabilitiesHandler, 
98             RevisionHandler {
99 
100  /** for serialization */
101  static final long serialVersionUID = -5902411487362274342L;
102 
103  /** Debugging mode */
104  private boolean m_Debug = false;
105 
106  /** The class value index is > -1 if subset is taken with specific class value only*/
107  protected double m_classValueIndex = -1.0;
108 
109  /** set if class is not important */
110  protected boolean m_noClass = true;
111 
112  /**
113   * Class to support a building process of an estimator.
114   */
115  private static class Builder
116    implements Serializable, RevisionHandler {
117
118    /** for serialization */
119    private static final long serialVersionUID = -5810927990193597303L;
120   
121    /** instances of the builder */
122    Instances m_instances = null;
123   
124    /** attribute index of the builder */
125    int m_attrIndex = -1;
126   
127    /** class index of the builder, only relevant if class value index is set*/
128    int m_classIndex = -1; 
129
130    /** class value index of the builder */
131    int m_classValueIndex = -1; 
132   
133    /**
134     * Returns the revision string.
135     *
136     * @return          the revision
137     */
138    public String getRevision() {
139      return RevisionUtils.extract("$Revision: 5489 $");
140    }
141  }
142 
143  /**
144   * Add a new data value to the current estimator.
145   *
146   * @param data the new data value
147   * @param weight the weight assigned to the data value
148   */
149  public void addValue(double data, double weight) {
150    try { 
151      throw new Exception("Method to add single value is not implemented!\n"+
152                          "Estimator should implement IncrementalEstimator.");
153    } catch (Exception ex) {
154      ex.printStackTrace();
155      System.out.println(ex.getMessage());
156    }
157  }
158
159  /**
160   * Initialize the estimator with a new dataset.
161   * Finds min and max first.
162   *
163   * @param data the dataset used to build this estimator
164   * @param attrIndex attribute the estimator is for
165   * @exception Exception if building of estimator goes wrong
166   */
167  public void addValues(Instances data, int attrIndex) throws Exception {
168    // can estimator handle the data?
169    getCapabilities().testWithFail(data);
170   
171    double []minMax = new double[2];
172   
173    try {
174      EstimatorUtils.getMinMax(data, attrIndex, minMax);
175    } catch (Exception ex) {
176      ex.printStackTrace();
177      System.out.println(ex.getMessage());
178    }
179   
180    double min = minMax[0];
181    double max = minMax[1];
182
183    // factor is 1.0, data set has not been reduced
184    addValues(data, attrIndex, min, max, 1.0);
185  }
186 
187  /**
188   * Initialize the estimator with all values of one attribute of a dataset.
189   * Some estimator might ignore the min and max values.
190   *
191   * @param data the dataset used to build this estimator
192   * @param attrIndex attribute the estimator is for
193   * @param min minimal border of range
194   * @param max maximal border of range
195   * @param factor number of instances has been reduced to that factor
196   * @exception Exception if building of estimator goes wrong
197   */
198  public void addValues(Instances data, int attrIndex,
199                        double min, double max, double factor) throws Exception {
200    // no handling of factor, would have to be overridden
201
202    // no handling of min and max, would have to be overridden
203
204    int numInst = data.numInstances();
205    for (int i = 1; i < numInst; i++) {
206      addValue(data.instance(i).value(attrIndex), 1.0);
207    }
208  }
209 
210  /**
211   * Initialize the estimator using only the instance of one class.
212   * It is using the values of one attribute only.
213   *
214   * @param data the dataset used to build this estimator
215   * @param attrIndex attribute the estimator is for
216   * @param classIndex index of the class attribute
217   * @param classValue the class value
218   * @exception Exception if building of estimator goes wrong
219   */
220  public void addValues(Instances data, int attrIndex,
221                        int classIndex, int classValue) throws Exception{
222    // can estimator handle the data?
223    m_noClass = false;   
224    getCapabilities().testWithFail(data);
225   
226    // find the minimal and the maximal value
227    double []minMax = new double[2];
228   
229    try {
230      EstimatorUtils.getMinMax(data, attrIndex, minMax);
231    } catch (Exception ex) {
232      ex.printStackTrace();
233      System.out.println(ex.getMessage());
234    }
235   
236    double min = minMax[0];
237    double max = minMax[1];
238 
239    // extract the instances with the given class value
240    Instances workData = new Instances(data, 0);
241    double factor = getInstancesFromClass(data, attrIndex,
242                                          classIndex, 
243                                          (double)classValue, workData);
244
245    // if no data return
246    if (workData.numInstances() == 0) return;
247
248    addValues(data, attrIndex, min, max, factor);
249  }
250 
251  /**
252   * Initialize the estimator using only the instance of one class.
253   * It is using the values of one attribute only.
254   *
255   * @param data the dataset used to build this estimator
256   * @param attrIndex attribute the estimator is for
257   * @param classIndex index of the class attribute
258   * @param classValue the class value
259   * @param min minimal value of this attribute
260   * @param max maximal value of this attribute
261   * @exception Exception if building of estimator goes wrong
262   */
263  public void addValues(Instances data, int attrIndex,
264      int classIndex, int classValue,
265      double min, double max) throws Exception{
266     
267    // extract the instances with the given class value
268    Instances workData = new Instances(data, 0);
269    double factor = getInstancesFromClass(data, attrIndex,
270            classIndex, 
271            (double)classValue, workData);
272
273    // if no data return
274    if (workData.numInstances() == 0) return;
275
276    addValues(data, attrIndex, min, max, factor);
277  }
278 
279 
280  /**
281   * Returns a dataset that contains all instances of a certain class value.
282   *
283   * @param data dataset to select the instances from
284   * @param attrIndex index of the relevant attribute
285   * @param classIndex index of the class attribute
286   * @param classValue the relevant class value
287   * @return a dataset with only
288   */
289  private double getInstancesFromClass(Instances data, int attrIndex,
290                                       int classIndex,
291                                       double classValue, Instances workData) {
292    //DBO.pln("getInstancesFromClass classValue"+classValue+" workData"+data.numInstances());
293
294    int num = 0;
295    int numClassValue = 0;
296    for (int i = 0; i < data.numInstances(); i++) {
297      if (!data.instance(i).isMissing(attrIndex)) {
298        num++;
299        if (data.instance(i).value(classIndex) == classValue) {
300          workData.add(data.instance(i));
301          numClassValue++;
302        }
303      }
304    } 
305
306    Double alphaFactor = new Double((double)numClassValue/(double)num);
307    return alphaFactor;
308  }
309
310  /**
311   * Get a probability estimate for a value.
312   *
313   * @param data the value to estimate the probability of
314   * @return the estimated probability of the supplied value
315   */
316  public abstract double getProbability(double data);
317
318  /**
319   * Build an estimator using the options. The data is given in the options.
320   *
321   * @param est the estimator used
322   * @param options the list of options
323   * @param isIncremental true if estimator is incremental
324   * @exception Exception if something goes wrong or the user requests help on
325   * command options
326   */
327  public static void buildEstimator(Estimator est, String [] options,
328                                    boolean isIncremental) 
329    throws Exception {
330    //DBO.pln("buildEstimator");
331   
332    boolean debug = false;
333    boolean helpRequest;
334   
335    // read all options
336    Builder build = new Builder();
337    try {
338      setGeneralOptions(build, est, options);
339     
340      if (est instanceof OptionHandler) {
341        ((OptionHandler)est).setOptions(options);
342      }
343     
344      Utils.checkForRemainingOptions(options);
345     
346   
347      buildEstimator(est, build.m_instances, build.m_attrIndex,
348                     build.m_classIndex, build.m_classValueIndex, isIncremental);
349    } catch (Exception ex) {
350      ex.printStackTrace();
351      System.out.println(ex.getMessage());
352      String specificOptions = "";
353      // Output the error and also the valid options
354      if (est instanceof OptionHandler) {
355        specificOptions += "\nEstimator options:\n\n";
356        Enumeration enumOptions = ((OptionHandler)est).listOptions();
357        while (enumOptions.hasMoreElements()) {
358          Option option = (Option) enumOptions.nextElement();
359          specificOptions += option.synopsis() + '\n'
360            + option.description() + "\n";
361        }
362      }
363     
364      String genericOptions = "\nGeneral options:\n\n"
365        + "-h\n"
366        + "\tGet help on available options.\n"
367        + "-i <file>\n"
368        + "\tThe name of the file containing input instances.\n"
369        + "\tIf not supplied then instances will be read from stdin.\n"
370        + "-a <attribute index>\n"
371        + "\tThe number of the attribute the probability distribution\n"
372        + "\testimation is done for.\n"
373        + "\t\"first\" and \"last\" are also valid entries.\n"
374        + "\tIf not supplied then no class is assigned.\n"
375        + "-c <class index>\n"
376        + "\tIf class value index is set, this attribute is taken as class.\n"
377        + "\t\"first\" and \"last\" are also valid entries.\n"
378        + "\tIf not supplied then last is default.\n"
379        + "-v <class value index>\n"
380        + "\tIf value is different to -1, select instances of this class value.\n"
381        + "\t\"first\" and \"last\" are also valid entries.\n"
382        + "\tIf not supplied then all instances are taken.\n";
383     
384      throw new Exception('\n' + ex.getMessage()
385                          + specificOptions+genericOptions);
386    }
387  }
388
389  public static void buildEstimator(Estimator est,
390                                    Instances instances, int attrIndex, 
391                                    int classIndex, int classValueIndex,
392                                    boolean isIncremental) throws Exception {
393
394    // DBO.pln("buildEstimator 2 " + classValueIndex);
395
396    // non-incremental estimator add all instances at once
397    if (!isIncremental) {
398     
399      if (classValueIndex == -1) {
400        // DBO.pln("before addValues -- Estimator");
401        est.addValues(instances, attrIndex);
402      } else {
403        // DBO.pln("before addValues with classvalue -- Estimator");
404        est.addValues(instances, attrIndex, 
405                      classIndex, classValueIndex);
406      }
407    } else {
408      // incremental estimator, read one value at a time
409      Enumeration enumInsts = (instances).enumerateInstances();
410      while (enumInsts.hasMoreElements()) {
411        Instance instance = 
412          (Instance) enumInsts.nextElement();
413        ((IncrementalEstimator)est).addValue(instance.value(attrIndex),
414                                             instance.weight());
415      }
416    }
417  }
418 
419  /**
420   * Parses and sets the general options
421   * @param build contains the data used
422   * @param est the estimator used
423   * @param options the options from the command line
424   */
425  private static void setGeneralOptions(Builder build, Estimator est, 
426                                        String [] options) 
427    throws Exception {
428    Reader input = null;
429   
430    // help request option
431    boolean helpRequest = Utils.getFlag('h', options);
432    if (helpRequest) {
433      throw new Exception("Help requested.\n");
434    }
435   
436    // instances used
437    String infileName = Utils.getOption('i', options);
438    if (infileName.length() != 0) {
439      input = new BufferedReader(new FileReader(infileName));
440    } else {
441      input = new BufferedReader(new InputStreamReader(System.in));
442    }
443   
444    build.m_instances = new Instances(input);
445   
446    // attribute index
447    String attrIndex = Utils.getOption('a', options);
448   
449    if (attrIndex.length() != 0) {
450      if (attrIndex.equals("first")) {
451        build.m_attrIndex = 0;
452      } else if (attrIndex.equals("last")) {
453        build.m_attrIndex = build.m_instances.numAttributes() - 1;
454      } else {
455        int index = Integer.parseInt(attrIndex) - 1;
456        if ((index < 0) || (index >= build.m_instances.numAttributes())) {
457          throw new IllegalArgumentException("Option a: attribute index out of range.");
458        }
459        build.m_attrIndex = index;
460       
461      }
462    } else {
463      // default is the first attribute
464      build.m_attrIndex = 0;
465    }
466   
467    //class index, if not given is set to last attribute
468    String classIndex = Utils.getOption('c', options);
469    if (classIndex.length() == 0) classIndex = "last";
470
471    if (classIndex.length() != 0) {
472      if (classIndex.equals("first")) {
473        build.m_classIndex = 0;
474      } else if (classIndex.equals("last")) {
475        build.m_classIndex = build.m_instances.numAttributes() - 1;
476      } else {
477        int cl = Integer.parseInt(classIndex);
478        if (cl == -1) {
479          build.m_classIndex = build.m_instances.numAttributes() - 1;
480        } else {
481          build.m_classIndex = cl - 1; 
482        }
483      }
484    } 
485   
486    //class value index, if not given is set to  -1
487    String classValueIndex = Utils.getOption('v', options);
488    if (classValueIndex.length() != 0) {
489      if (classValueIndex.equals("first")) {
490        build.m_classValueIndex = 0;
491      } else if (classValueIndex.equals("last")) {
492        build.m_classValueIndex = build.m_instances.numAttributes() - 1;
493      } else {
494        int cl = Integer.parseInt(classValueIndex);
495        if (cl == -1) {
496          build.m_classValueIndex = -1;
497        } else {
498          build.m_classValueIndex = cl - 1;     
499        }
500      }
501    } 
502   
503    build.m_instances.setClassIndex(build.m_classIndex);
504  }
505 
506  /**
507   * Creates a deep copy of the given estimator using serialization.
508   *
509   * @param model the estimator to copy
510   * @return a deep copy of the estimator
511   * @exception Exception if an error occurs
512   */
513  public static Estimator clone(Estimator model) throws Exception {
514   
515    return makeCopy(model);
516  }
517 
518  /**
519   * Creates a deep copy of the given estimator using serialization.
520   *
521   * @param model the estimator to copy
522   * @return a deep copy of the estimator
523   * @exception Exception if an error occurs
524   */
525  public static Estimator makeCopy(Estimator model) throws Exception {
526
527    return (Estimator)new SerializedObject(model).getObject();
528  }
529
530  /**
531   * Creates a given number of deep copies of the given estimator using serialization.
532   *
533   * @param model the estimator to copy
534   * @param num the number of estimator copies to create.
535   * @return an array of estimators.
536   * @exception Exception if an error occurs
537   */
538  public static Estimator [] makeCopies(Estimator model,
539                                         int num) throws Exception {
540
541    if (model == null) {
542      throw new Exception("No model estimator set");
543    }
544    Estimator [] estimators = new Estimator [num];
545    SerializedObject so = new SerializedObject(model);
546    for(int i = 0; i < estimators.length; i++) {
547      estimators[i] = (Estimator) so.getObject();
548    }
549    return estimators;
550  }
551 
552  /**
553   * Tests whether the current estimation object is equal to another
554   * estimation object
555   *
556   * @param obj the object to compare against
557   * @return true if the two objects are equal
558   */
559  public boolean equals(Object obj) {
560   
561    if ((obj == null) || !(obj.getClass().equals(this.getClass()))) {
562      return false;
563    }
564    Estimator cmp = (Estimator) obj;
565    if (m_Debug != cmp.m_Debug) return false;
566    if (m_classValueIndex != cmp.m_classValueIndex) return false;
567    if (m_noClass != cmp.m_noClass) return false;
568   
569    return true;
570  }
571
572  /**
573   * Returns an enumeration describing the available options.
574   *
575   * @return an enumeration of all the available options.
576   */
577  public Enumeration listOptions() {
578
579    Vector newVector = new Vector(1);
580
581    newVector.addElement(new Option(
582              "\tIf set, estimator is run in debug mode and\n"
583              + "\tmay output additional info to the console",
584              "D", 0, "-D"));
585    return newVector.elements();
586  }
587
588  /**
589   * Parses a given list of options. Valid options are:<p>
590   *
591   * -D  <br>
592   * If set, estimator is run in debug mode and
593   * may output additional info to the console.<p>
594   *
595   * @param options the list of options as an array of strings
596   * @exception Exception if an option is not supported
597   */
598  public void setOptions(String[] options) throws Exception {
599
600    setDebug(Utils.getFlag('D', options));
601  }
602
603  /**
604   * Gets the current settings of the Estimator.
605   *
606   * @return an array of strings suitable for passing to setOptions
607   */
608  public String [] getOptions() {
609
610    String [] options;
611    if (getDebug()) {
612      options = new String[1];
613      options[0] = "-D";
614    } else {
615      options = new String[0];
616    }
617    return options;
618  }
619 
620  /**
621   * Creates a new instance of a estimatorr given it's class name and
622   * (optional) arguments to pass to it's setOptions method. If the
623   * classifier implements OptionHandler and the options parameter is
624   * non-null, the classifier will have it's options set.
625   *
626   * @param name the fully qualified class name of the estimatorr
627   * @param options an array of options suitable for passing to setOptions. May
628   * be null.
629   * @return the newly created classifier, ready for use.
630   * @exception Exception if the classifier name is invalid, or the options
631   * supplied are not acceptable to the classifier
632   */
633  public static Estimator forName(String name,
634      String [] options) throws Exception {
635   
636    return (Estimator)Utils.forName(Estimator.class,
637        name,
638        options);
639  }
640
641 /**
642   * Set debugging mode.
643   *
644   * @param debug true if debug output should be printed
645   */
646  public void setDebug(boolean debug) {
647
648    m_Debug = debug;
649  }
650
651  /**
652   * Get whether debugging is turned on.
653   *
654   * @return true if debugging output is on
655   */
656  public boolean getDebug() {
657
658    return m_Debug;
659  }
660 
661  /**
662   * Returns the tip text for this property
663   * @return tip text for this property suitable for
664   * displaying in the explorer/experimenter gui
665   */
666  public String debugTipText() {
667    return "If set to true, estimator may output additional info to " +
668      "the console.";
669  }
670 
671  /**
672   * Returns the Capabilities of this Estimator. Derived estimators have to
673   * override this method to enable capabilities.
674   *
675   * @return            the capabilities of this object
676   * @see               Capabilities
677   */
678  public Capabilities getCapabilities() {
679    Capabilities result = new Capabilities(this);
680    result.enableAll();
681   
682/*    // class
683    if (!m_noClass) {
684      result.enable(Capability.NOMINAL_CLASS);
685      result.enable(Capability.MISSING_CLASS_VALUES);
686    } else {
687      result.enable(Capability.NO_CLASS);
688    } */
689       
690    return result;
691  }
692 
693  /**
694   * Returns the revision string.
695   *
696   * @return            the revision
697   */
698  public String getRevision() {
699    return RevisionUtils.extract("$Revision: 5489 $");
700  }
701 
702  /**
703   * Test if the estimator can handle the data.
704   * @param data the dataset the estimator takes an attribute from
705   * @param attrIndex the index of the attribute
706   * @see Capabilities
707   */
708  public void testCapabilities(Instances data, int attrIndex) throws Exception {
709    getCapabilities().testWithFail(data);
710    getCapabilities().testWithFail(data.attribute(attrIndex));
711  }
712}
713
714
715
716
717
718
719
720
Note: See TracBrowser for help on using the repository browser.