/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 * PriorEstimation.java
 * Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.associations;

import weka.core.Instances;
import weka.core.RevisionHandler;
import weka.core.RevisionUtils;
import weka.core.SpecialFunctions;
import weka.core.Utils;

import java.io.Serializable;
import java.util.Hashtable;
import java.util.Random;

/**
 * Class implementing the prior estimattion of the predictive apriori algorithm 
 * for mining association rules. 
 *
 * Reference: T. Scheffer (2001). <i>Finding Association Rules That Trade Support 
 * Optimally against Confidence</i>. Proc of the 5th European Conf.
 * on Principles and Practice of Knowledge Discovery in Databases (PKDD'01),
 * pp. 424-435. Freiburg, Germany: Springer-Verlag. <p>
 *
 * @author Stefan Mutter (mutter@cs.waikato.ac.nz)
 * @version $Revision: 1.7 $ */

 public class PriorEstimation
   implements Serializable, RevisionHandler {
    
    /** for serialization */
    private static final long serialVersionUID = 5570863216522496271L;

    /** The number of rnadom rules. */
    protected int m_numRandRules;
    
    /** The number of intervals. */
    protected int m_numIntervals;
    
    /** The random seed used for the random rule generation step. */
    protected static final int SEED = 0;
    
    /** The maximum number of attributes for which a prior can be estimated. */
    protected static final int MAX_N = 1024;
    
    /** The random number generator. */
    protected Random m_randNum;
    
    /** The instances for which association rules are mined. */
    protected Instances m_instances;
    
    /** Flag indicating whether standard association rules or class association rules are mined. */
    protected boolean m_CARs;
    
    /** Hashtable to store the confidence values of randomly generated rules. */    
    protected Hashtable m_distribution;
    
    /** Hashtable containing the estimated prior probabilities. */
    protected  Hashtable m_priors;
    
    /** Sums up the confidences of all rules with a certain length. */
    protected double m_sum;
    
    /** The mid points of the discrete intervals in which the interval [0,1] is divided. */
    protected double[] m_midPoints;
    
    
    
   /**
   * Constructor 
   *
   * @param instances the instances to be used for generating the associations
   * @param numRules the number of random rules used for generating the prior
   * @param numIntervals the number of intervals to discretise [0,1]
   * @param car flag indicating whether standard or class association rules are mined
   */
    public PriorEstimation(Instances instances,int numRules,int numIntervals,boolean car) {
        
       m_instances = instances;
       m_CARs = car;
       m_numRandRules = numRules;
       m_numIntervals = numIntervals;
       m_randNum = m_instances.getRandomNumberGenerator(SEED);
    }
    /**
   * Calculates the prior distribution.
   *
   * @exception Exception if prior can't be estimated successfully
   */
    public final void generateDistribution() throws Exception{
        
        boolean jump;
        int i,maxLength = m_instances.numAttributes(), count =0,count1=0, ruleCounter;
        int [] itemArray;
        m_distribution = new Hashtable(maxLength*m_numIntervals);
        RuleItem current;
        ItemSet generate;
        
        if(m_instances.numAttributes() == 0)
            throw new Exception("Dataset has no attributes!");
        if(m_instances.numAttributes() >= MAX_N)
            throw new Exception("Dataset has to many attributes for prior estimation!");
        if(m_instances.numInstances() == 0)
            throw new Exception("Dataset has no instances!");
        for (int h = 0; h < maxLength; h++) {
            if (m_instances.attribute(h).isNumeric())
                throw new Exception("Can't handle numeric attributes!");
        } 
        if(m_numIntervals  == 0 || m_numRandRules == 0)
            throw new Exception("Prior initialisation impossible");
       
        //calculate mid points for the intervals
        midPoints();
        
        //create random rules of length i and measure their support and if support >0 their confidence
        for(i = 1;i <= maxLength; i++){
            m_sum = 0;
            int j = 0;
            count = 0;
            count1 = 0;
            while(j < m_numRandRules){
                count++;
                jump =false;
                if(!m_CARs){
                    itemArray = randomRule(maxLength,i,m_randNum);
                    current = splitItemSet(m_randNum.nextInt(i), itemArray);
                }
                else{
                    itemArray = randomCARule(maxLength,i,m_randNum);
                    current = addCons(itemArray);
                }
                int [] ruleItem = new int[maxLength];
                for(int k =0; k < itemArray.length;k++){
                    if(current.m_premise.m_items[k] != -1)
                        ruleItem[k] = current.m_premise.m_items[k];
                    else
                        if(current.m_consequence.m_items[k] != -1)
                            ruleItem[k] = current.m_consequence.m_items[k];
                        else
                            ruleItem[k] = -1;
                }
                ItemSet rule = new ItemSet(ruleItem);
                updateCounters(rule);
                ruleCounter = rule.m_counter;
                if(ruleCounter > 0)
                    jump =true;
                updateCounters(current.m_premise);
                j++;
                if(jump){
                    buildDistribution((double)ruleCounter/(double)current.m_premise.m_counter, (double)i);
                }
             }
            
            //normalize
            if(m_sum > 0){
                for(int w = 0; w < m_midPoints.length;w++){
                    String key = (String.valueOf(m_midPoints[w])).concat(String.valueOf((double)i));
                    Double oldValue = (Double)m_distribution.remove(key);
                    if(oldValue == null){
                        m_distribution.put(key,new Double(1.0/m_numIntervals));
                        m_sum += 1.0/m_numIntervals;
                    }
                    else
                        m_distribution.put(key,oldValue);
                }
                for(int w = 0; w < m_midPoints.length;w++){
                    double conf =0;
                    String key = (String.valueOf(m_midPoints[w])).concat(String.valueOf((double)i));
                    Double oldValue = (Double)m_distribution.remove(key);
                    if(oldValue != null){
                        conf = oldValue.doubleValue() / m_sum;
                        m_distribution.put(key,new Double(conf));
                    }
                }
            }
            else{
                for(int w = 0; w < m_midPoints.length;w++){
                    String key = (String.valueOf(m_midPoints[w])).concat(String.valueOf((double)i));
                    m_distribution.put(key,new Double(1.0/m_numIntervals));
                }
            }
        }
        
    }
    
    /**
     * Constructs an item set of certain length randomly.
     * This method is used for standard association rule mining.
     * @param maxLength the number of attributes of the instances
     * @param actualLength the number of attributes that should be present in the item set
     * @param randNum the random number generator
     * @return a randomly constructed item set in form of an int array
     */
    public final int[] randomRule(int maxLength, int actualLength, Random randNum){
     
        int[] itemArray = new int[maxLength];
        for(int k =0;k < itemArray.length;k++)
            itemArray[k] = -1;
        int help =actualLength;
        if(help == maxLength){
            help = 0;
            for(int h = 0; h < itemArray.length; h++){
                itemArray[h] = m_randNum.nextInt((m_instances.attribute(h)).numValues());
            }
        }
        while(help > 0){
            int mark = randNum.nextInt(maxLength);
            if(itemArray[mark] == -1){
                help--;
                itemArray[mark] = m_randNum.nextInt((m_instances.attribute(mark)).numValues());
            }
       }
        return itemArray;
    }
    
    
    /**
     * Constructs an item set of certain length randomly.
     * This method is used for class association rule mining.
     * @param maxLength the number of attributes of the instances
     * @param actualLength the number of attributes that should be present in the item set
     * @param randNum the random number generator
     * @return a randomly constructed item set in form of an int array
     */
     public final int[] randomCARule(int maxLength, int actualLength, Random randNum){
     
        int[] itemArray = new int[maxLength];
        for(int k =0;k < itemArray.length;k++)
            itemArray[k] = -1;
        if(actualLength == 1)
            return itemArray;
        int help =actualLength-1;
        if(help == maxLength-1){
            help = 0;
            for(int h = 0; h < itemArray.length; h++){
                if(h != m_instances.classIndex()){
                    itemArray[h] = m_randNum.nextInt((m_instances.attribute(h)).numValues());
                }
            }
        }
        while(help > 0){
            int mark = randNum.nextInt(maxLength);
            if(itemArray[mark] == -1 && mark != m_instances.classIndex()){
                help--;
                itemArray[mark] = m_randNum.nextInt((m_instances.attribute(mark)).numValues());
            }
       }
        return itemArray;
    }
   
     /**
      * updates the distribution of the confidence values.
      * For every confidence value the interval to which it belongs is searched
      * and the confidence is added to the confidence already found in this
      * interval.
      * @param conf the confidence of the randomly created rule
      * @param length the legnth of the randomly created rule
      */     
    public final void buildDistribution(double conf, double length){
     
        double mPoint = findIntervall(conf);
        String key = (String.valueOf(mPoint)).concat(String.valueOf(length));
        m_sum += conf;
        Double oldValue = (Double)m_distribution.remove(key);
        if(oldValue != null)
            conf = conf + oldValue.doubleValue();
        m_distribution.put(key,new Double(conf));
        
    }
    
    /**
     * searches the mid point of the interval a given confidence value falls into
     * @param conf the confidence of a rule
     * @return the mid point of the interval the confidence belongs to
     */    
     public final double findIntervall(double conf){
        
        if(conf == 1.0)
            return m_midPoints[m_midPoints.length-1];
        int end   = m_midPoints.length-1;
        int start = 0;
        while (Math.abs(end-start) > 1) {
            int mid = (start + end) / 2;
            if (conf > m_midPoints[mid])
                start = mid+1;
            if (conf < m_midPoints[mid]) 
                end = mid-1;
            if(conf == m_midPoints[mid])
                return m_midPoints[mid];
        }
        if(Math.abs(conf-m_midPoints[start]) <=  Math.abs(conf-m_midPoints[end]))
            return m_midPoints[start];
        else
            return m_midPoints[end];
    }
    
    
     /**
      * calculates the numerator and the denominator of the prior equation
      * @param weighted indicates whether the numerator or the denominator is calculated
      * @param mPoint the mid Point of an interval
      * @return the numerator or denominator of the prior equation
      */     
    public final double calculatePriorSum(boolean weighted, double mPoint){
  
      double distr, sum =0, max = logbinomialCoefficient(m_instances.numAttributes(),(int)m_instances.numAttributes()/2);
      
      
      for(int i = 1; i <= m_instances.numAttributes(); i++){
              
          if(weighted){
            String key = (String.valueOf(mPoint)).concat(String.valueOf((double)i));
            Double hashValue = (Double)m_distribution.get(key);
            
            if(hashValue !=null)
                distr = hashValue.doubleValue();
            else
                distr = 0;
                //distr = 1.0/m_numIntervals;
            if(distr != 0){
              double addend = Utils.log2(distr) - max + Utils.log2((Math.pow(2,i)-1)) + logbinomialCoefficient(m_instances.numAttributes(),i);
              sum = sum + Math.pow(2,addend);
            }
          }
          else{
              double addend = Utils.log2((Math.pow(2,i)-1)) - max + logbinomialCoefficient(m_instances.numAttributes(),i);
              sum = sum + Math.pow(2,addend);
          }
      }
      return sum;
  }
    /**
     * Method that calculates the base 2 logarithm of a binomial coefficient
     * @param upperIndex upper Inedx of the binomial coefficient
     * @param lowerIndex lower index of the binomial coefficient
     * @return the base 2 logarithm of the binomial coefficient
     */    
   public static final double logbinomialCoefficient(int upperIndex, int lowerIndex){
   
     double result =1.0;
     if(upperIndex == lowerIndex || lowerIndex == 0)
         return result;
     result = SpecialFunctions.log2Binomial((double)upperIndex, (double)lowerIndex);
     return result;
   }
   
   /**
    * Method to estimate the prior probabilities
    * @throws Exception throws exception if the prior cannot be calculated
    * @return a hashtable containing the prior probabilities
    */   
   public final Hashtable estimatePrior() throws Exception{
   
       double distr, prior, denominator, mPoint;
       
       Hashtable m_priors = new Hashtable(m_numIntervals);
       denominator = calculatePriorSum(false,1.0);
       generateDistribution();
       for(int i = 0; i < m_numIntervals; i++){ 
            mPoint = m_midPoints[i];
            prior = calculatePriorSum(true,mPoint) / denominator;
            m_priors.put(new Double(mPoint), new Double(prior));
       }
       return m_priors;
   }  
   
   /**
    * split the interval [0,1] into a predefined number of intervals and calculates their mid points
    */   
   public final void midPoints(){
        
        m_midPoints = new double[m_numIntervals];
        for(int i = 0; i < m_numIntervals; i++)
            m_midPoints[i] = midPoint(1.0/m_numIntervals, i);
   }
     
   /**
    * calculates the mid point of an interval
    * @param size the size of each interval
    * @param number the number of the interval.
    * The intervals are numbered from 0 to m_numIntervals.
    * @return the mid point of the interval
    */   
   public double midPoint(double size, int number){
    
       return (size * (double)number) + (size / 2.0);
   }
    
   /**
    * returns an ordered array of all mid points
    * @return an ordered array of doubles conatining all midpoints
    */   
   public final double[] getMidPoints(){
    
       return m_midPoints;
   }
   
   
   /**
    * splits an item set into premise and consequence and constructs therefore
    * an association rule. The length of the premise is given. The attributes
    * for premise and consequence are chosen randomly. The result is a RuleItem.
    * @param premiseLength the length of the premise
    * @param itemArray a (randomly generated) item set
    * @return a randomly generated association rule stored in a RuleItem
    */   
    public final RuleItem splitItemSet (int premiseLength, int[] itemArray){
        
       int[] cons = new int[m_instances.numAttributes()];
       System.arraycopy(itemArray, 0, cons, 0, itemArray.length);
       int help = premiseLength;
       while(help > 0){
            int mark = m_randNum.nextInt(itemArray.length);
            if(cons[mark] != -1){
                help--;
                cons[mark] =-1;
            }
       }
       if(premiseLength == 0)
            for(int i =0; i < itemArray.length;i++)
                itemArray[i] = -1;
       else
           for(int i =0; i < itemArray.length;i++)
               if(cons[i] != -1)
                    itemArray[i] = -1;
       ItemSet premise = new ItemSet(itemArray);
       ItemSet consequence = new ItemSet(cons);
       RuleItem current = new RuleItem();
       current.m_premise = premise;
       current.m_consequence = consequence;
       return current;
    }

    /**
     * generates a class association rule out of a given premise.
     * It randomly chooses a class label as consequence.
     * @param itemArray the (randomly constructed) premise of the class association rule
     * @return a class association rule stored in a RuleItem
     */    
    public final RuleItem addCons (int[] itemArray){
        
        ItemSet premise = new ItemSet(itemArray);
        int[] cons = new int[itemArray.length];
        for(int i =0;i < itemArray.length;i++)
            cons[i] = -1;
        cons[m_instances.classIndex()] = m_randNum.nextInt((m_instances.attribute(m_instances.classIndex())).numValues());
        ItemSet consequence = new ItemSet(cons);
        RuleItem current = new RuleItem();
        current.m_premise = premise;
        current.m_consequence = consequence;
        return current;
    }
    
    /**
     * updates the support count of an item set
     * @param itemSet the item set
     */    
    public final void updateCounters(ItemSet itemSet){
        
        for (int i = 0; i < m_instances.numInstances(); i++) 
            itemSet.upDateCounter(m_instances.instance(i));
    }
    
    /**
     * Returns the revision string.
     * 
     * @return		the revision
     */
    public String getRevision() {
      return RevisionUtils.extract("$Revision: 1.7 $");
    }
}
