| [29] | 1 | /* | 
|---|
|  | 2 | *    This program is free software; you can redistribute it and/or modify | 
|---|
|  | 3 | *    it under the terms of the GNU General Public License as published by | 
|---|
|  | 4 | *    the Free Software Foundation; either version 2 of the License, or | 
|---|
|  | 5 | *    (at your option) any later version. | 
|---|
|  | 6 | * | 
|---|
|  | 7 | *    This program is distributed in the hope that it will be useful, | 
|---|
|  | 8 | *    but WITHOUT ANY WARRANTY; without even the implied warranty of | 
|---|
|  | 9 | *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | 
|---|
|  | 10 | *    GNU General Public License for more details. | 
|---|
|  | 11 | * | 
|---|
|  | 12 | *    You should have received a copy of the GNU General Public License | 
|---|
|  | 13 | *    along with this program; if not, write to the Free Software | 
|---|
|  | 14 | *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. | 
|---|
|  | 15 | */ | 
|---|
|  | 16 |  | 
|---|
|  | 17 | /* | 
|---|
|  | 18 | *    NominalPrediction.java | 
|---|
|  | 19 | *    Copyright (C) 2002 University of Waikato, Hamilton, New Zealand | 
|---|
|  | 20 | * | 
|---|
|  | 21 | */ | 
|---|
|  | 22 |  | 
|---|
|  | 23 | package weka.classifiers.evaluation; | 
|---|
|  | 24 |  | 
|---|
|  | 25 | import weka.core.RevisionHandler; | 
|---|
|  | 26 | import weka.core.RevisionUtils; | 
|---|
|  | 27 |  | 
|---|
|  | 28 | import java.io.Serializable; | 
|---|
|  | 29 |  | 
|---|
|  | 30 | /** | 
|---|
|  | 31 | * Encapsulates an evaluatable nominal prediction: the predicted probability | 
|---|
|  | 32 | * distribution plus the actual class value. | 
|---|
|  | 33 | * | 
|---|
|  | 34 | * @author Len Trigg (len@reeltwo.com) | 
|---|
|  | 35 | * @version $Revision: 1.12 $ | 
|---|
|  | 36 | */ | 
|---|
|  | 37 | public class NominalPrediction | 
|---|
|  | 38 | implements Prediction, Serializable, RevisionHandler { | 
|---|
|  | 39 |  | 
|---|
|  | 40 | /** | 
|---|
|  | 41 | * Remove this if you change this class so that serialization would be | 
|---|
|  | 42 | * affected. | 
|---|
|  | 43 | */ | 
|---|
|  | 44 | static final long serialVersionUID = -8871333992740492788L; | 
|---|
|  | 45 |  | 
|---|
|  | 46 | /** The predicted probabilities */ | 
|---|
|  | 47 | private double [] m_Distribution; | 
|---|
|  | 48 |  | 
|---|
|  | 49 | /** The actual class value */ | 
|---|
|  | 50 | private double m_Actual = MISSING_VALUE; | 
|---|
|  | 51 |  | 
|---|
|  | 52 | /** The predicted class value */ | 
|---|
|  | 53 | private double m_Predicted = MISSING_VALUE; | 
|---|
|  | 54 |  | 
|---|
|  | 55 | /** The weight assigned to this prediction */ | 
|---|
|  | 56 | private double m_Weight = 1; | 
|---|
|  | 57 |  | 
|---|
|  | 58 | /** | 
|---|
|  | 59 | * Creates the NominalPrediction object with a default weight of 1.0. | 
|---|
|  | 60 | * | 
|---|
|  | 61 | * @param actual the actual value, or MISSING_VALUE. | 
|---|
|  | 62 | * @param distribution the predicted probability distribution. Use | 
|---|
|  | 63 | * NominalPrediction.makeDistribution() if you only know the predicted value. | 
|---|
|  | 64 | */ | 
|---|
|  | 65 | public NominalPrediction(double actual, double [] distribution) { | 
|---|
|  | 66 |  | 
|---|
|  | 67 | this(actual, distribution, 1); | 
|---|
|  | 68 | } | 
|---|
|  | 69 |  | 
|---|
|  | 70 | /** | 
|---|
|  | 71 | * Creates the NominalPrediction object. | 
|---|
|  | 72 | * | 
|---|
|  | 73 | * @param actual the actual value, or MISSING_VALUE. | 
|---|
|  | 74 | * @param distribution the predicted probability distribution. Use | 
|---|
|  | 75 | * NominalPrediction.makeDistribution() if you only know the predicted value. | 
|---|
|  | 76 | * @param weight the weight assigned to the prediction. | 
|---|
|  | 77 | */ | 
|---|
|  | 78 | public NominalPrediction(double actual, double [] distribution, | 
|---|
|  | 79 | double weight) { | 
|---|
|  | 80 |  | 
|---|
|  | 81 | if (distribution == null) { | 
|---|
|  | 82 | throw new NullPointerException("Null distribution in NominalPrediction."); | 
|---|
|  | 83 | } | 
|---|
|  | 84 | m_Actual = actual; | 
|---|
|  | 85 | m_Distribution = distribution.clone(); | 
|---|
|  | 86 | m_Weight = weight; | 
|---|
|  | 87 | updatePredicted(); | 
|---|
|  | 88 | } | 
|---|
|  | 89 |  | 
|---|
|  | 90 | /** | 
|---|
|  | 91 | * Gets the predicted probabilities | 
|---|
|  | 92 | * | 
|---|
|  | 93 | * @return the predicted probabilities | 
|---|
|  | 94 | */ | 
|---|
|  | 95 | public double [] distribution() { | 
|---|
|  | 96 |  | 
|---|
|  | 97 | return m_Distribution; | 
|---|
|  | 98 | } | 
|---|
|  | 99 |  | 
|---|
|  | 100 | /** | 
|---|
|  | 101 | * Gets the actual class value. | 
|---|
|  | 102 | * | 
|---|
|  | 103 | * @return the actual class value, or MISSING_VALUE if no | 
|---|
|  | 104 | * prediction was made. | 
|---|
|  | 105 | */ | 
|---|
|  | 106 | public double actual() { | 
|---|
|  | 107 |  | 
|---|
|  | 108 | return m_Actual; | 
|---|
|  | 109 | } | 
|---|
|  | 110 |  | 
|---|
|  | 111 | /** | 
|---|
|  | 112 | * Gets the predicted class value. | 
|---|
|  | 113 | * | 
|---|
|  | 114 | * @return the predicted class value, or MISSING_VALUE if no | 
|---|
|  | 115 | * prediction was made. | 
|---|
|  | 116 | */ | 
|---|
|  | 117 | public double predicted() { | 
|---|
|  | 118 |  | 
|---|
|  | 119 | return m_Predicted; | 
|---|
|  | 120 | } | 
|---|
|  | 121 |  | 
|---|
|  | 122 | /** | 
|---|
|  | 123 | * Gets the weight assigned to this prediction. This is typically the weight | 
|---|
|  | 124 | * of the test instance the prediction was made for. | 
|---|
|  | 125 | * | 
|---|
|  | 126 | * @return the weight assigned to this prediction. | 
|---|
|  | 127 | */ | 
|---|
|  | 128 | public double weight() { | 
|---|
|  | 129 |  | 
|---|
|  | 130 | return m_Weight; | 
|---|
|  | 131 | } | 
|---|
|  | 132 |  | 
|---|
|  | 133 | /** | 
|---|
|  | 134 | * Calculates the prediction margin. This is defined as the difference | 
|---|
|  | 135 | * between the probability predicted for the actual class and the highest | 
|---|
|  | 136 | * predicted probability of the other classes. | 
|---|
|  | 137 | * | 
|---|
|  | 138 | * @return the margin for this prediction, or | 
|---|
|  | 139 | * MISSING_VALUE if either the actual or predicted value | 
|---|
|  | 140 | * is missing. | 
|---|
|  | 141 | */ | 
|---|
|  | 142 | public double margin() { | 
|---|
|  | 143 |  | 
|---|
|  | 144 | if ((m_Actual == MISSING_VALUE) || | 
|---|
|  | 145 | (m_Predicted == MISSING_VALUE)) { | 
|---|
|  | 146 | return MISSING_VALUE; | 
|---|
|  | 147 | } | 
|---|
|  | 148 | double probActual = m_Distribution[(int)m_Actual]; | 
|---|
|  | 149 | double probNext = 0; | 
|---|
|  | 150 | for(int i = 0; i < m_Distribution.length; i++) | 
|---|
|  | 151 | if ((i != m_Actual) && | 
|---|
|  | 152 | (m_Distribution[i] > probNext)) | 
|---|
|  | 153 | probNext = m_Distribution[i]; | 
|---|
|  | 154 |  | 
|---|
|  | 155 | return probActual - probNext; | 
|---|
|  | 156 | } | 
|---|
|  | 157 |  | 
|---|
|  | 158 | /** | 
|---|
|  | 159 | * Convert a single prediction into a probability distribution | 
|---|
|  | 160 | * with all zero probabilities except the predicted value which | 
|---|
|  | 161 | * has probability 1.0. If no prediction was made, all probabilities | 
|---|
|  | 162 | * are zero. | 
|---|
|  | 163 | * | 
|---|
|  | 164 | * @param predictedClass the index of the predicted class, or | 
|---|
|  | 165 | * MISSING_VALUE if no prediction was made. | 
|---|
|  | 166 | * @param numClasses the number of possible classes for this nominal | 
|---|
|  | 167 | * prediction. | 
|---|
|  | 168 | * @return the probability distribution. | 
|---|
|  | 169 | */ | 
|---|
|  | 170 | public static double [] makeDistribution(double predictedClass, | 
|---|
|  | 171 | int numClasses) { | 
|---|
|  | 172 |  | 
|---|
|  | 173 | double [] dist = new double [numClasses]; | 
|---|
|  | 174 | if (predictedClass == MISSING_VALUE) { | 
|---|
|  | 175 | return dist; | 
|---|
|  | 176 | } | 
|---|
|  | 177 | dist[(int)predictedClass] = 1.0; | 
|---|
|  | 178 | return dist; | 
|---|
|  | 179 | } | 
|---|
|  | 180 |  | 
|---|
|  | 181 | /** | 
|---|
|  | 182 | * Creates a uniform probability distribution -- where each of the | 
|---|
|  | 183 | * possible classes is assigned equal probability. | 
|---|
|  | 184 | * | 
|---|
|  | 185 | * @param numClasses the number of possible classes for this nominal | 
|---|
|  | 186 | * prediction. | 
|---|
|  | 187 | * @return the probability distribution. | 
|---|
|  | 188 | */ | 
|---|
|  | 189 | public static double [] makeUniformDistribution(int numClasses) { | 
|---|
|  | 190 |  | 
|---|
|  | 191 | double [] dist = new double [numClasses]; | 
|---|
|  | 192 | for (int i = 0; i < numClasses; i++) { | 
|---|
|  | 193 | dist[i] = 1.0 / numClasses; | 
|---|
|  | 194 | } | 
|---|
|  | 195 | return dist; | 
|---|
|  | 196 | } | 
|---|
|  | 197 |  | 
|---|
|  | 198 | /** | 
|---|
|  | 199 | * Determines the predicted class (doesn't detect multiple | 
|---|
|  | 200 | * classifications). If no prediction was made (i.e. all zero | 
|---|
|  | 201 | * probababilities in the distribution), m_Prediction is set to | 
|---|
|  | 202 | * MISSING_VALUE. | 
|---|
|  | 203 | */ | 
|---|
|  | 204 | private void updatePredicted() { | 
|---|
|  | 205 |  | 
|---|
|  | 206 | int predictedClass = -1; | 
|---|
|  | 207 | double bestProb = 0.0; | 
|---|
|  | 208 | for(int i = 0; i < m_Distribution.length; i++) { | 
|---|
|  | 209 | if (m_Distribution[i] > bestProb) { | 
|---|
|  | 210 | predictedClass = i; | 
|---|
|  | 211 | bestProb = m_Distribution[i]; | 
|---|
|  | 212 | } | 
|---|
|  | 213 | } | 
|---|
|  | 214 |  | 
|---|
|  | 215 | if (predictedClass != -1) { | 
|---|
|  | 216 | m_Predicted = predictedClass; | 
|---|
|  | 217 | } else { | 
|---|
|  | 218 | m_Predicted = MISSING_VALUE; | 
|---|
|  | 219 | } | 
|---|
|  | 220 | } | 
|---|
|  | 221 |  | 
|---|
|  | 222 | /** | 
|---|
|  | 223 | * Gets a human readable representation of this prediction. | 
|---|
|  | 224 | * | 
|---|
|  | 225 | * @return a human readable representation of this prediction. | 
|---|
|  | 226 | */ | 
|---|
|  | 227 | public String toString() { | 
|---|
|  | 228 |  | 
|---|
|  | 229 | StringBuffer sb = new StringBuffer(); | 
|---|
|  | 230 | sb.append("NOM: ").append(actual()).append(" ").append(predicted()); | 
|---|
|  | 231 | sb.append(' ').append(weight()); | 
|---|
|  | 232 | double [] dist = distribution(); | 
|---|
|  | 233 | for (int i = 0; i < dist.length; i++) { | 
|---|
|  | 234 | sb.append(' ').append(dist[i]); | 
|---|
|  | 235 | } | 
|---|
|  | 236 | return sb.toString(); | 
|---|
|  | 237 | } | 
|---|
|  | 238 |  | 
|---|
|  | 239 | /** | 
|---|
|  | 240 | * Returns the revision string. | 
|---|
|  | 241 | * | 
|---|
|  | 242 | * @return            the revision | 
|---|
|  | 243 | */ | 
|---|
|  | 244 | public String getRevision() { | 
|---|
|  | 245 | return RevisionUtils.extract("$Revision: 1.12 $"); | 
|---|
|  | 246 | } | 
|---|
|  | 247 | } | 
|---|
|  | 248 |  | 
|---|