source: src/main/java/weka/clusterers/ClusterEvaluation.java @ 26

Last change on this file since 26 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 39.7 KB
RevLine 
[4]1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    ClusterEvaluation.java
19 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package  weka.clusterers;
24
25import weka.core.Drawable;
26import weka.core.Instance;
27import weka.core.Instances;
28import weka.core.Option;
29import weka.core.OptionHandler;
30import weka.core.Range;
31import weka.core.RevisionHandler;
32import weka.core.RevisionUtils;
33import weka.core.Utils;
34import weka.core.converters.ConverterUtils.DataSource;
35import weka.filters.Filter;
36import weka.filters.unsupervised.attribute.Remove;
37
38import java.beans.BeanInfo;
39import java.beans.Introspector;
40import java.beans.MethodDescriptor;
41import java.io.BufferedWriter;
42import java.io.FileWriter;
43import java.io.Serializable;
44import java.lang.reflect.Method;
45import java.util.Enumeration;
46import java.util.Random;
47import java.util.Vector;
48
49/**
50 * Class for evaluating clustering models.<p/>
51 *
52 * Valid options are: <p/>
53 *
54 * -t name of the training file <br/>
55 * Specify the training file. <p/>
56 *
57 * -T name of the test file <br/>
58 * Specify the test file to apply clusterer to. <p/>
59 *
60 * -d name of file to save clustering model to <br/>
61 * Specify output file. <p/>
62 *
63 * -l name of file to load clustering model from <br/>
64 * Specifiy input file. <p/>
65 *
66 * -p attribute range <br/>
67 * Output predictions. Predictions are for the training file if only the
68 * training file is specified, otherwise they are for the test file. The range
69 * specifies attribute values to be output with the predictions.
70 * Use '-p 0' for none. <p/>
71 *
72 * -x num folds <br/>
73 * Set the number of folds for a cross validation of the training data.
74 * Cross validation can only be done for distribution clusterers and will
75 * be performed if the test file is missing. <p/>
76 *
77 * -s num <br/>
78 * Sets the seed for randomizing the data for cross-validation. <p/>
79 *
80 * -c class <br/>
81 * Set the class attribute. If set, then class based evaluation of clustering
82 * is performed. <p/>
83 *
84 * -g name of graph file <br/>
85 * Outputs the graph representation of the clusterer to the file. Only for
86 * clusterer that implemented the <code>weka.core.Drawable</code> interface.
87 * <p/>
88 *
89 * @author   Mark Hall (mhall@cs.waikato.ac.nz)
90 * @version  $Revision: 6021 $
91 * @see      weka.core.Drawable
92 */
93public class ClusterEvaluation 
94  implements Serializable, RevisionHandler {
95
96  /** for serialization */
97  static final long serialVersionUID = -830188327319128005L;
98 
99  /** the clusterer */
100  private Clusterer m_Clusterer;
101
102  /** holds a string describing the results of clustering the training data */
103  private StringBuffer m_clusteringResults;
104
105  /** holds the number of clusters found by the clusterer */
106  private int m_numClusters;
107
108  /** holds the assigments of instances to clusters for a particular testing
109      dataset */
110  private double[] m_clusterAssignments;
111
112  /** holds the average log likelihood for a particular testing dataset
113     if the clusterer is a DensityBasedClusterer */
114  private double m_logL;
115
116  /** will hold the mapping of classes to clusters (for class based
117      evaluation) */
118  private int[] m_classToCluster = null;
119
120  /**
121   * set the clusterer
122   * @param clusterer the clusterer to use
123   */
124  public void setClusterer(Clusterer clusterer) {
125    m_Clusterer = clusterer;
126  }
127
128  /**
129   * return the results of clustering.
130   * @return a string detailing the results of clustering a data set
131   */
132  public String clusterResultsToString() {
133    return m_clusteringResults.toString();
134  }
135
136  /**
137   * Return the number of clusters found for the most recent call to
138   * evaluateClusterer
139   * @return the number of clusters found
140   */
141  public int getNumClusters() {
142    return m_numClusters;
143  }
144
145  /**
146   * Return an array of cluster assignments corresponding to the most
147   * recent set of instances clustered.
148   * @return an array of cluster assignments
149   */
150  public double[] getClusterAssignments() {
151    return m_clusterAssignments;
152  }
153
154  /**
155   * Return the array (ordered by cluster number) of minimum error class to
156   * cluster mappings
157   * @return an array of class to cluster mappings
158   */
159  public int[] getClassesToClusters() {
160    return m_classToCluster;
161  }
162
163  /**
164   * Return the log likelihood corresponding to the most recent
165   * set of instances clustered.
166   *
167   * @return a <code>double</code> value
168   */
169  public double getLogLikelihood() {
170    return m_logL;
171  }
172
173  /**
174   * Constructor. Sets defaults for each member variable. Default Clusterer
175   * is EM.
176   */
177  public ClusterEvaluation () {
178    setClusterer(new SimpleKMeans());
179    m_clusteringResults = new StringBuffer();
180    m_clusterAssignments = null;
181  }
182
183  /**
184   * Evaluate the clusterer on a set of instances. Calculates clustering
185   * statistics and stores cluster assigments for the instances in
186   * m_clusterAssignments
187   *
188   * @param test the set of instances to cluster
189   * @throws Exception if something goes wrong
190   */
191  public void evaluateClusterer(Instances test) throws Exception {
192    evaluateClusterer(test, "");
193  }
194
195  /**
196   * Evaluate the clusterer on a set of instances. Calculates clustering
197   * statistics and stores cluster assigments for the instances in
198   * m_clusterAssignments
199   *
200   * @param test the set of instances to cluster
201   * @param testFileName the name of the test file for incremental testing,
202   * if "" or null then not used
203   * @throws Exception if something goes wrong
204   */
205  public void evaluateClusterer(Instances test, String testFileName) throws Exception {
206    int i = 0;
207    int cnum;
208    double loglk = 0.0;
209    int cc = m_Clusterer.numberOfClusters();
210    m_numClusters = cc;
211    double[] instanceStats = new double[cc];
212    Instances testRaw = null;
213    boolean hasClass = (test.classIndex() >= 0);
214    int unclusteredInstances = 0;
215    Vector<Double> clusterAssignments = new Vector<Double>();
216    Filter filter = null;
217    DataSource source = null;
218    Instance inst;
219
220    if (testFileName == null)
221      testFileName = "";
222   
223    // load data
224    if (testFileName.length() != 0)
225      source = new DataSource(testFileName);
226    else
227      source = new DataSource(test);
228    testRaw = source.getStructure(test.classIndex());
229   
230    // If class is set then do class based evaluation as well
231    if (hasClass) {
232      if (testRaw.classAttribute().isNumeric())
233        throw new Exception("ClusterEvaluation: Class must be nominal!");
234
235      filter = new Remove();
236      ((Remove) filter).setAttributeIndices("" + (testRaw.classIndex() + 1));
237      ((Remove) filter).setInvertSelection(false);
238      filter.setInputFormat(testRaw);
239    }
240   
241    i = 0;
242    while (source.hasMoreElements(testRaw)) {
243      // next instance
244      inst = source.nextElement(testRaw);
245      if (filter != null) {
246        filter.input(inst);
247        filter.batchFinished();
248        inst = filter.output();
249      }
250     
251      cnum = -1;
252      try {
253        if (m_Clusterer instanceof DensityBasedClusterer) {
254          loglk += ((DensityBasedClusterer)m_Clusterer).
255            logDensityForInstance(inst);
256          cnum = m_Clusterer.clusterInstance(inst); 
257          clusterAssignments.add((double) cnum);
258        }
259        else {
260          cnum = m_Clusterer.clusterInstance(inst);
261          clusterAssignments.add((double) cnum);
262        }
263      }
264      catch (Exception e) {
265        clusterAssignments.add(-1.0);
266        unclusteredInstances++;
267      }
268     
269      if (cnum != -1) {
270        instanceStats[cnum]++;
271      }
272    }
273   
274    double sum = Utils.sum(instanceStats);
275    loglk /= sum;
276    m_logL = loglk;
277    m_clusterAssignments = new double [clusterAssignments.size()];
278    for (i = 0; i < clusterAssignments.size(); i++) {
279      m_clusterAssignments[i] = clusterAssignments.get(i);
280    }
281    int numInstFieldWidth = (int)((Math.log(clusterAssignments.size())/Math.log(10))+1);
282   
283    m_clusteringResults.append(m_Clusterer.toString());
284    m_clusteringResults.append("Clustered Instances\n\n");
285    int clustFieldWidth = (int)((Math.log(cc)/Math.log(10))+1);
286    for (i = 0; i < cc; i++) {
287      if (instanceStats[i] > 0)
288        m_clusteringResults.append(Utils.doubleToString((double)i, 
289                                                        clustFieldWidth, 0) 
290                                   + "      " 
291                                   + Utils.doubleToString(instanceStats[i],
292                                                          numInstFieldWidth, 0) 
293                                   + " (" 
294                                   + Utils.doubleToString((instanceStats[i] / 
295                                                           sum * 100.0)
296                                                          , 3, 0) + "%)\n");
297    }
298   
299    if (unclusteredInstances > 0)
300      m_clusteringResults.append("\nUnclustered instances : "
301                                 +unclusteredInstances);
302
303    if (m_Clusterer instanceof DensityBasedClusterer)
304      m_clusteringResults.append("\n\nLog likelihood: " 
305                                 + Utils.doubleToString(loglk, 1, 5) 
306                                 + "\n");       
307   
308    if (hasClass) {
309      evaluateClustersWithRespectToClass(test, testFileName);
310    }
311  }
312
313  /**
314   * Evaluates cluster assignments with respect to actual class labels.
315   * Assumes that m_Clusterer has been trained and tested on
316   * inst (minus the class).
317   *
318   * @param inst the instances (including class) to evaluate with respect to
319   * @param fileName the name of the test file for incremental testing,
320   * if "" or null then not used
321   * @throws Exception if something goes wrong
322   */
323  private void evaluateClustersWithRespectToClass(Instances inst, String fileName)
324    throws Exception {
325   
326   
327   
328    int numClasses = inst.classAttribute().numValues();
329    int[][] counts = new int [m_numClusters][numClasses];
330    int[] clusterTotals = new int[m_numClusters];
331    double[] best = new double[m_numClusters+1];
332    double[] current = new double[m_numClusters+1];
333    DataSource source = null;
334    Instances instances = null;
335    Instance instance = null;
336    int i;
337    int numInstances;
338       
339
340    if (fileName == null)
341      fileName = "";
342   
343    if (fileName.length() != 0) {
344      source = new DataSource(fileName);
345    }
346    else
347      source = new DataSource(inst);
348    instances = source.getStructure(inst.classIndex());
349
350    i = 0;
351    while (source.hasMoreElements(instances)) {
352      instance = source.nextElement(instances);
353      if (m_clusterAssignments[i] >= 0) {
354        counts[(int)m_clusterAssignments[i]][(int)instance.classValue()]++;
355        clusterTotals[(int)m_clusterAssignments[i]]++;       
356      }
357      i++;
358    }
359    numInstances = i;
360   
361    best[m_numClusters] = Double.MAX_VALUE;
362    mapClasses(m_numClusters, 0, counts, clusterTotals, current, best, 0);
363
364    m_clusteringResults.append("\n\nClass attribute: "
365                        +inst.classAttribute().name()
366                        +"\n");
367    m_clusteringResults.append("Classes to Clusters:\n");
368    String matrixString = toMatrixString(counts, clusterTotals, new Instances(inst, 0));
369    m_clusteringResults.append(matrixString).append("\n");
370
371    int Cwidth = 1 + (int)(Math.log(m_numClusters) / Math.log(10));
372    // add the minimum error assignment
373    for (i = 0; i < m_numClusters; i++) {
374      if (clusterTotals[i] > 0) {
375        m_clusteringResults.append("Cluster "
376                                   +Utils.doubleToString((double)i,Cwidth,0));
377        m_clusteringResults.append(" <-- ");
378       
379        if (best[i] < 0) {
380          m_clusteringResults.append("No class\n");
381        } else {
382          m_clusteringResults.
383            append(inst.classAttribute().value((int)best[i])).append("\n");
384        }
385      }
386    }
387    m_clusteringResults.append("\nIncorrectly clustered instances :\t"
388                               +best[m_numClusters]+"\t"
389                               +(Utils.doubleToString((best[m_numClusters] / 
390                                                       numInstances * 
391                                                       100.0), 8, 4))
392                               +" %\n");
393
394    // copy the class assignments
395    m_classToCluster = new int [m_numClusters];
396    for (i = 0; i < m_numClusters; i++) {
397      m_classToCluster[i] = (int)best[i];
398    }
399  }
400
401  /**
402   * Returns a "confusion" style matrix of classes to clusters assignments
403   * @param counts the counts of classes for each cluster
404   * @param clusterTotals total number of examples in each cluster
405   * @param inst the training instances (with class)
406   * @return the "confusion" style matrix as string
407   * @throws Exception if matrix can't be generated
408   */
409  private String toMatrixString(int[][] counts, int[] clusterTotals,
410                                Instances inst) 
411    throws Exception {
412    StringBuffer ms = new StringBuffer();
413
414    int maxval = 0;
415    for (int i = 0; i < m_numClusters; i++) {
416      for (int j = 0; j < counts[i].length; j++) {
417        if (counts[i][j] > maxval) {
418          maxval = counts[i][j];
419        }
420      }
421    }
422
423    int Cwidth = 1 + Math.max((int)(Math.log(maxval) / Math.log(10)),
424                              (int)(Math.log(m_numClusters) / Math.log(10)));
425
426    ms.append("\n");
427   
428    for (int i = 0; i < m_numClusters; i++) {
429      if (clusterTotals[i] > 0) {
430        ms.append(" ").append(Utils.doubleToString((double)i, Cwidth, 0));
431      }
432    }
433    ms.append("  <-- assigned to cluster\n");
434   
435    for (int i = 0; i< counts[0].length; i++) {
436
437      for (int j = 0; j < m_numClusters; j++) {
438        if (clusterTotals[j] > 0) {
439          ms.append(" ").append(Utils.doubleToString((double)counts[j][i], 
440                                                     Cwidth, 0));
441        }
442      }
443      ms.append(" | ").append(inst.classAttribute().value(i)).append("\n");
444    }
445
446    return ms.toString();
447  }
448
449  /**
450   * Finds the minimum error mapping of classes to clusters. Recursively
451   * considers all possible class to cluster assignments.
452   *
453   * @param numClusters the number of clusters
454   * @param lev the cluster being processed
455   * @param counts the counts of classes in clusters
456   * @param clusterTotals the total number of examples in each cluster
457   * @param current the current path through the class to cluster assignment
458   * tree
459   * @param best the best assignment path seen
460   * @param error accumulates the error for a particular path
461   */
462  public static void mapClasses(int numClusters, int lev, int[][] counts, int[] clusterTotals,
463                          double[] current, double[] best, int error) {
464    // leaf
465    if (lev == numClusters) {
466      if (error < best[numClusters]) {
467        best[numClusters] = error;
468        for (int i = 0; i < numClusters; i++) {
469          best[i] = current[i];
470        }
471      }
472    } else {
473      // empty cluster -- ignore
474      if (clusterTotals[lev] == 0) {
475        current[lev] = -1; // cluster ignored
476        mapClasses(numClusters, lev+1, counts, clusterTotals, current, best,
477                   error);
478      } else {
479        // first try no class assignment to this cluster
480        current[lev] = -1; // cluster assigned no class (ie all errors)
481        mapClasses(numClusters, lev+1, counts, clusterTotals, current, best,
482                   error+clusterTotals[lev]);
483        // now loop through the classes in this cluster
484        for (int i = 0; i < counts[0].length; i++) {
485          if (counts[lev][i] > 0) {
486            boolean ok = true;
487            // check to see if this class has already been assigned
488            for (int j = 0; j < lev; j++) {
489              if ((int)current[j] == i) {
490                ok = false;
491                break;
492              }
493            }
494            if (ok) {
495              current[lev] = i;
496              mapClasses(numClusters, lev+1, counts, clusterTotals, current, best, 
497                         (error + (clusterTotals[lev] - counts[lev][i])));
498            }
499          }
500        }
501      }
502    }
503  }
504
505  /**
506   * Evaluates a clusterer with the options given in an array of
507   * strings. It takes the string indicated by "-t" as training file, the
508   * string indicated by "-T" as test file.
509   * If the test file is missing, a stratified ten-fold
510   * cross-validation is performed (distribution clusterers only).
511   * Using "-x" you can change the number of
512   * folds to be used, and using "-s" the random seed.
513   * If the "-p" option is present it outputs the classification for
514   * each test instance. If you provide the name of an object file using
515   * "-l", a clusterer will be loaded from the given file. If you provide the
516   * name of an object file using "-d", the clusterer built from the
517   * training data will be saved to the given file.
518   *
519   * @param clusterer machine learning clusterer
520   * @param options the array of string containing the options
521   * @throws Exception if model could not be evaluated successfully
522   * @return a string describing the results
523   */
524  public static String evaluateClusterer(Clusterer clusterer, String[] options)
525    throws Exception {
526   
527    int seed = 1, folds = 10;
528    boolean doXval = false;
529    Instances train = null;
530    Random random;
531    String trainFileName, testFileName, seedString, foldsString;
532    String objectInputFileName, objectOutputFileName, attributeRangeString;
533    String graphFileName;
534    String[] savedOptions = null;
535    boolean printClusterAssignments = false;
536    Range attributesToOutput = null;
537    StringBuffer text = new StringBuffer();
538    int theClass = -1; // class based evaluation of clustering
539    boolean updateable = (clusterer instanceof UpdateableClusterer);
540    DataSource source = null;
541    Instance inst;
542
543    if (Utils.getFlag('h', options) || Utils.getFlag("help", options)) {
544     
545      // global info requested as well?
546      boolean globalInfo = Utils.getFlag("synopsis", options) ||
547        Utils.getFlag("info", options);
548     
549      throw  new Exception("Help requested." 
550          + makeOptionString(clusterer, globalInfo));
551    }
552   
553    try {
554      // Get basic options (options the same for all clusterers
555      //printClusterAssignments = Utils.getFlag('p', options);
556      objectInputFileName = Utils.getOption('l', options);
557      objectOutputFileName = Utils.getOption('d', options);
558      trainFileName = Utils.getOption('t', options);
559      testFileName = Utils.getOption('T', options);
560      graphFileName = Utils.getOption('g', options);
561
562      // Check -p option
563      try {
564        attributeRangeString = Utils.getOption('p', options);
565      }
566      catch (Exception e) {
567        throw new Exception(e.getMessage() + "\nNOTE: the -p option has changed. " +
568                            "It now expects a parameter specifying a range of attributes " +
569                            "to list with the predictions. Use '-p 0' for none.");
570      }
571      if (attributeRangeString.length() != 0) {
572        printClusterAssignments = true;
573        if (!attributeRangeString.equals("0")) 
574          attributesToOutput = new Range(attributeRangeString);
575      }
576
577      if (trainFileName.length() == 0) {
578        if (objectInputFileName.length() == 0) {
579          throw  new Exception("No training file and no object " 
580                               + "input file given.");
581        }
582
583        if (testFileName.length() == 0) {
584          throw  new Exception("No training file and no test file given.");
585        }
586      }
587      else {
588        if ((objectInputFileName.length() != 0) 
589            && (printClusterAssignments == false)) {
590          throw  new Exception("Can't use both train and model file " 
591                               + "unless -p specified.");
592        }
593      }
594
595      seedString = Utils.getOption('s', options);
596
597      if (seedString.length() != 0) {
598        seed = Integer.parseInt(seedString);
599      }
600
601      foldsString = Utils.getOption('x', options);
602
603      if (foldsString.length() != 0) {
604        folds = Integer.parseInt(foldsString);
605        doXval = true;
606      }
607    }
608    catch (Exception e) {
609      throw  new Exception('\n' + e.getMessage() 
610                           + makeOptionString(clusterer, false));
611    }
612
613    try {
614      if (trainFileName.length() != 0) {
615        source = new DataSource(trainFileName);
616        train  = source.getStructure();
617
618        String classString = Utils.getOption('c',options);
619        if (classString.length() != 0) {
620          if (classString.compareTo("last") == 0)
621            theClass = train.numAttributes();
622          else if (classString.compareTo("first") == 0)
623            theClass = 1;
624          else
625            theClass = Integer.parseInt(classString);
626
627          if (theClass != -1) {
628            if (doXval || testFileName.length() != 0)
629              throw new Exception("Can only do class based evaluation on the "
630                  +"training data");
631
632            if (objectInputFileName.length() != 0)
633              throw new Exception("Can't load a clusterer and do class based "
634                  +"evaluation");
635
636            if (objectOutputFileName.length() != 0)
637              throw new Exception(
638                  "Can't do class based evaluation and save clusterer");
639          }
640        }
641        else {
642          // if the dataset defines a class attribute, use it
643          if (train.classIndex() != -1) {
644            theClass = train.classIndex() + 1;
645            System.err.println(
646                "Note: using class attribute from dataset, i.e., attribute #" 
647                + theClass);
648          }
649        }
650
651        if (theClass != -1) {
652          if (theClass < 1 || theClass > train.numAttributes())
653            throw new Exception("Class is out of range!");
654
655          if (!train.attribute(theClass - 1).isNominal())
656            throw new Exception("Class must be nominal!");
657         
658          train.setClassIndex(theClass - 1);
659        }
660      }
661    }
662    catch (Exception e) {
663      throw  new Exception("ClusterEvaluation: " + e.getMessage() + '.');
664    }
665
666    // Save options
667    if (options != null) {
668      savedOptions = new String[options.length];
669      System.arraycopy(options, 0, savedOptions, 0, options.length);
670    }
671
672    if (objectInputFileName.length() != 0)
673      Utils.checkForRemainingOptions(options);
674
675    // Set options for clusterer
676    if (clusterer instanceof OptionHandler)
677      ((OptionHandler)clusterer).setOptions(options);
678
679    Utils.checkForRemainingOptions(options);
680
681    Instances trainHeader = train;
682    if (objectInputFileName.length() != 0) {
683      // Load the clusterer from file
684      //      clusterer = (Clusterer) SerializationHelper.read(objectInputFileName);
685      java.io.ObjectInputStream ois = 
686        new java.io.ObjectInputStream(
687        new java.io.BufferedInputStream(
688        new java.io.FileInputStream(objectInputFileName)));
689      clusterer = (Clusterer) ois.readObject();
690      // try and get the training header
691      try {
692        trainHeader = (Instances) ois.readObject();
693      } catch (Exception ex) {
694        // don't moan if we cant
695      }
696    }
697    else {
698      // Build the clusterer if no object file provided
699      if (theClass == -1) {
700        if (updateable) {
701          clusterer.buildClusterer(source.getStructure());
702          while (source.hasMoreElements(train)) {
703            inst = source.nextElement(train);
704            ((UpdateableClusterer) clusterer).updateClusterer(inst);
705          }
706          ((UpdateableClusterer) clusterer).updateFinished();
707        }
708        else {
709          clusterer.buildClusterer(source.getDataSet());
710        }
711      }
712      else {
713        Remove removeClass = new Remove();
714        removeClass.setAttributeIndices("" + theClass);
715        removeClass.setInvertSelection(false);
716        removeClass.setInputFormat(train);
717        if (updateable) {
718          Instances clusterTrain = Filter.useFilter(train, removeClass);
719          clusterer.buildClusterer(clusterTrain);
720          trainHeader = clusterTrain;
721          while (source.hasMoreElements(train)) {
722            inst = source.nextElement(train);
723            removeClass.input(inst);
724            removeClass.batchFinished();
725            Instance clusterTrainInst = removeClass.output();
726            ((UpdateableClusterer) clusterer).updateClusterer(clusterTrainInst);
727          }
728          ((UpdateableClusterer) clusterer).updateFinished();
729        }
730        else {
731          Instances clusterTrain = Filter.useFilter(source.getDataSet(), removeClass);
732          clusterer.buildClusterer(clusterTrain);
733          trainHeader = clusterTrain;
734        }
735        ClusterEvaluation ce = new ClusterEvaluation();
736        ce.setClusterer(clusterer);
737        ce.evaluateClusterer(train, trainFileName);
738       
739        return "\n\n=== Clustering stats for training data ===\n\n" +
740          ce.clusterResultsToString();
741      }
742    }
743
744    /* Output cluster predictions only (for the test data if specified,
745       otherwise for the training data */
746    if (printClusterAssignments) {
747      return printClusterings(clusterer, trainFileName, testFileName, attributesToOutput);
748    }
749
750    text.append(clusterer.toString());
751    text.append("\n\n=== Clustering stats for training data ===\n\n" 
752                + printClusterStats(clusterer, trainFileName));
753
754    if (testFileName.length() != 0) {
755      // check header compatibility
756      DataSource test = new DataSource(testFileName);
757      Instances testStructure = test.getStructure();
758      if (!trainHeader.equalHeaders(testStructure)) {
759        throw new Exception("Training and testing data are not compatible\n" + trainHeader.equalHeadersMsg(testStructure));
760      }
761
762      text.append("\n\n=== Clustering stats for testing data ===\n\n" 
763                  + printClusterStats(clusterer, testFileName));
764    }
765
766    if ((clusterer instanceof DensityBasedClusterer) && 
767        (doXval == true) && 
768        (testFileName.length() == 0) && 
769        (objectInputFileName.length() == 0)) {
770      // cross validate the log likelihood on the training data
771      random = new Random(seed);
772      random.setSeed(seed);
773      train = source.getDataSet();
774      train.randomize(random);
775      text.append(
776          crossValidateModel(
777              clusterer.getClass().getName(), train, folds, savedOptions, random));
778    }
779
780    // Save the clusterer if an object output file is provided
781    if (objectOutputFileName.length() != 0) {
782      //SerializationHelper.write(objectOutputFileName, clusterer);
783      saveClusterer(objectOutputFileName, clusterer, trainHeader);
784    }
785
786    // If classifier is drawable output string describing graph
787    if ((clusterer instanceof Drawable) && (graphFileName.length() != 0)) {
788      BufferedWriter writer = new BufferedWriter(new FileWriter(graphFileName));
789      writer.write(((Drawable) clusterer).graph());
790      writer.newLine();
791      writer.flush();
792      writer.close();
793    }
794   
795    return  text.toString();
796  }
797
798  private static void saveClusterer(String fileName, 
799                             Clusterer clusterer, 
800                             Instances header) throws Exception {
801    java.io.ObjectOutputStream oos = 
802      new java.io.ObjectOutputStream(
803      new java.io.BufferedOutputStream(
804      new java.io.FileOutputStream(fileName)));
805
806    oos.writeObject(clusterer);
807    if (header != null) {
808      oos.writeObject(header);
809    }
810    oos.flush();
811    oos.close();
812  }
813
814  /**
815   * Perform a cross-validation for DensityBasedClusterer on a set of instances.
816   *
817   * @param clusterer the clusterer to use
818   * @param data the training data
819   * @param numFolds number of folds of cross validation to perform
820   * @param random random number seed for cross-validation
821   * @return the cross-validated log-likelihood
822   * @throws Exception if an error occurs
823   */
824  public static double crossValidateModel(DensityBasedClusterer clusterer,
825                                          Instances data,
826                                          int numFolds,
827                                          Random random) throws Exception {
828    Instances train, test;
829    double foldAv = 0;;
830    data = new Instances(data);
831    data.randomize(random);
832    //    double sumOW = 0;
833    for (int i = 0; i < numFolds; i++) {
834      // Build and test clusterer
835      train = data.trainCV(numFolds, i, random);
836
837      clusterer.buildClusterer(train);
838
839      test = data.testCV(numFolds, i);
840     
841      for (int j = 0; j < test.numInstances(); j++) {
842        try {
843          foldAv += ((DensityBasedClusterer)clusterer).
844            logDensityForInstance(test.instance(j));
845          //      sumOW += test.instance(j).weight();
846          //    double temp = Utils.sum(tempDist);
847        } catch (Exception ex) {
848          // unclustered instances
849        }
850      }
851    }
852   
853    //    return foldAv / sumOW;
854    return foldAv / data.numInstances();
855  }
856
857  /**
858   * Performs a cross-validation
859   * for a DensityBasedClusterer clusterer on a set of instances.
860   *
861   * @param clustererString a string naming the class of the clusterer
862   * @param data the data on which the cross-validation is to be
863   * performed
864   * @param numFolds the number of folds for the cross-validation
865   * @param options the options to the clusterer
866   * @param random a random number generator
867   * @return a string containing the cross validated log likelihood
868   * @throws Exception if a clusterer could not be generated
869   */
870  public static String crossValidateModel (String clustererString, 
871                                           Instances data, 
872                                           int numFolds, 
873                                           String[] options,
874                                           Random random)
875    throws Exception {
876    Clusterer clusterer = null;
877    String[] savedOptions = null;
878    double CvAv = 0.0;
879    StringBuffer CvString = new StringBuffer();
880
881    if (options != null) {
882      savedOptions = new String[options.length];
883    }
884
885    data = new Instances(data);
886
887    // create clusterer
888    try {
889      clusterer = (Clusterer)Class.forName(clustererString).newInstance();
890    }
891    catch (Exception e) {
892      throw  new Exception("Can't find class with name " 
893                           + clustererString + '.');
894    }
895
896    if (!(clusterer instanceof DensityBasedClusterer)) {
897      throw  new Exception(clustererString
898                           + " must be a distrinbution " 
899                           + "clusterer.");
900    }
901
902    // Save options
903    if (options != null) {
904      System.arraycopy(options, 0, savedOptions, 0, options.length);
905    }
906
907    // Parse options
908    if (clusterer instanceof OptionHandler) {
909      try {
910        ((OptionHandler)clusterer).setOptions(savedOptions);
911        Utils.checkForRemainingOptions(savedOptions);
912      }
913      catch (Exception e) {
914        throw  new Exception("Can't parse given options in " 
915                             + "cross-validation!");
916      }
917    }
918    CvAv = crossValidateModel((DensityBasedClusterer)clusterer, data, numFolds, random);
919
920    CvString.append("\n" + numFolds
921                    + " fold CV Log Likelihood: " 
922                    + Utils.doubleToString(CvAv, 6, 4) 
923                    + "\n");
924    return  CvString.toString();
925  }
926
927
928  // ===============
929  // Private methods
930  // ===============
931  /**
932   * Print the cluster statistics for either the training
933   * or the testing data.
934   *
935   * @param clusterer the clusterer to use for generating statistics.
936   * @param fileName the file to load
937   * @return a string containing cluster statistics.
938   * @throws Exception if statistics can't be generated.
939   */
940  private static String printClusterStats (Clusterer clusterer, 
941                                           String fileName)
942    throws Exception {
943    StringBuffer text = new StringBuffer();
944    int i = 0;
945    int cnum;
946    double loglk = 0.0;
947    int cc = clusterer.numberOfClusters();
948    double[] instanceStats = new double[cc];
949    int unclusteredInstances = 0;
950
951    if (fileName.length() != 0) {
952      DataSource source = new DataSource(fileName);
953      Instances structure = source.getStructure();
954      Instance inst;
955      while (source.hasMoreElements(structure)) {
956        inst = source.nextElement(structure);
957        try {
958          cnum = clusterer.clusterInstance(inst);
959
960          if (clusterer instanceof DensityBasedClusterer) {
961            loglk += ((DensityBasedClusterer)clusterer).
962              logDensityForInstance(inst);
963            //      temp = Utils.sum(dist);
964          }
965          instanceStats[cnum]++;
966        }
967        catch (Exception e) {
968          unclusteredInstances++;
969        }
970        i++;
971      }
972
973      /*
974      // count the actual number of used clusters
975      int count = 0;
976      for (i = 0; i < cc; i++) {
977        if (instanceStats[i] > 0) {
978          count++;
979        }
980      }
981      if (count > 0) {
982        double[] tempStats = new double [count];
983        count=0;
984        for (i=0;i<cc;i++) {
985          if (instanceStats[i] > 0) {
986            tempStats[count++] = instanceStats[i];
987        }
988        }
989        instanceStats = tempStats;
990        cc = instanceStats.length;
991        } */
992
993      int clustFieldWidth = (int)((Math.log(cc)/Math.log(10))+1);
994      int numInstFieldWidth = (int)((Math.log(i)/Math.log(10))+1);
995      double sum = Utils.sum(instanceStats);
996      loglk /= sum;
997      text.append("Clustered Instances\n");
998
999      for (i = 0; i < cc; i++) {
1000        if (instanceStats[i] > 0) {
1001          text.append(Utils.doubleToString((double)i, 
1002                                           clustFieldWidth, 0) 
1003                      + "      " 
1004                      + Utils.doubleToString(instanceStats[i], 
1005                                             numInstFieldWidth, 0) 
1006                      + " (" 
1007                    + Utils.doubleToString((instanceStats[i]/sum*100.0)
1008                                           , 3, 0) + "%)\n");
1009        }
1010      }
1011      if (unclusteredInstances > 0) {
1012        text.append("\nUnclustered Instances : "+unclusteredInstances);
1013      }
1014
1015      if (clusterer instanceof DensityBasedClusterer) {
1016        text.append("\n\nLog likelihood: " 
1017                    + Utils.doubleToString(loglk, 1, 5) 
1018                    + "\n");
1019      }
1020    }
1021
1022    return text.toString();
1023  }
1024
1025
1026  /**
1027   * Print the cluster assignments for either the training
1028   * or the testing data.
1029   *
1030   * @param clusterer the clusterer to use for cluster assignments
1031   * @param trainFileName the train file
1032   * @param testFileName an optional test file
1033   * @param attributesToOutput the attributes to print
1034   * @return a string containing the instance indexes and cluster assigns.
1035   * @throws Exception if cluster assignments can't be printed
1036   */
1037  private static String printClusterings (Clusterer clusterer, String trainFileName,
1038                                          String testFileName, Range attributesToOutput)
1039    throws Exception {
1040
1041    StringBuffer text = new StringBuffer();
1042    int i = 0;
1043    int cnum;
1044    DataSource source = null;
1045    Instance inst;
1046    Instances structure;
1047   
1048    if (testFileName.length() != 0)
1049      source = new DataSource(testFileName);
1050    else
1051      source = new DataSource(trainFileName);
1052   
1053    structure = source.getStructure();
1054    while (source.hasMoreElements(structure)) {
1055      inst = source.nextElement(structure);
1056      try {
1057        cnum = clusterer.clusterInstance(inst);
1058       
1059        text.append(i + " " + cnum + " "
1060            + attributeValuesString(inst, attributesToOutput) + "\n");
1061      }
1062      catch (Exception e) {
1063        /*        throw  new Exception('\n' + "Unable to cluster instance\n"
1064         + e.getMessage()); */
1065        text.append(i + " Unclustered "
1066            + attributeValuesString(inst, attributesToOutput) + "\n");
1067      }
1068      i++;
1069    }
1070   
1071    return text.toString();
1072  }
1073
1074  /**
1075   * Builds a string listing the attribute values in a specified range of indices,
1076   * separated by commas and enclosed in brackets.
1077   *
1078   * @param instance the instance to print the values from
1079   * @param attRange the range of the attributes to list
1080   * @return a string listing values of the attributes in the range
1081   */
1082  private static String attributeValuesString(Instance instance, Range attRange) {
1083    StringBuffer text = new StringBuffer();
1084    if (attRange != null) {
1085      boolean firstOutput = true;
1086      attRange.setUpper(instance.numAttributes() - 1);
1087      for (int i=0; i<instance.numAttributes(); i++)
1088        if (attRange.isInRange(i)) {
1089          if (firstOutput) text.append("(");
1090          else text.append(",");
1091          text.append(instance.toString(i));
1092          firstOutput = false;
1093        }
1094      if (!firstOutput) text.append(")");
1095    }
1096    return text.toString();
1097  }
1098
1099  /**
1100   * Make up the help string giving all the command line options
1101   *
1102   * @param clusterer the clusterer to include options for
1103   * @return a string detailing the valid command line options
1104   */
1105  private static String makeOptionString (Clusterer clusterer,
1106                                          boolean globalInfo) {
1107    StringBuffer optionsText = new StringBuffer("");
1108    // General options
1109    optionsText.append("\n\nGeneral options:\n\n");
1110    optionsText.append("-h or -help\n");
1111    optionsText.append("\tOutput help information.\n");
1112    optionsText.append("-synopsis or -info\n");
1113    optionsText.append("\tOutput synopsis for clusterer (use in conjunction "
1114        + " with -h)\n");
1115    optionsText.append("-t <name of training file>\n");
1116    optionsText.append("\tSets training file.\n");
1117    optionsText.append("-T <name of test file>\n");
1118    optionsText.append("\tSets test file.\n");
1119    optionsText.append("-l <name of input file>\n");
1120    optionsText.append("\tSets model input file.\n");
1121    optionsText.append("-d <name of output file>\n");
1122    optionsText.append("\tSets model output file.\n");
1123    optionsText.append("-p <attribute range>\n");
1124    optionsText.append("\tOutput predictions. Predictions are for " 
1125                       + "training file" 
1126                       + "\n\tif only training file is specified," 
1127                       + "\n\totherwise predictions are for the test file."
1128                       + "\n\tThe range specifies attribute values to be output"
1129                       + "\n\twith the predictions. Use '-p 0' for none.\n");
1130    optionsText.append("-x <number of folds>\n");
1131    optionsText.append("\tOnly Distribution Clusterers can be cross validated.\n");
1132    optionsText.append("-s <random number seed>\n");
1133    optionsText.append("\tSets the seed for randomizing the data in cross-validation\n");
1134    optionsText.append("-c <class index>\n");
1135    optionsText.append("\tSet class attribute. If supplied, class is ignored");
1136    optionsText.append("\n\tduring clustering but is used in a classes to");
1137    optionsText.append("\n\tclusters evaluation.\n");
1138    if (clusterer instanceof Drawable) {
1139      optionsText.append("-g <name of graph file>\n");
1140      optionsText.append("\tOutputs the graph representation of the clusterer to the file.\n");
1141    }
1142
1143    // Get scheme-specific options
1144    if (clusterer instanceof OptionHandler) {
1145      optionsText.append("\nOptions specific to " 
1146                         + clusterer.getClass().getName() + ":\n\n");
1147      Enumeration enu = ((OptionHandler)clusterer).listOptions();
1148
1149      while (enu.hasMoreElements()) {
1150        Option option = (Option)enu.nextElement();
1151        optionsText.append(option.synopsis() + '\n');
1152        optionsText.append(option.description() + "\n");
1153      }
1154    }
1155   
1156    // Get global information (if available)
1157    if (globalInfo) {
1158      try {
1159        String gi = getGlobalInfo(clusterer);
1160        optionsText.append(gi);
1161      } catch (Exception ex) {
1162        // quietly ignore
1163      }
1164    }
1165
1166    return  optionsText.toString();
1167  }
1168 
1169  /**
1170   * Return the global info (if it exists) for the supplied clusterer
1171   *
1172   * @param clusterer the clusterer to get the global info for
1173   * @return the global info (synopsis) for the clusterer
1174   * @throws Exception if there is a problem reflecting on the clusterer
1175   */
1176  protected static String getGlobalInfo(Clusterer clusterer) throws Exception {
1177    BeanInfo bi = Introspector.getBeanInfo(clusterer.getClass());
1178    MethodDescriptor[] methods;
1179    methods = bi.getMethodDescriptors();
1180    Object[] args = {};
1181    String result = "\nSynopsis for " + clusterer.getClass().getName()
1182      + ":\n\n";
1183   
1184    for (int i = 0; i < methods.length; i++) {
1185      String name = methods[i].getDisplayName();
1186      Method meth = methods[i].getMethod();
1187      if (name.equals("globalInfo")) {
1188        String globalInfo = (String)(meth.invoke(clusterer, args));
1189        result += globalInfo;
1190        break;
1191      }
1192    }
1193   
1194    return result;
1195  }
1196
1197  /**
1198   * Tests whether the current evaluation object is equal to another
1199   * evaluation object
1200   *
1201   * @param obj the object to compare against
1202   * @return true if the two objects are equal
1203   */
1204  public boolean equals(Object obj) {
1205    if ((obj == null) || !(obj.getClass().equals(this.getClass())))
1206      return false;
1207   
1208    ClusterEvaluation cmp = (ClusterEvaluation) obj;
1209   
1210    if ((m_classToCluster != null) != (cmp.m_classToCluster != null)) return false;
1211    if (m_classToCluster != null) {
1212      for (int i = 0; i < m_classToCluster.length; i++) {
1213        if (m_classToCluster[i] != cmp.m_classToCluster[i])
1214        return false;
1215      }
1216    }
1217   
1218    if ((m_clusterAssignments != null) != (cmp.m_clusterAssignments != null)) return false;
1219    if (m_clusterAssignments != null) {
1220      for (int i = 0; i < m_clusterAssignments.length; i++) {
1221        if (m_clusterAssignments[i] != cmp.m_clusterAssignments[i])
1222        return false;
1223      }
1224    }
1225
1226    if (Double.isNaN(m_logL) != Double.isNaN(cmp.m_logL)) return false;
1227    if (!Double.isNaN(m_logL)) {
1228      if (m_logL != cmp.m_logL) return false;
1229    }
1230   
1231    if (m_numClusters != cmp.m_numClusters) return false;
1232   
1233    // TODO: better comparison? via members?
1234    String clusteringResults1 = m_clusteringResults.toString().replaceAll("Elapsed time.*", "");
1235    String clusteringResults2 = cmp.m_clusteringResults.toString().replaceAll("Elapsed time.*", "");
1236    if (!clusteringResults1.equals(clusteringResults2)) return false;
1237   
1238    return true;
1239  }
1240 
1241  /**
1242   * Returns the revision string.
1243   *
1244   * @return            the revision
1245   */
1246  public String getRevision() {
1247    return RevisionUtils.extract("$Revision: 6021 $");
1248  }
1249
1250  /**
1251   * Main method for testing this class.
1252   *
1253   * @param args the options
1254   */
1255  public static void main (String[] args) {
1256    try {
1257      if (args.length == 0) {
1258        throw  new Exception("The first argument must be the name of a " 
1259                             + "clusterer");
1260      }
1261
1262      String ClustererString = args[0];
1263      args[0] = "";
1264      Clusterer newClusterer = AbstractClusterer.forName(ClustererString, null);
1265      System.out.println(evaluateClusterer(newClusterer, args));
1266    }
1267    catch (Exception e) {
1268      System.out.println(e.getMessage());
1269    }
1270  }
1271}
Note: See TracBrowser for help on using the repository browser.