source: src/main/java/weka/classifiers/bayes/net/ADNode.java @ 22

Last change on this file since 22 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 10.5 KB
Line 
1/*
2 * This program is free software; you can redistribute it and/or modify
3 * it under the terms of the GNU General Public License as published by
4 * the Free Software Foundation; either version 2 of the License, or
5 * (at your option) any later version.
6 *
7 * This program is distributed in the hope that it will be useful,
8 * but WITHOUT ANY WARRANTY; without even the implied warranty of
9 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 * GNU General Public License for more details.
11 *
12 * You should have received a copy of the GNU General Public License
13 * along with this program; if not, write to the Free Software
14 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 * ADNode.java
19 * Copyright (C) 2002 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.bayes.net;
24
25import weka.core.FastVector;
26import weka.core.Instance;
27import weka.core.Instances;
28import weka.core.RevisionHandler;
29import weka.core.RevisionUtils;
30import weka.core.TechnicalInformation;
31import weka.core.TechnicalInformationHandler;
32import weka.core.TechnicalInformation.Field;
33import weka.core.TechnicalInformation.Type;
34
35import java.io.FileReader;
36import java.io.Serializable;
37
38/**
39 * The ADNode class implements the ADTree datastructure which increases
40 * the speed with which sub-contingency tables can be constructed from
41 * a data set in an Instances object. For details, see: <p/>
42 *
43 <!-- technical-plaintext-start -->
44 * Andrew W. Moore, Mary S. Lee (1998). Cached Sufficient Statistics for Efficient Machine Learning with Large Datasets. Journal of Artificial Intelligence Research. 8:67-91.
45 <!-- technical-plaintext-end -->
46 * <p/>
47 *
48 <!-- technical-bibtex-start -->
49 * BibTeX:
50 * <pre>
51 * &#64;article{Moore1998,
52 *    author = {Andrew W. Moore and Mary S. Lee},
53 *    journal = {Journal of Artificial Intelligence Research},
54 *    pages = {67-91},
55 *    title = {Cached Sufficient Statistics for Efficient Machine Learning with Large Datasets},
56 *    volume = {8},
57 *    year = {1998}
58 * }
59 * </pre>
60 * <p/>
61 <!-- technical-bibtex-end -->
62 *
63 * @author Remco Bouckaert (rrb@xm.co.nz)
64 * @version $Revision: 1.7 $
65 */
66public class ADNode 
67        implements Serializable, TechnicalInformationHandler, RevisionHandler {
68 
69        /** for serialization */
70        static final long serialVersionUID = 397409728366910204L;
71 
72        final static int MIN_RECORD_SIZE = 0;
73       
74        /** list of VaryNode children **/
75        public VaryNode [] m_VaryNodes;
76        /** list of Instance children (either m_Instances or m_VaryNodes is instantiated) **/
77        public Instance [] m_Instances;
78
79        /** count **/
80        public int m_nCount;
81
82        /** first node in VaryNode array **/
83        public int m_nStartNode;
84
85        /** Creates new ADNode */
86        public ADNode() {
87        }
88
89        /**
90         * Returns an instance of a TechnicalInformation object, containing
91         * detailed information about the technical background of this class,
92         * e.g., paper reference or book this class is based on.
93         *
94         * @return the technical information about this class
95         */
96        public TechnicalInformation getTechnicalInformation() {
97          TechnicalInformation  result;
98         
99          result = new TechnicalInformation(Type.ARTICLE);
100          result.setValue(Field.AUTHOR, "Andrew W. Moore and Mary S. Lee");
101          result.setValue(Field.YEAR, "1998");
102          result.setValue(Field.TITLE, "Cached Sufficient Statistics for Efficient Machine Learning with Large Datasets");
103          result.setValue(Field.JOURNAL, "Journal of Artificial Intelligence Research");
104          result.setValue(Field.VOLUME, "8");
105          result.setValue(Field.PAGES, "67-91");
106         
107          return result;
108        }
109
110        /** create sub tree
111         * @param iNode index of the lowest node in the tree
112         * @param nRecords set of records in instances to be considered
113         * @param instances data set
114         * @return VaryNode representing part of an ADTree
115         **/
116        public static VaryNode makeVaryNode(int iNode, FastVector nRecords, Instances instances) {
117                VaryNode _VaryNode = new VaryNode(iNode);
118                int nValues = instances.attribute(iNode).numValues();
119               
120
121                // reserve memory and initialize
122                FastVector [] nChildRecords = new FastVector[nValues];
123                for (int iChild = 0; iChild < nValues; iChild++) {
124                        nChildRecords[iChild] = new FastVector();
125                }
126                // divide the records among children
127                for (int iRecord = 0; iRecord < nRecords.size(); iRecord++) {
128                        int iInstance = ((Integer) nRecords.elementAt(iRecord)).intValue();
129                        nChildRecords[(int) instances.instance(iInstance).value(iNode)].addElement(new Integer(iInstance));
130                }
131
132                // find most common value
133                int nCount = nChildRecords[0].size();
134                int nMCV = 0; 
135                for (int iChild = 1; iChild < nValues; iChild++) {
136                        if (nChildRecords[iChild].size() > nCount) {
137                                nCount = nChildRecords[iChild].size();
138                                nMCV = iChild;
139                        }
140                }
141                _VaryNode.m_nMCV = nMCV;
142
143                // determine child nodes
144                _VaryNode.m_ADNodes = new ADNode[nValues];
145                for (int iChild = 0; iChild < nValues; iChild++) {
146                        if (iChild == nMCV || nChildRecords[iChild].size() == 0) {
147                                _VaryNode.m_ADNodes[iChild] = null;
148                        } else {
149                                _VaryNode.m_ADNodes[iChild] = makeADTree(iNode + 1, nChildRecords[iChild], instances);
150                        }
151                }
152                return _VaryNode;
153        } // MakeVaryNode
154
155        /**
156         * create sub tree
157         *
158         * @param iNode index of the lowest node in the tree
159         * @param nRecords set of records in instances to be considered
160         * @param instances data set
161         * @return ADNode representing an ADTree
162         */
163        public static ADNode makeADTree(int iNode, FastVector nRecords, Instances instances) {
164                ADNode _ADNode = new ADNode();
165                _ADNode.m_nCount = nRecords.size();
166                _ADNode.m_nStartNode = iNode;
167                if (nRecords.size() < MIN_RECORD_SIZE) {
168                  _ADNode.m_Instances = new Instance[nRecords.size()];
169                  for (int iInstance = 0; iInstance < nRecords.size(); iInstance++) {
170                    _ADNode.m_Instances[iInstance] = instances.instance(((Integer) nRecords.elementAt(iInstance)).intValue());
171                  }
172                } else {
173                  _ADNode.m_VaryNodes = new VaryNode[instances.numAttributes() - iNode];
174                  for (int iNode2 = iNode; iNode2 < instances.numAttributes(); iNode2++) {
175                          _ADNode.m_VaryNodes[iNode2 - iNode] = makeVaryNode(iNode2, nRecords, instances);
176                  }
177                }
178                return _ADNode;
179        } // MakeADTree
180
181        /**
182         * create AD tree from set of instances
183         *
184         * @param instances data set
185         * @return ADNode representing an ADTree
186         */
187        public static ADNode makeADTree(Instances instances) {
188          FastVector nRecords = new FastVector(instances.numInstances());
189          for (int iRecord = 0; iRecord < instances.numInstances(); iRecord++) {
190            nRecords.addElement(new Integer(iRecord));
191          }
192          return makeADTree(0, nRecords, instances);
193        } // MakeADTree
194       
195          /**
196           * get counts for specific instantiation of a set of nodes
197           *
198           * @param nCounts - array for storing counts
199           * @param nNodes - array of node indexes
200           * @param nOffsets - offset for nodes in nNodes in nCounts
201           * @param iNode - index into nNode indicating current node
202           * @param iOffset - Offset into nCounts due to nodes below iNode
203           * @param bSubstract - indicate whether counts should be added or substracted
204           */
205        public void getCounts(
206              int [] nCounts, 
207              int [] nNodes, 
208              int [] nOffsets, 
209              int iNode, 
210              int iOffset,
211              boolean bSubstract
212        ) {
213//for (int iNode2 = 0; iNode2 < nCounts.length; iNode2++) {
214//   System.out.print(nCounts[iNode2] + " ");
215//}
216//System.out.println();
217          if (iNode >= nNodes.length) {
218            if (bSubstract) {
219              nCounts[iOffset] -= m_nCount;
220            } else {
221              nCounts[iOffset] += m_nCount;
222            }
223            return;
224          } else {
225            if (m_VaryNodes != null) {
226              m_VaryNodes[nNodes[iNode] - m_nStartNode].getCounts(nCounts, nNodes, nOffsets, iNode, iOffset, this, bSubstract);
227            } else {
228              for (int iInstance = 0; iInstance < m_Instances.length; iInstance++) {
229                int iOffset2 = iOffset;
230                Instance instance = m_Instances[iInstance];
231                for (int iNode2 = iNode; iNode2 < nNodes.length; iNode2++) {
232                  iOffset2 = iOffset2 + nOffsets[iNode2] * (int) instance.value(nNodes[iNode2]);
233                }
234                if (bSubstract) {
235                        nCounts[iOffset2]--;
236                } else {
237                        nCounts[iOffset2]++;
238                }
239              }
240            }
241          }
242        } // getCounts
243
244
245        /**
246         * print is used for debugging only and shows the ADTree in ASCII graphics
247         */
248        public void print() {
249          String sTab = new String();for (int i = 0; i < m_nStartNode; i++) {
250              sTab = sTab + "  ";
251          }
252          System.out.println(sTab + "Count = " + m_nCount);
253          if (m_VaryNodes != null) {
254                  for (int iNode = 0; iNode < m_VaryNodes.length; iNode++) {
255                    System.out.println(sTab + "Node " + (iNode + m_nStartNode));
256                    m_VaryNodes[iNode].print(sTab);
257                  }
258          } else {
259              System.out.println(m_Instances);
260          }
261        }
262       
263        /**
264         * for testing only
265         *
266         * @param argv the commandline options
267         */
268        public static void main(String [] argv) {
269            try {
270                Instances instances = new Instances(new FileReader("\\iris.2.arff"));
271                ADNode ADTree = ADNode.makeADTree(instances);
272                int [] nCounts = new int[12];
273                int [] nNodes = new int[3];
274                int [] nOffsets = new int[3];
275                nNodes[0] = 0;
276                nNodes[1] = 3;
277                nNodes[2] = 4;
278                nOffsets[0] = 2;
279                nOffsets[1] = 1;
280                nOffsets[2] = 4;
281                ADTree.print();
282                ADTree.getCounts(nCounts, nNodes, nOffsets,0, 0, false); 
283               
284            } catch (Throwable t) {
285                t.printStackTrace();
286            }
287        } // main
288       
289        /**
290         * Returns the revision string.
291         *
292         * @return              the revision
293         */
294        public String getRevision() {
295          return RevisionUtils.extract("$Revision: 1.7 $");
296        }
297} // class ADNode
Note: See TracBrowser for help on using the repository browser.