source: src/main/java/weka/classifiers/functions/pace/ChisqMixture.java @ 4

Last change on this file since 4 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 14.0 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or (at
5 *    your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful, but
8 *    WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
10 *    General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.  */
15
16/*
17 *    ChisqMixture.java
18 *    Copyright (C) 2002 University of Waikato, Hamilton, New Zealand
19 *
20 */
21
22package weka.classifiers.functions.pace;
23
24import weka.core.RevisionUtils;
25import weka.core.matrix.DoubleVector;
26import weka.core.matrix.Maths;
27
28import java.util.Random;
29
30/**
31 * Class for manipulating chi-square mixture distributions. <p/>
32 *
33 * For more information see: <p/>
34 *
35 <!-- technical-plaintext-start -->
36 * Wang, Y (2000). A new approach to fitting linear models in high dimensional spaces. Hamilton, New Zealand.<br/>
37 * <br/>
38 * Wang, Y., Witten, I. H.: Modeling for optimal probability prediction. In: Proceedings of the Nineteenth International Conference in Machine Learning, Sydney, Australia, 650-657, 2002.
39 <!-- technical-plaintext-end -->
40 *
41 <!-- technical-bibtex-start -->
42 * BibTeX:
43 * <pre>
44 * &#64;phdthesis{Wang2000,
45 *    address = {Hamilton, New Zealand},
46 *    author = {Wang, Y},
47 *    school = {Department of Computer Science, University of Waikato},
48 *    title = {A new approach to fitting linear models in high dimensional spaces},
49 *    year = {2000}
50 * }
51 *
52 * &#64;inproceedings{Wang2002,
53 *    address = {Sydney, Australia},
54 *    author = {Wang, Y. and Witten, I. H.},
55 *    booktitle = {Proceedings of the Nineteenth International Conference in Machine Learning},
56 *    pages = {650-657},
57 *    title = {Modeling for optimal probability prediction},
58 *    year = {2002}
59 * }
60 * </pre>
61 * <p/>
62 <!-- technical-bibtex-end -->
63 *
64 * @author Yong Wang (yongwang@cs.waikato.ac.nz)
65 * @version $Revision: 1.5 $
66 */
67public class ChisqMixture 
68  extends MixtureDistribution {
69 
70  /** the separating threshold value */
71  protected double separatingThreshold = 0.05; 
72
73  /** the triming thresholding */
74  protected double trimingThreshold = 0.5;
75
76  protected double supportThreshold = 0.5;
77
78  protected int maxNumSupportPoints = 200; // for computational reason
79
80  protected int fittingIntervalLength = 3;
81   
82  protected double fittingIntervalThreshold = 0.5;
83
84  /** Contructs an empty ChisqMixture
85   */
86  public ChisqMixture() {}
87
88  /**
89   * Gets the separating threshold value. This value is used by the method
90   * separatable
91   *
92   * @return the separating threshold
93   */
94  public double getSeparatingThreshold() {
95    return separatingThreshold;
96  }
97 
98  /**
99   * Sets the separating threshold value
100   *
101   * @param t the threshold value
102   */
103  public void setSeparatingThreshold( double t ) {
104    separatingThreshold = t;
105  }
106
107  /**
108   * Gets the triming thresholding value. This value is usef by the method trim.
109   *
110   * @return the triming threshold
111   */
112  public double getTrimingThreshold() {
113    return trimingThreshold;
114  }
115
116  /**
117   * Sets the triming thresholding value.
118   *
119   * @param t the triming threshold
120   */
121  public void setTrimingThreshold( double t ){
122    trimingThreshold = t;
123  }
124
125  /**
126   *  Return true if a value can be considered for mixture estimation
127   *  separately from the data indexed between i0 and i1
128   * 
129   *  @param data the data supposedly generated from the mixture
130   *  @param i0 the index of the first element in the group
131   *  @param i1 the index of the last element in the group
132   *  @param x the value
133   *  @return true if the value can be considered
134   */
135  public boolean separable( DoubleVector data, int i0, int i1, double x ) {
136
137    DoubleVector dataSqrt = data.sqrt();
138    double xh = Math.sqrt( x );
139
140    NormalMixture m = new NormalMixture();
141    m.setSeparatingThreshold( separatingThreshold );
142    return m.separable( dataSqrt, i0, i1, xh );
143  }
144
145  /**
146   *  Contructs the set of support points for mixture estimation.
147   * 
148   *  @param data the data supposedly generated from the mixture
149   *  @param ne the number of extra data that are suppposedly discarded
150   *  earlier and not passed into here
151   *  @return the set of support points
152   */
153  public DoubleVector  supportPoints( DoubleVector data, int ne ) {
154
155    DoubleVector sp = new DoubleVector();
156    sp.setCapacity( data.size() + 1 );
157
158    if( data.get(0) < supportThreshold || ne != 0 ) 
159      sp.addElement( 0 );
160    for( int i = 0; i < data.size(); i++ ) 
161      if( data.get( i ) > supportThreshold ) 
162        sp.addElement( data.get(i) );
163       
164    // The following will be fixed later???
165    if( sp.size() > maxNumSupportPoints ) 
166      throw new IllegalArgumentException( "Too many support points. " );
167
168    return sp;
169  }
170   
171  /**
172   *  Contructs the set of fitting intervals for mixture estimation.
173   * 
174   *  @param data the data supposedly generated from the mixture
175   *  @return the set of fitting intervals
176   */
177  public PaceMatrix  fittingIntervals( DoubleVector data ) {
178
179    PaceMatrix a = new PaceMatrix( data.size() * 2, 2 );
180    DoubleVector v = data.sqrt();
181    int count = 0;
182    double left, right;
183    for( int i = 0; i < data.size(); i++ ) {
184      left = v.get(i) - fittingIntervalLength; 
185      if( left < fittingIntervalThreshold ) left = 0;
186      left = left * left;
187      right = data.get(i);
188      if( right < fittingIntervalThreshold ) 
189        right = fittingIntervalThreshold;
190      a.set( count, 0, left );
191      a.set( count, 1, right );
192      count++;
193    }
194    for( int i = 0; i < data.size(); i++ ) {
195      left = data.get(i);
196      if( left < fittingIntervalThreshold ) left = 0;
197      right = v.get(i) + fittingIntervalThreshold;
198      right = right * right;
199      a.set( count, 0, left );
200      a.set( count, 1, right );
201      count++;
202    }
203    a.setRowDimension( count );
204       
205    return a;
206  }
207   
208  /**
209   *  Contructs the probability matrix for mixture estimation, given a set
210   *  of support points and a set of intervals.
211   * 
212   *  @param s  the set of support points
213   *  @param intervals the intervals
214   *  @return the probability matrix
215   */
216  public PaceMatrix  probabilityMatrix(DoubleVector s, PaceMatrix intervals) {
217   
218    int ns = s.size();
219    int nr = intervals.getRowDimension();
220    PaceMatrix p = new PaceMatrix(nr, ns);
221       
222    for( int i = 0; i < nr; i++ ) {
223      for( int j = 0; j < ns; j++ ) {
224        p.set( i, j,
225               Maths.pchisq( intervals.get(i, 1), s.get(j) ) - 
226               Maths.pchisq( intervals.get(i, 0), s.get(j) ) );
227      }
228    }
229       
230    return p;
231  }
232   
233
234  /**
235   *  Returns the pace6 estimate of a single value.
236   * 
237   *  @param x the value
238   *  @return the pace6 estimate
239   */
240  public double  pace6 ( double x ) { 
241   
242    if( x > 100 ) return x; // pratical consideration. will modify later
243    DoubleVector points = mixingDistribution.getPointValues();
244    DoubleVector values = mixingDistribution.getFunctionValues(); 
245    DoubleVector mean = points.sqrt();
246       
247    DoubleVector d = Maths.dchisqLog( x, points );
248    d.minusEquals( d.max() );
249    d = d.map("java.lang.Math", "exp").timesEquals( values );
250    double atilde = mean.innerProduct( d ) / d.sum();
251    return atilde * atilde;
252  }
253
254  /**
255   *  Returns the pace6 estimate of a vector.
256   * 
257   *  @param x the vector
258   *  @return the pace6 estimate
259   */
260  public DoubleVector pace6( DoubleVector x ) {
261
262    DoubleVector pred = new DoubleVector( x.size() );
263    for(int i = 0; i < x.size(); i++ ) 
264      pred.set(i, pace6(x.get(i)) );
265    trim( pred );
266    return pred;
267  }
268
269  /**
270   *  Returns the pace2 estimate of a vector.
271   * 
272   *  @param x the vector
273   *  @return the pace2 estimate
274   */
275  public DoubleVector  pace2( DoubleVector x ) {
276   
277    DoubleVector chf = new DoubleVector( x.size() );
278    for(int i = 0; i < x.size(); i++ ) chf.set( i, hf( x.get(i) ) );
279
280    chf.cumulateInPlace();
281
282    int index = chf.indexOfMax();
283
284    DoubleVector copy = x.copy();
285    if( index < x.size()-1 ) copy.set( index + 1, x.size()-1, 0 );
286    trim( copy );
287    return copy;
288  }
289
290  /**
291   *  Returns the pace4 estimate of a vector.
292   * 
293   *  @param x the vector
294   *  @return the pace4 estimate
295   */
296  public DoubleVector  pace4( DoubleVector x ) {
297   
298    DoubleVector h = h( x );
299    DoubleVector copy = x.copy();
300    for( int i = 0; i < x.size(); i++ )
301      if( h.get(i) <= 0 ) copy.set(i, 0);
302    trim( copy );
303    return copy;
304  }
305
306  /**
307   * Trims the small values of the estaimte
308   *
309   * @param x the estimate vector
310   */
311  public void trim( DoubleVector x ) {
312   
313    for(int i = 0; i < x.size(); i++ ) {
314      if( x.get(i) <= trimingThreshold ) x.set(i, 0);
315    }
316  }
317   
318  /**
319   *  Computes the value of h(x) / f(x) given the mixture. The
320   *  implementation avoided overflow.
321   * 
322   *  @param AHat the value
323   *  @return the value of h(x) / f(x)
324   */
325  public double hf( double AHat ) {
326   
327    DoubleVector points = mixingDistribution.getPointValues();
328    DoubleVector values = mixingDistribution.getFunctionValues(); 
329
330    double x = Math.sqrt( AHat );
331    DoubleVector mean = points.sqrt();
332    DoubleVector d1 = Maths.dnormLog( x, mean, 1 );
333    double d1max = d1.max();
334    d1.minusEquals( d1max );
335    DoubleVector d2 = Maths.dnormLog( -x, mean, 1 );
336    d2.minusEquals( d1max );
337
338    d1 = d1.map("java.lang.Math", "exp");
339    d1.timesEquals( values ); 
340    d2 = d2.map("java.lang.Math", "exp");
341    d2.timesEquals( values ); 
342
343    return ( ( points.minus(x/2)).innerProduct( d1 ) - 
344             ( points.plus(x/2)).innerProduct( d2 ) ) 
345    / (d1.sum() + d2.sum());
346  }
347   
348  /**
349   *  Computes the value of h(x) given the mixture.
350   * 
351   *  @param AHat the value
352   *  @return the value of h(x)
353   */
354  public double h( double AHat ) {
355   
356    if( AHat == 0.0 ) return 0.0;
357    DoubleVector points = mixingDistribution.getPointValues();
358    DoubleVector values = mixingDistribution.getFunctionValues();
359       
360    double aHat = Math.sqrt( AHat );
361    DoubleVector aStar = points.sqrt();
362    DoubleVector d1 = Maths.dnorm( aHat, aStar, 1 ).timesEquals( values );
363    DoubleVector d2 = Maths.dnorm( -aHat, aStar, 1 ).timesEquals( values );
364
365    return points.minus(aHat/2).innerProduct( d1 ) - 
366           points.plus(aHat/2).innerProduct( d2 );
367  }
368   
369  /**
370   *  Computes the value of h(x) given the mixture, where x is a vector.
371   * 
372   *  @param AHat the vector
373   *  @return the value of h(x)
374   */
375  public DoubleVector h( DoubleVector AHat ) {
376   
377    DoubleVector h = new DoubleVector( AHat.size() );
378    for( int i = 0; i < AHat.size(); i++ ) 
379      h.set( i, h( AHat.get(i) ) );
380    return h;
381  }
382   
383  /**
384   *  Computes the value of f(x) given the mixture.
385   * 
386   *  @param x the value
387   *  @return the value of f(x)
388   */
389  public double f( double x ) {
390   
391    DoubleVector points = mixingDistribution.getPointValues();
392    DoubleVector values = mixingDistribution.getFunctionValues(); 
393
394    return Maths.dchisq(x, points).timesEquals(values).sum(); 
395  }
396   
397  /**
398   *  Computes the value of f(x) given the mixture, where x is a vector.
399   * 
400   *  @param x the vector
401   *  @return the value of f(x)
402   */
403  public DoubleVector f( DoubleVector x ) {
404   
405    DoubleVector f = new DoubleVector( x.size() );
406    for( int i = 0; i < x.size(); i++ ) 
407      f.set( i, h( f.get(i) ) );
408    return f;
409  }
410   
411  /**
412   * Converts to a string
413   *
414   * @return a string representation
415   */
416  public String  toString() {
417    return mixingDistribution.toString();
418  }
419 
420  /**
421   * Returns the revision string.
422   *
423   * @return            the revision
424   */
425  public String getRevision() {
426    return RevisionUtils.extract("$Revision: 1.5 $");
427  }
428   
429  /**
430   * Method to test this class
431   *
432   * @param args the commandline arguments
433   */
434  public static void  main(String args[]) {
435   
436    int n1 = 50;
437    int n2 = 50;
438    double ncp1 = 0;
439    double ncp2 = 10; 
440    double mu1 = Math.sqrt( ncp1 );
441    double mu2 = Math.sqrt( ncp2 );
442    DoubleVector a = Maths.rnorm( n1, mu1, 1, new Random() );
443    a = a.cat( Maths.rnorm(n2, mu2, 1, new Random()) );
444    DoubleVector aNormal = a;
445    a = a.square();
446    a.sort();
447       
448    DoubleVector means = (new DoubleVector( n1, mu1 )).cat(new DoubleVector(n2, mu2));
449       
450    System.out.println("==========================================================");
451    System.out.println("This is to test the estimation of the mixing\n" +
452                       "distribution of the mixture of non-central Chi-square\n" + 
453                       "distributions. The example mixture used is of the form: \n\n" + 
454                       "   0.5 * Chi^2_1(ncp1) + 0.5 * Chi^2_1(ncp2)\n" );
455
456    System.out.println("It also tests the PACE estimators. Quadratic losses of the\n" +
457                       "estimators are given, measuring their performance.");
458    System.out.println("==========================================================");
459    System.out.println( "ncp1 = " + ncp1 + " ncp2 = " + ncp2 +"\n" );
460
461    System.out.println( a.size() + " observations are: \n\n" + a );
462
463    System.out.println( "\nQuadratic loss of the raw data (i.e., the MLE) = " + 
464                        aNormal.sum2( means ) );
465    System.out.println("==========================================================");
466       
467    // find the mixing distribution
468    ChisqMixture d = new ChisqMixture();
469    d.fit( a, NNMMethod ); 
470    System.out.println( "The estimated mixing distribution is\n" + d ); 
471       
472    DoubleVector pred = d.pace2( a.rev() ).rev();
473    System.out.println( "\nThe PACE2 Estimate = \n" + pred );
474    System.out.println( "Quadratic loss = " + 
475                        pred.sqrt().times(aNormal.sign()).sum2( means ) );
476   
477    pred = d.pace4( a );
478    System.out.println( "\nThe PACE4 Estimate = \n" + pred );
479    System.out.println( "Quadratic loss = " + 
480                        pred.sqrt().times(aNormal.sign()).sum2( means ) );
481
482    pred = d.pace6( a );
483    System.out.println( "\nThe PACE6 Estimate = \n" + pred );
484    System.out.println( "Quadratic loss = " + 
485                        pred.sqrt().times(aNormal.sign()).sum2( means ) );
486  }
487}
Note: See TracBrowser for help on using the repository browser.