source: tags/MetisMQIDemo/src/main/java/weka/classifiers/bayes/net/search/ci/ICSSearchAlgorithm.java

Last change on this file was 29, checked in by gnappo, 15 years ago

Taggata versione per la demo e aggiunto branch.

File size: 19.0 KB
Line 
1/*
2 * This program is free software; you can redistribute it and/or modify
3 * it under the terms of the GNU General Public License as published by
4 * the Free Software Foundation; either version 2 of the License, or
5 * (at your option) any later version.
6 *
7 * This program is distributed in the hope that it will be useful,
8 * but WITHOUT ANY WARRANTY; without even the implied warranty of
9 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 * GNU General Public License for more details.
11 *
12 * You should have received a copy of the GNU General Public License
13 * along with this program; if not, write to the Free Software
14 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 * ICSSearchAlgorithm.java
19 * Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23
24package weka.classifiers.bayes.net.search.ci;
25
26import weka.classifiers.bayes.BayesNet;
27import weka.classifiers.bayes.net.ParentSet;
28import weka.core.Instances;
29import weka.core.Option;
30import weka.core.RevisionHandler;
31import weka.core.RevisionUtils;
32import weka.core.Utils;
33
34import java.io.FileReader;
35import java.util.Enumeration;
36import java.util.Vector;
37
38/**
39 <!-- globalinfo-start -->
40 * This Bayes Network learning algorithm uses conditional independence tests to find a skeleton, finds V-nodes and applies a set of rules to find the directions of the remaining arrows.
41 * <p/>
42 <!-- globalinfo-end -->
43 *
44 <!-- options-start -->
45 * Valid options are: <p/>
46 *
47 * <pre> -cardinality &lt;num&gt;
48 *  When determining whether an edge exists a search is performed
49 *  for a set Z that separates the nodes. MaxCardinality determines
50 *  the maximum size of the set Z. This greatly influences the
51 *  length of the search. (default 2)</pre>
52 *
53 * <pre> -mbc
54 *  Applies a Markov Blanket correction to the network structure,
55 *  after a network structure is learned. This ensures that all
56 *  nodes in the network are part of the Markov blanket of the
57 *  classifier node.</pre>
58 *
59 * <pre> -S [BAYES|MDL|ENTROPY|AIC|CROSS_CLASSIC|CROSS_BAYES]
60 *  Score type (BAYES, BDeu, MDL, ENTROPY and AIC)</pre>
61 *
62 <!-- options-end -->
63 *
64 * @author Remco Bouckaert
65 * @version $Revision: 1.8 $
66 */ 
67public class ICSSearchAlgorithm 
68    extends CISearchAlgorithm {
69
70    /** for serialization */
71    static final long serialVersionUID = -2510985917284798576L;
72 
73    /**
74     * returns the name of the attribute with the given index
75     *
76     * @param iAttribute the index of the attribute
77     * @return the name of the attribute
78     */
79    String name(int iAttribute) {
80      return m_instances.attribute(iAttribute).name();
81    }
82   
83    /**
84     * returns the number of attributes
85     *
86     * @return the number of attributes
87     */
88    int maxn() {
89      return m_instances.numAttributes();
90    }
91   
92    /** maximum size of separating set **/
93    private int m_nMaxCardinality = 2; 
94
95    /**
96     * sets the cardinality
97     *
98     * @param nMaxCardinality the max cardinality
99     */
100    public void setMaxCardinality(int nMaxCardinality) {
101      m_nMaxCardinality = nMaxCardinality;
102    }
103   
104    /**
105     * returns the max cardinality
106     *
107     * @return the max cardinality
108     */
109    public int getMaxCardinality() {
110      return m_nMaxCardinality;
111    }
112       
113
114       
115    class SeparationSet
116        implements RevisionHandler {
117     
118        public int [] m_set;
119       
120        /**
121         * constructor
122         */
123        public SeparationSet() {
124                        m_set= new int [getMaxCardinality() + 1];
125        } // c'tor
126
127        public boolean contains(int nItem) {
128                for (int iItem = 0; iItem < getMaxCardinality() && m_set[iItem] != -1; iItem++) {
129                        if (m_set[iItem] == nItem) {
130                                        return true;
131                                }
132                }
133                        return false;
134        } // contains
135       
136        /**
137         * Returns the revision string.
138         *
139         * @return              the revision
140         */
141        public String getRevision() {
142          return RevisionUtils.extract("$Revision: 1.8 $");
143        }
144
145    } // class sepset
146
147        /**
148         * Search for Bayes network structure using ICS algorithm
149         * @param bayesNet datastructure to build network structure for
150         * @param instances data set to learn from
151         * @throws Exception if something goes wrong
152         */
153        protected void search(BayesNet bayesNet, Instances instances) throws Exception {
154        // init
155        m_BayesNet = bayesNet;
156        m_instances = instances;
157
158        boolean edges[][] = new boolean [maxn() + 1][];
159                boolean [] [] arrows = new boolean [maxn() + 1][];
160        SeparationSet [] [] sepsets = new SeparationSet [maxn() + 1][];
161        for (int iNode = 0 ; iNode < maxn() + 1; iNode++) {
162            edges[iNode] = new boolean[maxn()];
163                        arrows[iNode] = new boolean[maxn()]; 
164            sepsets[iNode] = new SeparationSet[maxn()];
165        }
166
167        calcDependencyGraph(edges, sepsets);
168        calcVeeNodes(edges, arrows, sepsets);
169        calcArcDirections(edges, arrows);
170
171                // transfrom into BayesNet datastructure
172                for (int iNode = 0; iNode < maxn(); iNode++) {
173                        // clear parent set of AttributeX
174                        ParentSet oParentSet = m_BayesNet.getParentSet(iNode);
175                        while (oParentSet.getNrOfParents() > 0) {
176                                oParentSet.deleteLastParent(m_instances);
177                        }
178                        for (int iParent = 0; iParent < maxn(); iParent++) {
179                                if (arrows[iParent][iNode]) {
180                                        oParentSet.addParent(iParent, m_instances);
181                                }
182                        }
183                }
184        } // search
185       
186       
187        /** CalcDependencyGraph determines the skeleton of the BayesNetwork by
188         * starting with a complete graph and removing edges (a--b) if it can
189         * find a set Z such that a and b conditionally independent given Z.
190         * The set Z is found by trying all possible subsets of nodes adjacent
191         * to a and b, first of size 0, then of size 1, etc. up to size
192         * m_nMaxCardinality
193         * @param edges boolean matrix representing the edges
194         * @param sepsets set of separating sets
195         */
196        void calcDependencyGraph(boolean[][] edges, SeparationSet[][] sepsets) {
197                /*calc undirected graph a-b iff D(a,S,b) for all S)*/
198                SeparationSet oSepSet;
199
200                for (int iNode1 = 0; iNode1 < maxn(); iNode1++) { 
201                        /*start with complete graph*/
202                        for (int iNode2 = 0; iNode2 < maxn(); iNode2++) {
203                                edges[iNode1][iNode2] = true;
204                        }
205                }
206                for (int iNode1 = 0; iNode1 < maxn(); iNode1++) {
207                        edges[iNode1][iNode1] = false;
208                }
209
210                for (int iCardinality = 0; iCardinality <= getMaxCardinality(); iCardinality++) {
211                        for (int iNode1 = 0; iNode1 <= maxn() - 2; iNode1++) {
212                                for (int iNode2 = iNode1 + 1; iNode2 < maxn(); iNode2++) {
213                                        if (edges[iNode1][iNode2]) {
214                                                oSepSet = existsSepSet(iNode1, iNode2, iCardinality, edges);
215                                                if (oSepSet != null) {
216                                                        edges[iNode1][iNode2] = false;
217                                                        edges[iNode2][iNode1] = false;
218                                                        sepsets[iNode1][iNode2] = oSepSet;
219                                                        sepsets[iNode2][iNode1] = oSepSet;
220                                                        // report separating set
221                                                        System.err.print("I(" + name(iNode1) + ", {");
222                                                        for (int iNode3 = 0; iNode3 < iCardinality; iNode3++) {
223                                                                System.err.print(name(oSepSet.m_set[iNode3]) + " ");
224                                                        }
225                                                        System.err.print("} ," + name(iNode2) + ")\n");
226                                                }
227                                        }
228                                }
229                        }
230                        // report current state of dependency graph
231                        System.err.print(iCardinality + " ");
232                        for (int iNode1 = 0; iNode1 < maxn(); iNode1++) {
233                                System.err.print(name(iNode1) + " ");
234                        }
235                        System.err.print('\n');
236                        for (int iNode1 = 0; iNode1 < maxn(); iNode1++) {
237                                for (int iNode2 = 0; iNode2 < maxn(); iNode2++) {
238                                        if (edges[iNode1][iNode2])
239                                                System.err.print("X ");
240                                        else
241                                                System.err.print(". ");
242                                }
243                                System.err.print(name(iNode1) + " ");
244                                System.err.print('\n');
245                        }
246                }
247        } /*CalcDependencyGraph*/
248
249        /** ExistsSepSet tests if a separating set Z of node a and b exists of given
250         * cardiniality exists.
251         * The set Z is found by trying all possible subsets of nodes adjacent
252         * to both a and b of the requested cardinality.
253         * @param iNode1 index of first node a
254         * @param iNode2 index of second node b
255         * @param nCardinality size of the separating set Z
256         * @param edges
257         * @return SeparationSet containing set that separates iNode1 and iNode2 or null if no such set exists
258         */
259    SeparationSet existsSepSet(int iNode1, int iNode2, int nCardinality, boolean [] [] edges)
260    {
261        /*Test if a separating set of node d and e exists of cardiniality k*/
262//        int iNode1_, iNode2_;
263        int iNode3, iZ;
264                SeparationSet Z = new SeparationSet();
265                Z.m_set[nCardinality] = -1;
266
267//        iNode1_ = iNode1;
268//        iNode2_ = iNode2;
269
270                // find first candidate separating set Z
271        if (nCardinality > 0) {
272            Z.m_set[0] = next(-1, iNode1, iNode2, edges);
273            iNode3 = 1;
274            while (iNode3 < nCardinality) {
275              Z.m_set[iNode3] = next(Z.m_set[iNode3 - 1], iNode1, iNode2, edges);
276              iNode3++;
277            }
278        }
279
280        if (nCardinality > 0) {
281                iZ = maxn() - Z.m_set[nCardinality - 1] - 1;
282        } else {
283            iZ = 0;
284        }
285       
286
287        while (iZ >= 0)
288        { 
289                //check if candidate separating set makes iNode2_ and iNode1_ independent
290            if (isConditionalIndependent(iNode2, iNode1, Z.m_set, nCardinality))        {
291                return Z;
292            }
293                        // calc next candidate separating set
294            if (nCardinality > 0) {
295                Z.m_set[nCardinality - 1] = next(Z.m_set[nCardinality - 1], iNode1, iNode2, edges);
296            }
297            iZ = nCardinality - 1;   
298            while (iZ >= 0 && Z.m_set[iZ] >= maxn()) {
299                iZ = nCardinality - 1;
300                while (iZ >= 0 && Z.m_set[iZ] >= maxn()) {
301                        iZ--;
302                }
303                if (iZ < 0) {
304                    break;
305                }
306                Z.m_set[iZ] = next(Z.m_set[iZ], iNode1, iNode2, edges);
307                for (iNode3 = iZ + 1; iNode3 < nCardinality; iNode3++) {
308                    Z.m_set[iNode3] = next(Z.m_set[iNode3 - 1], iNode1, iNode2, edges);
309                }
310                iZ = nCardinality - 1;
311            }
312        }
313
314        return null;
315    }  /*ExistsSepSet*/
316
317        /**
318         * determine index of node that makes next candidate separating set
319         * adjacent to iNode1 and iNode2, but not iNode2 itself
320         * @param x index of current node
321         * @param iNode1 first node
322         * @param iNode2 second node (must be larger than iNode1)
323         * @param edges skeleton so far
324         * @return int index of next node adjacent to iNode1 after x
325         */
326        int next(int x, int iNode1, int iNode2, boolean [] [] edges)
327        {
328                x++;
329                while (x < maxn() && (!edges[iNode1][x] || !edges[iNode2][x] ||x == iNode2)) {
330                        x++;
331                }
332                return x;
333        }  /*next*/
334
335
336        /** CalcVeeNodes tries to find V-nodes, i.e. nodes a,b,c such that
337         * a->c<-b and a-/-b. These nodes are identified by finding nodes
338         * a,b,c in the skeleton such that a--c, c--b and a-/-b and furthermore
339         * c is not in the set Z that separates a and b
340         * @param edges skeleton
341         * @param arrows resulting partially directed skeleton after all V-nodes
342         * have been identified
343         * @param sepsets separating sets
344         */
345        void calcVeeNodes(
346                boolean[][] edges,
347                boolean[][] arrows,
348                SeparationSet[][] sepsets) {
349
350                // start with complete empty graph
351                for (int iNode1 = 0; iNode1 < maxn(); iNode1++) {
352                        for (int iNode2 = 0; iNode2 < maxn(); iNode2++) {
353                                arrows[iNode1][iNode2] = false;
354                        }
355                }
356
357                for (int iNode1 = 0; iNode1 < maxn() - 1; iNode1++) {
358                        for (int iNode2 = iNode1 + 1; iNode2 < maxn(); iNode2++) {
359                                if (!edges[iNode1][iNode2]) { /*i nonadj j*/
360                                        for (int iNode3 = 0; iNode3 < maxn(); iNode3++) {
361                                                if ((iNode3 != iNode1
362                                                        && iNode3 != iNode2
363                                                        && edges[iNode1][iNode3]
364                                                        && edges[iNode2][iNode3])
365                                                        & (!sepsets[iNode1][iNode2].contains(iNode3))) {
366                                                        arrows[iNode1][iNode3] = true; /*add arc i->k*/
367                                                        arrows[iNode2][iNode3] = true; /*add arc j->k*/
368                                                }
369                                        }
370                                }
371                        }
372                }
373        } // CalcVeeNodes
374
375
376        /** CalcArcDirections assigns directions to edges that remain after V-nodes have
377         * been identified. The arcs are directed using the following rules:
378           Rule 1: i->j--k & i-/-k => j->k
379           Rule 2: i->j->k & i--k => i->k
380           Rule 3  m
381                         /|\
382                        i | k  => m->j
383        i->j<-k  \|/
384                          j
385       
386           Rule 4  m
387                         / \
388                        i---k  => i->m & k->m
389          i->j   \ /
390                          j
391           Rule 5: if no edges are directed then take a random one (first we can find)
392         * @param edges skeleton
393         * @param arrows resulting fully directed DAG
394         */
395        void calcArcDirections(boolean[][] edges, boolean[][] arrows) {
396                /*give direction to remaining arcs*/
397                int i, j, k, m;
398                boolean bFound;
399
400                do {
401                        bFound = false;
402
403                        /*Rule 1: i->j--k & i-/-k => j->k*/
404
405                        for (i = 0; i < maxn(); i++) {
406                                for (j = 0; j < maxn(); j++) {
407                                        if (i != j && arrows[i][j]) {
408                                                for (k = 0; k < maxn(); k++) {
409                                                        if (i != k
410                                                                && j != k
411                                                                && edges[j][k]
412                                                                && !edges[i][k]
413                                                                && !arrows[j][k]
414                                                                && !arrows[k][j]) {
415                                                                arrows[j][k] = true;
416                                                                bFound = true;
417                                                        }
418                                                }
419                                        }
420                                }
421                        }
422
423                        /*Rule 2: i->j->k & i--k => i->k*/
424
425                        for (i = 0; i < maxn(); i++) {
426                                for (j = 0; j < maxn(); j++) {
427                                        if (i != j && arrows[i][j]) {
428                                                for (k = 0; k < maxn(); k++) {
429                                                        if (i != k
430                                                                && j != k
431                                                                && edges[i][k]
432                                                                && arrows[j][k]
433                                                                && !arrows[i][k]
434                                                                && !arrows[k][i]) {
435                                                                arrows[i][k] = true;
436                                                                bFound = true;
437                                                        }
438                                                }
439                                        }
440                                }
441                        }
442
443                        /* Rule 3  m
444                                 /|\
445                                i | k  => m->j
446                        i->j<-k  \|/
447                                  j
448                        */
449                        for (i = 0; i < maxn(); i++) {
450                                for (j = 0; j < maxn(); j++) {
451                                        if (i != j && arrows[i][j]) {
452                                                for (k = 0; k < maxn(); k++) {
453                                                        if (k != i
454                                                                && k != j
455                                                                && arrows[k][j]
456                                                                && !edges[k][i]) {
457                                                                for (m = 0; m < maxn(); m++) {
458                                                                        if (m != i
459                                                                                && m != j
460                                                                                && m != k
461                                                                                && edges[m][i]
462                                                                                && !arrows[m][i]
463                                                                                && !arrows[i][m]
464                                                                                && edges[m][j]
465                                                                                && !arrows[m][j]
466                                                                                && !arrows[j][m]
467                                                                                && edges[m][k]
468                                                                                && !arrows[m][k]
469                                                                                && !arrows[k][m]) {
470                                                                                arrows[m][j] = true;
471                                                                                bFound = true;
472                                                                        }
473                                                                }
474                                                        }
475                                                }
476                                        }
477                                }
478                        }
479
480                        /* Rule 4  m
481                                 / \
482                                i---k  => i->m & k->m
483                          i->j   \ /
484                                  j
485                        */
486                        for (i = 0; i < maxn(); i++) {
487                                for (j = 0; j < maxn(); j++) {
488                                        if (i != j && arrows[j][i]) {
489                                                for (k = 0; k < maxn(); k++) {
490                                                        if (k != i
491                                                                && k != j
492                                                                && edges[k][j]
493                                                                && !arrows[k][j]
494                                                                && !arrows[j][k]
495                                                                && edges[k][i]
496                                                                && !arrows[k][i]
497                                                                && !arrows[i][k]) {
498                                                                for (m = 0; m < maxn(); m++) {
499                                                                        if (m != i
500                                                                                && m != j
501                                                                                && m != k
502                                                                                && edges[m][i]
503                                                                                && !arrows[m][i]
504                                                                                && !arrows[i][m]
505                                                                                && edges[m][k]
506                                                                                && !arrows[m][k]
507                                                                                && !arrows[k][m]) {
508                                                                                arrows[i][m] = true;
509                                                                                arrows[k][m] = true;
510                                                                                bFound = true;
511                                                                        }
512                                                                }
513                                                        }
514                                                }
515                                        }
516                                }
517                        }
518
519                        /*Rule 5: if no edges are directed then take a random one (first we can find)*/
520
521                        if (!bFound) {
522                                i = 0;
523                                while (!bFound && i < maxn()) {
524                                        j = 0;
525                                        while (!bFound && j < maxn()) {
526                                                if (edges[i][j]
527                                                        && !arrows[i][j]
528                                                        && !arrows[j][i]) {
529                                                        arrows[i][j] = true;
530                                                        bFound = true;
531                                                }
532                                                j++;
533                                        }
534                                        i++;
535                                }
536                        }
537
538                }
539                while (bFound);
540
541        } // CalcArcDirections
542
543        /**
544         * Returns an enumeration describing the available options.
545         *
546         * @return an enumeration of all the available options.
547         */
548        public Enumeration listOptions() {
549          Vector result = new Vector();
550         
551          result.addElement(new Option(
552                "\tWhen determining whether an edge exists a search is performed \n"
553              + "\tfor a set Z that separates the nodes. MaxCardinality determines \n"
554              + "\tthe maximum size of the set Z. This greatly influences the \n"
555              + "\tlength of the search. (default 2)",
556              "cardinality", 1, "-cardinality <num>"));
557         
558          Enumeration en = super.listOptions();
559          while (en.hasMoreElements())
560            result.addElement(en.nextElement());
561         
562          return result.elements();
563        } // listOption
564       
565        /**
566         * Parses a given list of options. <p/>
567         *
568         <!-- options-start -->
569         * Valid options are: <p/>
570         *
571         * <pre> -cardinality &lt;num&gt;
572         *  When determining whether an edge exists a search is performed
573         *  for a set Z that separates the nodes. MaxCardinality determines
574         *  the maximum size of the set Z. This greatly influences the
575         *  length of the search. (default 2)</pre>
576         *
577         * <pre> -mbc
578         *  Applies a Markov Blanket correction to the network structure,
579         *  after a network structure is learned. This ensures that all
580         *  nodes in the network are part of the Markov blanket of the
581         *  classifier node.</pre>
582         *
583         * <pre> -S [BAYES|MDL|ENTROPY|AIC|CROSS_CLASSIC|CROSS_BAYES]
584         *  Score type (BAYES, BDeu, MDL, ENTROPY and AIC)</pre>
585         *
586         <!-- options-end -->
587         *
588         * @param options the list of options as an array of strings
589         * @throws Exception if an option is not supported
590         */
591        public void setOptions(String[] options) throws Exception {
592          String        tmpStr;
593         
594          tmpStr = Utils.getOption("cardinality", options);
595          if (tmpStr.length() != 0)
596            setMaxCardinality(Integer.parseInt(tmpStr));
597          else
598            setMaxCardinality(2);
599           
600          super.setOptions(options);
601        } // setOptions
602       
603        /**
604         * Gets the current settings of the Classifier.
605         *
606         * @return an array of strings suitable for passing to setOptions
607         */
608        public String[] getOptions() {
609          Vector        result;
610          String[]      options;
611          int           i;
612         
613          result  = new Vector();
614          options = super.getOptions();
615          for (i = 0; i < options.length; i++)
616            result.add(options[i]);
617         
618          result.add("-cardinality");
619          result.add("" + getMaxCardinality());
620         
621          return (String[]) result.toArray(new String[result.size()]);
622        } // getOptions
623       
624
625        /**
626         * @return a string to describe the MaxCardinality option.
627         */
628        public String maxCardinalityTipText() {
629          return "When determining whether an edge exists a search is performed for a set Z "+
630          "that separates the nodes. MaxCardinality determines the maximum size of the set Z. " +
631          "This greatly influences the length of the search. Default value is 2.";
632        } // maxCardinalityTipText
633
634        /**
635         * This will return a string describing the search algorithm.
636         * @return The string.
637         */
638        public String globalInfo() {
639          return "This Bayes Network learning algorithm uses conditional independence tests " +
640          "to find a skeleton, finds V-nodes and applies a set of rules to find the directions " +
641          "of the remaining arrows.";
642        }
643
644        /**
645         * Returns the revision string.
646         *
647         * @return              the revision
648         */
649        public String getRevision() {
650          return RevisionUtils.extract("$Revision: 1.8 $");
651        }
652
653        /**
654         * for testing the class
655         *
656         * @param argv the commandline parameters
657         */
658        static public void main(String [] argv) {
659                try {
660                        BayesNet b = new BayesNet();
661                        b.setSearchAlgorithm( new ICSSearchAlgorithm());
662                        Instances instances = new Instances(new FileReader("C:\\eclipse\\workspace\\weka\\data\\contact-lenses.arff"));
663                        instances.setClassIndex(instances.numAttributes() - 1);
664                        b.buildClassifier(instances);
665                        System.out.println(b.toString());
666                } catch (Exception e) {
667                        e.printStackTrace();
668                }
669        } // main
670
671} // class ICSSearchAlgorithm
Note: See TracBrowser for help on using the repository browser.