[4] | 1 | /* |
---|
| 2 | * This program is free software; you can redistribute it and/or modify |
---|
| 3 | * it under the terms of the GNU General Public License as published by |
---|
| 4 | * the Free Software Foundation; either version 2 of the License, or |
---|
| 5 | * (at your option) any later version. |
---|
| 6 | * |
---|
| 7 | * This program is distributed in the hope that it will be useful, |
---|
| 8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
| 9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
---|
| 10 | * GNU General Public License for more details. |
---|
| 11 | * |
---|
| 12 | * You should have received a copy of the GNU General Public License |
---|
| 13 | * along with this program; if not, write to the Free Software |
---|
| 14 | * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. |
---|
| 15 | */ |
---|
| 16 | |
---|
| 17 | /* |
---|
| 18 | * CostMatrix.java |
---|
| 19 | * Copyright (C) 2006 University of Waikato, Hamilton, New Zealand |
---|
| 20 | * |
---|
| 21 | */ |
---|
| 22 | |
---|
| 23 | package weka.classifiers; |
---|
| 24 | |
---|
| 25 | import weka.core.AttributeExpression; |
---|
| 26 | import weka.core.Instance; |
---|
| 27 | import weka.core.Instances; |
---|
| 28 | import weka.core.Matrix; |
---|
| 29 | import weka.core.RevisionHandler; |
---|
| 30 | import weka.core.RevisionUtils; |
---|
| 31 | import weka.core.Utils; |
---|
| 32 | |
---|
| 33 | import java.io.LineNumberReader; |
---|
| 34 | import java.io.Reader; |
---|
| 35 | import java.io.Serializable; |
---|
| 36 | import java.io.StreamTokenizer; |
---|
| 37 | import java.io.Writer; |
---|
| 38 | import java.util.Random; |
---|
| 39 | import java.util.StringTokenizer; |
---|
| 40 | |
---|
| 41 | /** |
---|
| 42 | * Class for storing and manipulating a misclassification cost matrix. |
---|
| 43 | * The element at position i,j in the matrix is the penalty for classifying |
---|
| 44 | * an instance of class j as class i. Cost values can be fixed or |
---|
| 45 | * computed on a per-instance basis (cost sensitive evaluation only) |
---|
| 46 | * from the value of an attribute or an expression involving |
---|
| 47 | * attribute(s). |
---|
| 48 | * |
---|
| 49 | * @author Mark Hall |
---|
| 50 | * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz) |
---|
| 51 | * @version $Revision: 6041 $ |
---|
| 52 | */ |
---|
| 53 | public class CostMatrix implements Serializable, RevisionHandler { |
---|
| 54 | |
---|
| 55 | /** for serialization */ |
---|
| 56 | private static final long serialVersionUID = -1973792250544554965L; |
---|
| 57 | |
---|
| 58 | private int m_size; |
---|
| 59 | |
---|
| 60 | /** [rows][columns] */ |
---|
| 61 | protected Object [][] m_matrix; |
---|
| 62 | |
---|
| 63 | /** The deafult file extension for cost matrix files */ |
---|
| 64 | public static String FILE_EXTENSION = ".cost"; |
---|
| 65 | |
---|
| 66 | /** |
---|
| 67 | * Creates a default cost matrix of a particular size. |
---|
| 68 | * All diagonal values will be 0 and all non-diagonal values 1. |
---|
| 69 | * |
---|
| 70 | * @param numOfClasses the number of classes that the cost matrix holds. |
---|
| 71 | */ |
---|
| 72 | public CostMatrix(int numOfClasses) { |
---|
| 73 | m_size = numOfClasses; |
---|
| 74 | initialize(); |
---|
| 75 | } |
---|
| 76 | |
---|
| 77 | /** |
---|
| 78 | * Creates a cost matrix that is a copy of another. |
---|
| 79 | * |
---|
| 80 | * @param toCopy the matrix to copy. |
---|
| 81 | */ |
---|
| 82 | public CostMatrix(CostMatrix toCopy) { |
---|
| 83 | this(toCopy.size()); |
---|
| 84 | |
---|
| 85 | for (int i = 0; i < m_size; i++) { |
---|
| 86 | for (int j = 0; j < m_size; j++) { |
---|
| 87 | setCell(i, j, toCopy.getCell(i, j)); |
---|
| 88 | } |
---|
| 89 | } |
---|
| 90 | } |
---|
| 91 | |
---|
| 92 | /** |
---|
| 93 | * Initializes the matrix |
---|
| 94 | */ |
---|
| 95 | public void initialize() { |
---|
| 96 | m_matrix = new Object[m_size][m_size]; |
---|
| 97 | for (int i = 0; i < m_size; i++) { |
---|
| 98 | for (int j = 0; j < m_size; j++) { |
---|
| 99 | setCell(i, j, i == j ? new Double(0.0) : new Double(1.0)); |
---|
| 100 | } |
---|
| 101 | } |
---|
| 102 | } |
---|
| 103 | |
---|
| 104 | /** |
---|
| 105 | * The number of rows (and columns) |
---|
| 106 | * @return the size of the matrix |
---|
| 107 | */ |
---|
| 108 | public int size() { |
---|
| 109 | return m_size; |
---|
| 110 | } |
---|
| 111 | |
---|
| 112 | /** |
---|
| 113 | * Same as size |
---|
| 114 | * @return the number of columns |
---|
| 115 | */ |
---|
| 116 | public int numColumns() { |
---|
| 117 | return size(); |
---|
| 118 | } |
---|
| 119 | |
---|
| 120 | /** |
---|
| 121 | * Same as size |
---|
| 122 | * @return the number of rows |
---|
| 123 | */ |
---|
| 124 | public int numRows() { |
---|
| 125 | return size(); |
---|
| 126 | } |
---|
| 127 | |
---|
| 128 | private boolean replaceStrings() throws Exception { |
---|
| 129 | boolean nonDouble = false; |
---|
| 130 | |
---|
| 131 | for (int i = 0; i < m_size; i++) { |
---|
| 132 | for (int j = 0; j < m_size; j++) { |
---|
| 133 | if (getCell(i, j) instanceof String) { |
---|
| 134 | AttributeExpression temp = new AttributeExpression(); |
---|
| 135 | temp.convertInfixToPostfix((String)getCell(i, j)); |
---|
| 136 | setCell(i, j, temp); |
---|
| 137 | nonDouble = true; |
---|
| 138 | } else if (getCell(i, j) instanceof AttributeExpression) { |
---|
| 139 | nonDouble = true; |
---|
| 140 | } |
---|
| 141 | } |
---|
| 142 | } |
---|
| 143 | |
---|
| 144 | return nonDouble; |
---|
| 145 | } |
---|
| 146 | |
---|
| 147 | /** |
---|
| 148 | * Applies the cost matrix to a set of instances. If a random number generator is |
---|
| 149 | * supplied the instances will be resampled, otherwise they will be rewighted. |
---|
| 150 | * Adapted from code once sitting in Instances.java |
---|
| 151 | * |
---|
| 152 | * @param data the instances to reweight. |
---|
| 153 | * @param random a random number generator for resampling, if null then instances are |
---|
| 154 | * rewighted. |
---|
| 155 | * @return a new dataset reflecting the cost of misclassification. |
---|
| 156 | * @exception Exception if the data has no class or the matrix in inappropriate. |
---|
| 157 | */ |
---|
| 158 | public Instances applyCostMatrix(Instances data, Random random) |
---|
| 159 | throws Exception { |
---|
| 160 | |
---|
| 161 | double sumOfWeightFactors = 0, sumOfMissClassWeights, |
---|
| 162 | sumOfWeights; |
---|
| 163 | double [] weightOfInstancesInClass, weightFactor, weightOfInstances; |
---|
| 164 | Instances newData; |
---|
| 165 | |
---|
| 166 | if (data.classIndex() < 0) { |
---|
| 167 | throw new Exception("Class index is not set!"); |
---|
| 168 | } |
---|
| 169 | |
---|
| 170 | if (size() != data.numClasses()) { |
---|
| 171 | throw new Exception("Misclassification cost matrix has wrong format!"); |
---|
| 172 | } |
---|
| 173 | |
---|
| 174 | // are there any non-fixed, per-instance costs defined in the matrix? |
---|
| 175 | if (replaceStrings()) { |
---|
| 176 | // could reweight in the two class case |
---|
| 177 | if (data.classAttribute().numValues() > 2) { |
---|
| 178 | throw new Exception("Can't resample/reweight instances using " |
---|
| 179 | +"non-fixed cost values when there are more " |
---|
| 180 | +"than two classes!"); |
---|
| 181 | } else { |
---|
| 182 | // Store new weights |
---|
| 183 | weightOfInstances = new double[data.numInstances()]; |
---|
| 184 | for (int i = 0; i < data.numInstances(); i++) { |
---|
| 185 | Instance inst = data.instance(i); |
---|
| 186 | int classValIndex = (int)inst.classValue(); |
---|
| 187 | double factor = 1.0; |
---|
| 188 | Object element = (classValIndex == 0) |
---|
| 189 | ? getCell(classValIndex, 1) |
---|
| 190 | : getCell(classValIndex, 0); |
---|
| 191 | if (element instanceof Double) { |
---|
| 192 | factor = ((Double)element).doubleValue(); |
---|
| 193 | } else { |
---|
| 194 | factor = ((AttributeExpression)element).evaluateExpression(inst); |
---|
| 195 | } |
---|
| 196 | weightOfInstances[i] = inst.weight() * factor; |
---|
| 197 | /* System.err.println("Multiplying " + inst.classAttribute().value((int)inst.classValue()) |
---|
| 198 | +" by factor " + factor); */ |
---|
| 199 | } |
---|
| 200 | |
---|
| 201 | // Change instances weight or do resampling |
---|
| 202 | if (random != null) { |
---|
| 203 | return data.resampleWithWeights(random, weightOfInstances); |
---|
| 204 | } else { |
---|
| 205 | Instances instances = new Instances(data); |
---|
| 206 | for (int i = 0; i < data.numInstances(); i++) { |
---|
| 207 | instances.instance(i).setWeight(weightOfInstances[i]); |
---|
| 208 | } |
---|
| 209 | return instances; |
---|
| 210 | } |
---|
| 211 | } |
---|
| 212 | } |
---|
| 213 | |
---|
| 214 | weightFactor = new double[data.numClasses()]; |
---|
| 215 | weightOfInstancesInClass = new double[data.numClasses()]; |
---|
| 216 | for (int j = 0; j < data.numInstances(); j++) { |
---|
| 217 | weightOfInstancesInClass[(int)data.instance(j).classValue()] += |
---|
| 218 | data.instance(j).weight(); |
---|
| 219 | } |
---|
| 220 | sumOfWeights = Utils.sum(weightOfInstancesInClass); |
---|
| 221 | |
---|
| 222 | // normalize the matrix if not already |
---|
| 223 | for (int i=0; i< m_size; i++) { |
---|
| 224 | if (!Utils.eq(((Double)getCell(i, i)).doubleValue(), 0)) { |
---|
| 225 | CostMatrix normMatrix = new CostMatrix(this); |
---|
| 226 | normMatrix.normalize(); |
---|
| 227 | return normMatrix.applyCostMatrix(data, random); |
---|
| 228 | } |
---|
| 229 | } |
---|
| 230 | |
---|
| 231 | for (int i = 0; i < data.numClasses(); i++) { |
---|
| 232 | // Using Kai Ming Ting's formula for deriving weights for |
---|
| 233 | // the classes and Breiman's heuristic for multiclass |
---|
| 234 | // problems. |
---|
| 235 | |
---|
| 236 | sumOfMissClassWeights = 0; |
---|
| 237 | for (int j = 0; j < data.numClasses(); j++) { |
---|
| 238 | if (Utils.sm(((Double)getCell(i,j)).doubleValue(),0)) { |
---|
| 239 | throw new Exception("Neg. weights in misclassification "+ |
---|
| 240 | "cost matrix!"); |
---|
| 241 | } |
---|
| 242 | sumOfMissClassWeights |
---|
| 243 | += ((Double)getCell(i,j)).doubleValue(); |
---|
| 244 | } |
---|
| 245 | weightFactor[i] = sumOfMissClassWeights * sumOfWeights; |
---|
| 246 | sumOfWeightFactors += sumOfMissClassWeights * |
---|
| 247 | weightOfInstancesInClass[i]; |
---|
| 248 | } |
---|
| 249 | for (int i = 0; i < data.numClasses(); i++) { |
---|
| 250 | weightFactor[i] /= sumOfWeightFactors; |
---|
| 251 | } |
---|
| 252 | |
---|
| 253 | // Store new weights |
---|
| 254 | weightOfInstances = new double[data.numInstances()]; |
---|
| 255 | for (int i = 0; i < data.numInstances(); i++) { |
---|
| 256 | weightOfInstances[i] = data.instance(i).weight()* |
---|
| 257 | weightFactor[(int)data.instance(i).classValue()]; |
---|
| 258 | } |
---|
| 259 | |
---|
| 260 | // Change instances weight or do resampling |
---|
| 261 | if (random != null) { |
---|
| 262 | return data.resampleWithWeights(random, weightOfInstances); |
---|
| 263 | } else { |
---|
| 264 | Instances instances = new Instances(data); |
---|
| 265 | for (int i = 0; i < data.numInstances(); i++) { |
---|
| 266 | instances.instance(i).setWeight(weightOfInstances[i]); |
---|
| 267 | } |
---|
| 268 | return instances; |
---|
| 269 | } |
---|
| 270 | } |
---|
| 271 | |
---|
| 272 | /** |
---|
| 273 | * Calculates the expected misclassification cost for each possible class value, |
---|
| 274 | * given class probability estimates. |
---|
| 275 | * |
---|
| 276 | * @param classProbs the class probability estimates. |
---|
| 277 | * @return the expected costs. |
---|
| 278 | * @exception Exception if the wrong number of class probabilities is supplied. |
---|
| 279 | */ |
---|
| 280 | public double[] expectedCosts(double[] classProbs) throws Exception { |
---|
| 281 | |
---|
| 282 | if (classProbs.length != m_size) { |
---|
| 283 | throw new Exception("Length of probability estimates don't " |
---|
| 284 | +"match cost matrix"); |
---|
| 285 | } |
---|
| 286 | |
---|
| 287 | double[] costs = new double[m_size]; |
---|
| 288 | |
---|
| 289 | for (int x = 0; x < m_size; x++) { |
---|
| 290 | for (int y = 0; y < m_size; y++) { |
---|
| 291 | Object element = getCell(y, x); |
---|
| 292 | if (!(element instanceof Double)) { |
---|
| 293 | throw new Exception("Can't use non-fixed costs in " |
---|
| 294 | +"computing expected costs."); |
---|
| 295 | } |
---|
| 296 | costs[x] += classProbs[y] * ((Double)element).doubleValue(); |
---|
| 297 | } |
---|
| 298 | } |
---|
| 299 | |
---|
| 300 | return costs; |
---|
| 301 | } |
---|
| 302 | |
---|
| 303 | /** |
---|
| 304 | * Calculates the expected misclassification cost for each possible class value, |
---|
| 305 | * given class probability estimates. |
---|
| 306 | * |
---|
| 307 | * @param classProbs the class probability estimates. |
---|
| 308 | * @param inst the current instance for which the class probabilites |
---|
| 309 | * apply. Is used for computing any non-fixed cost values. |
---|
| 310 | * @return the expected costs. |
---|
| 311 | * @exception Exception if something goes wrong |
---|
| 312 | */ |
---|
| 313 | public double[] expectedCosts(double [] classProbs, |
---|
| 314 | Instance inst) throws Exception { |
---|
| 315 | |
---|
| 316 | if (classProbs.length != m_size) { |
---|
| 317 | throw new Exception("Length of probability estimates don't " |
---|
| 318 | +"match cost matrix"); |
---|
| 319 | } |
---|
| 320 | |
---|
| 321 | if (!replaceStrings()) { |
---|
| 322 | return expectedCosts(classProbs); |
---|
| 323 | } |
---|
| 324 | |
---|
| 325 | double[] costs = new double[m_size]; |
---|
| 326 | |
---|
| 327 | for (int x = 0; x < m_size; x++) { |
---|
| 328 | for (int y = 0; y < m_size; y++) { |
---|
| 329 | Object element = getCell(y, x); |
---|
| 330 | double costVal; |
---|
| 331 | if (!(element instanceof Double)) { |
---|
| 332 | costVal = |
---|
| 333 | ((AttributeExpression)element).evaluateExpression(inst); |
---|
| 334 | } else { |
---|
| 335 | costVal = ((Double)element).doubleValue(); |
---|
| 336 | } |
---|
| 337 | costs[x] += classProbs[y] * costVal; |
---|
| 338 | } |
---|
| 339 | } |
---|
| 340 | |
---|
| 341 | return costs; |
---|
| 342 | } |
---|
| 343 | |
---|
| 344 | /** |
---|
| 345 | * Gets the maximum cost for a particular class value. |
---|
| 346 | * |
---|
| 347 | * @param classVal the class value. |
---|
| 348 | * @return the maximum cost. |
---|
| 349 | * @exception Exception if cost matrix contains non-fixed |
---|
| 350 | * costs |
---|
| 351 | */ |
---|
| 352 | public double getMaxCost(int classVal) throws Exception { |
---|
| 353 | |
---|
| 354 | double maxCost = Double.NEGATIVE_INFINITY; |
---|
| 355 | |
---|
| 356 | for (int i = 0; i < m_size; i++) { |
---|
| 357 | Object element = getCell(classVal, i); |
---|
| 358 | if (!(element instanceof Double)) { |
---|
| 359 | throw new Exception("Can't use non-fixed costs when " |
---|
| 360 | +"getting max cost."); |
---|
| 361 | } |
---|
| 362 | double cost = ((Double)element).doubleValue(); |
---|
| 363 | if (cost > maxCost) maxCost = cost; |
---|
| 364 | } |
---|
| 365 | |
---|
| 366 | return maxCost; |
---|
| 367 | } |
---|
| 368 | |
---|
| 369 | /** |
---|
| 370 | * Gets the maximum cost for a particular class value. |
---|
| 371 | * |
---|
| 372 | * @param classVal the class value. |
---|
| 373 | * @return the maximum cost. |
---|
| 374 | * @exception Exception if cost matrix contains non-fixed |
---|
| 375 | * costs |
---|
| 376 | */ |
---|
| 377 | public double getMaxCost(int classVal, Instance inst) |
---|
| 378 | throws Exception { |
---|
| 379 | |
---|
| 380 | if (!replaceStrings()) { |
---|
| 381 | return getMaxCost(classVal); |
---|
| 382 | } |
---|
| 383 | |
---|
| 384 | double maxCost = Double.NEGATIVE_INFINITY; |
---|
| 385 | double cost; |
---|
| 386 | for (int i = 0; i < m_size; i++) { |
---|
| 387 | Object element = getCell(classVal, i); |
---|
| 388 | if (!(element instanceof Double)) { |
---|
| 389 | cost = |
---|
| 390 | ((AttributeExpression)element).evaluateExpression(inst); |
---|
| 391 | } else { |
---|
| 392 | cost = ((Double)element).doubleValue(); |
---|
| 393 | } |
---|
| 394 | if (cost > maxCost) maxCost = cost; |
---|
| 395 | } |
---|
| 396 | |
---|
| 397 | return maxCost; |
---|
| 398 | } |
---|
| 399 | |
---|
| 400 | |
---|
| 401 | /** |
---|
| 402 | * Normalizes the matrix so that the diagonal contains zeros. |
---|
| 403 | * |
---|
| 404 | */ |
---|
| 405 | public void normalize() { |
---|
| 406 | |
---|
| 407 | for (int y=0; y<m_size; y++) { |
---|
| 408 | double diag = ((Double)getCell(y, y)).doubleValue(); |
---|
| 409 | for (int x=0; x<m_size; x++) { |
---|
| 410 | setCell(x, y, new Double(((Double)getCell(x, y)). |
---|
| 411 | doubleValue() - diag)); |
---|
| 412 | } |
---|
| 413 | } |
---|
| 414 | } |
---|
| 415 | |
---|
| 416 | /** |
---|
| 417 | * Loads a cost matrix in the old format from a reader. Adapted from code once sitting |
---|
| 418 | * in Instances.java |
---|
| 419 | * |
---|
| 420 | * @param reader the reader to get the values from. |
---|
| 421 | * @exception Exception if the matrix cannot be read correctly. |
---|
| 422 | */ |
---|
| 423 | public void readOldFormat(Reader reader) throws Exception { |
---|
| 424 | |
---|
| 425 | StreamTokenizer tokenizer; |
---|
| 426 | int currentToken; |
---|
| 427 | double firstIndex, secondIndex, weight; |
---|
| 428 | |
---|
| 429 | tokenizer = new StreamTokenizer(reader); |
---|
| 430 | |
---|
| 431 | initialize(); |
---|
| 432 | |
---|
| 433 | tokenizer.commentChar('%'); |
---|
| 434 | tokenizer.eolIsSignificant(true); |
---|
| 435 | while (StreamTokenizer.TT_EOF != (currentToken = tokenizer.nextToken())) { |
---|
| 436 | |
---|
| 437 | // Skip empty lines |
---|
| 438 | if (currentToken == StreamTokenizer.TT_EOL) { |
---|
| 439 | continue; |
---|
| 440 | } |
---|
| 441 | |
---|
| 442 | // Get index of first class. |
---|
| 443 | if (currentToken != StreamTokenizer.TT_NUMBER) { |
---|
| 444 | throw new Exception("Only numbers and comments allowed "+ |
---|
| 445 | "in cost file!"); |
---|
| 446 | } |
---|
| 447 | firstIndex = tokenizer.nval; |
---|
| 448 | if (!Utils.eq((double)(int)firstIndex,firstIndex)) { |
---|
| 449 | throw new Exception("First number in line has to be "+ |
---|
| 450 | "index of a class!"); |
---|
| 451 | } |
---|
| 452 | if ((int)firstIndex >= size()) { |
---|
| 453 | throw new Exception("Class index out of range!"); |
---|
| 454 | } |
---|
| 455 | |
---|
| 456 | // Get index of second class. |
---|
| 457 | if (StreamTokenizer.TT_EOF == (currentToken = tokenizer.nextToken())) { |
---|
| 458 | throw new Exception("Premature end of file!"); |
---|
| 459 | } |
---|
| 460 | if (currentToken == StreamTokenizer.TT_EOL) { |
---|
| 461 | throw new Exception("Premature end of line!"); |
---|
| 462 | } |
---|
| 463 | if (currentToken != StreamTokenizer.TT_NUMBER) { |
---|
| 464 | throw new Exception("Only numbers and comments allowed "+ |
---|
| 465 | "in cost file!"); |
---|
| 466 | } |
---|
| 467 | secondIndex = tokenizer.nval; |
---|
| 468 | if (!Utils.eq((double)(int)secondIndex,secondIndex)) { |
---|
| 469 | throw new Exception("Second number in line has to be "+ |
---|
| 470 | "index of a class!"); |
---|
| 471 | } |
---|
| 472 | if ((int)secondIndex >= size()) { |
---|
| 473 | throw new Exception("Class index out of range!"); |
---|
| 474 | } |
---|
| 475 | if ((int)secondIndex == (int)firstIndex) { |
---|
| 476 | throw new Exception("Diagonal of cost matrix non-zero!"); |
---|
| 477 | } |
---|
| 478 | |
---|
| 479 | // Get cost factor. |
---|
| 480 | if (StreamTokenizer.TT_EOF == (currentToken = tokenizer.nextToken())) { |
---|
| 481 | throw new Exception("Premature end of file!"); |
---|
| 482 | } |
---|
| 483 | if (currentToken == StreamTokenizer.TT_EOL) { |
---|
| 484 | throw new Exception("Premature end of line!"); |
---|
| 485 | } |
---|
| 486 | if (currentToken != StreamTokenizer.TT_NUMBER) { |
---|
| 487 | throw new Exception("Only numbers and comments allowed "+ |
---|
| 488 | "in cost file!"); |
---|
| 489 | } |
---|
| 490 | weight = tokenizer.nval; |
---|
| 491 | if (!Utils.gr(weight,0)) { |
---|
| 492 | throw new Exception("Only positive weights allowed!"); |
---|
| 493 | } |
---|
| 494 | setCell((int)firstIndex, (int)secondIndex, |
---|
| 495 | new Double(weight)); |
---|
| 496 | } |
---|
| 497 | } |
---|
| 498 | |
---|
| 499 | /** |
---|
| 500 | * Reads a matrix from a reader. The first line in the file should |
---|
| 501 | * contain the number of rows and columns. Subsequent lines |
---|
| 502 | * contain elements of the matrix. |
---|
| 503 | * (FracPete: taken from old weka.core.Matrix class) |
---|
| 504 | * |
---|
| 505 | * @param reader the reader containing the matrix |
---|
| 506 | * @throws Exception if an error occurs |
---|
| 507 | * @see #write(Writer) |
---|
| 508 | */ |
---|
| 509 | public CostMatrix(Reader reader) throws Exception { |
---|
| 510 | LineNumberReader lnr = new LineNumberReader(reader); |
---|
| 511 | String line; |
---|
| 512 | int currentRow = -1; |
---|
| 513 | |
---|
| 514 | while ((line = lnr.readLine()) != null) { |
---|
| 515 | |
---|
| 516 | // Comments |
---|
| 517 | if (line.startsWith("%")) { |
---|
| 518 | continue; |
---|
| 519 | } |
---|
| 520 | |
---|
| 521 | StringTokenizer st = new StringTokenizer(line); |
---|
| 522 | // Ignore blank lines |
---|
| 523 | if (!st.hasMoreTokens()) { |
---|
| 524 | continue; |
---|
| 525 | } |
---|
| 526 | |
---|
| 527 | if (currentRow < 0) { |
---|
| 528 | int rows = Integer.parseInt(st.nextToken()); |
---|
| 529 | if (!st.hasMoreTokens()) { |
---|
| 530 | throw new Exception("Line " + lnr.getLineNumber() |
---|
| 531 | + ": expected number of columns"); |
---|
| 532 | } |
---|
| 533 | |
---|
| 534 | int cols = Integer.parseInt(st.nextToken()); |
---|
| 535 | if (rows != cols) { |
---|
| 536 | throw new Exception("Trying to create a non-square cost " |
---|
| 537 | +"matrix"); |
---|
| 538 | } |
---|
| 539 | // m_matrix = new Object[rows][cols]; |
---|
| 540 | m_size = rows; |
---|
| 541 | initialize(); |
---|
| 542 | currentRow++; |
---|
| 543 | continue; |
---|
| 544 | |
---|
| 545 | } else { |
---|
| 546 | if (currentRow == m_size) { |
---|
| 547 | throw new Exception("Line " + lnr.getLineNumber() |
---|
| 548 | + ": too many rows provided"); |
---|
| 549 | } |
---|
| 550 | |
---|
| 551 | for (int i = 0; i < m_size; i++) { |
---|
| 552 | if (!st.hasMoreTokens()) { |
---|
| 553 | throw new Exception("Line " + lnr.getLineNumber() |
---|
| 554 | + ": too few matrix elements provided"); |
---|
| 555 | } |
---|
| 556 | |
---|
| 557 | String nextTok = st.nextToken(); |
---|
| 558 | // try to parse as a double first |
---|
| 559 | Double val = null; |
---|
| 560 | try { |
---|
| 561 | val = new Double(nextTok); |
---|
| 562 | double value = val.doubleValue(); |
---|
| 563 | } catch (Exception ex) { |
---|
| 564 | val = null; |
---|
| 565 | } |
---|
| 566 | if (val == null) { |
---|
| 567 | setCell(currentRow, i, nextTok); |
---|
| 568 | } else { |
---|
| 569 | setCell(currentRow, i, val); |
---|
| 570 | } |
---|
| 571 | } |
---|
| 572 | currentRow++; |
---|
| 573 | } |
---|
| 574 | } |
---|
| 575 | |
---|
| 576 | if (currentRow == -1) { |
---|
| 577 | throw new Exception("Line " + lnr.getLineNumber() |
---|
| 578 | + ": expected number of rows"); |
---|
| 579 | } else if (currentRow != m_size) { |
---|
| 580 | throw new Exception("Line " + lnr.getLineNumber() |
---|
| 581 | + ": too few rows provided"); |
---|
| 582 | } |
---|
| 583 | } |
---|
| 584 | |
---|
| 585 | /** |
---|
| 586 | * Writes out a matrix. The format can be read via the |
---|
| 587 | * CostMatrix(Reader) constructor. |
---|
| 588 | * (FracPete: taken from old weka.core.Matrix class) |
---|
| 589 | * |
---|
| 590 | * @param w the output Writer |
---|
| 591 | * @throws Exception if an error occurs |
---|
| 592 | */ |
---|
| 593 | public void write(Writer w) throws Exception { |
---|
| 594 | w.write("% Rows\tColumns\n"); |
---|
| 595 | w.write("" + m_size + "\t" + m_size + "\n"); |
---|
| 596 | w.write("% Matrix elements\n"); |
---|
| 597 | for(int i = 0; i < m_size; i++) { |
---|
| 598 | for(int j = 0; j < m_size; j++) { |
---|
| 599 | w.write("" + getCell(i, j) + "\t"); |
---|
| 600 | } |
---|
| 601 | w.write("\n"); |
---|
| 602 | } |
---|
| 603 | w.flush(); |
---|
| 604 | } |
---|
| 605 | |
---|
| 606 | /** |
---|
| 607 | * converts the Matrix into a single line Matlab string: matrix is enclosed |
---|
| 608 | * by parentheses, rows are separated by semicolon and single cells by |
---|
| 609 | * blanks, e.g., [1 2; 3 4]. |
---|
| 610 | * @return the matrix in Matlab single line format |
---|
| 611 | */ |
---|
| 612 | public String toMatlab() { |
---|
| 613 | StringBuffer result; |
---|
| 614 | int i; |
---|
| 615 | int n; |
---|
| 616 | |
---|
| 617 | result = new StringBuffer(); |
---|
| 618 | |
---|
| 619 | result.append("["); |
---|
| 620 | |
---|
| 621 | for (i = 0; i < m_size; i++) { |
---|
| 622 | if (i > 0) { |
---|
| 623 | result.append("; "); |
---|
| 624 | } |
---|
| 625 | |
---|
| 626 | for (n = 0; n < m_size; n++) { |
---|
| 627 | if (n > 0) { |
---|
| 628 | result.append(" "); |
---|
| 629 | } |
---|
| 630 | result.append(getCell(i, n)); |
---|
| 631 | } |
---|
| 632 | } |
---|
| 633 | |
---|
| 634 | result.append("]"); |
---|
| 635 | |
---|
| 636 | return result.toString(); |
---|
| 637 | } |
---|
| 638 | |
---|
| 639 | /** |
---|
| 640 | * Set the value of a particular cell in the matrix |
---|
| 641 | * |
---|
| 642 | * @param rowIndex the row |
---|
| 643 | * @param columnIndex the column |
---|
| 644 | * @param value the value to set |
---|
| 645 | */ |
---|
| 646 | public final void setCell(int rowIndex, int columnIndex, |
---|
| 647 | Object value) { |
---|
| 648 | m_matrix[rowIndex][columnIndex] = value; |
---|
| 649 | } |
---|
| 650 | |
---|
| 651 | /** |
---|
| 652 | * Return the contents of a particular cell. Note: this |
---|
| 653 | * method returns the Object stored at a particular cell. |
---|
| 654 | * |
---|
| 655 | * @param rowIndex the row |
---|
| 656 | * @param columnIndex the column |
---|
| 657 | * @return the value at the cell |
---|
| 658 | */ |
---|
| 659 | public final Object getCell(int rowIndex, int columnIndex) { |
---|
| 660 | return m_matrix[rowIndex][columnIndex]; |
---|
| 661 | } |
---|
| 662 | |
---|
| 663 | /** |
---|
| 664 | * Return the value of a cell as a double (for legacy code) |
---|
| 665 | * |
---|
| 666 | * @param rowIndex the row |
---|
| 667 | * @param columnIndex the column |
---|
| 668 | * @return the value at a particular cell as a double |
---|
| 669 | * @exception Exception if the value is not a double |
---|
| 670 | */ |
---|
| 671 | public final double getElement(int rowIndex, int columnIndex) |
---|
| 672 | throws Exception { |
---|
| 673 | if (!(m_matrix[rowIndex][columnIndex] instanceof Double)) { |
---|
| 674 | throw new Exception("Cost matrix contains non-fixed costs!"); |
---|
| 675 | } |
---|
| 676 | return ((Double)m_matrix[rowIndex][columnIndex]).doubleValue(); |
---|
| 677 | } |
---|
| 678 | |
---|
| 679 | /** |
---|
| 680 | * Return the value of a cell as a double. Computes the |
---|
| 681 | * value for non-fixed costs using the supplied Instance |
---|
| 682 | * |
---|
| 683 | * @param rowIndex the row |
---|
| 684 | * @param columnIndex the column |
---|
| 685 | * @return the value from a particular cell |
---|
| 686 | * @exception Exception if something goes wrong |
---|
| 687 | */ |
---|
| 688 | public final double getElement(int rowIndex, int columnIndex, |
---|
| 689 | Instance inst) throws Exception { |
---|
| 690 | |
---|
| 691 | if (m_matrix[rowIndex][columnIndex] instanceof Double) { |
---|
| 692 | return ((Double)m_matrix[rowIndex][columnIndex]).doubleValue(); |
---|
| 693 | } else if (m_matrix[rowIndex][columnIndex] instanceof String) { |
---|
| 694 | replaceStrings(); |
---|
| 695 | } |
---|
| 696 | |
---|
| 697 | return ((AttributeExpression)m_matrix[rowIndex][columnIndex]). |
---|
| 698 | evaluateExpression(inst); |
---|
| 699 | } |
---|
| 700 | |
---|
| 701 | /** |
---|
| 702 | * Set the value of a cell as a double |
---|
| 703 | * |
---|
| 704 | * @param rowIndex the row |
---|
| 705 | * @param columnIndex the column |
---|
| 706 | * @param value the value (double) to set |
---|
| 707 | */ |
---|
| 708 | public final void setElement(int rowIndex, int columnIndex, |
---|
| 709 | double value) { |
---|
| 710 | m_matrix[rowIndex][columnIndex] = new Double(value); |
---|
| 711 | } |
---|
| 712 | |
---|
| 713 | /** |
---|
| 714 | * creates a matrix from the given Matlab string. |
---|
| 715 | * @param matlab the matrix in matlab format |
---|
| 716 | * @return the matrix represented by the given string |
---|
| 717 | */ |
---|
| 718 | public static Matrix parseMatlab(String matlab) throws Exception { |
---|
| 719 | return Matrix.parseMatlab(matlab); |
---|
| 720 | } |
---|
| 721 | |
---|
| 722 | /** |
---|
| 723 | * Converts a matrix to a string. |
---|
| 724 | * (FracPete: taken from old weka.core.Matrix class) |
---|
| 725 | * |
---|
| 726 | * @return the converted string |
---|
| 727 | */ |
---|
| 728 | public String toString() { |
---|
| 729 | // Determine the width required for the maximum element, |
---|
| 730 | // and check for fractional display requirement. |
---|
| 731 | double maxval = 0; |
---|
| 732 | boolean fractional = false; |
---|
| 733 | Object element = null; |
---|
| 734 | int widthNumber = 0; |
---|
| 735 | int widthExpression = 0; |
---|
| 736 | for (int i = 0; i < size(); i++) { |
---|
| 737 | for (int j = 0; j < size(); j++) { |
---|
| 738 | element = getCell(i, j); |
---|
| 739 | if (element instanceof Double) { |
---|
| 740 | double current = ((Double)element).doubleValue(); |
---|
| 741 | |
---|
| 742 | if (current < 0) |
---|
| 743 | current *= -11; |
---|
| 744 | if (current > maxval) |
---|
| 745 | maxval = current; |
---|
| 746 | double fract = Math.abs(current - Math.rint(current)); |
---|
| 747 | if (!fractional |
---|
| 748 | && ((Math.log(fract) / Math.log(10)) >= -2)) { |
---|
| 749 | fractional = true; |
---|
| 750 | } |
---|
| 751 | } else { |
---|
| 752 | if (element.toString().length() > widthExpression) { |
---|
| 753 | widthExpression = element.toString().length(); |
---|
| 754 | } |
---|
| 755 | } |
---|
| 756 | } |
---|
| 757 | } |
---|
| 758 | if (maxval > 0) { |
---|
| 759 | widthNumber = (int)(Math.log(maxval) / Math.log(10) |
---|
| 760 | + (fractional ? 4 : 1)); |
---|
| 761 | } |
---|
| 762 | |
---|
| 763 | int width = (widthNumber > widthExpression) |
---|
| 764 | ? widthNumber |
---|
| 765 | : widthExpression; |
---|
| 766 | |
---|
| 767 | StringBuffer text = new StringBuffer(); |
---|
| 768 | for (int i = 0; i < size(); i++) { |
---|
| 769 | for (int j = 0; j < size(); j++) { |
---|
| 770 | element = getCell(i, j); |
---|
| 771 | if (element instanceof Double) { |
---|
| 772 | text.append(" "). |
---|
| 773 | append(Utils.doubleToString(((Double)element). |
---|
| 774 | doubleValue(), |
---|
| 775 | width, (fractional ? 2 : 0))); |
---|
| 776 | } else { |
---|
| 777 | int diff = width - element.toString().length(); |
---|
| 778 | if (diff > 0) { |
---|
| 779 | int left = diff % 2; |
---|
| 780 | left += diff / 2; |
---|
| 781 | String temp = Utils.padLeft(element.toString(), |
---|
| 782 | element.toString().length()+left); |
---|
| 783 | temp = Utils.padRight(temp, width); |
---|
| 784 | text.append(" ").append(temp); |
---|
| 785 | } else { |
---|
| 786 | text.append(" "). |
---|
| 787 | append(element.toString()); |
---|
| 788 | } |
---|
| 789 | } |
---|
| 790 | } |
---|
| 791 | text.append("\n"); |
---|
| 792 | } |
---|
| 793 | |
---|
| 794 | return text.toString(); |
---|
| 795 | } |
---|
| 796 | |
---|
| 797 | /** |
---|
| 798 | * Returns the revision string. |
---|
| 799 | * |
---|
| 800 | * @return the revision |
---|
| 801 | */ |
---|
| 802 | public String getRevision() { |
---|
| 803 | return RevisionUtils.extract("$Revision: 6041 $"); |
---|
| 804 | } |
---|
| 805 | } |
---|