source: src/main/java/weka/gui/beans/ClassifierPerformanceEvaluator.java @ 24

Last change on this file since 24 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 16.3 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    ClassifierPerformanceEvaluator.java
19 *    Copyright (C) 2002 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.gui.beans;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.classifiers.Evaluation;
28import weka.classifiers.evaluation.ThresholdCurve;
29import weka.core.Instance;
30import weka.core.Instances;
31import weka.core.OptionHandler;
32import weka.core.Utils;
33import weka.gui.explorer.ClassifierErrorsPlotInstances;
34import weka.gui.explorer.ExplorerDefaults;
35import weka.gui.visualize.PlotData2D;
36
37import java.io.Serializable;
38import java.util.Enumeration;
39import java.util.Vector;
40
41/**
42 * A bean that evaluates the performance of batch trained classifiers
43 *
44 * @author <a href="mailto:mhall@cs.waikato.ac.nz">Mark Hall</a>
45 * @version $Revision: 5928 $
46 */
47public class ClassifierPerformanceEvaluator 
48  extends AbstractEvaluator
49  implements BatchClassifierListener, 
50             Serializable, UserRequestAcceptor, EventConstraints {
51
52  /** for serialization */
53  private static final long serialVersionUID = -3511801418192148690L;
54
55  /**
56   * Evaluation object used for evaluating a classifier
57   */
58  private transient Evaluation m_eval;
59
60  private transient Thread m_evaluateThread = null;
61 
62  private transient long m_currentBatchIdentifier;
63  private transient int m_setsComplete;
64 
65  private Vector m_textListeners = new Vector();
66  private Vector m_thresholdListeners = new Vector();
67  private Vector m_visualizableErrorListeners = new Vector();
68
69  public ClassifierPerformanceEvaluator() {
70    m_visual.loadIcons(BeanVisual.ICON_PATH
71                       +"ClassifierPerformanceEvaluator.gif",
72                       BeanVisual.ICON_PATH
73                       +"ClassifierPerformanceEvaluator_animated.gif");
74    m_visual.setText("ClassifierPerformanceEvaluator");
75  }
76
77  /**
78   * Set a custom (descriptive) name for this bean
79   *
80   * @param name the name to use
81   */
82  public void setCustomName(String name) {
83    m_visual.setText(name);
84  }
85
86  /**
87   * Get the custom (descriptive) name for this bean (if one has been set)
88   *
89   * @return the custom name (or the default name)
90   */
91  public String getCustomName() {
92    return m_visual.getText();
93  }
94 
95  /**
96   * Global info for this bean
97   *
98   * @return a <code>String</code> value
99   */
100  public String globalInfo() {
101    return "Evaluate the performance of batch trained classifiers.";
102  }
103
104  // ----- Stuff for ROC curves
105  private boolean m_rocListenersConnected = false;
106 
107  /** for generating plottable instance with predictions appended. */
108  private transient ClassifierErrorsPlotInstances m_PlotInstances = null;
109
110  /**
111   * Accept a classifier to be evaluated.
112   *
113   * @param ce a <code>BatchClassifierEvent</code> value
114   */
115  public void acceptClassifier(final BatchClassifierEvent ce) {
116    if (ce.getTestSet() == null || ce.getTestSet().isStructureOnly()) {
117      return; // cant evaluate empty/non-existent test instances
118    }
119    try {
120      if (m_evaluateThread == null) {
121        m_evaluateThread = new Thread() {
122            public void run() {
123              boolean errorOccurred = false;
124//            final String oldText = m_visual.getText();
125              Classifier classifier = ce.getClassifier();
126              try {
127                // if (ce.getSetNumber() == 1) {
128                if (ce.getGroupIdentifier() != m_currentBatchIdentifier) {
129                 
130                  if (ce.getTrainSet().getDataSet() == null ||
131                      ce.getTrainSet().getDataSet().numInstances() == 0) {
132                    // we have no training set to estimate majority class
133                    // or mean of target from
134                    m_eval = new Evaluation(ce.getTestSet().getDataSet());
135                    m_eval.useNoPriors();
136                  } else {
137                    m_eval = new Evaluation(ce.getTrainSet().getDataSet());
138                  }
139//                m_classifier = ce.getClassifier();
140                  m_PlotInstances = ExplorerDefaults.getClassifierErrorsPlotInstances();
141                  m_PlotInstances.setInstances(ce.getTestSet().getDataSet());
142                  m_PlotInstances.setClassifier(ce.getClassifier());
143                  m_PlotInstances.setClassIndex(ce.getTestSet().getDataSet().classIndex());
144                  m_PlotInstances.setEvaluation(m_eval);
145                  m_PlotInstances.setUp();
146                 
147                  m_currentBatchIdentifier = ce.getGroupIdentifier();
148                  m_setsComplete = 0;
149                }
150//              if (ce.getSetNumber() <= ce.getMaxSetNumber()) {
151                if (m_setsComplete < ce.getMaxSetNumber()) {
152                 
153                  if (ce.getTrainSet().getDataSet() != null &&
154                      ce.getTrainSet().getDataSet().numInstances() > 0) {
155                    // set the priors
156                    m_eval.setPriors(ce.getTrainSet().getDataSet());
157                  }
158                 
159//                m_visual.setText("Evaluating ("+ce.getSetNumber()+")...");
160                  if (m_logger != null) {
161                    m_logger.statusMessage(statusMessagePrefix()
162                                           +"Evaluating ("+ce.getSetNumber()
163                                           +")...");
164                  }
165                  m_visual.setAnimated();
166                  /*
167                  m_eval.evaluateModel(ce.getClassifier(),
168                  ce.getTestSet().getDataSet()); */
169                  for (int i = 0; i < ce.getTestSet().getDataSet().numInstances(); i++) {
170                    Instance temp = ce.getTestSet().getDataSet().instance(i);
171                    m_PlotInstances.process(temp, ce.getClassifier(), m_eval);
172                  }
173                 
174                  m_setsComplete++;
175                }
176               
177//              if (ce.getSetNumber() == ce.getMaxSetNumber()) {
178                if (m_setsComplete == ce.getMaxSetNumber()) {
179                  //              System.err.println(m_eval.toSummaryString());
180                  // m_resultsString.append(m_eval.toSummaryString());
181                  // m_outText.setText(m_resultsString.toString());
182                  String textTitle = classifier.getClass().getName();
183                  String textOptions = "";
184                  if (classifier instanceof OptionHandler) {
185                     textOptions = 
186                       Utils.joinOptions(((OptionHandler)classifier).getOptions()); 
187                  }
188                  textTitle = 
189                    textTitle.substring(textTitle.lastIndexOf('.')+1,
190                                        textTitle.length());
191                  String resultT = "=== Evaluation result ===\n\n"
192                    + "Scheme: " + textTitle + "\n"
193                    + ((textOptions.length() > 0) ? "Options: " + textOptions + "\n": "")
194                    + "Relation: " + ce.getTestSet().getDataSet().relationName()
195                    + "\n\n" + m_eval.toSummaryString();
196                 
197                  if (ce.getTestSet().getDataSet().
198                      classAttribute().isNominal()) {
199                    resultT += "\n" + m_eval.toClassDetailsString()
200                      + "\n" + m_eval.toMatrixString();
201                  }
202                 
203                  TextEvent te = 
204                    new TextEvent(ClassifierPerformanceEvaluator.this, 
205                                  resultT,
206                                  textTitle);
207                  notifyTextListeners(te);
208
209                  // set up visualizable errors
210                  if (m_visualizableErrorListeners.size() > 0) {
211                    PlotData2D errorD = m_PlotInstances.getPlotData(
212                        textTitle + " " + textOptions);
213                    VisualizableErrorEvent vel = 
214                      new VisualizableErrorEvent(ClassifierPerformanceEvaluator.this, errorD);
215                    notifyVisualizableErrorListeners(vel);
216                    m_PlotInstances.cleanUp();
217                  }
218                 
219
220                  if (ce.getTestSet().getDataSet().classAttribute().isNominal() &&
221                      m_thresholdListeners.size() > 0) {
222                    ThresholdCurve tc = new ThresholdCurve();
223                    Instances result = tc.getCurve(m_eval.predictions(), 0);
224                    result.
225                      setRelationName(ce.getTestSet().getDataSet().relationName());
226                    PlotData2D pd = new PlotData2D(result);
227                    String htmlTitle = "<html><font size=-2>"
228                      + textTitle;
229                    String newOptions = "";
230                    if (classifier instanceof OptionHandler) {
231                      String[] options = 
232                        ((OptionHandler) classifier).getOptions();
233                      if (options.length > 0) {
234                        for (int ii = 0; ii < options.length; ii++) {
235                          if (options[ii].length() == 0) {
236                            continue;
237                          }
238                          if (options[ii].charAt(0) == '-' && 
239                              !(options[ii].charAt(1) >= '0' &&
240                                  options[ii].charAt(1)<= '9')) {
241                            newOptions += "<br>";
242                          }
243                          newOptions += options[ii];
244                        }
245                      }
246                    }
247                   
248                   htmlTitle += " " + newOptions + "<br>" 
249                      + " (class: "
250                      +ce.getTestSet().getDataSet().
251                        classAttribute().value(0) + ")" 
252                      + "</font></html>";
253                    pd.setPlotName(textTitle + " (class: "
254                              +ce.getTestSet().getDataSet().
255                                classAttribute().value(0) + ")");
256                    pd.setPlotNameHTML(htmlTitle);
257                    boolean [] connectPoints = 
258                      new boolean [result.numInstances()];
259                    for (int jj = 1; jj < connectPoints.length; jj++) {
260                      connectPoints[jj] = true;
261                    }
262                    pd.setConnectPoints(connectPoints);
263                    ThresholdDataEvent rde = 
264                      new ThresholdDataEvent(ClassifierPerformanceEvaluator.this,
265                                       pd, ce.getTestSet().getDataSet().classAttribute());
266                    notifyThresholdListeners(rde);
267                    /*te = new TextEvent(ClassifierPerformanceEvaluator.this,
268                                       result.toString(),
269                                       "ThresholdCurveInst");
270                                       notifyTextListeners(te); */
271                  }
272                  if (m_logger != null) {
273                    m_logger.statusMessage(statusMessagePrefix() + "Finished.");
274                  }
275
276                  // save memory
277                  m_PlotInstances = null;
278                }
279              } catch (Exception ex) {
280                errorOccurred = true;
281                ClassifierPerformanceEvaluator.this.stop(); // stop all processing
282                if (m_logger != null) {
283                  m_logger.logMessage("[ClassifierPerformanceEvaluator] "
284                      + statusMessagePrefix() 
285                      + " problem evaluating classifier. " 
286                      + ex.getMessage());
287                }
288                ex.printStackTrace();
289              } finally {
290//              m_visual.setText(oldText);
291                m_visual.setStatic();
292                m_evaluateThread = null;
293                                               
294                if (m_logger != null) {
295                  if (errorOccurred) {
296                    m_logger.statusMessage(statusMessagePrefix() 
297                        + "ERROR (See log for details)");
298                  } else if (isInterrupted()) {
299                    m_logger.logMessage("[" + getCustomName() +"] Evaluation interrupted!");
300                    m_logger.statusMessage(statusMessagePrefix() 
301                        + "INTERRUPTED");
302                  }
303                }
304                block(false);
305              }
306            }
307          };
308        m_evaluateThread.setPriority(Thread.MIN_PRIORITY);
309        m_evaluateThread.start();
310
311        // make sure the thread is still running before we block
312        //      if (m_evaluateThread.isAlive()) {
313        block(true);
314          //    }
315        m_evaluateThread = null;
316      }
317    }  catch (Exception ex) {
318      ex.printStackTrace();
319    }
320  }
321
322  /**
323   * Returns true if. at this time, the bean is busy with some
324   * (i.e. perhaps a worker thread is performing some calculation).
325   *
326   * @return true if the bean is busy.
327   */
328  public boolean isBusy() {
329    return (m_evaluateThread != null);
330  }
331   
332  /**
333   * Try and stop any action
334   */
335  public void stop() {
336    // tell the listenee (upstream bean) to stop
337    if (m_listenee instanceof BeanCommon) {
338      //      System.err.println("Listener is BeanCommon");
339      ((BeanCommon)m_listenee).stop();
340    }
341
342    // stop the evaluate thread
343    if (m_evaluateThread != null) {
344      m_evaluateThread.interrupt();
345      m_evaluateThread.stop();
346      m_evaluateThread = null;
347      m_visual.setStatic();
348    }
349  }
350 
351  /**
352   * Function used to stop code that calls acceptClassifier. This is
353   * needed as classifier evaluation is performed inside a separate
354   * thread of execution.
355   *
356   * @param tf a <code>boolean</code> value
357   */
358  private synchronized void block(boolean tf) {
359    if (tf) {
360      try {
361        // only block if thread is still doing something useful!
362        if (m_evaluateThread != null && m_evaluateThread.isAlive()) {
363          wait();
364        }
365      } catch (InterruptedException ex) {
366      }
367    } else {
368      notifyAll();
369    }
370  }
371
372  /**
373   * Return an enumeration of user activated requests for this bean
374   *
375   * @return an <code>Enumeration</code> value
376   */
377  public Enumeration enumerateRequests() {
378    Vector newVector = new Vector(0);
379    if (m_evaluateThread != null) {
380      newVector.addElement("Stop");
381    }
382    return newVector.elements();
383  }
384
385  /**
386   * Perform the named request
387   *
388   * @param request the request to perform
389   * @exception IllegalArgumentException if an error occurs
390   */
391  public void performRequest(String request) {
392    if (request.compareTo("Stop") == 0) {
393      stop();
394    } else {
395      throw new 
396        IllegalArgumentException(request
397
398                    + " not supported (ClassifierPerformanceEvaluator)");
399    }
400  }
401
402  /**
403   * Add a text listener
404   *
405   * @param cl a <code>TextListener</code> value
406   */
407  public synchronized void addTextListener(TextListener cl) {
408    m_textListeners.addElement(cl);
409  }
410
411  /**
412   * Remove a text listener
413   *
414   * @param cl a <code>TextListener</code> value
415   */
416  public synchronized void removeTextListener(TextListener cl) {
417    m_textListeners.remove(cl);
418  }
419 
420  /**
421   * Add a threshold data listener
422   *
423   * @param cl a <code>ThresholdDataListener</code> value
424   */
425  public synchronized void addThresholdDataListener(ThresholdDataListener cl) {
426    m_thresholdListeners.addElement(cl);
427  }
428
429  /**
430   * Remove a Threshold data listener
431   *
432   * @param cl a <code>ThresholdDataListener</code> value
433   */
434  public synchronized void removeThresholdDataListener(ThresholdDataListener cl) {
435    m_thresholdListeners.remove(cl);
436  }
437
438  /**
439   * Add a visualizable error listener
440   *
441   * @param vel a <code>VisualizableErrorListener</code> value
442   */
443  public synchronized void addVisualizableErrorListener(VisualizableErrorListener vel) {
444    m_visualizableErrorListeners.add(vel);
445  }
446
447  /**
448   * Remove a visualizable error listener
449   *
450   * @param vel a <code>VisualizableErrorListener</code> value
451   */
452  public synchronized void removeVisualizableErrorListener(VisualizableErrorListener vel) {
453    m_visualizableErrorListeners.remove(vel);
454  }
455
456  /**
457   * Notify all text listeners of a TextEvent
458   *
459   * @param te a <code>TextEvent</code> value
460   */
461  private void notifyTextListeners(TextEvent te) {
462    Vector l;
463    synchronized (this) {
464      l = (Vector)m_textListeners.clone();
465    }
466    if (l.size() > 0) {
467      for(int i = 0; i < l.size(); i++) {
468        //      System.err.println("Notifying text listeners "
469        //                         +"(ClassifierPerformanceEvaluator)");
470        ((TextListener)l.elementAt(i)).acceptText(te);
471      }
472    }
473  }
474
475  /**
476   * Notify all ThresholdDataListeners of a ThresholdDataEvent
477   *
478   * @param te a <code>ThresholdDataEvent</code> value
479   */
480  private void notifyThresholdListeners(ThresholdDataEvent re) {
481    Vector l;
482    synchronized (this) {
483      l = (Vector)m_thresholdListeners.clone();
484    }
485    if (l.size() > 0) {
486      for(int i = 0; i < l.size(); i++) {
487        //      System.err.println("Notifying text listeners "
488        //                         +"(ClassifierPerformanceEvaluator)");
489        ((ThresholdDataListener)l.elementAt(i)).acceptDataSet(re);
490      }
491    }
492  }
493
494  /**
495   * Notify all VisualizableErrorListeners of a VisualizableErrorEvent
496   *
497   * @param te a <code>VisualizableErrorEvent</code> value
498   */
499  private void notifyVisualizableErrorListeners(VisualizableErrorEvent re) {
500    Vector l;
501    synchronized (this) {
502      l = (Vector)m_visualizableErrorListeners.clone();
503    }
504    if (l.size() > 0) {
505      for(int i = 0; i < l.size(); i++) {
506        //      System.err.println("Notifying text listeners "
507        //                         +"(ClassifierPerformanceEvaluator)");
508        ((VisualizableErrorListener)l.elementAt(i)).acceptDataSet(re);
509      }
510    }
511  }
512
513  /**
514   * Returns true, if at the current time, the named event could
515   * be generated. Assumes that supplied event names are names of
516   * events that could be generated by this bean.
517   *
518   * @param eventName the name of the event in question
519   * @return true if the named event could be generated at this point in
520   * time
521   */
522  public boolean eventGeneratable(String eventName) {
523    if (m_listenee == null) {
524      return false;
525    }
526
527    if (m_listenee instanceof EventConstraints) {
528      if (!((EventConstraints)m_listenee).
529          eventGeneratable("batchClassifier")) {
530        return false;
531      }
532    }
533    return true;
534  }
535 
536  private String statusMessagePrefix() {
537    return getCustomName() + "$" + hashCode() + "|";
538  }
539}
540
Note: See TracBrowser for help on using the repository browser.