source: src/main/java/weka/classifiers/evaluation/ConfusionMatrix.java @ 19

Last change on this file since 19 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 9.3 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    NominalPrediction.java
19 *    Copyright (C) 2002 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.evaluation;
24
25import weka.classifiers.CostMatrix;
26import weka.core.FastVector;
27import weka.core.Matrix;
28import weka.core.RevisionUtils;
29import weka.core.Utils;
30
31/**
32 * Cells of this matrix correspond to counts of the number (or weight)
33 * of predictions for each actual value / predicted value combination.
34 *
35 * @author Len Trigg (len@reeltwo.com)
36 * @version $Revision: 1.9 $
37 */
38public class ConfusionMatrix extends Matrix {
39
40  /** for serialization */
41  private static final long serialVersionUID = -181789981401504090L;
42
43  /** Stores the names of the classes */
44  protected String [] m_ClassNames;
45
46  /**
47   * Creates the confusion matrix with the given class names.
48   *
49   * @param classNames an array containing the names the classes.
50   */
51  public ConfusionMatrix(String [] classNames) {
52
53    super(classNames.length, classNames.length);
54    m_ClassNames = (String [])classNames.clone();
55  }
56
57  /**
58   * Makes a copy of this ConfusionMatrix after applying the
59   * supplied CostMatrix to the cells. The resulting ConfusionMatrix
60   * can be used to get cost-weighted statistics.
61   *
62   * @param costs the CostMatrix.
63   * @return a ConfusionMatrix that has had costs applied.
64   * @exception Exception if the CostMatrix is not of the same size
65   * as this ConfusionMatrix.
66   */
67  public ConfusionMatrix makeWeighted(CostMatrix costs) throws Exception {
68
69    if (costs.size() != size()) {
70      throw new Exception("Cost and confusion matrices must be the same size");
71    }
72    ConfusionMatrix weighted = new ConfusionMatrix(m_ClassNames);
73    for (int row = 0; row < size(); row++) {
74      for (int col = 0; col < size(); col++) {
75        weighted.setElement(row, col, getElement(row, col) * 
76                            costs.getElement(row, col));
77      }
78    }
79    return weighted;
80  }
81
82
83  /**
84   * Creates and returns a clone of this object.
85   *
86   * @return a clone of this instance.
87   */
88  public Object clone() {
89
90    ConfusionMatrix m = (ConfusionMatrix)super.clone();
91    m.m_ClassNames = (String [])m_ClassNames.clone();
92    return m;
93  }
94
95  /**
96   * Gets the number of classes.
97   *
98   * @return the number of classes
99   */
100  public int size() {
101
102    return m_ClassNames.length;
103  }
104
105  /**
106   * Gets the name of one of the classes.
107   *
108   * @param index the index of the class.
109   * @return the class name.
110   */
111  public String className(int index) {
112
113    return m_ClassNames[index];
114  }
115
116  /**
117   * Includes a prediction in the confusion matrix.
118   *
119   * @param pred the NominalPrediction to include
120   * @exception Exception if no valid prediction was made (i.e.
121   * unclassified).
122   */
123  public void addPrediction(NominalPrediction pred) throws Exception {
124
125    if (pred.predicted() == NominalPrediction.MISSING_VALUE) {
126      throw new Exception("No predicted value given.");
127    }
128    if (pred.actual() == NominalPrediction.MISSING_VALUE) {
129      throw new Exception("No actual value given.");
130    }
131    addElement((int)pred.actual(), (int)pred.predicted(), pred.weight());
132  }
133
134  /**
135   * Includes a whole bunch of predictions in the confusion matrix.
136   *
137   * @param predictions a FastVector containing the NominalPredictions
138   * to include
139   * @exception Exception if no valid prediction was made (i.e.
140   * unclassified).
141   */
142  public void addPredictions(FastVector predictions) throws Exception {
143
144    for (int i = 0; i < predictions.size(); i++) {
145      addPrediction((NominalPrediction)predictions.elementAt(i));
146    }
147  }
148
149 
150  /**
151   * Gets the performance with respect to one of the classes
152   * as a TwoClassStats object.
153   *
154   * @param classIndex the index of the class of interest.
155   * @return the generated TwoClassStats object.
156   */
157  public TwoClassStats getTwoClassStats(int classIndex) {
158
159    double fp = 0, tp = 0, fn = 0, tn = 0;
160    for (int row = 0; row < size(); row++) {
161      for (int col = 0; col < size(); col++) {
162        if (row == classIndex) {
163          if (col == classIndex) {
164            tp += getElement(row, col);
165          } else {
166            fn += getElement(row, col);
167          }         
168        } else {
169          if (col == classIndex) {
170            fp += getElement(row, col);
171          } else {
172            tn += getElement(row, col);
173          }         
174        }
175      }
176    }
177    return new TwoClassStats(tp, fp, tn, fn);
178  }
179
180  /**
181   * Gets the number of correct classifications (that is, for which a
182   * correct prediction was made). (Actually the sum of the weights of
183   * these classifications)
184   *
185   * @return the number of correct classifications
186   */
187  public double correct() {
188
189    double correct = 0;
190    for (int i = 0; i < size(); i++) {
191      correct += getElement(i, i);
192    }
193    return correct;
194  }
195
196  /**
197   * Gets the number of incorrect classifications (that is, for which an
198   * incorrect prediction was made). (Actually the sum of the weights of
199   * these classifications)
200   *
201   * @return the number of incorrect classifications
202   */
203  public double incorrect() {
204
205    double incorrect = 0;
206    for (int row = 0; row < size(); row++) {
207      for (int col = 0; col < size(); col++) {
208        if (row != col) {
209          incorrect += getElement(row, col);
210        }
211      }
212    }
213    return incorrect;
214  }
215
216  /**
217   * Gets the number of predictions that were made
218   * (actually the sum of the weights of predictions where the
219   * class value was known).
220   *
221   * @return the number of predictions with known class
222   */
223  public double total() {
224
225    double total = 0;
226    for (int row = 0; row < size(); row++) {
227      for (int col = 0; col < size(); col++) {
228        total += getElement(row, col);
229      }
230    }
231    return total;
232  }
233
234  /**
235   * Returns the estimated error rate.
236   *
237   * @return the estimated error rate (between 0 and 1).
238   */
239  public double errorRate() {
240
241    return incorrect() / total();
242  }
243
244  /**
245   * Calls toString() with a default title.
246   *
247   * @return the confusion matrix as a string
248   */
249  public String toString() {
250
251    return toString("=== Confusion Matrix ===\n");
252  }
253
254  /**
255   * Outputs the performance statistics as a classification confusion
256   * matrix. For each class value, shows the distribution of
257   * predicted class values.
258   *
259   * @param title the title for the confusion matrix
260   * @return the confusion matrix as a String
261   */
262  public String toString(String title) {
263
264    StringBuffer text = new StringBuffer();
265    char [] IDChars = {'a','b','c','d','e','f','g','h','i','j',
266                       'k','l','m','n','o','p','q','r','s','t',
267                       'u','v','w','x','y','z'};
268    int IDWidth;
269    boolean fractional = false;
270
271    // Find the maximum value in the matrix
272    // and check for fractional display requirement
273    double maxval = 0;
274    for (int i = 0; i < size(); i++) {
275      for (int j = 0; j < size(); j++) {
276        double current = getElement(i, j);
277        if (current < 0) {
278          current *= -10;
279        }
280        if (current > maxval) {
281          maxval = current;
282        }
283        double fract = current - Math.rint(current);
284        if (!fractional
285            && ((Math.log(fract) / Math.log(10)) >= -2)) {
286          fractional = true;
287        }
288      }
289    }
290
291    IDWidth = 1 + Math.max((int)(Math.log(maxval) / Math.log(10) 
292                                 + (fractional ? 3 : 0)),
293                             (int)(Math.log(size()) / 
294                                   Math.log(IDChars.length)));
295    text.append(title).append("\n");
296    for (int i = 0; i < size(); i++) {
297      if (fractional) {
298        text.append(" ").append(num2ShortID(i,IDChars,IDWidth - 3))
299          .append("   ");
300      } else {
301        text.append(" ").append(num2ShortID(i,IDChars,IDWidth));
302      }
303    }
304    text.append("     actual class\n");
305    for (int i = 0; i< size(); i++) { 
306      for (int j = 0; j < size(); j++) {
307        text.append(" ").append(
308                    Utils.doubleToString(getElement(i, j),
309                                         IDWidth,
310                                         (fractional ? 2 : 0)));
311      }
312      text.append(" | ").append(num2ShortID(i,IDChars,IDWidth))
313        .append(" = ").append(m_ClassNames[i]).append("\n");
314    }
315    return text.toString();
316  }
317
318  /**
319   * Method for generating indices for the confusion matrix.
320   *
321   * @param num integer to format
322   * @return the formatted integer as a string
323   */
324  private static String num2ShortID(int num, char [] IDChars, int IDWidth) {
325   
326    char ID [] = new char [IDWidth];
327    int i;
328   
329    for(i = IDWidth - 1; i >=0; i--) {
330      ID[i] = IDChars[num % IDChars.length];
331      num = num / IDChars.length - 1;
332      if (num < 0) {
333        break;
334      }
335    }
336    for(i--; i >= 0; i--) {
337      ID[i] = ' ';
338    }
339
340    return new String(ID);
341  }
342 
343  /**
344   * Returns the revision string.
345   *
346   * @return            the revision
347   */
348  public String getRevision() {
349    return RevisionUtils.extract("$Revision: 1.9 $");
350  }
351}
Note: See TracBrowser for help on using the repository browser.