source: src/main/java/weka/classifiers/bayes/net/MarginCalculator.java @ 28

Last change on this file since 28 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 28.9 KB
Line 
1package weka.classifiers.bayes.net;
2
3import weka.classifiers.bayes.BayesNet;
4import weka.core.RevisionHandler;
5import weka.core.RevisionUtils;
6
7import java.io.Serializable;
8import java.util.HashSet;
9import java.util.Iterator;
10import java.util.Set;
11import java.util.Vector;
12
13
14public class MarginCalculator implements Serializable, RevisionHandler {
15          /** for serialization */
16          private static final long serialVersionUID = 650278019241175534L;
17
18          boolean m_debug = false;
19          public JunctionTreeNode m_root = null;
20        JunctionTreeNode [] jtNodes;
21
22        public int getNode(String sNodeName) {
23        int iNode = 0;
24        while (iNode < m_root.m_bayesNet.m_Instances.numAttributes()) {
25                if (m_root.m_bayesNet.m_Instances.attribute(iNode).name().equals(sNodeName)) {
26                        return iNode;
27                }
28                iNode++; 
29        }
30        //throw new Exception("Could not find node [[" + sNodeName + "]]");
31        return -1;
32        }
33        public String toXMLBIF03() {return m_root.m_bayesNet.toXMLBIF03();}
34       
35        /**
36         * Calc marginal distributions of nodes in Bayesian network
37         *       Note that a connected network is assumed.
38         *       Unconnected networks may give unexpected results.
39         * @param bayesNet
40         */
41        public void calcMargins(BayesNet bayesNet) throws Exception {
42                //System.out.println(bayesNet.toString());
43                boolean[][] bAdjacencyMatrix = moralize(bayesNet);
44                process(bAdjacencyMatrix, bayesNet);
45        } // calcMargins
46
47        public void calcFullMargins(BayesNet bayesNet) throws Exception {
48                //System.out.println(bayesNet.toString());
49                int nNodes = bayesNet.getNrOfNodes();
50                boolean[][] bAdjacencyMatrix = new boolean[nNodes][nNodes];
51                for (int iNode = 0; iNode < nNodes; iNode++) {
52                        for (int iNode2 = 0; iNode2 < nNodes; iNode2++) {
53                                bAdjacencyMatrix[iNode][iNode2] = true;
54                        }
55                }
56                process(bAdjacencyMatrix, bayesNet);
57        } // calcMargins
58       
59       
60        public void process(boolean[][] bAdjacencyMatrix, BayesNet bayesNet) throws Exception {
61                int[] order = getMaxCardOrder(bAdjacencyMatrix);
62                bAdjacencyMatrix = fillIn(order, bAdjacencyMatrix);
63                order = getMaxCardOrder(bAdjacencyMatrix);
64                Set [] cliques = getCliques(order, bAdjacencyMatrix);
65                Set [] separators = getSeparators(order, cliques);
66                int [] parentCliques = getCliqueTree(order, cliques, separators);
67                // report cliques
68                int nNodes = bAdjacencyMatrix.length;
69                if (m_debug) {
70                for (int i = 0; i < nNodes; i++) {
71                        int iNode = order[i];
72                        if (cliques[iNode] != null) {
73                                System.out.print("Clique " + iNode + " (");
74                                Iterator nodes = cliques[iNode].iterator();
75                                while (nodes.hasNext()) {
76                                        int iNode2 = (Integer) nodes.next();
77                                        System.out.print(iNode2 + " " + bayesNet.getNodeName(iNode2));
78                                        if (nodes.hasNext()) {
79                                                System.out.print(",");
80                                        }
81                                }
82                                System.out.print(") S(");
83                                nodes = separators[iNode].iterator();
84                                while (nodes.hasNext()) {
85                                        int iNode2 = (Integer) nodes.next();
86                                        System.out.print(iNode2 + " " + bayesNet.getNodeName(iNode2));
87                                        if (nodes.hasNext()) {
88                                                System.out.print(",");
89                                        }
90                                }
91                                System.out.println(") parent clique " + parentCliques[iNode]);
92                        }               
93                }
94                }
95                               
96                jtNodes = getJunctionTree(cliques, separators, parentCliques, order, bayesNet);
97                m_root = null;
98                for (int iNode = 0; iNode < nNodes; iNode++) {
99                        if (parentCliques[iNode] < 0 && jtNodes[iNode] != null) {
100                                m_root = jtNodes[iNode];
101                                break;
102                        }
103                }
104                m_Margins = new double[nNodes][];
105                initialize(jtNodes, order, cliques, separators, parentCliques);
106               
107                // sanity check
108                for (int i = 0; i < nNodes; i++) {
109                        int iNode = order[i];
110                        if (cliques[iNode] != null) {
111                                if (parentCliques[iNode] == -1 && separators[iNode].size() > 0) {
112                                        throw new Exception("Something wrong in clique tree");
113                                }
114                        }
115                }
116                if (m_debug) {
117                        //System.out.println(m_root.toString());
118                }
119        } // process
120               
121        void initialize(JunctionTreeNode [] jtNodes, int [] order, Set [] cliques, Set [] separators, int [] parentCliques) {
122                int nNodes = order.length;
123                for (int i = nNodes - 1; i >= 0; i--) {
124                        int iNode = order[i];
125                        if (jtNodes[iNode]!=null) {
126                                jtNodes[iNode].initializeUp();
127                        }
128                }       
129                for (int i = 0; i < nNodes; i++) {
130                        int iNode = order[i];
131                        if (jtNodes[iNode]!=null) {
132                                jtNodes[iNode].initializeDown(false);
133                        }
134                }       
135        } // initialize
136       
137        JunctionTreeNode [] getJunctionTree(Set [] cliques, Set [] separators, int [] parentCliques, int [] order, BayesNet bayesNet) {
138                int nNodes = order.length;
139                JunctionTreeNode root = null;
140                JunctionTreeNode [] jtns = new JunctionTreeNode[nNodes]; 
141                boolean [] bDone = new boolean[nNodes];
142                // create junction tree nodes
143                for (int i = 0; i < nNodes; i++) {
144                        int iNode = order[i];
145                        if (cliques[iNode] != null) {
146                                jtns[iNode] = new JunctionTreeNode(cliques[iNode], bayesNet, bDone);
147                        }
148                }
149                // create junction tree separators
150                for (int i = 0; i < nNodes; i++) {
151                        int iNode = order[i];
152                        if (cliques[iNode] != null) {
153                                JunctionTreeNode parent = null;
154                                if (parentCliques[iNode] > 0) {
155                                        parent = jtns[parentCliques[iNode]];
156                                        JunctionTreeSeparator jts = new JunctionTreeSeparator(separators[iNode], bayesNet, jtns[iNode], parent);
157                                        jtns[iNode].setParentSeparator(jts);
158                                        jtns[parentCliques[iNode]].addChildClique(jtns[iNode]);
159                                } else {
160                                        root = jtns[iNode];     
161                                }
162                        }
163                }
164                return jtns;
165        } // getJunctionTree
166       
167        public class JunctionTreeSeparator implements Serializable, RevisionHandler {
168         
169                  private static final long serialVersionUID = 6502780192411755343L;
170                int [] m_nNodes;
171                int m_nCardinality;
172                double [] m_fiParent;
173                double [] m_fiChild;
174                JunctionTreeNode m_parentNode;
175                JunctionTreeNode m_childNode;
176                BayesNet m_bayesNet;
177               
178                JunctionTreeSeparator(Set separator, BayesNet bayesNet, JunctionTreeNode childNode, JunctionTreeNode parentNode) {
179                        //////////////////////
180                        // initialize node set
181                        m_nNodes = new int[separator.size()];
182                        int iPos = 0;
183                        m_nCardinality = 1;
184                        for(Iterator nodes = separator.iterator(); nodes.hasNext();) {
185                                int iNode = (Integer) nodes.next();
186                                m_nNodes[iPos++] = iNode;
187                                m_nCardinality *= bayesNet.getCardinality(iNode);
188                        }
189                        m_parentNode = parentNode;
190                        m_childNode = childNode;
191                        m_bayesNet = bayesNet;
192                } // c'tor
193               
194                /** marginalize junciontTreeNode node over all nodes outside the separator set
195                 * of the parent clique
196                 *
197                 */
198                public void updateFromParent() {
199                        double [] fis = update(m_parentNode); 
200                        if (fis == null) {
201                                m_fiParent = null;
202                        } else {
203                                m_fiParent = fis;
204                                // normalize
205                                double sum = 0;
206                                for (int iPos = 0; iPos < m_nCardinality; iPos++) {
207                                        sum += m_fiParent[iPos];
208                                }
209                                for (int iPos = 0; iPos < m_nCardinality; iPos++) {
210                                        m_fiParent[iPos] /= sum;
211                                }
212                        }
213                } // updateFromParent
214
215                /** marginalize junciontTreeNode node over all nodes outside the separator set
216                 * of the child clique
217                 *
218                 */
219                public void updateFromChild() {
220                        double [] fis = update(m_childNode); 
221                        if (fis == null) {
222                                m_fiChild = null;
223                        } else {
224                                m_fiChild = fis;
225                                // normalize
226                                double sum = 0;
227                                for (int iPos = 0; iPos < m_nCardinality; iPos++) {
228                                        sum += m_fiChild[iPos];
229                                }
230                                for (int iPos = 0; iPos < m_nCardinality; iPos++) {
231                                        m_fiChild[iPos] /= sum;
232                                }
233                        }
234                } // updateFromChild
235               
236                /** marginalize junciontTreeNode node over all nodes outside the separator set
237                 *
238                 * @param node one of the neighboring junciont tree nodes of this separator
239                 */
240                public double [] update(JunctionTreeNode node) {
241                        if (node.m_P == null) {
242                                return null;
243                        }
244                        double [] fi = new double[m_nCardinality];
245
246                        int [] values = new int[node.m_nNodes.length];
247                        int [] order = new int[m_bayesNet.getNrOfNodes()];
248                        for (int iNode = 0; iNode < node.m_nNodes.length; iNode++) {
249                                order[node.m_nNodes[iNode]] = iNode;
250                        }
251                        // fill in the values
252                        for (int iPos = 0; iPos < node.m_nCardinality; iPos++) {
253                                int iNodeCPT = getCPT(node.m_nNodes, node.m_nNodes.length, values, order, m_bayesNet);
254                                int iSepCPT =  getCPT(m_nNodes, m_nNodes.length, values, order, m_bayesNet);
255                                fi[iSepCPT] += node.m_P[iNodeCPT];
256                                // update values
257                                int i = 0;
258                                values[i]++;
259                                while (i < node.m_nNodes.length && values[i] == m_bayesNet.getCardinality(node.m_nNodes[i])) {
260                                        values[i] = 0;
261                                        i++;
262                                        if (i < node.m_nNodes.length) {
263                                                values[i]++;
264                                        }
265                                }
266                        }
267                        return fi;
268                } // update
269                 
270                /**
271                 * Returns the revision string.
272                 *
273                 * @return              the revision
274                 */
275                public String getRevision() {
276                  return RevisionUtils.extract("$Revision: 4899 $");
277                }
278
279        } // class JunctionTreeSeparator
280
281        public class JunctionTreeNode implements Serializable, RevisionHandler {
282         
283                  private static final long serialVersionUID = 650278019241175536L;
284                /** reference Bayes net for information about variables like name, cardinality, etc.
285                 * but not for relations between nodes **/
286                BayesNet m_bayesNet;
287                /** nodes of the Bayes net in this junction node **/
288                public int [] m_nNodes;
289                /** cardinality of the instances of variables in this junction node **/
290                int m_nCardinality;
291                /** potentials for first network **/
292                double [] m_fi;
293
294                /** distribution over this junction node according to first Bayes network **/
295                double [] m_P;
296
297
298                double [][] m_MarginalP;               
299
300               
301                JunctionTreeSeparator m_parentSeparator;
302                public void setParentSeparator(JunctionTreeSeparator parentSeparator) {m_parentSeparator = parentSeparator;}
303                public Vector m_children;
304                public void addChildClique(JunctionTreeNode child) {m_children.add(child);}
305
306                public void initializeUp() {
307                        m_P = new double[m_nCardinality];
308                        for (int iPos = 0; iPos < m_nCardinality; iPos++) {
309                                m_P[iPos] = m_fi[iPos];
310                        }
311                        int [] values = new int[m_nNodes.length];
312                        int [] order = new int[m_bayesNet.getNrOfNodes()];
313                        for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
314                                order[m_nNodes[iNode]] = iNode;
315                        }
316                        for (Iterator child = m_children.iterator(); child.hasNext(); ) {
317                                JunctionTreeNode childNode = (JunctionTreeNode) child.next();
318                                JunctionTreeSeparator separator = childNode.m_parentSeparator;
319                        // Update the values
320                        for (int iPos = 0; iPos < m_nCardinality; iPos++) {
321                                int iSepCPT = getCPT(separator.m_nNodes, separator.m_nNodes.length, values, order, m_bayesNet);
322                                int iNodeCPT =  getCPT(m_nNodes, m_nNodes.length, values, order, m_bayesNet);
323                                        m_P[iNodeCPT] *= separator.m_fiChild[iSepCPT];                                 
324                                // update values
325                                int i = 0;
326                                values[i]++;
327                                while (i < m_nNodes.length && values[i] == m_bayesNet.getCardinality(m_nNodes[i])) {
328                                        values[i] = 0;
329                                        i++;
330                                        if (i < m_nNodes.length) {
331                                                values[i]++;
332                                        }
333                                }
334                        }
335                        }
336                        // normalize
337                        double sum = 0;
338                        for (int iPos = 0; iPos < m_nCardinality; iPos++) {
339                                sum += m_P[iPos];
340                        }
341                        for (int iPos = 0; iPos < m_nCardinality; iPos++) {
342                                m_P[iPos] /= sum;
343                        }
344
345                        if (m_parentSeparator != null) { // not a root node
346                                m_parentSeparator.updateFromChild();
347                        }
348                } // initializeUp
349
350                public void initializeDown(boolean recursively) {
351                        if (m_parentSeparator == null) { // a root node
352                                calcMarginalProbabilities();
353                        } else {
354                        m_parentSeparator.updateFromParent();
355                                int [] values = new int[m_nNodes.length];
356                                int [] order = new int[m_bayesNet.getNrOfNodes()];
357                                for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
358                                        order[m_nNodes[iNode]] = iNode;
359                                }
360
361                               
362                                // Update the values
363                                for (int iPos = 0; iPos < m_nCardinality; iPos++) {
364                                        int iSepCPT = getCPT(m_parentSeparator.m_nNodes, m_parentSeparator.m_nNodes.length, values, order, m_bayesNet);
365                                        int iNodeCPT =  getCPT(m_nNodes, m_nNodes.length, values, order, m_bayesNet);
366                                        if ( m_parentSeparator.m_fiChild[iSepCPT] > 0) {
367                                                m_P[iNodeCPT] *= m_parentSeparator.m_fiParent[iSepCPT] / m_parentSeparator.m_fiChild[iSepCPT];
368                                        } else {
369                                                m_P[iNodeCPT] = 0;
370                                        }
371                                        // update values
372                                        int i = 0;
373                                        values[i]++;
374                                        while (i < m_nNodes.length && values[i] == m_bayesNet.getCardinality(m_nNodes[i])) {
375                                                values[i] = 0;
376                                                i++;
377                                                if (i < m_nNodes.length) {
378                                                        values[i]++;
379                                                }
380                                        }
381                                }
382                                // normalize
383                                double sum = 0;
384                                for (int iPos = 0; iPos < m_nCardinality; iPos++) {
385                                        sum += m_P[iPos];
386                                }
387                                for (int iPos = 0; iPos < m_nCardinality; iPos++) {
388                                        m_P[iPos] /= sum;
389                                }
390                                m_parentSeparator.updateFromChild();
391                                calcMarginalProbabilities();
392                        }
393                        if (recursively) {
394                                for (Iterator child = m_children.iterator(); child.hasNext(); ) {
395                                        JunctionTreeNode childNode = (JunctionTreeNode) child.next();
396                                        childNode.initializeDown(true);
397                                }                       
398                        }
399                } // initializeDown
400               
401               
402                /** calculate marginal probabilities for the individual nodes in the clique.
403                 * Store results in m_MarginalP
404                 */
405                void calcMarginalProbabilities() {                     
406                        // calculate marginal probabilities
407                        int [] values = new int[m_nNodes.length];
408                        int [] order = new int[m_bayesNet.getNrOfNodes()];
409                        m_MarginalP = new double[m_nNodes.length][];
410                        for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
411                                order[m_nNodes[iNode]] = iNode;
412                                m_MarginalP[iNode]=new double[m_bayesNet.getCardinality(m_nNodes[iNode])];
413                        }
414                        for (int iPos = 0; iPos < m_nCardinality; iPos++) {
415                                int iNodeCPT =  getCPT(m_nNodes, m_nNodes.length, values, order, m_bayesNet);
416                                for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
417                                        m_MarginalP[iNode][values[iNode]] += m_P[iNodeCPT];
418                                }
419                                // update values
420                                int i = 0;
421                                values[i]++;
422                                while (i < m_nNodes.length && values[i] == m_bayesNet.getCardinality(m_nNodes[i])) {
423                                        values[i] = 0;
424                                        i++;
425                                        if (i < m_nNodes.length) {
426                                                values[i]++;
427                                        }
428                                }
429                        }
430                       
431                        for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
432                                m_Margins[m_nNodes[iNode]] = m_MarginalP[iNode]; 
433                        }
434                } // calcMarginalProbabilities
435               
436                public String toString() {
437                        StringBuffer buf = new StringBuffer();
438                        for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
439                                buf.append(m_bayesNet.getNodeName(m_nNodes[iNode]) + ": ");
440                                for (int iValue = 0; iValue < m_MarginalP[iNode].length; iValue++) {
441                                        buf.append(m_MarginalP[iNode][iValue] + " ");
442                                }
443                                buf.append('\n');
444                        }
445                        for (Iterator child = m_children.iterator(); child.hasNext(); ) {
446                                JunctionTreeNode childNode = (JunctionTreeNode) child.next();
447                                buf.append("----------------\n");
448                                buf.append(childNode.toString());
449                        }                       
450                        return buf.toString();
451                } // toString
452               
453                void calculatePotentials(BayesNet bayesNet, Set clique, boolean [] bDone) {
454                        m_fi = new double[m_nCardinality];
455                       
456                        int [] values = new int[m_nNodes.length];
457                        int [] order = new int[bayesNet.getNrOfNodes()];
458                        for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
459                                order[m_nNodes[iNode]] = iNode;
460                        }
461                        // find conditional probabilities that need to be taken in account
462                        boolean [] bIsContained = new boolean[m_nNodes.length];
463                        for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
464                                int nNode = m_nNodes[iNode];
465                                bIsContained[iNode] = !bDone[nNode];
466                                for (int iParent = 0; iParent < bayesNet.getNrOfParents(nNode); iParent++) {
467                                        int nParent = bayesNet.getParent(nNode, iParent);
468                                        if (!clique.contains(nParent)) {
469                                                bIsContained[iNode] = false;
470                                        }
471                                }
472                                if (bIsContained[iNode]) {
473                                        bDone[nNode] = true;
474                                        if (m_debug) {
475                                                System.out.println("adding node " +nNode);
476                                        }
477                                }
478                        }                       
479
480                        // fill in the values
481                        for (int iPos = 0; iPos < m_nCardinality; iPos++) {
482                                int iCPT = getCPT(m_nNodes, m_nNodes.length, values, order, bayesNet);
483                                m_fi[iCPT] = 1.0;
484                                for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
485                                        if (bIsContained[iNode]) {
486                                                int nNode = m_nNodes[iNode];
487                                                int [] nNodes = bayesNet.getParentSet(nNode).getParents();
488                                                int iCPT2 = getCPT(nNodes, bayesNet.getNrOfParents(nNode), values, order, bayesNet);
489                                                double f = bayesNet.getDistributions()[nNode][iCPT2].getProbability(values[iNode]);
490                                                m_fi[iCPT] *= f;
491                                        }
492                                }
493                               
494                                // update values
495                                int i = 0;
496                                values[i]++;
497                                while (i < m_nNodes.length && values[i] == bayesNet.getCardinality(m_nNodes[i])) {
498                                        values[i] = 0;
499                                        i++;
500                                        if (i < m_nNodes.length) {
501                                                values[i]++;
502                                        }
503                                }
504                        }
505                } // calculatePotentials
506
507                JunctionTreeNode(Set clique, BayesNet bayesNet, boolean [] bDone) {
508                        m_bayesNet = bayesNet;
509                        m_children = new Vector();
510                        //////////////////////
511                        // initialize node set
512                        m_nNodes = new int[clique.size()];
513                        int iPos = 0;
514                        m_nCardinality = 1;
515                        for(Iterator nodes = clique.iterator(); nodes.hasNext();) {
516                                int iNode = (Integer) nodes.next();
517                                m_nNodes[iPos++] = iNode;
518                                m_nCardinality *= bayesNet.getCardinality(iNode);
519                        }
520                        ////////////////////////////////
521                        // initialize potential function
522                        calculatePotentials(bayesNet, clique, bDone);
523       } // JunctionTreeNode c'tor
524
525                /* check whether this junciton tree node contains node nNode
526                 *
527                 */
528                boolean contains(int nNode) {
529                        for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
530                                if (m_nNodes[iNode]== nNode){
531                                        return true;
532                                }
533                        }
534                        return false;
535                } // contains
536               
537                public void setEvidence(int nNode, int iValue) throws Exception {
538                        int [] values = new int[m_nNodes.length];
539                        int [] order = new int[m_bayesNet.getNrOfNodes()];
540
541                        int nNodeIdx = -1;
542                        for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
543                                order[m_nNodes[iNode]] = iNode;
544                                if (m_nNodes[iNode] == nNode) {
545                                        nNodeIdx = iNode;
546                                }
547                        }
548                        if (nNodeIdx < 0) {
549                                throw new Exception("setEvidence: Node " + nNode + " not found in this clique");
550                        }
551                        for (int iPos = 0; iPos < m_nCardinality; iPos++) {
552                                if (values[nNodeIdx] != iValue) {
553                                        int iNodeCPT =  getCPT(m_nNodes, m_nNodes.length, values, order, m_bayesNet);
554                                        m_P[iNodeCPT] = 0;
555                                }
556                                // update values
557                                int i = 0;
558                                values[i]++;
559                                while (i < m_nNodes.length && values[i] == m_bayesNet.getCardinality(m_nNodes[i])) {
560                                        values[i] = 0;
561                                        i++;
562                                        if (i < m_nNodes.length) {
563                                                values[i]++;
564                                        }
565                                }
566                        }               
567                        // normalize
568                        double sum = 0;
569                        for (int iPos = 0; iPos < m_nCardinality; iPos++) {
570                                sum += m_P[iPos];
571                        }
572                        for (int iPos = 0; iPos < m_nCardinality; iPos++) {
573                                m_P[iPos] /= sum;
574                        }
575                        calcMarginalProbabilities();
576                        updateEvidence(this);
577                } // setEvidence
578
579                void updateEvidence(JunctionTreeNode source) {
580                        if (source != this) {
581                                int [] values = new int[m_nNodes.length];
582                                int [] order = new int[m_bayesNet.getNrOfNodes()];
583                                for (int iNode = 0; iNode < m_nNodes.length; iNode++) {
584                                        order[m_nNodes[iNode]] = iNode;
585                                }
586                                int [] nChildNodes = source.m_parentSeparator.m_nNodes;
587                                int nNumChildNodes = nChildNodes.length; 
588                                for (int iPos = 0; iPos < m_nCardinality; iPos++) {
589                                        int iNodeCPT =  getCPT(m_nNodes, m_nNodes.length, values, order, m_bayesNet);
590                                        int iChildCPT =  getCPT(nChildNodes, nNumChildNodes, values, order, m_bayesNet);
591                                        if (source.m_parentSeparator.m_fiParent[iChildCPT] != 0) {
592                                                m_P[iNodeCPT] *= source.m_parentSeparator.m_fiChild[iChildCPT]/source.m_parentSeparator.m_fiParent[iChildCPT];
593                                        } else {
594                                                m_P[iNodeCPT] = 0;
595                                        }
596                                        // update values
597                                        int i = 0;
598                                        values[i]++;
599                                        while (i < m_nNodes.length && values[i] == m_bayesNet.getCardinality(m_nNodes[i])) {
600                                                values[i] = 0;
601                                                i++;
602                                                if (i < m_nNodes.length) {
603                                                        values[i]++;
604                                                }
605                                        }
606                                }               
607                                // normalize
608                                double sum = 0;
609                                for (int iPos = 0; iPos < m_nCardinality; iPos++) {
610                                        sum += m_P[iPos];
611                                }
612                                for (int iPos = 0; iPos < m_nCardinality; iPos++) {
613                                        m_P[iPos] /= sum;
614                                }
615                                calcMarginalProbabilities();
616                        }
617                        for (Iterator child = m_children.iterator(); child.hasNext(); ) {
618                                JunctionTreeNode childNode = (JunctionTreeNode) child.next();
619                                if (childNode != source) {
620                                        childNode.initializeDown(true);
621                                }
622                        }                       
623                        if (m_parentSeparator != null) {
624                                m_parentSeparator.updateFromChild();
625                                m_parentSeparator.m_parentNode.updateEvidence(this);
626                                m_parentSeparator.updateFromParent();
627                        }
628                } // updateEvidence
629
630                /**
631                 * Returns the revision string.
632                 *
633                 * @return              the revision
634                 */
635                public String getRevision() {
636                  return RevisionUtils.extract("$Revision: 4899 $");
637                }
638               
639        } // class JunctionTreeNode
640
641        int getCPT(int [] nodeSet, int nNodes, int[] values, int[] order, BayesNet bayesNet) {
642                int iCPTnew = 0;
643                for (int iNode = 0; iNode < nNodes; iNode++) {
644                        int nNode = nodeSet[iNode];
645                        iCPTnew = iCPTnew * bayesNet.getCardinality(nNode);
646                        iCPTnew += values[order[nNode]];
647                }
648                return iCPTnew;
649        } // getCPT
650
651        int [] getCliqueTree(int [] order, Set [] cliques, Set [] separators) {
652                int nNodes = order.length;
653                int [] parentCliques = new int[nNodes];
654                //for (int i = nNodes - 1; i >= 0; i--) {
655                for (int i = 0; i < nNodes; i++) {
656                        int iNode = order[i];
657                        parentCliques[iNode] = -1;
658                        if (cliques[iNode] != null && separators[iNode].size() > 0) {
659                                //for (int j = nNodes - 1; j > i; j--) {
660                                for (int j = 0; j < nNodes; j++) {
661                                        int iNode2 = order[j];
662                                        if (iNode!= iNode2 && cliques[iNode2] != null && cliques[iNode2].containsAll(separators[iNode])) {
663                                                parentCliques[iNode] = iNode2;
664                                                j = i;
665                                                j = 0;
666                                                j = nNodes;
667                                        }
668                                }
669                               
670                        }
671                }
672                return parentCliques;
673        } // getCliqueTree
674       
675        /** calculate separator sets in clique tree
676         *
677         * @param order: maximum cardinality ordering of the graph
678         * @param cliques: set of cliques
679         * @return set of separator sets
680         */
681        Set [] getSeparators(int [] order, Set [] cliques) {
682                int nNodes = order.length;
683                Set [] separators = new HashSet[nNodes];
684                Set processedNodes = new HashSet(); 
685                //for (int i = nNodes - 1; i >= 0; i--) {
686                for (int i = 0; i < nNodes; i++) {
687                        int iNode = order[i];
688                        if (cliques[iNode] != null) {
689                                Set separator = new HashSet();
690                                separator.addAll(cliques[iNode]);
691                                separator.retainAll(processedNodes);
692                                separators[iNode] = separator;
693                                processedNodes.addAll(cliques[iNode]);
694                        }
695                }
696                return separators;
697        } // getSeparators
698       
699        /**
700         * get cliques in a decomposable graph represented by an adjacency matrix
701         *
702         * @param order: maximum cardinality ordering of the graph
703         * @param bAdjacencyMatrix: decomposable graph
704         * @return set of cliques
705         */
706        Set [] getCliques(int[] order, boolean[][] bAdjacencyMatrix) throws Exception {
707                int nNodes = bAdjacencyMatrix.length;
708                Set [] cliques = new HashSet[nNodes];
709                //int[] inverseOrder = new int[nNodes];
710                //for (int iNode = 0; iNode < nNodes; iNode++) {
711                        //inverseOrder[order[iNode]] = iNode;
712                //}
713                // consult nodes in reverse order
714                for (int i = nNodes - 1; i >= 0; i--) {
715                        int iNode = order[i];
716                        if (iNode == 22) {
717                                int h = 3;
718                                h ++;
719                        }
720                        Set clique = new HashSet();
721                        clique.add(iNode);
722                        for (int j = 0; j < i; j++) {
723                                int iNode2 = order[j];
724                                if (bAdjacencyMatrix[iNode][iNode2]) {
725                                        clique.add(iNode2);
726                                }
727                        }
728                       
729                        //for (int iNode2 = 0; iNode2 < nNodes; iNode2++) {
730                                //if (bAdjacencyMatrix[iNode][iNode2] && inverseOrder[iNode2] < inverseOrder[iNode]) {
731                                        //clique.add(iNode2);
732                                //}
733                        //}
734                        cliques[iNode] = clique;
735                }
736                for (int iNode = 0; iNode < nNodes; iNode++) {
737                        for (int iNode2 = 0; iNode2 < nNodes; iNode2++) {
738                                if (iNode != iNode2 && cliques[iNode]!= null && cliques[iNode2]!= null && cliques[iNode].containsAll(cliques[iNode2])) {
739                                        cliques[iNode2] = null;
740                                }
741                        }
742                }               
743                // sanity check
744                if (m_debug) {
745                int [] nNodeSet = new int[nNodes];
746                for (int iNode = 0; iNode < nNodes; iNode++) {
747                        if (cliques[iNode] != null) {
748                                Iterator it = cliques[iNode].iterator();
749                                int k = 0;
750                                while (it.hasNext()) {
751                                        nNodeSet[k++] = (Integer) it.next();
752                                }
753                                for (int i = 0; i < cliques[iNode].size(); i++) {
754                                        for (int j = 0; j < cliques[iNode].size(); j++) {
755                                                if (i!=j && !bAdjacencyMatrix[nNodeSet[i]][nNodeSet[j]]) {
756                                                        throw new Exception("Non clique" + i + " " + j);
757                                                }
758                                        }
759                                }
760                        }
761                }
762                }
763                return cliques;
764        } // getCliques
765
766        /**
767         * moralize DAG and calculate
768         * adjacency matrix representation for a Bayes Network, effecively
769         * converting the directed acyclic graph to an undirected graph.
770         *
771         * @param bayesNet
772         *            Bayes Network to process
773         * @return adjacencies in boolean matrix format
774         */
775        public boolean[][] moralize(BayesNet bayesNet) {
776                int nNodes = bayesNet.getNrOfNodes();
777                boolean[][] bAdjacencyMatrix = new boolean[nNodes][nNodes];
778                for (int iNode = 0; iNode < nNodes; iNode++) {
779                        ParentSet parents = bayesNet.getParentSets()[iNode];
780                        moralizeNode(parents, iNode, bAdjacencyMatrix);
781                }
782                return bAdjacencyMatrix;
783        } // moralize
784
785        private void moralizeNode(ParentSet parents, int iNode, boolean[][] bAdjacencyMatrix) {
786                for (int iParent = 0; iParent < parents.getNrOfParents(); iParent++) {
787                        int nParent = parents.getParent(iParent);
788                        if ( m_debug && !bAdjacencyMatrix[iNode][nParent])
789                                System.out.println("Insert " + iNode + "--" + nParent);
790                        bAdjacencyMatrix[iNode][nParent] = true;
791                        bAdjacencyMatrix[nParent][iNode] = true;
792                        for (int iParent2 = iParent + 1; iParent2 < parents.getNrOfParents(); iParent2++) {
793                                int nParent2 = parents.getParent(iParent2);
794                                if (m_debug && !bAdjacencyMatrix[nParent2][nParent])
795                                        System.out.println("Mary " + nParent + "--" + nParent2);
796                                bAdjacencyMatrix[nParent2][nParent] = true;
797                                bAdjacencyMatrix[nParent][nParent2] = true;
798                        }
799                }       
800        } // moralizeNode
801       
802        /**
803         * Apply Tarjan and Yannakakis (1984) fill in algorithm for graph
804         * triangulation. In reverse order, insert edges between any non-adjacent
805         * neighbors that are lower numbered in the ordering.
806         *
807         * Side effect: input matrix is used as output
808         *
809         * @param order
810         *            node ordering
811         * @param bAdjacencyMatrix
812         *            boolean matrix representing the graph
813         * @return boolean matrix representing the graph with fill ins
814         */
815        public boolean[][] fillIn(int[] order, boolean[][] bAdjacencyMatrix) {
816                int nNodes = bAdjacencyMatrix.length;
817                int[] inverseOrder = new int[nNodes];
818                for (int iNode = 0; iNode < nNodes; iNode++) {
819                        inverseOrder[order[iNode]] = iNode;
820                }
821                // consult nodes in reverse order
822                for (int i = nNodes - 1; i >= 0; i--) {
823                        int iNode = order[i];
824                        // find pairs of neighbors with lower order
825                        for (int j = 0; j < i; j++) {
826                                int iNode2 = order[j];
827                                if (bAdjacencyMatrix[iNode][iNode2]) {
828                                        for (int k = j+1; k < i; k++) {
829                                                int iNode3 = order[k];
830                                                if (bAdjacencyMatrix[iNode][iNode3]) {
831                                                        // fill in
832                                                        if (m_debug && (!bAdjacencyMatrix[iNode2][iNode3] || !bAdjacencyMatrix[iNode3][iNode2]) )
833                                                                System.out.println("Fill in " + iNode2 + "--" + iNode3);
834                                                        bAdjacencyMatrix[iNode2][iNode3] = true;
835                                                        bAdjacencyMatrix[iNode3][iNode2] = true;
836                                                }
837                                        }
838                                }
839                        }
840                }
841                return bAdjacencyMatrix;
842        } // fillIn
843
844        /**
845         * calculate maximum cardinality ordering; start with first node add node
846         * that has most neighbors already ordered till all nodes are in the
847         * ordering
848         *
849         * This implementation does not assume the graph is connected
850         *
851         * @param bAdjacencyMatrix:
852         *            n by n matrix with adjacencies in graph of n nodes
853         * @return maximum cardinality ordering
854         */
855        int[] getMaxCardOrder(boolean[][] bAdjacencyMatrix) {
856                int nNodes = bAdjacencyMatrix.length;
857                int[] order = new int[nNodes];
858                if (nNodes==0) {return order;}
859                boolean[] bDone = new boolean[nNodes];
860                // start with node 0
861                order[0] = 0;
862                bDone[0] = true;
863                // order remaining nodes
864                for (int iNode = 1; iNode < nNodes; iNode++) {
865                        int nMaxCard = -1;
866                        int iBestNode = -1;
867                        // find node with higest cardinality of previously ordered nodes
868                        for (int iNode2 = 0; iNode2 < nNodes; iNode2++) {
869                                if (!bDone[iNode2]) {
870                                        int nCard = 0;
871                                        // calculate cardinality for node iNode2
872                                        for (int iNode3 = 0; iNode3 < nNodes; iNode3++) {
873                                                if (bAdjacencyMatrix[iNode2][iNode3] && bDone[iNode3]) {
874                                                        nCard++;
875                                                }
876                                        }
877                                        if (nCard > nMaxCard) {
878                                                nMaxCard = nCard;
879                                                iBestNode = iNode2;
880                                        }
881                                }
882                        }
883                        order[iNode] = iBestNode;
884                        bDone[iBestNode] = true;
885                }
886                return order;
887        } // getMaxCardOrder
888
889        public void setEvidence(int nNode, int iValue) throws Exception {
890                if (m_root == null) {
891                        throw new Exception("Junction tree not initialize yet");
892                }
893                int iJtNode = 0;
894                while (iJtNode < jtNodes.length && (jtNodes[iJtNode] == null ||!jtNodes[iJtNode].contains(nNode))) {
895                        iJtNode++;
896                }
897                if (jtNodes.length == iJtNode) {
898                        throw new Exception("Could not find node " + nNode + " in junction tree");
899                }
900                jtNodes[iJtNode].setEvidence(nNode, iValue);
901        } // setEvidence
902       
903        public String toString() {
904                return m_root.toString();
905        } // toString
906
907        double [][] m_Margins;
908        public double [] getMargin(int iNode) {
909                return m_Margins[iNode];
910        } // getMargin
911
912        /**
913         * Returns the revision string.
914         *
915         * @return              the revision
916         */
917        public String getRevision() {
918          return RevisionUtils.extract("$Revision: 4899 $");
919        }
920       
921        public static void main(String[] args) {
922                try {
923                        BIFReader bayesNet = new BIFReader();
924                        bayesNet.processFile(args[0]);
925
926                        MarginCalculator dc = new MarginCalculator();
927                        dc.calcMargins(bayesNet);
928                        int iNode = 2;
929                        int iValue = 0;
930                        int iNode2 = 4;
931                        int iValue2 = 0;
932                        dc.setEvidence(iNode, iValue);
933                        dc.setEvidence(iNode2, iValue2);
934                        System.out.print(dc.toString());
935
936
937                        dc.calcFullMargins(bayesNet);
938                        dc.setEvidence(iNode, iValue);
939                        dc.setEvidence(iNode2, iValue2);
940                        System.out.println("==============");
941                        System.out.print(dc.toString());
942                       
943                       
944                } catch (Exception e) {
945                        e.printStackTrace();
946                }
947        } // main
948
949} // class MarginCalculator
Note: See TracBrowser for help on using the repository browser.