source: src/main/java/weka/estimators/EstimatorUtils.java @ 6

Last change on this file since 6 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 9.6 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    EstimatorUtils.java
19 *    Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.estimators;
24
25import weka.core.Instance;
26import weka.core.Instances;
27import weka.core.RevisionHandler;
28import weka.core.RevisionUtils;
29
30import java.io.FileOutputStream;
31import java.io.PrintWriter;
32import java.util.Enumeration;
33import java.util.Vector;
34 
35/**
36 * Contains static utility functions for Estimators.<p>
37 *
38 * @author Gabi Schmidberger (gabi@cs.waikato.ac.nz)
39 * @version $Revision: 1.4 $
40 */
41public class EstimatorUtils
42  implements RevisionHandler {
43 
44  /**
45   * Find the minimum distance between values
46   * @param inst sorted instances, sorted
47   * @param attrIndex index of the attribute, they are sorted after
48   * @return the minimal distance
49   */
50  public static double findMinDistance(Instances inst, int attrIndex) {
51    double min = Double.MAX_VALUE;
52    int numInst = inst.numInstances();
53    double diff;
54    if (numInst < 2) return min;
55    int begin = -1;
56    Instance instance = null;
57    do { 
58      begin++;
59      if (begin < numInst) 
60        { instance = inst.instance(begin); }
61    } while (begin < numInst && instance.isMissing(attrIndex)); 
62
63    double secondValue = inst.instance(begin).value(attrIndex);
64    for (int i = begin; i < numInst && !inst.instance(i).isMissing(attrIndex);  i++) {
65      double firstValue = secondValue; 
66      secondValue = inst.instance(i).value(attrIndex);
67      if (secondValue != firstValue) {
68        diff = secondValue - firstValue;
69        if (diff < min && diff > 0.0) {
70          min = diff;
71        }
72      }
73    }
74    return min;
75  }
76
77  /**
78   * Find the minimum and the maximum of the attribute and return it in
79   * the last parameter..
80   * @param inst instances used to build the estimator
81   * @param attrIndex index of the attribute
82   * @param minMax the array to return minimum and maximum in
83   * @return number of not missing values
84   * @exception Exception if parameter minMax wasn't initialized properly
85   */
86  public static int getMinMax(Instances inst, int attrIndex, double [] minMax) 
87    throws Exception {
88    double min = Double.NaN;
89    double max = Double.NaN;
90    Instance instance = null;
91    int numNotMissing = 0;
92    if ((minMax == null) || (minMax.length < 2)) {
93      throw new Exception("Error in Program, privat method getMinMax");
94    }
95   
96    Enumeration enumInst = inst.enumerateInstances();
97    if (enumInst.hasMoreElements()) {
98      do {
99        instance = (Instance) enumInst.nextElement();
100      } while (instance.isMissing(attrIndex) && (enumInst.hasMoreElements()));
101     
102      // add values if not  missing
103      if (!instance.isMissing(attrIndex)) {
104        numNotMissing++;
105        min = instance.value(attrIndex);
106        max = instance.value(attrIndex);
107      }
108      while (enumInst.hasMoreElements()) {
109        instance = (Instance) enumInst.nextElement();
110        if (!instance.isMissing(attrIndex)) {
111          numNotMissing++;
112          if (instance.value(attrIndex) < min) {
113            min = (instance.value(attrIndex));
114          } else {
115            if (instance.value(attrIndex) > max) {           
116              max = (instance.value(attrIndex));
117            }
118          }
119        }
120      }
121    }
122    minMax[0] = min;
123    minMax[1] = max;
124    return numNotMissing;
125  }
126
127  /**
128   * Returns a dataset that contains all instances of a certain class value.
129   *
130   * @param data dataset to select the instances from
131   * @param attrIndex index of the relevant attribute
132   * @param classIndex index of the class attribute
133   * @param classValue the relevant class value
134   * @return a dataset with only
135   */
136  public static Vector getInstancesFromClass(Instances data, int attrIndex,
137                                             int classIndex,
138                                             double classValue, Instances workData) {
139    //Oops.pln("getInstancesFromClass classValue"+classValue+" workData"+data.numInstances());
140    Vector dataPlusInfo = new Vector(0);
141    int num = 0;
142    int numClassValue = 0;
143    //workData = new Instances(data, 0);
144    for (int i = 0; i < data.numInstances(); i++) {
145      if (!data.instance(i).isMissing(attrIndex)) {
146        num++;
147        if (data.instance(i).value(classIndex) == classValue) {
148          workData.add(data.instance(i));
149          numClassValue++;
150        }
151      }
152    } 
153
154    Double alphaFactor = new Double((double)numClassValue/(double)num);
155    dataPlusInfo.add(workData);
156    dataPlusInfo.add(alphaFactor);
157    return dataPlusInfo;
158  }
159
160
161  /**
162   * Returns a dataset that contains of all instances of a certain class value.
163   * @param data dataset to select the instances from
164   * @param classIndex index of the class attribute
165   * @param classValue the class value
166   * @return a dataset with only instances of one class value
167   */
168  public static Instances getInstancesFromClass(Instances data, int classIndex,
169                                                double classValue) {
170     Instances workData = new Instances(data, 0);
171    for (int i = 0; i < data.numInstances(); i++) {
172      if (data.instance(i).value(classIndex) == classValue) {
173        workData.add(data.instance(i));
174      }
175     
176    }
177    return workData;
178  }
179 
180   
181   
182  /**
183   * Output of an n points of a density curve.
184   * Filename is parameter f + ".curv".
185   *
186   * @param f string to build filename
187   * @param est
188   * @param min
189   * @param max
190   * @param numPoints
191   * @throws Exception if something goes wrong
192   */
193  public static void writeCurve(String f, Estimator est, 
194                                double min, double max,
195                                int numPoints) throws Exception {
196
197    PrintWriter output = null;
198    StringBuffer text = new StringBuffer("");
199   
200    if (f.length() != 0) {
201      // add attribute indexnumber to filename and extension .hist
202      String name = f + ".curv";
203      output = new PrintWriter(new FileOutputStream(name));
204    } else {
205      return;
206    }
207
208    double diff = (max - min) / ((double)numPoints - 1.0);
209    try {
210      text.append("" + min + " " + est.getProbability(min) + " \n");
211
212      for (double value = min + diff; value < max; value += diff) {
213        text.append("" + value + " " + est.getProbability(value) + " \n");
214      }
215      text.append("" + max + " " + est.getProbability(max) + " \n");
216    } catch (Exception ex) {
217      ex.printStackTrace();
218      System.out.println(ex.getMessage());
219    }
220    output.println(text.toString());   
221
222    // close output
223    if (output != null) {
224      output.close();
225    }
226  }
227
228  /**
229   * Output of an n points of a density curve.
230   * Filename is parameter f + ".curv".
231   *
232   * @param f string to build filename
233   * @param est
234   * @param classEst
235   * @param classIndex
236   * @param min
237   * @param max
238   * @param numPoints
239   * @throws Exception if something goes wrong
240   */
241  public static void writeCurve(String f, Estimator est, 
242                                Estimator classEst,
243                                double classIndex,
244                                double min, double max,
245                                int numPoints) throws Exception {
246
247    PrintWriter output = null;
248    StringBuffer text = new StringBuffer("");
249   
250    if (f.length() != 0) {
251      // add attribute indexnumber to filename and extension .hist
252      String name = f + ".curv";
253      output = new PrintWriter(new FileOutputStream(name));
254    } else {
255      return;
256    }
257
258    double diff = (max - min) / ((double)numPoints - 1.0);
259    try {
260      text.append("" + min + " " + 
261                  est.getProbability(min) * classEst.getProbability(classIndex)
262                  + " \n");
263
264      for (double value = min + diff; value < max; value += diff) {
265        text.append("" + value + " " + 
266                    est.getProbability(value) * classEst.getProbability(classIndex)
267                    + " \n");
268      }
269      text.append("" + max + " " +
270                  est.getProbability(max) * classEst.getProbability(classIndex)
271                  + " \n");
272    } catch (Exception ex) {
273      ex.printStackTrace();
274      System.out.println(ex.getMessage());
275    }
276    output.println(text.toString());   
277
278    // close output
279    if (output != null) {
280      output.close();
281    }
282  }
283
284 
285  /**
286   * Returns a dataset that contains of all instances of a certain value
287   * for the given attribute.
288   * @param data dataset to select the instances from
289   * @param index the index of the attribute 
290   * @param v the value
291   * @return a subdataset with only instances of one value for the attribute
292   */
293  public static Instances getInstancesFromValue(Instances data, int index,
294                                          double v) {
295    Instances workData = new Instances(data, 0);
296    for (int i = 0; i < data.numInstances(); i++) {
297      if (data.instance(i).value(index) == v) {
298        workData.add(data.instance(i));
299      }
300    } 
301    return workData;
302  }
303
304   
305  /**
306   * Returns a string representing the cutpoints
307   */
308  public static String cutpointsToString(double [] cutPoints, boolean [] cutAndLeft) {
309    StringBuffer text = new StringBuffer("");
310    if (cutPoints == null) {
311      text.append("\n# no cutpoints found - attribute \n"); 
312    } else {
313      text.append("\n#* "+cutPoints.length+" cutpoint(s) -\n"); 
314      for (int i = 0; i < cutPoints.length; i++) {
315        text.append("# "+cutPoints[i]+" "); 
316        text.append(""+cutAndLeft[i]+"\n");
317      }
318      text.append("# end\n");
319    }
320    return text.toString();
321  }
322 
323  /**
324   * Returns the revision string.
325   *
326   * @return            the revision
327   */
328  public String getRevision() {
329    return RevisionUtils.extract("$Revision: 1.4 $");
330  }
331}
Note: See TracBrowser for help on using the repository browser.