source: src/main/java/weka/clusterers/SimpleKMeans.java @ 28

Last change on this file since 28 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 38.2 KB
RevLine 
[4]1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    SimpleKMeans.java
19 *    Copyright (C) 2000 University of Waikato, Hamilton, New Zealand
20 *
21 */
22package weka.clusterers;
23
24import weka.classifiers.rules.DecisionTableHashKey;
25import weka.core.Attribute;
26import weka.core.Capabilities;
27import weka.core.DistanceFunction;
28import weka.core.EuclideanDistance;
29import weka.core.Instance;
30import weka.core.DenseInstance;
31import weka.core.Instances;
32import weka.core.ManhattanDistance;
33import weka.core.Option;
34import weka.core.RevisionUtils;
35import weka.core.Utils;
36import weka.core.WeightedInstancesHandler;
37import weka.core.Capabilities.Capability;
38import weka.filters.Filter;
39import weka.filters.unsupervised.attribute.ReplaceMissingValues;
40
41import java.util.Enumeration;
42import java.util.HashMap;
43import java.util.Random;
44import java.util.Vector;
45
46/**
47 <!-- globalinfo-start -->
48 * Cluster data using the k means algorithm
49 * <p/>
50 <!-- globalinfo-end -->
51 *
52 <!-- options-start -->
53 * Valid options are: <p/>
54 *
55 * <pre> -N &lt;num&gt;
56 *  number of clusters.
57 *  (default 2).</pre>
58 *
59 * <pre> -V
60 *  Display std. deviations for centroids.
61 * </pre>
62 *
63 * <pre> -M
64 *  Replace missing values with mean/mode.
65 * </pre>
66 *
67 * <pre> -S &lt;num&gt;
68 *  Random number seed.
69 *  (default 10)</pre>
70 *
71 * <pre> -A &lt;classname and options&gt;
72 *  Distance function to be used for instance comparison
73 *  (default weka.core.EuclidianDistance)</pre>
74 *
75 * <pre> -I &lt;num&gt;
76 *  Maximum number of iterations. </pre>
77 *
78 * <pre> -O
79 *  Preserve order of instances. </pre>
80 *
81 *
82 <!-- options-end -->
83 *
84 * @author Mark Hall (mhall@cs.waikato.ac.nz)
85 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
86 * @version $Revision: 5987 $
87 * @see RandomizableClusterer
88 */
89public class SimpleKMeans 
90  extends RandomizableClusterer
91  implements NumberOfClustersRequestable, WeightedInstancesHandler {
92
93  /** for serialization */
94  static final long serialVersionUID = -3235809600124455376L;
95 
96  /**
97   * replace missing values in training instances
98   */
99  private ReplaceMissingValues m_ReplaceMissingFilter;
100
101  /**
102   * number of clusters to generate
103   */
104  private int m_NumClusters = 2;
105
106  /**
107   * holds the cluster centroids
108   */
109  private Instances m_ClusterCentroids;
110
111  /**
112   * Holds the standard deviations of the numeric attributes in each cluster
113   */
114  private Instances m_ClusterStdDevs;
115
116 
117  /**
118   * For each cluster, holds the frequency counts for the values of each
119   * nominal attribute
120   */
121  private int [][][] m_ClusterNominalCounts;
122  private int[][] m_ClusterMissingCounts;
123 
124  /**
125   * Stats on the full data set for comparison purposes
126   * In case the attribute is numeric the value is the mean if is
127   * being used the Euclidian distance or the median if Manhattan distance
128   * and if the attribute is nominal then it's mode is saved
129   */
130  private double[] m_FullMeansOrMediansOrModes;
131  private double[] m_FullStdDevs;
132  private int[][] m_FullNominalCounts;
133  private int[] m_FullMissingCounts;
134
135  /**
136   * Display standard deviations for numeric atts
137   */
138  private boolean m_displayStdDevs;
139
140  /**
141   * Replace missing values globally?
142   */
143  private boolean m_dontReplaceMissing = false;
144
145  /**
146   * The number of instances in each cluster
147   */
148  private int [] m_ClusterSizes;
149
150  /**
151   * Maximum number of iterations to be executed
152   */
153  private int m_MaxIterations = 500;
154
155  /**
156   * Keep track of the number of iterations completed before convergence
157   */
158  private int m_Iterations = 0;
159
160  /**
161   * Holds the squared errors for all clusters
162   */
163  private double [] m_squaredErrors;
164
165  /** the distance function used. */
166  protected DistanceFunction m_DistanceFunction = new EuclideanDistance();
167
168  /**
169   * Preserve order of instances
170   */
171  private boolean m_PreserveOrder = false;
172       
173  /**
174   * Assignments obtained
175   */
176  protected int[] m_Assignments = null;
177       
178  /**
179   * the default constructor
180   */
181  public SimpleKMeans() {
182    super();
183   
184    m_SeedDefault = 10;
185    setSeed(m_SeedDefault);
186  }
187 
188  /**
189   * Returns a string describing this clusterer
190   * @return a description of the evaluator suitable for
191   * displaying in the explorer/experimenter gui
192   */
193  public String globalInfo() {
194    return "Cluster data using the k means algorithm. Can use either "
195      + "the Euclidean distance (default) or the Manhattan distance."
196      + " If the Manhattan distance is used, then centroids are computed "
197      + "as the component-wise median rather than mean.";
198  }
199
200  /**
201   * Returns default capabilities of the clusterer.
202   *
203   * @return      the capabilities of this clusterer
204   */
205  public Capabilities getCapabilities() {
206    Capabilities result = super.getCapabilities();
207    result.disableAll();
208    result.enable(Capability.NO_CLASS);
209
210    // attributes
211    result.enable(Capability.NOMINAL_ATTRIBUTES);
212    result.enable(Capability.NUMERIC_ATTRIBUTES);
213    result.enable(Capability.MISSING_VALUES);
214
215    return result;
216  }
217
218  /**
219   * Generates a clusterer. Has to initialize all fields of the clusterer
220   * that are not being set via options.
221   *
222   * @param data set of instances serving as training data
223   * @throws Exception if the clusterer has not been
224   * generated successfully
225   */
226  public void buildClusterer(Instances data) throws Exception {
227
228    // can clusterer handle the data?
229    getCapabilities().testWithFail(data);
230
231    m_Iterations = 0;
232
233    m_ReplaceMissingFilter = new ReplaceMissingValues();
234    Instances instances = new Instances(data);
235                               
236    instances.setClassIndex(-1);
237    if (!m_dontReplaceMissing) {
238      m_ReplaceMissingFilter.setInputFormat(instances);
239      instances = Filter.useFilter(instances, m_ReplaceMissingFilter);
240    }
241
242    m_FullMissingCounts = new int[instances.numAttributes()];
243    if (m_displayStdDevs) {
244      m_FullStdDevs = new double[instances.numAttributes()];
245    }
246    m_FullNominalCounts = new int[instances.numAttributes()][0];
247               
248    m_FullMeansOrMediansOrModes = moveCentroid(0, instances, false);
249    for (int i = 0; i < instances.numAttributes(); i++) {
250      m_FullMissingCounts[i] = instances.attributeStats(i).missingCount;
251      if (instances.attribute(i).isNumeric()) {
252        if (m_displayStdDevs) {
253          m_FullStdDevs[i] = Math.sqrt(instances.variance(i));
254        }
255        if (m_FullMissingCounts[i] == instances.numInstances()) {
256          m_FullMeansOrMediansOrModes[i] = Double.NaN; // mark missing as mean
257        }
258      } else {
259        m_FullNominalCounts[i] = instances.attributeStats(i).nominalCounts;
260        if (m_FullMissingCounts[i] 
261            > m_FullNominalCounts[i][Utils.maxIndex(m_FullNominalCounts[i])]) {
262          m_FullMeansOrMediansOrModes[i] = -1; // mark missing as most common value
263        }
264      }
265    }
266
267    m_ClusterCentroids = new Instances(instances, m_NumClusters);
268    int[] clusterAssignments = new int [instances.numInstances()];
269
270    if(m_PreserveOrder)
271      m_Assignments = clusterAssignments;
272               
273    m_DistanceFunction.setInstances(instances);
274   
275    Random RandomO = new Random(getSeed());
276    int instIndex;
277    HashMap initC = new HashMap();
278    DecisionTableHashKey hk = null;
279
280    Instances initInstances = null;
281    if(m_PreserveOrder)
282      initInstances = new Instances(instances);
283    else
284      initInstances = instances;
285               
286    for (int j = initInstances.numInstances() - 1; j >= 0; j--) {
287      instIndex = RandomO.nextInt(j+1);
288      hk = new DecisionTableHashKey(initInstances.instance(instIndex),
289                                    initInstances.numAttributes(), true);
290      if (!initC.containsKey(hk)) {
291        m_ClusterCentroids.add(initInstances.instance(instIndex));
292        initC.put(hk, null);
293      }
294      initInstances.swap(j, instIndex);
295     
296      if (m_ClusterCentroids.numInstances() == m_NumClusters) {
297        break;
298      }
299    }
300
301    m_NumClusters = m_ClusterCentroids.numInstances();
302   
303    //removing reference
304    initInstances = null;
305               
306    int i;
307    boolean converged = false;
308    int emptyClusterCount;
309    Instances [] tempI = new Instances[m_NumClusters];
310    m_squaredErrors = new double [m_NumClusters];
311    m_ClusterNominalCounts = new int [m_NumClusters][instances.numAttributes()][0];
312    m_ClusterMissingCounts = new int[m_NumClusters][instances.numAttributes()];
313    while (!converged) {
314      emptyClusterCount = 0;
315      m_Iterations++;
316      converged = true;
317      for (i = 0; i < instances.numInstances(); i++) {
318        Instance toCluster = instances.instance(i);
319        int newC = clusterProcessedInstance(toCluster, true);
320        if (newC != clusterAssignments[i]) {
321          converged = false;
322        }
323        clusterAssignments[i] = newC;
324      }
325     
326      // update centroids
327      m_ClusterCentroids = new Instances(instances, m_NumClusters);
328      for (i = 0; i < m_NumClusters; i++) {
329        tempI[i] = new Instances(instances, 0);
330      }
331      for (i = 0; i < instances.numInstances(); i++) {
332        tempI[clusterAssignments[i]].add(instances.instance(i));
333      }
334      for (i = 0; i < m_NumClusters; i++) {
335        if (tempI[i].numInstances() == 0) {
336          // empty cluster
337          emptyClusterCount++;
338        } else {
339          moveCentroid( i, tempI[i], true  );                                   
340        }
341      }
342
343      if (emptyClusterCount > 0) {
344        m_NumClusters -= emptyClusterCount;
345        if (converged) {
346          Instances[] t = new Instances[m_NumClusters];
347          int index = 0;
348          for (int k = 0; k < tempI.length; k++) {
349            if (tempI[k].numInstances() > 0) {
350              t[index++] = tempI[k];
351            }
352          }
353          tempI = t;
354        } else {
355          tempI = new Instances[m_NumClusters];
356        }
357      }
358                       
359      if(m_Iterations == m_MaxIterations)
360        converged = true;
361                       
362      if (!converged) {
363        m_squaredErrors = new double [m_NumClusters];
364        m_ClusterNominalCounts = new int [m_NumClusters][instances.numAttributes()][0];
365      }
366    }
367               
368    if (m_displayStdDevs) {
369      m_ClusterStdDevs = new Instances(instances, m_NumClusters);
370    }
371    m_ClusterSizes = new int [m_NumClusters];
372    for (i = 0; i < m_NumClusters; i++) {
373      if (m_displayStdDevs) {
374        double [] vals2 = new double[instances.numAttributes()];
375        for (int j = 0; j < instances.numAttributes(); j++) {
376          if (instances.attribute(j).isNumeric()) {
377            vals2[j] = Math.sqrt(tempI[i].variance(j));
378          } else {
379            vals2[j] = Utils.missingValue();
380          }     
381        }   
382        m_ClusterStdDevs.add(new DenseInstance(1.0, vals2));
383      }
384      m_ClusterSizes[i] = tempI[i].numInstances();
385    }
386  }
387
388  /**
389   * Move the centroid to it's new coordinates. Generate the centroid coordinates based
390   * on it's  members (objects assigned to the cluster of the centroid) and the distance
391   * function being used.
392   * @param centroidIndex index of the centroid which the coordinates will be computed
393   * @param members the objects that are assigned to the cluster of this centroid
394   * @param updateClusterInfo if the method is supposed to update the m_Cluster arrays
395   * @return the centroid coordinates
396   */
397  protected double[] moveCentroid(int centroidIndex, Instances members, boolean updateClusterInfo){
398    double [] vals = new double[members.numAttributes()];
399               
400    //used only for Manhattan Distance
401    Instances sortedMembers = null;
402    int middle = 0;
403    boolean dataIsEven = false;
404               
405    if(m_DistanceFunction instanceof ManhattanDistance){
406      middle = (members.numInstances()-1)/2;
407      dataIsEven = ((members.numInstances()%2)==0);
408      if(m_PreserveOrder){
409        sortedMembers = members;
410      }else{
411        sortedMembers = new Instances(members);
412      }
413    }
414               
415    for (int j = 0; j < members.numAttributes(); j++) {                                         
416                       
417      //in case of Euclidian distance the centroid is the mean point
418      //in case of Manhattan distance the centroid is the median point
419      //in both cases, if the attribute is nominal, the centroid is the mode
420      if(m_DistanceFunction instanceof EuclideanDistance ||
421         members.attribute(j).isNominal())
422        {                                                                                                       
423          vals[j] = members.meanOrMode(j);
424        }else if(m_DistanceFunction instanceof ManhattanDistance){
425        //singleton special case
426        if(members.numInstances() == 1){
427          vals[j] = members.instance(0).value(j);
428        }else{
429          sortedMembers.kthSmallestValue(j, middle+1);
430          vals[j] = sortedMembers.instance(middle).value(j);
431          if( dataIsEven ){                                             
432            sortedMembers.kthSmallestValue(j, middle+2);                                               
433            vals[j] = (vals[j]+sortedMembers.instance(middle+1).value(j))/2;
434          }
435        }
436      } 
437                       
438      if(updateClusterInfo){
439        m_ClusterMissingCounts[centroidIndex][j] = members.attributeStats(j).missingCount;
440        m_ClusterNominalCounts[centroidIndex][j] = members.attributeStats(j).nominalCounts;
441        if (members.attribute(j).isNominal()) {
442          if (m_ClusterMissingCounts[centroidIndex][j] > 
443              m_ClusterNominalCounts[centroidIndex][j][Utils.maxIndex(m_ClusterNominalCounts[centroidIndex][j])]) 
444            {
445              vals[j] = Utils.missingValue(); // mark mode as missing
446            }
447        } else {
448          if (m_ClusterMissingCounts[centroidIndex][j] == members.numInstances()) {
449            vals[j] = Utils.missingValue(); // mark mean as missing
450          }
451        }
452      }
453    }
454    if(updateClusterInfo)
455      m_ClusterCentroids.add(new DenseInstance(1.0, vals));
456    return vals;
457  }
458       
459  /**
460   * clusters an instance that has been through the filters
461   *
462   * @param instance the instance to assign a cluster to
463   * @param updateErrors if true, update the within clusters sum of errors
464   * @return a cluster number
465   */
466  private int clusterProcessedInstance(Instance instance, boolean updateErrors) {
467    double minDist = Integer.MAX_VALUE;
468    int bestCluster = 0;
469    for (int i = 0; i < m_NumClusters; i++) {
470      double dist = m_DistanceFunction.distance(instance, m_ClusterCentroids.instance(i));
471      if (dist < minDist) {
472        minDist = dist;
473        bestCluster = i;
474      }
475    }
476    if (updateErrors) {
477      if(m_DistanceFunction instanceof EuclideanDistance){
478        //Euclidean distance to Squared Euclidean distance
479        minDist *= minDist;
480      }
481      m_squaredErrors[bestCluster] += minDist;
482    }
483    return bestCluster;
484  }
485
486  /**
487   * Classifies a given instance.
488   *
489   * @param instance the instance to be assigned to a cluster
490   * @return the number of the assigned cluster as an interger
491   * if the class is enumerated, otherwise the predicted value
492   * @throws Exception if instance could not be classified
493   * successfully
494   */
495  public int clusterInstance(Instance instance) throws Exception {
496    Instance inst = null;
497    if (!m_dontReplaceMissing) {
498      m_ReplaceMissingFilter.input(instance);
499      m_ReplaceMissingFilter.batchFinished();
500      inst = m_ReplaceMissingFilter.output();
501    } else {
502      inst = instance;
503    }
504
505    return clusterProcessedInstance(inst, false);
506  }
507
508  /**
509   * Returns the number of clusters.
510   *
511   * @return the number of clusters generated for a training dataset.
512   * @throws Exception if number of clusters could not be returned
513   * successfully
514   */
515  public int numberOfClusters() throws Exception {
516    return m_NumClusters;
517  }
518
519  /**
520   * Returns an enumeration describing the available options.
521   *
522   * @return an enumeration of all the available options.
523   */
524  public Enumeration listOptions () {
525    Vector result = new Vector();
526
527    result.addElement(new Option(
528                                 "\tnumber of clusters.\n"
529                                 + "\t(default 2).", 
530                                 "N", 1, "-N <num>"));
531    result.addElement(new Option(
532                                 "\tDisplay std. deviations for centroids.\n", 
533                                 "V", 0, "-V"));
534    result.addElement(new Option(
535                                 "\tReplace missing values with mean/mode.\n", 
536                                 "M", 0, "-M"));
537
538    result.add(new Option(
539                          "\tDistance function to use.\n"
540                          + "\t(default: weka.core.EuclideanDistance)",
541                          "A", 1,"-A <classname and options>"));
542               
543    result.add(new Option(
544                          "\tMaximum number of iterations.\n",
545                          "I",1,"-I <num>"));
546
547    result.addElement(new Option(
548                                 "\tPreserve order of instances.\n", 
549                                 "O", 0, "-O"));
550               
551    Enumeration en = super.listOptions();
552    while (en.hasMoreElements())
553      result.addElement(en.nextElement());
554
555    return  result.elements();
556  }
557
558  /**
559   * Returns the tip text for this property
560   * @return tip text for this property suitable for
561   * displaying in the explorer/experimenter gui
562   */
563  public String numClustersTipText() {
564    return "set number of clusters";
565  }
566
567  /**
568   * set the number of clusters to generate
569   *
570   * @param n the number of clusters to generate
571   * @throws Exception if number of clusters is negative
572   */
573  public void setNumClusters(int n) throws Exception {
574    if (n <= 0) {
575      throw new Exception("Number of clusters must be > 0");
576    }
577    m_NumClusters = n;
578  }
579
580  /**
581   * gets the number of clusters to generate
582   *
583   * @return the number of clusters to generate
584   */
585  public int getNumClusters() {
586    return m_NumClusters;
587  }
588
589  /**
590   * Returns the tip text for this property
591   * @return tip text for this property suitable for
592   * displaying in the explorer/experimenter gui
593   */
594  public String maxIterationsTipText() {
595    return "set maximum number of iterations";
596  }
597
598  /**
599   * set the maximum number of iterations to be executed
600   *
601   * @param n the maximum number of iterations
602   * @throws Exception if maximum number of iteration is smaller than 1
603   */
604  public void setMaxIterations(int n) throws Exception {
605    if (n <= 0) {
606      throw new Exception("Maximum number of iterations must be > 0");
607    }
608    m_MaxIterations = n;
609  }
610
611  /**
612   * gets the number of maximum iterations to be executed
613   *
614   * @return the number of clusters to generate
615   */
616  public int getMaxIterations() {
617    return m_MaxIterations;
618  }
619       
620
621  /**
622   * Returns the tip text for this property
623   * @return tip text for this property suitable for
624   * displaying in the explorer/experimenter gui
625   */
626  public String displayStdDevsTipText() {
627    return "Display std deviations of numeric attributes "
628      + "and counts of nominal attributes.";
629  }
630
631  /**
632   * Sets whether standard deviations and nominal count
633   * Should be displayed in the clustering output
634   *
635   * @param stdD true if std. devs and counts should be
636   * displayed
637   */
638  public void setDisplayStdDevs(boolean stdD) {
639    m_displayStdDevs = stdD;
640  }
641
642  /**
643   * Gets whether standard deviations and nominal count
644   * Should be displayed in the clustering output
645   *
646   * @return true if std. devs and counts should be
647   * displayed
648   */
649  public boolean getDisplayStdDevs() {
650    return m_displayStdDevs;
651  }
652
653  /**
654   * Returns the tip text for this property
655   * @return tip text for this property suitable for
656   * displaying in the explorer/experimenter gui
657   */
658  public String dontReplaceMissingValuesTipText() {
659    return "Replace missing values globally with mean/mode.";
660  }
661
662  /**
663   * Sets whether missing values are to be replaced
664   *
665   * @param r true if missing values are to be
666   * replaced
667   */
668  public void setDontReplaceMissingValues(boolean r) {
669    m_dontReplaceMissing = r;
670  }
671
672  /**
673   * Gets whether missing values are to be replaced
674   *
675   * @return true if missing values are to be
676   * replaced
677   */
678  public boolean getDontReplaceMissingValues() {
679    return m_dontReplaceMissing;
680  }
681
682  /**
683   * Returns the tip text for this property.
684   *
685   * @return            tip text for this property suitable for
686   *                    displaying in the explorer/experimenter gui
687   */
688  public String distanceFunctionTipText() {
689    return "The distance function to use for instances comparison " +
690      "(default: weka.core.EuclideanDistance). ";
691  }
692
693  /**
694   * returns the distance function currently in use.
695   *
696   * @return the distance function
697   */
698  public DistanceFunction getDistanceFunction() {
699    return m_DistanceFunction;
700  }
701
702  /**
703   * sets the distance function to use for instance comparison.
704   *
705   * @param df the new distance function to use
706   * @throws Exception if instances cannot be processed
707   */
708  public void setDistanceFunction(DistanceFunction df) throws Exception {
709    if(!(df instanceof EuclideanDistance) && 
710       !(df instanceof ManhattanDistance))
711      {
712        throw new Exception("SimpleKMeans currently only supports the Euclidean and Manhattan distances.");
713      }
714    m_DistanceFunction = df;
715  }     
716
717  /**
718   * Returns the tip text for this property
719   * @return tip text for this property suitable for
720   * displaying in the explorer/experimenter gui
721   */
722  public String preserveInstancesOrderTipText() {
723    return "Preserve order of instances.";
724  }
725
726  /**
727   * Sets whether order of instances must be preserved
728   *
729   * @param r true if missing values are to be
730   * replaced
731   */
732  public void setPreserveInstancesOrder(boolean r) {
733    m_PreserveOrder = r;
734  }
735
736  /**
737   * Gets whether order of instances must be preserved
738   *
739   * @return true if missing values are to be
740   * replaced
741   */
742  public boolean getPreserveInstancesOrder() {
743    return m_PreserveOrder;
744  }
745       
746       
747  /**
748   * Parses a given list of options. <p/>
749   *
750   <!-- options-start -->
751   * Valid options are: <p/>
752   *
753   * <pre> -N &lt;num&gt;
754   *  number of clusters.
755   *  (default 2).</pre>
756   *
757   * <pre> -V
758   *  Display std. deviations for centroids.
759   * </pre>
760   *
761   * <pre> -M
762   *  Replace missing values with mean/mode.
763   * </pre>
764   *
765   * <pre> -S &lt;num&gt;
766   *  Random number seed.
767   *  (default 10)</pre>
768   *
769   * <pre> -A &lt;classname and options&gt;
770   *  Distance function to be used for instance comparison
771   *  (default weka.core.EuclidianDistance)</pre>
772   *
773   * <pre> -I &lt;num&gt;
774   *  Maximum number of iterations. </pre>
775   * 
776   * <pre> -O
777   *  Preserve order of instances.
778   * </pre>
779   *
780   <!-- options-end -->
781   *
782   * @param options the list of options as an array of strings
783   * @throws Exception if an option is not supported
784   */
785  public void setOptions (String[] options)
786    throws Exception {
787
788    m_displayStdDevs = Utils.getFlag("V", options);
789    m_dontReplaceMissing = Utils.getFlag("M", options);
790
791    String optionString = Utils.getOption('N', options);
792
793    if (optionString.length() != 0) {
794      setNumClusters(Integer.parseInt(optionString));
795    }
796   
797    optionString = Utils.getOption("I", options);
798    if (optionString.length() != 0) {
799      setMaxIterations(Integer.parseInt(optionString));
800    }
801               
802    String distFunctionClass = Utils.getOption('A', options);
803    if(distFunctionClass.length() != 0) {
804      String distFunctionClassSpec[] = Utils.splitOptions(distFunctionClass);
805      if(distFunctionClassSpec.length == 0) { 
806        throw new Exception("Invalid DistanceFunction specification string."); 
807      }
808      String className = distFunctionClassSpec[0];
809      distFunctionClassSpec[0] = "";
810
811      setDistanceFunction( (DistanceFunction)
812                           Utils.forName( DistanceFunction.class, 
813                                          className, distFunctionClassSpec) );
814    }
815    else {
816      setDistanceFunction(new EuclideanDistance());
817    }
818               
819    m_PreserveOrder = Utils.getFlag("O", options);
820
821    super.setOptions(options);
822  }
823
824  /**
825   * Gets the current settings of SimpleKMeans
826   *
827   * @return an array of strings suitable for passing to setOptions()
828   */
829  public String[] getOptions () {
830    int         i;
831    Vector      result;
832    String[]    options;
833
834    result = new Vector();
835
836    if (m_displayStdDevs) {
837      result.add("-V");
838    }
839
840    if (m_dontReplaceMissing) {
841      result.add("-M");
842    }
843
844    result.add("-N");
845    result.add("" + getNumClusters());
846
847    result.add("-A");
848    result.add((m_DistanceFunction.getClass().getName() + " " +
849                Utils.joinOptions(m_DistanceFunction.getOptions())).trim());
850               
851    result.add("-I");
852    result.add(""+ getMaxIterations());
853
854    if(m_PreserveOrder){
855      result.add("-O");
856    }
857               
858    options = super.getOptions();
859    for (i = 0; i < options.length; i++)
860      result.add(options[i]);
861
862    return (String[]) result.toArray(new String[result.size()]);         
863  }
864
865  /**
866   * return a string describing this clusterer
867   *
868   * @return a description of the clusterer as a string
869   */
870  public String toString() {
871    if (m_ClusterCentroids == null) {
872      return "No clusterer built yet!";
873    }
874
875    int maxWidth = 0;
876    int maxAttWidth = 0;
877    boolean containsNumeric = false;
878    for (int i = 0; i < m_NumClusters; i++) {
879      for (int j = 0 ;j < m_ClusterCentroids.numAttributes(); j++) {
880        if (m_ClusterCentroids.attribute(j).name().length() > maxAttWidth) {
881          maxAttWidth = m_ClusterCentroids.attribute(j).name().length();
882        }
883        if (m_ClusterCentroids.attribute(j).isNumeric()) {
884          containsNumeric = true;
885          double width = Math.log(Math.abs(m_ClusterCentroids.instance(i).value(j))) /
886            Math.log(10.0);
887          //          System.err.println(m_ClusterCentroids.instance(i).value(j)+" "+width);
888          if (width < 0) {
889            width = 1;
890          }
891          // decimal + # decimal places + 1
892          width += 6.0;
893          if ((int)width > maxWidth) {
894            maxWidth = (int)width;
895          }
896        }
897      }
898    }
899
900    for (int i = 0; i < m_ClusterCentroids.numAttributes(); i++) {
901      if (m_ClusterCentroids.attribute(i).isNominal()) {
902        Attribute a = m_ClusterCentroids.attribute(i);
903        for (int j = 0; j < m_ClusterCentroids.numInstances(); j++) {
904          String val = a.value((int)m_ClusterCentroids.instance(j).value(i));
905          if (val.length() > maxWidth) {
906            maxWidth = val.length();
907          }
908        }
909        for (int j = 0; j < a.numValues(); j++) {
910          String val = a.value(j) + " ";
911          if (val.length() > maxAttWidth) {
912            maxAttWidth = val.length();
913          }
914        }
915      }
916    }
917
918    if (m_displayStdDevs) {
919      // check for maximum width of maximum frequency count
920      for (int i = 0; i < m_ClusterCentroids.numAttributes(); i++) {
921        if (m_ClusterCentroids.attribute(i).isNominal()) {
922          int maxV = Utils.maxIndex(m_FullNominalCounts[i]);
923          /*          int percent = (int)((double)m_FullNominalCounts[i][maxV] /
924                      Utils.sum(m_ClusterSizes) * 100.0); */
925          int percent = 6; // max percent width (100%)
926          String nomV = "" + m_FullNominalCounts[i][maxV];
927          //            + " (" + percent + "%)";
928          if (nomV.length() + percent > maxWidth) {
929            maxWidth = nomV.length() + 1;
930          }
931        }
932      }
933    }
934
935    // check for size of cluster sizes
936    for (int i = 0; i < m_ClusterSizes.length; i++) {
937      String size = "(" + m_ClusterSizes[i] + ")";
938      if (size.length() > maxWidth) {
939        maxWidth = size.length();
940      }
941    }
942   
943    if (m_displayStdDevs && maxAttWidth < "missing".length()) {
944      maxAttWidth = "missing".length();
945    }
946   
947    String plusMinus = "+/-";
948    maxAttWidth += 2;
949    if (m_displayStdDevs && containsNumeric) {
950      maxWidth += plusMinus.length();
951    }
952    if (maxAttWidth < "Attribute".length() + 2) {
953      maxAttWidth = "Attribute".length() + 2;
954    }
955
956    if (maxWidth < "Full Data".length()) {
957      maxWidth = "Full Data".length() + 1;
958    }
959
960    if (maxWidth < "missing".length()) {
961      maxWidth = "missing".length() + 1;
962    }
963
964
965   
966    StringBuffer temp = new StringBuffer();
967    //    String naString = "N/A";
968
969   
970    /*    for (int i = 0; i < maxWidth+2; i++) {
971          naString += " ";
972          } */
973    temp.append("\nkMeans\n======\n");
974    temp.append("\nNumber of iterations: " + m_Iterations+"\n");
975               
976    if(m_DistanceFunction instanceof EuclideanDistance){
977      temp.append("Within cluster sum of squared errors: " + Utils.sum(m_squaredErrors));
978    }else{
979      temp.append("Sum of within cluster distances: " + Utils.sum(m_squaredErrors));
980    }
981               
982               
983    if (!m_dontReplaceMissing) {
984      temp.append("\nMissing values globally replaced with mean/mode");
985    }
986
987    temp.append("\n\nCluster centroids:\n");
988    temp.append(pad("Cluster#", " ", (maxAttWidth + (maxWidth * 2 + 2)) - "Cluster#".length(), true));
989
990    temp.append("\n");
991    temp.append(pad("Attribute", " ", maxAttWidth - "Attribute".length(), false));
992
993   
994    temp.append(pad("Full Data", " ", maxWidth + 1 - "Full Data".length(), true));
995
996    // cluster numbers
997    for (int i = 0; i < m_NumClusters; i++) {
998      String clustNum = "" + i;
999      temp.append(pad(clustNum, " ", maxWidth + 1 - clustNum.length(), true));
1000    }
1001    temp.append("\n");
1002
1003    // cluster sizes
1004    String cSize = "(" + Utils.sum(m_ClusterSizes) + ")";
1005    temp.append(pad(cSize, " ", maxAttWidth + maxWidth + 1 - cSize.length(), true));
1006    for (int i = 0; i < m_NumClusters; i++) {
1007      cSize = "(" + m_ClusterSizes[i] + ")";
1008      temp.append(pad(cSize, " ",maxWidth + 1 - cSize.length(), true));
1009    }
1010    temp.append("\n");
1011
1012    temp.append(pad("", "=", maxAttWidth + 
1013                    (maxWidth * (m_ClusterCentroids.numInstances()+1) 
1014                     + m_ClusterCentroids.numInstances() + 1), true));
1015    temp.append("\n");
1016
1017    for (int i = 0; i < m_ClusterCentroids.numAttributes(); i++) {
1018      String attName = m_ClusterCentroids.attribute(i).name();
1019      temp.append(attName);
1020      for (int j = 0; j < maxAttWidth - attName.length(); j++) {
1021        temp.append(" ");
1022      }
1023
1024      String strVal;
1025      String valMeanMode;
1026      // full data
1027      if (m_ClusterCentroids.attribute(i).isNominal()) {
1028        if (m_FullMeansOrMediansOrModes[i] == -1) { // missing
1029          valMeanMode = pad("missing", " ", maxWidth + 1 - "missing".length(), true);
1030        } else {
1031          valMeanMode = 
1032            pad((strVal = m_ClusterCentroids.attribute(i).value((int)m_FullMeansOrMediansOrModes[i])),
1033                " ", maxWidth + 1 - strVal.length(), true);
1034        }
1035      } else {
1036        if (Double.isNaN(m_FullMeansOrMediansOrModes[i])) {
1037          valMeanMode = pad("missing", " ", maxWidth + 1 - "missing".length(), true);
1038        } else {
1039          valMeanMode =  pad((strVal = Utils.doubleToString(m_FullMeansOrMediansOrModes[i],
1040                                                            maxWidth,4).trim()), 
1041                             " ", maxWidth + 1 - strVal.length(), true);
1042        }
1043      }
1044      temp.append(valMeanMode);
1045
1046      for (int j = 0; j < m_NumClusters; j++) {
1047        if (m_ClusterCentroids.attribute(i).isNominal()) {
1048          if (m_ClusterCentroids.instance(j).isMissing(i)) {
1049            valMeanMode = pad("missing", " ", maxWidth + 1 - "missing".length(), true);
1050          } else {
1051            valMeanMode = 
1052              pad((strVal = m_ClusterCentroids.attribute(i).value((int)m_ClusterCentroids.instance(j).value(i))),
1053                  " ", maxWidth + 1 - strVal.length(), true);
1054          }
1055        } else {
1056          if (m_ClusterCentroids.instance(j).isMissing(i)) {
1057            valMeanMode = pad("missing", " ", maxWidth + 1 - "missing".length(), true);
1058          } else {
1059            valMeanMode = pad((strVal = Utils.doubleToString(m_ClusterCentroids.instance(j).value(i),
1060                                                             maxWidth,4).trim()), 
1061                              " ", maxWidth + 1 - strVal.length(), true);
1062          }
1063        }
1064        temp.append(valMeanMode);
1065      }
1066      temp.append("\n");
1067
1068      if (m_displayStdDevs) {
1069        // Std devs/max nominal
1070        String stdDevVal = "";
1071
1072        if (m_ClusterCentroids.attribute(i).isNominal()) {
1073          // Do the values of the nominal attribute
1074          Attribute a = m_ClusterCentroids.attribute(i);
1075          for (int j = 0; j < a.numValues(); j++) {
1076            // full data
1077            String val = "  " + a.value(j);
1078            temp.append(pad(val, " ", maxAttWidth + 1 - val.length(), false));
1079            int count = m_FullNominalCounts[i][j];
1080            int percent = (int)((double)m_FullNominalCounts[i][j] /
1081                                Utils.sum(m_ClusterSizes) * 100.0);
1082            String percentS = "" + percent + "%)";
1083            percentS = pad(percentS, " ", 5 - percentS.length(), true);
1084            stdDevVal = "" + count + " (" + percentS;
1085            stdDevVal = 
1086              pad(stdDevVal, " ", maxWidth + 1 - stdDevVal.length(), true);
1087            temp.append(stdDevVal);
1088
1089            // Clusters
1090            for (int k = 0; k < m_NumClusters; k++) {
1091              count = m_ClusterNominalCounts[k][i][j];
1092              percent = (int)((double)m_ClusterNominalCounts[k][i][j] /
1093                              m_ClusterSizes[k] * 100.0);
1094              percentS = "" + percent + "%)";
1095              percentS = pad(percentS, " ", 5 - percentS.length(), true);
1096              stdDevVal = "" + count + " (" + percentS;
1097              stdDevVal = 
1098                pad(stdDevVal, " ", maxWidth + 1 - stdDevVal.length(), true);
1099              temp.append(stdDevVal);
1100            }
1101            temp.append("\n");
1102          }
1103          // missing (if any)
1104          if (m_FullMissingCounts[i] > 0) {
1105            // Full data
1106            temp.append(pad("  missing", " ", maxAttWidth + 1 - "  missing".length(), false));
1107            int count = m_FullMissingCounts[i];
1108            int percent = (int)((double)m_FullMissingCounts[i] /
1109                                Utils.sum(m_ClusterSizes) * 100.0);
1110            String percentS = "" + percent + "%)";
1111            percentS = pad(percentS, " ", 5 - percentS.length(), true);
1112            stdDevVal = "" + count + " (" + percentS;
1113            stdDevVal = 
1114              pad(stdDevVal, " ", maxWidth + 1 - stdDevVal.length(), true);
1115            temp.append(stdDevVal);
1116           
1117            // Clusters
1118            for (int k = 0; k < m_NumClusters; k++) {
1119              count = m_ClusterMissingCounts[k][i];
1120              percent = (int)((double)m_ClusterMissingCounts[k][i] /
1121                              m_ClusterSizes[k] * 100.0);
1122              percentS = "" + percent + "%)";
1123              percentS = pad(percentS, " ", 5 - percentS.length(), true);
1124              stdDevVal = "" + count + " (" + percentS;
1125              stdDevVal = 
1126                pad(stdDevVal, " ", maxWidth + 1 - stdDevVal.length(), true);
1127              temp.append(stdDevVal);
1128            }
1129
1130            temp.append("\n");
1131          }
1132
1133          temp.append("\n");
1134        } else {
1135          // Full data
1136          if (Double.isNaN(m_FullMeansOrMediansOrModes[i])) {
1137            stdDevVal = pad("--", " ", maxAttWidth + maxWidth + 1 - 2, true);
1138          } else {
1139            stdDevVal = pad((strVal = plusMinus
1140                             + Utils.doubleToString(m_FullStdDevs[i],
1141                                                    maxWidth,4).trim()), 
1142                            " ", maxWidth + maxAttWidth + 1 - strVal.length(), true);
1143          }
1144          temp.append(stdDevVal);
1145
1146          // Clusters
1147          for (int j = 0; j < m_NumClusters; j++) {
1148            if (m_ClusterCentroids.instance(j).isMissing(i)) {
1149              stdDevVal = pad("--", " ", maxWidth + 1 - 2, true);
1150            } else {
1151              stdDevVal = 
1152                pad((strVal = plusMinus
1153                     + Utils.doubleToString(m_ClusterStdDevs.instance(j).value(i),
1154                                            maxWidth,4).trim()), 
1155                    " ", maxWidth + 1 - strVal.length(), true);
1156            }
1157            temp.append(stdDevVal);
1158          }
1159          temp.append("\n\n");
1160        }
1161      }
1162    }
1163
1164    temp.append("\n\n");
1165    return temp.toString();
1166  }
1167
1168  private String pad(String source, String padChar, 
1169                     int length, boolean leftPad) {
1170    StringBuffer temp = new StringBuffer();
1171
1172    if (leftPad) {
1173      for (int i = 0; i< length; i++) {
1174        temp.append(padChar);
1175      }
1176      temp.append(source);
1177    } else {
1178      temp.append(source);
1179      for (int i = 0; i< length; i++) {
1180        temp.append(padChar);
1181      }
1182    }
1183    return temp.toString();
1184  }
1185
1186  /**
1187   * Gets the the cluster centroids
1188   *
1189   * @return            the cluster centroids
1190   */
1191  public Instances getClusterCentroids() {
1192    return m_ClusterCentroids;
1193  }
1194
1195  /**
1196   * Gets the standard deviations of the numeric attributes in each cluster
1197   *
1198   * @return            the standard deviations of the numeric attributes
1199   *                    in each cluster
1200   */
1201  public Instances getClusterStandardDevs() {
1202    return m_ClusterStdDevs;
1203  }
1204
1205  /**
1206   * Returns for each cluster the frequency counts for the values of each
1207   * nominal attribute
1208   *
1209   * @return            the counts
1210   */
1211  public int [][][] getClusterNominalCounts() {
1212    return m_ClusterNominalCounts;
1213  }
1214
1215  /**
1216   * Gets the squared error for all clusters
1217   *
1218   * @return            the squared error
1219   */
1220  public double getSquaredError() {
1221    return Utils.sum(m_squaredErrors);
1222  }
1223
1224  /**
1225   * Gets the number of instances in each cluster
1226   *
1227   * @return            The number of instances in each cluster
1228   */
1229  public int [] getClusterSizes() {
1230    return m_ClusterSizes;
1231  }
1232 
1233  /**
1234   * Gets the assignments for each instance
1235   * @return Array of indexes of the centroid assigned to each instance
1236   * @throws Exception if order of instances wasn't preserved or no assignments were made
1237   */
1238  public int [] getAssignments() throws Exception{
1239    if(!m_PreserveOrder){
1240      throw new Exception("The assignments are only available when order of instances is preserved (-O)");
1241    }
1242    if(m_Assignments == null){
1243      throw new Exception("No assignments made.");
1244    }
1245    return m_Assignments;
1246  }
1247       
1248  /**
1249   * Returns the revision string.
1250   *
1251   * @return            the revision
1252   */
1253  public String getRevision() {
1254    return RevisionUtils.extract("$Revision: 5987 $");
1255  }
1256
1257  /**
1258   * Main method for testing this class.
1259   *
1260   * @param argv should contain the following arguments: <p>
1261   * -t training file [-N number of clusters]
1262   */
1263  public static void main (String[] argv) {
1264    runClusterer(new SimpleKMeans(), argv);
1265  }
1266}
1267
Note: See TracBrowser for help on using the repository browser.