source: tags/MetisMQIDemo/src/main/java/weka/classifiers/bayes/net/search/global/RepeatedHillClimber.java

Last change on this file was 29, checked in by gnappo, 15 years ago

Taggata versione per la demo e aggiunto branch.

File size: 9.7 KB
Line 
1/*
2 * This program is free software; you can redistribute it and/or modify
3 * it under the terms of the GNU General Public License as published by
4 * the Free Software Foundation; either version 2 of the License, or
5 * (at your option) any later version.
6 *
7 * This program is distributed in the hope that it will be useful,
8 * but WITHOUT ANY WARRANTY; without even the implied warranty of
9 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 * GNU General Public License for more details.
11 *
12 * You should have received a copy of the GNU General Public License
13 * along with this program; if not, write to the Free Software
14 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 * RepeatedHillClimber.java
19 * Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
20 *
21 */
22 
23package weka.classifiers.bayes.net.search.global;
24
25import weka.classifiers.bayes.BayesNet;
26import weka.classifiers.bayes.net.ParentSet;
27import weka.core.Instances;
28import weka.core.Option;
29import weka.core.RevisionUtils;
30import weka.core.Utils;
31
32import java.util.Enumeration;
33import java.util.Random;
34import java.util.Vector;
35
36/**
37 <!-- globalinfo-start -->
38 * This Bayes Network learning algorithm repeatedly uses hill climbing starting with a randomly generated network structure and return the best structure of the various runs.
39 * <p/>
40 <!-- globalinfo-end -->
41 *
42 <!-- options-start -->
43 * Valid options are: <p/>
44 *
45 * <pre> -U &lt;integer&gt;
46 *  Number of runs</pre>
47 *
48 * <pre> -A &lt;seed&gt;
49 *  Random number seed</pre>
50 *
51 * <pre> -P &lt;nr of parents&gt;
52 *  Maximum number of parents</pre>
53 *
54 * <pre> -R
55 *  Use arc reversal operation.
56 *  (default false)</pre>
57 *
58 * <pre> -N
59 *  Initial structure is empty (instead of Naive Bayes)</pre>
60 *
61 * <pre> -mbc
62 *  Applies a Markov Blanket correction to the network structure,
63 *  after a network structure is learned. This ensures that all
64 *  nodes in the network are part of the Markov blanket of the
65 *  classifier node.</pre>
66 *
67 * <pre> -S [LOO-CV|k-Fold-CV|Cumulative-CV]
68 *  Score type (LOO-CV,k-Fold-CV,Cumulative-CV)</pre>
69 *
70 * <pre> -Q
71 *  Use probabilistic or 0/1 scoring.
72 *  (default probabilistic scoring)</pre>
73 *
74 <!-- options-end -->
75 *
76 * @author Remco Bouckaert (rrb@xm.co.nz)
77 * @version $Revision: 1.6 $
78 */
79public class RepeatedHillClimber 
80    extends HillClimber {
81
82    /** for serialization */
83    static final long serialVersionUID = -7359197180460703069L;
84 
85    /** number of runs **/
86    int m_nRuns = 10;
87    /** random number seed **/
88    int m_nSeed = 1;
89    /** random number generator **/
90    Random m_random;
91
92        /**
93        * search determines the network structure/graph of the network
94        * with the repeated hill climbing.
95        *
96        * @param bayesNet the network to use
97        * @param instances the data to use
98        * @throws Exception if something goes wrong
99        **/
100        protected void search(BayesNet bayesNet, Instances instances) throws Exception {
101                m_random = new Random(getSeed());
102                // keeps track of score pf best structure found so far
103                double fBestScore;     
104                double fCurrentScore = calcScore(bayesNet);
105
106                // keeps track of best structure found so far
107                BayesNet bestBayesNet;
108
109                // initialize bestBayesNet
110                fBestScore = fCurrentScore;
111                bestBayesNet = new BayesNet();
112                bestBayesNet.m_Instances = instances;
113                bestBayesNet.initStructure();
114                copyParentSets(bestBayesNet, bayesNet);
115               
116               
117        // go do the search       
118        for (int iRun = 0; iRun < m_nRuns; iRun++) {
119                // generate random nework
120                generateRandomNet(bayesNet, instances);
121
122                // search
123                super.search(bayesNet, instances);
124
125                        // calculate score
126                        fCurrentScore = calcScore(bayesNet);
127
128                        // keep track of best network seen so far
129                        if (fCurrentScore > fBestScore) {
130                                fBestScore = fCurrentScore;
131                                copyParentSets(bestBayesNet, bayesNet);
132                        }
133        }
134       
135        // restore current network to best network
136                copyParentSets(bayesNet, bestBayesNet);
137               
138                // free up memory
139                bestBayesNet = null;
140    } // search
141
142        /**
143         *
144         * @param bayesNet
145         * @param instances
146         */
147        void generateRandomNet(BayesNet bayesNet, Instances instances) {
148                int nNodes = instances.numAttributes();
149                // clear network
150                for (int iNode = 0; iNode < nNodes; iNode++) {
151                        ParentSet parentSet = bayesNet.getParentSet(iNode);
152                        while (parentSet.getNrOfParents() > 0) {
153                                parentSet.deleteLastParent(instances);
154                        }
155                }
156               
157                // initialize as naive Bayes?
158                if (getInitAsNaiveBayes()) {
159                        int iClass = instances.classIndex();
160                        // initialize parent sets to have arrow from classifier node to
161                        // each of the other nodes
162                        for (int iNode = 0; iNode < nNodes; iNode++) {
163                                if (iNode != iClass) {
164                                        bayesNet.getParentSet(iNode).addParent(iClass, instances);
165                                }
166                        }
167                }
168
169                // insert random arcs
170                int nNrOfAttempts = m_random.nextInt(nNodes * nNodes);
171                for (int iAttempt = 0; iAttempt < nNrOfAttempts; iAttempt++) {
172                        int iTail = m_random.nextInt(nNodes);
173                        int iHead = m_random.nextInt(nNodes);
174                        if (bayesNet.getParentSet(iHead).getNrOfParents() < getMaxNrOfParents() &&
175                            addArcMakesSense(bayesNet, instances, iHead, iTail)) {
176                                        bayesNet.getParentSet(iHead).addParent(iTail, instances);
177                        }
178                }
179        } // generateRandomNet
180
181        /**
182         * copyParentSets copies parent sets of source to dest BayesNet
183         *
184         * @param dest destination network
185         * @param source source network
186         */
187        void copyParentSets(BayesNet dest, BayesNet source) {
188                int nNodes = source.getNrOfNodes();
189                // clear parent set first
190                for (int iNode = 0; iNode < nNodes; iNode++) {
191                        dest.getParentSet(iNode).copy(source.getParentSet(iNode));
192                }               
193        } // CopyParentSets
194
195
196    /**
197     * Returns the number of runs
198     *
199     * @return number of runs
200     */
201    public int getRuns() {
202        return m_nRuns;
203    } // getRuns
204
205    /**
206     * Sets the number of runs
207     *
208     * @param nRuns The number of runs to set
209     */
210    public void setRuns(int nRuns) {
211        m_nRuns = nRuns;
212    } // setRuns
213
214        /**
215         * Returns the random seed
216         *
217         * @return random number seed
218         */
219        public int getSeed() {
220                return m_nSeed;
221        } // getSeed
222
223        /**
224         * Sets the random number seed
225         *
226         * @param nSeed The number of the seed to set
227         */
228        public void setSeed(int nSeed) {
229                m_nSeed = nSeed;
230        } // setSeed
231
232        /**
233         * Returns an enumeration describing the available options.
234         *
235         * @return an enumeration of all the available options.
236         */
237        public Enumeration listOptions() {
238                Vector newVector = new Vector(4);
239
240                newVector.addElement(new Option("\tNumber of runs", "U", 1, "-U <integer>"));
241                newVector.addElement(new Option("\tRandom number seed", "A", 1, "-A <seed>"));
242
243                Enumeration enu = super.listOptions();
244                while (enu.hasMoreElements()) {
245                        newVector.addElement(enu.nextElement());
246                }
247                return newVector.elements();
248        } // listOptions
249
250        /**
251         * Parses a given list of options. <p/>
252         *
253         <!-- options-start -->
254         * Valid options are: <p/>
255         *
256         * <pre> -U &lt;integer&gt;
257         *  Number of runs</pre>
258         *
259         * <pre> -A &lt;seed&gt;
260         *  Random number seed</pre>
261         *
262         * <pre> -P &lt;nr of parents&gt;
263         *  Maximum number of parents</pre>
264         *
265         * <pre> -R
266         *  Use arc reversal operation.
267         *  (default false)</pre>
268         *
269         * <pre> -N
270         *  Initial structure is empty (instead of Naive Bayes)</pre>
271         *
272         * <pre> -mbc
273         *  Applies a Markov Blanket correction to the network structure,
274         *  after a network structure is learned. This ensures that all
275         *  nodes in the network are part of the Markov blanket of the
276         *  classifier node.</pre>
277         *
278         * <pre> -S [LOO-CV|k-Fold-CV|Cumulative-CV]
279         *  Score type (LOO-CV,k-Fold-CV,Cumulative-CV)</pre>
280         *
281         * <pre> -Q
282         *  Use probabilistic or 0/1 scoring.
283         *  (default probabilistic scoring)</pre>
284         *
285         <!-- options-end -->
286         *
287         * @param options the list of options as an array of strings
288         * @throws Exception if an option is not supported
289         */
290        public void setOptions(String[] options) throws Exception {
291                String sRuns = Utils.getOption('U', options);
292                if (sRuns.length() != 0) {
293                        setRuns(Integer.parseInt(sRuns));
294                }
295               
296                String sSeed = Utils.getOption('A', options);
297                if (sSeed.length() != 0) {
298                        setSeed(Integer.parseInt(sSeed));
299                }
300
301                super.setOptions(options);
302        } // setOptions
303
304        /**
305         * Gets the current settings of the search algorithm.
306         *
307         * @return an array of strings suitable for passing to setOptions
308         */
309        public String[] getOptions() {
310                String[] superOptions = super.getOptions();
311                String[] options = new String[7 + superOptions.length];
312                int current = 0;
313
314                options[current++] = "-U";
315                options[current++] = "" + getRuns();
316
317                options[current++] = "-A";
318                options[current++] = "" + getSeed();
319
320                // insert options from parent class
321                for (int iOption = 0; iOption < superOptions.length; iOption++) {
322                        options[current++] = superOptions[iOption];
323                }
324
325                // Fill up rest with empty strings, not nulls!
326                while (current < options.length) {
327                        options[current++] = "";
328                }
329                return options;
330        } // getOptions
331
332        /**
333         * This will return a string describing the classifier.
334         *
335         * @return The string.
336         */
337        public String globalInfo() {
338                return "This Bayes Network learning algorithm repeatedly uses hill climbing starting " +
339                "with a randomly generated network structure and return the best structure of the " +
340                "various runs.";
341        } // globalInfo
342       
343        /**
344         * @return a string to describe the Runs option.
345         */
346        public String runsTipText() {
347          return "Sets the number of times hill climbing is performed.";
348        } // runsTipText
349
350        /**
351         * @return a string to describe the Seed option.
352         */
353        public String seedTipText() {
354          return "Initialization value for random number generator." +
355          " Setting the seed allows replicability of experiments.";
356        } // seedTipText
357
358        /**
359         * Returns the revision string.
360         *
361         * @return              the revision
362         */
363        public String getRevision() {
364          return RevisionUtils.extract("$Revision: 1.6 $");
365        }
366}
Note: See TracBrowser for help on using the repository browser.