source: src/main/java/weka/clusterers/EM.java @ 18

Last change on this file since 18 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 40.0 KB
RevLine 
[4]1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    EM.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.clusterers;
24
25import weka.core.Capabilities;
26import weka.core.Instance;
27import weka.core.Attribute;
28import weka.core.Instances;
29import weka.core.Option;
30import weka.core.RevisionUtils;
31import weka.core.Utils;
32import weka.core.WeightedInstancesHandler;
33import weka.estimators.DiscreteEstimator;
34import weka.estimators.Estimator;
35import weka.filters.unsupervised.attribute.ReplaceMissingValues;
36
37import java.util.Enumeration;
38import java.util.Random;
39import java.util.Vector;
40
41/**
42 <!-- globalinfo-start -->
43 * Simple EM (expectation maximisation) class.<br/>
44 * <br/>
45 * EM assigns a probability distribution to each instance which indicates the probability of it belonging to each of the clusters. EM can decide how many clusters to create by cross validation, or you may specify apriori how many clusters to generate.<br/>
46 * <br/>
47 * The cross validation performed to determine the number of clusters is done in the following steps:<br/>
48 * 1. the number of clusters is set to 1<br/>
49 * 2. the training set is split randomly into 10 folds.<br/>
50 * 3. EM is performed 10 times using the 10 folds the usual CV way.<br/>
51 * 4. the loglikelihood is averaged over all 10 results.<br/>
52 * 5. if loglikelihood has increased the number of clusters is increased by 1 and the program continues at step 2. <br/>
53 * <br/>
54 * The number of folds is fixed to 10, as long as the number of instances in the training set is not smaller 10. If this is the case the number of folds is set equal to the number of instances.
55 * <p/>
56 <!-- globalinfo-end -->
57 *
58 <!-- options-start -->
59 * Valid options are: <p/>
60 *
61 * <pre> -N &lt;num&gt;
62 *  number of clusters. If omitted or -1 specified, then
63 *  cross validation is used to select the number of clusters.</pre>
64 *
65 * <pre> -I &lt;num&gt;
66 *  max iterations.
67 * (default 100)</pre>
68 *
69 * <pre> -V
70 *  verbose.</pre>
71 *
72 * <pre> -M &lt;num&gt;
73 *  minimum allowable standard deviation for normal density
74 *  computation
75 *  (default 1e-6)</pre>
76 *
77 * <pre> -O
78 *  Display model in old format (good when there are many clusters)
79 * </pre>
80 *
81 * <pre> -S &lt;num&gt;
82 *  Random number seed.
83 *  (default 100)</pre>
84 *
85 <!-- options-end -->
86 *
87 * @author Mark Hall (mhall@cs.waikato.ac.nz)
88 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
89 * @version $Revision: 1.44 $
90 */
91public class EM
92  extends RandomizableDensityBasedClusterer
93  implements NumberOfClustersRequestable, WeightedInstancesHandler {
94
95  /** for serialization */
96  static final long serialVersionUID = 8348181483812829475L;
97 
98  /** hold the discrete estimators for each cluster */
99  private Estimator m_model[][];
100
101  /** hold the normal estimators for each cluster */
102  private double m_modelNormal[][][];
103
104  /** default minimum standard deviation */
105  private double m_minStdDev = 1e-6;
106
107  private double [] m_minStdDevPerAtt;
108
109  /** hold the weights of each instance for each cluster */
110  private double m_weights[][];
111
112  /** the prior probabilities for clusters */
113  private double m_priors[];
114
115  /** the loglikelihood of the data */
116  private double m_loglikely;
117
118  /** training instances */
119  private Instances m_theInstances = null;
120
121  /** number of clusters selected by the user or cross validation */
122  private int m_num_clusters;
123
124  /** the initial number of clusters requested by the user--- -1 if
125      xval is to be used to find the number of clusters */
126  private int m_initialNumClusters;
127
128  /** number of attributes */
129  private int m_num_attribs;
130
131  /** number of training instances */
132  private int m_num_instances;
133
134  /** maximum iterations to perform */
135  private int m_max_iterations;
136
137  /** attribute min values */
138  private double [] m_minValues;
139
140  /** attribute max values */
141  private double [] m_maxValues;
142
143  /** random number generator */
144  private Random m_rr;
145
146  /** Verbose? */
147  private boolean m_verbose;
148
149 /** globally replace missing values */
150  private ReplaceMissingValues m_replaceMissing;
151
152  /** display model output in old-style format */
153  private boolean m_displayModelInOldFormat;
154
155  /**
156   * Returns a string describing this clusterer
157   * @return a description of the evaluator suitable for
158   * displaying in the explorer/experimenter gui
159   */
160  public String globalInfo() {
161    return
162        "Simple EM (expectation maximisation) class.\n\n"
163      + "EM assigns a probability distribution to each instance which "
164      + "indicates the probability of it belonging to each of the clusters. "
165      + "EM can decide how many clusters to create by cross validation, or you "
166      + "may specify apriori how many clusters to generate.\n\n"
167      + "The cross validation performed to determine the number of clusters "
168      + "is done in the following steps:\n"
169      + "1. the number of clusters is set to 1\n"
170      + "2. the training set is split randomly into 10 folds.\n"
171      + "3. EM is performed 10 times using the 10 folds the usual CV way.\n"
172      + "4. the loglikelihood is averaged over all 10 results.\n"
173      + "5. if loglikelihood has increased the number of clusters is increased "
174      + "by 1 and the program continues at step 2. \n\n"
175      + "The number of folds is fixed to 10, as long as the number of "
176      + "instances in the training set is not smaller 10. If this is the case "
177      + "the number of folds is set equal to the number of instances.";
178  }
179
180  /**
181   * Returns an enumeration describing the available options.
182   *
183   * @return an enumeration of all the available options.
184   */
185  public Enumeration listOptions () {
186    Vector result = new Vector();
187   
188    result.addElement(new Option(
189        "\tnumber of clusters. If omitted or -1 specified, then \n"
190        + "\tcross validation is used to select the number of clusters.", 
191        "N", 1, "-N <num>"));
192
193    result.addElement(new Option(
194        "\tmax iterations."
195        + "\n(default 100)", 
196        "I", 1, "-I <num>"));
197   
198    result.addElement(new Option(
199        "\tverbose.",
200        "V", 0, "-V"));
201   
202    result.addElement(new Option(
203        "\tminimum allowable standard deviation for normal density\n"
204        + "\tcomputation\n"
205        + "\t(default 1e-6)",
206        "M",1,"-M <num>"));
207
208    result.addElement(
209              new Option("\tDisplay model in old format (good when there are "
210                         + "many clusters)\n",
211                         "O", 0, "-O"));
212
213    Enumeration en = super.listOptions();
214    while (en.hasMoreElements())
215      result.addElement(en.nextElement());
216   
217    return  result.elements();
218  }
219
220
221  /**
222   * Parses a given list of options. <p/>
223   *
224   <!-- options-start -->
225   * Valid options are: <p/>
226   *
227   * <pre> -N &lt;num&gt;
228   *  number of clusters. If omitted or -1 specified, then
229   *  cross validation is used to select the number of clusters.</pre>
230   *
231   * <pre> -I &lt;num&gt;
232   *  max iterations.
233   * (default 100)</pre>
234   *
235   * <pre> -V
236   *  verbose.</pre>
237   *
238   * <pre> -M &lt;num&gt;
239   *  minimum allowable standard deviation for normal density
240   *  computation
241   *  (default 1e-6)</pre>
242   *
243   * <pre> -O
244   *  Display model in old format (good when there are many clusters)
245   * </pre>
246   *
247   * <pre> -S &lt;num&gt;
248   *  Random number seed.
249   *  (default 100)</pre>
250   *
251   <!-- options-end -->
252   *
253   * @param options the list of options as an array of strings
254   * @throws Exception if an option is not supported
255   */
256  public void setOptions (String[] options)
257    throws Exception {
258    resetOptions();
259    setDebug(Utils.getFlag('V', options));
260    String optionString = Utils.getOption('I', options);
261
262    if (optionString.length() != 0) {
263      setMaxIterations(Integer.parseInt(optionString));
264    }
265
266    optionString = Utils.getOption('N', options);
267    if (optionString.length() != 0) {
268      setNumClusters(Integer.parseInt(optionString));
269    }
270
271    optionString = Utils.getOption('M', options);
272    if (optionString.length() != 0) {
273      setMinStdDev((new Double(optionString)).doubleValue());
274    }
275
276    setDisplayModelInOldFormat(Utils.getFlag('O', options));
277   
278    super.setOptions(options);
279  }
280
281  /**
282   * Returns the tip text for this property
283   * @return tip text for this property suitable for
284   * displaying in the explorer/experimenter gui
285   */
286  public String displayModelInOldFormatTipText() {
287    return "Use old format for model output. The old format is "
288      + "better when there are many clusters. The new format "
289      + "is better when there are fewer clusters and many attributes.";
290  }
291
292  /**
293   * Set whether to display model output in the old, original
294   * format.
295   *
296   * @param d true if model ouput is to be shown in the old format
297   */
298  public void setDisplayModelInOldFormat(boolean d) {
299    m_displayModelInOldFormat = d;
300  }
301
302  /**
303   * Get whether to display model output in the old, original
304   * format.
305   *
306   * @return true if model ouput is to be shown in the old format
307   */
308  public boolean getDisplayModelInOldFormat() {
309    return m_displayModelInOldFormat;
310  }
311
312  /**
313   * Returns the tip text for this property
314   * @return tip text for this property suitable for
315   * displaying in the explorer/experimenter gui
316   */
317  public String minStdDevTipText() {
318    return "set minimum allowable standard deviation";
319  }
320
321  /**
322   * Set the minimum value for standard deviation when calculating
323   * normal density. Reducing this value can help prevent arithmetic
324   * overflow resulting from multiplying large densities (arising from small
325   * standard deviations) when there are many singleton or near singleton
326   * values.
327   * @param m minimum value for standard deviation
328   */
329  public void setMinStdDev(double m) {
330    m_minStdDev = m;
331  }
332
333  public void setMinStdDevPerAtt(double [] m) {
334    m_minStdDevPerAtt = m;
335  }
336
337  /**
338   * Get the minimum allowable standard deviation.
339   * @return the minumum allowable standard deviation
340   */
341  public double getMinStdDev() {
342    return m_minStdDev;
343  }
344
345  /**
346   * Returns the tip text for this property
347   * @return tip text for this property suitable for
348   * displaying in the explorer/experimenter gui
349   */
350  public String numClustersTipText() {
351    return "set number of clusters. -1 to select number of clusters "
352      +"automatically by cross validation.";
353  }
354
355  /**
356   * Set the number of clusters (-1 to select by CV).
357   *
358   * @param n the number of clusters
359   * @throws Exception if n is 0
360   */
361  public void setNumClusters (int n)
362    throws Exception {
363   
364    if (n == 0) {
365      throw  new Exception("Number of clusters must be > 0. (or -1 to " 
366                           + "select by cross validation).");
367    }
368
369    if (n < 0) {
370      m_num_clusters = -1;
371      m_initialNumClusters = -1;
372    }
373    else {
374      m_num_clusters = n;
375      m_initialNumClusters = n;
376    }
377  }
378
379
380  /**
381   * Get the number of clusters
382   *
383   * @return the number of clusters.
384   */
385  public int getNumClusters () {
386    return  m_initialNumClusters;
387  }
388
389  /**
390   * Returns the tip text for this property
391   * @return tip text for this property suitable for
392   * displaying in the explorer/experimenter gui
393   */
394  public String maxIterationsTipText() {
395    return "maximum number of iterations";
396  }
397
398  /**
399   * Set the maximum number of iterations to perform
400   *
401   * @param i the number of iterations
402   * @throws Exception if i is less than 1
403   */
404  public void setMaxIterations (int i)
405    throws Exception {
406    if (i < 1) {
407      throw  new Exception("Maximum number of iterations must be > 0!");
408    }
409
410    m_max_iterations = i;
411  }
412
413
414  /**
415   * Get the maximum number of iterations
416   *
417   * @return the number of iterations
418   */
419  public int getMaxIterations () {
420    return  m_max_iterations;
421  }
422
423 
424  /**
425   * Returns the tip text for this property
426   * @return tip text for this property suitable for
427   * displaying in the explorer/experimenter gui
428   */
429  public String debugTipText() {
430    return "If set to true, clusterer may output additional info to " +
431      "the console.";
432  }
433
434
435  /**
436   * Set debug mode - verbose output
437   *
438   * @param v true for verbose output
439   */
440  public void setDebug (boolean v) {
441    m_verbose = v;
442  }
443
444
445  /**
446   * Get debug mode
447   *
448   * @return true if debug mode is set
449   */
450  public boolean getDebug () {
451    return  m_verbose;
452  }
453
454
455  /**
456   * Gets the current settings of EM.
457   *
458   * @return an array of strings suitable for passing to setOptions()
459   */
460  public String[] getOptions () {
461    int         i;
462    Vector      result;
463    String[]    options;
464
465    result = new Vector();
466
467    result.add("-I");
468    result.add("" + m_max_iterations);
469    result.add("-N");
470    result.add("" + getNumClusters());
471    result.add("-M");
472    result.add("" + getMinStdDev());
473    if (m_displayModelInOldFormat) {
474      result.add("-O");
475    }
476
477    options = super.getOptions();
478    for (i = 0; i < options.length; i++)
479      result.add(options[i]);
480
481    return (String[]) result.toArray(new String[result.size()]);         
482  }
483
484  /**
485   * Initialise estimators and storage.
486   *
487   * @param inst the instances
488   * @throws Exception if initialization fails
489   **/
490  private void EM_Init (Instances inst)
491    throws Exception {
492    int i, j, k;
493
494    // run k means 10 times and choose best solution
495    SimpleKMeans bestK = null;
496    double bestSqE = Double.MAX_VALUE;
497    for (i = 0; i < 10; i++) {
498      SimpleKMeans sk = new SimpleKMeans();
499      sk.setSeed(m_rr.nextInt());
500      sk.setNumClusters(m_num_clusters);
501      sk.setDisplayStdDevs(true);
502      sk.buildClusterer(inst);
503      if (sk.getSquaredError() < bestSqE) {
504        bestSqE = sk.getSquaredError();
505        bestK = sk;
506      }
507    }
508   
509    // initialize with best k-means solution
510    m_num_clusters = bestK.numberOfClusters();
511    m_weights = new double[inst.numInstances()][m_num_clusters];
512    m_model = new DiscreteEstimator[m_num_clusters][m_num_attribs];
513    m_modelNormal = new double[m_num_clusters][m_num_attribs][3];
514    m_priors = new double[m_num_clusters];
515    Instances centers = bestK.getClusterCentroids();
516    Instances stdD = bestK.getClusterStandardDevs();
517    int [][][] nominalCounts = bestK.getClusterNominalCounts();
518    int [] clusterSizes = bestK.getClusterSizes();
519
520    for (i = 0; i < m_num_clusters; i++) {
521      Instance center = centers.instance(i);
522      for (j = 0; j < m_num_attribs; j++) {
523        if (inst.attribute(j).isNominal()) {
524          m_model[i][j] = new DiscreteEstimator(m_theInstances.
525                                                attribute(j).numValues()
526                                                , true);
527          for (k = 0; k < inst.attribute(j).numValues(); k++) {
528            m_model[i][j].addValue(k, nominalCounts[i][j][k]);
529          }
530        } else {
531          double minStdD = (m_minStdDevPerAtt != null)
532            ? m_minStdDevPerAtt[j]
533            : m_minStdDev;
534          double mean = (center.isMissing(j))
535            ? inst.meanOrMode(j)
536            : center.value(j);
537          m_modelNormal[i][j][0] = mean;
538          double stdv = (stdD.instance(i).isMissing(j))
539            ? ((m_maxValues[j] - m_minValues[j]) / (2 * m_num_clusters))
540            : stdD.instance(i).value(j);
541          if (stdv < minStdD) {
542            stdv = inst.attributeStats(j).numericStats.stdDev;
543            if (Double.isInfinite(stdv)) {
544              stdv = minStdD;
545            }
546            if (stdv < minStdD) {
547              stdv = minStdD;
548            }
549          }
550          if (stdv <= 0) {
551            stdv = m_minStdDev;
552          }
553
554          m_modelNormal[i][j][1] = stdv;
555          m_modelNormal[i][j][2] = 1.0;
556        }
557      } 
558    }   
559   
560   
561    for (j = 0; j < m_num_clusters; j++) {
562      //      m_priors[j] += 1.0;
563      m_priors[j] = clusterSizes[j];
564    }
565    Utils.normalize(m_priors);
566  }
567
568
569  /**
570   * calculate prior probabilites for the clusters
571   *
572   * @param inst the instances
573   * @throws Exception if priors can't be calculated
574   **/
575  private void estimate_priors (Instances inst)
576    throws Exception {
577
578    for (int i = 0; i < m_num_clusters; i++) {
579      m_priors[i] = 0.0;
580    }
581
582    for (int i = 0; i < inst.numInstances(); i++) {
583      for (int j = 0; j < m_num_clusters; j++) {
584        m_priors[j] += inst.instance(i).weight() * m_weights[i][j];
585      }
586    }
587
588    Utils.normalize(m_priors);
589  }
590
591
592  /** Constant for normal distribution. */
593  private static double m_normConst = Math.log(Math.sqrt(2*Math.PI));
594
595  /**
596   * Density function of normal distribution.
597   * @param x input value
598   * @param mean mean of distribution
599   * @param stdDev standard deviation of distribution
600   * @return the density
601   */
602  private double logNormalDens (double x, double mean, double stdDev) {
603
604    double diff = x - mean;
605    //    System.err.println("x: "+x+" mean: "+mean+" diff: "+diff+" stdv: "+stdDev);
606    //    System.err.println("diff*diff/(2*stdv*stdv): "+ (diff * diff / (2 * stdDev * stdDev)));
607   
608    return - (diff * diff / (2 * stdDev * stdDev))  - m_normConst - Math.log(stdDev);
609  }
610
611  /**
612   * New probability estimators for an iteration
613   */
614  private void new_estimators () {
615    for (int i = 0; i < m_num_clusters; i++) {
616      for (int j = 0; j < m_num_attribs; j++) {
617        if (m_theInstances.attribute(j).isNominal()) {
618          m_model[i][j] = new DiscreteEstimator(m_theInstances.
619                                                attribute(j).numValues()
620                                                , true);
621        }
622        else {
623          m_modelNormal[i][j][0] = m_modelNormal[i][j][1] = 
624            m_modelNormal[i][j][2] = 0.0;
625        }
626      }
627    }
628  }
629
630
631  /**
632   * The M step of the EM algorithm.
633   * @param inst the training instances
634   * @throws Exception if something goes wrong
635   */
636  private void M (Instances inst)
637    throws Exception {
638
639    int i, j, l;
640
641    new_estimators();
642
643    for (i = 0; i < m_num_clusters; i++) {
644      for (j = 0; j < m_num_attribs; j++) {
645        for (l = 0; l < inst.numInstances(); l++) {
646          Instance in = inst.instance(l);
647          if (!in.isMissing(j)) {
648            if (inst.attribute(j).isNominal()) {
649              m_model[i][j].addValue(in.value(j), 
650                                     in.weight() * m_weights[l][i]);
651            }
652            else {
653              m_modelNormal[i][j][0] += (in.value(j) * in.weight() *
654                                         m_weights[l][i]);
655              m_modelNormal[i][j][2] += in.weight() * m_weights[l][i];
656              m_modelNormal[i][j][1] += (in.value(j) * 
657                                         in.value(j) * in.weight() * m_weights[l][i]);
658            }
659          }
660        }
661      }
662    }
663   
664    // calcualte mean and std deviation for numeric attributes
665    for (j = 0; j < m_num_attribs; j++) {
666      if (!inst.attribute(j).isNominal()) {
667        for (i = 0; i < m_num_clusters; i++) {
668          if (m_modelNormal[i][j][2] <= 0) {
669            m_modelNormal[i][j][1] = Double.MAX_VALUE;
670            //      m_modelNormal[i][j][0] = 0;
671            m_modelNormal[i][j][0] = m_minStdDev;
672          } else {
673             
674            // variance
675            m_modelNormal[i][j][1] = (m_modelNormal[i][j][1] - 
676                                      (m_modelNormal[i][j][0] * 
677                                       m_modelNormal[i][j][0] / 
678                                       m_modelNormal[i][j][2])) / 
679              (m_modelNormal[i][j][2]);
680           
681            if (m_modelNormal[i][j][1] < 0) {
682              m_modelNormal[i][j][1] = 0;
683            }
684           
685            // std dev     
686            double minStdD = (m_minStdDevPerAtt != null)
687            ? m_minStdDevPerAtt[j]
688            : m_minStdDev;
689
690            m_modelNormal[i][j][1] = Math.sqrt(m_modelNormal[i][j][1]);             
691
692            if ((m_modelNormal[i][j][1] <= minStdD)) {
693              m_modelNormal[i][j][1] = inst.attributeStats(j).numericStats.stdDev;
694              if ((m_modelNormal[i][j][1] <= minStdD)) {
695                m_modelNormal[i][j][1] = minStdD;
696              }
697            }
698            if ((m_modelNormal[i][j][1] <= 0)) {
699              m_modelNormal[i][j][1] = m_minStdDev;
700            }
701            if (Double.isInfinite(m_modelNormal[i][j][1])) {
702              m_modelNormal[i][j][1] = m_minStdDev;
703            }
704           
705            // mean
706            m_modelNormal[i][j][0] /= m_modelNormal[i][j][2];
707          }
708        }
709      }
710    }
711  }
712
713  /**
714   * The E step of the EM algorithm. Estimate cluster membership
715   * probabilities.
716   *
717   * @param inst the training instances
718   * @param change_weights whether to change the weights
719   * @return the average log likelihood
720   * @throws Exception if computation fails
721   */
722  private double E (Instances inst, boolean change_weights)
723    throws Exception {
724
725    double loglk = 0.0, sOW = 0.0;
726
727    for (int l = 0; l < inst.numInstances(); l++) {
728
729      Instance in = inst.instance(l);
730
731      loglk += in.weight() * logDensityForInstance(in);
732      sOW += in.weight();
733
734      if (change_weights) {
735        m_weights[l] = distributionForInstance(in);
736      }
737    }
738   
739    // reestimate priors
740    if (change_weights) {
741      estimate_priors(inst);
742    }
743    return  loglk / sOW;
744  }
745 
746 
747  /**
748   * Constructor.
749   *
750   **/
751  public EM () {
752    super();
753   
754    m_SeedDefault = 100;
755    resetOptions();
756  }
757
758
759  /**
760   * Reset to default options
761   */
762  protected void resetOptions () {
763    m_minStdDev = 1e-6;
764    m_max_iterations = 100;
765    m_Seed = m_SeedDefault;
766    m_num_clusters = -1;
767    m_initialNumClusters = -1;
768    m_verbose = false;
769  }
770
771  /**
772   * Return the normal distributions for the cluster models
773   *
774   * @return a <code>double[][][]</code> value
775   */
776  public double [][][] getClusterModelsNumericAtts() {
777    return m_modelNormal;
778  }
779
780  /**
781   * Return the priors for the clusters
782   *
783   * @return a <code>double[]</code> value
784   */
785  public double [] getClusterPriors() {
786    return m_priors;
787  }
788
789  /**
790   * Outputs the generated clusters into a string.
791   *
792   * @return the clusterer in string representation
793   */
794  public String toString() {
795    if (m_displayModelInOldFormat) {
796      return toStringOriginal();
797    }
798
799    if (m_priors == null) {
800      return "No clusterer built yet!";
801    }
802    StringBuffer temp = new StringBuffer();
803    temp.append("\nEM\n==\n");
804    if (m_initialNumClusters == -1) {
805      temp.append("\nNumber of clusters selected by cross validation: "
806                  +m_num_clusters+"\n");
807    } else {
808      temp.append("\nNumber of clusters: " + m_num_clusters + "\n");
809    }
810
811    int maxWidth = 0;
812    int maxAttWidth = 0;
813    boolean containsKernel = false;
814   
815    // set up max widths
816    // attributes
817    for (int i = 0; i < m_num_attribs; i++) {
818      Attribute a = m_theInstances.attribute(i);
819      if (a.name().length() > maxAttWidth) {
820        maxAttWidth = m_theInstances.attribute(i).name().length();
821      }
822      if (a.isNominal()) {
823        // check values
824        for (int j = 0; j < a.numValues(); j++) {
825          String val = a.value(j) + "  ";
826          if (val.length() > maxAttWidth) {
827            maxAttWidth = val.length();
828          }
829        }
830      }
831    }
832
833    for (int i = 0; i < m_num_clusters; i++) {
834      for (int j = 0; j < m_num_attribs; j++) {
835        if (m_theInstances.attribute(j).isNumeric()) {
836          // check mean and std. dev. against maxWidth
837          double mean = Math.log(Math.abs(m_modelNormal[i][j][0])) / Math.log(10.0);
838          double stdD = Math.log(Math.abs(m_modelNormal[i][j][1])) / Math.log(10.0);
839          double width = (mean > stdD)
840            ? mean
841            : stdD;
842          if (width < 0) {
843            width = 1;
844          }
845          // decimal + # decimal places + 1
846          width += 6.0;
847          if ((int)width > maxWidth) {
848            maxWidth = (int)width;
849          }
850        } else {
851          // nominal distributions
852          DiscreteEstimator d = (DiscreteEstimator)m_model[i][j];
853          for (int k = 0; k < d.getNumSymbols(); k++) {
854            String size = Utils.doubleToString(d.getCount(k), maxWidth, 4).trim();
855            if (size.length() > maxWidth) {
856              maxWidth = size.length();
857            }
858          }
859          int sum = 
860            Utils.doubleToString(d.getSumOfCounts(), maxWidth, 4).trim().length();
861          if (sum > maxWidth) {
862            maxWidth = sum;
863          }
864        }
865      }
866    }
867
868    if (maxAttWidth < "Attribute".length()) {
869      maxAttWidth = "Attribute".length();
870    }   
871   
872    maxAttWidth += 2;
873
874    temp.append("\n\n");
875    temp.append(pad("Cluster", " ", 
876                    (maxAttWidth + maxWidth + 1) - "Cluster".length(), 
877                    true));
878   
879    temp.append("\n");
880    temp.append(pad("Attribute", " ", maxAttWidth - "Attribute".length(), false));
881
882    // cluster #'s
883    for (int i = 0; i < m_num_clusters; i++) {
884      String classL = "" + i;
885      temp.append(pad(classL, " ", maxWidth + 1 - classL.length(), true));
886    }
887    temp.append("\n");
888
889    // cluster priors
890    temp.append(pad("", " ", maxAttWidth, true));
891    for (int i = 0; i < m_num_clusters; i++) {
892      String priorP = Utils.doubleToString(m_priors[i], maxWidth, 2).trim();
893      priorP = "(" + priorP + ")";
894      temp.append(pad(priorP, " ", maxWidth + 1 - priorP.length(), true));
895    }
896
897    temp.append("\n");
898    temp.append(pad("", "=", maxAttWidth + 
899                    (maxWidth * m_num_clusters)
900                    + m_num_clusters + 1, true));
901    temp.append("\n");
902
903    for (int i = 0; i < m_num_attribs; i++) {
904      String attName = m_theInstances.attribute(i).name();
905      temp.append(attName + "\n");
906
907      if (m_theInstances.attribute(i).isNumeric()) {
908        String meanL = "  mean";
909        temp.append(pad(meanL, " ", maxAttWidth + 1 - meanL.length(), false));
910        for (int j = 0; j < m_num_clusters; j++) {
911          // means
912          String mean = 
913            Utils.doubleToString(m_modelNormal[j][i][0], maxWidth, 4).trim();
914          temp.append(pad(mean, " ", maxWidth + 1 - mean.length(), true));
915        }
916        temp.append("\n");           
917        // now do std deviations
918        String stdDevL = "  std. dev.";
919        temp.append(pad(stdDevL, " ", maxAttWidth + 1 - stdDevL.length(), false));
920        for (int j = 0; j < m_num_clusters; j++) {
921          String stdDev = 
922            Utils.doubleToString(m_modelNormal[j][i][1], maxWidth, 4).trim();
923          temp.append(pad(stdDev, " ", maxWidth + 1 - stdDev.length(), true));
924        }
925        temp.append("\n\n");
926      } else {
927        Attribute a = m_theInstances.attribute(i);
928        for (int j = 0; j < a.numValues(); j++) {
929          String val = "  " + a.value(j);
930          temp.append(pad(val, " ", maxAttWidth + 1 - val.length(), false));
931          for (int k = 0; k < m_num_clusters; k++) {
932            DiscreteEstimator d = (DiscreteEstimator)m_model[k][i];
933            String count = Utils.doubleToString(d.getCount(j), maxWidth, 4).trim();
934            temp.append(pad(count, " ", maxWidth + 1 - count.length(), true));
935          }
936          temp.append("\n");
937        }
938        // do the totals
939        String total = "  [total]";
940        temp.append(pad(total, " ", maxAttWidth + 1 - total.length(), false));
941        for (int k = 0; k < m_num_clusters; k++) {
942          DiscreteEstimator d = (DiscreteEstimator)m_model[k][i];
943          String count = 
944            Utils.doubleToString(d.getSumOfCounts(), maxWidth, 4).trim();
945            temp.append(pad(count, " ", maxWidth + 1 - count.length(), true));
946        }
947        temp.append("\n");       
948      }
949    }
950
951    return temp.toString();
952  }
953
954  private String pad(String source, String padChar, 
955                     int length, boolean leftPad) {
956    StringBuffer temp = new StringBuffer();
957
958    if (leftPad) {
959      for (int i = 0; i< length; i++) {
960        temp.append(padChar);
961      }
962      temp.append(source);
963    } else {
964      temp.append(source);
965      for (int i = 0; i< length; i++) {
966        temp.append(padChar);
967      }
968    }
969    return temp.toString();
970  }
971
972  /**
973   * Outputs the generated clusters into a string.
974   *
975   * @return the clusterer in string representation
976   */
977  protected String toStringOriginal () {
978    if (m_priors == null) {
979      return "No clusterer built yet!";
980    }
981    StringBuffer temp = new StringBuffer();
982    temp.append("\nEM\n==\n");
983    if (m_initialNumClusters == -1) {
984      temp.append("\nNumber of clusters selected by cross validation: "
985                  +m_num_clusters+"\n");
986    } else {
987      temp.append("\nNumber of clusters: " + m_num_clusters + "\n");
988    }
989
990    for (int j = 0; j < m_num_clusters; j++) {
991      temp.append("\nCluster: " + j + " Prior probability: " 
992                  + Utils.doubleToString(m_priors[j], 4) + "\n\n");
993
994      for (int i = 0; i < m_num_attribs; i++) {
995        temp.append("Attribute: " + m_theInstances.attribute(i).name() + "\n");
996
997        if (m_theInstances.attribute(i).isNominal()) {
998          if (m_model[j][i] != null) {
999            temp.append(m_model[j][i].toString());
1000          }
1001        }
1002        else {
1003          temp.append("Normal Distribution. Mean = " 
1004                      + Utils.doubleToString(m_modelNormal[j][i][0], 4) 
1005                      + " StdDev = " 
1006                      + Utils.doubleToString(m_modelNormal[j][i][1], 4) 
1007                      + "\n");
1008        }
1009      }
1010    }
1011
1012    return  temp.toString();
1013  }
1014
1015
1016  /**
1017   * verbose output for debugging
1018   * @param inst the training instances
1019   */
1020  private void EM_Report (Instances inst) {
1021    int i, j, l, m;
1022    System.out.println("======================================");
1023
1024    for (j = 0; j < m_num_clusters; j++) {
1025      for (i = 0; i < m_num_attribs; i++) {
1026        System.out.println("Clust: " + j + " att: " + i + "\n");
1027
1028        if (m_theInstances.attribute(i).isNominal()) {
1029          if (m_model[j][i] != null) {
1030            System.out.println(m_model[j][i].toString());
1031          }
1032        }
1033        else {
1034          System.out.println("Normal Distribution. Mean = " 
1035                             + Utils.doubleToString(m_modelNormal[j][i][0]
1036                                                    , 8, 4) 
1037                             + " StandardDev = " 
1038                             + Utils.doubleToString(m_modelNormal[j][i][1]
1039                                                    , 8, 4) 
1040                             + " WeightSum = " 
1041                             + Utils.doubleToString(m_modelNormal[j][i][2]
1042                                                    , 8, 4));
1043        }
1044      }
1045    }
1046   
1047    for (l = 0; l < inst.numInstances(); l++) {
1048      m = Utils.maxIndex(m_weights[l]);
1049      System.out.print("Inst " + Utils.doubleToString((double)l, 5, 0) 
1050                       + " Class " + m + "\t");
1051      for (j = 0; j < m_num_clusters; j++) {
1052        System.out.print(Utils.doubleToString(m_weights[l][j], 7, 5) + "  ");
1053      }
1054      System.out.println();
1055    }
1056  }
1057
1058
1059  /**
1060   * estimate the number of clusters by cross validation on the training
1061   * data.
1062   *
1063   * @throws Exception if something goes wrong
1064   */
1065  private void CVClusters ()
1066    throws Exception {
1067    double CVLogLikely = -Double.MAX_VALUE;
1068    double templl, tll;
1069    boolean CVincreased = true;
1070    m_num_clusters = 1;
1071    int num_clusters = m_num_clusters;
1072    int i;
1073    Random cvr;
1074    Instances trainCopy;
1075    int numFolds = (m_theInstances.numInstances() < 10) 
1076      ? m_theInstances.numInstances() 
1077      : 10;
1078
1079    boolean ok = true;
1080    int seed = getSeed();
1081    int restartCount = 0;
1082    CLUSTER_SEARCH: while (CVincreased) {
1083      // theInstances.stratify(10);
1084       
1085      CVincreased = false;
1086      cvr = new Random(getSeed());
1087      trainCopy = new Instances(m_theInstances);
1088      trainCopy.randomize(cvr);
1089      templl = 0.0;
1090      for (i = 0; i < numFolds; i++) {
1091        Instances cvTrain = trainCopy.trainCV(numFolds, i, cvr);
1092        if (num_clusters > cvTrain.numInstances()) {
1093          break CLUSTER_SEARCH;
1094        }
1095        Instances cvTest = trainCopy.testCV(numFolds, i);
1096        m_rr = new Random(seed);
1097        for (int z=0; z<10; z++) m_rr.nextDouble();
1098        m_num_clusters = num_clusters;
1099        EM_Init(cvTrain);
1100        try {
1101          iterate(cvTrain, false);
1102        } catch (Exception ex) {
1103          // catch any problems - i.e. empty clusters occuring
1104          ex.printStackTrace();
1105          //          System.err.println("Restarting after CV training failure ("+num_clusters+" clusters");
1106          seed++;
1107          restartCount++;
1108          ok = false;
1109          if (restartCount > 5) {
1110            break CLUSTER_SEARCH;
1111          }
1112          break;
1113        }
1114        try {
1115          tll = E(cvTest, false);
1116        } catch (Exception ex) {
1117          // catch any problems - i.e. empty clusters occuring
1118          //          ex.printStackTrace();
1119          ex.printStackTrace();
1120          //          System.err.println("Restarting after CV testing failure ("+num_clusters+" clusters");
1121          //          throw new Exception(ex);
1122          seed++;
1123          restartCount++;
1124          ok = false;
1125          if (restartCount > 5) {
1126            break CLUSTER_SEARCH;
1127          }
1128          break;
1129        }
1130
1131        if (m_verbose) {
1132          System.out.println("# clust: " + num_clusters + " Fold: " + i
1133                             + " Loglikely: " + tll);
1134        }
1135        templl += tll;
1136      }
1137
1138      if (ok) {
1139        restartCount = 0;
1140        seed = getSeed();
1141        templl /= (double)numFolds;
1142       
1143        if (m_verbose) {
1144          System.out.println("===================================" 
1145                             + "==============\n# clust: " 
1146                             + num_clusters
1147                             + " Mean Loglikely: " 
1148                             + templl
1149                             + "\n================================" 
1150                             + "=================");
1151        }
1152       
1153        if (templl > CVLogLikely) {
1154          CVLogLikely = templl;
1155          CVincreased = true;
1156          num_clusters++;
1157        }
1158      }
1159    }
1160
1161    if (m_verbose) {
1162      System.out.println("Number of clusters: " + (num_clusters - 1));
1163    }
1164
1165    m_num_clusters = num_clusters - 1;
1166  }
1167
1168
1169  /**
1170   * Returns the number of clusters.
1171   *
1172   * @return the number of clusters generated for a training dataset.
1173   * @throws Exception if number of clusters could not be returned
1174   * successfully
1175   */
1176  public int numberOfClusters ()
1177    throws Exception {
1178    if (m_num_clusters == -1) {
1179      throw  new Exception("Haven't generated any clusters!");
1180    }
1181
1182    return  m_num_clusters;
1183  }
1184
1185 /**
1186  * Updates the minimum and maximum values for all the attributes
1187  * based on a new instance.
1188  *
1189  * @param instance the new instance
1190  */
1191  private void updateMinMax(Instance instance) {
1192   
1193    for (int j = 0; j < m_theInstances.numAttributes(); j++) {
1194      if (!instance.isMissing(j)) {
1195        if (Double.isNaN(m_minValues[j])) {
1196          m_minValues[j] = instance.value(j);
1197          m_maxValues[j] = instance.value(j);
1198        } else {
1199          if (instance.value(j) < m_minValues[j]) {
1200            m_minValues[j] = instance.value(j);
1201          } else {
1202            if (instance.value(j) > m_maxValues[j]) {
1203              m_maxValues[j] = instance.value(j);
1204            }
1205          }
1206        }
1207      }
1208    }
1209  }
1210
1211  /**
1212   * Returns default capabilities of the clusterer (i.e., the ones of
1213   * SimpleKMeans).
1214   *
1215   * @return      the capabilities of this clusterer
1216   */
1217  public Capabilities getCapabilities() {
1218    Capabilities result = new SimpleKMeans().getCapabilities();
1219    result.setOwner(this);
1220    return result;
1221  }
1222 
1223  /**
1224   * Generates a clusterer. Has to initialize all fields of the clusterer
1225   * that are not being set via options.
1226   *
1227   * @param data set of instances serving as training data
1228   * @throws Exception if the clusterer has not been
1229   * generated successfully
1230   */
1231  public void buildClusterer (Instances data)
1232    throws Exception {
1233   
1234    // can clusterer handle the data?
1235    getCapabilities().testWithFail(data);
1236
1237    m_replaceMissing = new ReplaceMissingValues();
1238    Instances instances = new Instances(data);
1239    instances.setClassIndex(-1);
1240    m_replaceMissing.setInputFormat(instances);
1241    data = weka.filters.Filter.useFilter(instances, m_replaceMissing);
1242    instances = null;
1243   
1244    m_theInstances = data;
1245
1246    // calculate min and max values for attributes
1247    m_minValues = new double [m_theInstances.numAttributes()];
1248    m_maxValues = new double [m_theInstances.numAttributes()];
1249    for (int i = 0; i < m_theInstances.numAttributes(); i++) {
1250      m_minValues[i] = m_maxValues[i] = Double.NaN;
1251    }
1252    for (int i = 0; i < m_theInstances.numInstances(); i++) {
1253      updateMinMax(m_theInstances.instance(i));
1254    }
1255
1256    doEM();
1257   
1258    // save memory
1259    m_theInstances = new Instances(m_theInstances,0);
1260  }
1261
1262  /**
1263   * Returns the cluster priors.
1264   *
1265   * @return the cluster priors
1266   */
1267  public double[] clusterPriors() {
1268
1269    double[] n = new double[m_priors.length];
1270 
1271    System.arraycopy(m_priors, 0, n, 0, n.length);
1272    return n;
1273  }
1274
1275  /**
1276   * Computes the log of the conditional density (per cluster) for a given instance.
1277   *
1278   * @param inst the instance to compute the density for
1279   * @return an array containing the estimated densities
1280   * @throws Exception if the density could not be computed
1281   * successfully
1282   */
1283  public double[] logDensityPerClusterForInstance(Instance inst) throws Exception {
1284
1285    int i, j;
1286    double logprob;
1287    double[] wghts = new double[m_num_clusters];
1288   
1289    m_replaceMissing.input(inst);
1290    inst = m_replaceMissing.output();
1291
1292    for (i = 0; i < m_num_clusters; i++) {
1293      //      System.err.println("Cluster : "+i);
1294      logprob = 0.0;
1295
1296      for (j = 0; j < m_num_attribs; j++) {
1297        if (!inst.isMissing(j)) {
1298          if (inst.attribute(j).isNominal()) {
1299            logprob += Math.log(m_model[i][j].getProbability(inst.value(j)));
1300          }
1301          else { // numeric attribute
1302            logprob += logNormalDens(inst.value(j), 
1303                                     m_modelNormal[i][j][0], 
1304                                     m_modelNormal[i][j][1]);
1305            /*      System.err.println(logNormalDens(inst.value(j),
1306                                     m_modelNormal[i][j][0],
1307                                     m_modelNormal[i][j][1]) + " "); */
1308          }
1309        }
1310      }
1311      //      System.err.println("");
1312
1313      wghts[i] = logprob;
1314    }
1315    return  wghts;
1316  }
1317
1318
1319  /**
1320   * Perform the EM algorithm
1321   *
1322   * @throws Exception if something goes wrong
1323   */
1324  private void doEM ()
1325    throws Exception {
1326   
1327    if (m_verbose) {
1328      System.out.println("Seed: " + getSeed());
1329    }
1330
1331    m_rr = new Random(getSeed());
1332
1333    // throw away numbers to avoid problem of similar initial numbers
1334    // from a similar seed
1335    for (int i=0; i<10; i++) m_rr.nextDouble();
1336
1337    m_num_instances = m_theInstances.numInstances();
1338    m_num_attribs = m_theInstances.numAttributes();
1339
1340    if (m_verbose) {
1341      System.out.println("Number of instances: " 
1342                         + m_num_instances
1343                         + "\nNumber of atts: " 
1344                         + m_num_attribs
1345                         + "\n");
1346    }
1347
1348    // setDefaultStdDevs(theInstances);
1349    // cross validate to determine number of clusters?
1350    if (m_initialNumClusters == -1) {
1351      if (m_theInstances.numInstances() > 9) {
1352        CVClusters();
1353        m_rr = new Random(getSeed());
1354        for (int i=0; i<10; i++) m_rr.nextDouble();
1355      } else {
1356        m_num_clusters = 1;
1357      }
1358    }
1359
1360    // fit full training set
1361    EM_Init(m_theInstances);
1362    m_loglikely = iterate(m_theInstances, m_verbose);
1363  }
1364
1365
1366  /**
1367   * iterates the E and M steps until the log likelihood of the data
1368   * converges.
1369   *
1370   * @param inst the training instances.
1371   * @param report be verbose.
1372   * @return the log likelihood of the data
1373   * @throws Exception if something goes wrong
1374   */
1375  private double iterate (Instances inst, boolean report)
1376    throws Exception {
1377   
1378    int i;
1379    double llkold = 0.0;
1380    double llk = 0.0;
1381
1382    if (report) {
1383      EM_Report(inst);
1384    }
1385
1386    boolean ok = false;
1387    int seed = getSeed();
1388    int restartCount = 0;
1389    while (!ok) {
1390      try {
1391        for (i = 0; i < m_max_iterations; i++) {
1392          llkold = llk;
1393          llk = E(inst, true);
1394         
1395          if (report) {
1396            System.out.println("Loglikely: " + llk);
1397          }
1398         
1399          if (i > 0) {
1400            if ((llk - llkold) < 1e-6) {
1401              break;
1402            }
1403          }
1404          M(inst);
1405        }
1406        ok = true;
1407      } catch (Exception ex) {
1408        //        System.err.println("Restarting after training failure");
1409        ex.printStackTrace();
1410        seed++;
1411        restartCount++;
1412        m_rr = new Random(seed);
1413        for (int z = 0; z < 10; z++) {
1414          m_rr.nextDouble(); m_rr.nextInt();
1415        }
1416        if (restartCount > 5) {
1417          //          System.err.println("Reducing the number of clusters");
1418          m_num_clusters--;
1419          restartCount = 0;
1420        }
1421        EM_Init(m_theInstances);
1422      }
1423    }
1424     
1425    if (report) {
1426      EM_Report(inst);
1427    }
1428
1429    return  llk;
1430  }
1431 
1432  /**
1433   * Returns the revision string.
1434   *
1435   * @return            the revision
1436   */
1437  public String getRevision() {
1438    return RevisionUtils.extract("$Revision: 1.44 $");
1439  }
1440
1441  // ============
1442  // Test method.
1443  // ============
1444  /**
1445   * Main method for testing this class.
1446   *
1447   * @param argv should contain the following arguments: <p>
1448   * -t training file [-T test file] [-N number of clusters] [-S random seed]
1449   */
1450  public static void main (String[] argv) {
1451    runClusterer(new EM(), argv);
1452  }
1453}
1454
Note: See TracBrowser for help on using the repository browser.