source: tags/MetisMQIDemo/src/main/java/weka/classifiers/bayes/net/search/local/TAN.java

Last change on this file was 29, checked in by gnappo, 15 years ago

Taggata versione per la demo e aggiunto branch.

File size: 8.9 KB
Line 
1/*
2 * This program is free software; you can redistribute it and/or modify
3 * it under the terms of the GNU General Public License as published by
4 * the Free Software Foundation; either version 2 of the License, or
5 * (at your option) any later version.
6 *
7 * This program is distributed in the hope that it will be useful,
8 * but WITHOUT ANY WARRANTY; without even the implied warranty of
9 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 * GNU General Public License for more details.
11 *
12 * You should have received a copy of the GNU General Public License
13 * along with this program; if not, write to the Free Software
14 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 * TAN.java
19 * Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.bayes.net.search.local;
24
25import weka.classifiers.bayes.BayesNet;
26import weka.core.Instances;
27import weka.core.RevisionUtils;
28import weka.core.TechnicalInformation;
29import weka.core.TechnicalInformation.Type;
30import weka.core.TechnicalInformation.Field;
31import weka.core.TechnicalInformationHandler;
32
33import java.util.Enumeration;
34
35/**
36 <!-- globalinfo-start -->
37 * This Bayes Network learning algorithm determines the maximum weight spanning tree  and returns a Naive Bayes network augmented with a tree.<br/>
38 * <br/>
39 * For more information see:<br/>
40 * <br/>
41 * N. Friedman, D. Geiger, M. Goldszmidt (1997). Bayesian network classifiers. Machine Learning. 29(2-3):131-163.
42 * <p/>
43 <!-- globalinfo-end -->
44 *
45 <!-- technical-bibtex-start -->
46 * BibTeX:
47 * <pre>
48 * &#64;article{Friedman1997,
49 *    author = {N. Friedman and D. Geiger and M. Goldszmidt},
50 *    journal = {Machine Learning},
51 *    number = {2-3},
52 *    pages = {131-163},
53 *    title = {Bayesian network classifiers},
54 *    volume = {29},
55 *    year = {1997}
56 * }
57 * </pre>
58 * <p/>
59 <!-- technical-bibtex-end -->
60 *
61 <!-- options-start -->
62 * Valid options are: <p/>
63 *
64 * <pre> -mbc
65 *  Applies a Markov Blanket correction to the network structure,
66 *  after a network structure is learned. This ensures that all
67 *  nodes in the network are part of the Markov blanket of the
68 *  classifier node.</pre>
69 *
70 * <pre> -S [BAYES|MDL|ENTROPY|AIC|CROSS_CLASSIC|CROSS_BAYES]
71 *  Score type (BAYES, BDeu, MDL, ENTROPY and AIC)</pre>
72 *
73 <!-- options-end -->
74 *
75 * @author Remco Bouckaert
76 * @version $Revision: 5333 $
77 */
78public class TAN 
79        extends LocalScoreSearchAlgorithm
80        implements TechnicalInformationHandler {
81 
82        /** for serialization */
83        static final long serialVersionUID = 965182127977228690L;
84
85        /**
86         * Returns an instance of a TechnicalInformation object, containing
87         * detailed information about the technical background of this class,
88         * e.g., paper reference or book this class is based on.
89         *
90         * @return the technical information about this class
91         */
92        public TechnicalInformation getTechnicalInformation() {
93          TechnicalInformation  result;
94         
95          result = new TechnicalInformation(Type.ARTICLE);
96          result.setValue(Field.AUTHOR, "N. Friedman and D. Geiger and M. Goldszmidt");
97          result.setValue(Field.YEAR, "1997");
98          result.setValue(Field.TITLE, "Bayesian network classifiers");
99          result.setValue(Field.JOURNAL, "Machine Learning");
100          result.setValue(Field.VOLUME, "29");
101          result.setValue(Field.NUMBER, "2-3");
102          result.setValue(Field.PAGES, "131-163");
103         
104          return result;
105        }
106
107        /**
108         * buildStructure determines the network structure/graph of the network
109         * using the maximimum weight spanning tree algorithm of Chow and Liu
110         *
111         * @param bayesNet the network
112         * @param instances the data to use
113         * @throws Exception if something goes wrong
114         */
115        public void buildStructure(BayesNet bayesNet, Instances instances) throws Exception {
116         
117                m_bInitAsNaiveBayes = true;
118                m_nMaxNrOfParents = 2;
119                super.buildStructure(bayesNet, instances);
120                int      nNrOfAtts = instances.numAttributes();
121               
122                if (nNrOfAtts <= 1) {
123                    return;
124                }
125
126                // determine base scores
127                double[] fBaseScores = new double[instances.numAttributes()];
128
129                for (int iAttribute = 0; iAttribute < nNrOfAtts; iAttribute++) {
130                  fBaseScores[iAttribute] = calcNodeScore(iAttribute);
131                } 
132
133                //              // cache scores & whether adding an arc makes sense
134                double[][]  fScore = new double[nNrOfAtts][nNrOfAtts];
135
136                for (int iAttributeHead = 0; iAttributeHead < nNrOfAtts; iAttributeHead++) {
137                        for (int iAttributeTail = 0; iAttributeTail < nNrOfAtts; iAttributeTail++) {
138                                if (iAttributeHead != iAttributeTail) {
139                                        fScore[iAttributeHead][iAttributeTail] = calcScoreWithExtraParent(iAttributeHead, iAttributeTail);
140                                }
141                        } 
142                }
143               
144                // TAN greedy search (not restricted by ordering like K2)
145                // 1. find strongest link
146                // 2. find remaining links by adding strongest link to already
147                //    connected nodes
148                // 3. assign direction to links
149                int nClassNode = instances.classIndex();
150                int [] link1 = new int [nNrOfAtts - 1];
151                int [] link2 = new int [nNrOfAtts - 1];
152                boolean [] linked = new boolean [nNrOfAtts];
153
154                // 1. find strongest link
155                int    nBestLinkNode1 = -1;
156                int    nBestLinkNode2 = -1;
157                double fBestDeltaScore = 0.0;
158                int iLinkNode1;
159                for (iLinkNode1 = 0; iLinkNode1 < nNrOfAtts; iLinkNode1++) {
160                        if (iLinkNode1 != nClassNode) {
161                        for (int iLinkNode2 = 0; iLinkNode2 < nNrOfAtts; iLinkNode2++) {
162                                if ((iLinkNode1 != iLinkNode2) &&
163                                    (iLinkNode2 != nClassNode) && (
164                                    (nBestLinkNode1 == -1) || (fScore[iLinkNode1][iLinkNode2] - fBaseScores[iLinkNode1] > fBestDeltaScore)
165                                )) {
166                                        fBestDeltaScore = fScore[iLinkNode1][iLinkNode2] - fBaseScores[iLinkNode1];
167                                        nBestLinkNode1 = iLinkNode2;
168                                        nBestLinkNode2 = iLinkNode1;
169                            } 
170                        } 
171                        }
172                }
173                link1[0] = nBestLinkNode1;
174                link2[0] = nBestLinkNode2;
175                linked[nBestLinkNode1] = true;
176                linked[nBestLinkNode2] = true;
177       
178                // 2. find remaining links by adding strongest link to already
179                //    connected nodes
180                for (int iLink = 1; iLink < nNrOfAtts - 2; iLink++) {
181                        nBestLinkNode1 = -1;
182                        for (iLinkNode1 = 0; iLinkNode1 < nNrOfAtts; iLinkNode1++) {
183                                if (iLinkNode1 != nClassNode) {
184                                for (int iLinkNode2 = 0; iLinkNode2 < nNrOfAtts; iLinkNode2++) {
185                                        if ((iLinkNode1 != iLinkNode2) &&
186                                            (iLinkNode2 != nClassNode) && 
187                                        (linked[iLinkNode1] || linked[iLinkNode2]) &&
188                                        (!linked[iLinkNode1] || !linked[iLinkNode2]) &&
189                                        (
190                                        (nBestLinkNode1 == -1) || (fScore[iLinkNode1][iLinkNode2] - fBaseScores[iLinkNode1] > fBestDeltaScore)
191                                        )) {
192                                                fBestDeltaScore = fScore[iLinkNode1][iLinkNode2] - fBaseScores[iLinkNode1];
193                                                nBestLinkNode1 = iLinkNode2;
194                                                nBestLinkNode2 = iLinkNode1;
195                                        } 
196                                } 
197                                }
198                        }
199
200                        link1[iLink] = nBestLinkNode1;
201                        link2[iLink] = nBestLinkNode2;
202                        linked[nBestLinkNode1] = true;
203                        linked[nBestLinkNode2] = true;
204                }
205               
206                // 3. assign direction to links
207                boolean [] hasParent = new boolean [nNrOfAtts];
208                for (int iLink = 0; iLink < nNrOfAtts - 2; iLink++) {
209                        if (!hasParent[link1[iLink]]) {
210                                bayesNet.getParentSet(link1[iLink]).addParent(link2[iLink], instances);
211                                hasParent[link1[iLink]] = true;
212                        } else {
213                                if (hasParent[link2[iLink]]) {
214                                        throw new Exception("Bug condition found: too many arrows");
215                                }
216                                bayesNet.getParentSet(link2[iLink]).addParent(link1[iLink], instances);
217                                hasParent[link2[iLink]] = true;
218                        }
219                }
220
221        } // buildStructure
222
223
224        /**
225         * Returns an enumeration describing the available options.
226         *
227         * @return an enumeration of all the available options.
228         */
229        public Enumeration listOptions() {
230                return super.listOptions();
231        } // listOption
232
233        /**
234         * Parses a given list of options. <p/>
235         *
236         <!-- options-start -->
237         * Valid options are: <p/>
238         *
239         * <pre> -mbc
240         *  Applies a Markov Blanket correction to the network structure,
241         *  after a network structure is learned. This ensures that all
242         *  nodes in the network are part of the Markov blanket of the
243         *  classifier node.</pre>
244         *
245         * <pre> -S [BAYES|MDL|ENTROPY|AIC|CROSS_CLASSIC|CROSS_BAYES]
246         *  Score type (BAYES, BDeu, MDL, ENTROPY and AIC)</pre>
247         *
248         <!-- options-end -->
249         *
250         * @param options the list of options as an array of strings
251         * @throws Exception if an option is not supported
252         */
253        public void setOptions(String[] options) throws Exception {
254                super.setOptions(options);
255        } // setOptions
256       
257        /**
258         * Gets the current settings of the Classifier.
259         *
260         * @return an array of strings suitable for passing to setOptions
261         */
262        public String [] getOptions() {
263                return super.getOptions();
264        } // getOptions
265
266        /**
267         * This will return a string describing the classifier.
268         * @return The string.
269         */
270        public String globalInfo() {
271                return 
272                    "This Bayes Network learning algorithm determines the maximum weight spanning tree "
273                  + " and returns a Naive Bayes network augmented with a tree.\n\n"
274                  + "For more information see:\n\n"
275                  + getTechnicalInformation().toString();
276        } // globalInfo
277
278        /**
279         * Returns the revision string.
280         *
281         * @return              the revision
282         */
283        public String getRevision() {
284          return RevisionUtils.extract("$Revision: 5333 $");
285        }
286
287} // TAN
288
Note: See TracBrowser for help on using the repository browser.