[29] | 1 | /* |
---|
| 2 | * This program is free software; you can redistribute it and/or modify |
---|
| 3 | * it under the terms of the GNU General Public License as published by |
---|
| 4 | * the Free Software Foundation; either version 2 of the License, or |
---|
| 5 | * (at your option) any later version. |
---|
| 6 | * |
---|
| 7 | * This program is distributed in the hope that it will be useful, |
---|
| 8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
| 9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
---|
| 10 | * GNU General Public License for more details. |
---|
| 11 | * |
---|
| 12 | * You should have received a copy of the GNU General Public License |
---|
| 13 | * along with this program; if not, write to the Free Software |
---|
| 14 | * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. |
---|
| 15 | */ |
---|
| 16 | |
---|
| 17 | /* |
---|
| 18 | * EM.java |
---|
| 19 | * Copyright (C) 1999 University of Waikato, Hamilton, New Zealand |
---|
| 20 | * |
---|
| 21 | */ |
---|
| 22 | |
---|
| 23 | package weka.clusterers; |
---|
| 24 | |
---|
| 25 | import weka.core.Capabilities; |
---|
| 26 | import weka.core.Instance; |
---|
| 27 | import weka.core.Attribute; |
---|
| 28 | import weka.core.Instances; |
---|
| 29 | import weka.core.Option; |
---|
| 30 | import weka.core.RevisionUtils; |
---|
| 31 | import weka.core.Utils; |
---|
| 32 | import weka.core.WeightedInstancesHandler; |
---|
| 33 | import weka.estimators.DiscreteEstimator; |
---|
| 34 | import weka.estimators.Estimator; |
---|
| 35 | import weka.filters.unsupervised.attribute.ReplaceMissingValues; |
---|
| 36 | |
---|
| 37 | import java.util.Enumeration; |
---|
| 38 | import java.util.Random; |
---|
| 39 | import java.util.Vector; |
---|
| 40 | |
---|
| 41 | /** |
---|
| 42 | <!-- globalinfo-start --> |
---|
| 43 | * Simple EM (expectation maximisation) class.<br/> |
---|
| 44 | * <br/> |
---|
| 45 | * EM assigns a probability distribution to each instance which indicates the probability of it belonging to each of the clusters. EM can decide how many clusters to create by cross validation, or you may specify apriori how many clusters to generate.<br/> |
---|
| 46 | * <br/> |
---|
| 47 | * The cross validation performed to determine the number of clusters is done in the following steps:<br/> |
---|
| 48 | * 1. the number of clusters is set to 1<br/> |
---|
| 49 | * 2. the training set is split randomly into 10 folds.<br/> |
---|
| 50 | * 3. EM is performed 10 times using the 10 folds the usual CV way.<br/> |
---|
| 51 | * 4. the loglikelihood is averaged over all 10 results.<br/> |
---|
| 52 | * 5. if loglikelihood has increased the number of clusters is increased by 1 and the program continues at step 2. <br/> |
---|
| 53 | * <br/> |
---|
| 54 | * The number of folds is fixed to 10, as long as the number of instances in the training set is not smaller 10. If this is the case the number of folds is set equal to the number of instances. |
---|
| 55 | * <p/> |
---|
| 56 | <!-- globalinfo-end --> |
---|
| 57 | * |
---|
| 58 | <!-- options-start --> |
---|
| 59 | * Valid options are: <p/> |
---|
| 60 | * |
---|
| 61 | * <pre> -N <num> |
---|
| 62 | * number of clusters. If omitted or -1 specified, then |
---|
| 63 | * cross validation is used to select the number of clusters.</pre> |
---|
| 64 | * |
---|
| 65 | * <pre> -I <num> |
---|
| 66 | * max iterations. |
---|
| 67 | * (default 100)</pre> |
---|
| 68 | * |
---|
| 69 | * <pre> -V |
---|
| 70 | * verbose.</pre> |
---|
| 71 | * |
---|
| 72 | * <pre> -M <num> |
---|
| 73 | * minimum allowable standard deviation for normal density |
---|
| 74 | * computation |
---|
| 75 | * (default 1e-6)</pre> |
---|
| 76 | * |
---|
| 77 | * <pre> -O |
---|
| 78 | * Display model in old format (good when there are many clusters) |
---|
| 79 | * </pre> |
---|
| 80 | * |
---|
| 81 | * <pre> -S <num> |
---|
| 82 | * Random number seed. |
---|
| 83 | * (default 100)</pre> |
---|
| 84 | * |
---|
| 85 | <!-- options-end --> |
---|
| 86 | * |
---|
| 87 | * @author Mark Hall (mhall@cs.waikato.ac.nz) |
---|
| 88 | * @author Eibe Frank (eibe@cs.waikato.ac.nz) |
---|
| 89 | * @version $Revision: 1.44 $ |
---|
| 90 | */ |
---|
| 91 | public class EM |
---|
| 92 | extends RandomizableDensityBasedClusterer |
---|
| 93 | implements NumberOfClustersRequestable, WeightedInstancesHandler { |
---|
| 94 | |
---|
| 95 | /** for serialization */ |
---|
| 96 | static final long serialVersionUID = 8348181483812829475L; |
---|
| 97 | |
---|
| 98 | /** hold the discrete estimators for each cluster */ |
---|
| 99 | private Estimator m_model[][]; |
---|
| 100 | |
---|
| 101 | /** hold the normal estimators for each cluster */ |
---|
| 102 | private double m_modelNormal[][][]; |
---|
| 103 | |
---|
| 104 | /** default minimum standard deviation */ |
---|
| 105 | private double m_minStdDev = 1e-6; |
---|
| 106 | |
---|
| 107 | private double [] m_minStdDevPerAtt; |
---|
| 108 | |
---|
| 109 | /** hold the weights of each instance for each cluster */ |
---|
| 110 | private double m_weights[][]; |
---|
| 111 | |
---|
| 112 | /** the prior probabilities for clusters */ |
---|
| 113 | private double m_priors[]; |
---|
| 114 | |
---|
| 115 | /** the loglikelihood of the data */ |
---|
| 116 | private double m_loglikely; |
---|
| 117 | |
---|
| 118 | /** training instances */ |
---|
| 119 | private Instances m_theInstances = null; |
---|
| 120 | |
---|
| 121 | /** number of clusters selected by the user or cross validation */ |
---|
| 122 | private int m_num_clusters; |
---|
| 123 | |
---|
| 124 | /** the initial number of clusters requested by the user--- -1 if |
---|
| 125 | xval is to be used to find the number of clusters */ |
---|
| 126 | private int m_initialNumClusters; |
---|
| 127 | |
---|
| 128 | /** number of attributes */ |
---|
| 129 | private int m_num_attribs; |
---|
| 130 | |
---|
| 131 | /** number of training instances */ |
---|
| 132 | private int m_num_instances; |
---|
| 133 | |
---|
| 134 | /** maximum iterations to perform */ |
---|
| 135 | private int m_max_iterations; |
---|
| 136 | |
---|
| 137 | /** attribute min values */ |
---|
| 138 | private double [] m_minValues; |
---|
| 139 | |
---|
| 140 | /** attribute max values */ |
---|
| 141 | private double [] m_maxValues; |
---|
| 142 | |
---|
| 143 | /** random number generator */ |
---|
| 144 | private Random m_rr; |
---|
| 145 | |
---|
| 146 | /** Verbose? */ |
---|
| 147 | private boolean m_verbose; |
---|
| 148 | |
---|
| 149 | /** globally replace missing values */ |
---|
| 150 | private ReplaceMissingValues m_replaceMissing; |
---|
| 151 | |
---|
| 152 | /** display model output in old-style format */ |
---|
| 153 | private boolean m_displayModelInOldFormat; |
---|
| 154 | |
---|
| 155 | /** |
---|
| 156 | * Returns a string describing this clusterer |
---|
| 157 | * @return a description of the evaluator suitable for |
---|
| 158 | * displaying in the explorer/experimenter gui |
---|
| 159 | */ |
---|
| 160 | public String globalInfo() { |
---|
| 161 | return |
---|
| 162 | "Simple EM (expectation maximisation) class.\n\n" |
---|
| 163 | + "EM assigns a probability distribution to each instance which " |
---|
| 164 | + "indicates the probability of it belonging to each of the clusters. " |
---|
| 165 | + "EM can decide how many clusters to create by cross validation, or you " |
---|
| 166 | + "may specify apriori how many clusters to generate.\n\n" |
---|
| 167 | + "The cross validation performed to determine the number of clusters " |
---|
| 168 | + "is done in the following steps:\n" |
---|
| 169 | + "1. the number of clusters is set to 1\n" |
---|
| 170 | + "2. the training set is split randomly into 10 folds.\n" |
---|
| 171 | + "3. EM is performed 10 times using the 10 folds the usual CV way.\n" |
---|
| 172 | + "4. the loglikelihood is averaged over all 10 results.\n" |
---|
| 173 | + "5. if loglikelihood has increased the number of clusters is increased " |
---|
| 174 | + "by 1 and the program continues at step 2. \n\n" |
---|
| 175 | + "The number of folds is fixed to 10, as long as the number of " |
---|
| 176 | + "instances in the training set is not smaller 10. If this is the case " |
---|
| 177 | + "the number of folds is set equal to the number of instances."; |
---|
| 178 | } |
---|
| 179 | |
---|
| 180 | /** |
---|
| 181 | * Returns an enumeration describing the available options. |
---|
| 182 | * |
---|
| 183 | * @return an enumeration of all the available options. |
---|
| 184 | */ |
---|
| 185 | public Enumeration listOptions () { |
---|
| 186 | Vector result = new Vector(); |
---|
| 187 | |
---|
| 188 | result.addElement(new Option( |
---|
| 189 | "\tnumber of clusters. If omitted or -1 specified, then \n" |
---|
| 190 | + "\tcross validation is used to select the number of clusters.", |
---|
| 191 | "N", 1, "-N <num>")); |
---|
| 192 | |
---|
| 193 | result.addElement(new Option( |
---|
| 194 | "\tmax iterations." |
---|
| 195 | + "\n(default 100)", |
---|
| 196 | "I", 1, "-I <num>")); |
---|
| 197 | |
---|
| 198 | result.addElement(new Option( |
---|
| 199 | "\tverbose.", |
---|
| 200 | "V", 0, "-V")); |
---|
| 201 | |
---|
| 202 | result.addElement(new Option( |
---|
| 203 | "\tminimum allowable standard deviation for normal density\n" |
---|
| 204 | + "\tcomputation\n" |
---|
| 205 | + "\t(default 1e-6)", |
---|
| 206 | "M",1,"-M <num>")); |
---|
| 207 | |
---|
| 208 | result.addElement( |
---|
| 209 | new Option("\tDisplay model in old format (good when there are " |
---|
| 210 | + "many clusters)\n", |
---|
| 211 | "O", 0, "-O")); |
---|
| 212 | |
---|
| 213 | Enumeration en = super.listOptions(); |
---|
| 214 | while (en.hasMoreElements()) |
---|
| 215 | result.addElement(en.nextElement()); |
---|
| 216 | |
---|
| 217 | return result.elements(); |
---|
| 218 | } |
---|
| 219 | |
---|
| 220 | |
---|
| 221 | /** |
---|
| 222 | * Parses a given list of options. <p/> |
---|
| 223 | * |
---|
| 224 | <!-- options-start --> |
---|
| 225 | * Valid options are: <p/> |
---|
| 226 | * |
---|
| 227 | * <pre> -N <num> |
---|
| 228 | * number of clusters. If omitted or -1 specified, then |
---|
| 229 | * cross validation is used to select the number of clusters.</pre> |
---|
| 230 | * |
---|
| 231 | * <pre> -I <num> |
---|
| 232 | * max iterations. |
---|
| 233 | * (default 100)</pre> |
---|
| 234 | * |
---|
| 235 | * <pre> -V |
---|
| 236 | * verbose.</pre> |
---|
| 237 | * |
---|
| 238 | * <pre> -M <num> |
---|
| 239 | * minimum allowable standard deviation for normal density |
---|
| 240 | * computation |
---|
| 241 | * (default 1e-6)</pre> |
---|
| 242 | * |
---|
| 243 | * <pre> -O |
---|
| 244 | * Display model in old format (good when there are many clusters) |
---|
| 245 | * </pre> |
---|
| 246 | * |
---|
| 247 | * <pre> -S <num> |
---|
| 248 | * Random number seed. |
---|
| 249 | * (default 100)</pre> |
---|
| 250 | * |
---|
| 251 | <!-- options-end --> |
---|
| 252 | * |
---|
| 253 | * @param options the list of options as an array of strings |
---|
| 254 | * @throws Exception if an option is not supported |
---|
| 255 | */ |
---|
| 256 | public void setOptions (String[] options) |
---|
| 257 | throws Exception { |
---|
| 258 | resetOptions(); |
---|
| 259 | setDebug(Utils.getFlag('V', options)); |
---|
| 260 | String optionString = Utils.getOption('I', options); |
---|
| 261 | |
---|
| 262 | if (optionString.length() != 0) { |
---|
| 263 | setMaxIterations(Integer.parseInt(optionString)); |
---|
| 264 | } |
---|
| 265 | |
---|
| 266 | optionString = Utils.getOption('N', options); |
---|
| 267 | if (optionString.length() != 0) { |
---|
| 268 | setNumClusters(Integer.parseInt(optionString)); |
---|
| 269 | } |
---|
| 270 | |
---|
| 271 | optionString = Utils.getOption('M', options); |
---|
| 272 | if (optionString.length() != 0) { |
---|
| 273 | setMinStdDev((new Double(optionString)).doubleValue()); |
---|
| 274 | } |
---|
| 275 | |
---|
| 276 | setDisplayModelInOldFormat(Utils.getFlag('O', options)); |
---|
| 277 | |
---|
| 278 | super.setOptions(options); |
---|
| 279 | } |
---|
| 280 | |
---|
| 281 | /** |
---|
| 282 | * Returns the tip text for this property |
---|
| 283 | * @return tip text for this property suitable for |
---|
| 284 | * displaying in the explorer/experimenter gui |
---|
| 285 | */ |
---|
| 286 | public String displayModelInOldFormatTipText() { |
---|
| 287 | return "Use old format for model output. The old format is " |
---|
| 288 | + "better when there are many clusters. The new format " |
---|
| 289 | + "is better when there are fewer clusters and many attributes."; |
---|
| 290 | } |
---|
| 291 | |
---|
| 292 | /** |
---|
| 293 | * Set whether to display model output in the old, original |
---|
| 294 | * format. |
---|
| 295 | * |
---|
| 296 | * @param d true if model ouput is to be shown in the old format |
---|
| 297 | */ |
---|
| 298 | public void setDisplayModelInOldFormat(boolean d) { |
---|
| 299 | m_displayModelInOldFormat = d; |
---|
| 300 | } |
---|
| 301 | |
---|
| 302 | /** |
---|
| 303 | * Get whether to display model output in the old, original |
---|
| 304 | * format. |
---|
| 305 | * |
---|
| 306 | * @return true if model ouput is to be shown in the old format |
---|
| 307 | */ |
---|
| 308 | public boolean getDisplayModelInOldFormat() { |
---|
| 309 | return m_displayModelInOldFormat; |
---|
| 310 | } |
---|
| 311 | |
---|
| 312 | /** |
---|
| 313 | * Returns the tip text for this property |
---|
| 314 | * @return tip text for this property suitable for |
---|
| 315 | * displaying in the explorer/experimenter gui |
---|
| 316 | */ |
---|
| 317 | public String minStdDevTipText() { |
---|
| 318 | return "set minimum allowable standard deviation"; |
---|
| 319 | } |
---|
| 320 | |
---|
| 321 | /** |
---|
| 322 | * Set the minimum value for standard deviation when calculating |
---|
| 323 | * normal density. Reducing this value can help prevent arithmetic |
---|
| 324 | * overflow resulting from multiplying large densities (arising from small |
---|
| 325 | * standard deviations) when there are many singleton or near singleton |
---|
| 326 | * values. |
---|
| 327 | * @param m minimum value for standard deviation |
---|
| 328 | */ |
---|
| 329 | public void setMinStdDev(double m) { |
---|
| 330 | m_minStdDev = m; |
---|
| 331 | } |
---|
| 332 | |
---|
| 333 | public void setMinStdDevPerAtt(double [] m) { |
---|
| 334 | m_minStdDevPerAtt = m; |
---|
| 335 | } |
---|
| 336 | |
---|
| 337 | /** |
---|
| 338 | * Get the minimum allowable standard deviation. |
---|
| 339 | * @return the minumum allowable standard deviation |
---|
| 340 | */ |
---|
| 341 | public double getMinStdDev() { |
---|
| 342 | return m_minStdDev; |
---|
| 343 | } |
---|
| 344 | |
---|
| 345 | /** |
---|
| 346 | * Returns the tip text for this property |
---|
| 347 | * @return tip text for this property suitable for |
---|
| 348 | * displaying in the explorer/experimenter gui |
---|
| 349 | */ |
---|
| 350 | public String numClustersTipText() { |
---|
| 351 | return "set number of clusters. -1 to select number of clusters " |
---|
| 352 | +"automatically by cross validation."; |
---|
| 353 | } |
---|
| 354 | |
---|
| 355 | /** |
---|
| 356 | * Set the number of clusters (-1 to select by CV). |
---|
| 357 | * |
---|
| 358 | * @param n the number of clusters |
---|
| 359 | * @throws Exception if n is 0 |
---|
| 360 | */ |
---|
| 361 | public void setNumClusters (int n) |
---|
| 362 | throws Exception { |
---|
| 363 | |
---|
| 364 | if (n == 0) { |
---|
| 365 | throw new Exception("Number of clusters must be > 0. (or -1 to " |
---|
| 366 | + "select by cross validation)."); |
---|
| 367 | } |
---|
| 368 | |
---|
| 369 | if (n < 0) { |
---|
| 370 | m_num_clusters = -1; |
---|
| 371 | m_initialNumClusters = -1; |
---|
| 372 | } |
---|
| 373 | else { |
---|
| 374 | m_num_clusters = n; |
---|
| 375 | m_initialNumClusters = n; |
---|
| 376 | } |
---|
| 377 | } |
---|
| 378 | |
---|
| 379 | |
---|
| 380 | /** |
---|
| 381 | * Get the number of clusters |
---|
| 382 | * |
---|
| 383 | * @return the number of clusters. |
---|
| 384 | */ |
---|
| 385 | public int getNumClusters () { |
---|
| 386 | return m_initialNumClusters; |
---|
| 387 | } |
---|
| 388 | |
---|
| 389 | /** |
---|
| 390 | * Returns the tip text for this property |
---|
| 391 | * @return tip text for this property suitable for |
---|
| 392 | * displaying in the explorer/experimenter gui |
---|
| 393 | */ |
---|
| 394 | public String maxIterationsTipText() { |
---|
| 395 | return "maximum number of iterations"; |
---|
| 396 | } |
---|
| 397 | |
---|
| 398 | /** |
---|
| 399 | * Set the maximum number of iterations to perform |
---|
| 400 | * |
---|
| 401 | * @param i the number of iterations |
---|
| 402 | * @throws Exception if i is less than 1 |
---|
| 403 | */ |
---|
| 404 | public void setMaxIterations (int i) |
---|
| 405 | throws Exception { |
---|
| 406 | if (i < 1) { |
---|
| 407 | throw new Exception("Maximum number of iterations must be > 0!"); |
---|
| 408 | } |
---|
| 409 | |
---|
| 410 | m_max_iterations = i; |
---|
| 411 | } |
---|
| 412 | |
---|
| 413 | |
---|
| 414 | /** |
---|
| 415 | * Get the maximum number of iterations |
---|
| 416 | * |
---|
| 417 | * @return the number of iterations |
---|
| 418 | */ |
---|
| 419 | public int getMaxIterations () { |
---|
| 420 | return m_max_iterations; |
---|
| 421 | } |
---|
| 422 | |
---|
| 423 | |
---|
| 424 | /** |
---|
| 425 | * Returns the tip text for this property |
---|
| 426 | * @return tip text for this property suitable for |
---|
| 427 | * displaying in the explorer/experimenter gui |
---|
| 428 | */ |
---|
| 429 | public String debugTipText() { |
---|
| 430 | return "If set to true, clusterer may output additional info to " + |
---|
| 431 | "the console."; |
---|
| 432 | } |
---|
| 433 | |
---|
| 434 | |
---|
| 435 | /** |
---|
| 436 | * Set debug mode - verbose output |
---|
| 437 | * |
---|
| 438 | * @param v true for verbose output |
---|
| 439 | */ |
---|
| 440 | public void setDebug (boolean v) { |
---|
| 441 | m_verbose = v; |
---|
| 442 | } |
---|
| 443 | |
---|
| 444 | |
---|
| 445 | /** |
---|
| 446 | * Get debug mode |
---|
| 447 | * |
---|
| 448 | * @return true if debug mode is set |
---|
| 449 | */ |
---|
| 450 | public boolean getDebug () { |
---|
| 451 | return m_verbose; |
---|
| 452 | } |
---|
| 453 | |
---|
| 454 | |
---|
| 455 | /** |
---|
| 456 | * Gets the current settings of EM. |
---|
| 457 | * |
---|
| 458 | * @return an array of strings suitable for passing to setOptions() |
---|
| 459 | */ |
---|
| 460 | public String[] getOptions () { |
---|
| 461 | int i; |
---|
| 462 | Vector result; |
---|
| 463 | String[] options; |
---|
| 464 | |
---|
| 465 | result = new Vector(); |
---|
| 466 | |
---|
| 467 | result.add("-I"); |
---|
| 468 | result.add("" + m_max_iterations); |
---|
| 469 | result.add("-N"); |
---|
| 470 | result.add("" + getNumClusters()); |
---|
| 471 | result.add("-M"); |
---|
| 472 | result.add("" + getMinStdDev()); |
---|
| 473 | if (m_displayModelInOldFormat) { |
---|
| 474 | result.add("-O"); |
---|
| 475 | } |
---|
| 476 | |
---|
| 477 | options = super.getOptions(); |
---|
| 478 | for (i = 0; i < options.length; i++) |
---|
| 479 | result.add(options[i]); |
---|
| 480 | |
---|
| 481 | return (String[]) result.toArray(new String[result.size()]); |
---|
| 482 | } |
---|
| 483 | |
---|
| 484 | /** |
---|
| 485 | * Initialise estimators and storage. |
---|
| 486 | * |
---|
| 487 | * @param inst the instances |
---|
| 488 | * @throws Exception if initialization fails |
---|
| 489 | **/ |
---|
| 490 | private void EM_Init (Instances inst) |
---|
| 491 | throws Exception { |
---|
| 492 | int i, j, k; |
---|
| 493 | |
---|
| 494 | // run k means 10 times and choose best solution |
---|
| 495 | SimpleKMeans bestK = null; |
---|
| 496 | double bestSqE = Double.MAX_VALUE; |
---|
| 497 | for (i = 0; i < 10; i++) { |
---|
| 498 | SimpleKMeans sk = new SimpleKMeans(); |
---|
| 499 | sk.setSeed(m_rr.nextInt()); |
---|
| 500 | sk.setNumClusters(m_num_clusters); |
---|
| 501 | sk.setDisplayStdDevs(true); |
---|
| 502 | sk.buildClusterer(inst); |
---|
| 503 | if (sk.getSquaredError() < bestSqE) { |
---|
| 504 | bestSqE = sk.getSquaredError(); |
---|
| 505 | bestK = sk; |
---|
| 506 | } |
---|
| 507 | } |
---|
| 508 | |
---|
| 509 | // initialize with best k-means solution |
---|
| 510 | m_num_clusters = bestK.numberOfClusters(); |
---|
| 511 | m_weights = new double[inst.numInstances()][m_num_clusters]; |
---|
| 512 | m_model = new DiscreteEstimator[m_num_clusters][m_num_attribs]; |
---|
| 513 | m_modelNormal = new double[m_num_clusters][m_num_attribs][3]; |
---|
| 514 | m_priors = new double[m_num_clusters]; |
---|
| 515 | Instances centers = bestK.getClusterCentroids(); |
---|
| 516 | Instances stdD = bestK.getClusterStandardDevs(); |
---|
| 517 | int [][][] nominalCounts = bestK.getClusterNominalCounts(); |
---|
| 518 | int [] clusterSizes = bestK.getClusterSizes(); |
---|
| 519 | |
---|
| 520 | for (i = 0; i < m_num_clusters; i++) { |
---|
| 521 | Instance center = centers.instance(i); |
---|
| 522 | for (j = 0; j < m_num_attribs; j++) { |
---|
| 523 | if (inst.attribute(j).isNominal()) { |
---|
| 524 | m_model[i][j] = new DiscreteEstimator(m_theInstances. |
---|
| 525 | attribute(j).numValues() |
---|
| 526 | , true); |
---|
| 527 | for (k = 0; k < inst.attribute(j).numValues(); k++) { |
---|
| 528 | m_model[i][j].addValue(k, nominalCounts[i][j][k]); |
---|
| 529 | } |
---|
| 530 | } else { |
---|
| 531 | double minStdD = (m_minStdDevPerAtt != null) |
---|
| 532 | ? m_minStdDevPerAtt[j] |
---|
| 533 | : m_minStdDev; |
---|
| 534 | double mean = (center.isMissing(j)) |
---|
| 535 | ? inst.meanOrMode(j) |
---|
| 536 | : center.value(j); |
---|
| 537 | m_modelNormal[i][j][0] = mean; |
---|
| 538 | double stdv = (stdD.instance(i).isMissing(j)) |
---|
| 539 | ? ((m_maxValues[j] - m_minValues[j]) / (2 * m_num_clusters)) |
---|
| 540 | : stdD.instance(i).value(j); |
---|
| 541 | if (stdv < minStdD) { |
---|
| 542 | stdv = inst.attributeStats(j).numericStats.stdDev; |
---|
| 543 | if (Double.isInfinite(stdv)) { |
---|
| 544 | stdv = minStdD; |
---|
| 545 | } |
---|
| 546 | if (stdv < minStdD) { |
---|
| 547 | stdv = minStdD; |
---|
| 548 | } |
---|
| 549 | } |
---|
| 550 | if (stdv <= 0) { |
---|
| 551 | stdv = m_minStdDev; |
---|
| 552 | } |
---|
| 553 | |
---|
| 554 | m_modelNormal[i][j][1] = stdv; |
---|
| 555 | m_modelNormal[i][j][2] = 1.0; |
---|
| 556 | } |
---|
| 557 | } |
---|
| 558 | } |
---|
| 559 | |
---|
| 560 | |
---|
| 561 | for (j = 0; j < m_num_clusters; j++) { |
---|
| 562 | // m_priors[j] += 1.0; |
---|
| 563 | m_priors[j] = clusterSizes[j]; |
---|
| 564 | } |
---|
| 565 | Utils.normalize(m_priors); |
---|
| 566 | } |
---|
| 567 | |
---|
| 568 | |
---|
| 569 | /** |
---|
| 570 | * calculate prior probabilites for the clusters |
---|
| 571 | * |
---|
| 572 | * @param inst the instances |
---|
| 573 | * @throws Exception if priors can't be calculated |
---|
| 574 | **/ |
---|
| 575 | private void estimate_priors (Instances inst) |
---|
| 576 | throws Exception { |
---|
| 577 | |
---|
| 578 | for (int i = 0; i < m_num_clusters; i++) { |
---|
| 579 | m_priors[i] = 0.0; |
---|
| 580 | } |
---|
| 581 | |
---|
| 582 | for (int i = 0; i < inst.numInstances(); i++) { |
---|
| 583 | for (int j = 0; j < m_num_clusters; j++) { |
---|
| 584 | m_priors[j] += inst.instance(i).weight() * m_weights[i][j]; |
---|
| 585 | } |
---|
| 586 | } |
---|
| 587 | |
---|
| 588 | Utils.normalize(m_priors); |
---|
| 589 | } |
---|
| 590 | |
---|
| 591 | |
---|
| 592 | /** Constant for normal distribution. */ |
---|
| 593 | private static double m_normConst = Math.log(Math.sqrt(2*Math.PI)); |
---|
| 594 | |
---|
| 595 | /** |
---|
| 596 | * Density function of normal distribution. |
---|
| 597 | * @param x input value |
---|
| 598 | * @param mean mean of distribution |
---|
| 599 | * @param stdDev standard deviation of distribution |
---|
| 600 | * @return the density |
---|
| 601 | */ |
---|
| 602 | private double logNormalDens (double x, double mean, double stdDev) { |
---|
| 603 | |
---|
| 604 | double diff = x - mean; |
---|
| 605 | // System.err.println("x: "+x+" mean: "+mean+" diff: "+diff+" stdv: "+stdDev); |
---|
| 606 | // System.err.println("diff*diff/(2*stdv*stdv): "+ (diff * diff / (2 * stdDev * stdDev))); |
---|
| 607 | |
---|
| 608 | return - (diff * diff / (2 * stdDev * stdDev)) - m_normConst - Math.log(stdDev); |
---|
| 609 | } |
---|
| 610 | |
---|
| 611 | /** |
---|
| 612 | * New probability estimators for an iteration |
---|
| 613 | */ |
---|
| 614 | private void new_estimators () { |
---|
| 615 | for (int i = 0; i < m_num_clusters; i++) { |
---|
| 616 | for (int j = 0; j < m_num_attribs; j++) { |
---|
| 617 | if (m_theInstances.attribute(j).isNominal()) { |
---|
| 618 | m_model[i][j] = new DiscreteEstimator(m_theInstances. |
---|
| 619 | attribute(j).numValues() |
---|
| 620 | , true); |
---|
| 621 | } |
---|
| 622 | else { |
---|
| 623 | m_modelNormal[i][j][0] = m_modelNormal[i][j][1] = |
---|
| 624 | m_modelNormal[i][j][2] = 0.0; |
---|
| 625 | } |
---|
| 626 | } |
---|
| 627 | } |
---|
| 628 | } |
---|
| 629 | |
---|
| 630 | |
---|
| 631 | /** |
---|
| 632 | * The M step of the EM algorithm. |
---|
| 633 | * @param inst the training instances |
---|
| 634 | * @throws Exception if something goes wrong |
---|
| 635 | */ |
---|
| 636 | private void M (Instances inst) |
---|
| 637 | throws Exception { |
---|
| 638 | |
---|
| 639 | int i, j, l; |
---|
| 640 | |
---|
| 641 | new_estimators(); |
---|
| 642 | |
---|
| 643 | for (i = 0; i < m_num_clusters; i++) { |
---|
| 644 | for (j = 0; j < m_num_attribs; j++) { |
---|
| 645 | for (l = 0; l < inst.numInstances(); l++) { |
---|
| 646 | Instance in = inst.instance(l); |
---|
| 647 | if (!in.isMissing(j)) { |
---|
| 648 | if (inst.attribute(j).isNominal()) { |
---|
| 649 | m_model[i][j].addValue(in.value(j), |
---|
| 650 | in.weight() * m_weights[l][i]); |
---|
| 651 | } |
---|
| 652 | else { |
---|
| 653 | m_modelNormal[i][j][0] += (in.value(j) * in.weight() * |
---|
| 654 | m_weights[l][i]); |
---|
| 655 | m_modelNormal[i][j][2] += in.weight() * m_weights[l][i]; |
---|
| 656 | m_modelNormal[i][j][1] += (in.value(j) * |
---|
| 657 | in.value(j) * in.weight() * m_weights[l][i]); |
---|
| 658 | } |
---|
| 659 | } |
---|
| 660 | } |
---|
| 661 | } |
---|
| 662 | } |
---|
| 663 | |
---|
| 664 | // calcualte mean and std deviation for numeric attributes |
---|
| 665 | for (j = 0; j < m_num_attribs; j++) { |
---|
| 666 | if (!inst.attribute(j).isNominal()) { |
---|
| 667 | for (i = 0; i < m_num_clusters; i++) { |
---|
| 668 | if (m_modelNormal[i][j][2] <= 0) { |
---|
| 669 | m_modelNormal[i][j][1] = Double.MAX_VALUE; |
---|
| 670 | // m_modelNormal[i][j][0] = 0; |
---|
| 671 | m_modelNormal[i][j][0] = m_minStdDev; |
---|
| 672 | } else { |
---|
| 673 | |
---|
| 674 | // variance |
---|
| 675 | m_modelNormal[i][j][1] = (m_modelNormal[i][j][1] - |
---|
| 676 | (m_modelNormal[i][j][0] * |
---|
| 677 | m_modelNormal[i][j][0] / |
---|
| 678 | m_modelNormal[i][j][2])) / |
---|
| 679 | (m_modelNormal[i][j][2]); |
---|
| 680 | |
---|
| 681 | if (m_modelNormal[i][j][1] < 0) { |
---|
| 682 | m_modelNormal[i][j][1] = 0; |
---|
| 683 | } |
---|
| 684 | |
---|
| 685 | // std dev |
---|
| 686 | double minStdD = (m_minStdDevPerAtt != null) |
---|
| 687 | ? m_minStdDevPerAtt[j] |
---|
| 688 | : m_minStdDev; |
---|
| 689 | |
---|
| 690 | m_modelNormal[i][j][1] = Math.sqrt(m_modelNormal[i][j][1]); |
---|
| 691 | |
---|
| 692 | if ((m_modelNormal[i][j][1] <= minStdD)) { |
---|
| 693 | m_modelNormal[i][j][1] = inst.attributeStats(j).numericStats.stdDev; |
---|
| 694 | if ((m_modelNormal[i][j][1] <= minStdD)) { |
---|
| 695 | m_modelNormal[i][j][1] = minStdD; |
---|
| 696 | } |
---|
| 697 | } |
---|
| 698 | if ((m_modelNormal[i][j][1] <= 0)) { |
---|
| 699 | m_modelNormal[i][j][1] = m_minStdDev; |
---|
| 700 | } |
---|
| 701 | if (Double.isInfinite(m_modelNormal[i][j][1])) { |
---|
| 702 | m_modelNormal[i][j][1] = m_minStdDev; |
---|
| 703 | } |
---|
| 704 | |
---|
| 705 | // mean |
---|
| 706 | m_modelNormal[i][j][0] /= m_modelNormal[i][j][2]; |
---|
| 707 | } |
---|
| 708 | } |
---|
| 709 | } |
---|
| 710 | } |
---|
| 711 | } |
---|
| 712 | |
---|
| 713 | /** |
---|
| 714 | * The E step of the EM algorithm. Estimate cluster membership |
---|
| 715 | * probabilities. |
---|
| 716 | * |
---|
| 717 | * @param inst the training instances |
---|
| 718 | * @param change_weights whether to change the weights |
---|
| 719 | * @return the average log likelihood |
---|
| 720 | * @throws Exception if computation fails |
---|
| 721 | */ |
---|
| 722 | private double E (Instances inst, boolean change_weights) |
---|
| 723 | throws Exception { |
---|
| 724 | |
---|
| 725 | double loglk = 0.0, sOW = 0.0; |
---|
| 726 | |
---|
| 727 | for (int l = 0; l < inst.numInstances(); l++) { |
---|
| 728 | |
---|
| 729 | Instance in = inst.instance(l); |
---|
| 730 | |
---|
| 731 | loglk += in.weight() * logDensityForInstance(in); |
---|
| 732 | sOW += in.weight(); |
---|
| 733 | |
---|
| 734 | if (change_weights) { |
---|
| 735 | m_weights[l] = distributionForInstance(in); |
---|
| 736 | } |
---|
| 737 | } |
---|
| 738 | |
---|
| 739 | // reestimate priors |
---|
| 740 | if (change_weights) { |
---|
| 741 | estimate_priors(inst); |
---|
| 742 | } |
---|
| 743 | return loglk / sOW; |
---|
| 744 | } |
---|
| 745 | |
---|
| 746 | |
---|
| 747 | /** |
---|
| 748 | * Constructor. |
---|
| 749 | * |
---|
| 750 | **/ |
---|
| 751 | public EM () { |
---|
| 752 | super(); |
---|
| 753 | |
---|
| 754 | m_SeedDefault = 100; |
---|
| 755 | resetOptions(); |
---|
| 756 | } |
---|
| 757 | |
---|
| 758 | |
---|
| 759 | /** |
---|
| 760 | * Reset to default options |
---|
| 761 | */ |
---|
| 762 | protected void resetOptions () { |
---|
| 763 | m_minStdDev = 1e-6; |
---|
| 764 | m_max_iterations = 100; |
---|
| 765 | m_Seed = m_SeedDefault; |
---|
| 766 | m_num_clusters = -1; |
---|
| 767 | m_initialNumClusters = -1; |
---|
| 768 | m_verbose = false; |
---|
| 769 | } |
---|
| 770 | |
---|
| 771 | /** |
---|
| 772 | * Return the normal distributions for the cluster models |
---|
| 773 | * |
---|
| 774 | * @return a <code>double[][][]</code> value |
---|
| 775 | */ |
---|
| 776 | public double [][][] getClusterModelsNumericAtts() { |
---|
| 777 | return m_modelNormal; |
---|
| 778 | } |
---|
| 779 | |
---|
| 780 | /** |
---|
| 781 | * Return the priors for the clusters |
---|
| 782 | * |
---|
| 783 | * @return a <code>double[]</code> value |
---|
| 784 | */ |
---|
| 785 | public double [] getClusterPriors() { |
---|
| 786 | return m_priors; |
---|
| 787 | } |
---|
| 788 | |
---|
| 789 | /** |
---|
| 790 | * Outputs the generated clusters into a string. |
---|
| 791 | * |
---|
| 792 | * @return the clusterer in string representation |
---|
| 793 | */ |
---|
| 794 | public String toString() { |
---|
| 795 | if (m_displayModelInOldFormat) { |
---|
| 796 | return toStringOriginal(); |
---|
| 797 | } |
---|
| 798 | |
---|
| 799 | if (m_priors == null) { |
---|
| 800 | return "No clusterer built yet!"; |
---|
| 801 | } |
---|
| 802 | StringBuffer temp = new StringBuffer(); |
---|
| 803 | temp.append("\nEM\n==\n"); |
---|
| 804 | if (m_initialNumClusters == -1) { |
---|
| 805 | temp.append("\nNumber of clusters selected by cross validation: " |
---|
| 806 | +m_num_clusters+"\n"); |
---|
| 807 | } else { |
---|
| 808 | temp.append("\nNumber of clusters: " + m_num_clusters + "\n"); |
---|
| 809 | } |
---|
| 810 | |
---|
| 811 | int maxWidth = 0; |
---|
| 812 | int maxAttWidth = 0; |
---|
| 813 | boolean containsKernel = false; |
---|
| 814 | |
---|
| 815 | // set up max widths |
---|
| 816 | // attributes |
---|
| 817 | for (int i = 0; i < m_num_attribs; i++) { |
---|
| 818 | Attribute a = m_theInstances.attribute(i); |
---|
| 819 | if (a.name().length() > maxAttWidth) { |
---|
| 820 | maxAttWidth = m_theInstances.attribute(i).name().length(); |
---|
| 821 | } |
---|
| 822 | if (a.isNominal()) { |
---|
| 823 | // check values |
---|
| 824 | for (int j = 0; j < a.numValues(); j++) { |
---|
| 825 | String val = a.value(j) + " "; |
---|
| 826 | if (val.length() > maxAttWidth) { |
---|
| 827 | maxAttWidth = val.length(); |
---|
| 828 | } |
---|
| 829 | } |
---|
| 830 | } |
---|
| 831 | } |
---|
| 832 | |
---|
| 833 | for (int i = 0; i < m_num_clusters; i++) { |
---|
| 834 | for (int j = 0; j < m_num_attribs; j++) { |
---|
| 835 | if (m_theInstances.attribute(j).isNumeric()) { |
---|
| 836 | // check mean and std. dev. against maxWidth |
---|
| 837 | double mean = Math.log(Math.abs(m_modelNormal[i][j][0])) / Math.log(10.0); |
---|
| 838 | double stdD = Math.log(Math.abs(m_modelNormal[i][j][1])) / Math.log(10.0); |
---|
| 839 | double width = (mean > stdD) |
---|
| 840 | ? mean |
---|
| 841 | : stdD; |
---|
| 842 | if (width < 0) { |
---|
| 843 | width = 1; |
---|
| 844 | } |
---|
| 845 | // decimal + # decimal places + 1 |
---|
| 846 | width += 6.0; |
---|
| 847 | if ((int)width > maxWidth) { |
---|
| 848 | maxWidth = (int)width; |
---|
| 849 | } |
---|
| 850 | } else { |
---|
| 851 | // nominal distributions |
---|
| 852 | DiscreteEstimator d = (DiscreteEstimator)m_model[i][j]; |
---|
| 853 | for (int k = 0; k < d.getNumSymbols(); k++) { |
---|
| 854 | String size = Utils.doubleToString(d.getCount(k), maxWidth, 4).trim(); |
---|
| 855 | if (size.length() > maxWidth) { |
---|
| 856 | maxWidth = size.length(); |
---|
| 857 | } |
---|
| 858 | } |
---|
| 859 | int sum = |
---|
| 860 | Utils.doubleToString(d.getSumOfCounts(), maxWidth, 4).trim().length(); |
---|
| 861 | if (sum > maxWidth) { |
---|
| 862 | maxWidth = sum; |
---|
| 863 | } |
---|
| 864 | } |
---|
| 865 | } |
---|
| 866 | } |
---|
| 867 | |
---|
| 868 | if (maxAttWidth < "Attribute".length()) { |
---|
| 869 | maxAttWidth = "Attribute".length(); |
---|
| 870 | } |
---|
| 871 | |
---|
| 872 | maxAttWidth += 2; |
---|
| 873 | |
---|
| 874 | temp.append("\n\n"); |
---|
| 875 | temp.append(pad("Cluster", " ", |
---|
| 876 | (maxAttWidth + maxWidth + 1) - "Cluster".length(), |
---|
| 877 | true)); |
---|
| 878 | |
---|
| 879 | temp.append("\n"); |
---|
| 880 | temp.append(pad("Attribute", " ", maxAttWidth - "Attribute".length(), false)); |
---|
| 881 | |
---|
| 882 | // cluster #'s |
---|
| 883 | for (int i = 0; i < m_num_clusters; i++) { |
---|
| 884 | String classL = "" + i; |
---|
| 885 | temp.append(pad(classL, " ", maxWidth + 1 - classL.length(), true)); |
---|
| 886 | } |
---|
| 887 | temp.append("\n"); |
---|
| 888 | |
---|
| 889 | // cluster priors |
---|
| 890 | temp.append(pad("", " ", maxAttWidth, true)); |
---|
| 891 | for (int i = 0; i < m_num_clusters; i++) { |
---|
| 892 | String priorP = Utils.doubleToString(m_priors[i], maxWidth, 2).trim(); |
---|
| 893 | priorP = "(" + priorP + ")"; |
---|
| 894 | temp.append(pad(priorP, " ", maxWidth + 1 - priorP.length(), true)); |
---|
| 895 | } |
---|
| 896 | |
---|
| 897 | temp.append("\n"); |
---|
| 898 | temp.append(pad("", "=", maxAttWidth + |
---|
| 899 | (maxWidth * m_num_clusters) |
---|
| 900 | + m_num_clusters + 1, true)); |
---|
| 901 | temp.append("\n"); |
---|
| 902 | |
---|
| 903 | for (int i = 0; i < m_num_attribs; i++) { |
---|
| 904 | String attName = m_theInstances.attribute(i).name(); |
---|
| 905 | temp.append(attName + "\n"); |
---|
| 906 | |
---|
| 907 | if (m_theInstances.attribute(i).isNumeric()) { |
---|
| 908 | String meanL = " mean"; |
---|
| 909 | temp.append(pad(meanL, " ", maxAttWidth + 1 - meanL.length(), false)); |
---|
| 910 | for (int j = 0; j < m_num_clusters; j++) { |
---|
| 911 | // means |
---|
| 912 | String mean = |
---|
| 913 | Utils.doubleToString(m_modelNormal[j][i][0], maxWidth, 4).trim(); |
---|
| 914 | temp.append(pad(mean, " ", maxWidth + 1 - mean.length(), true)); |
---|
| 915 | } |
---|
| 916 | temp.append("\n"); |
---|
| 917 | // now do std deviations |
---|
| 918 | String stdDevL = " std. dev."; |
---|
| 919 | temp.append(pad(stdDevL, " ", maxAttWidth + 1 - stdDevL.length(), false)); |
---|
| 920 | for (int j = 0; j < m_num_clusters; j++) { |
---|
| 921 | String stdDev = |
---|
| 922 | Utils.doubleToString(m_modelNormal[j][i][1], maxWidth, 4).trim(); |
---|
| 923 | temp.append(pad(stdDev, " ", maxWidth + 1 - stdDev.length(), true)); |
---|
| 924 | } |
---|
| 925 | temp.append("\n\n"); |
---|
| 926 | } else { |
---|
| 927 | Attribute a = m_theInstances.attribute(i); |
---|
| 928 | for (int j = 0; j < a.numValues(); j++) { |
---|
| 929 | String val = " " + a.value(j); |
---|
| 930 | temp.append(pad(val, " ", maxAttWidth + 1 - val.length(), false)); |
---|
| 931 | for (int k = 0; k < m_num_clusters; k++) { |
---|
| 932 | DiscreteEstimator d = (DiscreteEstimator)m_model[k][i]; |
---|
| 933 | String count = Utils.doubleToString(d.getCount(j), maxWidth, 4).trim(); |
---|
| 934 | temp.append(pad(count, " ", maxWidth + 1 - count.length(), true)); |
---|
| 935 | } |
---|
| 936 | temp.append("\n"); |
---|
| 937 | } |
---|
| 938 | // do the totals |
---|
| 939 | String total = " [total]"; |
---|
| 940 | temp.append(pad(total, " ", maxAttWidth + 1 - total.length(), false)); |
---|
| 941 | for (int k = 0; k < m_num_clusters; k++) { |
---|
| 942 | DiscreteEstimator d = (DiscreteEstimator)m_model[k][i]; |
---|
| 943 | String count = |
---|
| 944 | Utils.doubleToString(d.getSumOfCounts(), maxWidth, 4).trim(); |
---|
| 945 | temp.append(pad(count, " ", maxWidth + 1 - count.length(), true)); |
---|
| 946 | } |
---|
| 947 | temp.append("\n"); |
---|
| 948 | } |
---|
| 949 | } |
---|
| 950 | |
---|
| 951 | return temp.toString(); |
---|
| 952 | } |
---|
| 953 | |
---|
| 954 | private String pad(String source, String padChar, |
---|
| 955 | int length, boolean leftPad) { |
---|
| 956 | StringBuffer temp = new StringBuffer(); |
---|
| 957 | |
---|
| 958 | if (leftPad) { |
---|
| 959 | for (int i = 0; i< length; i++) { |
---|
| 960 | temp.append(padChar); |
---|
| 961 | } |
---|
| 962 | temp.append(source); |
---|
| 963 | } else { |
---|
| 964 | temp.append(source); |
---|
| 965 | for (int i = 0; i< length; i++) { |
---|
| 966 | temp.append(padChar); |
---|
| 967 | } |
---|
| 968 | } |
---|
| 969 | return temp.toString(); |
---|
| 970 | } |
---|
| 971 | |
---|
| 972 | /** |
---|
| 973 | * Outputs the generated clusters into a string. |
---|
| 974 | * |
---|
| 975 | * @return the clusterer in string representation |
---|
| 976 | */ |
---|
| 977 | protected String toStringOriginal () { |
---|
| 978 | if (m_priors == null) { |
---|
| 979 | return "No clusterer built yet!"; |
---|
| 980 | } |
---|
| 981 | StringBuffer temp = new StringBuffer(); |
---|
| 982 | temp.append("\nEM\n==\n"); |
---|
| 983 | if (m_initialNumClusters == -1) { |
---|
| 984 | temp.append("\nNumber of clusters selected by cross validation: " |
---|
| 985 | +m_num_clusters+"\n"); |
---|
| 986 | } else { |
---|
| 987 | temp.append("\nNumber of clusters: " + m_num_clusters + "\n"); |
---|
| 988 | } |
---|
| 989 | |
---|
| 990 | for (int j = 0; j < m_num_clusters; j++) { |
---|
| 991 | temp.append("\nCluster: " + j + " Prior probability: " |
---|
| 992 | + Utils.doubleToString(m_priors[j], 4) + "\n\n"); |
---|
| 993 | |
---|
| 994 | for (int i = 0; i < m_num_attribs; i++) { |
---|
| 995 | temp.append("Attribute: " + m_theInstances.attribute(i).name() + "\n"); |
---|
| 996 | |
---|
| 997 | if (m_theInstances.attribute(i).isNominal()) { |
---|
| 998 | if (m_model[j][i] != null) { |
---|
| 999 | temp.append(m_model[j][i].toString()); |
---|
| 1000 | } |
---|
| 1001 | } |
---|
| 1002 | else { |
---|
| 1003 | temp.append("Normal Distribution. Mean = " |
---|
| 1004 | + Utils.doubleToString(m_modelNormal[j][i][0], 4) |
---|
| 1005 | + " StdDev = " |
---|
| 1006 | + Utils.doubleToString(m_modelNormal[j][i][1], 4) |
---|
| 1007 | + "\n"); |
---|
| 1008 | } |
---|
| 1009 | } |
---|
| 1010 | } |
---|
| 1011 | |
---|
| 1012 | return temp.toString(); |
---|
| 1013 | } |
---|
| 1014 | |
---|
| 1015 | |
---|
| 1016 | /** |
---|
| 1017 | * verbose output for debugging |
---|
| 1018 | * @param inst the training instances |
---|
| 1019 | */ |
---|
| 1020 | private void EM_Report (Instances inst) { |
---|
| 1021 | int i, j, l, m; |
---|
| 1022 | System.out.println("======================================"); |
---|
| 1023 | |
---|
| 1024 | for (j = 0; j < m_num_clusters; j++) { |
---|
| 1025 | for (i = 0; i < m_num_attribs; i++) { |
---|
| 1026 | System.out.println("Clust: " + j + " att: " + i + "\n"); |
---|
| 1027 | |
---|
| 1028 | if (m_theInstances.attribute(i).isNominal()) { |
---|
| 1029 | if (m_model[j][i] != null) { |
---|
| 1030 | System.out.println(m_model[j][i].toString()); |
---|
| 1031 | } |
---|
| 1032 | } |
---|
| 1033 | else { |
---|
| 1034 | System.out.println("Normal Distribution. Mean = " |
---|
| 1035 | + Utils.doubleToString(m_modelNormal[j][i][0] |
---|
| 1036 | , 8, 4) |
---|
| 1037 | + " StandardDev = " |
---|
| 1038 | + Utils.doubleToString(m_modelNormal[j][i][1] |
---|
| 1039 | , 8, 4) |
---|
| 1040 | + " WeightSum = " |
---|
| 1041 | + Utils.doubleToString(m_modelNormal[j][i][2] |
---|
| 1042 | , 8, 4)); |
---|
| 1043 | } |
---|
| 1044 | } |
---|
| 1045 | } |
---|
| 1046 | |
---|
| 1047 | for (l = 0; l < inst.numInstances(); l++) { |
---|
| 1048 | m = Utils.maxIndex(m_weights[l]); |
---|
| 1049 | System.out.print("Inst " + Utils.doubleToString((double)l, 5, 0) |
---|
| 1050 | + " Class " + m + "\t"); |
---|
| 1051 | for (j = 0; j < m_num_clusters; j++) { |
---|
| 1052 | System.out.print(Utils.doubleToString(m_weights[l][j], 7, 5) + " "); |
---|
| 1053 | } |
---|
| 1054 | System.out.println(); |
---|
| 1055 | } |
---|
| 1056 | } |
---|
| 1057 | |
---|
| 1058 | |
---|
| 1059 | /** |
---|
| 1060 | * estimate the number of clusters by cross validation on the training |
---|
| 1061 | * data. |
---|
| 1062 | * |
---|
| 1063 | * @throws Exception if something goes wrong |
---|
| 1064 | */ |
---|
| 1065 | private void CVClusters () |
---|
| 1066 | throws Exception { |
---|
| 1067 | double CVLogLikely = -Double.MAX_VALUE; |
---|
| 1068 | double templl, tll; |
---|
| 1069 | boolean CVincreased = true; |
---|
| 1070 | m_num_clusters = 1; |
---|
| 1071 | int num_clusters = m_num_clusters; |
---|
| 1072 | int i; |
---|
| 1073 | Random cvr; |
---|
| 1074 | Instances trainCopy; |
---|
| 1075 | int numFolds = (m_theInstances.numInstances() < 10) |
---|
| 1076 | ? m_theInstances.numInstances() |
---|
| 1077 | : 10; |
---|
| 1078 | |
---|
| 1079 | boolean ok = true; |
---|
| 1080 | int seed = getSeed(); |
---|
| 1081 | int restartCount = 0; |
---|
| 1082 | CLUSTER_SEARCH: while (CVincreased) { |
---|
| 1083 | // theInstances.stratify(10); |
---|
| 1084 | |
---|
| 1085 | CVincreased = false; |
---|
| 1086 | cvr = new Random(getSeed()); |
---|
| 1087 | trainCopy = new Instances(m_theInstances); |
---|
| 1088 | trainCopy.randomize(cvr); |
---|
| 1089 | templl = 0.0; |
---|
| 1090 | for (i = 0; i < numFolds; i++) { |
---|
| 1091 | Instances cvTrain = trainCopy.trainCV(numFolds, i, cvr); |
---|
| 1092 | if (num_clusters > cvTrain.numInstances()) { |
---|
| 1093 | break CLUSTER_SEARCH; |
---|
| 1094 | } |
---|
| 1095 | Instances cvTest = trainCopy.testCV(numFolds, i); |
---|
| 1096 | m_rr = new Random(seed); |
---|
| 1097 | for (int z=0; z<10; z++) m_rr.nextDouble(); |
---|
| 1098 | m_num_clusters = num_clusters; |
---|
| 1099 | EM_Init(cvTrain); |
---|
| 1100 | try { |
---|
| 1101 | iterate(cvTrain, false); |
---|
| 1102 | } catch (Exception ex) { |
---|
| 1103 | // catch any problems - i.e. empty clusters occuring |
---|
| 1104 | ex.printStackTrace(); |
---|
| 1105 | // System.err.println("Restarting after CV training failure ("+num_clusters+" clusters"); |
---|
| 1106 | seed++; |
---|
| 1107 | restartCount++; |
---|
| 1108 | ok = false; |
---|
| 1109 | if (restartCount > 5) { |
---|
| 1110 | break CLUSTER_SEARCH; |
---|
| 1111 | } |
---|
| 1112 | break; |
---|
| 1113 | } |
---|
| 1114 | try { |
---|
| 1115 | tll = E(cvTest, false); |
---|
| 1116 | } catch (Exception ex) { |
---|
| 1117 | // catch any problems - i.e. empty clusters occuring |
---|
| 1118 | // ex.printStackTrace(); |
---|
| 1119 | ex.printStackTrace(); |
---|
| 1120 | // System.err.println("Restarting after CV testing failure ("+num_clusters+" clusters"); |
---|
| 1121 | // throw new Exception(ex); |
---|
| 1122 | seed++; |
---|
| 1123 | restartCount++; |
---|
| 1124 | ok = false; |
---|
| 1125 | if (restartCount > 5) { |
---|
| 1126 | break CLUSTER_SEARCH; |
---|
| 1127 | } |
---|
| 1128 | break; |
---|
| 1129 | } |
---|
| 1130 | |
---|
| 1131 | if (m_verbose) { |
---|
| 1132 | System.out.println("# clust: " + num_clusters + " Fold: " + i |
---|
| 1133 | + " Loglikely: " + tll); |
---|
| 1134 | } |
---|
| 1135 | templl += tll; |
---|
| 1136 | } |
---|
| 1137 | |
---|
| 1138 | if (ok) { |
---|
| 1139 | restartCount = 0; |
---|
| 1140 | seed = getSeed(); |
---|
| 1141 | templl /= (double)numFolds; |
---|
| 1142 | |
---|
| 1143 | if (m_verbose) { |
---|
| 1144 | System.out.println("===================================" |
---|
| 1145 | + "==============\n# clust: " |
---|
| 1146 | + num_clusters |
---|
| 1147 | + " Mean Loglikely: " |
---|
| 1148 | + templl |
---|
| 1149 | + "\n================================" |
---|
| 1150 | + "================="); |
---|
| 1151 | } |
---|
| 1152 | |
---|
| 1153 | if (templl > CVLogLikely) { |
---|
| 1154 | CVLogLikely = templl; |
---|
| 1155 | CVincreased = true; |
---|
| 1156 | num_clusters++; |
---|
| 1157 | } |
---|
| 1158 | } |
---|
| 1159 | } |
---|
| 1160 | |
---|
| 1161 | if (m_verbose) { |
---|
| 1162 | System.out.println("Number of clusters: " + (num_clusters - 1)); |
---|
| 1163 | } |
---|
| 1164 | |
---|
| 1165 | m_num_clusters = num_clusters - 1; |
---|
| 1166 | } |
---|
| 1167 | |
---|
| 1168 | |
---|
| 1169 | /** |
---|
| 1170 | * Returns the number of clusters. |
---|
| 1171 | * |
---|
| 1172 | * @return the number of clusters generated for a training dataset. |
---|
| 1173 | * @throws Exception if number of clusters could not be returned |
---|
| 1174 | * successfully |
---|
| 1175 | */ |
---|
| 1176 | public int numberOfClusters () |
---|
| 1177 | throws Exception { |
---|
| 1178 | if (m_num_clusters == -1) { |
---|
| 1179 | throw new Exception("Haven't generated any clusters!"); |
---|
| 1180 | } |
---|
| 1181 | |
---|
| 1182 | return m_num_clusters; |
---|
| 1183 | } |
---|
| 1184 | |
---|
| 1185 | /** |
---|
| 1186 | * Updates the minimum and maximum values for all the attributes |
---|
| 1187 | * based on a new instance. |
---|
| 1188 | * |
---|
| 1189 | * @param instance the new instance |
---|
| 1190 | */ |
---|
| 1191 | private void updateMinMax(Instance instance) { |
---|
| 1192 | |
---|
| 1193 | for (int j = 0; j < m_theInstances.numAttributes(); j++) { |
---|
| 1194 | if (!instance.isMissing(j)) { |
---|
| 1195 | if (Double.isNaN(m_minValues[j])) { |
---|
| 1196 | m_minValues[j] = instance.value(j); |
---|
| 1197 | m_maxValues[j] = instance.value(j); |
---|
| 1198 | } else { |
---|
| 1199 | if (instance.value(j) < m_minValues[j]) { |
---|
| 1200 | m_minValues[j] = instance.value(j); |
---|
| 1201 | } else { |
---|
| 1202 | if (instance.value(j) > m_maxValues[j]) { |
---|
| 1203 | m_maxValues[j] = instance.value(j); |
---|
| 1204 | } |
---|
| 1205 | } |
---|
| 1206 | } |
---|
| 1207 | } |
---|
| 1208 | } |
---|
| 1209 | } |
---|
| 1210 | |
---|
| 1211 | /** |
---|
| 1212 | * Returns default capabilities of the clusterer (i.e., the ones of |
---|
| 1213 | * SimpleKMeans). |
---|
| 1214 | * |
---|
| 1215 | * @return the capabilities of this clusterer |
---|
| 1216 | */ |
---|
| 1217 | public Capabilities getCapabilities() { |
---|
| 1218 | Capabilities result = new SimpleKMeans().getCapabilities(); |
---|
| 1219 | result.setOwner(this); |
---|
| 1220 | return result; |
---|
| 1221 | } |
---|
| 1222 | |
---|
| 1223 | /** |
---|
| 1224 | * Generates a clusterer. Has to initialize all fields of the clusterer |
---|
| 1225 | * that are not being set via options. |
---|
| 1226 | * |
---|
| 1227 | * @param data set of instances serving as training data |
---|
| 1228 | * @throws Exception if the clusterer has not been |
---|
| 1229 | * generated successfully |
---|
| 1230 | */ |
---|
| 1231 | public void buildClusterer (Instances data) |
---|
| 1232 | throws Exception { |
---|
| 1233 | |
---|
| 1234 | // can clusterer handle the data? |
---|
| 1235 | getCapabilities().testWithFail(data); |
---|
| 1236 | |
---|
| 1237 | m_replaceMissing = new ReplaceMissingValues(); |
---|
| 1238 | Instances instances = new Instances(data); |
---|
| 1239 | instances.setClassIndex(-1); |
---|
| 1240 | m_replaceMissing.setInputFormat(instances); |
---|
| 1241 | data = weka.filters.Filter.useFilter(instances, m_replaceMissing); |
---|
| 1242 | instances = null; |
---|
| 1243 | |
---|
| 1244 | m_theInstances = data; |
---|
| 1245 | |
---|
| 1246 | // calculate min and max values for attributes |
---|
| 1247 | m_minValues = new double [m_theInstances.numAttributes()]; |
---|
| 1248 | m_maxValues = new double [m_theInstances.numAttributes()]; |
---|
| 1249 | for (int i = 0; i < m_theInstances.numAttributes(); i++) { |
---|
| 1250 | m_minValues[i] = m_maxValues[i] = Double.NaN; |
---|
| 1251 | } |
---|
| 1252 | for (int i = 0; i < m_theInstances.numInstances(); i++) { |
---|
| 1253 | updateMinMax(m_theInstances.instance(i)); |
---|
| 1254 | } |
---|
| 1255 | |
---|
| 1256 | doEM(); |
---|
| 1257 | |
---|
| 1258 | // save memory |
---|
| 1259 | m_theInstances = new Instances(m_theInstances,0); |
---|
| 1260 | } |
---|
| 1261 | |
---|
| 1262 | /** |
---|
| 1263 | * Returns the cluster priors. |
---|
| 1264 | * |
---|
| 1265 | * @return the cluster priors |
---|
| 1266 | */ |
---|
| 1267 | public double[] clusterPriors() { |
---|
| 1268 | |
---|
| 1269 | double[] n = new double[m_priors.length]; |
---|
| 1270 | |
---|
| 1271 | System.arraycopy(m_priors, 0, n, 0, n.length); |
---|
| 1272 | return n; |
---|
| 1273 | } |
---|
| 1274 | |
---|
| 1275 | /** |
---|
| 1276 | * Computes the log of the conditional density (per cluster) for a given instance. |
---|
| 1277 | * |
---|
| 1278 | * @param inst the instance to compute the density for |
---|
| 1279 | * @return an array containing the estimated densities |
---|
| 1280 | * @throws Exception if the density could not be computed |
---|
| 1281 | * successfully |
---|
| 1282 | */ |
---|
| 1283 | public double[] logDensityPerClusterForInstance(Instance inst) throws Exception { |
---|
| 1284 | |
---|
| 1285 | int i, j; |
---|
| 1286 | double logprob; |
---|
| 1287 | double[] wghts = new double[m_num_clusters]; |
---|
| 1288 | |
---|
| 1289 | m_replaceMissing.input(inst); |
---|
| 1290 | inst = m_replaceMissing.output(); |
---|
| 1291 | |
---|
| 1292 | for (i = 0; i < m_num_clusters; i++) { |
---|
| 1293 | // System.err.println("Cluster : "+i); |
---|
| 1294 | logprob = 0.0; |
---|
| 1295 | |
---|
| 1296 | for (j = 0; j < m_num_attribs; j++) { |
---|
| 1297 | if (!inst.isMissing(j)) { |
---|
| 1298 | if (inst.attribute(j).isNominal()) { |
---|
| 1299 | logprob += Math.log(m_model[i][j].getProbability(inst.value(j))); |
---|
| 1300 | } |
---|
| 1301 | else { // numeric attribute |
---|
| 1302 | logprob += logNormalDens(inst.value(j), |
---|
| 1303 | m_modelNormal[i][j][0], |
---|
| 1304 | m_modelNormal[i][j][1]); |
---|
| 1305 | /* System.err.println(logNormalDens(inst.value(j), |
---|
| 1306 | m_modelNormal[i][j][0], |
---|
| 1307 | m_modelNormal[i][j][1]) + " "); */ |
---|
| 1308 | } |
---|
| 1309 | } |
---|
| 1310 | } |
---|
| 1311 | // System.err.println(""); |
---|
| 1312 | |
---|
| 1313 | wghts[i] = logprob; |
---|
| 1314 | } |
---|
| 1315 | return wghts; |
---|
| 1316 | } |
---|
| 1317 | |
---|
| 1318 | |
---|
| 1319 | /** |
---|
| 1320 | * Perform the EM algorithm |
---|
| 1321 | * |
---|
| 1322 | * @throws Exception if something goes wrong |
---|
| 1323 | */ |
---|
| 1324 | private void doEM () |
---|
| 1325 | throws Exception { |
---|
| 1326 | |
---|
| 1327 | if (m_verbose) { |
---|
| 1328 | System.out.println("Seed: " + getSeed()); |
---|
| 1329 | } |
---|
| 1330 | |
---|
| 1331 | m_rr = new Random(getSeed()); |
---|
| 1332 | |
---|
| 1333 | // throw away numbers to avoid problem of similar initial numbers |
---|
| 1334 | // from a similar seed |
---|
| 1335 | for (int i=0; i<10; i++) m_rr.nextDouble(); |
---|
| 1336 | |
---|
| 1337 | m_num_instances = m_theInstances.numInstances(); |
---|
| 1338 | m_num_attribs = m_theInstances.numAttributes(); |
---|
| 1339 | |
---|
| 1340 | if (m_verbose) { |
---|
| 1341 | System.out.println("Number of instances: " |
---|
| 1342 | + m_num_instances |
---|
| 1343 | + "\nNumber of atts: " |
---|
| 1344 | + m_num_attribs |
---|
| 1345 | + "\n"); |
---|
| 1346 | } |
---|
| 1347 | |
---|
| 1348 | // setDefaultStdDevs(theInstances); |
---|
| 1349 | // cross validate to determine number of clusters? |
---|
| 1350 | if (m_initialNumClusters == -1) { |
---|
| 1351 | if (m_theInstances.numInstances() > 9) { |
---|
| 1352 | CVClusters(); |
---|
| 1353 | m_rr = new Random(getSeed()); |
---|
| 1354 | for (int i=0; i<10; i++) m_rr.nextDouble(); |
---|
| 1355 | } else { |
---|
| 1356 | m_num_clusters = 1; |
---|
| 1357 | } |
---|
| 1358 | } |
---|
| 1359 | |
---|
| 1360 | // fit full training set |
---|
| 1361 | EM_Init(m_theInstances); |
---|
| 1362 | m_loglikely = iterate(m_theInstances, m_verbose); |
---|
| 1363 | } |
---|
| 1364 | |
---|
| 1365 | |
---|
| 1366 | /** |
---|
| 1367 | * iterates the E and M steps until the log likelihood of the data |
---|
| 1368 | * converges. |
---|
| 1369 | * |
---|
| 1370 | * @param inst the training instances. |
---|
| 1371 | * @param report be verbose. |
---|
| 1372 | * @return the log likelihood of the data |
---|
| 1373 | * @throws Exception if something goes wrong |
---|
| 1374 | */ |
---|
| 1375 | private double iterate (Instances inst, boolean report) |
---|
| 1376 | throws Exception { |
---|
| 1377 | |
---|
| 1378 | int i; |
---|
| 1379 | double llkold = 0.0; |
---|
| 1380 | double llk = 0.0; |
---|
| 1381 | |
---|
| 1382 | if (report) { |
---|
| 1383 | EM_Report(inst); |
---|
| 1384 | } |
---|
| 1385 | |
---|
| 1386 | boolean ok = false; |
---|
| 1387 | int seed = getSeed(); |
---|
| 1388 | int restartCount = 0; |
---|
| 1389 | while (!ok) { |
---|
| 1390 | try { |
---|
| 1391 | for (i = 0; i < m_max_iterations; i++) { |
---|
| 1392 | llkold = llk; |
---|
| 1393 | llk = E(inst, true); |
---|
| 1394 | |
---|
| 1395 | if (report) { |
---|
| 1396 | System.out.println("Loglikely: " + llk); |
---|
| 1397 | } |
---|
| 1398 | |
---|
| 1399 | if (i > 0) { |
---|
| 1400 | if ((llk - llkold) < 1e-6) { |
---|
| 1401 | break; |
---|
| 1402 | } |
---|
| 1403 | } |
---|
| 1404 | M(inst); |
---|
| 1405 | } |
---|
| 1406 | ok = true; |
---|
| 1407 | } catch (Exception ex) { |
---|
| 1408 | // System.err.println("Restarting after training failure"); |
---|
| 1409 | ex.printStackTrace(); |
---|
| 1410 | seed++; |
---|
| 1411 | restartCount++; |
---|
| 1412 | m_rr = new Random(seed); |
---|
| 1413 | for (int z = 0; z < 10; z++) { |
---|
| 1414 | m_rr.nextDouble(); m_rr.nextInt(); |
---|
| 1415 | } |
---|
| 1416 | if (restartCount > 5) { |
---|
| 1417 | // System.err.println("Reducing the number of clusters"); |
---|
| 1418 | m_num_clusters--; |
---|
| 1419 | restartCount = 0; |
---|
| 1420 | } |
---|
| 1421 | EM_Init(m_theInstances); |
---|
| 1422 | } |
---|
| 1423 | } |
---|
| 1424 | |
---|
| 1425 | if (report) { |
---|
| 1426 | EM_Report(inst); |
---|
| 1427 | } |
---|
| 1428 | |
---|
| 1429 | return llk; |
---|
| 1430 | } |
---|
| 1431 | |
---|
| 1432 | /** |
---|
| 1433 | * Returns the revision string. |
---|
| 1434 | * |
---|
| 1435 | * @return the revision |
---|
| 1436 | */ |
---|
| 1437 | public String getRevision() { |
---|
| 1438 | return RevisionUtils.extract("$Revision: 1.44 $"); |
---|
| 1439 | } |
---|
| 1440 | |
---|
| 1441 | // ============ |
---|
| 1442 | // Test method. |
---|
| 1443 | // ============ |
---|
| 1444 | /** |
---|
| 1445 | * Main method for testing this class. |
---|
| 1446 | * |
---|
| 1447 | * @param argv should contain the following arguments: <p> |
---|
| 1448 | * -t training file [-T test file] [-N number of clusters] [-S random seed] |
---|
| 1449 | */ |
---|
| 1450 | public static void main (String[] argv) { |
---|
| 1451 | runClusterer(new EM(), argv); |
---|
| 1452 | } |
---|
| 1453 | } |
---|
| 1454 | |
---|