source: src/main/java/weka/classifiers/functions/pace/MixtureDistribution.java @ 16

Last change on this file since 16 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 8.9 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or (at
5 *    your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful, but
8 *    WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
10 *    General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.  */
15
16/*
17 *    MixtureDistribution.java
18 *    Copyright (C) 2002 University of Waikato, Hamilton, New Zealand
19 *
20 */
21
22package weka.classifiers.functions.pace;
23
24import weka.core.RevisionHandler;
25import weka.core.TechnicalInformation;
26import weka.core.TechnicalInformationHandler;
27import weka.core.TechnicalInformation.Field;
28import weka.core.TechnicalInformation.Type;
29import weka.core.matrix.DoubleVector;
30import weka.core.matrix.IntVector;
31
32/**
33 * Abtract class for manipulating mixture distributions. <p>
34 *
35 * REFERENCES <p>
36 *
37 * Wang, Y. (2000). "A new approach to fitting linear models in high
38 * dimensional spaces." PhD Thesis. Department of Computer Science,
39 * University of Waikato, New Zealand. <p>
40 *
41 * Wang, Y. and Witten, I. H. (2002). "Modeling for optimal probability
42 * prediction." Proceedings of ICML'2002. Sydney. <p>
43 *
44 * @author Yong Wang (yongwang@cs.waikato.ac.nz)
45 * @version $Revision: 1.5 $ */
46
47public abstract class MixtureDistribution
48  implements TechnicalInformationHandler, RevisionHandler {
49 
50  protected DiscreteFunction mixingDistribution;
51
52  /** The nonnegative-measure-based method */
53  public static final int NNMMethod = 1; 
54   
55  /** The probability-measure-based method */
56  public static final int PMMethod = 2;
57
58  // The CDF-based method
59  // public static final int CDFMethod = 3;
60   
61  // The method based on the Kolmogrov and von Mises measure
62  // public static final int ModifiedCDFMethod = 4;
63
64  /**
65   * Returns an instance of a TechnicalInformation object, containing
66   * detailed information about the technical background of this class,
67   * e.g., paper reference or book this class is based on.
68   *
69   * @return the technical information about this class
70   */
71  public TechnicalInformation getTechnicalInformation() {
72    TechnicalInformation        result;
73    TechnicalInformation        additional;
74   
75    result = new TechnicalInformation(Type.PHDTHESIS);
76    result.setValue(Field.AUTHOR, "Wang, Y");
77    result.setValue(Field.YEAR, "2000");
78    result.setValue(Field.TITLE, "A new approach to fitting linear models in high dimensional spaces");
79    result.setValue(Field.SCHOOL, "Department of Computer Science, University of Waikato");
80    result.setValue(Field.ADDRESS, "Hamilton, New Zealand");
81
82    additional = result.add(Type.INPROCEEDINGS);
83    additional.setValue(Field.AUTHOR, "Wang, Y. and Witten, I. H.");
84    additional.setValue(Field.YEAR, "2002");
85    additional.setValue(Field.TITLE, "Modeling for optimal probability prediction");
86    additional.setValue(Field.BOOKTITLE, "Proceedings of the Nineteenth International Conference in Machine Learning");
87    additional.setValue(Field.YEAR, "2002");
88    additional.setValue(Field.PAGES, "650-657");
89    additional.setValue(Field.ADDRESS, "Sydney, Australia");
90   
91    return result;
92  }
93   
94  /**
95   * Gets the mixing distribution
96   *
97   * @return the mixing distribution
98   */
99  public DiscreteFunction getMixingDistribution() {
100    return mixingDistribution;
101  }
102
103  /** Sets the mixing distribution
104   *  @param d the mixing distribution
105   */
106  public void  setMixingDistribution( DiscreteFunction d ) {
107    mixingDistribution = d;
108  }
109
110  /** Fits the mixture (or mixing) distribution to the data. The default
111   *  method is the nonnegative-measure-based method.
112   * @param data the data, supposedly generated from the mixture model */
113  public void fit( DoubleVector data ) {
114    fit( data, NNMMethod );
115  }
116
117  /** Fits the mixture (or mixing) distribution to the data.
118   *  @param data the data supposedly generated from the mixture
119   *  @param method the method to be used. Refer to the static final
120   *  variables of this class. */
121  public void fit( DoubleVector data, int method ) {
122    DoubleVector data2 = (DoubleVector) data.clone();
123    if( data2.unsorted() ) data2.sort();
124
125    int n = data2.size();
126    int start = 0;
127    DoubleVector subset;
128    DiscreteFunction d = new DiscreteFunction();
129    for( int i = 0; i < n-1; i++ ) {
130      if( separable( data2, start, i, data2.get(i+1) ) &&
131          separable( data2, i+1, n-1, data2.get(i) ) ) {
132        subset = (DoubleVector) data2.subvector( start, i );
133        d.plusEquals( fitForSingleCluster( subset, method ).
134                      timesEquals(i - start + 1) );
135        start = i + 1;
136      }
137    }
138    subset = (DoubleVector) data2.subvector( start, n-1 );
139    d.plusEquals( fitForSingleCluster( subset, method ).
140                  timesEquals(n - start) ); 
141    d.sort();
142    d.normalize();
143    mixingDistribution = d;
144  }
145   
146  /**
147   *  Fits the mixture (or mixing) distribution to the data. The data is
148   *  not pre-clustered for computational efficiency.
149   * 
150   *  @param data the data supposedly generated from the mixture
151   *  @param method the method to be used. Refer to the static final
152   *  variables of this class.
153   *  @return the generated distribution
154   */
155  public DiscreteFunction fitForSingleCluster( DoubleVector data, 
156                                               int method ) {
157   
158    if( data.size() < 2 ) return new DiscreteFunction( data );
159    DoubleVector sp = supportPoints( data, 0 );
160    PaceMatrix fi = fittingIntervals( data );
161    PaceMatrix pm = probabilityMatrix( sp, fi );
162    PaceMatrix epm = new 
163      PaceMatrix( empiricalProbability( data, fi ).
164                  timesEquals( 1. / data.size() ) );
165   
166    IntVector pvt = (IntVector) IntVector.seq(0, sp.size()-1);
167    DoubleVector weights;
168   
169    switch( method ) {
170    case NNMMethod: 
171      weights = pm.nnls( epm, pvt );
172      break;
173    case PMMethod:
174      weights = pm.nnlse1( epm, pvt );
175      break;
176    default: 
177      throw new IllegalArgumentException("unknown method");
178    }
179   
180    DoubleVector sp2 = new DoubleVector( pvt.size() );
181    for( int i = 0; i < sp2.size(); i++ ){
182      sp2.set( i, sp.get(pvt.get(i)) );
183    }
184   
185    DiscreteFunction d = new DiscreteFunction( sp2, weights );
186    d.sort();
187    d.normalize();
188    return d;
189  }
190   
191  /**
192   *  Return true if a value can be considered for mixture estimatino
193   *  separately from the data indexed between i0 and i1
194   * 
195   *  @param data the data supposedly generated from the mixture
196   *  @param i0 the index of the first element in the group
197   *  @param i1 the index of the last element in the group
198   *  @param x the value
199   *  @return true if a value can be considered
200   */
201  public abstract boolean separable( DoubleVector data, 
202                                     int i0, int i1, double x );
203   
204  /**
205   *  Contructs the set of support points for mixture estimation.
206   * 
207   *  @param data the data supposedly generated from the mixture
208   *  @param ne the number of extra data that are suppposedly discarded
209   *  earlier and not passed into here
210   *  @return the set of support points
211   */
212  public abstract DoubleVector  supportPoints( DoubleVector data, int ne );
213   
214  /**
215   *  Contructs the set of fitting intervals for mixture estimation.
216   * 
217   *  @param data the data supposedly generated from the mixture
218   *  @return the set of fitting intervals
219   */
220  public abstract PaceMatrix  fittingIntervals( DoubleVector data );
221 
222  /**
223   *  Contructs the probability matrix for mixture estimation, given a set
224   *  of support points and a set of intervals.
225   * 
226   *  @param s  the set of support points
227   *  @param intervals the intervals
228   *  @return the probability matrix
229   */
230  public abstract PaceMatrix  probabilityMatrix( DoubleVector s, 
231                                                 PaceMatrix intervals );
232   
233  /**
234   *  Computes the empirical probabilities of the data over a set of
235   *  intervals.
236   * 
237   *  @param data the data
238   *  @param intervals the intervals
239   *  @return the empirical probabilities
240   */
241  public PaceMatrix  empiricalProbability( DoubleVector data, 
242                                           PaceMatrix intervals )
243  {
244    int n = data.size();
245    int k = intervals.getRowDimension();
246    PaceMatrix epm = new PaceMatrix( k, 1, 0 );
247   
248    double point;
249    for( int j = 0; j < n; j ++ ) {
250      for(int i = 0; i < k; i++ ) {
251        point = 0.0;
252        if( intervals.get(i, 0) == data.get(j) || 
253            intervals.get(i, 1) == data.get(j) ) point = 0.5;
254        else if( intervals.get(i, 0) < data.get(j) && 
255                 intervals.get(i, 1) > data.get(j) ) point = 1.0;
256        epm.setPlus( i, 0, point);
257      }
258    }
259    return epm;
260  }
261 
262  /**
263   * Converts to a string
264   *
265   * @return a string representation
266   */
267  public String  toString() 
268  {
269    return "The mixing distribution:\n" + mixingDistribution.toString();
270  }
271   
272}
273
Note: See TracBrowser for help on using the repository browser.