source: src/main/java/weka/classifiers/functions/RBFNetwork.java @ 7

Last change on this file since 7 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 15.7 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17
18/*
19 *    RBFNetwork.java
20 *    Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
21 *
22 */
23package weka.classifiers.functions;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.clusterers.MakeDensityBasedClusterer;
28import weka.clusterers.SimpleKMeans;
29import weka.core.Capabilities;
30import weka.core.Instance;
31import weka.core.Instances;
32import weka.core.Option;
33import weka.core.OptionHandler;
34import weka.core.RevisionUtils;
35import weka.core.SelectedTag;
36import weka.core.Utils;
37import weka.filters.Filter;
38import weka.filters.unsupervised.attribute.ClusterMembership;
39import weka.filters.unsupervised.attribute.Standardize;
40
41import java.util.Enumeration;
42import java.util.Vector;
43
44/**
45 <!-- globalinfo-start -->
46 * Class that implements a normalized Gaussian radial basisbasis function network.<br/>
47 * It uses the k-means clustering algorithm to provide the basis functions and learns either a logistic regression (discrete class problems) or linear regression (numeric class problems) on top of that. Symmetric multivariate Gaussians are fit to the data from each cluster. If the class is nominal it uses the given number of clusters per class.It standardizes all numeric attributes to zero mean and unit variance.
48 * <p/>
49 <!-- globalinfo-end -->
50 *
51 <!-- options-start -->
52 * Valid options are: <p/>
53 *
54 * <pre> -B &lt;number&gt;
55 *  Set the number of clusters (basis functions) to generate. (default = 2).</pre>
56 *
57 * <pre> -S &lt;seed&gt;
58 *  Set the random seed to be used by K-means. (default = 1).</pre>
59 *
60 * <pre> -R &lt;ridge&gt;
61 *  Set the ridge value for the logistic or linear regression.</pre>
62 *
63 * <pre> -M &lt;number&gt;
64 *  Set the maximum number of iterations for the logistic regression. (default -1, until convergence).</pre>
65 *
66 * <pre> -W &lt;number&gt;
67 *  Set the minimum standard deviation for the clusters. (default 0.1).</pre>
68 *
69 <!-- options-end -->
70 *
71 * @author Mark Hall
72 * @author Eibe Frank
73 * @version $Revision: 5928 $
74 */
75public class RBFNetwork extends AbstractClassifier implements OptionHandler {
76
77  /** for serialization */
78  static final long serialVersionUID = -3669814959712675720L;
79 
80  /** The logistic regression for classification problems */
81  private Logistic m_logistic;
82
83  /** The linear regression for numeric problems */
84  private LinearRegression m_linear;
85
86  /** The filter for producing the meta data */
87  private ClusterMembership m_basisFilter;
88
89  /** Filter used for normalizing the data */
90  private Standardize m_standardize;
91
92  /** The number of clusters (basis functions to generate) */
93  private int m_numClusters = 2;
94
95  /** The ridge parameter for the logistic regression. */
96  protected double m_ridge = 1e-8;
97
98  /** The maximum number of iterations for logistic regression. */
99  private int m_maxIts = -1;
100
101  /** The seed to pass on to K-means */
102  private int m_clusteringSeed = 1;
103
104  /** The minimum standard deviation */
105  private double m_minStdDev = 0.1;
106
107  /** a ZeroR model in case no model can be built from the data */
108  private Classifier m_ZeroR;
109   
110  /**
111   * Returns a string describing this classifier
112   * @return a description of the classifier suitable for
113   * displaying in the explorer/experimenter gui
114   */
115  public String globalInfo() {
116    return "Class that implements a normalized Gaussian radial basis" 
117      + "basis function network.\n"
118      + "It uses the k-means clustering algorithm to provide the basis "
119      + "functions and learns either a logistic regression (discrete "
120      + "class problems) or linear regression (numeric class problems) "
121      + "on top of that. Symmetric multivariate Gaussians are fit to "
122      + "the data from each cluster. If the class is "
123      + "nominal it uses the given number of clusters per class."
124      + "It standardizes all numeric "
125      + "attributes to zero mean and unit variance." ;
126  }
127
128  /**
129   * Returns default capabilities of the classifier, i.e.,  and "or" of
130   * Logistic and LinearRegression.
131   *
132   * @return      the capabilities of this classifier
133   * @see         Logistic
134   * @see         LinearRegression
135   */
136  public Capabilities getCapabilities() {
137    Capabilities result = new Logistic().getCapabilities();
138    result.or(new LinearRegression().getCapabilities());
139    Capabilities classes = result.getClassCapabilities();
140    result.and(new SimpleKMeans().getCapabilities());
141    result.or(classes);
142    return result;
143  }
144
145  /**
146   * Builds the classifier
147   *
148   * @param instances the training data
149   * @throws Exception if the classifier could not be built successfully
150   */
151  public void buildClassifier(Instances instances) throws Exception {
152
153    // can classifier handle the data?
154    getCapabilities().testWithFail(instances);
155
156    // remove instances with missing class
157    instances = new Instances(instances);
158    instances.deleteWithMissingClass();
159   
160    // only class? -> build ZeroR model
161    if (instances.numAttributes() == 1) {
162      System.err.println(
163          "Cannot build model (only class attribute present in data!), "
164          + "using ZeroR model instead!");
165      m_ZeroR = new weka.classifiers.rules.ZeroR();
166      m_ZeroR.buildClassifier(instances);
167      return;
168    }
169    else {
170      m_ZeroR = null;
171    }
172   
173    m_standardize = new Standardize();
174    m_standardize.setInputFormat(instances);
175    instances = Filter.useFilter(instances, m_standardize);
176
177    SimpleKMeans sk = new SimpleKMeans();
178    sk.setNumClusters(m_numClusters);
179    sk.setSeed(m_clusteringSeed);
180    MakeDensityBasedClusterer dc = new MakeDensityBasedClusterer();
181    dc.setClusterer(sk);
182    dc.setMinStdDev(m_minStdDev);
183    m_basisFilter = new ClusterMembership();
184    m_basisFilter.setDensityBasedClusterer(dc);
185    m_basisFilter.setInputFormat(instances);
186    Instances transformed = Filter.useFilter(instances, m_basisFilter);
187
188    if (instances.classAttribute().isNominal()) {
189      m_linear = null;
190      m_logistic = new Logistic();
191      m_logistic.setRidge(m_ridge);
192      m_logistic.setMaxIts(m_maxIts);
193      m_logistic.buildClassifier(transformed);
194    } else {
195      m_logistic = null;
196      m_linear = new LinearRegression();
197      m_linear.setAttributeSelectionMethod(new SelectedTag(LinearRegression.SELECTION_NONE,
198                                                           LinearRegression.TAGS_SELECTION));
199      m_linear.setRidge(m_ridge);
200      m_linear.buildClassifier(transformed);
201    }
202  }
203
204  /**
205   * Computes the distribution for a given instance
206   *
207   * @param instance the instance for which distribution is computed
208   * @return the distribution
209   * @throws Exception if the distribution can't be computed successfully
210   */
211  public double [] distributionForInstance(Instance instance) 
212    throws Exception {
213
214    // default model?
215    if (m_ZeroR != null) {
216      return m_ZeroR.distributionForInstance(instance);
217    }
218   
219    m_standardize.input(instance);
220    m_basisFilter.input(m_standardize.output());
221    Instance transformed = m_basisFilter.output();
222   
223    return ((instance.classAttribute().isNominal()
224             ? m_logistic.distributionForInstance(transformed)
225             : m_linear.distributionForInstance(transformed)));
226  }
227 
228  /**
229   * Returns a description of this classifier as a String
230   *
231   * @return a description of this classifier
232   */
233  public String toString() {
234
235    // only ZeroR model?
236    if (m_ZeroR != null) {
237      StringBuffer buf = new StringBuffer();
238      buf.append(this.getClass().getName().replaceAll(".*\\.", "") + "\n");
239      buf.append(this.getClass().getName().replaceAll(".*\\.", "").replaceAll(".", "=") + "\n\n");
240      buf.append("Warning: No model could be built, hence ZeroR model is used:\n\n");
241      buf.append(m_ZeroR.toString());
242      return buf.toString();
243    }
244   
245    if (m_basisFilter == null) {
246      return "No classifier built yet!";
247    }
248
249    StringBuffer sb = new StringBuffer();
250    sb.append("Radial basis function network\n");
251    sb.append((m_linear == null) 
252              ? "(Logistic regression "
253              : "(Linear regression ");
254    sb.append("applied to K-means clusters as basis functions):\n\n");
255    sb.append((m_linear == null)
256              ? m_logistic.toString()
257              : m_linear.toString());
258    return sb.toString();
259  }
260
261  /**
262   * Returns the tip text for this property
263   * @return tip text for this property suitable for
264   * displaying in the explorer/experimenter gui
265   */
266  public String maxItsTipText() {
267    return "Maximum number of iterations for the logistic regression to perform. "
268      +"Only applied to discrete class problems.";
269  }
270
271  /**
272   * Get the value of MaxIts.
273   *
274   * @return Value of MaxIts.
275   */
276  public int getMaxIts() {
277       
278    return m_maxIts;
279  }
280   
281  /**
282   * Set the value of MaxIts.
283   *
284   * @param newMaxIts Value to assign to MaxIts.
285   */
286  public void setMaxIts(int newMaxIts) {
287       
288    m_maxIts = newMaxIts;
289  }   
290
291  /**
292   * Returns the tip text for this property
293   * @return tip text for this property suitable for
294   * displaying in the explorer/experimenter gui
295   */
296  public String ridgeTipText() {
297    return "Set the Ridge value for the logistic or linear regression.";
298  }
299
300  /**
301   * Sets the ridge value for logistic or linear regression.
302   *
303   * @param ridge the ridge
304   */
305  public void setRidge(double ridge) {
306    m_ridge = ridge;
307  }
308   
309  /**
310   * Gets the ridge value.
311   *
312   * @return the ridge
313   */
314  public double getRidge() {
315    return m_ridge;
316  }
317
318  /**
319   * Returns the tip text for this property
320   * @return tip text for this property suitable for
321   * displaying in the explorer/experimenter gui
322   */
323  public String numClustersTipText() {
324    return "The number of clusters for K-Means to generate.";
325  }
326
327  /**
328   * Set the number of clusters for K-means to generate.
329   *
330   * @param numClusters the number of clusters to generate.
331   */
332  public void setNumClusters(int numClusters) {
333    if (numClusters > 0) {
334      m_numClusters = numClusters;
335    }
336  }
337
338  /**
339   * Return the number of clusters to generate.
340   *
341   * @return the number of clusters to generate.
342   */
343  public int getNumClusters() {
344    return m_numClusters;
345  }
346
347  /**
348   * Returns the tip text for this property
349   * @return tip text for this property suitable for
350   * displaying in the explorer/experimenter gui
351   */
352  public String clusteringSeedTipText() {
353    return "The random seed to pass on to K-means.";
354  }
355 
356  /**
357   * Set the random seed to be passed on to K-means.
358   *
359   * @param seed a seed value.
360   */
361  public void setClusteringSeed(int seed) {
362    m_clusteringSeed = seed;
363  }
364
365  /**
366   * Get the random seed used by K-means.
367   *
368   * @return the seed value.
369   */
370  public int getClusteringSeed() {
371    return m_clusteringSeed;
372  }
373
374  /**
375   * Returns the tip text for this property
376   * @return tip text for this property suitable for
377   * displaying in the explorer/experimenter gui
378   */
379  public String minStdDevTipText() {
380    return "Sets the minimum standard deviation for the clusters.";
381  }
382
383  /**
384   * Get the MinStdDev value.
385   * @return the MinStdDev value.
386   */
387  public double getMinStdDev() {
388    return m_minStdDev;
389  }
390
391  /**
392   * Set the MinStdDev value.
393   * @param newMinStdDev The new MinStdDev value.
394   */
395  public void setMinStdDev(double newMinStdDev) {
396    m_minStdDev = newMinStdDev;
397  }
398
399 
400  /**
401   * Returns an enumeration describing the available options
402   *
403   * @return an enumeration of all the available options
404   */
405  public Enumeration listOptions() {
406    Vector newVector = new Vector(4);
407
408    newVector.addElement(new Option("\tSet the number of clusters (basis functions) "
409                                    +"to generate. (default = 2).",
410                                    "B", 1, "-B <number>"));
411    newVector.addElement(new Option("\tSet the random seed to be used by K-means. "
412                                    +"(default = 1).",
413                                    "S", 1, "-S <seed>"));
414    newVector.addElement(new Option("\tSet the ridge value for the logistic or "
415                                    +"linear regression.",
416                                    "R", 1, "-R <ridge>"));
417    newVector.addElement(new Option("\tSet the maximum number of iterations "
418                                    +"for the logistic regression."
419                                    + " (default -1, until convergence).",
420                                    "M", 1, "-M <number>"));
421    newVector.addElement(new Option("\tSet the minimum standard "
422                                    +"deviation for the clusters."
423                                    + " (default 0.1).",
424                                    "W", 1, "-W <number>"));
425    return newVector.elements();
426  }
427
428  /**
429   * Parses a given list of options. <p/>
430   *
431   <!-- options-start -->
432   * Valid options are: <p/>
433   *
434   * <pre> -B &lt;number&gt;
435   *  Set the number of clusters (basis functions) to generate. (default = 2).</pre>
436   *
437   * <pre> -S &lt;seed&gt;
438   *  Set the random seed to be used by K-means. (default = 1).</pre>
439   *
440   * <pre> -R &lt;ridge&gt;
441   *  Set the ridge value for the logistic or linear regression.</pre>
442   *
443   * <pre> -M &lt;number&gt;
444   *  Set the maximum number of iterations for the logistic regression. (default -1, until convergence).</pre>
445   *
446   * <pre> -W &lt;number&gt;
447   *  Set the minimum standard deviation for the clusters. (default 0.1).</pre>
448   *
449   <!-- options-end -->
450   *
451   * @param options the list of options as an array of strings
452   * @throws Exception if an option is not supported
453   */
454  public void setOptions(String[] options) throws Exception {
455    setDebug(Utils.getFlag('D', options));
456
457    String ridgeString = Utils.getOption('R', options);
458    if (ridgeString.length() != 0) {
459      m_ridge = Double.parseDouble(ridgeString);
460    } else {
461      m_ridge = 1.0e-8;
462    }
463       
464    String maxItsString = Utils.getOption('M', options);
465    if (maxItsString.length() != 0) {
466      m_maxIts = Integer.parseInt(maxItsString);
467    } else {
468      m_maxIts = -1;
469    }
470
471    String numClustersString = Utils.getOption('B', options);
472    if (numClustersString.length() != 0) {
473      setNumClusters(Integer.parseInt(numClustersString));
474    }
475
476    String seedString = Utils.getOption('S', options);
477    if (seedString.length() != 0) {
478      setClusteringSeed(Integer.parseInt(seedString));
479    }
480    String stdString = Utils.getOption('W', options);
481    if (stdString.length() != 0) {
482      setMinStdDev(Double.parseDouble(stdString));
483    }
484    Utils.checkForRemainingOptions(options);
485  }
486
487  /**
488   * Gets the current settings of the classifier.
489   *
490   * @return an array of strings suitable for passing to setOptions
491   */
492  public String [] getOptions() {
493       
494    String [] options = new String [10];
495    int current = 0;
496   
497    options[current++] = "-B";
498    options[current++] = "" + m_numClusters;
499    options[current++] = "-S";
500    options[current++] = "" + m_clusteringSeed;
501    options[current++] = "-R";
502    options[current++] = ""+m_ridge;   
503    options[current++] = "-M";
504    options[current++] = ""+m_maxIts;
505    options[current++] = "-W";
506    options[current++] = ""+m_minStdDev;
507
508    while (current < options.length) 
509      options[current++] = "";
510    return options;
511  }
512 
513  /**
514   * Returns the revision string.
515   *
516   * @return            the revision
517   */
518  public String getRevision() {
519    return RevisionUtils.extract("$Revision: 5928 $");
520  }
521
522  /**
523   * Main method for testing this class.
524   *
525   * @param argv should contain the command line arguments to the
526   * scheme (see Evaluation)
527   */
528  public static void main(String [] argv) {
529    runClassifier(new RBFNetwork(), argv);
530  }
531}
Note: See TracBrowser for help on using the repository browser.