source: src/main/java/weka/classifiers/meta/CostSensitiveClassifier.java @ 7

Last change on this file since 7 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 20.0 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    CostSensitiveClassifier.java
19 *    Copyright (C) 2002 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.meta;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.classifiers.CostMatrix;
28import weka.classifiers.RandomizableSingleClassifierEnhancer;
29import weka.core.Capabilities;
30import weka.core.Drawable;
31import weka.core.Instance;
32import weka.core.Instances;
33import weka.core.Option;
34import weka.core.OptionHandler;
35import weka.core.RevisionUtils;
36import weka.core.SelectedTag;
37import weka.core.Tag;
38import weka.core.Utils;
39import weka.core.WeightedInstancesHandler;
40import weka.core.Capabilities.Capability;
41
42import java.io.BufferedReader;
43import java.io.File;
44import java.io.FileReader;
45import java.io.StringReader;
46import java.io.StringWriter;
47import java.util.Enumeration;
48import java.util.Random;
49import java.util.Vector;
50
51/**
52 <!-- globalinfo-start -->
53 * A metaclassifier that makes its base classifier cost-sensitive. Two methods can be used to introduce cost-sensitivity: reweighting training instances according to the total cost assigned to each class; or predicting the class with minimum expected misclassification cost (rather than the most likely class). Performance can often be improved by using a Bagged classifier to improve the probability estimates of the base classifier.
54 * <p/>
55 <!-- globalinfo-end -->
56 *
57 <!-- options-start -->
58 * Valid options are: <p/>
59 *
60 * <pre> -M
61 *  Minimize expected misclassification cost. Default is to
62 *  reweight training instances according to costs per class</pre>
63 *
64 * <pre> -C &lt;cost file name&gt;
65 *  File name of a cost matrix to use. If this is not supplied,
66 *  a cost matrix will be loaded on demand. The name of the
67 *  on-demand file is the relation name of the training data
68 *  plus ".cost", and the path to the on-demand file is
69 *  specified with the -N option.</pre>
70 *
71 * <pre> -N &lt;directory&gt;
72 *  Name of a directory to search for cost files when loading
73 *  costs on demand (default current directory).</pre>
74 *
75 * <pre> -cost-matrix &lt;matrix&gt;
76 *  The cost matrix in Matlab single line format.</pre>
77 *
78 * <pre> -S &lt;num&gt;
79 *  Random number seed.
80 *  (default 1)</pre>
81 *
82 * <pre> -D
83 *  If set, classifier is run in debug mode and
84 *  may output additional info to the console</pre>
85 *
86 * <pre> -W
87 *  Full name of base classifier.
88 *  (default: weka.classifiers.rules.ZeroR)</pre>
89 *
90 * <pre>
91 * Options specific to classifier weka.classifiers.rules.ZeroR:
92 * </pre>
93 *
94 * <pre> -D
95 *  If set, classifier is run in debug mode and
96 *  may output additional info to the console</pre>
97 *
98 <!-- options-end -->
99 *
100 * Options after -- are passed to the designated classifier.<p>
101 *
102 * @author Len Trigg (len@reeltwo.com)
103 * @version $Revision: 5928 $
104 */
105public class CostSensitiveClassifier 
106  extends RandomizableSingleClassifierEnhancer
107  implements OptionHandler, Drawable {
108
109  /** for serialization */
110  static final long serialVersionUID = -720658209263002404L;
111 
112  /** load cost matrix on demand */
113  public static final int MATRIX_ON_DEMAND = 1;
114  /** use explicit cost matrix */
115  public static final int MATRIX_SUPPLIED = 2;
116  /** Specify possible sources of the cost matrix */
117  public static final Tag [] TAGS_MATRIX_SOURCE = {
118    new Tag(MATRIX_ON_DEMAND, "Load cost matrix on demand"),
119    new Tag(MATRIX_SUPPLIED, "Use explicit cost matrix")
120  };
121
122  /** Indicates the current cost matrix source */
123  protected int m_MatrixSource = MATRIX_ON_DEMAND;
124
125  /**
126   * The directory used when loading cost files on demand, null indicates
127   * current directory
128   */
129  protected File m_OnDemandDirectory = new File(System.getProperty("user.dir"));
130
131  /** The name of the cost file, for command line options */
132  protected String m_CostFile;
133
134  /** The cost matrix */
135  protected CostMatrix m_CostMatrix = new CostMatrix(1);
136
137  /**
138   * True if the costs should be used by selecting the minimum expected
139   * cost (false means weight training data by the costs)
140   */
141  protected boolean m_MinimizeExpectedCost;
142 
143  /**
144   * String describing default classifier.
145   *
146   * @return the default classifier classname
147   */
148  protected String defaultClassifierString() {
149   
150    return "weka.classifiers.rules.ZeroR";
151  }
152
153  /**
154   * Default constructor.
155   */
156  public CostSensitiveClassifier() {
157    m_Classifier = new weka.classifiers.rules.ZeroR();
158  }
159
160  /**
161   * Returns an enumeration describing the available options.
162   *
163   * @return an enumeration of all the available options.
164   */
165  public Enumeration listOptions() {
166
167    Vector newVector = new Vector(5);
168
169    newVector.addElement(new Option(
170              "\tMinimize expected misclassification cost. Default is to\n"
171              +"\treweight training instances according to costs per class",
172              "M", 0, "-M"));
173    newVector.addElement(new Option(
174              "\tFile name of a cost matrix to use. If this is not supplied,\n"
175              +"\ta cost matrix will be loaded on demand. The name of the\n"
176              +"\ton-demand file is the relation name of the training data\n"
177              +"\tplus \".cost\", and the path to the on-demand file is\n"
178              +"\tspecified with the -N option.",
179              "C", 1, "-C <cost file name>"));
180    newVector.addElement(new Option(
181              "\tName of a directory to search for cost files when loading\n"
182              +"\tcosts on demand (default current directory).",
183              "N", 1, "-N <directory>"));
184    newVector.addElement(new Option(
185              "\tThe cost matrix in Matlab single line format.",
186              "cost-matrix", 1, "-cost-matrix <matrix>"));
187
188    Enumeration enu = super.listOptions();
189    while (enu.hasMoreElements()) {
190      newVector.addElement(enu.nextElement());
191    }
192
193    return newVector.elements();
194  }
195
196  /**
197   * Parses a given list of options. <p/>
198   *
199   <!-- options-start -->
200   * Valid options are: <p/>
201   *
202   * <pre> -M
203   *  Minimize expected misclassification cost. Default is to
204   *  reweight training instances according to costs per class</pre>
205   *
206   * <pre> -C &lt;cost file name&gt;
207   *  File name of a cost matrix to use. If this is not supplied,
208   *  a cost matrix will be loaded on demand. The name of the
209   *  on-demand file is the relation name of the training data
210   *  plus ".cost", and the path to the on-demand file is
211   *  specified with the -N option.</pre>
212   *
213   * <pre> -N &lt;directory&gt;
214   *  Name of a directory to search for cost files when loading
215   *  costs on demand (default current directory).</pre>
216   *
217   * <pre> -cost-matrix &lt;matrix&gt;
218   *  The cost matrix in Matlab single line format.</pre>
219   *
220   * <pre> -S &lt;num&gt;
221   *  Random number seed.
222   *  (default 1)</pre>
223   *
224   * <pre> -D
225   *  If set, classifier is run in debug mode and
226   *  may output additional info to the console</pre>
227   *
228   * <pre> -W
229   *  Full name of base classifier.
230   *  (default: weka.classifiers.rules.ZeroR)</pre>
231   *
232   * <pre>
233   * Options specific to classifier weka.classifiers.rules.ZeroR:
234   * </pre>
235   *
236   * <pre> -D
237   *  If set, classifier is run in debug mode and
238   *  may output additional info to the console</pre>
239   *
240   <!-- options-end -->
241   *
242   * Options after -- are passed to the designated classifier.<p>
243   *
244   * @param options the list of options as an array of strings
245   * @throws Exception if an option is not supported
246   */
247  public void setOptions(String[] options) throws Exception {
248
249    setMinimizeExpectedCost(Utils.getFlag('M', options));
250
251    String costFile = Utils.getOption('C', options);
252    if (costFile.length() != 0) {
253      try {
254        setCostMatrix(new CostMatrix(new BufferedReader(
255                                     new FileReader(costFile))));
256      } catch (Exception ex) {
257        // now flag as possible old format cost matrix. Delay cost matrix
258        // loading until buildClassifer is called
259        setCostMatrix(null);
260      }
261      setCostMatrixSource(new SelectedTag(MATRIX_SUPPLIED,
262                                          TAGS_MATRIX_SOURCE));
263      m_CostFile = costFile;
264    } else {
265      setCostMatrixSource(new SelectedTag(MATRIX_ON_DEMAND, 
266                                          TAGS_MATRIX_SOURCE));
267    }
268   
269    String demandDir = Utils.getOption('N', options);
270    if (demandDir.length() != 0) {
271      setOnDemandDirectory(new File(demandDir));
272    }
273
274    String cost_matrix = Utils.getOption("cost-matrix", options);
275    if (cost_matrix.length() != 0) {
276      StringWriter writer = new StringWriter();
277      CostMatrix.parseMatlab(cost_matrix).write(writer);
278      setCostMatrix(new CostMatrix(new StringReader(writer.toString())));
279      setCostMatrixSource(new SelectedTag(MATRIX_SUPPLIED,
280                                          TAGS_MATRIX_SOURCE));
281    }
282   
283    super.setOptions(options);
284  }
285
286
287  /**
288   * Gets the current settings of the Classifier.
289   *
290   * @return an array of strings suitable for passing to setOptions
291   */
292  public String [] getOptions() {
293    String [] superOptions = super.getOptions();
294    String [] options = new String [superOptions.length + 7];
295
296    int current = 0;
297
298    if (m_MatrixSource == MATRIX_SUPPLIED) {
299      if (m_CostFile != null) {
300        options[current++] = "-C";
301        options[current++] = "" + m_CostFile;
302      }
303      else {
304        options[current++] = "-cost-matrix";
305        options[current++] = getCostMatrix().toMatlab();
306      }
307    } else {
308      options[current++] = "-N";
309      options[current++] = "" + getOnDemandDirectory();
310    }
311
312    if (getMinimizeExpectedCost()) {
313      options[current++] = "-M";
314    }
315
316    System.arraycopy(superOptions, 0, options, current, 
317                     superOptions.length);
318
319    while (current < options.length) {
320      if (options[current] == null) {
321        options[current] = "";
322      }
323      current++;
324    }
325
326    return options;
327  }
328
329  /**
330   * @return a description of the classifier suitable for
331   * displaying in the explorer/experimenter gui
332   */
333  public String globalInfo() {
334
335    return "A metaclassifier that makes its base classifier cost-sensitive. "
336      + "Two methods can be used to introduce cost-sensitivity: reweighting "
337      + "training instances according to the total cost assigned to each "
338      + "class; or predicting the class with minimum expected "
339      + "misclassification cost (rather than the most likely class). "
340      + "Performance can often be "
341      + "improved by using a Bagged classifier to improve the probability "
342      + "estimates of the base classifier.";
343  }
344
345  /**
346   * @return tip text for this property suitable for
347   * displaying in the explorer/experimenter gui
348   */
349  public String costMatrixSourceTipText() {
350
351    return "Sets where to get the cost matrix. The two options are"
352      + "to use the supplied explicit cost matrix (the setting of the "
353      + "costMatrix property), or to load a cost matrix from a file when "
354      + "required (this file will be loaded from the directory set by the "
355      + "onDemandDirectory property and will be named relation_name" 
356      + CostMatrix.FILE_EXTENSION + ").";
357  }
358
359  /**
360   * Gets the source location method of the cost matrix. Will be one of
361   * MATRIX_ON_DEMAND or MATRIX_SUPPLIED.
362   *
363   * @return the cost matrix source.
364   */
365  public SelectedTag getCostMatrixSource() {
366
367    return new SelectedTag(m_MatrixSource, TAGS_MATRIX_SOURCE);
368  }
369 
370  /**
371   * Sets the source location of the cost matrix. Values other than
372   * MATRIX_ON_DEMAND or MATRIX_SUPPLIED will be ignored.
373   *
374   * @param newMethod the cost matrix location method.
375   */
376  public void setCostMatrixSource(SelectedTag newMethod) {
377   
378    if (newMethod.getTags() == TAGS_MATRIX_SOURCE) {
379      m_MatrixSource = newMethod.getSelectedTag().getID();
380    }
381  }
382
383  /**
384   * @return tip text for this property suitable for
385   * displaying in the explorer/experimenter gui
386   */
387  public String onDemandDirectoryTipText() {
388
389    return "Sets the directory where cost files are loaded from. This option "
390      + "is used when the costMatrixSource is set to \"On Demand\".";
391  }
392
393  /**
394   * Returns the directory that will be searched for cost files when
395   * loading on demand.
396   *
397   * @return The cost file search directory.
398   */
399  public File getOnDemandDirectory() {
400
401    return m_OnDemandDirectory;
402  }
403
404  /**
405   * Sets the directory that will be searched for cost files when
406   * loading on demand.
407   *
408   * @param newDir The cost file search directory.
409   */
410  public void setOnDemandDirectory(File newDir) {
411
412    if (newDir.isDirectory()) {
413      m_OnDemandDirectory = newDir;
414    } else {
415      m_OnDemandDirectory = new File(newDir.getParent());
416    }
417    m_MatrixSource = MATRIX_ON_DEMAND;
418  }
419
420  /**
421   * @return tip text for this property suitable for
422   * displaying in the explorer/experimenter gui
423   */
424  public String minimizeExpectedCostTipText() {
425
426    return "Sets whether the minimum expected cost criteria will be used. If "
427      + "this is false, the training data will be reweighted according to the "
428      + "costs assigned to each class. If true, the minimum expected cost "
429      + "criteria will be used.";
430  }
431
432  /**
433   * Gets the value of MinimizeExpectedCost.
434   *
435   * @return Value of MinimizeExpectedCost.
436   */
437  public boolean getMinimizeExpectedCost() {
438   
439    return m_MinimizeExpectedCost;
440  }
441 
442  /**
443   * Set the value of MinimizeExpectedCost.
444   *
445   * @param newMinimizeExpectedCost Value to assign to MinimizeExpectedCost.
446   */
447  public void setMinimizeExpectedCost(boolean newMinimizeExpectedCost) {
448   
449    m_MinimizeExpectedCost = newMinimizeExpectedCost;
450  }
451 
452  /**
453   * Gets the classifier specification string, which contains the class name of
454   * the classifier and any options to the classifier
455   *
456   * @return the classifier string.
457   */
458  protected String getClassifierSpec() {
459   
460    Classifier c = getClassifier();
461    if (c instanceof OptionHandler) {
462      return c.getClass().getName() + " "
463        + Utils.joinOptions(((OptionHandler)c).getOptions());
464    }
465    return c.getClass().getName();
466  }
467 
468  /**
469   * @return tip text for this property suitable for
470   * displaying in the explorer/experimenter gui
471   */
472  public String costMatrixTipText() {
473    return "Sets the cost matrix explicitly. This matrix is used if the "
474      + "costMatrixSource property is set to \"Supplied\".";
475  }
476
477  /**
478   * Gets the misclassification cost matrix.
479   *
480   * @return the cost matrix
481   */
482  public CostMatrix getCostMatrix() {
483   
484    return m_CostMatrix;
485  }
486 
487  /**
488   * Sets the misclassification cost matrix.
489   *
490   * @param newCostMatrix the cost matrix
491   */
492  public void setCostMatrix(CostMatrix newCostMatrix) {
493   
494    m_CostMatrix = newCostMatrix;
495    m_MatrixSource = MATRIX_SUPPLIED;
496  }
497
498  /**
499   * Returns default capabilities of the classifier.
500   *
501   * @return      the capabilities of this classifier
502   */
503  public Capabilities getCapabilities() {
504    Capabilities result = super.getCapabilities();
505
506    // class
507    result.disableAllClasses();
508    result.disableAllClassDependencies();
509    result.enable(Capability.NOMINAL_CLASS);
510   
511    return result;
512  }
513
514  /**
515   * Builds the model of the base learner.
516   *
517   * @param data the training data
518   * @throws Exception if the classifier could not be built successfully
519   */
520  public void buildClassifier(Instances data) throws Exception {
521
522    // can classifier handle the data?
523    getCapabilities().testWithFail(data);
524
525    // remove instances with missing class
526    data = new Instances(data);
527    data.deleteWithMissingClass();
528   
529    if (m_Classifier == null) {
530      throw new Exception("No base classifier has been set!");
531    }
532    if (m_MatrixSource == MATRIX_ON_DEMAND) {
533      String costName = data.relationName() + CostMatrix.FILE_EXTENSION;
534      File costFile = new File(getOnDemandDirectory(), costName);
535      if (!costFile.exists()) {
536        throw new Exception("On-demand cost file doesn't exist: " + costFile);
537      }
538      setCostMatrix(new CostMatrix(new BufferedReader(
539                                   new FileReader(costFile))));
540    } else if (m_CostMatrix == null) {
541      // try loading an old format cost file
542      m_CostMatrix = new CostMatrix(data.numClasses());
543      m_CostMatrix.readOldFormat(new BufferedReader(
544                               new FileReader(m_CostFile)));
545    }
546
547    if (!m_MinimizeExpectedCost) {
548      Random random = null;
549      if (!(m_Classifier instanceof WeightedInstancesHandler)) {
550        random = new Random(m_Seed);
551      }
552      data = m_CostMatrix.applyCostMatrix(data, random);     
553    }
554    m_Classifier.buildClassifier(data);
555  }
556
557  /**
558   * Returns class probabilities. When minimum expected cost approach is chosen,
559   * returns probability one for class with the minimum expected misclassification
560   * cost. Otherwise it returns the probability distribution returned by
561   * the base classifier.
562   *
563   * @param instance the instance to be classified
564   * @return the computed distribution for the given instance
565   * @throws Exception if instance could not be classified
566   * successfully */
567  public double[] distributionForInstance(Instance instance) throws Exception {
568
569    if (!m_MinimizeExpectedCost) {
570      return m_Classifier.distributionForInstance(instance);
571    }
572    double [] pred = m_Classifier.distributionForInstance(instance);
573    double [] costs = m_CostMatrix.expectedCosts(pred, instance);
574    /*
575    for (int i = 0; i < pred.length; i++) {
576      System.out.print(pred[i] + " ");
577    }
578    System.out.println();
579    for (int i = 0; i < costs.length; i++) {
580      System.out.print(costs[i] + " ");
581    }
582    System.out.println("\n");
583    */
584
585    // This is probably not ideal
586    int classIndex = Utils.minIndex(costs);
587    for (int i = 0; i  < pred.length; i++) {
588      if (i == classIndex) {
589        pred[i] = 1.0;
590      } else {
591        pred[i] = 0.0;
592      }
593    }
594    return pred; 
595  }
596
597  /**
598   *  Returns the type of graph this classifier
599   *  represents.
600   * 
601   *  @return the type of graph this classifier represents
602   */   
603  public int graphType() {
604   
605    if (m_Classifier instanceof Drawable)
606      return ((Drawable)m_Classifier).graphType();
607    else 
608      return Drawable.NOT_DRAWABLE;
609  }
610
611  /**
612   * Returns graph describing the classifier (if possible).
613   *
614   * @return the graph of the classifier in dotty format
615   * @throws Exception if the classifier cannot be graphed
616   */
617  public String graph() throws Exception {
618   
619    if (m_Classifier instanceof Drawable)
620      return ((Drawable)m_Classifier).graph();
621    else throw new Exception("Classifier: " + getClassifierSpec()
622                             + " cannot be graphed");
623  }
624
625  /**
626   * Output a representation of this classifier
627   *
628   * @return a string representation of the classifier
629   */
630  public String toString() {
631
632    if (m_Classifier == null) {
633      return "CostSensitiveClassifier: No model built yet.";
634    }
635
636    String result = "CostSensitiveClassifier using ";
637      if (m_MinimizeExpectedCost) {
638        result += "minimized expected misclasification cost\n";
639      } else {
640        result += "reweighted training instances\n";
641      }
642      result += "\n" + getClassifierSpec()
643        + "\n\nClassifier Model\n"
644        + m_Classifier.toString()
645        + "\n\nCost Matrix\n"
646        + m_CostMatrix.toString();
647
648    return result;
649  }
650 
651  /**
652   * Returns the revision string.
653   *
654   * @return            the revision
655   */
656  public String getRevision() {
657    return RevisionUtils.extract("$Revision: 5928 $");
658  }
659
660  /**
661   * Main method for testing this class.
662   *
663   * @param argv should contain the following arguments:
664   * -t training file [-T test file] [-c class index]
665   */
666  public static void main(String [] argv) {
667    runClassifier(new CostSensitiveClassifier(), argv);
668  }
669}
Note: See TracBrowser for help on using the repository browser.