source: src/main/java/weka/classifiers/bayes/net/BayesNetGenerator.java @ 16

Last change on this file since 16 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 19.0 KB
Line 
1/*
2 * This program is free software; you can redistribute it and/or modify
3 * it under the terms of the GNU General Public License as published by
4 * the Free Software Foundation; either version 2 of the License, or
5 * (at your option) any later version.
6 *
7 * This program is distributed in the hope that it will be useful,
8 * but WITHOUT ANY WARRANTY; without even the implied warranty of
9 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 * GNU General Public License for more details.
11 *
12 * You should have received a copy of the GNU General Public License
13 * along with this program; if not, write to the Free Software
14 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 * BayesNet.java
19 * Copyright (C) 2003 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.bayes.net;
24
25import weka.classifiers.bayes.net.estimate.DiscreteEstimatorBayes;
26import weka.core.Attribute;
27import weka.core.FastVector;
28import weka.core.Instance;
29import weka.core.Instances;
30import weka.core.DenseInstance;
31import weka.core.Option;
32import weka.core.OptionHandler;
33import weka.core.RevisionUtils;
34import weka.core.Utils;
35import weka.estimators.Estimator;
36
37import java.util.Enumeration;
38import java.util.Random;
39import java.util.Vector;
40
41/**
42 <!-- globalinfo-start -->
43 * Bayes Network learning using various search algorithms and quality measures.<br/>
44 * Base class for a Bayes Network classifier. Provides datastructures (network structure, conditional probability distributions, etc.) and facilities common to Bayes Network learning algorithms like K2 and B.<br/>
45 * <br/>
46 * For more information see:<br/>
47 * <br/>
48 * http://www.cs.waikato.ac.nz/~remco/weka.pdf
49 * <p/>
50 <!-- globalinfo-end -->
51 *
52 <!-- options-start -->
53 * Valid options are: <p/>
54 *
55 * <pre> -B
56 *  Generate network (instead of instances)
57 * </pre>
58 *
59 * <pre> -N &lt;integer&gt;
60 *  Nr of nodes
61 * </pre>
62 *
63 * <pre> -A &lt;integer&gt;
64 *  Nr of arcs
65 * </pre>
66 *
67 * <pre> -M &lt;integer&gt;
68 *  Nr of instances
69 * </pre>
70 *
71 * <pre> -C &lt;integer&gt;
72 *  Cardinality of the variables
73 * </pre>
74 *
75 * <pre> -S &lt;integer&gt;
76 *  Seed for random number generator
77 * </pre>
78 *
79 * <pre> -F &lt;file&gt;
80 *  The BIF file to obtain the structure from.
81 * </pre>
82 *
83 <!-- options-end -->
84 *
85 * @author Remco Bouckaert (rrb@xm.co.nz)
86 * @version $Revision: 5987 $
87 */
88public class BayesNetGenerator extends EditableBayesNet {
89    /** the seed value */
90    int m_nSeed = 1;
91   
92    /** the random number generator */
93    Random random;
94   
95    /** for serialization */
96    static final long serialVersionUID = -7462571170596157720L;
97
98        /**
99         * Constructor for BayesNetGenerator.
100         */
101        public BayesNetGenerator() {
102                super();
103        } // c'tor
104
105        /**
106         * Generate random connected Bayesian network with discrete nodes
107         * having all the same cardinality.
108         *
109         * @throws Exception if something goes wrong
110         */
111        public void generateRandomNetwork () throws Exception {
112                if (m_otherBayesNet == null) {
113                        // generate from scratch
114                        Init(m_nNrOfNodes, m_nCardinality);
115                        generateRandomNetworkStructure(m_nNrOfNodes, m_nNrOfArcs);
116                        generateRandomDistributions(m_nNrOfNodes, m_nCardinality);
117                } else {
118                        // read from file, just copy parent sets and distributions
119                        m_nNrOfNodes = m_otherBayesNet.getNrOfNodes();
120                        m_ParentSets = m_otherBayesNet.getParentSets();
121                        m_Distributions = m_otherBayesNet.getDistributions();
122
123
124                        random = new Random(m_nSeed);
125                        // initialize m_Instances
126                        FastVector attInfo = new FastVector(m_nNrOfNodes);
127                        // generate value strings
128
129                        for (int iNode = 0; iNode < m_nNrOfNodes; iNode++) {
130                                int nValues = m_otherBayesNet.getCardinality(iNode);
131                                FastVector nomStrings = new FastVector(nValues + 1);
132                                for (int iValue = 0; iValue < nValues; iValue++) {
133                                        nomStrings.addElement(m_otherBayesNet.getNodeValue(iNode, iValue));
134                                }
135                                Attribute att = new Attribute(m_otherBayesNet.getNodeName(iNode), nomStrings);
136                                attInfo.addElement(att);
137                        }
138
139                        m_Instances = new Instances(m_otherBayesNet.getName(), attInfo, 100);
140                        m_Instances.setClassIndex(m_nNrOfNodes - 1);
141                }
142        } // GenerateRandomNetwork
143
144        /**
145         * Init defines a minimal Bayes net with no arcs
146         * @param nNodes number of nodes in the Bayes net
147         * @param nValues number of values each of the nodes can take
148         * @throws Exception if something goes wrong
149         */
150        public void Init(int nNodes, int nValues) throws Exception {
151                random = new Random(m_nSeed);
152                // initialize structure
153                FastVector attInfo = new FastVector(nNodes);
154                // generate value strings
155        FastVector nomStrings = new FastVector(nValues + 1);
156        for (int iValue = 0; iValue < nValues; iValue++) {
157                        nomStrings.addElement("Value" + (iValue + 1));
158        }
159
160                for (int iNode = 0; iNode < nNodes; iNode++) {
161                        Attribute att = new Attribute("Node" + (iNode + 1), nomStrings);
162                        attInfo.addElement(att);
163                }
164                m_Instances = new Instances("RandomNet", attInfo, 100);
165                m_Instances.setClassIndex(nNodes - 1);
166                setUseADTree(false);
167//              m_bInitAsNaiveBayes = false;
168//              m_bMarkovBlanketClassifier = false;
169                initStructure();
170               
171                // initialize conditional distribution tables
172                m_Distributions = new Estimator[nNodes][1];
173                for (int iNode = 0; iNode < nNodes; iNode++) {
174                        m_Distributions[iNode][0] = 
175                          new DiscreteEstimatorBayes(nValues, getEstimator().getAlpha());
176                }
177                m_nEvidence = new FastVector(nNodes);
178                for (int i = 0; i < nNodes; i++) {
179                        m_nEvidence.addElement(-1);
180                }
181                m_fMarginP = new FastVector(nNodes);
182                for (int i = 0; i < nNodes; i++) {
183                        double[] P = new double[getCardinality(i)];
184                        m_fMarginP.addElement(P);
185                }
186
187                m_nPositionX = new FastVector(nNodes);
188                m_nPositionY = new FastVector(nNodes);
189                for (int iNode = 0; iNode < nNodes; iNode++) {
190                        m_nPositionX.addElement(iNode%10 * 50);
191                        m_nPositionY.addElement(((int)(iNode/10)) * 50);
192                }
193        } // DefineNodes
194
195        /**
196         * GenerateRandomNetworkStructure generate random connected Bayesian network
197         * @param nNodes number of nodes in the Bayes net to generate
198         * @param nArcs number of arcs to generate. Must be between nNodes - 1 and nNodes * (nNodes-1) / 2
199         * @throws Exception if number of arcs is incorrect
200         */
201        public void generateRandomNetworkStructure(int nNodes, int nArcs) 
202                throws Exception
203        {
204                if (nArcs < nNodes - 1) {
205                        throw new Exception("Number of arcs should be at least (nNodes - 1) = " + (nNodes - 1) + " instead of " + nArcs);
206                }
207                if (nArcs > nNodes * (nNodes - 1) / 2) {
208                        throw new Exception("Number of arcs should be at most nNodes * (nNodes - 1) / 2 = "+ (nNodes * (nNodes - 1) / 2) + " instead of " + nArcs);
209                }
210                if (nArcs == 0) {return;} // deal with  patalogical case for nNodes = 1
211
212            // first generate tree connecting all nodes
213            generateTree(nNodes);
214            // The tree contains nNodes - 1 arcs, so there are
215            // nArcs - (nNodes-1) to add at random.
216            // All arcs point from lower to higher ordered nodes
217            // so that acyclicity is ensured.
218            for (int iArc = nNodes - 1; iArc < nArcs; iArc++) {
219                boolean bDone = false;
220                while (!bDone) {
221                                int nNode1 = random.nextInt(nNodes);
222                                int nNode2 = random.nextInt(nNodes);
223                                if (nNode1 == nNode2) {nNode2 = (nNode1 + 1) % nNodes;}
224                                if (nNode2 < nNode1) {int h = nNode1; nNode1 = nNode2; nNode2 = h;}
225                                if (!m_ParentSets[nNode2].contains(nNode1)) {
226                                        m_ParentSets[nNode2].addParent(nNode1, m_Instances);
227                                        bDone = true;
228                                }
229                }
230            }
231
232        } // GenerateRandomNetworkStructure
233       
234        /**
235         * GenerateTree creates a tree-like network structure (actually a
236         * forest) by starting with a randomly selected pair of nodes, add
237         * an arc between. Then keep on selecting one of the connected nodes
238         * and one of the unconnected ones and add an arrow between them,
239         * till all nodes are connected.
240         * @param nNodes number of nodes in the Bayes net to generate
241         */
242        void generateTree(int nNodes) {
243        boolean [] bConnected = new boolean [nNodes];
244        // start adding an arc at random
245                int nNode1 = random.nextInt(nNodes);
246                int nNode2 = random.nextInt(nNodes);
247                if (nNode1 == nNode2) {nNode2 = (nNode1 + 1) % nNodes;}
248                if (nNode2 < nNode1) {int h = nNode1; nNode1 = nNode2; nNode2 = h;}
249                m_ParentSets[nNode2].addParent(nNode1, m_Instances);
250                bConnected[nNode1] = true;
251                bConnected[nNode2] = true;
252                // Repeatedly, select one of the connected nodes, and one of
253                // the unconnected nodes and add an arc.
254            // All arcs point from lower to higher ordered nodes
255            // so that acyclicity is ensured.
256                for (int iArc = 2; iArc < nNodes; iArc++ ) {
257                        int nNode = random.nextInt(nNodes);
258                        nNode1 = 0; //  one of the connected nodes
259                        while (nNode >= 0) {
260                                nNode1 = (nNode1 + 1) % nNodes;
261                                while (!bConnected[nNode1]) {
262                                        nNode1 = (nNode1 + 1) % nNodes;
263                                }
264                                nNode--;
265                        }
266                        nNode = random.nextInt(nNodes);
267                        nNode2 = 0; //  one of the unconnected nodes
268                        while (nNode >= 0) {
269                                nNode2 = (nNode2 + 1) % nNodes;
270                                while (bConnected[nNode2]) {
271                                        nNode2 = (nNode2 + 1) % nNodes;
272                                }
273                                nNode--;
274                        }
275                        if (nNode2 < nNode1) {int h = nNode1; nNode1 = nNode2; nNode2 = h;}
276                        m_ParentSets[nNode2].addParent(nNode1, m_Instances);
277                        bConnected[nNode1] = true;
278                        bConnected[nNode2] = true;
279                }
280        } // GenerateTree
281       
282        /**
283         * GenerateRandomDistributions generates discrete conditional distribution tables
284         * for all nodes of a Bayes network once a network structure has been determined.
285         * @param nNodes number of nodes in the Bayes net
286         * @param nValues number of values each of the nodes can take
287         */
288    void generateRandomDistributions(int nNodes, int nValues) {
289            // Reserve space for CPTs
290        int nMaxParentCardinality = 1;
291            for (int iAttribute = 0; iAttribute < nNodes; iAttribute++) {
292            if (m_ParentSets[iAttribute].getCardinalityOfParents() > nMaxParentCardinality) {
293                     nMaxParentCardinality = m_ParentSets[iAttribute].getCardinalityOfParents();
294            } 
295        } 
296
297        // Reserve plenty of memory
298        m_Distributions = new Estimator[m_Instances.numAttributes()][nMaxParentCardinality];
299
300        // estimate CPTs
301        for (int iAttribute = 0; iAttribute < nNodes; iAttribute++) {
302                int [] nPs = new int [nValues + 1];
303                nPs[0] = 0;
304                nPs[nValues] = 1000;
305            for (int iParent = 0; iParent < m_ParentSets[iAttribute].getCardinalityOfParents(); iParent++) {
306                // fill array with random nr's
307                for (int iValue = 1; iValue < nValues; iValue++)  {
308                        nPs[iValue] = random.nextInt(1000);
309                }
310                // sort
311                for (int iValue = 1; iValue < nValues; iValue++)  {
312                        for (int iValue2 = iValue + 1; iValue2 < nValues; iValue2++)  {
313                                if (nPs[iValue2] < nPs[iValue]) {
314                                        int h = nPs[iValue2]; nPs[iValue2] = nPs[iValue]; nPs[iValue] = h;
315                                }
316                        }
317                }
318                // assign to probability tables
319                DiscreteEstimatorBayes d = new DiscreteEstimatorBayes(nValues, getEstimator().getAlpha());
320                for (int iValue = 0; iValue < nValues; iValue++)  {
321                        d.addValue(iValue, nPs[iValue + 1] - nPs[iValue]);
322                }
323                    m_Distributions[iAttribute][iParent] = d;
324            } 
325        } 
326    } // GenerateRandomDistributions
327   
328        /**
329         * GenerateInstances generates random instances sampling from the
330         * distribution represented by the Bayes network structure. It assumes
331         * a Bayes network structure has been initialized
332         *
333         * @throws Exception if something goes wrong
334         */
335        public void generateInstances () throws Exception {
336            int [] order = getOrder();
337                for (int iInstance = 0; iInstance < m_nNrOfInstances; iInstance++) {
338                    int nNrOfAtts = m_Instances.numAttributes();
339                        Instance instance = new DenseInstance(nNrOfAtts);
340                        instance.setDataset(m_Instances);
341                        for (int iAtt2 = 0; iAtt2 < nNrOfAtts; iAtt2++) {
342                            int iAtt = order[iAtt2];
343
344                                double iCPT = 0;
345
346                                for (int iParent = 0; iParent < m_ParentSets[iAtt].getNrOfParents(); iParent++) {
347                                  int nParent = m_ParentSets[iAtt].getParent(iParent);
348                                  iCPT = iCPT * m_Instances.attribute(nParent).numValues() + instance.value(nParent);
349                                } 
350       
351                                double fRandom = random.nextInt(1000) / 1000.0f;
352                                int iValue = 0;
353                                while (fRandom > m_Distributions[iAtt][(int) iCPT].getProbability(iValue)) {
354                                        fRandom = fRandom - m_Distributions[iAtt][(int) iCPT].getProbability(iValue);
355                                        iValue++ ;
356                                }
357                                instance.setValue(iAtt, iValue);
358                        }
359                        m_Instances.add(instance);
360                }
361        } // GenerateInstances
362
363    /**
364     * @throws Exception if there's a cycle in the graph
365     */ 
366    int [] getOrder() throws Exception {
367        int nNrOfAtts = m_Instances.numAttributes();
368        int [] order = new int[nNrOfAtts];
369        boolean [] bDone = new boolean[nNrOfAtts];
370        for (int iAtt = 0; iAtt < nNrOfAtts; iAtt++) {
371            int iAtt2 = 0; 
372            boolean allParentsDone = false;
373            while (!allParentsDone && iAtt2 < nNrOfAtts) {
374                if (!bDone[iAtt2]) {
375                    allParentsDone = true;
376                    int iParent = 0;
377                    while (allParentsDone && iParent < m_ParentSets[iAtt2].getNrOfParents()) {
378                        allParentsDone = bDone[m_ParentSets[iAtt2].getParent(iParent++)];
379                    }
380                    if (allParentsDone && iParent == m_ParentSets[iAtt2].getNrOfParents()) {
381                        order[iAtt] = iAtt2;
382                        bDone[iAtt2] = true;
383                    } else {
384                        iAtt2++;
385                    }
386                } else {
387                    iAtt2++;
388                }
389            }
390            if (!allParentsDone && iAtt2 == nNrOfAtts) {
391                throw new Exception("There appears to be a cycle in the graph");
392            }
393        }
394        return order;
395    } // getOrder
396
397        /**
398         * Returns either the net (if BIF format) or the generated instances
399         *
400         * @return either the net or the generated instances
401         */
402        public String toString() {
403          if (m_bGenerateNet) {
404            return toXMLBIF03();
405          }
406          return m_Instances.toString();
407        } // toString
408       
409
410        boolean m_bGenerateNet = false;
411        int m_nNrOfNodes = 10;
412        int m_nNrOfArcs = 10;
413        int m_nNrOfInstances = 10;
414        int m_nCardinality = 2;
415        String m_sBIFFile = "";
416
417        void setNrOfNodes(int nNrOfNodes) {m_nNrOfNodes = nNrOfNodes;}
418        void setNrOfArcs(int nNrOfArcs) {m_nNrOfArcs = nNrOfArcs;}
419        void setNrOfInstances(int nNrOfInstances) {m_nNrOfInstances = nNrOfInstances;}
420        void setCardinality(int nCardinality) {m_nCardinality = nCardinality;}
421        void setSeed(int nSeed) {m_nSeed = nSeed;}
422
423        /**
424         * Returns an enumeration describing the available options
425         *
426         * @return an enumeration of all the available options
427         */
428        public Enumeration listOptions() {
429                Vector newVector = new Vector(6);
430
431                newVector.addElement(new Option("\tGenerate network (instead of instances)\n", "B", 0, "-B"));
432                newVector.addElement(new Option("\tNr of nodes\n", "N", 1, "-N <integer>"));
433                newVector.addElement(new Option("\tNr of arcs\n", "A", 1, "-A <integer>"));
434                newVector.addElement(new Option("\tNr of instances\n", "M", 1, "-M <integer>"));
435                newVector.addElement(new Option("\tCardinality of the variables\n", "C", 1, "-C <integer>"));
436                newVector.addElement(new Option("\tSeed for random number generator\n", "S", 1, "-S <integer>"));
437                newVector.addElement(new Option("\tThe BIF file to obtain the structure from.\n", "F", 1, "-F <file>"));
438
439                return newVector.elements();
440        } // listOptions
441
442        /**
443         * Parses a given list of options. <p/>
444         *
445         <!-- options-start -->
446         * Valid options are: <p/>
447         *
448         * <pre> -B
449         *  Generate network (instead of instances)
450         * </pre>
451         *
452         * <pre> -N &lt;integer&gt;
453         *  Nr of nodes
454         * </pre>
455         *
456         * <pre> -A &lt;integer&gt;
457         *  Nr of arcs
458         * </pre>
459         *
460         * <pre> -M &lt;integer&gt;
461         *  Nr of instances
462         * </pre>
463         *
464         * <pre> -C &lt;integer&gt;
465         *  Cardinality of the variables
466         * </pre>
467         *
468         * <pre> -S &lt;integer&gt;
469         *  Seed for random number generator
470         * </pre>
471         *
472         * <pre> -F &lt;file&gt;
473         *  The BIF file to obtain the structure from.
474         * </pre>
475         *
476         <!-- options-end -->
477         *
478         * @param options the list of options as an array of strings
479         * @exception Exception if an option is not supported
480         */
481        public void setOptions(String[] options) throws Exception {
482                m_bGenerateNet = Utils.getFlag('B', options);
483
484                String sNrOfNodes = Utils.getOption('N', options);
485                if (sNrOfNodes.length() != 0) {
486                  setNrOfNodes(Integer.parseInt(sNrOfNodes));
487                } else {
488                        setNrOfNodes(10);
489                } 
490
491                String sNrOfArcs = Utils.getOption('A', options);
492                if (sNrOfArcs.length() != 0) {
493                  setNrOfArcs(Integer.parseInt(sNrOfArcs));
494                } else {
495                        setNrOfArcs(10);
496                } 
497
498                String sNrOfInstances = Utils.getOption('M', options);
499                if (sNrOfInstances.length() != 0) {
500                  setNrOfInstances(Integer.parseInt(sNrOfInstances));
501                } else {
502                        setNrOfInstances(10);
503                } 
504
505                String sCardinality = Utils.getOption('C', options);
506                if (sCardinality.length() != 0) {
507                  setCardinality(Integer.parseInt(sCardinality));
508                } else {
509                        setCardinality(2);
510                } 
511
512                String sSeed = Utils.getOption('S', options);
513                if (sSeed.length() != 0) {
514                  setSeed(Integer.parseInt(sSeed));
515                } else {
516                        setSeed(1);
517                } 
518
519                String sBIFFile = Utils.getOption('F', options);
520                if ((sBIFFile != null) && (sBIFFile != "")) {
521                        setBIFFile(sBIFFile);
522                }
523        } // setOptions
524
525        /**
526         * Gets the current settings of the classifier.
527         *
528         * @return an array of strings suitable for passing to setOptions
529         */
530        public String[] getOptions() {
531                String[] options = new String[13];
532                int current = 0;
533                if (m_bGenerateNet) {
534                  options[current++] = "-B";
535                } 
536
537                options[current++] = "-N";
538                options[current++] = "" + m_nNrOfNodes;
539
540                options[current++] = "-A";
541                options[current++] = "" + m_nNrOfArcs;
542
543                options[current++] = "-M";
544                options[current++] = "" + m_nNrOfInstances;
545
546                options[current++] = "-C";
547                options[current++] = "" + m_nCardinality;
548
549                options[current++] = "-S";
550                options[current++] = "" + m_nSeed;
551
552                if (m_sBIFFile.length() != 0) {
553                  options[current++] = "-F";
554                  options[current++] = "" + m_sBIFFile;
555                }
556
557                // Fill up rest with empty strings, not nulls!
558                while (current < options.length) {
559                        options[current++] = "";
560                }
561
562                return options;
563        } // getOptions
564
565    /**
566     * prints all the options to stdout
567     */
568    protected static void printOptions(OptionHandler o) {
569      Enumeration enm = o.listOptions();
570     
571      System.out.println("Options for " + o.getClass().getName() + ":\n");
572     
573      while (enm.hasMoreElements()) {
574        Option option = (Option) enm.nextElement();
575        System.out.println(option.synopsis());
576        System.out.println(option.description());
577      }
578    }
579   
580    /**
581     * Returns the revision string.
582     *
583     * @return          the revision
584     */
585    public String getRevision() {
586      return RevisionUtils.extract("$Revision: 5987 $");
587    }
588
589    /**
590     * Main method
591     *
592     * @param args the commandline parameters
593     */
594    static public void main(String [] args) {
595                BayesNetGenerator b = new BayesNetGenerator();
596        try {
597                if ( (args.length == 0) || (Utils.getFlag('h', args)) ) {
598                        printOptions(b);
599                        return;
600                }
601                b.setOptions(args);
602               
603                b.generateRandomNetwork();
604                if (!b.m_bGenerateNet) { // skip if not required
605                                b.generateInstances();
606                }
607                System.out.println(b.toString());
608        } catch (Exception e) {
609                e.printStackTrace();
610                printOptions(b);
611        }
612    } // main
613   
614} // class BayesNetGenerator
Note: See TracBrowser for help on using the repository browser.