source: src/main/java/weka/experiment/AveragingResultProducer.java @ 6

Last change on this file since 6 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 36.5 KB
RevLine 
[4]1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    AveragingResultProducer.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23
24package weka.experiment;
25
26import weka.core.AdditionalMeasureProducer;
27import weka.core.FastVector;
28import weka.core.Instances;
29import weka.core.Option;
30import weka.core.OptionHandler;
31import weka.core.RevisionHandler;
32import weka.core.RevisionUtils;
33import weka.core.Utils;
34
35import java.util.Enumeration;
36import java.util.Hashtable;
37import java.util.Vector;
38
39/**
40 <!-- globalinfo-start -->
41 * Takes the results from a ResultProducer and submits the average to the result listener. Normally used with a CrossValidationResultProducer to perform n x m fold cross validation. For non-numeric result fields, the first value is used.
42 * <p/>
43 <!-- globalinfo-end -->
44 *
45 <!-- options-start -->
46 * Valid options are: <p/>
47 *
48 * <pre> -F &lt;field name&gt;
49 *  The name of the field to average over.
50 *  (default "Fold")</pre>
51 *
52 * <pre> -X &lt;num results&gt;
53 *  The number of results expected per average.
54 *  (default 10)</pre>
55 *
56 * <pre> -S
57 *  Calculate standard deviations.
58 *  (default only averages)</pre>
59 *
60 * <pre> -W &lt;class name&gt;
61 *  The full class name of a ResultProducer.
62 *  eg: weka.experiment.CrossValidationResultProducer</pre>
63 *
64 * <pre>
65 * Options specific to result producer weka.experiment.CrossValidationResultProducer:
66 * </pre>
67 *
68 * <pre> -X &lt;number of folds&gt;
69 *  The number of folds to use for the cross-validation.
70 *  (default 10)</pre>
71 *
72 * <pre> -D
73 * Save raw split evaluator output.</pre>
74 *
75 * <pre> -O &lt;file/directory name/path&gt;
76 *  The filename where raw output will be stored.
77 *  If a directory name is specified then then individual
78 *  outputs will be gzipped, otherwise all output will be
79 *  zipped to the named file. Use in conjuction with -D. (default splitEvalutorOut.zip)</pre>
80 *
81 * <pre> -W &lt;class name&gt;
82 *  The full class name of a SplitEvaluator.
83 *  eg: weka.experiment.ClassifierSplitEvaluator</pre>
84 *
85 * <pre>
86 * Options specific to split evaluator weka.experiment.ClassifierSplitEvaluator:
87 * </pre>
88 *
89 * <pre> -W &lt;class name&gt;
90 *  The full class name of the classifier.
91 *  eg: weka.classifiers.bayes.NaiveBayes</pre>
92 *
93 * <pre> -C &lt;index&gt;
94 *  The index of the class for which IR statistics
95 *  are to be output. (default 1)</pre>
96 *
97 * <pre> -I &lt;index&gt;
98 *  The index of an attribute to output in the
99 *  results. This attribute should identify an
100 *  instance in order to know which instances are
101 *  in the test set of a cross validation. if 0
102 *  no output (default 0).</pre>
103 *
104 * <pre> -P
105 *  Add target and prediction columns to the result
106 *  for each fold.</pre>
107 *
108 * <pre>
109 * Options specific to classifier weka.classifiers.rules.ZeroR:
110 * </pre>
111 *
112 * <pre> -D
113 *  If set, classifier is run in debug mode and
114 *  may output additional info to the console</pre>
115 *
116 <!-- options-end -->
117 *
118 * All options after -- will be passed to the result producer.
119 *
120 * @author Len Trigg (trigg@cs.waikato.ac.nz)
121 * @version $Revision: 1.18 $
122 */
123public class AveragingResultProducer 
124  implements ResultListener, ResultProducer, OptionHandler,
125             AdditionalMeasureProducer, RevisionHandler {
126
127  /** for serialization */
128  static final long serialVersionUID = 2551284958501991352L;
129 
130  /** The dataset of interest */
131  protected Instances m_Instances;
132
133  /** The ResultListener to send results to */
134  protected ResultListener m_ResultListener = new CSVResultListener();
135
136  /** The ResultProducer used to generate results */
137  protected ResultProducer m_ResultProducer
138    = new CrossValidationResultProducer();
139
140  /** The names of any additional measures to look for in SplitEvaluators */
141  protected String [] m_AdditionalMeasures = null;
142 
143  /** The number of results expected to average over for each run */
144  protected int m_ExpectedResultsPerAverage = 10;
145
146  /** True if standard deviation fields should be produced */
147  protected boolean m_CalculateStdDevs;
148   
149  /**
150   * The name of the field that will contain the number of results
151   * averaged over.
152   */
153  protected String m_CountFieldName = "Num_" + CrossValidationResultProducer
154    .FOLD_FIELD_NAME;
155
156  /** The name of the key field to average over */
157  protected String m_KeyFieldName = CrossValidationResultProducer
158    .FOLD_FIELD_NAME;
159
160  /** The index of the field to average over in the resultproducers key */
161  protected int m_KeyIndex = -1;
162
163  /** Collects the keys from a single run */
164  protected FastVector m_Keys = new FastVector();
165 
166  /** Collects the results from a single run */
167  protected FastVector m_Results = new FastVector();
168
169  /**
170   * Returns a string describing this result producer
171   * @return a description of the result producer suitable for
172   * displaying in the explorer/experimenter gui
173   */
174  public String globalInfo() {
175    return "Takes the results from a ResultProducer "
176      +"and submits the average to the result listener. Normally used with "
177      +"a CrossValidationResultProducer to perform n x m fold cross "
178      +"validation. For non-numeric result fields, the first value is used.";
179  }
180
181  /**
182   * Scans through the key field names of the result producer to find
183   * the index of the key field to average over. Sets the value of
184   * m_KeyIndex to the index, or -1 if no matching key field was found.
185   *
186   * @return the index of the key field to average over
187   */
188  protected int findKeyIndex() {
189
190    m_KeyIndex = -1;
191    try {
192      if (m_ResultProducer != null) {
193        String [] keyNames = m_ResultProducer.getKeyNames();
194        for (int i = 0; i < keyNames.length; i++) {
195          if (keyNames[i].equals(m_KeyFieldName)) {
196            m_KeyIndex = i;
197            break;
198          }
199        }
200      }
201    } catch (Exception ex) {
202    }
203    return m_KeyIndex;
204  }
205
206  /**
207   * Determines if there are any constraints (imposed by the
208   * destination) on the result columns to be produced by
209   * resultProducers. Null should be returned if there are NO
210   * constraints, otherwise a list of column names should be
211   * returned as an array of Strings.
212   * @param rp the ResultProducer to which the constraints will apply
213   * @return an array of column names to which resutltProducer's
214   * results will be restricted.
215   * @throws Exception if constraints can't be determined
216   */
217  public String [] determineColumnConstraints(ResultProducer rp) 
218    throws Exception {
219    return null;
220  }
221
222  /**
223   * Simulates a run to collect the keys the sub-resultproducer could
224   * generate. Does some checking on the keys and determines the
225   * template key.
226   *
227   * @param run the run number
228   * @return a template key (null for the field being averaged)
229   * @throws Exception if an error occurs
230   */
231  protected Object [] determineTemplate(int run) throws Exception {
232
233    if (m_Instances == null) {
234      throw new Exception("No Instances set");
235    }
236    m_ResultProducer.setInstances(m_Instances);
237
238    // Clear the collected results
239    m_Keys.removeAllElements();
240    m_Results.removeAllElements();
241   
242    m_ResultProducer.doRunKeys(run);
243    checkForMultipleDifferences();
244
245    Object [] template = (Object [])((Object [])m_Keys.elementAt(0)).clone();
246    template[m_KeyIndex] = null;
247    // Check for duplicate keys
248    checkForDuplicateKeys(template);
249
250    return template;
251  }
252
253  /**
254   * Gets the keys for a specified run number. Different run
255   * numbers correspond to different randomizations of the data. Keys
256   * produced should be sent to the current ResultListener
257   *
258   * @param run the run number to get keys for.
259   * @throws Exception if a problem occurs while getting the keys
260   */
261  public void doRunKeys(int run) throws Exception {
262
263    // Generate the template
264    Object [] template = determineTemplate(run);
265    String [] newKey = new String [template.length - 1];
266    System.arraycopy(template, 0, newKey, 0, m_KeyIndex);
267    System.arraycopy(template, m_KeyIndex + 1,
268                     newKey, m_KeyIndex,
269                     template.length - m_KeyIndex - 1);
270    m_ResultListener.acceptResult(this, newKey, null);     
271  }
272
273  /**
274   * Gets the results for a specified run number. Different run
275   * numbers correspond to different randomizations of the data. Results
276   * produced should be sent to the current ResultListener
277   *
278   * @param run the run number to get results for.
279   * @throws Exception if a problem occurs while getting the results
280   */
281  public void doRun(int run) throws Exception {
282
283    // Generate the key and ask whether the result is required
284    Object [] template = determineTemplate(run);
285    String [] newKey = new String [template.length - 1];
286    System.arraycopy(template, 0, newKey, 0, m_KeyIndex);
287    System.arraycopy(template, m_KeyIndex + 1,
288                     newKey, m_KeyIndex,
289                     template.length - m_KeyIndex - 1);
290
291    if (m_ResultListener.isResultRequired(this, newKey)) {
292      // Clear the collected keys
293      m_Keys.removeAllElements();
294      m_Results.removeAllElements();
295     
296      m_ResultProducer.doRun(run);
297     
298      // Average the results collected
299      //System.err.println("Number of results collected: " + m_Keys.size());
300     
301      // Check that the keys only differ on the selected key field
302      checkForMultipleDifferences();
303     
304      template = (Object [])((Object [])m_Keys.elementAt(0)).clone();
305      template[m_KeyIndex] = null;
306      // Check for duplicate keys
307      checkForDuplicateKeys(template);
308      // Calculate the average and submit it if necessary
309      doAverageResult(template);
310    }
311  }
312
313 
314  /**
315   * Compares a key to a template to see whether they match. Null
316   * fields in the template are ignored in the matching.
317   *
318   * @param template the template to match against
319   * @param test the key to test
320   * @return true if the test key matches the template on all non-null template
321   * fields
322   */
323  protected boolean matchesTemplate(Object [] template, Object [] test) {
324   
325    if (template.length != test.length) {
326      return false;
327    }
328    for (int i = 0; i < test.length; i++) {
329      if ((template[i] != null) && (!template[i].equals(test[i]))) {
330        return false;
331      }
332    }
333    return true;
334  }
335 
336  /**
337   * Asks the resultlistener whether an average result is required, and
338   * if so, calculates it.
339   *
340   * @param template the template to match keys against when calculating the
341   * average
342   * @throws Exception if an error occurs
343   */
344  protected void doAverageResult(Object [] template) throws Exception {
345
346    // Generate the key and ask whether the result is required
347    String [] newKey = new String [template.length - 1];
348    System.arraycopy(template, 0, newKey, 0, m_KeyIndex);
349    System.arraycopy(template, m_KeyIndex + 1,
350                     newKey, m_KeyIndex,
351                     template.length - m_KeyIndex - 1);
352    if (m_ResultListener.isResultRequired(this, newKey)) {
353      Object [] resultTypes = m_ResultProducer.getResultTypes();
354      Stats [] stats = new Stats [resultTypes.length];
355      for (int i = 0; i < stats.length; i++) {
356        stats[i] = new Stats();
357      }
358      Object [] result = getResultTypes();
359      int numMatches = 0;
360      for (int i = 0; i < m_Keys.size(); i++) {
361        Object [] currentKey = (Object [])m_Keys.elementAt(i);
362        // Skip non-matching keys
363        if (!matchesTemplate(template, currentKey)) {
364          continue;
365        }
366        // Add the results to the stats accumulator
367        Object [] currentResult = (Object [])m_Results.elementAt(i);
368        numMatches++;
369        for (int j = 0; j < resultTypes.length; j++) {
370          if (resultTypes[j] instanceof Double) {
371            if (currentResult[j] == null) {
372
373              // set the stats object for this result to null---
374              // more than likely this is an additional measure field
375              // not supported by the low level split evaluator
376              if (stats[j] != null) {
377                stats[j] = null;
378              }
379             
380              /* throw new Exception("Null numeric result field found:\n"
381                 + DatabaseUtils.arrayToString(currentKey)
382                 + " -- "
383                 + DatabaseUtils
384                 .arrayToString(currentResult)); */
385            }
386            if (stats[j] != null) {
387              double currentVal = ((Double)currentResult[j]).doubleValue();
388              stats[j].add(currentVal);
389            }
390          }
391        }
392      }
393      if (numMatches != m_ExpectedResultsPerAverage) {
394        throw new Exception("Expected " + m_ExpectedResultsPerAverage
395                            + " results matching key \""
396                            + DatabaseUtils.arrayToString(template)
397                            + "\" but got "
398                            + numMatches);
399      }
400      result[0] = new Double(numMatches);
401      Object [] currentResult = (Object [])m_Results.elementAt(0);
402      int k = 1;
403      for (int j = 0; j < resultTypes.length; j++) {
404        if (resultTypes[j] instanceof Double) {
405          if (stats[j] != null) {
406            stats[j].calculateDerived();
407            result[k++] = new Double(stats[j].mean);
408          } else {
409            result[k++] = null;
410          }
411          if (getCalculateStdDevs()) {
412            if (stats[j] != null) {
413              result[k++] = new Double(stats[j].stdDev);
414            } else {
415              result[k++] = null;
416            }
417          }
418        } else {
419          result[k++] = currentResult[j];
420        }
421      }
422      m_ResultListener.acceptResult(this, newKey, result);     
423    }
424  }
425 
426  /**
427   * Checks whether any duplicate results (with respect to a key template)
428   * were received.
429   *
430   * @param template the template key.
431   * @throws Exception if duplicate results are detected
432   */
433  protected void checkForDuplicateKeys(Object [] template) throws Exception {
434
435    Hashtable hash = new Hashtable();
436    int numMatches = 0;
437    for (int i = 0; i < m_Keys.size(); i++) {
438      Object [] current = (Object [])m_Keys.elementAt(i);
439      // Skip non-matching keys
440      if (!matchesTemplate(template, current)) {
441        continue;
442      }
443      if (hash.containsKey(current[m_KeyIndex])) {
444        throw new Exception("Duplicate result received:"
445                            + DatabaseUtils.arrayToString(current));
446      }
447      numMatches++;
448      hash.put(current[m_KeyIndex], current[m_KeyIndex]);
449    }
450    if (numMatches != m_ExpectedResultsPerAverage) {
451      throw new Exception("Expected " + m_ExpectedResultsPerAverage
452                          + " results matching key \""
453                          + DatabaseUtils.arrayToString(template)
454                          + "\" but got "
455                          + numMatches);
456    }
457  }
458 
459  /**
460   * Checks that the keys for a run only differ in one key field. If they
461   * differ in more than one field, a more sophisticated averager will submit
462   * multiple results - for now an exception is thrown. Currently assumes that
463   * the most differences will be shown between the first and last
464   * result received.
465   *
466   * @throws Exception if the keys differ on fields other than the
467   * key averaging field
468   */
469  protected void checkForMultipleDifferences() throws Exception {
470   
471    Object [] firstKey = (Object [])m_Keys.elementAt(0);
472    Object [] lastKey = (Object [])m_Keys.elementAt(m_Keys.size() - 1);
473    /*
474    System.err.println("First key:" +  DatabaseUtils.arrayToString(firstKey));
475    System.err.println("Last key :" + DatabaseUtils.arrayToString(lastKey));
476    */
477    for (int i = 0; i < firstKey.length; i++) {
478      if ((i != m_KeyIndex) && !firstKey[i].equals(lastKey[i])) {
479        throw new Exception("Keys differ on fields other than \""
480                            + m_KeyFieldName
481                            + "\" -- time to implement multiple averaging");
482      }
483    }
484  }
485 
486  /**
487   * Prepare for the results to be received.
488   *
489   * @param rp the ResultProducer that will generate the results
490   * @throws Exception if an error occurs during preprocessing.
491   */
492  public void preProcess(ResultProducer rp) throws Exception {
493
494    if (m_ResultListener == null) {
495      throw new Exception("No ResultListener set");
496    }
497    m_ResultListener.preProcess(this);
498  }
499
500  /**
501   * Prepare to generate results. The ResultProducer should call
502   * preProcess(this) on the ResultListener it is to send results to.
503   *
504   * @throws Exception if an error occurs during preprocessing.
505   */
506  public void preProcess() throws Exception {
507   
508    if (m_ResultProducer == null) {
509      throw new Exception("No ResultProducer set");
510    }
511    // Tell the resultproducer to send results to us
512    m_ResultProducer.setResultListener(this);
513    findKeyIndex();
514    if (m_KeyIndex == -1) {
515      throw new Exception("No key field called " + m_KeyFieldName
516                          + " produced by "
517                          + m_ResultProducer.getClass().getName());
518    }
519    m_ResultProducer.preProcess();
520  }
521 
522  /**
523   * When this method is called, it indicates that no more results
524   * will be sent that need to be grouped together in any way.
525   *
526   * @param rp the ResultProducer that generated the results
527   * @throws Exception if an error occurs
528   */
529  public void postProcess(ResultProducer rp) throws Exception {
530
531    m_ResultListener.postProcess(this);
532  }
533
534  /**
535   * When this method is called, it indicates that no more requests to
536   * generate results for the current experiment will be sent. The
537   * ResultProducer should call preProcess(this) on the
538   * ResultListener it is to send results to.
539   *
540   * @throws Exception if an error occurs
541   */
542  public void postProcess() throws Exception {
543
544    m_ResultProducer.postProcess();
545  }
546 
547  /**
548   * Accepts results from a ResultProducer.
549   *
550   * @param rp the ResultProducer that generated the results
551   * @param key an array of Objects (Strings or Doubles) that uniquely
552   * identify a result for a given ResultProducer with given compatibilityState
553   * @param result the results stored in an array. The objects stored in
554   * the array may be Strings, Doubles, or null (for the missing value).
555   * @throws Exception if the result could not be accepted.
556   */
557  public void acceptResult(ResultProducer rp, Object [] key, Object [] result)
558    throws Exception {
559
560    if (m_ResultProducer != rp) {
561      throw new Error("Unrecognized ResultProducer sending results!!");
562    }
563    m_Keys.addElement(key);
564    m_Results.addElement(result);
565  }
566
567  /**
568   * Determines whether the results for a specified key must be
569   * generated.
570   *
571   * @param rp the ResultProducer wanting to generate the results
572   * @param key an array of Objects (Strings or Doubles) that uniquely
573   * identify a result for a given ResultProducer with given compatibilityState
574   * @return true if the result should be generated
575   * @throws Exception if it could not be determined if the result
576   * is needed.
577   */
578  public boolean isResultRequired(ResultProducer rp, Object [] key) 
579    throws Exception {
580
581    if (m_ResultProducer != rp) {
582      throw new Error("Unrecognized ResultProducer sending results!!");
583    }
584    return true;
585  }
586
587  /**
588   * Gets the names of each of the columns produced for a single run.
589   *
590   * @return an array containing the name of each column
591   * @throws Exception if key names cannot be generated
592   */
593  public String [] getKeyNames() throws Exception {
594
595    if (m_KeyIndex == -1) {
596      throw new Exception("No key field called " + m_KeyFieldName
597                          + " produced by "
598                          + m_ResultProducer.getClass().getName());
599    }
600    String [] keyNames = m_ResultProducer.getKeyNames();
601    String [] newKeyNames = new String [keyNames.length - 1];
602    System.arraycopy(keyNames, 0, newKeyNames, 0, m_KeyIndex);
603    System.arraycopy(keyNames, m_KeyIndex + 1,
604                     newKeyNames, m_KeyIndex,
605                     keyNames.length - m_KeyIndex - 1);
606    return newKeyNames;
607  }
608
609  /**
610   * Gets the data types of each of the columns produced for a single run.
611   * This method should really be static.
612   *
613   * @return an array containing objects of the type of each column. The
614   * objects should be Strings, or Doubles.
615   * @throws Exception if the key types could not be determined (perhaps
616   * because of a problem from a nested sub-resultproducer)
617   */
618  public Object [] getKeyTypes() throws Exception {
619
620    if (m_KeyIndex == -1) {
621      throw new Exception("No key field called " + m_KeyFieldName
622                          + " produced by "
623                          + m_ResultProducer.getClass().getName());
624    }
625    Object [] keyTypes = m_ResultProducer.getKeyTypes();
626    // Find and remove the key field that is being averaged over
627    Object [] newKeyTypes = new String [keyTypes.length - 1];
628    System.arraycopy(keyTypes, 0, newKeyTypes, 0, m_KeyIndex);
629    System.arraycopy(keyTypes, m_KeyIndex + 1,
630                     newKeyTypes, m_KeyIndex,
631                     keyTypes.length - m_KeyIndex - 1);
632    return newKeyTypes;
633  }
634
635  /**
636   * Gets the names of each of the columns produced for a single run.
637   * A new result field is added for the number of results used to
638   * produce each average.
639   * If only averages are being produced the names are not altered, if
640   * standard deviations are produced then "Dev_" and "Avg_" are prepended
641   * to each result deviation and average field respectively.
642   *
643   * @return an array containing the name of each column
644   * @throws Exception if the result names could not be determined (perhaps
645   * because of a problem from a nested sub-resultproducer)
646   */
647  public String [] getResultNames() throws Exception {
648
649    String [] resultNames = m_ResultProducer.getResultNames();
650    // Add in the names of our extra Result fields
651    if (getCalculateStdDevs()) {
652      Object [] resultTypes = m_ResultProducer.getResultTypes();
653      int numNumeric = 0;
654      for (int i = 0; i < resultTypes.length; i++) {
655        if (resultTypes[i] instanceof Double) {
656          numNumeric++;
657        }
658      }
659      String [] newResultNames = new String [resultNames.length +
660                                            1 + numNumeric];
661      newResultNames[0] = m_CountFieldName;
662      int j = 1;
663      for (int i = 0; i < resultNames.length; i++) {
664        newResultNames[j++] = "Avg_" + resultNames[i];
665        if (resultTypes[i] instanceof Double) {
666          newResultNames[j++] = "Dev_" + resultNames[i];
667        }
668      }
669      return newResultNames;
670    } else {
671      String [] newResultNames = new String [resultNames.length + 1];
672      newResultNames[0] = m_CountFieldName;
673      System.arraycopy(resultNames, 0, newResultNames, 1, resultNames.length);
674      return newResultNames;
675    }
676  }
677
678  /**
679   * Gets the data types of each of the columns produced for a single run.
680   *
681   * @return an array containing objects of the type of each column. The
682   * objects should be Strings, or Doubles.
683   * @throws Exception if the result types could not be determined (perhaps
684   * because of a problem from a nested sub-resultproducer)
685   */
686  public Object [] getResultTypes() throws Exception {
687
688    Object [] resultTypes = m_ResultProducer.getResultTypes();
689    // Add in the types of our extra Result fields
690    if (getCalculateStdDevs()) {
691      int numNumeric = 0;
692      for (int i = 0; i < resultTypes.length; i++) {
693        if (resultTypes[i] instanceof Double) {
694          numNumeric++;
695        }
696      }
697      Object [] newResultTypes = new Object [resultTypes.length +
698                                            1 + numNumeric];
699      newResultTypes[0] = new Double(0);
700      int j = 1;
701      for (int i = 0; i < resultTypes.length; i++) {
702        newResultTypes[j++] = resultTypes[i];
703        if (resultTypes[i] instanceof Double) {
704          newResultTypes[j++] = new Double(0);
705        }
706      }
707      return newResultTypes;
708    } else {
709      Object [] newResultTypes = new Object [resultTypes.length + 1];
710      newResultTypes[0] = new Double(0);
711      System.arraycopy(resultTypes, 0, newResultTypes, 1, resultTypes.length);
712      return newResultTypes;
713    }
714  }
715
716  /**
717   * Gets a description of the internal settings of the result
718   * producer, sufficient for distinguishing a ResultProducer
719   * instance from another with different settings (ignoring
720   * those settings set through this interface). For example,
721   * a cross-validation ResultProducer may have a setting for the
722   * number of folds. For a given state, the results produced should
723   * be compatible. Typically if a ResultProducer is an OptionHandler,
724   * this string will represent the command line arguments required
725   * to set the ResultProducer to that state.
726   *
727   * @return the description of the ResultProducer state, or null
728   * if no state is defined
729   */
730  public String getCompatibilityState() {
731
732    String result = // "-F " + Utils.quote(getKeyFieldName())
733      " -X " + getExpectedResultsPerAverage() + " ";
734    if (getCalculateStdDevs()) {
735      result += "-S ";
736    }
737    if (m_ResultProducer == null) {
738      result += "<null ResultProducer>";
739    } else {
740      result += "-W " + m_ResultProducer.getClass().getName();
741    }
742    result  += " -- " + m_ResultProducer.getCompatibilityState();
743    return result.trim();
744  }
745
746
747  /**
748   * Returns an enumeration describing the available options..
749   *
750   * @return an enumeration of all the available options.
751   */
752  public Enumeration listOptions() {
753
754    Vector newVector = new Vector(2);
755
756    newVector.addElement(new Option(
757             "\tThe name of the field to average over.\n"
758              +"\t(default \"Fold\")", 
759             "F", 1, 
760             "-F <field name>"));
761    newVector.addElement(new Option(
762             "\tThe number of results expected per average.\n"
763              +"\t(default 10)", 
764             "X", 1, 
765             "-X <num results>"));
766    newVector.addElement(new Option(
767             "\tCalculate standard deviations.\n"
768              +"\t(default only averages)", 
769             "S", 0, 
770             "-S"));
771    newVector.addElement(new Option(
772             "\tThe full class name of a ResultProducer.\n"
773              +"\teg: weka.experiment.CrossValidationResultProducer", 
774             "W", 1, 
775             "-W <class name>"));
776
777    if ((m_ResultProducer != null) &&
778        (m_ResultProducer instanceof OptionHandler)) {
779      newVector.addElement(new Option(
780             "",
781             "", 0, "\nOptions specific to result producer "
782             + m_ResultProducer.getClass().getName() + ":"));
783      Enumeration enu = ((OptionHandler)m_ResultProducer).listOptions();
784      while (enu.hasMoreElements()) {
785        newVector.addElement(enu.nextElement());
786      }
787    }
788    return newVector.elements();
789  }
790
791  /**
792   * Parses a given list of options. <p/>
793   *
794   <!-- options-start -->
795   * Valid options are: <p/>
796   *
797   * <pre> -F &lt;field name&gt;
798   *  The name of the field to average over.
799   *  (default "Fold")</pre>
800   *
801   * <pre> -X &lt;num results&gt;
802   *  The number of results expected per average.
803   *  (default 10)</pre>
804   *
805   * <pre> -S
806   *  Calculate standard deviations.
807   *  (default only averages)</pre>
808   *
809   * <pre> -W &lt;class name&gt;
810   *  The full class name of a ResultProducer.
811   *  eg: weka.experiment.CrossValidationResultProducer</pre>
812   *
813   * <pre>
814   * Options specific to result producer weka.experiment.CrossValidationResultProducer:
815   * </pre>
816   *
817   * <pre> -X &lt;number of folds&gt;
818   *  The number of folds to use for the cross-validation.
819   *  (default 10)</pre>
820   *
821   * <pre> -D
822   * Save raw split evaluator output.</pre>
823   *
824   * <pre> -O &lt;file/directory name/path&gt;
825   *  The filename where raw output will be stored.
826   *  If a directory name is specified then then individual
827   *  outputs will be gzipped, otherwise all output will be
828   *  zipped to the named file. Use in conjuction with -D. (default splitEvalutorOut.zip)</pre>
829   *
830   * <pre> -W &lt;class name&gt;
831   *  The full class name of a SplitEvaluator.
832   *  eg: weka.experiment.ClassifierSplitEvaluator</pre>
833   *
834   * <pre>
835   * Options specific to split evaluator weka.experiment.ClassifierSplitEvaluator:
836   * </pre>
837   *
838   * <pre> -W &lt;class name&gt;
839   *  The full class name of the classifier.
840   *  eg: weka.classifiers.bayes.NaiveBayes</pre>
841   *
842   * <pre> -C &lt;index&gt;
843   *  The index of the class for which IR statistics
844   *  are to be output. (default 1)</pre>
845   *
846   * <pre> -I &lt;index&gt;
847   *  The index of an attribute to output in the
848   *  results. This attribute should identify an
849   *  instance in order to know which instances are
850   *  in the test set of a cross validation. if 0
851   *  no output (default 0).</pre>
852   *
853   * <pre> -P
854   *  Add target and prediction columns to the result
855   *  for each fold.</pre>
856   *
857   * <pre>
858   * Options specific to classifier weka.classifiers.rules.ZeroR:
859   * </pre>
860   *
861   * <pre> -D
862   *  If set, classifier is run in debug mode and
863   *  may output additional info to the console</pre>
864   *
865   <!-- options-end -->
866   *
867   * All options after -- will be passed to the result producer.
868   *
869   * @param options the list of options as an array of strings
870   * @throws Exception if an option is not supported
871   */
872  public void setOptions(String[] options) throws Exception {
873   
874    String keyFieldName = Utils.getOption('F', options);
875    if (keyFieldName.length() != 0) {
876      setKeyFieldName(keyFieldName);
877    } else {
878      setKeyFieldName(CrossValidationResultProducer.FOLD_FIELD_NAME);
879    }
880
881    String numResults = Utils.getOption('X', options);
882    if (numResults.length() != 0) {
883      setExpectedResultsPerAverage(Integer.parseInt(numResults));
884    } else {
885      setExpectedResultsPerAverage(10);
886    }
887
888    setCalculateStdDevs(Utils.getFlag('S', options));
889   
890    String rpName = Utils.getOption('W', options);
891    if (rpName.length() == 0) {
892      throw new Exception("A ResultProducer must be specified with"
893                          + " the -W option.");
894    }
895    // Do it first without options, so if an exception is thrown during
896    // the option setting, listOptions will contain options for the actual
897    // RP.
898    setResultProducer((ResultProducer)Utils.forName(
899                      ResultProducer.class,
900                      rpName,
901                      null));
902    if (getResultProducer() instanceof OptionHandler) {
903      ((OptionHandler) getResultProducer())
904        .setOptions(Utils.partitionOptions(options));
905    }
906  }
907
908  /**
909   * Gets the current settings of the result producer.
910   *
911   * @return an array of strings suitable for passing to setOptions
912   */
913  public String [] getOptions() {
914
915    String [] seOptions = new String [0];
916    if ((m_ResultProducer != null) && 
917        (m_ResultProducer instanceof OptionHandler)) {
918      seOptions = ((OptionHandler)m_ResultProducer).getOptions();
919    }
920   
921    String [] options = new String [seOptions.length + 8];
922    int current = 0;
923
924    options[current++] = "-F";
925    options[current++] = "" + getKeyFieldName();
926    options[current++] = "-X";
927    options[current++] = "" + getExpectedResultsPerAverage();
928    if (getCalculateStdDevs()) {
929      options[current++] = "-S";
930    }
931    if (getResultProducer() != null) {
932      options[current++] = "-W";
933      options[current++] = getResultProducer().getClass().getName();
934    }
935    options[current++] = "--";
936
937    System.arraycopy(seOptions, 0, options, current, 
938                     seOptions.length);
939    current += seOptions.length;
940    while (current < options.length) {
941      options[current++] = "";
942    }
943    return options;
944  }
945
946  /**
947   * Set a list of method names for additional measures to look for
948   * in SplitEvaluators. This could contain many measures (of which only a
949   * subset may be produceable by the current resultProducer) if an experiment
950   * is the type that iterates over a set of properties.
951   * @param additionalMeasures an array of measure names, null if none
952   */
953  public void setAdditionalMeasures(String [] additionalMeasures) {
954    m_AdditionalMeasures = additionalMeasures;
955
956    if (m_ResultProducer != null) {
957      System.err.println("AveragingResultProducer: setting additional "
958                         +"measures for "
959                         +"ResultProducer");
960      m_ResultProducer.setAdditionalMeasures(m_AdditionalMeasures);
961    }
962  }
963
964  /**
965   * Returns an enumeration of any additional measure names that might be
966   * in the result producer
967   * @return an enumeration of the measure names
968   */
969  public Enumeration enumerateMeasures() {
970    Vector newVector = new Vector();
971    if (m_ResultProducer instanceof AdditionalMeasureProducer) {
972      Enumeration en = ((AdditionalMeasureProducer)m_ResultProducer).
973        enumerateMeasures();
974      while (en.hasMoreElements()) {
975        String mname = (String)en.nextElement();
976        newVector.addElement(mname);
977      }
978    }
979    return newVector.elements();
980  }
981
982  /**
983   * Returns the value of the named measure
984   * @param additionalMeasureName the name of the measure to query for its value
985   * @return the value of the named measure
986   * @throws IllegalArgumentException if the named measure is not supported
987   */
988  public double getMeasure(String additionalMeasureName) {
989    if (m_ResultProducer instanceof AdditionalMeasureProducer) {
990      return ((AdditionalMeasureProducer)m_ResultProducer).
991        getMeasure(additionalMeasureName);
992    } else {
993      throw new IllegalArgumentException("AveragingResultProducer: "
994                          +"Can't return value for : "+additionalMeasureName
995                          +". "+m_ResultProducer.getClass().getName()+" "
996                          +"is not an AdditionalMeasureProducer");
997    }
998  }
999
1000  /**
1001   * Sets the dataset that results will be obtained for.
1002   *
1003   * @param instances a value of type 'Instances'.
1004   */
1005  public void setInstances(Instances instances) {
1006   
1007    m_Instances = instances;
1008  }
1009
1010  /**
1011   * Returns the tip text for this property
1012   * @return tip text for this property suitable for
1013   * displaying in the explorer/experimenter gui
1014   */
1015  public String calculateStdDevsTipText() {
1016    return "Record standard deviations for each run.";
1017  }
1018
1019  /**
1020   * Get the value of CalculateStdDevs.
1021   *
1022   * @return Value of CalculateStdDevs.
1023   */
1024  public boolean getCalculateStdDevs() {
1025   
1026    return m_CalculateStdDevs;
1027  }
1028 
1029  /**
1030   * Set the value of CalculateStdDevs.
1031   *
1032   * @param newCalculateStdDevs Value to assign to CalculateStdDevs.
1033   */
1034  public void setCalculateStdDevs(boolean newCalculateStdDevs) {
1035   
1036    m_CalculateStdDevs = newCalculateStdDevs;
1037  }
1038
1039  /**
1040   * Returns the tip text for this property
1041   * @return tip text for this property suitable for
1042   * displaying in the explorer/experimenter gui
1043   */
1044  public String expectedResultsPerAverageTipText() {
1045    return "Set the expected number of results to average per run. "
1046      +"For example if a CrossValidationResultProducer is being used "
1047      +"(with the number of folds set to 10), then the expected number "
1048      +"of results per run is 10.";
1049  }
1050
1051  /**
1052   * Get the value of ExpectedResultsPerAverage.
1053   *
1054   * @return Value of ExpectedResultsPerAverage.
1055   */
1056  public int getExpectedResultsPerAverage() {
1057   
1058    return m_ExpectedResultsPerAverage;
1059  }
1060 
1061  /**
1062   * Set the value of ExpectedResultsPerAverage.
1063   *
1064   * @param newExpectedResultsPerAverage Value to assign to
1065   * ExpectedResultsPerAverage.
1066   */
1067  public void setExpectedResultsPerAverage(int newExpectedResultsPerAverage) {
1068   
1069    m_ExpectedResultsPerAverage = newExpectedResultsPerAverage;
1070  }
1071
1072  /**
1073   * Returns the tip text for this property
1074   * @return tip text for this property suitable for
1075   * displaying in the explorer/experimenter gui
1076   */
1077  public String keyFieldNameTipText() {
1078    return "Set the field name that will be unique for a run.";
1079  }
1080
1081  /**
1082   * Get the value of KeyFieldName.
1083   *
1084   * @return Value of KeyFieldName.
1085   */
1086  public String getKeyFieldName() {
1087   
1088    return m_KeyFieldName;
1089  }
1090 
1091  /**
1092   * Set the value of KeyFieldName.
1093   *
1094   * @param newKeyFieldName Value to assign to KeyFieldName.
1095   */
1096  public void setKeyFieldName(String newKeyFieldName) {
1097   
1098    m_KeyFieldName = newKeyFieldName;
1099    m_CountFieldName = "Num_" + m_KeyFieldName;
1100    findKeyIndex();
1101  }
1102 
1103  /**
1104   * Sets the object to send results of each run to.
1105   *
1106   * @param listener a value of type 'ResultListener'
1107   */
1108  public void setResultListener(ResultListener listener) {
1109
1110    m_ResultListener = listener;
1111  }
1112 
1113  /**
1114   * Returns the tip text for this property
1115   * @return tip text for this property suitable for
1116   * displaying in the explorer/experimenter gui
1117   */
1118  public String resultProducerTipText() {
1119    return "Set the resultProducer for which results are to be averaged.";
1120  }
1121
1122  /**
1123   * Get the ResultProducer.
1124   *
1125   * @return the ResultProducer.
1126   */
1127  public ResultProducer getResultProducer() {
1128   
1129    return m_ResultProducer;
1130  }
1131 
1132  /**
1133   * Set the ResultProducer.
1134   *
1135   * @param newResultProducer new ResultProducer to use.
1136   */
1137  public void setResultProducer(ResultProducer newResultProducer) {
1138
1139    m_ResultProducer = newResultProducer;
1140    m_ResultProducer.setResultListener(this);
1141    findKeyIndex();
1142  }
1143
1144  /**
1145   * Gets a text descrption of the result producer.
1146   *
1147   * @return a text description of the result producer.
1148   */
1149  public String toString() {
1150
1151    String result = "AveragingResultProducer: ";
1152    result += getCompatibilityState();
1153    if (m_Instances == null) {
1154      result += ": <null Instances>";
1155    } else {
1156      result += ": " + Utils.backQuoteChars(m_Instances.relationName());
1157    }
1158    return result;
1159  }
1160 
1161  /**
1162   * Returns the revision string.
1163   *
1164   * @return            the revision
1165   */
1166  public String getRevision() {
1167    return RevisionUtils.extract("$Revision: 1.18 $");
1168  }
1169} // AveragingResultProducer
Note: See TracBrowser for help on using the repository browser.