source: tags/MetisMQIDemo/src/main/java/weka/classifiers/bayes/net/search/global/TAN.java

Last change on this file was 29, checked in by gnappo, 15 years ago

Taggata versione per la demo e aggiunto branch.

File size: 8.4 KB
Line 
1/*
2 * This program is free software; you can redistribute it and/or modify
3 * it under the terms of the GNU General Public License as published by
4 * the Free Software Foundation; either version 2 of the License, or
5 * (at your option) any later version.
6 *
7 * This program is distributed in the hope that it will be useful,
8 * but WITHOUT ANY WARRANTY; without even the implied warranty of
9 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 * GNU General Public License for more details.
11 *
12 * You should have received a copy of the GNU General Public License
13 * along with this program; if not, write to the Free Software
14 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 * TAN.java
19 * Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.bayes.net.search.global;
24
25import weka.classifiers.bayes.BayesNet;
26import weka.core.Instances;
27import weka.core.RevisionUtils;
28import weka.core.TechnicalInformation;
29import weka.core.TechnicalInformation.Type;
30import weka.core.TechnicalInformation.Field;
31import weka.core.TechnicalInformationHandler;
32
33import java.util.Enumeration;
34
35/**
36 <!-- globalinfo-start -->
37 * This Bayes Network learning algorithm determines the maximum weight spanning tree and returns a Naive Bayes network augmented with a tree.<br/>
38 * <br/>
39 * For more information see:<br/>
40 * <br/>
41 * N. Friedman, D. Geiger, M. Goldszmidt (1997). Bayesian network classifiers. Machine Learning. 29(2-3):131-163.
42 * <p/>
43 <!-- globalinfo-end -->
44 *
45 <!-- technical-bibtex-start -->
46 * BibTeX:
47 * <pre>
48 * &#64;article{Friedman1997,
49 *    author = {N. Friedman and D. Geiger and M. Goldszmidt},
50 *    journal = {Machine Learning},
51 *    number = {2-3},
52 *    pages = {131-163},
53 *    title = {Bayesian network classifiers},
54 *    volume = {29},
55 *    year = {1997}
56 * }
57 * </pre>
58 * <p/>
59 <!-- technical-bibtex-end -->
60 *
61 <!-- options-start -->
62 * Valid options are: <p/>
63 *
64 * <pre> -mbc
65 *  Applies a Markov Blanket correction to the network structure,
66 *  after a network structure is learned. This ensures that all
67 *  nodes in the network are part of the Markov blanket of the
68 *  classifier node.</pre>
69 *
70 * <pre> -S [LOO-CV|k-Fold-CV|Cumulative-CV]
71 *  Score type (LOO-CV,k-Fold-CV,Cumulative-CV)</pre>
72 *
73 * <pre> -Q
74 *  Use probabilistic or 0/1 scoring.
75 *  (default probabilistic scoring)</pre>
76 *
77 <!-- options-end -->
78 *
79 * @author Remco Bouckaert
80 * @version $Revision: 1.7 $
81 */
82public class TAN 
83        extends GlobalScoreSearchAlgorithm
84        implements TechnicalInformationHandler {
85
86        /** for serialization */
87        static final long serialVersionUID = 1715277053980895298L;
88
89        /**
90         * Returns an instance of a TechnicalInformation object, containing
91         * detailed information about the technical background of this class,
92         * e.g., paper reference or book this class is based on.
93         *
94         * @return the technical information about this class
95         */
96        public TechnicalInformation getTechnicalInformation() {
97          TechnicalInformation  result;
98         
99          result = new TechnicalInformation(Type.ARTICLE);
100          result.setValue(Field.AUTHOR, "N. Friedman and D. Geiger and M. Goldszmidt");
101          result.setValue(Field.YEAR, "1997");
102          result.setValue(Field.TITLE, "Bayesian network classifiers");
103          result.setValue(Field.JOURNAL, "Machine Learning");
104          result.setValue(Field.VOLUME, "29");
105          result.setValue(Field.NUMBER, "2-3");
106          result.setValue(Field.PAGES, "131-163");
107         
108          return result;
109        }
110
111        /**
112         * buildStructure determines the network structure/graph of the network
113         * using the maximimum weight spanning tree algorithm of Chow and Liu
114         *
115         * @param bayesNet
116         * @param instances
117         * @throws Exception if something goes wrong
118         */
119        public void buildStructure(BayesNet bayesNet, Instances instances) throws Exception {
120                m_BayesNet = bayesNet;
121
122                m_bInitAsNaiveBayes = true;
123                m_nMaxNrOfParents = 2;
124                super.buildStructure(bayesNet, instances);
125                int      nNrOfAtts = instances.numAttributes();
126
127                // TAN greedy search (not restricted by ordering like K2)
128                // 1. find strongest link
129                // 2. find remaining links by adding strongest link to already
130                //    connected nodes
131                // 3. assign direction to links
132                int nClassNode = instances.classIndex();
133                int [] link1 = new int [nNrOfAtts - 1];
134                int [] link2 = new int [nNrOfAtts - 1];
135                boolean [] linked = new boolean [nNrOfAtts];
136                // 1. find strongest link
137                int    nBestLinkNode1 = -1;
138                int    nBestLinkNode2 = -1;
139                double fBestDeltaScore = 0.0;
140                int iLinkNode1;
141                for (iLinkNode1 = 0; iLinkNode1 < nNrOfAtts; iLinkNode1++) {
142                        if (iLinkNode1 != nClassNode) {
143                                for (int iLinkNode2 = 0; iLinkNode2 < nNrOfAtts; iLinkNode2++) {
144                                        if ((iLinkNode1 != iLinkNode2) && (iLinkNode2 != nClassNode)) {
145                                                double fScore = calcScoreWithExtraParent(iLinkNode1, iLinkNode2);
146                                            if ((nBestLinkNode1 == -1) || (fScore > fBestDeltaScore)) {
147                                                fBestDeltaScore = fScore;
148                                                nBestLinkNode1 = iLinkNode2;
149                                                nBestLinkNode2 = iLinkNode1;
150                                            } 
151                                        }
152                                }
153                        }
154                }
155
156                link1[0] = nBestLinkNode1;
157                link2[0] = nBestLinkNode2;
158                linked[nBestLinkNode1] = true;
159                linked[nBestLinkNode2] = true;
160       
161                // 2. find remaining links by adding strongest link to already
162                //    connected nodes
163                for (int iLink = 1; iLink < nNrOfAtts - 2; iLink++) {
164                        nBestLinkNode1 = -1;
165                        for (iLinkNode1 = 0; iLinkNode1 < nNrOfAtts; iLinkNode1++) {
166                                if (iLinkNode1 != nClassNode) {
167                                        for (int iLinkNode2 = 0; iLinkNode2 < nNrOfAtts; iLinkNode2++) {
168                                                if ((iLinkNode1 != iLinkNode2) &&
169                                                    (iLinkNode2 != nClassNode) && 
170                                                (linked[iLinkNode1] || linked[iLinkNode2]) &&
171                                                (!linked[iLinkNode1] || !linked[iLinkNode2])) {
172                                                        double fScore = calcScoreWithExtraParent(iLinkNode1, iLinkNode2);
173
174                                                        if ((nBestLinkNode1 == -1) || (fScore > fBestDeltaScore)) {
175                                                                fBestDeltaScore = fScore;
176                                                                nBestLinkNode1 = iLinkNode2;
177                                                                nBestLinkNode2 = iLinkNode1;
178                                                        } 
179                                                }
180                                        } 
181                                }
182                        }
183                        link1[iLink] = nBestLinkNode1;
184                        link2[iLink] = nBestLinkNode2;
185                        linked[nBestLinkNode1] = true;
186                        linked[nBestLinkNode2] = true;
187                }
188               
189               
190//              System.out.println();   
191//              for (int i = 0; i < 3; i++) {
192//                      System.out.println(link1[i] + " " + link2[i]);
193//              }
194                // 3. assign direction to links
195                boolean [] hasParent = new boolean [nNrOfAtts];
196                for (int iLink = 0; iLink < nNrOfAtts - 2; iLink++) {
197                        if (!hasParent[link1[iLink]]) {
198                                bayesNet.getParentSet(link1[iLink]).addParent(link2[iLink], instances);
199                                hasParent[link1[iLink]] = true;
200                        } else {
201                                if (hasParent[link2[iLink]]) {
202                                        throw new Exception("Bug condition found: too many arrows");
203                                }
204                                bayesNet.getParentSet(link2[iLink]).addParent(link1[iLink], instances);
205                                hasParent[link2[iLink]] = true;
206                        }
207                }
208
209        } // buildStructure
210
211
212        /**
213         * Returns an enumeration describing the available options.
214         *
215         * @return an enumeration of all the available options.
216         */
217        public Enumeration listOptions() {
218          return super.listOptions();
219        } // listOption
220
221        /**
222         * Parses a given list of options. <p/>
223         *
224         <!-- options-start -->
225         * Valid options are: <p/>
226         *
227         * <pre> -mbc
228         *  Applies a Markov Blanket correction to the network structure,
229         *  after a network structure is learned. This ensures that all
230         *  nodes in the network are part of the Markov blanket of the
231         *  classifier node.</pre>
232         *
233         * <pre> -S [LOO-CV|k-Fold-CV|Cumulative-CV]
234         *  Score type (LOO-CV,k-Fold-CV,Cumulative-CV)</pre>
235         *
236         * <pre> -Q
237         *  Use probabilistic or 0/1 scoring.
238         *  (default probabilistic scoring)</pre>
239         *
240         <!-- options-end -->
241         *
242         * @param options the list of options as an array of strings
243         * @throws Exception if an option is not supported
244         */
245        public void setOptions(String[] options) throws Exception {
246                super.setOptions(options);
247        } // setOptions
248       
249        /**
250         * Gets the current settings of the Classifier.
251         *
252         * @return an array of strings suitable for passing to setOptions
253         */
254        public String [] getOptions() {
255                return super.getOptions();
256        } // getOptions
257
258        /**
259         * This will return a string describing the classifier.
260         * @return The string.
261         */
262        public String globalInfo() {
263          return 
264              "This Bayes Network learning algorithm determines the maximum weight spanning tree "
265            + "and returns a Naive Bayes network augmented with a tree.\n\n"
266            + "For more information see:\n\n"
267            + getTechnicalInformation().toString();
268        } // globalInfo
269
270        /**
271         * Returns the revision string.
272         *
273         * @return              the revision
274         */
275        public String getRevision() {
276          return RevisionUtils.extract("$Revision: 1.7 $");
277        }
278
279} // TAN
280
Note: See TracBrowser for help on using the repository browser.