source: src/main/java/weka/classifiers/functions/pace/NormalMixture.java @ 28

Last change on this file since 28 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 12.5 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or (at
5 *    your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful, but
8 *    WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
10 *    General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.  */
15
16/*
17 *    NormalMixture.java
18 *    Copyright (C) 2002 University of Waikato, Hamilton, New Zealand
19 *
20 */
21
22package weka.classifiers.functions.pace;
23
24import java.util.Random;
25
26import weka.core.RevisionUtils;
27import weka.core.matrix.DoubleVector;
28import weka.core.matrix.Maths;
29
30/**
31 * Class for manipulating normal mixture distributions. <p>
32 *
33 * For more information see: <p/>
34 *
35 <!-- technical-plaintext-start -->
36 * Wang, Y (2000). A new approach to fitting linear models in high dimensional spaces. Hamilton, New Zealand.<br/>
37 * <br/>
38 * Wang, Y., Witten, I. H.: Modeling for optimal probability prediction. In: Proceedings of the Nineteenth International Conference in Machine Learning, Sydney, Australia, 650-657, 2002.
39 <!-- technical-plaintext-end -->
40 *
41 <!-- technical-bibtex-start -->
42 * BibTeX:
43 * <pre>
44 * &#64;phdthesis{Wang2000,
45 *    address = {Hamilton, New Zealand},
46 *    author = {Wang, Y},
47 *    school = {Department of Computer Science, University of Waikato},
48 *    title = {A new approach to fitting linear models in high dimensional spaces},
49 *    year = {2000}
50 * }
51 *
52 * &#64;inproceedings{Wang2002,
53 *    address = {Sydney, Australia},
54 *    author = {Wang, Y. and Witten, I. H.},
55 *    booktitle = {Proceedings of the Nineteenth International Conference in Machine Learning},
56 *    pages = {650-657},
57 *    title = {Modeling for optimal probability prediction},
58 *    year = {2002}
59 * }
60 * </pre>
61 * <p/>
62 <!-- technical-bibtex-end -->
63 *
64 * @author Yong Wang (yongwang@cs.waikato.ac.nz)
65 * @version $Revision: 1.5 $
66 */
67public class  NormalMixture 
68  extends MixtureDistribution {
69 
70  /** the separating threshold */
71  protected double separatingThreshold = 0.05;
72
73  /** the triming thresholding */
74  protected double trimingThreshold = 0.7;
75
76  protected double fittingIntervalLength = 3;
77
78  /**
79   * Contructs an empty NormalMixture
80   */
81  public NormalMixture() {}
82
83  /**
84   * Gets the separating threshold value. This value is used by the method
85   * separatable
86   *
87   * @return the separating threshold
88   */
89  public double getSeparatingThreshold(){
90    return separatingThreshold;
91  }
92
93  /**
94   *  Sets the separating threshold value
95   * 
96   *  @param t the threshold value
97   */
98  public void setSeparatingThreshold( double t ){
99    separatingThreshold = t;
100  }
101
102  /**
103   * Gets the triming thresholding value. This value is usef by the method
104   * trim.
105   *
106   * @return the triming thresholding
107   */
108  public double getTrimingThreshold(){ 
109    return trimingThreshold; 
110  }
111
112  /**
113   * Sets the triming thresholding value.
114   *
115   * @param t the triming thresholding
116   */
117  public void setTrimingThreshold( double t ){
118    trimingThreshold = t;
119  }
120
121  /**
122   *  Return true if a value can be considered for mixture estimatino
123   *  separately from the data indexed between i0 and i1
124   * 
125   *  @param data the data supposedly generated from the mixture
126   *  @param i0 the index of the first element in the group
127   *  @param i1 the index of the last element in the group
128   *  @param x the value
129   *  @return true if the value can be considered
130   */
131  public boolean separable( DoubleVector data, int i0, int i1, double x ) {
132    double p = 0;
133    for( int i = i0; i <= i1; i++ ) {
134      p += Maths.pnorm( - Math.abs(x - data.get(i)) );
135    }
136    if( p < separatingThreshold ) return true;
137    else return false;
138  }
139
140  /**
141   *  Contructs the set of support points for mixture estimation.
142   * 
143   *  @param data the data supposedly generated from the mixture
144   *  @param ne the number of extra data that are suppposedly discarded
145   *  earlier and not passed into here
146   *  @return the set of support points
147   */
148  public DoubleVector  supportPoints( DoubleVector data, int ne ) {
149    if( data.size() < 2 )
150      throw new IllegalArgumentException("data size < 2");
151       
152    return data.copy();
153  }
154   
155  /**
156   *  Contructs the set of fitting intervals for mixture estimation.
157   * 
158   *  @param data the data supposedly generated from the mixture
159   *  @return the set of fitting intervals
160   */
161  public PaceMatrix  fittingIntervals( DoubleVector data ) {
162    DoubleVector left = data.cat( data.minus( fittingIntervalLength ) );
163    DoubleVector right = data.plus( fittingIntervalLength ).cat( data );
164       
165    PaceMatrix a = new PaceMatrix(left.size(), 2);
166       
167    a.setMatrix(0, left.size()-1, 0, left);
168    a.setMatrix(0, right.size()-1, 1, right);
169       
170    return a;
171  }
172   
173  /**
174   *  Contructs the probability matrix for mixture estimation, given a set
175   *  of support points and a set of intervals.
176   * 
177   *  @param s  the set of support points
178   *  @param intervals the intervals
179   *  @return the probability matrix
180   */
181  public PaceMatrix  probabilityMatrix( DoubleVector s, 
182                                        PaceMatrix intervals ) {
183   
184    int ns = s.size();
185    int nr = intervals.getRowDimension();
186    PaceMatrix p = new PaceMatrix(nr, ns);
187       
188    for( int i = 0; i < nr; i++ ) {
189      for( int j = 0; j < ns; j++ ) {
190        p.set( i, j,
191               Maths.pnorm( intervals.get(i, 1), s.get(j), 1 ) - 
192               Maths.pnorm( intervals.get(i, 0), s.get(j), 1 ) );
193      }
194    }
195       
196    return p;
197  }
198   
199  /**
200   * Returns the empirical Bayes estimate of a single value.
201   *
202   * @param x the value
203   * @return the empirical Bayes estimate
204   */
205  public double  empiricalBayesEstimate ( double x ) { 
206    if( Math.abs(x) > 10 ) return x; // pratical consideration; modify later
207    DoubleVector d = 
208    Maths.dnormLog( x, mixingDistribution.getPointValues(), 1 );
209   
210    d.minusEquals( d.max() );
211    d = d.map("java.lang.Math", "exp");
212    d.timesEquals( mixingDistribution.getFunctionValues() );
213    return mixingDistribution.getPointValues().innerProduct( d ) / d.sum();
214  }
215
216  /**
217   * Returns the empirical Bayes estimate of a vector.
218   *
219   * @param x the vector
220   * @return the empirical Bayes estimate
221   */
222  public DoubleVector empiricalBayesEstimate( DoubleVector x ) {
223    DoubleVector pred = new DoubleVector( x.size() );
224    for(int i = 0; i < x.size(); i++ ) 
225      pred.set(i, empiricalBayesEstimate(x.get(i)) );
226    trim( pred );
227    return pred;
228  }
229
230  /**
231   * Returns the optimal nested model estimate of a vector.
232   *
233   * @param x the vector
234   * @return the optimal nested model estimate
235   */
236  public DoubleVector  nestedEstimate( DoubleVector x ) {
237   
238    DoubleVector chf = new DoubleVector( x.size() );
239    for(int i = 0; i < x.size(); i++ ) chf.set( i, hf( x.get(i) ) );
240    chf.cumulateInPlace();
241    int index = chf.indexOfMax();
242    DoubleVector copy = x.copy();
243    if( index < x.size()-1 ) copy.set( index + 1, x.size()-1, 0 );
244    trim( copy );
245    return copy;
246  }
247 
248  /**
249   * Returns the estimate of optimal subset selection.
250   *
251   * @param x the vector
252   * @return the estimate of optimal subset selection
253   */
254  public DoubleVector  subsetEstimate( DoubleVector x ) {
255
256    DoubleVector h = h( x );
257    DoubleVector copy = x.copy();
258    for( int i = 0; i < x.size(); i++ )
259      if( h.get(i) <= 0 ) copy.set(i, 0);
260    trim( copy );
261    return copy;
262  }
263 
264  /**
265   * Trims the small values of the estaimte
266   *
267   * @param x the estimate vector
268   */
269  public void trim( DoubleVector x ) {
270    for(int i = 0; i < x.size(); i++ ) {
271      if( Math.abs(x.get(i)) <= trimingThreshold ) x.set(i, 0);
272    }
273  }
274 
275  /**
276   *  Computes the value of h(x) / f(x) given the mixture. The
277   *  implementation avoided overflow.
278   * 
279   *  @param x the value
280   *  @return the value of h(x) / f(x)
281   */
282  public double hf( double x ) {
283    DoubleVector points = mixingDistribution.getPointValues();
284    DoubleVector values = mixingDistribution.getFunctionValues(); 
285
286    DoubleVector d = Maths.dnormLog( x, points, 1 );
287    d.minusEquals( d.max() );
288
289    d = (DoubleVector) d.map("java.lang.Math", "exp");
290    d.timesEquals( values ); 
291
292    return ((DoubleVector) points.times(2*x).minusEquals(x*x))
293    .innerProduct( d ) / d.sum();
294  }
295   
296  /**
297   *  Computes the value of h(x) given the mixture.
298   * 
299   *  @param x the value
300   *  @return the value of h(x)
301   */
302  public double h( double x ) {
303    DoubleVector points = mixingDistribution.getPointValues();
304    DoubleVector values = mixingDistribution.getFunctionValues(); 
305    DoubleVector d = (DoubleVector) Maths.dnorm( x, points, 1 ).timesEquals( values ); 
306    return ((DoubleVector) points.times(2*x).minusEquals(x*x))
307    .innerProduct( d );
308  }
309   
310  /**
311   *  Computes the value of h(x) given the mixture, where x is a vector.
312   * 
313   *  @param x the vector
314   *  @return the value of h(x)
315   */
316  public DoubleVector h( DoubleVector x ) {
317    DoubleVector h = new DoubleVector( x.size() );
318    for( int i = 0; i < x.size(); i++ ) 
319      h.set( i, h( x.get(i) ) );
320    return h;
321  }
322   
323  /**
324   *  Computes the value of f(x) given the mixture.
325   * 
326   *  @param x the value
327   *  @return the value of f(x)
328   */
329  public double f( double x ) {
330    DoubleVector points = mixingDistribution.getPointValues();
331    DoubleVector values = mixingDistribution.getFunctionValues(); 
332    return Maths.dchisq( x, points ).timesEquals( values ).sum();
333  }
334   
335  /**
336   *  Computes the value of f(x) given the mixture, where x is a vector.
337   * 
338   *  @param x the vector
339   *  @return the value of f(x)
340   */
341  public DoubleVector f( DoubleVector x ) {
342    DoubleVector f = new DoubleVector( x.size() );
343    for( int i = 0; i < x.size(); i++ ) 
344      f.set( i, h( f.get(i) ) );
345    return f;
346  }
347   
348  /**
349   * Converts to a string
350   *
351   * @return a string representation
352   */
353  public String  toString() {
354    return mixingDistribution.toString();
355  }
356 
357  /**
358   * Returns the revision string.
359   *
360   * @return            the revision
361   */
362  public String getRevision() {
363    return RevisionUtils.extract("$Revision: 1.5 $");
364  }
365   
366  /**
367   * Method to test this class
368   *
369   * @param args the commandline arguments - ignored
370   */
371  public static void  main(String args[]) {
372    int n1 = 50;
373    int n2 = 50;
374    double mu1 = 0;
375    double mu2 = 5; 
376    DoubleVector a = Maths.rnorm( n1, mu1, 1, new Random() );
377    a = a.cat( Maths.rnorm( n2, mu2, 1, new Random() ) );
378    DoubleVector means = (new DoubleVector( n1, mu1 )).cat(new DoubleVector(n2, mu2));
379
380    System.out.println("==========================================================");
381    System.out.println("This is to test the estimation of the mixing\n" +
382            "distribution of the mixture of unit variance normal\n" + 
383            "distributions. The example mixture used is of the form: \n\n" + 
384            "   0.5 * N(mu1, 1) + 0.5 * N(mu2, 1)\n" );
385
386    System.out.println("It also tests three estimators: the subset\n" +
387            "selector, the nested model selector, and the empirical Bayes\n" +
388            "estimator. Quadratic losses of the estimators are given, \n" +
389            "and are taken as the measure of their performance.");
390    System.out.println("==========================================================");
391    System.out.println( "mu1 = " + mu1 + " mu2 = " + mu2 +"\n" );
392
393    System.out.println( a.size() + " observations are: \n\n" + a );
394
395    System.out.println( "\nQuadratic loss of the raw data (i.e., the MLE) = " + 
396             a.sum2( means ) );
397    System.out.println("==========================================================");
398
399    // find the mixing distribution
400    NormalMixture d = new NormalMixture();
401    d.fit( a, NNMMethod ); 
402    System.out.println( "The estimated mixing distribution is:\n" + d );
403       
404    DoubleVector pred = d.nestedEstimate( a.rev() ).rev();
405    System.out.println( "\nThe Nested Estimate = \n" + pred );
406    System.out.println( "Quadratic loss = " + pred.sum2( means ) );
407
408    pred = d.subsetEstimate( a );
409    System.out.println( "\nThe Subset Estimate = \n" + pred );
410    System.out.println( "Quadratic loss = " + pred.sum2( means ) );
411
412    pred = d.empiricalBayesEstimate( a );
413    System.out.println( "\nThe Empirical Bayes Estimate = \n" + pred );
414    System.out.println( "Quadratic loss = " + pred.sum2( means ) );
415       
416  }
417}
Note: See TracBrowser for help on using the repository browser.