source: src/main/java/weka/associations/PriorEstimation.java @ 20

Last change on this file since 20 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 18.1 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 * PriorEstimation.java
19 * Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.associations;
24
25import weka.core.Instances;
26import weka.core.RevisionHandler;
27import weka.core.RevisionUtils;
28import weka.core.SpecialFunctions;
29import weka.core.Utils;
30
31import java.io.Serializable;
32import java.util.Hashtable;
33import java.util.Random;
34
35/**
36 * Class implementing the prior estimattion of the predictive apriori algorithm
37 * for mining association rules.
38 *
39 * Reference: T. Scheffer (2001). <i>Finding Association Rules That Trade Support
40 * Optimally against Confidence</i>. Proc of the 5th European Conf.
41 * on Principles and Practice of Knowledge Discovery in Databases (PKDD'01),
42 * pp. 424-435. Freiburg, Germany: Springer-Verlag. <p>
43 *
44 * @author Stefan Mutter (mutter@cs.waikato.ac.nz)
45 * @version $Revision: 1.7 $ */
46
47 public class PriorEstimation
48   implements Serializable, RevisionHandler {
49   
50    /** for serialization */
51    private static final long serialVersionUID = 5570863216522496271L;
52
53    /** The number of rnadom rules. */
54    protected int m_numRandRules;
55   
56    /** The number of intervals. */
57    protected int m_numIntervals;
58   
59    /** The random seed used for the random rule generation step. */
60    protected static final int SEED = 0;
61   
62    /** The maximum number of attributes for which a prior can be estimated. */
63    protected static final int MAX_N = 1024;
64   
65    /** The random number generator. */
66    protected Random m_randNum;
67   
68    /** The instances for which association rules are mined. */
69    protected Instances m_instances;
70   
71    /** Flag indicating whether standard association rules or class association rules are mined. */
72    protected boolean m_CARs;
73   
74    /** Hashtable to store the confidence values of randomly generated rules. */   
75    protected Hashtable m_distribution;
76   
77    /** Hashtable containing the estimated prior probabilities. */
78    protected  Hashtable m_priors;
79   
80    /** Sums up the confidences of all rules with a certain length. */
81    protected double m_sum;
82   
83    /** The mid points of the discrete intervals in which the interval [0,1] is divided. */
84    protected double[] m_midPoints;
85   
86   
87   
88   /**
89   * Constructor
90   *
91   * @param instances the instances to be used for generating the associations
92   * @param numRules the number of random rules used for generating the prior
93   * @param numIntervals the number of intervals to discretise [0,1]
94   * @param car flag indicating whether standard or class association rules are mined
95   */
96    public PriorEstimation(Instances instances,int numRules,int numIntervals,boolean car) {
97       
98       m_instances = instances;
99       m_CARs = car;
100       m_numRandRules = numRules;
101       m_numIntervals = numIntervals;
102       m_randNum = m_instances.getRandomNumberGenerator(SEED);
103    }
104    /**
105   * Calculates the prior distribution.
106   *
107   * @exception Exception if prior can't be estimated successfully
108   */
109    public final void generateDistribution() throws Exception{
110       
111        boolean jump;
112        int i,maxLength = m_instances.numAttributes(), count =0,count1=0, ruleCounter;
113        int [] itemArray;
114        m_distribution = new Hashtable(maxLength*m_numIntervals);
115        RuleItem current;
116        ItemSet generate;
117       
118        if(m_instances.numAttributes() == 0)
119            throw new Exception("Dataset has no attributes!");
120        if(m_instances.numAttributes() >= MAX_N)
121            throw new Exception("Dataset has to many attributes for prior estimation!");
122        if(m_instances.numInstances() == 0)
123            throw new Exception("Dataset has no instances!");
124        for (int h = 0; h < maxLength; h++) {
125            if (m_instances.attribute(h).isNumeric())
126                throw new Exception("Can't handle numeric attributes!");
127        } 
128        if(m_numIntervals  == 0 || m_numRandRules == 0)
129            throw new Exception("Prior initialisation impossible");
130       
131        //calculate mid points for the intervals
132        midPoints();
133       
134        //create random rules of length i and measure their support and if support >0 their confidence
135        for(i = 1;i <= maxLength; i++){
136            m_sum = 0;
137            int j = 0;
138            count = 0;
139            count1 = 0;
140            while(j < m_numRandRules){
141                count++;
142                jump =false;
143                if(!m_CARs){
144                    itemArray = randomRule(maxLength,i,m_randNum);
145                    current = splitItemSet(m_randNum.nextInt(i), itemArray);
146                }
147                else{
148                    itemArray = randomCARule(maxLength,i,m_randNum);
149                    current = addCons(itemArray);
150                }
151                int [] ruleItem = new int[maxLength];
152                for(int k =0; k < itemArray.length;k++){
153                    if(current.m_premise.m_items[k] != -1)
154                        ruleItem[k] = current.m_premise.m_items[k];
155                    else
156                        if(current.m_consequence.m_items[k] != -1)
157                            ruleItem[k] = current.m_consequence.m_items[k];
158                        else
159                            ruleItem[k] = -1;
160                }
161                ItemSet rule = new ItemSet(ruleItem);
162                updateCounters(rule);
163                ruleCounter = rule.m_counter;
164                if(ruleCounter > 0)
165                    jump =true;
166                updateCounters(current.m_premise);
167                j++;
168                if(jump){
169                    buildDistribution((double)ruleCounter/(double)current.m_premise.m_counter, (double)i);
170                }
171             }
172           
173            //normalize
174            if(m_sum > 0){
175                for(int w = 0; w < m_midPoints.length;w++){
176                    String key = (String.valueOf(m_midPoints[w])).concat(String.valueOf((double)i));
177                    Double oldValue = (Double)m_distribution.remove(key);
178                    if(oldValue == null){
179                        m_distribution.put(key,new Double(1.0/m_numIntervals));
180                        m_sum += 1.0/m_numIntervals;
181                    }
182                    else
183                        m_distribution.put(key,oldValue);
184                }
185                for(int w = 0; w < m_midPoints.length;w++){
186                    double conf =0;
187                    String key = (String.valueOf(m_midPoints[w])).concat(String.valueOf((double)i));
188                    Double oldValue = (Double)m_distribution.remove(key);
189                    if(oldValue != null){
190                        conf = oldValue.doubleValue() / m_sum;
191                        m_distribution.put(key,new Double(conf));
192                    }
193                }
194            }
195            else{
196                for(int w = 0; w < m_midPoints.length;w++){
197                    String key = (String.valueOf(m_midPoints[w])).concat(String.valueOf((double)i));
198                    m_distribution.put(key,new Double(1.0/m_numIntervals));
199                }
200            }
201        }
202       
203    }
204   
205    /**
206     * Constructs an item set of certain length randomly.
207     * This method is used for standard association rule mining.
208     * @param maxLength the number of attributes of the instances
209     * @param actualLength the number of attributes that should be present in the item set
210     * @param randNum the random number generator
211     * @return a randomly constructed item set in form of an int array
212     */
213    public final int[] randomRule(int maxLength, int actualLength, Random randNum){
214     
215        int[] itemArray = new int[maxLength];
216        for(int k =0;k < itemArray.length;k++)
217            itemArray[k] = -1;
218        int help =actualLength;
219        if(help == maxLength){
220            help = 0;
221            for(int h = 0; h < itemArray.length; h++){
222                itemArray[h] = m_randNum.nextInt((m_instances.attribute(h)).numValues());
223            }
224        }
225        while(help > 0){
226            int mark = randNum.nextInt(maxLength);
227            if(itemArray[mark] == -1){
228                help--;
229                itemArray[mark] = m_randNum.nextInt((m_instances.attribute(mark)).numValues());
230            }
231       }
232        return itemArray;
233    }
234   
235   
236    /**
237     * Constructs an item set of certain length randomly.
238     * This method is used for class association rule mining.
239     * @param maxLength the number of attributes of the instances
240     * @param actualLength the number of attributes that should be present in the item set
241     * @param randNum the random number generator
242     * @return a randomly constructed item set in form of an int array
243     */
244     public final int[] randomCARule(int maxLength, int actualLength, Random randNum){
245     
246        int[] itemArray = new int[maxLength];
247        for(int k =0;k < itemArray.length;k++)
248            itemArray[k] = -1;
249        if(actualLength == 1)
250            return itemArray;
251        int help =actualLength-1;
252        if(help == maxLength-1){
253            help = 0;
254            for(int h = 0; h < itemArray.length; h++){
255                if(h != m_instances.classIndex()){
256                    itemArray[h] = m_randNum.nextInt((m_instances.attribute(h)).numValues());
257                }
258            }
259        }
260        while(help > 0){
261            int mark = randNum.nextInt(maxLength);
262            if(itemArray[mark] == -1 && mark != m_instances.classIndex()){
263                help--;
264                itemArray[mark] = m_randNum.nextInt((m_instances.attribute(mark)).numValues());
265            }
266       }
267        return itemArray;
268    }
269   
270     /**
271      * updates the distribution of the confidence values.
272      * For every confidence value the interval to which it belongs is searched
273      * and the confidence is added to the confidence already found in this
274      * interval.
275      * @param conf the confidence of the randomly created rule
276      * @param length the legnth of the randomly created rule
277      */     
278    public final void buildDistribution(double conf, double length){
279     
280        double mPoint = findIntervall(conf);
281        String key = (String.valueOf(mPoint)).concat(String.valueOf(length));
282        m_sum += conf;
283        Double oldValue = (Double)m_distribution.remove(key);
284        if(oldValue != null)
285            conf = conf + oldValue.doubleValue();
286        m_distribution.put(key,new Double(conf));
287       
288    }
289   
290    /**
291     * searches the mid point of the interval a given confidence value falls into
292     * @param conf the confidence of a rule
293     * @return the mid point of the interval the confidence belongs to
294     */   
295     public final double findIntervall(double conf){
296       
297        if(conf == 1.0)
298            return m_midPoints[m_midPoints.length-1];
299        int end   = m_midPoints.length-1;
300        int start = 0;
301        while (Math.abs(end-start) > 1) {
302            int mid = (start + end) / 2;
303            if (conf > m_midPoints[mid])
304                start = mid+1;
305            if (conf < m_midPoints[mid]) 
306                end = mid-1;
307            if(conf == m_midPoints[mid])
308                return m_midPoints[mid];
309        }
310        if(Math.abs(conf-m_midPoints[start]) <=  Math.abs(conf-m_midPoints[end]))
311            return m_midPoints[start];
312        else
313            return m_midPoints[end];
314    }
315   
316   
317     /**
318      * calculates the numerator and the denominator of the prior equation
319      * @param weighted indicates whether the numerator or the denominator is calculated
320      * @param mPoint the mid Point of an interval
321      * @return the numerator or denominator of the prior equation
322      */     
323    public final double calculatePriorSum(boolean weighted, double mPoint){
324 
325      double distr, sum =0, max = logbinomialCoefficient(m_instances.numAttributes(),(int)m_instances.numAttributes()/2);
326     
327     
328      for(int i = 1; i <= m_instances.numAttributes(); i++){
329             
330          if(weighted){
331            String key = (String.valueOf(mPoint)).concat(String.valueOf((double)i));
332            Double hashValue = (Double)m_distribution.get(key);
333           
334            if(hashValue !=null)
335                distr = hashValue.doubleValue();
336            else
337                distr = 0;
338                //distr = 1.0/m_numIntervals;
339            if(distr != 0){
340              double addend = Utils.log2(distr) - max + Utils.log2((Math.pow(2,i)-1)) + logbinomialCoefficient(m_instances.numAttributes(),i);
341              sum = sum + Math.pow(2,addend);
342            }
343          }
344          else{
345              double addend = Utils.log2((Math.pow(2,i)-1)) - max + logbinomialCoefficient(m_instances.numAttributes(),i);
346              sum = sum + Math.pow(2,addend);
347          }
348      }
349      return sum;
350  }
351    /**
352     * Method that calculates the base 2 logarithm of a binomial coefficient
353     * @param upperIndex upper Inedx of the binomial coefficient
354     * @param lowerIndex lower index of the binomial coefficient
355     * @return the base 2 logarithm of the binomial coefficient
356     */   
357   public static final double logbinomialCoefficient(int upperIndex, int lowerIndex){
358   
359     double result =1.0;
360     if(upperIndex == lowerIndex || lowerIndex == 0)
361         return result;
362     result = SpecialFunctions.log2Binomial((double)upperIndex, (double)lowerIndex);
363     return result;
364   }
365   
366   /**
367    * Method to estimate the prior probabilities
368    * @throws Exception throws exception if the prior cannot be calculated
369    * @return a hashtable containing the prior probabilities
370    */   
371   public final Hashtable estimatePrior() throws Exception{
372   
373       double distr, prior, denominator, mPoint;
374       
375       Hashtable m_priors = new Hashtable(m_numIntervals);
376       denominator = calculatePriorSum(false,1.0);
377       generateDistribution();
378       for(int i = 0; i < m_numIntervals; i++){ 
379            mPoint = m_midPoints[i];
380            prior = calculatePriorSum(true,mPoint) / denominator;
381            m_priors.put(new Double(mPoint), new Double(prior));
382       }
383       return m_priors;
384   } 
385   
386   /**
387    * split the interval [0,1] into a predefined number of intervals and calculates their mid points
388    */   
389   public final void midPoints(){
390       
391        m_midPoints = new double[m_numIntervals];
392        for(int i = 0; i < m_numIntervals; i++)
393            m_midPoints[i] = midPoint(1.0/m_numIntervals, i);
394   }
395     
396   /**
397    * calculates the mid point of an interval
398    * @param size the size of each interval
399    * @param number the number of the interval.
400    * The intervals are numbered from 0 to m_numIntervals.
401    * @return the mid point of the interval
402    */   
403   public double midPoint(double size, int number){
404   
405       return (size * (double)number) + (size / 2.0);
406   }
407   
408   /**
409    * returns an ordered array of all mid points
410    * @return an ordered array of doubles conatining all midpoints
411    */   
412   public final double[] getMidPoints(){
413   
414       return m_midPoints;
415   }
416   
417   
418   /**
419    * splits an item set into premise and consequence and constructs therefore
420    * an association rule. The length of the premise is given. The attributes
421    * for premise and consequence are chosen randomly. The result is a RuleItem.
422    * @param premiseLength the length of the premise
423    * @param itemArray a (randomly generated) item set
424    * @return a randomly generated association rule stored in a RuleItem
425    */   
426    public final RuleItem splitItemSet (int premiseLength, int[] itemArray){
427       
428       int[] cons = new int[m_instances.numAttributes()];
429       System.arraycopy(itemArray, 0, cons, 0, itemArray.length);
430       int help = premiseLength;
431       while(help > 0){
432            int mark = m_randNum.nextInt(itemArray.length);
433            if(cons[mark] != -1){
434                help--;
435                cons[mark] =-1;
436            }
437       }
438       if(premiseLength == 0)
439            for(int i =0; i < itemArray.length;i++)
440                itemArray[i] = -1;
441       else
442           for(int i =0; i < itemArray.length;i++)
443               if(cons[i] != -1)
444                    itemArray[i] = -1;
445       ItemSet premise = new ItemSet(itemArray);
446       ItemSet consequence = new ItemSet(cons);
447       RuleItem current = new RuleItem();
448       current.m_premise = premise;
449       current.m_consequence = consequence;
450       return current;
451    }
452
453    /**
454     * generates a class association rule out of a given premise.
455     * It randomly chooses a class label as consequence.
456     * @param itemArray the (randomly constructed) premise of the class association rule
457     * @return a class association rule stored in a RuleItem
458     */   
459    public final RuleItem addCons (int[] itemArray){
460       
461        ItemSet premise = new ItemSet(itemArray);
462        int[] cons = new int[itemArray.length];
463        for(int i =0;i < itemArray.length;i++)
464            cons[i] = -1;
465        cons[m_instances.classIndex()] = m_randNum.nextInt((m_instances.attribute(m_instances.classIndex())).numValues());
466        ItemSet consequence = new ItemSet(cons);
467        RuleItem current = new RuleItem();
468        current.m_premise = premise;
469        current.m_consequence = consequence;
470        return current;
471    }
472   
473    /**
474     * updates the support count of an item set
475     * @param itemSet the item set
476     */   
477    public final void updateCounters(ItemSet itemSet){
478       
479        for (int i = 0; i < m_instances.numInstances(); i++) 
480            itemSet.upDateCounter(m_instances.instance(i));
481    }
482   
483    /**
484     * Returns the revision string.
485     *
486     * @return          the revision
487     */
488    public String getRevision() {
489      return RevisionUtils.extract("$Revision: 1.7 $");
490    }
491}
Note: See TracBrowser for help on using the repository browser.