source: tags/MetisMQIDemo/src/main/java/weka/classifiers/bayes/net/search/global/SimulatedAnnealing.java

Last change on this file was 29, checked in by gnappo, 15 years ago

Taggata versione per la demo e aggiunto branch.

File size: 13.1 KB
Line 
1/*
2 * This program is free software; you can redistribute it and/or modify
3 * it under the terms of the GNU General Public License as published by
4 * the Free Software Foundation; either version 2 of the License, or
5 * (at your option) any later version.
6 *
7 * This program is distributed in the hope that it will be useful,
8 * but WITHOUT ANY WARRANTY; without even the implied warranty of
9 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 * GNU General Public License for more details.
11 *
12 * You should have received a copy of the GNU General Public License
13 * along with this program; if not, write to the Free Software
14 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 * SimulatedAnnealing.java
19 * Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
20 *
21 */
22 
23package weka.classifiers.bayes.net.search.global;
24
25import weka.classifiers.bayes.BayesNet;
26import weka.core.Instances;
27import weka.core.Option;
28import weka.core.RevisionUtils;
29import weka.core.TechnicalInformation;
30import weka.core.TechnicalInformation.Type;
31import weka.core.TechnicalInformation.Field;
32import weka.core.TechnicalInformationHandler;
33import weka.core.Utils;
34
35import java.util.Enumeration;
36import java.util.Random;
37import java.util.Vector;
38
39/**
40 <!-- globalinfo-start -->
41 * This Bayes Network learning algorithm uses the general purpose search method of simulated annealing to find a well scoring network structure.<br/>
42 * <br/>
43 * For more information see:<br/>
44 * <br/>
45 * R.R. Bouckaert (1995). Bayesian Belief Networks: from Construction to Inference. Utrecht, Netherlands.
46 * <p/>
47 <!-- globalinfo-end -->
48 *
49 <!-- technical-bibtex-start -->
50 * BibTeX:
51 * <pre>
52 * &#64;phdthesis{Bouckaert1995,
53 *    address = {Utrecht, Netherlands},
54 *    author = {R.R. Bouckaert},
55 *    institution = {University of Utrecht},
56 *    title = {Bayesian Belief Networks: from Construction to Inference},
57 *    year = {1995}
58 * }
59 * </pre>
60 * <p/>
61 <!-- technical-bibtex-end -->
62 *
63 <!-- options-start -->
64 * Valid options are: <p/>
65 *
66 * <pre> -A &lt;float&gt;
67 *  Start temperature</pre>
68 *
69 * <pre> -U &lt;integer&gt;
70 *  Number of runs</pre>
71 *
72 * <pre> -D &lt;float&gt;
73 *  Delta temperature</pre>
74 *
75 * <pre> -R &lt;seed&gt;
76 *  Random number seed</pre>
77 *
78 * <pre> -mbc
79 *  Applies a Markov Blanket correction to the network structure,
80 *  after a network structure is learned. This ensures that all
81 *  nodes in the network are part of the Markov blanket of the
82 *  classifier node.</pre>
83 *
84 * <pre> -S [LOO-CV|k-Fold-CV|Cumulative-CV]
85 *  Score type (LOO-CV,k-Fold-CV,Cumulative-CV)</pre>
86 *
87 * <pre> -Q
88 *  Use probabilistic or 0/1 scoring.
89 *  (default probabilistic scoring)</pre>
90 *
91 <!-- options-end -->
92 *
93 * @author Remco Bouckaert (rrb@xm.co.nz)
94 * @version $Revision: 1.6 $
95 */
96public class SimulatedAnnealing 
97        extends GlobalScoreSearchAlgorithm
98        implements TechnicalInformationHandler {
99
100        /** for serialization */
101        static final long serialVersionUID = -5482721887881010916L;
102
103        /** start temperature **/
104        double m_fTStart = 10;
105
106        /** change in temperature at every run **/
107        double m_fDelta = 0.999;
108
109        /** number of runs **/
110        int m_nRuns = 10000;
111
112        /** use the arc reversal operator **/
113        boolean m_bUseArcReversal = false;
114
115        /** random number seed **/
116        int m_nSeed = 1;
117
118        /** random number generator **/
119        Random m_random;
120
121        /**
122         * Returns an instance of a TechnicalInformation object, containing
123         * detailed information about the technical background of this class,
124         * e.g., paper reference or book this class is based on.
125         *
126         * @return the technical information about this class
127         */
128        public TechnicalInformation getTechnicalInformation() {
129          TechnicalInformation  result;
130         
131          result = new TechnicalInformation(Type.PHDTHESIS);
132          result.setValue(Field.AUTHOR, "R.R. Bouckaert");
133          result.setValue(Field.YEAR, "1995");
134          result.setValue(Field.TITLE, "Bayesian Belief Networks: from Construction to Inference");
135          result.setValue(Field.INSTITUTION, "University of Utrecht");
136          result.setValue(Field.ADDRESS, "Utrecht, Netherlands");
137         
138          return result;
139        }
140       
141    /**
142     *
143     * @param bayesNet the bayes net to use
144     * @param instances the data to use
145     * @throws Exception if something goes wrong
146     */
147    public void search (BayesNet bayesNet, Instances instances) throws Exception {
148                m_random = new Random(m_nSeed);
149               
150        // determine base scores
151                double fCurrentScore = calcScore(bayesNet);
152
153                // keep track of best scoring network
154                double fBestScore = fCurrentScore;
155                BayesNet bestBayesNet = new BayesNet();
156                bestBayesNet.m_Instances = instances;
157                bestBayesNet.initStructure();
158                copyParentSets(bestBayesNet, bayesNet);
159
160        double fTemp = m_fTStart;
161        for (int iRun = 0; iRun < m_nRuns; iRun++) {
162            boolean bRunSucces = false;
163            double fDeltaScore = 0.0;
164            while (!bRunSucces) {
165                    // pick two nodes at random
166                    int iTailNode = Math.abs(m_random.nextInt()) % instances.numAttributes();
167                    int iHeadNode = Math.abs(m_random.nextInt()) % instances.numAttributes();
168                    while (iTailNode == iHeadNode) {
169                            iHeadNode = Math.abs(m_random.nextInt()) % instances.numAttributes();
170                    }
171                    if (isArc(bayesNet, iHeadNode, iTailNode)) {
172                    bRunSucces = true;
173                        // either try a delete
174                    bayesNet.getParentSet(iHeadNode).deleteParent(iTailNode, instances);
175                    double fScore = calcScore(bayesNet);
176                    fDeltaScore = fScore - fCurrentScore;
177//System.out.println("Try delete " + iTailNode + "->" + iHeadNode + " dScore = " + fDeltaScore);                   
178                    if (fTemp * Math.log((Math.abs(m_random.nextInt()) % 10000)/10000.0  + 1e-100) < fDeltaScore) {
179//System.out.println("success!!!");                   
180                                                fCurrentScore = fScore;
181                    } else {
182                        // roll back
183                        bayesNet.getParentSet(iHeadNode).addParent(iTailNode, instances);
184                    }
185                    } else {
186                        // try to add an arc
187                        if (addArcMakesSense(bayesNet, instances, iHeadNode, iTailNode)) {
188                        bRunSucces = true;
189                        double fScore = calcScoreWithExtraParent(iHeadNode, iTailNode);
190                        fDeltaScore = fScore - fCurrentScore;
191//System.out.println("Try add " + iTailNode + "->" + iHeadNode + " dScore = " + fDeltaScore);                   
192                        if (fTemp * Math.log((Math.abs(m_random.nextInt()) % 10000)/10000.0  + 1e-100) < fDeltaScore) {
193//System.out.println("success!!!");                   
194                            bayesNet.getParentSet(iHeadNode).addParent(iTailNode, instances);
195                                                        fCurrentScore = fScore;
196                        }
197                        }
198                    }
199            }
200                        if (fCurrentScore > fBestScore) {
201                                copyParentSets(bestBayesNet, bayesNet);                         
202                        }
203            fTemp = fTemp * m_fDelta;
204        }
205
206                copyParentSets(bayesNet, bestBayesNet);
207    } // buildStructure
208       
209        /** CopyParentSets copies parent sets of source to dest BayesNet
210         * @param dest destination network
211         * @param source source network
212         */
213        void copyParentSets(BayesNet dest, BayesNet source) {
214                int nNodes = source.getNrOfNodes();
215                // clear parent set first
216                for (int iNode = 0; iNode < nNodes; iNode++) {
217                        dest.getParentSet(iNode).copy(source.getParentSet(iNode));
218                }               
219        } // CopyParentSets
220
221    /**
222     * @return double
223     */
224    public double getDelta() {
225        return m_fDelta;
226    }
227
228    /**
229     * @return double
230     */
231    public double getTStart() {
232        return m_fTStart;
233    }
234
235    /**
236     * @return int
237     */
238    public int getRuns() {
239        return m_nRuns;
240    }
241
242    /**
243     * Sets the m_fDelta.
244     * @param fDelta The m_fDelta to set
245     */
246    public void setDelta(double fDelta) {
247        m_fDelta = fDelta;
248    }
249
250    /**
251     * Sets the m_fTStart.
252     * @param fTStart The m_fTStart to set
253     */
254    public void setTStart(double fTStart) {
255        m_fTStart = fTStart;
256    }
257
258    /**
259     * Sets the m_nRuns.
260     * @param nRuns The m_nRuns to set
261     */
262    public void setRuns(int nRuns) {
263        m_nRuns = nRuns;
264    }
265
266        /**
267        * @return random number seed
268        */
269        public int getSeed() {
270                return m_nSeed;
271        } // getSeed
272
273        /**
274         * Sets the random number seed
275         * @param nSeed The number of the seed to set
276         */
277        public void setSeed(int nSeed) {
278                m_nSeed = nSeed;
279        } // setSeed
280
281        /**
282         * Returns an enumeration describing the available options.
283         *
284         * @return an enumeration of all the available options.
285         */
286        public Enumeration listOptions() {
287                Vector newVector = new Vector(3);
288
289                newVector.addElement(new Option("\tStart temperature", "A", 1, "-A <float>"));
290                newVector.addElement(new Option("\tNumber of runs", "U", 1, "-U <integer>"));
291                newVector.addElement(new Option("\tDelta temperature", "D", 1, "-D <float>"));
292                newVector.addElement(new Option("\tRandom number seed", "R", 1, "-R <seed>"));
293
294                Enumeration enu = super.listOptions();
295                while (enu.hasMoreElements()) {
296                        newVector.addElement(enu.nextElement());
297                }
298                return newVector.elements();
299        }
300
301        /**
302         * Parses a given list of options. <p/>
303         *
304         <!-- options-start -->
305         * Valid options are: <p/>
306         *
307         * <pre> -A &lt;float&gt;
308         *  Start temperature</pre>
309         *
310         * <pre> -U &lt;integer&gt;
311         *  Number of runs</pre>
312         *
313         * <pre> -D &lt;float&gt;
314         *  Delta temperature</pre>
315         *
316         * <pre> -R &lt;seed&gt;
317         *  Random number seed</pre>
318         *
319         * <pre> -mbc
320         *  Applies a Markov Blanket correction to the network structure,
321         *  after a network structure is learned. This ensures that all
322         *  nodes in the network are part of the Markov blanket of the
323         *  classifier node.</pre>
324         *
325         * <pre> -S [LOO-CV|k-Fold-CV|Cumulative-CV]
326         *  Score type (LOO-CV,k-Fold-CV,Cumulative-CV)</pre>
327         *
328         * <pre> -Q
329         *  Use probabilistic or 0/1 scoring.
330         *  (default probabilistic scoring)</pre>
331         *
332         <!-- options-end -->
333         *
334         * @param options the list of options as an array of strings
335         * @throws Exception if an option is not supported
336         */
337        public void setOptions(String[] options) throws Exception {
338                String sTStart = Utils.getOption('A', options);
339                if (sTStart.length() != 0) {
340                        setTStart(Double.parseDouble(sTStart));
341                }
342                String sRuns = Utils.getOption('U', options);
343                if (sRuns.length() != 0) {
344                        setRuns(Integer.parseInt(sRuns));
345                }
346                String sDelta = Utils.getOption('D', options);
347                if (sDelta.length() != 0) {
348                        setDelta(Double.parseDouble(sDelta));
349                }
350                String sSeed = Utils.getOption('R', options);
351                if (sSeed.length() != 0) {
352                        setSeed(Integer.parseInt(sSeed));
353                }
354                super.setOptions(options);
355        }
356
357        /**
358         * Gets the current settings of the search algorithm.
359         *
360         * @return an array of strings suitable for passing to setOptions
361         */
362        public String[] getOptions() {
363                String[] superOptions = super.getOptions();
364                String[] options = new String[8 + superOptions.length];
365                int current = 0;
366                options[current++] = "-A";
367                options[current++] = "" + getTStart();
368
369                options[current++] = "-U";
370                options[current++] = "" + getRuns();
371
372                options[current++] = "-D";
373                options[current++] = "" + getDelta();
374
375                options[current++] = "-R";
376                options[current++] = "" + getSeed();
377
378                // insert options from parent class
379                for (int iOption = 0; iOption < superOptions.length; iOption++) {
380                        options[current++] = superOptions[iOption];
381                }
382
383                // Fill up rest with empty strings, not nulls!
384                while (current < options.length) {
385                        options[current++] = "";
386                }
387                return options;
388        }
389
390        /**
391         * This will return a string describing the classifier.
392         * @return The string.
393         */
394        public String globalInfo() {
395                return 
396                    "This Bayes Network learning algorithm uses the general purpose search method "
397                  + "of simulated annealing to find a well scoring network structure.\n\n"
398                  + "For more information see:\n\n"
399                  + getTechnicalInformation().toString();
400        } // globalInfo
401       
402        /**
403         * @return a string to describe the TStart option.
404         */
405        public String TStartTipText() {
406          return "Sets the start temperature of the simulated annealing search. "+
407          "The start temperature determines the probability that a step in the 'wrong' direction in the " +
408          "search space is accepted. The higher the temperature, the higher the probability of acceptance.";
409        } // TStartTipText
410
411        /**
412         * @return a string to describe the Runs option.
413         */
414        public String runsTipText() {
415          return "Sets the number of iterations to be performed by the simulated annealing search.";
416        } // runsTipText
417       
418        /**
419         * @return a string to describe the Delta option.
420         */
421        public String deltaTipText() {
422          return "Sets the factor with which the temperature (and thus the acceptance probability of " +
423                "steps in the wrong direction in the search space) is decreased in each iteration.";
424        } // deltaTipText
425
426        /**
427         * @return a string to describe the Seed option.
428         */
429        public String seedTipText() {
430          return "Initialization value for random number generator." +
431          " Setting the seed allows replicability of experiments.";
432        } // seedTipText
433
434        /**
435         * Returns the revision string.
436         *
437         * @return              the revision
438         */
439        public String getRevision() {
440          return RevisionUtils.extract("$Revision: 1.6 $");
441        }
442} // SimulatedAnnealing
Note: See TracBrowser for help on using the repository browser.