source: src/main/java/weka/classifiers/bayes/blr/LaplacePriorImpl.java @ 9

Last change on this file since 9 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 4.5 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    GaussianPrior.java
19 *    Copyright (C) 2008 Illinois Institute of Technology
20 *
21 */
22package weka.classifiers.bayes.blr;
23
24import weka.classifiers.bayes.BayesianLogisticRegression;
25import weka.core.Instance;
26import weka.core.Instances;
27import weka.core.RevisionUtils;
28
29/**
30 * Implementation of the Gaussian Prior update function based on modified
31 *  CLG Algorithm (CLG-Lasso) with a certain Trust Region Update based
32 * on Laplace Priors.
33 *
34 * @author Navendu Garg(gargnav@iit.edu)
35 * @version $Revision: 4899 $
36 */
37public class LaplacePriorImpl
38  extends Prior {
39 
40  /** for serialization. */
41  private static final long serialVersionUID = 2353576123257012607L;
42 
43  Instances m_Instances;
44  double Beta;
45  double Hyperparameter;
46  double DeltaUpdate;
47  double[] R;
48  double Delta;
49
50  /**
51   * Update function specific to Laplace Prior.
52   */
53  public double update(int j, Instances instances, double beta,
54    double hyperparameter, double[] r, double deltaV) {
55    double sign = 0.0;
56    double change = 0.0;
57    DeltaUpdate = 0.0;
58    m_Instances = instances;
59    Beta = beta;
60    Hyperparameter = hyperparameter;
61    R = r;
62    Delta = deltaV;
63
64    if (Beta == 0) {
65      sign = 1.0;
66      DeltaUpdate = laplaceUpdate(j, sign);
67
68      if (DeltaUpdate <= 0.0) { // positive direction failed.
69        sign = -1.0;
70        DeltaUpdate = laplaceUpdate(j, sign);
71
72        if (DeltaUpdate >= 0.0) {
73          DeltaUpdate = 0;
74        }
75      }
76    } else {
77      sign = Beta / Math.abs(Beta);
78      DeltaUpdate = laplaceUpdate(j, sign);
79      change = Beta + DeltaUpdate;
80      change = change / Math.abs(change);
81
82      if (change < 0) {
83        DeltaUpdate = 0 - Beta;
84      }
85    }
86
87    return DeltaUpdate;
88  }
89
90  /**
91   * This is the CLG-lasso update function described in the
92
93  *<pre>
94  * &#64;TechReport{blrtext04,
95  *author = {Alexander Genkin and David D. Lewis and David Madigan},
96  *title = {Large-scale bayesian logistic regression for text categorization},
97  *institution = {DIMACS},
98  *year = {2004},
99  *url = "http://www.stat.rutgers.edu/~madigan/PAPERS/shortFat-v3a.pdf",
100  *OPTannote = {}
101  *}</pre>
102   *
103   * @param j
104   * @return double value
105   */
106  public double laplaceUpdate(int j, double sign) {
107    double value = 0.0;
108    double numerator = 0.0;
109    double denominator = 0.0;
110
111    Instance instance;
112
113    for (int i = 0; i < m_Instances.numInstances(); i++) {
114      instance = m_Instances.instance(i);
115
116      if (instance.value(j) != 0) {
117        numerator += (instance.value(j) * BayesianLogisticRegression.classSgn(instance.classValue()) * (1.0 / (1.0 +
118        Math.exp(R[i]))));
119        denominator += (instance.value(j) * instance.value(j) * BayesianLogisticRegression.bigF(R[i],
120          Delta * instance.value(j)));
121      }
122    }
123
124    numerator -= (Math.sqrt(2.0 / Hyperparameter) * sign);
125
126    if (denominator != 0.0) {
127      value = numerator / denominator;
128    }
129
130    return value;
131  }
132
133  /**
134   * Computes the log-likelihood values using the implementation in the Prior class.
135   * @param betas
136   * @param instances
137   */
138  public void computeLogLikelihood(double[] betas, Instances instances) {
139    //Basic implementation done in the prior class.
140    super.computelogLikelihood(betas, instances);
141  }
142
143  /**
144   * This function computes the penalty term specific to Laplacian distribution.
145   * @param betas
146   * @param hyperparameters
147   */
148  public void computePenalty(double[] betas, double[] hyperparameters) {
149    penalty = 0.0;
150
151    double lambda = 0.0;
152
153    for (int j = 0; j < betas.length; j++) {
154      lambda = Math.sqrt(hyperparameters[j]);
155      penalty += (Math.log(2) - Math.log(lambda) +
156      (lambda * Math.abs(betas[j])));
157    }
158
159    penalty = 0 - penalty;
160  }
161 
162  /**
163   * Returns the revision string.
164   *
165   * @return            the revision
166   */
167  public String getRevision() {
168    return RevisionUtils.extract("$Revision: 4899 $");
169  }
170}
Note: See TracBrowser for help on using the repository browser.