source: src/main/java/weka/classifiers/bayes/net/BIFReader.java @ 25

Last change on this file since 25 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 21.8 KB
Line 
1/*
2 * This program is free software; you can redistribute it and/or modify
3 * it under the terms of the GNU General Public License as published by
4 * the Free Software Foundation; either version 2 of the License, or
5 * (at your option) any later version.
6 *
7 * This program is distributed in the hope that it will be useful,
8 * but WITHOUT ANY WARRANTY; without even the implied warranty of
9 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 * GNU General Public License for more details.
11 *
12 * You should have received a copy of the GNU General Public License
13 * along with this program; if not, write to the Free Software
14 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 * BIFReader.java
19 * Copyright (C) 2003 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.bayes.net;
24
25import weka.classifiers.bayes.BayesNet;
26import weka.classifiers.bayes.net.estimate.DiscreteEstimatorBayes;
27import weka.core.FastVector;
28import weka.core.Instances;
29import weka.core.RevisionUtils;
30import weka.core.TechnicalInformation;
31import weka.core.TechnicalInformation.Type;
32import weka.core.TechnicalInformation.Field;
33import weka.core.TechnicalInformationHandler;
34import weka.estimators.Estimator;
35
36import java.io.File;
37import java.io.StringReader;
38import java.util.StringTokenizer;
39
40import javax.xml.parsers.DocumentBuilderFactory;
41
42import org.w3c.dom.CharacterData;
43import org.w3c.dom.Document;
44import org.w3c.dom.Element;
45import org.w3c.dom.Node;
46import org.w3c.dom.NodeList;
47
48/**
49 <!-- globalinfo-start -->
50 * Builds a description of a Bayes Net classifier stored in XML BIF 0.3 format.<br/>
51 * <br/>
52 * For more details on XML BIF see:<br/>
53 * <br/>
54 * Fabio Cozman, Marek Druzdzel, Daniel Garcia (1998). XML BIF version 0.3. URL http://www-2.cs.cmu.edu/~fgcozman/Research/InterchangeFormat/.
55 * <p/>
56 <!-- globalinfo-end -->
57 *
58 <!-- technical-bibtex-start -->
59 * BibTeX:
60 * <pre>
61 * &#64;misc{Cozman1998,
62 *    author = {Fabio Cozman and Marek Druzdzel and Daniel Garcia},
63 *    title = {XML BIF version 0.3},
64 *    year = {1998},
65 *    URL = {http://www-2.cs.cmu.edu/\~fgcozman/Research/InterchangeFormat/}
66 * }
67 * </pre>
68 * <p/>
69 <!-- technical-bibtex-end -->
70 *
71 <!-- options-start -->
72 * Valid options are: <p/>
73 *
74 * <pre> -D
75 *  Do not use ADTree data structure
76 * </pre>
77 *
78 * <pre> -B &lt;BIF file&gt;
79 *  BIF file to compare with
80 * </pre>
81 *
82 * <pre> -Q weka.classifiers.bayes.net.search.SearchAlgorithm
83 *  Search algorithm
84 * </pre>
85 *
86 * <pre> -E weka.classifiers.bayes.net.estimate.SimpleEstimator
87 *  Estimator algorithm
88 * </pre>
89 *
90 <!-- options-end -->
91 *
92 * @author Remco Bouckaert (rrb@xm.co.nz)
93 * @version $Revision: 1.15 $
94 */
95public class BIFReader 
96    extends BayesNet
97    implements TechnicalInformationHandler {
98 
99    protected int [] m_nPositionX;
100    protected int [] m_nPositionY;
101    private int [] m_order;
102   
103    /** for serialization */
104    static final long serialVersionUID = -8358864680379881429L;
105
106    /**
107     * This will return a string describing the classifier.
108     * @return The string.
109     */
110    public String globalInfo() {
111        return 
112            "Builds a description of a Bayes Net classifier stored in XML "
113        + "BIF 0.3 format.\n\n"
114        + "For more details on XML BIF see:\n\n"
115        + getTechnicalInformation().toString();
116    }
117
118        /** processFile reads a BIFXML file and initializes a Bayes Net
119         * @param sFile name of the file to parse
120         * @return the BIFReader
121         * @throws Exception if processing fails
122         */
123        public BIFReader processFile(String sFile) throws Exception {
124                m_sFile = sFile;
125        DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance();
126        factory.setValidating(true);
127        Document doc = factory.newDocumentBuilder().parse(new File(sFile));
128        doc.normalize();
129
130        buildInstances(doc, sFile);
131        buildStructure(doc);
132        return this;
133        } // processFile
134
135        public BIFReader processString(String sStr) throws Exception {
136        DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance();
137        factory.setValidating(true);
138                Document doc = factory.newDocumentBuilder().parse(new org.xml.sax.InputSource(new StringReader(sStr)));
139        doc.normalize();
140        buildInstances(doc, "from-string");
141        buildStructure(doc);
142        return this;
143        } // processString
144       
145       
146        /** the current filename */
147        String m_sFile;
148       
149        /**
150         * returns the current filename
151         *
152         * @return the current filename
153         */
154        public String getFileName() {
155          return m_sFile;
156        }
157       
158       
159        /**
160         * Returns an instance of a TechnicalInformation object, containing
161         * detailed information about the technical background of this class,
162         * e.g., paper reference or book this class is based on.
163         *
164         * @return the technical information about this class
165         */
166        public TechnicalInformation getTechnicalInformation() {
167          TechnicalInformation  result;
168         
169          result = new TechnicalInformation(Type.MISC);
170          result.setValue(Field.AUTHOR, "Fabio Cozman and Marek Druzdzel and Daniel Garcia");
171          result.setValue(Field.YEAR, "1998");
172          result.setValue(Field.TITLE, "XML BIF version 0.3");
173          result.setValue(Field.URL, "http://www-2.cs.cmu.edu/~fgcozman/Research/InterchangeFormat/");
174         
175          return result;
176        }
177       
178        /** buildStructure parses the BIF document in the DOM tree contained
179         * in the doc parameter and specifies the the network structure and
180         * probability tables.
181         * It assumes that buildInstances has been called before
182         * @param doc DOM document containing BIF document in DOM tree
183         * @throws Exception if building of structure fails
184         */
185    void buildStructure(Document doc)  throws Exception {
186        // Get the name of the network
187                // initialize conditional distribution tables
188                m_Distributions = new Estimator[m_Instances.numAttributes()][];
189        for (int iNode = 0; iNode < m_Instances.numAttributes(); iNode++) {
190                // find definition that goes with this node
191                String sName = m_Instances.attribute(iNode).name();
192                        Element definition = getDefinition(doc, sName);
193/*
194                if (nodelist.getLength() == 0) {
195                        throw new Exception("No definition found for node " + sName);
196                }
197                if (nodelist.getLength() > 1) {
198                        System.err.println("More than one definition found for node " + sName + ". Using first definition.");
199                }
200                Element definition = (Element) nodelist.item(0);
201*/             
202               
203                // get the parents for this node
204                // resolve structure
205                FastVector nodelist = getParentNodes(definition);
206                for (int iParent = 0; iParent < nodelist.size(); iParent++) {
207                        Node parentName = ((Node) nodelist.elementAt(iParent)).getFirstChild();
208                        String sParentName = ((CharacterData) (parentName)).getData();
209                        int nParent = getNode(sParentName);
210                        m_ParentSets[iNode].addParent(nParent, m_Instances);
211                }
212                // resolve conditional probability table
213                        int nCardinality = m_ParentSets[iNode].getCardinalityOfParents();
214                int nValues = m_Instances.attribute(iNode).numValues();
215                m_Distributions[iNode] = new Estimator[nCardinality];
216                        for (int i = 0; i < nCardinality; i++) {
217                                m_Distributions[iNode][i] = new DiscreteEstimatorBayes(nValues, 0.0f);
218                        }
219
220/*
221                StringBuffer sTable = new StringBuffer();
222                for (int iText = 0; iText < nodelist.getLength(); iText++) {
223                        sTable.append(((CharacterData) (nodelist.item(iText))).getData());
224                        sTable.append(' ');
225                }
226                StringTokenizer st = new StringTokenizer(sTable.toString());
227*/
228                String sTable = getTable(definition);
229                        StringTokenizer st = new StringTokenizer(sTable.toString());
230               
231               
232                        for (int i = 0; i < nCardinality; i++) {
233                                DiscreteEstimatorBayes d = (DiscreteEstimatorBayes) m_Distributions[iNode][i];
234                                for (int iValue = 0; iValue < nValues; iValue++) {
235                                        String sWeight = st.nextToken();
236                                        d.addValue(iValue, new Double(sWeight).doubleValue());
237                                }
238                        }
239         }
240    } // buildStructure
241
242    /** synchronizes the node ordering of this Bayes network with
243     * those in the other network (if possible).
244     * @param other Bayes network to synchronize with
245     * @throws Exception if nr of attributes differs or not all of the variables have the same name.
246     */
247    public void Sync(BayesNet other) throws Exception {
248        int nAtts = m_Instances.numAttributes();
249        if (nAtts != other.m_Instances.numAttributes()) {
250                throw new Exception ("Cannot synchronize networks: different number of attributes.");
251        }
252        m_order = new int[nAtts];
253        for (int iNode = 0; iNode < nAtts; iNode++) {
254                String sName = other.getNodeName(iNode);
255                m_order[getNode(sName)] = iNode;
256        }
257    } // Sync
258
259
260    /**
261     * Returns all TEXT children of the given node in one string. Between
262     * the node values new lines are inserted.
263     *
264     * @param node the node to return the content for
265     * @return the content of the node
266     */
267    public String getContent(Element node) {
268      NodeList       list;
269      Node           item;
270      int            i;
271      String         result;
272     
273      result = "";
274      list   = node.getChildNodes();
275     
276      for (i = 0; i < list.getLength(); i++) {
277         item = list.item(i);
278         if (item.getNodeType() == Node.TEXT_NODE)
279            result += "\n" + item.getNodeValue();
280      }
281         
282      return result;
283    }
284
285
286        /** buildInstances parses the BIF document and creates a Bayes Net with its
287         * nodes specified, but leaves the network structure and probability tables empty.
288         * @param doc DOM document containing BIF document in DOM tree
289         * @param sName default name to give to the Bayes Net. Will be overridden if specified in the BIF document.
290         * @throws Exception if building fails
291         */
292        void buildInstances(Document doc, String sName) throws Exception {
293                NodeList nodelist;
294        // Get the name of the network
295        nodelist = selectAllNames(doc);
296        if (nodelist.getLength() > 0) {
297                sName = ((CharacterData) (nodelist.item(0).getFirstChild())).getData();
298        }
299
300        // Process variables
301        nodelist = selectAllVariables(doc);
302                int nNodes = nodelist.getLength();
303                // initialize structure
304                FastVector attInfo = new FastVector(nNodes);
305
306        // Initialize
307        m_nPositionX = new int[nodelist.getLength()];
308        m_nPositionY = new int[nodelist.getLength()];
309
310        // Process variables
311        for (int iNode = 0; iNode < nodelist.getLength(); iNode++) {
312            // Get element
313                        FastVector valueslist;
314                // Get the name of the network
315            valueslist = selectOutCome(nodelist.item(iNode));
316
317                        int nValues = valueslist.size();
318                        // generate value strings
319                FastVector nomStrings = new FastVector(nValues + 1);
320                for (int iValue = 0; iValue < nValues; iValue++) {
321                        Node node = ((Node) valueslist.elementAt(iValue)).getFirstChild();
322                        String sValue = ((CharacterData) (node)).getData();
323                        if (sValue == null) {
324                                sValue = "Value" + (iValue + 1);
325                        }
326                                nomStrings.addElement(sValue);
327                }
328                        FastVector nodelist2;
329                // Get the name of the network
330            nodelist2 = selectName(nodelist.item(iNode));
331            if (nodelist2.size() == 0) {
332                throw new Exception ("No name specified for variable");
333            }
334            String sNodeName = ((CharacterData) (((Node) nodelist2.elementAt(0)).getFirstChild())).getData();
335
336                        weka.core.Attribute att = new weka.core.Attribute(sNodeName, nomStrings);
337                        attInfo.addElement(att);
338
339            valueslist = selectProperty(nodelist.item(iNode));
340                        nValues = valueslist.size();
341                        // generate value strings
342                for (int iValue = 0; iValue < nValues; iValue++) {
343                // parsing for strings of the form "position = (73, 165)"
344                        Node node = ((Node)valueslist.elementAt(iValue)).getFirstChild();
345                        String sValue = ((CharacterData) (node)).getData();
346                if (sValue.startsWith("position")) {
347                    int i0 = sValue.indexOf('(');
348                    int i1 = sValue.indexOf(',');
349                    int i2 = sValue.indexOf(')');
350                    String sX = sValue.substring(i0 + 1, i1).trim();
351                    String sY = sValue.substring(i1 + 1, i2).trim();
352                    try {
353                        m_nPositionX[iNode] = (int) Integer.parseInt(sX);
354                        m_nPositionY[iNode] = (int) Integer.parseInt(sY);
355                    } catch (NumberFormatException e) {
356                        System.err.println("Wrong number format in position :(" + sX + "," + sY +")");
357                            m_nPositionX[iNode] = 0;
358                            m_nPositionY[iNode] = 0;
359                    }
360                }
361            }
362
363        }
364       
365                m_Instances = new Instances(sName, attInfo, 100);
366                m_Instances.setClassIndex(nNodes - 1);
367                setUseADTree(false);
368                initStructure();
369        } // buildInstances
370
371//      /** selectNodeList selects list of nodes from document specified in XPath expression
372//       * @param doc : document (or node) to query
373//       * @param sXPath : XPath expression
374//       * @return list of nodes conforming to XPath expression in doc
375//       * @throws Exception
376//       */
377//      private NodeList selectNodeList(Node doc, String sXPath) throws Exception {
378//              NodeList nodelist = org.apache.xpath.XPathAPI.selectNodeList(doc, sXPath);
379//              return nodelist;
380//      } // selectNodeList
381
382        NodeList selectAllNames(Document doc) throws Exception {
383                //NodeList nodelist = selectNodeList(doc, "//NAME");
384                NodeList nodelist = doc.getElementsByTagName("NAME");
385                return nodelist;
386        } // selectAllNames
387
388        NodeList selectAllVariables(Document doc) throws Exception {
389                //NodeList nodelist = selectNodeList(doc, "//VARIABLE");
390                NodeList nodelist = doc.getElementsByTagName("VARIABLE");
391                return nodelist;
392        } // selectAllVariables
393
394        Element getDefinition(Document doc, String sName) throws Exception {
395                //NodeList nodelist = selectNodeList(doc, "//DEFINITION[normalize-space(FOR/text())=\"" + sName + "\"]");
396
397                NodeList nodelist = doc.getElementsByTagName("DEFINITION");
398                for (int iNode = 0; iNode < nodelist.getLength(); iNode++) {
399                        Node node = nodelist.item(iNode);
400                        FastVector list = selectElements(node, "FOR");
401                        if (list.size() > 0) {
402                                Node forNode = (Node) list.elementAt(0);
403                                if (getContent((Element) forNode).trim().equals(sName)) {
404                                        return (Element) node;
405                                }
406                        }
407                }
408                throw new Exception("Could not find definition for ((" + sName + "))");
409        } // getDefinition
410
411        FastVector getParentNodes(Node definition) throws Exception {
412                //NodeList nodelist = selectNodeList(definition, "GIVEN");
413                FastVector nodelist = selectElements(definition, "GIVEN");
414                return nodelist;
415        } // getParentNodes
416
417        String getTable(Node definition) throws Exception {
418                //NodeList nodelist = selectNodeList(definition, "TABLE/text()");
419                FastVector nodelist = selectElements(definition, "TABLE");
420                String sTable = getContent((Element) nodelist.elementAt(0));
421                sTable = sTable.replaceAll("\\n"," ");
422                return sTable;
423        } // getTable
424
425        FastVector selectOutCome(Node item) throws Exception {
426                //NodeList nodelist = selectNodeList(item, "OUTCOME");
427                FastVector nodelist = selectElements(item, "OUTCOME");
428                return nodelist;
429        } // selectOutCome
430
431        FastVector selectName(Node item) throws Exception {
432           //NodeList nodelist = selectNodeList(item, "NAME");
433           FastVector nodelist = selectElements(item, "NAME");
434           return nodelist;
435   } // selectName
436
437   FastVector selectProperty(Node item) throws Exception {
438          // NodeList nodelist = selectNodeList(item, "PROPERTY");
439          FastVector nodelist = selectElements(item, "PROPERTY");
440          return nodelist;
441   } // selectProperty
442
443        FastVector selectElements(Node item, String sElement) throws Exception {
444          NodeList children = item.getChildNodes();
445          FastVector nodelist = new FastVector();
446          for (int iNode = 0; iNode < children.getLength(); iNode++) {
447                Node node = children.item(iNode);
448                if ((node.getNodeType() == Node.ELEMENT_NODE) && node.getNodeName().equals(sElement)) {
449                        nodelist.addElement(node);
450                }
451          }
452          return nodelist;
453  } // selectElements
454        /** Count nr of arcs missing from other network compared to current network
455         * Note that an arc is not 'missing' if it is reversed.
456         * @param other network to compare with
457         * @return nr of missing arcs
458         */
459        public int missingArcs(BayesNet other) {
460                try {
461                        Sync(other);
462                        int nMissing = 0;
463                        for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) {
464                                for (int iParent = 0; iParent < m_ParentSets[iAttribute].getNrOfParents(); iParent++) {
465                                        int nParent = m_ParentSets[iAttribute].getParent(iParent);
466                                        if (!other.getParentSet(m_order[iAttribute]).contains(m_order[nParent]) && !other.getParentSet(m_order[nParent]).contains(m_order[iAttribute])) {
467                                                nMissing++;
468                                        }
469                                }
470                        }
471                        return nMissing;
472                } catch (Exception e) {
473                        System.err.println(e.getMessage());
474                        return 0;
475                }
476        } // missingArcs
477
478        /** Count nr of exta arcs  from other network compared to current network
479         * Note that an arc is not 'extra' if it is reversed.
480         * @param other network to compare with
481         * @return nr of missing arcs
482         */
483        public int extraArcs(BayesNet other) {
484                try {
485                        Sync(other);
486                        int nExtra = 0;
487                        for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) {
488                                for (int iParent = 0; iParent < other.getParentSet(m_order[iAttribute]).getNrOfParents(); iParent++) {
489                                        int nParent = m_order[other.getParentSet(m_order[iAttribute]).getParent(iParent)];
490                                        if (!m_ParentSets[iAttribute].contains(nParent) && !m_ParentSets[nParent].contains(iAttribute)) {
491                                                nExtra++;
492                                        }
493                                }
494                        }
495                        return nExtra;
496                } catch (Exception e) {
497                        System.err.println(e.getMessage());
498                        return 0;
499                }
500        } // extraArcs
501
502
503        /** calculates the divergence between the probability distribution
504         * represented by this network and that of another, that is,
505         * \sum_{x\in X} P(x)log P(x)/Q(x)
506         * where X is the set of values the nodes in the network can take,
507         * P(x) the probability of this network for configuration x
508         * Q(x) the probability of the other network for configuration x
509         * @param other network to compare with
510         * @return divergence between this and other Bayes Network
511         */
512        public double divergence(BayesNet other) {
513                try {
514                        Sync(other);
515                        // D: divergence
516                        double D = 0.0;
517                        int nNodes = m_Instances.numAttributes();
518                        int [] nCard = new int[nNodes];
519                        for (int iNode = 0; iNode < nNodes; iNode++) {
520                                nCard[iNode] = m_Instances.attribute(iNode).numValues();
521                        }
522                        // x: holds current configuration of nodes
523                        int [] x = new int[nNodes];
524                        // simply sum over all configurations to calc divergence D
525                        int i = 0;
526                        while (i < nNodes) {
527                                // update configuration
528                                x[i]++;
529                                while (i < nNodes && x[i] == m_Instances.attribute(i).numValues()) {
530                                        x[i] = 0;
531                                        i++;
532                                        if (i < nNodes){
533                                                x[i]++;
534                                        }
535                                }
536                                if (i < nNodes) {
537                                        i = 0;
538                                        // calc P(x) and Q(x)
539                                        double P = 1.0;
540                                        for (int iNode = 0; iNode < nNodes; iNode++) {
541                                                int iCPT = 0;
542                                                for (int iParent = 0; iParent < m_ParentSets[iNode].getNrOfParents(); iParent++) {
543                                                int nParent = m_ParentSets[iNode].getParent(iParent);
544                                                    iCPT = iCPT * nCard[nParent] + x[nParent];
545                                                } 
546                                                P = P * m_Distributions[iNode][iCPT].getProbability(x[iNode]);
547                                        }
548       
549                                        double Q = 1.0;
550                                        for (int iNode = 0; iNode < nNodes; iNode++) {
551                                                int iCPT = 0;
552                                                for (int iParent = 0; iParent < other.getParentSet(m_order[iNode]).getNrOfParents(); iParent++) {
553                                                int nParent = m_order[other.getParentSet(m_order[iNode]).getParent(iParent)];
554                                                    iCPT = iCPT * nCard[nParent] + x[nParent];
555                                                } 
556                                                Q = Q * other.m_Distributions[m_order[iNode]][iCPT].getProbability(x[iNode]);
557                                        }
558       
559                                        // update divergence if probabilities are positive
560                                        if (P > 0.0 && Q > 0.0) {
561                                                D = D + P * Math.log(Q / P);
562                                        }
563                                }
564                        }
565                        return D;
566                } catch (Exception e) {
567                        System.err.println(e.getMessage());
568                        return 0;
569                }
570        } // divergence
571
572        /** Count nr of reversed arcs from other network compared to current network
573         * @param other network to compare with
574         * @return nr of missing arcs
575         */
576        public int reversedArcs(BayesNet other) {
577                try {
578                        Sync(other);
579                        int nReversed = 0;
580                    for (int iAttribute = 0; iAttribute < m_Instances.numAttributes(); iAttribute++) {
581                                for (int iParent = 0; iParent < m_ParentSets[iAttribute].getNrOfParents(); iParent++) {
582                                        int nParent = m_ParentSets[iAttribute].getParent(iParent);
583                                        if (!other.getParentSet(m_order[iAttribute]).contains(m_order[nParent]) && other.getParentSet(m_order[nParent]).contains(m_order[iAttribute])) {
584                                                nReversed++;
585                                        }
586                                }
587                        }
588                        return nReversed;
589                } catch (Exception e) {
590                        System.err.println(e.getMessage());
591                        return 0;
592                }
593        } // reversedArcs
594        /** getNode finds the index of the node with name sNodeName
595         * and throws an exception if no such node can be found.
596         * @param sNodeName name of the node to get the index from
597         * @return index of the node with name sNodeName
598         * @throws Exception if node cannot be found
599         */
600    public int getNode(String sNodeName) throws Exception {
601                int iNode = 0;
602                while (iNode < m_Instances.numAttributes()) {
603                        if (m_Instances.attribute(iNode).name().equals(sNodeName)) {
604                                return iNode;
605                        }
606                        iNode++;
607                }
608                throw new Exception("Could not find node [[" + sNodeName + "]]");
609    } // getNode
610
611    /**
612     * the default constructor
613     */
614    public BIFReader() {
615    }
616   
617    /**
618     * Returns the revision string.
619     *
620     * @return          the revision
621     */
622    public String getRevision() {
623      return RevisionUtils.extract("$Revision: 1.15 $");
624    }
625
626    /**
627     * Loads the file specified as first parameter and prints it to stdout.
628     *
629     * @param args the command line parameters
630     */
631    public static void main(String[] args) {
632        try {
633            BIFReader br = new BIFReader();
634            br.processFile(args[0]);
635            System.out.println(br.toString());
636       
637        }
638        catch (Throwable t) {
639            t.printStackTrace();
640        }
641    } // main
642} // class BIFReader
643
Note: See TracBrowser for help on using the repository browser.