source: src/main/java/weka/clusterers/HierarchicalClusterer.java @ 28

Last change on this file since 28 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 36.3 KB
Line 
1/*
2 * This program is free software; you can redistribute it and/or modify
3 * it under the terms of the GNU General Public License as published by
4 * the Free Software Foundation; either version 2 of the License, or
5 * (at your option) any later version.
6 *
7 * This program is distributed in the hope that it will be useful,
8 * but WITHOUT ANY WARRANTY; without even the implied warranty of
9 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 * GNU General Public License for more details.
11 *
12 * You should have received a copy of the GNU General Public License
13 * along with this program; if not, write to the Free Software
14 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16/*
17 * HierarchicalClusterer.java
18 * Copyright (C) 2009 University of Waikato, Hamilton, New Zealand
19*/
20/**
21 <!-- globalinfo-start -->
22 * Hierarchical clustering class.
23 * Implements a number of classic hierarchical clustering methods.
24 <!-- globalinfo-end -->
25 *
26 <!-- options-start -->
27 * Valid options are: <p/>
28 *
29 * <pre> -N
30 *  number of clusters
31 * </pre>
32 *
33 *
34 * <pre> -L
35 *  Link type (Single, Complete, Average, Mean, Centroid, Ward, Adjusted complete, Neighbor Joining)
36 *  [SINGLE|COMPLETE|AVERAGE|MEAN|CENTROID|WARD|ADJCOMLPETE|NEIGHBOR_JOINING]
37 * </pre>
38 *
39 * <pre> -A
40 * Distance function to use. (default: weka.core.EuclideanDistance)
41 * </pre>
42 *
43 * <pre> -P
44 * Print hierarchy in Newick format, which can be used for display in other programs.
45 * </pre>
46 * 
47 * <pre> -D
48 * If set, classifier is run in debug mode and may output additional info to the console.
49 * </pre>
50 *
51 * <pre> -B
52 * \If set, distance is interpreted as branch length, otherwise it is node height.
53 * </pre>
54 *
55 *<!-- options-end -->
56 *
57 *
58 * @author Remco Bouckaert (rrb@xm.co.nz, remco@cs.waikato.ac.nz)
59 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
60 * @version $Revision: 6042 $
61 */
62
63package weka.clusterers;
64
65import java.io.Serializable;
66import java.text.DecimalFormat;
67import java.util.Comparator;
68import java.util.Enumeration;
69import java.util.PriorityQueue;
70import java.util.Vector;
71
72import weka.core.Capabilities;
73import weka.core.CapabilitiesHandler;
74import weka.core.DistanceFunction;
75import weka.core.Drawable;
76import weka.core.EuclideanDistance;
77import weka.core.Instance;
78import weka.core.Instances;
79import weka.core.Option;
80import weka.core.OptionHandler;
81import weka.core.RevisionUtils;
82import weka.core.SelectedTag;
83import weka.core.Tag;
84import weka.core.Utils;
85import weka.core.Capabilities.Capability;
86
87public class HierarchicalClusterer extends AbstractClusterer implements OptionHandler, CapabilitiesHandler, Drawable {
88        private static final long serialVersionUID = 1L;
89
90        /** Whether the classifier is run in debug mode. */
91        protected boolean m_bDebug = false;
92
93        /** Whether the distance represent node height (if false) or branch length (if true). */
94        protected boolean m_bDistanceIsBranchLength = false;
95
96        /** training data **/
97        Instances m_instances;
98
99        /** number of clusters desired in clustering **/
100        int m_nNumClusters = 2;
101        public void setNumClusters(int nClusters) {m_nNumClusters = Math.max(1,nClusters);}
102        public int getNumClusters() {return m_nNumClusters;}
103       
104        /** distance function used for comparing members of a cluster **/
105        protected DistanceFunction m_DistanceFunction = new EuclideanDistance();
106        public DistanceFunction getDistanceFunction() {return m_DistanceFunction;}
107        public void setDistanceFunction(DistanceFunction distanceFunction) {m_DistanceFunction = distanceFunction;}
108
109        /** used for priority queue for efficient retrieval of pair of clusters to merge**/
110        class Tuple {
111                public Tuple(double d, int i, int j, int nSize1, int nSize2) {
112                        m_fDist = d;
113                        m_iCluster1 = i;
114                        m_iCluster2 = j;
115                        m_nClusterSize1 = nSize1;
116                        m_nClusterSize2 = nSize2;
117                }
118                double m_fDist;
119                int m_iCluster1;
120                int m_iCluster2;
121                int m_nClusterSize1;
122                int m_nClusterSize2;
123        }
124        /** comparator used by priority queue**/
125        class TupleComparator implements Comparator<Tuple> {
126                public int compare(Tuple o1, Tuple o2) {
127                        if (o1.m_fDist < o2.m_fDist) {
128                                return -1;
129                        } else if (o1.m_fDist == o2.m_fDist) {
130                                return 0;
131                        }
132                        return 1;
133                }
134        }
135
136        /** the various link types */
137        final static int SINGLE = 0;
138        final static int COMPLETE = 1;
139        final static int AVERAGE = 2;
140        final static int MEAN = 3;
141        final static int CENTROID = 4;
142        final static int WARD = 5;
143        final static int ADJCOMLPETE = 6;
144        final static int NEIGHBOR_JOINING = 7;
145        public static final Tag[] TAGS_LINK_TYPE = {
146          new Tag(SINGLE, "SINGLE"),
147          new Tag(COMPLETE, "COMPLETE"),
148          new Tag(AVERAGE, "AVERAGE"),
149          new Tag(MEAN, "MEAN"),
150          new Tag(CENTROID, "CENTROID"),
151          new Tag(WARD, "WARD"),
152          new Tag(ADJCOMLPETE,"ADJCOMLPETE"),
153          new Tag(NEIGHBOR_JOINING,"NEIGHBOR_JOINING")
154        };
155
156        /**
157         * Holds the Link type used calculate distance between clusters
158         */
159        int m_nLinkType = SINGLE;
160       
161        boolean m_bPrintNewick = true;;
162        public boolean getPrintNewick() {return m_bPrintNewick;}
163        public void setPrintNewick(boolean bPrintNewick) {m_bPrintNewick = bPrintNewick;}
164       
165        public void setLinkType(SelectedTag newLinkType) {
166                if (newLinkType.getTags() == TAGS_LINK_TYPE) {
167                        m_nLinkType = newLinkType.getSelectedTag().getID();
168                }
169        }
170
171        public SelectedTag getLinkType() {
172                return new SelectedTag(m_nLinkType, TAGS_LINK_TYPE);
173        }
174
175        /** class representing node in cluster hierarchy **/
176        class Node implements Serializable {
177                Node m_left;
178                Node m_right;
179                Node m_parent;
180                int m_iLeftInstance;
181                int m_iRightInstance;
182                double m_fLeftLength = 0;
183                double m_fRightLength = 0;
184                double m_fHeight = 0;
185                public String toString(int attIndex) {
186                        DecimalFormat myFormatter = new DecimalFormat("#.#####");
187
188                        if (m_left == null) {
189                                if (m_right == null) {
190                                        return "(" + m_instances.instance(m_iLeftInstance).stringValue(attIndex) + ":" + myFormatter.format(m_fLeftLength) + "," +
191                                                     m_instances.instance(m_iRightInstance).stringValue(attIndex) +":" + myFormatter.format(m_fRightLength) + ")";
192                                } else {
193                                        return "(" + m_instances.instance(m_iLeftInstance).stringValue(attIndex) + ":" + myFormatter.format(m_fLeftLength) + "," +
194                                                m_right.toString(attIndex) + ":" + myFormatter.format(m_fRightLength) + ")";
195                                }
196                        } else {
197                                if (m_right == null) {
198                                        return "(" + m_left.toString(attIndex) + ":" + myFormatter.format(m_fLeftLength) + "," +
199                                                     m_instances.instance(m_iRightInstance).stringValue(attIndex) + ":" + myFormatter.format(m_fRightLength) + ")";
200                                } else {
201                                        return "(" + m_left.toString(attIndex) + ":" + myFormatter.format(m_fLeftLength) + "," +m_right.toString(attIndex) + ":" + myFormatter.format(m_fRightLength) + ")";
202                                }
203                        }
204                }
205                public String toString2(int attIndex) {
206                        DecimalFormat myFormatter = new DecimalFormat("#.#####");
207
208                        if (m_left == null) {
209                                if (m_right == null) {
210                                        return "(" + m_instances.instance(m_iLeftInstance).value(attIndex) + ":" + myFormatter.format(m_fLeftLength) + "," +
211                                                     m_instances.instance(m_iRightInstance).value(attIndex) +":" + myFormatter.format(m_fRightLength) + ")";
212                                } else {
213                                        return "(" + m_instances.instance(m_iLeftInstance).value(attIndex) + ":" + myFormatter.format(m_fLeftLength) + "," +
214                                                m_right.toString2(attIndex) + ":" + myFormatter.format(m_fRightLength) + ")";
215                                }
216                        } else {
217                                if (m_right == null) {
218                                        return "(" + m_left.toString2(attIndex) + ":" + myFormatter.format(m_fLeftLength) + "," +
219                                                     m_instances.instance(m_iRightInstance).value(attIndex) + ":" + myFormatter.format(m_fRightLength) + ")";
220                                } else {
221                                        return "(" + m_left.toString2(attIndex) + ":" + myFormatter.format(m_fLeftLength) + "," +m_right.toString2(attIndex) + ":" + myFormatter.format(m_fRightLength) + ")";
222                                }
223                        }
224                }
225                void setHeight(double fHeight1, double fHeight2) {
226                        m_fHeight = fHeight1;
227                        if (m_left == null) {
228                                m_fLeftLength = fHeight1;
229                        } else {
230                                m_fLeftLength = fHeight1 - m_left.m_fHeight;
231                        }
232                        if (m_right == null) {
233                                m_fRightLength = fHeight2;
234                        } else {
235                                m_fRightLength = fHeight2 - m_right.m_fHeight;
236                        }
237                }
238                void setLength(double fLength1, double fLength2) {
239                        m_fLeftLength = fLength1;
240                        m_fRightLength = fLength2;
241                        m_fHeight = fLength1;
242                        if (m_left != null) {
243                                m_fHeight += m_left.m_fHeight;
244                        }
245                }
246        }
247        Node [] m_clusters;
248        int [] m_nClusterNr;
249       
250       
251        @Override
252        public void buildClusterer(Instances data) throws Exception {
253//              /System.err.println("Method " + m_nLinkType);
254                m_instances = data;
255                int nInstances = m_instances.numInstances();
256                if (nInstances == 0) {
257                        return;
258                }
259                m_DistanceFunction.setInstances(m_instances);
260                // use array of integer vectors to store cluster indices,
261                // starting with one cluster per instance
262                Vector<Integer> [] nClusterID = new Vector[data.numInstances()];
263                for (int i = 0; i < data.numInstances(); i++) {
264                        nClusterID[i] = new Vector<Integer>();
265                        nClusterID[i].add(i);
266                }
267                // calculate distance matrix
268                int nClusters = data.numInstances();
269               
270                // used for keeping track of hierarchy
271                Node [] clusterNodes = new Node[nInstances];
272                if (m_nLinkType == NEIGHBOR_JOINING) {
273                        neighborJoining(nClusters, nClusterID, clusterNodes);
274                } else {
275                        doLinkClustering(nClusters, nClusterID, clusterNodes);
276                }
277               
278                // move all clusters in m_nClusterID array
279                // & collect hierarchy
280                int iCurrent = 0;
281                m_clusters = new Node[m_nNumClusters];
282                m_nClusterNr = new int[nInstances];
283                for (int i = 0; i < nInstances; i++) {
284                        if (nClusterID[i].size() > 0) {
285                                for (int j = 0; j < nClusterID[i].size(); j++) {
286                                        m_nClusterNr[nClusterID[i].elementAt(j)] = iCurrent;
287                                }
288                                m_clusters[iCurrent] = clusterNodes[i];
289                                iCurrent++;
290                        }
291                }
292               
293        } // buildClusterer
294
295        /** use neighbor joining algorithm for clustering
296         * This is roughly based on the RapidNJ simple implementation and runs at O(n^3)
297         * More efficient implementations exist, see RapidNJ (or my GPU implementation :-))
298         * @param nClusters
299         * @param nClusterID
300         * @param clusterNodes
301         */
302        void neighborJoining(int nClusters, Vector<Integer>[] nClusterID, Node [] clusterNodes) {
303                int n = m_instances.numInstances();
304
305                double [][] fDist = new double[nClusters][nClusters];
306                for (int i = 0; i < nClusters; i++) {
307                        fDist[i][i] = 0;
308                        for (int j = i+1; j < nClusters; j++) {
309                                fDist[i][j] = getDistance0(nClusterID[i], nClusterID[j]);
310                                fDist[j][i] = fDist[i][j];
311                        }
312                }
313               
314                double [] fSeparationSums = new double [n];
315                double [] fSeparations = new double [n];
316            int [] nNextActive = new int[n];
317
318                //calculate initial separation rows
319                for(int i = 0; i < n; i++){
320                    double fSum = 0;
321                    for(int j = 0; j < n; j++){
322                        fSum += fDist[i][j];
323                    }
324                    fSeparationSums[i] = fSum;
325                    fSeparations[i] = fSum / (nClusters - 2);
326                    nNextActive[i] = i +1;
327                }
328
329                while (nClusters > 2) {
330                        // find minimum
331                        int iMin1 = -1;
332                        int iMin2 = -1;
333                        double fMin = Double.MAX_VALUE;
334                        if (m_bDebug) {
335                                for (int i = 0; i < n; i++) {
336                                        if(nClusterID[i].size() > 0){
337                                                double [] fRow = fDist[i];
338                                                double fSep1 = fSeparations[i];
339                                                for(int j = 0; j < n; j++){
340                                            if(nClusterID[j].size() > 0 && i != j){
341                                                double fSep2 = fSeparations[j];
342                                                double fVal = fRow[j] - fSep1 - fSep2;
343       
344                                                                if(fVal < fMin){
345                                                                        // new minimum
346                                                                        iMin1 = i;
347                                                                        iMin2 = j;
348                                                                        fMin = fVal;
349                                                                }
350                                            }
351                                                }
352                                        }
353                                }
354                        } else {
355                                int i = 0;
356                                while (i < n) {
357                                        double fSep1 = fSeparations[i];
358                                        double [] fRow = fDist[i];
359                                        int j = nNextActive[i];
360                                        while (j < n) {
361                                double fSep2 = fSeparations[j];
362                                double fVal = fRow[j] - fSep1 - fSep2;
363                                                if(fVal < fMin){
364                                                        // new minimum
365                                                        iMin1 = i;
366                                                        iMin2 = j;
367                                                        fMin = fVal;
368                                                }
369                                                j = nNextActive[j];
370                                        }
371                                        i = nNextActive[i];
372                                }               
373                        }
374                        // record distance
375                        double fMinDistance = fDist[iMin1][iMin2];
376                        nClusters--;
377                        double fSep1 = fSeparations[iMin1];
378                        double fSep2 = fSeparations[iMin2];
379                        double fDist1 = (0.5 * fMinDistance) + (0.5 * (fSep1 - fSep2));
380                        double fDist2 = (0.5 * fMinDistance) + (0.5 * (fSep2 - fSep1));
381                        if (nClusters > 2) {
382                                // update separations  & distance
383                                double fNewSeparationSum = 0;
384                                double fMutualDistance = fDist[iMin1][iMin2];
385                                double[] fRow1 = fDist[iMin1];
386                                double[] fRow2 = fDist[iMin2];
387                                for(int i = 0; i < n; i++) {
388                                    if(i == iMin1 || i == iMin2 || nClusterID[i].size() == 0) {
389                                        fRow1[i] = 0;
390                                    } else {
391                                        double fVal1 = fRow1[i];
392                                        double fVal2 = fRow2[i];
393                                        double fDistance = (fVal1 + fVal2 - fMutualDistance) / 2.0;
394                                        fNewSeparationSum += fDistance;
395                                        // update the separationsum of cluster i.
396                                        fSeparationSums[i] += (fDistance - fVal1 - fVal2);
397                                        fSeparations[i] = fSeparationSums[i] / (nClusters -2);
398                                        fRow1[i] = fDistance;
399                                        fDist[i][iMin1] = fDistance;
400                                    }
401                                }
402                                fSeparationSums[iMin1] = fNewSeparationSum;
403                                fSeparations[iMin1] = fNewSeparationSum / (nClusters - 2);
404                                fSeparationSums[iMin2] = 0;
405                                merge(iMin1, iMin2, fDist1, fDist2, nClusterID, clusterNodes);
406                                int iPrev = iMin2;
407                                // since iMin1 < iMin2 we havenActiveRows[0] >= 0, so the next loop should be save
408                                while (nClusterID[iPrev].size() == 0) {
409                                        iPrev--;
410                                }
411                                nNextActive[iPrev] = nNextActive[iMin2];
412                        } else {
413                                merge(iMin1, iMin2, fDist1, fDist2, nClusterID, clusterNodes);
414                                break;
415                        }
416                }
417
418                for (int i = 0; i < n; i++) {
419                        if (nClusterID[i].size() > 0) {
420                                for (int j = i+1; j < n; j++) {
421                                        if (nClusterID[j].size() > 0) {
422                                                double fDist1 = fDist[i][j];
423                                                if(nClusterID[i].size() == 1) {
424                                                        merge(i,j,fDist1,0,nClusterID, clusterNodes);
425                                                } else if (nClusterID[j].size() == 1) {
426                                                        merge(i,j,0,fDist1,nClusterID, clusterNodes);
427                                                } else {
428                                                        merge(i,j,fDist1/2.0,fDist1/2.0,nClusterID, clusterNodes);
429                                                }
430                                                break;
431                                        }
432                                }
433                        }
434                }
435        } // neighborJoining
436       
437        /** Perform clustering using a link method
438         * This implementation uses a priority queue resulting in a O(n^2 log(n)) algorithm
439         * @param nClusters number of clusters
440         * @param nClusterID
441         * @param clusterNodes
442         */
443        void doLinkClustering(int nClusters, Vector<Integer>[] nClusterID, Node [] clusterNodes) {
444                int nInstances = m_instances.numInstances();
445                PriorityQueue<Tuple> queue = new PriorityQueue<Tuple>(nClusters*nClusters/2, new TupleComparator());
446                double [][] fDistance0 = new double[nClusters][nClusters];
447                double [][] fClusterDistance = null;
448                if (m_bDebug) {
449                        fClusterDistance = new double[nClusters][nClusters];
450                }
451                for (int i = 0; i < nClusters; i++) {
452                        fDistance0[i][i] = 0;
453                        for (int j = i+1; j < nClusters; j++) {
454                                fDistance0[i][j] = getDistance0(nClusterID[i], nClusterID[j]);
455                                fDistance0[j][i] = fDistance0[i][j];
456                                queue.add(new Tuple(fDistance0[i][j], i, j, 1, 1));
457                                if (m_bDebug) {
458                                        fClusterDistance[i][j] = fDistance0[i][j];
459                                        fClusterDistance[j][i] = fDistance0[i][j];
460                                }
461                        }
462                }
463                while (nClusters > m_nNumClusters) {
464                        int iMin1 = -1;
465                        int iMin2 = -1;
466                        // find closest two clusters
467                        if (m_bDebug) {
468                                /* simple but inefficient implementation */
469                                double fMinDistance = Double.MAX_VALUE;
470                                for (int i = 0; i < nInstances; i++) {
471                                        if (nClusterID[i].size()>0) {
472                                                for (int j = i+1; j < nInstances; j++) {
473                                                        if (nClusterID[j].size()>0) {
474                                                                double fDist = fClusterDistance[i][j];
475                                                                if (fDist < fMinDistance) {
476                                                                        fMinDistance = fDist;
477                                                                        iMin1 = i;
478                                                                        iMin2 = j;
479                                                                }
480                                                        }
481                                                }
482                                        }
483                                }
484                                merge(iMin1, iMin2, fMinDistance, fMinDistance, nClusterID, clusterNodes);
485                        } else {
486                                // use priority queue to find next best pair to cluster
487                                Tuple t;
488                                do {
489                                        t = queue.poll();
490                                } while (t!=null && (nClusterID[t.m_iCluster1].size() != t.m_nClusterSize1 || nClusterID[t.m_iCluster2].size() != t.m_nClusterSize2));
491                                iMin1 = t.m_iCluster1;
492                                iMin2 = t.m_iCluster2;
493                                merge(iMin1, iMin2, t.m_fDist, t.m_fDist, nClusterID, clusterNodes);
494                        }
495                        // merge  clusters
496                       
497                        // update distances & queue
498                        for (int i = 0; i < nInstances; i++) {
499                                if (i != iMin1 && nClusterID[i].size()!=0) {
500                                        int i1 = Math.min(iMin1,i);
501                                        int i2 = Math.max(iMin1,i);
502                                        double fDistance = getDistance(fDistance0, nClusterID[i1], nClusterID[i2]);
503                                        if (m_bDebug) {
504                                                fClusterDistance[i1][i2] = fDistance;
505                                                fClusterDistance[i2][i1] = fDistance;
506                                        }
507                                        queue.add(new Tuple(fDistance, i1, i2, nClusterID[i1].size(), nClusterID[i2].size()));
508                                }
509                        }
510                       
511                        nClusters--;
512                }
513        } // doLinkClustering
514       
515        void merge(int iMin1, int iMin2, double fDist1, double fDist2, Vector<Integer>[] nClusterID, Node [] clusterNodes) {
516                if (m_bDebug) {
517                        System.err.println("Merging " + iMin1 + " " + iMin2 + " " + fDist1 + " " + fDist2);
518                }
519                if (iMin1 > iMin2) {
520                        int h = iMin1; iMin1 = iMin2; iMin2 = h;
521                        double f = fDist1; fDist1 = fDist2; fDist2 = f;
522                }
523                nClusterID[iMin1].addAll(nClusterID[iMin2]);
524                nClusterID[iMin2].removeAllElements();
525               
526                // track hierarchy
527                Node node = new Node();
528                if (clusterNodes[iMin1] == null) {
529                        node.m_iLeftInstance = iMin1;
530                } else {
531                        node.m_left = clusterNodes[iMin1];
532                        clusterNodes[iMin1].m_parent = node;
533                }
534                if (clusterNodes[iMin2] == null) {
535                        node.m_iRightInstance = iMin2;
536                } else {
537                        node.m_right = clusterNodes[iMin2];
538                        clusterNodes[iMin2].m_parent = node;
539                }
540                if (m_bDistanceIsBranchLength) {
541                        node.setLength(fDist1, fDist2);
542                } else {
543                        node.setHeight(fDist1, fDist2);
544                }
545                clusterNodes[iMin1] = node;
546        } // merge
547       
548        /** calculate distance the first time when setting up the distance matrix **/
549        double getDistance0(Vector<Integer> cluster1, Vector<Integer> cluster2) {
550                double fBestDist = Double.MAX_VALUE;
551                switch (m_nLinkType) {
552                case SINGLE:
553                case NEIGHBOR_JOINING:
554                case CENTROID:
555                case COMPLETE:
556                case ADJCOMLPETE:
557                case AVERAGE:
558                case MEAN:
559                        // set up two instances for distance function
560                        Instance instance1 = (Instance) m_instances.instance(cluster1.elementAt(0)).copy();
561                        Instance instance2 = (Instance) m_instances.instance(cluster2.elementAt(0)).copy();
562                        fBestDist = m_DistanceFunction.distance(instance1, instance2);
563                        break;
564                case WARD:
565                        {
566                                // finds the distance of the change in caused by merging the cluster.
567                                // The information of a cluster is calculated as the error sum of squares of the
568                                // centroids of the cluster and its members.
569                                double ESS1 = calcESS(cluster1);
570                                double ESS2 = calcESS(cluster2);
571                                Vector<Integer> merged = new Vector<Integer>();
572                                merged.addAll(cluster1);
573                                merged.addAll(cluster2);
574                                double ESS = calcESS(merged);
575                                fBestDist = ESS * merged.size() - ESS1 * cluster1.size() - ESS2 * cluster2.size();
576                        }
577                        break;
578                }
579                return fBestDist;
580        } // getDistance0
581
582        /** calculate the distance between two clusters
583         * @param cluster1 list of indices of instances in the first cluster
584         * @param cluster2 dito for second cluster
585         * @return distance between clusters based on link type
586         */
587        double getDistance(double [][] fDistance, Vector<Integer> cluster1, Vector<Integer> cluster2) {
588                double fBestDist = Double.MAX_VALUE;
589                switch (m_nLinkType) {
590                case SINGLE:
591                        // find single link distance aka minimum link, which is the closest distance between
592                        // any item in cluster1 and any item in cluster2
593                        fBestDist = Double.MAX_VALUE;
594                        for (int i = 0; i < cluster1.size(); i++) {
595                                int i1 = cluster1.elementAt(i);
596                                for (int j = 0; j < cluster2.size(); j++) {
597                                        int i2  = cluster2.elementAt(j);
598                                        double fDist = fDistance[i1][i2];
599                                        if (fBestDist > fDist) {
600                                                fBestDist = fDist;
601                                        }
602                                }
603                        }
604                        break;
605                case COMPLETE:
606                case ADJCOMLPETE:
607                        // find complete link distance aka maximum link, which is the largest distance between
608                        // any item in cluster1 and any item in cluster2
609                        fBestDist = 0;
610                        for (int i = 0; i < cluster1.size(); i++) {
611                                int i1 = cluster1.elementAt(i);
612                                for (int j = 0; j < cluster2.size(); j++) {
613                                        int i2 = cluster2.elementAt(j);
614                                        double fDist = fDistance[i1][i2];
615                                        if (fBestDist < fDist) {
616                                                fBestDist = fDist;
617                                        }
618                                }
619                        }
620                        if (m_nLinkType == COMPLETE) {
621                                break;
622                        }
623                        // calculate adjustment, which is the largest within cluster distance
624                        double fMaxDist = 0;
625                        for (int i = 0; i < cluster1.size(); i++) {
626                                int i1 = cluster1.elementAt(i);
627                                for (int j = i+1; j < cluster1.size(); j++) {
628                                        int i2 = cluster1.elementAt(j);
629                                        double fDist = fDistance[i1][i2];
630                                        if (fMaxDist < fDist) {
631                                                fMaxDist = fDist;
632                                        }
633                                }
634                        }
635                        for (int i = 0; i < cluster2.size(); i++) {
636                                int i1 = cluster2.elementAt(i);
637                                for (int j = i+1; j < cluster2.size(); j++) {
638                                        int i2 = cluster2.elementAt(j);
639                                        double fDist = fDistance[i1][i2];
640                                        if (fMaxDist < fDist) {
641                                                fMaxDist = fDist;
642                                        }
643                                }
644                        }
645                        fBestDist -= fMaxDist;
646                        break;
647                case AVERAGE:
648                        // finds average distance between the elements of the two clusters
649                        fBestDist = 0;
650                        for (int i = 0; i < cluster1.size(); i++) {
651                                int i1 = cluster1.elementAt(i);
652                                for (int j = 0; j < cluster2.size(); j++) {
653                                        int i2 = cluster2.elementAt(j);
654                                        fBestDist += fDistance[i1][i2];
655                                }
656                        }
657                        fBestDist /= (cluster1.size() * cluster2.size());
658                        break;
659                case MEAN: 
660                        {
661                                // calculates the mean distance of a merged cluster (akak Group-average agglomerative clustering)
662                                Vector<Integer> merged = new Vector<Integer>();
663                                merged.addAll(cluster1);
664                                merged.addAll(cluster2);
665                                fBestDist = 0;
666                                for (int i = 0; i < merged.size(); i++) {
667                                        int i1 = merged.elementAt(i);
668                                        for (int j = i+1; j < merged.size(); j++) {
669                                                int i2 = merged.elementAt(j);
670                                                fBestDist += fDistance[i1][i2];
671                                        }
672                                }
673                                int n = merged.size();
674                                fBestDist /= (n*(n-1.0)/2.0);
675                        }
676                        break;
677                case CENTROID:
678                        // finds the distance of the centroids of the clusters
679                        double [] fValues1 = new double[m_instances.numAttributes()];
680                        for (int i = 0; i < cluster1.size(); i++) {
681                                Instance instance = m_instances.instance(cluster1.elementAt(i));
682                                for (int j = 0; j < m_instances.numAttributes(); j++) {
683                                        fValues1[j] += instance.value(j);
684                                }
685                        }
686                        double [] fValues2 = new double[m_instances.numAttributes()];
687                        for (int i = 0; i < cluster2.size(); i++) {
688                                Instance instance = m_instances.instance(cluster2.elementAt(i));
689                                for (int j = 0; j < m_instances.numAttributes(); j++) {
690                                        fValues2[j] += instance.value(j);
691                                }
692                        }
693                        for (int j = 0; j < m_instances.numAttributes(); j++) {
694                                fValues1[j] /= cluster1.size();
695                                fValues2[j] /= cluster2.size();
696                        }
697                        // set up two instances for distance function
698                        Instance instance1 = (Instance) m_instances.instance(0).copy();
699                        Instance instance2 = (Instance) m_instances.instance(0).copy();
700                        for (int j = 0; j < m_instances.numAttributes(); j++) {
701                                instance1.setValue(j, fValues1[j]);
702                                instance2.setValue(j, fValues2[j]);
703                        }
704                        fBestDist = m_DistanceFunction.distance(instance1, instance2);
705                        break;
706                case WARD:
707                        {
708                                // finds the distance of the change in caused by merging the cluster.
709                                // The information of a cluster is calculated as the error sum of squares of the
710                                // centroids of the cluster and its members.
711                                double ESS1 = calcESS(cluster1);
712                                double ESS2 = calcESS(cluster2);
713                                Vector<Integer> merged = new Vector<Integer>();
714                                merged.addAll(cluster1);
715                                merged.addAll(cluster2);
716                                double ESS = calcESS(merged);
717                                fBestDist = ESS * merged.size() - ESS1 * cluster1.size() - ESS2 * cluster2.size();
718                        }
719                        break;
720                }
721                return fBestDist;
722        } // getDistance
723
724        /** calculated error sum-of-squares for instances wrt centroid **/
725        double calcESS(Vector<Integer> cluster) {
726                double [] fValues1 = new double[m_instances.numAttributes()];
727                for (int i = 0; i < cluster.size(); i++) {
728                        Instance instance = m_instances.instance(cluster.elementAt(i));
729                        for (int j = 0; j < m_instances.numAttributes(); j++) {
730                                fValues1[j] += instance.value(j);
731                        }
732                }
733                for (int j = 0; j < m_instances.numAttributes(); j++) {
734                        fValues1[j] /= cluster.size();
735                }
736                // set up two instances for distance function
737                Instance centroid = (Instance) m_instances.instance(cluster.elementAt(0)).copy();
738                for (int j = 0; j < m_instances.numAttributes(); j++) {
739                        centroid.setValue(j, fValues1[j]);
740                }
741                double fESS = 0;
742                for (int i = 0; i < cluster.size(); i++) {
743                        Instance instance = m_instances.instance(cluster.elementAt(i));
744                        fESS += m_DistanceFunction.distance(centroid, instance);
745                }
746                return fESS / cluster.size(); 
747        } // calcESS
748       
749        @Override
750        /** instances are assigned a cluster by finding the instance in the training data
751         * with the closest distance to the instance to be clustered. The cluster index of
752         * the training data point is taken as the cluster index.
753         */
754        public int clusterInstance(Instance instance) throws Exception {
755                if (m_instances.numInstances() == 0) {
756                        return 0;
757                }
758                double fBestDist = Double.MAX_VALUE;
759                int iBestInstance = -1;
760                for (int i = 0; i < m_instances.numInstances(); i++) {
761                        double fDist = m_DistanceFunction.distance(instance, m_instances.instance(i));
762                        if (fDist < fBestDist) {
763                                fBestDist = fDist;
764                                iBestInstance = i;
765                        }
766                }
767                return m_nClusterNr[iBestInstance];
768        }
769
770        @Override
771        /** create distribution with all clusters having zero probability, except the
772         * cluster the instance is assigned to.
773         */
774        public double[] distributionForInstance(Instance instance) throws Exception {
775                if (numberOfClusters() == 0) {
776                        double [] p = new double[1];
777                        p[0] = 1;
778                        return p;
779                }
780                double [] p = new double[numberOfClusters()];
781                p[clusterInstance(instance)] = 1.0;
782                return p;
783        }
784
785        @Override
786        public Capabilities getCapabilities() {
787            Capabilities result = new Capabilities(this);
788            result.disableAll();
789            result.enable(Capability.NO_CLASS);
790
791            // attributes
792            result.enable(Capability.NOMINAL_ATTRIBUTES);
793            result.enable(Capability.NUMERIC_ATTRIBUTES);
794            result.enable(Capability.DATE_ATTRIBUTES);
795            result.enable(Capability.MISSING_VALUES);
796            result.enable(Capability.STRING_ATTRIBUTES);
797
798            // other
799            result.setMinimumNumberInstances(0);
800            return result;
801        }
802
803        @Override
804        public int numberOfClusters() throws Exception {
805                return Math.min(m_nNumClusters, m_instances.numInstances());
806        }
807
808          /**
809           * Returns an enumeration describing the available options.
810           *
811           * @return an enumeration of all the available options.
812           */
813          public Enumeration listOptions() {
814
815            Vector newVector = new Vector(8);
816            newVector.addElement(new Option(
817                      "\tIf set, classifier is run in debug mode and\n"
818                      + "\tmay output additional info to the console",
819                      "D", 0, "-D"));
820            newVector.addElement(new Option(
821                              "\tIf set, distance is interpreted as branch length\n"
822                              + "\totherwise it is node height.",
823                              "B", 0, "-B"));
824
825            newVector.addElement(new Option(
826                    "\tnumber of clusters",
827                      "N", 1,"-N <Nr Of Clusters>"));
828            newVector.addElement(new Option(
829                    "\tFlag to indicate the cluster should be printed in Newick format.",
830                      "P", 0,"-P"));
831                newVector.addElement(
832                                new Option(
833                                        "Link type (Single, Complete, Average, Mean, Centroid, Ward, Adjusted complete, Neighbor joining)", "L", 1,
834                                        "-L [SINGLE|COMPLETE|AVERAGE|MEAN|CENTROID|WARD|ADJCOMLPETE|NEIGHBOR_JOINING]"));
835            newVector.add(new Option(
836                        "\tDistance function to use.\n"
837                        + "\t(default: weka.core.EuclideanDistance)",
838                        "A", 1,"-A <classname and options>"));
839            return newVector.elements();
840          }
841
842          /**
843           * Parses a given list of options. <p/>
844           *
845           <!-- options-start -->
846           * Valid options are: <p/>
847           *
848           <!-- options-end -->
849           *
850           * @param options the list of options as an array of strings
851           * @throws Exception if an option is not supported
852           */
853          public void setOptions(String[] options) throws Exception {
854                    m_bPrintNewick = Utils.getFlag('P', options);
855
856                    String optionString = Utils.getOption('N', options); 
857                    if (optionString.length() != 0) {
858                      Integer temp = new Integer(optionString);
859                      setNumClusters(temp);
860                    }
861                    else {
862                      setNumClusters(2);
863                    }
864           
865        setDebug(Utils.getFlag('D', options));
866        setDistanceIsBranchLength(Utils.getFlag('B', options));
867
868            String sLinkType = Utils.getOption('L', options);
869
870
871                if (sLinkType.compareTo("SINGLE") == 0) {setLinkType(new SelectedTag(SINGLE, TAGS_LINK_TYPE));}
872                if (sLinkType.compareTo("COMPLETE") == 0) {setLinkType(new SelectedTag(COMPLETE, TAGS_LINK_TYPE));}
873                if (sLinkType.compareTo("AVERAGE") == 0) {setLinkType(new SelectedTag(AVERAGE, TAGS_LINK_TYPE));}
874                if (sLinkType.compareTo("MEAN") == 0) {setLinkType(new SelectedTag(MEAN, TAGS_LINK_TYPE));}
875                if (sLinkType.compareTo("CENTROID") == 0) {setLinkType(new SelectedTag(CENTROID, TAGS_LINK_TYPE));}
876                if (sLinkType.compareTo("WARD") == 0) {setLinkType(new SelectedTag(WARD, TAGS_LINK_TYPE));}
877                if (sLinkType.compareTo("ADJCOMLPETE") == 0) {setLinkType(new SelectedTag(ADJCOMLPETE, TAGS_LINK_TYPE));}
878                if (sLinkType.compareTo("NEIGHBOR_JOINING") == 0) {setLinkType(new SelectedTag(NEIGHBOR_JOINING, TAGS_LINK_TYPE));}
879               
880            String nnSearchClass = Utils.getOption('A', options);
881            if(nnSearchClass.length() != 0) {
882              String nnSearchClassSpec[] = Utils.splitOptions(nnSearchClass);
883              if(nnSearchClassSpec.length == 0) { 
884                throw new Exception("Invalid DistanceFunction specification string."); 
885              }
886              String className = nnSearchClassSpec[0];
887              nnSearchClassSpec[0] = "";
888
889              setDistanceFunction( (DistanceFunction)
890                                    Utils.forName( DistanceFunction.class, 
891                                                   className, nnSearchClassSpec) );
892            }
893            else {
894              setDistanceFunction(new EuclideanDistance());
895            }
896           
897            Utils.checkForRemainingOptions(options);
898          }
899
900          /**
901           * Gets the current settings of the clusterer.
902           *
903           * @return an array of strings suitable for passing to setOptions()
904           */
905          public String [] getOptions() {
906
907            String [] options = new String [14];
908            int current = 0;
909
910            options[current++] = "-N";
911            options[current++] = "" + getNumClusters();
912           
913            options[current++] = "-L";
914                switch (m_nLinkType) {
915                        case (SINGLE) :options[current++] = "SINGLE";break;
916                        case (COMPLETE) :options[current++] = "COMPLETE";break;
917                        case (AVERAGE) :options[current++] = "AVERAGE";break;
918                        case (MEAN) :options[current++] = "MEAN";break;
919                        case (CENTROID) :options[current++] = "CENTROID";break;
920                        case (WARD) :options[current++] = "WARD";break;
921                        case (ADJCOMLPETE) :options[current++] = "ADJCOMLPETE";break;
922                        case (NEIGHBOR_JOINING) :options[current++] = "NEIGHBOR_JOINING";break;
923                }
924                if (m_bPrintNewick) {
925                        options[current++] = "-P";
926                }
927            if (getDebug()) {
928                options[current++] = "-D";
929            }
930        if (getDistanceIsBranchLength()) {
931                options[current++] = "-B";
932        }
933           
934                options[current++] = "-A";
935                options[current++] = (m_DistanceFunction.getClass().getName() + " " +
936                           Utils.joinOptions(m_DistanceFunction.getOptions())).trim();
937           
938            while (current < options.length) {
939              options[current++] = "";
940            }
941           
942            return options;
943          }
944          public String toString() {
945                  StringBuffer buf = new StringBuffer();
946                  int attIndex = m_instances.classIndex();
947                  if (attIndex < 0) {
948                          // try find a string, or last attribute otherwise
949                          attIndex = 0;
950                          while (attIndex < m_instances.numAttributes()-1) {
951                                  if (m_instances.attribute(attIndex).isString()) {
952                                          break;
953                                  }
954                                  attIndex++;
955                          }
956                  }
957                  try {
958                        if (m_bPrintNewick && (numberOfClusters() > 0)) {
959                                  for (int i = 0; i < m_clusters.length; i++) {
960                                          if (m_clusters[i] != null) {
961                                                  buf.append("Cluster " + i + "\n");
962                                                  if (m_instances.attribute(attIndex).isString()) {
963                                                          buf.append(m_clusters[i].toString(attIndex));
964                                                  } else {
965                                                          buf.append(m_clusters[i].toString2(attIndex));
966                                                  }
967                                                  buf.append("\n\n");
968                                          }
969                                  }
970                          }
971                } catch (Exception e) {
972                        e.printStackTrace();
973                }
974                  return buf.toString();
975          }
976          /**
977           * Set debugging mode.
978           *
979           * @param debug true if debug output should be printed
980           */
981          public void setDebug(boolean debug) {
982
983            m_bDebug = debug;
984          }
985
986          /**
987           * Get whether debugging is turned on.
988           *
989           * @return true if debugging output is on
990           */
991          public boolean getDebug() {
992
993            return m_bDebug;
994          }
995
996          public boolean getDistanceIsBranchLength() {return m_bDistanceIsBranchLength;}
997
998          public void setDistanceIsBranchLength(boolean bDistanceIsHeight) {m_bDistanceIsBranchLength = bDistanceIsHeight;}
999
1000          public String distanceIsHeightTipText() {
1001                  return "If set to false, the distance between clusters is interpreted " +
1002                  "as the height of the node linking the clusters. This is appropriate for " +
1003                  "example for single link clustering. However, for neighbor joining, the " +
1004                  "distance is better interpreted as branch length. Set this flag to " +
1005                  "get the latter interpretation.";
1006          }
1007          /**
1008           * Returns the tip text for this property
1009           * @return tip text for this property suitable for
1010           * displaying in the explorer/experimenter gui
1011           */
1012          public String debugTipText() {
1013            return "If set to true, classifier may output additional info to " +
1014              "the console.";
1015          }
1016          /**
1017           * @return a string to describe the NumClusters
1018           */
1019          public String numClustersTipText() {
1020            return "Sets the number of clusters. " +
1021            "If a single hierarchy is desired, set this to 1.";
1022          }
1023
1024          /**
1025           * @return a string to describe the print Newick flag
1026           */
1027          public String printNewickTipText() {
1028            return "Flag to indicate whether the cluster should be print in Newick format." +
1029            " This can be useful for display in other programs. However, for large datasets" +
1030            " a lot of text may be produced, which may not be a nuisance when the Newick format" +
1031            " is not required";
1032          }
1033
1034          /**
1035           * @return a string to describe the distance function
1036           */
1037          public String distanceFunctionTipText() {
1038            return "Sets the distance function, which measures the distance between two individual. " +
1039            "instances (or possibly the distance between an instance and the centroid of a cluster" +
1040            "depending on the Link type).";
1041          }
1042
1043          /**
1044           * @return a string to describe the Link type
1045           */
1046          public String linkTypeTipText() {
1047            return "Sets the method used to measure the distance between two clusters.\n" +
1048            "SINGLE:\n" +
1049            " find single link distance aka minimum link, which is the closest distance between" +
1050            " any item in cluster1 and any item in cluster2\n" +
1051            "COMPLETE:\n" +
1052            " find complete link distance aka maximum link, which is the largest distance between" +
1053            " any item in cluster1 and any item in cluster2\n" +
1054            "ADJCOMLPETE:\n" +
1055            " as COMPLETE, but with adjustment, which is the largest within cluster distance\n" +
1056            "AVERAGE:\n" +
1057            " finds average distance between the elements of the two clusters\n" +
1058            "MEAN: \n" +
1059            " calculates the mean distance of a merged cluster (akak Group-average agglomerative clustering)\n" +
1060            "CENTROID:\n" +
1061            " finds the distance of the centroids of the clusters\n" +
1062            "WARD:\n" +
1063            " finds the distance of the change in caused by merging the cluster." +
1064            " The information of a cluster is calculated as the error sum of squares of the" +
1065            " centroids of the cluster and its members.\n" +
1066            "NEIGHBOR_JOINING\n" +
1067            " use neighbor joining algorithm."
1068            ;
1069          }
1070
1071          /**
1072           * This will return a string describing the clusterer.
1073           * @return The string.
1074           */
1075          public String globalInfo() {
1076            return 
1077            "Hierarchical clustering class.\n" +
1078            "Implements a number of classic agglomorative (i.e. bottom up) hierarchical clustering methods" +
1079            "based on .";
1080          }
1081         
1082          public static void main(String [] argv) {
1083                    runClusterer(new HierarchicalClusterer(), argv);
1084                  }
1085        @Override
1086        public String graph() throws Exception {
1087                if (numberOfClusters() == 0) {
1088                          return "Newick:(no,clusters)";
1089                }
1090                  int attIndex = m_instances.classIndex();
1091                  if (attIndex < 0) {
1092                          // try find a string, or last attribute otherwise
1093                          attIndex = 0;
1094                          while (attIndex < m_instances.numAttributes()-1) {
1095                                  if (m_instances.attribute(attIndex).isString()) {
1096                                          break;
1097                                  }
1098                                  attIndex++;
1099                          }
1100                  }
1101                  String sNewick = null;
1102                  if (m_instances.attribute(attIndex).isString()) {
1103                          sNewick = m_clusters[0].toString(attIndex);
1104                  } else {
1105                          sNewick = m_clusters[0].toString2(attIndex);
1106                  }
1107                  return "Newick:" + sNewick;
1108        }
1109        @Override
1110        public int graphType() {
1111                return Drawable.Newick;
1112        }
1113          /**
1114           * Returns the revision string.
1115           *
1116           * @return            the revision
1117           */
1118          public String getRevision() {
1119            return RevisionUtils.extract("$Revision: 6042 $");
1120          }
1121} // class HierarchicalClusterer
Note: See TracBrowser for help on using the repository browser.