/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* GridSearch.java
* Copyright (C) 2006 University of Waikato, Hamilton, New Zealand
*/
package weka.classifiers.meta;
import weka.classifiers.Classifier;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Evaluation;
import weka.classifiers.RandomizableSingleClassifierEnhancer;
import weka.classifiers.functions.LinearRegression;
import weka.core.AdditionalMeasureProducer;
import weka.core.Capabilities;
import weka.core.Debug;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.MathematicalExpression;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.PropertyPath;
import weka.core.RevisionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.SerializedObject;
import weka.core.Summarizable;
import weka.core.Tag;
import weka.core.Utils;
import weka.core.Capabilities.Capability;
import weka.filters.Filter;
import weka.filters.supervised.attribute.PLSFilter;
import weka.filters.unsupervised.attribute.MathExpression;
import weka.filters.unsupervised.attribute.NumericCleaner;
import weka.filters.unsupervised.instance.Resample;
import java.beans.PropertyDescriptor;
import java.io.File;
import java.io.Serializable;
import java.util.Collections;
import java.util.Comparator;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.Random;
import java.util.Vector;
/**
* Performs a grid search of parameter pairs for the a classifier (Y-axis, default is LinearRegression with the "Ridge" parameter) and the PLSFilter (X-axis, "# of Components") and chooses the best pair found for the actual predicting.
*
* The initial grid is worked on with 2-fold CV to determine the values of the parameter pairs for the selected type of evaluation (e.g., accuracy). The best point in the grid is then taken and a 10-fold CV is performed with the adjacent parameter pairs. If a better pair is found, then this will act as new center and another 10-fold CV will be performed (kind of hill-climbing). This process is repeated until no better pair is found or the best pair is on the border of the grid.
* In case the best pair is on the border, one can let GridSearch automatically extend the grid and continue the search. Check out the properties 'gridIsExtendable' (option '-extend-grid') and 'maxGridExtensions' (option '-max-grid-extensions <num>').
*
* GridSearch can handle doubles, integers (values are just cast to int) and booleans (0 is false, otherwise true). float, char and long are supported as well.
*
* The best filter/classifier setup can be accessed after the buildClassifier call via the getBestFilter/getBestClassifier methods.
* Note on the implementation: after the data has been passed through the filter, a default NumericCleaner filter is applied to the data in order to avoid numbers that are getting too small and might produce NaNs in other schemes.
*
*
* Valid options are:
*
*
-E <CC|RMSE|RRSE|MAE|RAE|COMB|ACC|KAP>
* Determines the parameter used for evaluation:
* CC = Correlation coefficient
* RMSE = Root mean squared error
* RRSE = Root relative squared error
* MAE = Mean absolute error
* RAE = Root absolute error
* COMB = Combined = (1-abs(CC)) + RRSE + RAE
* ACC = Accuracy
* KAP = Kappa
* (default: CC)
*
*
-y-property <option>
* The Y option to test (without leading dash).
* (default: classifier.ridge)
*
*
-y-min <num>
* The minimum for Y.
* (default: -10)
*
*
-y-max <num>
* The maximum for Y.
* (default: +5)
*
*
-y-step <num>
* The step size for Y.
* (default: 1)
*
*
-y-base <num>
* The base for Y.
* (default: 10)
*
*
-y-expression <expr>
* The expression for Y.
* Available parameters:
* BASE
* FROM
* TO
* STEP
* I - the current iteration value
* (from 'FROM' to 'TO' with stepsize 'STEP')
* (default: 'pow(BASE,I)')
*
*
-filter <filter specification>
* The filter to use (on X axis). Full classname of filter to include,
* followed by scheme options.
* (default: weka.filters.supervised.attribute.PLSFilter)
*
*
-x-property <option>
* The X option to test (without leading dash).
* (default: filter.numComponents)
*
*
-x-min <num>
* The minimum for X.
* (default: +5)
*
*
-x-max <num>
* The maximum for X.
* (default: +20)
*
*
-x-step <num>
* The step size for X.
* (default: 1)
*
*
-x-base <num>
* The base for X.
* (default: 10)
*
*
-x-expression <expr>
* The expression for the X value.
* Available parameters:
* BASE
* MIN
* MAX
* STEP
* I - the current iteration value
* (from 'FROM' to 'TO' with stepsize 'STEP')
* (default: 'pow(BASE,I)')
*
*
-extend-grid
* Whether the grid can be extended.
* (default: no)
*
*
-max-grid-extensions <num>
* The maximum number of grid extensions (-1 is unlimited).
* (default: 3)
*
*
-sample-size <num>
* The size (in percent) of the sample to search the inital grid with.
* (default: 100)
*
*
-traversal <ROW-WISE|COLUMN-WISE>
* The type of traversal for the grid.
* (default: COLUMN-WISE)
*
*
-log-file <filename>
* The log file to log the messages to.
* (default: none)
*
*
-S <num>
* Random number seed.
* (default 1)
*
*
-D
* If set, classifier is run in debug mode and
* may output additional info to the console
*
*
-W
* Full name of base classifier.
* (default: weka.classifiers.functions.LinearRegression)
*
*
* Options specific to classifier weka.classifiers.functions.LinearRegression:
*
*
*
-D
* Produce debugging output.
* (default no debugging output)
*
*
-S <number of selection method>
* Set the attribute selection method to use. 1 = None, 2 = Greedy.
* (default 0 = M5' method)
*
*
-C
* Do not try to eliminate colinear attributes.
*
*
*
-R <double>
* Set ridge parameter (default 1.0e-8).
*
*
*
* Options specific to filter weka.filters.supervised.attribute.PLSFilter ('-filter'):
*
*
*
-D
* Turns on output of debugging information.
*
*
-C <num>
* The number of components to compute.
* (default: 20)
*
*
-U
* Updates the class attribute as well.
* (default: off)
*
*
-M
* Turns replacing of missing values on.
* (default: off)
*
*
-A <SIMPLS|PLS1>
* The algorithm to use.
* (default: PLS1)
*
*
-P <none|center|standardize>
* The type of preprocessing that is applied to the data.
* (default: center)
*
*
* Examples:
*
*
* Optimizing SMO with RBFKernel (C and gamma)
*
*
Set the evaluation to Accuracy.
*
Set the filter to weka.filters.AllFilter since we
* don't need any special data processing and we don't optimize the
* filter in this case (data gets always passed through filter!).
*
Set weka.classifiers.functions.SMO as classifier
* with weka.classifiers.functions.supportVector.RBFKernel
* as kernel.
*
*
Set the XProperty to "classifier.c", XMin to "1", XMax to "16",
* XStep to "1" and the XExpression to "I". This will test the "C"
* parameter of SMO for the values from 1 to 16.
*
Set the YProperty to "classifier.kernel.gamma", YMin to "-5",
* YMax to "2", YStep to "1" YBase to "10" and YExpression to
* "pow(BASE,I)". This will test the gamma of the RBFKernel with the
* values 10^-5, 10^-4,..,10^2.
*
*
*
* Optimizing PLSFilter with LinearRegression (# of components and ridge) - default setup
*
*
Set the evaluation to Correlation coefficient.
*
Set the filter to weka.filters.supervised.attribute.PLSFilter.
*
Set weka.classifiers.functions.LinearRegression as
* classifier and use no attribute selection and no elimination of
* colinear attributes.
*
Set the XProperty to "filter.numComponents", XMin to "5", XMax
* to "20" (this depends heavily on your dataset, should be no more
* than the number of attributes!), XStep to "1" and XExpression to
* "I". This will test the number of components the PLSFilter will
* produce from 5 to 20.
*
Set the YProperty to "classifier.ridge", XMin to "-10", XMax to
* "5", YStep to "1" and YExpression to "pow(BASE,I)". This will
* try ridge parameters from 10^-10 to 10^5.
*
*
*
*
* General notes:
*
*
Turn the debug flag on in order to see some progress output in the
* console
*
If you want to view the fitness landscape that GridSearch explores,
* select a log file. This log will then contain Gnuplot data and
* script block for viewing the landscape. Just copy paste those blocks
* into files named accordingly and run Gnuplot with them.
*
*
* @author Bernhard Pfahringer (bernhard at cs dot waikato dot ac dot nz)
* @author Geoff Holmes (geoff at cs dot waikato dot ac dot nz)
* @author fracpete (fracpete at waikato dot ac dot nz)
* @version $Revision: 5928 $
* @see PLSFilter
* @see LinearRegression
* @see NumericCleaner
*/
public class GridSearch
extends RandomizableSingleClassifierEnhancer
implements AdditionalMeasureProducer, Summarizable {
/**
* a serializable version of Point2D.Double
*
* @see java.awt.geom.Point2D.Double
*/
protected class PointDouble
extends java.awt.geom.Point2D.Double
implements Serializable, RevisionHandler {
/** for serialization */
private static final long serialVersionUID = 7151661776161898119L;
/**
* the default constructor
*
* @param x the x value of the point
* @param y the y value of the point
*/
public PointDouble(double x, double y) {
super(x, y);
}
/**
* Determines whether or not two points are equal.
*
* @param obj an object to be compared with this PointDouble
* @return true if the object to be compared has the same values;
* false otherwise.
*/
public boolean equals(Object obj) {
PointDouble pd;
pd = (PointDouble) obj;
return (Utils.eq(this.getX(), pd.getX()) && Utils.eq(this.getY(), pd.getY()));
}
/**
* returns a string representation of the Point
*
* @return the point as string
*/
public String toString() {
return super.toString().replaceAll(".*\\[", "[");
}
/**
* Returns the revision string.
*
* @return the revision
*/
public String getRevision() {
return RevisionUtils.extract("$Revision: 5928 $");
}
}
/**
* a serializable version of Point
*
* @see java.awt.Point
*/
protected class PointInt
extends java.awt.Point
implements Serializable, RevisionHandler {
/** for serialization */
private static final long serialVersionUID = -5900415163698021618L;
/**
* the default constructor
*
* @param x the x value of the point
* @param y the y value of the point
*/
public PointInt(int x, int y) {
super(x, y);
}
/**
* returns a string representation of the Point
*
* @return the point as string
*/
public String toString() {
return super.toString().replaceAll(".*\\[", "[");
}
/**
* Returns the revision string.
*
* @return the revision
*/
public String getRevision() {
return RevisionUtils.extract("$Revision: 5928 $");
}
}
/**
* for generating the parameter pairs in a grid
*/
protected class Grid
implements Serializable, RevisionHandler {
/** for serialization */
private static final long serialVersionUID = 7290732613611243139L;
/** the minimum on the X axis */
protected double m_MinX;
/** the maximum on the X axis */
protected double m_MaxX;
/** the step size for the X axis */
protected double m_StepX;
/** the label for the X axis */
protected String m_LabelX;
/** the minimum on the Y axis */
protected double m_MinY;
/** the maximum on the Y axis */
protected double m_MaxY;
/** the step size for the Y axis */
protected double m_StepY;
/** the label for the Y axis */
protected String m_LabelY;
/** the number of points on the X axis */
protected int m_Width;
/** the number of points on the Y axis */
protected int m_Height;
/**
* initializes the grid
*
* @param minX the minimum on the X axis
* @param maxX the maximum on the X axis
* @param stepX the step size for the X axis
* @param minY the minimum on the Y axis
* @param maxY the maximum on the Y axis
* @param stepY the step size for the Y axis
*/
public Grid(double minX, double maxX, double stepX,
double minY, double maxY, double stepY) {
this(minX, maxX, stepX, "", minY, maxY, stepY, "");
}
/**
* initializes the grid
*
* @param minX the minimum on the X axis
* @param maxX the maximum on the X axis
* @param stepX the step size for the X axis
* @param labelX the label for the X axis
* @param minY the minimum on the Y axis
* @param maxY the maximum on the Y axis
* @param stepY the step size for the Y axis
* @param labelY the label for the Y axis
*/
public Grid(double minX, double maxX, double stepX, String labelX,
double minY, double maxY, double stepY, String labelY) {
super();
m_MinX = minX;
m_MaxX = maxX;
m_StepX = stepX;
m_LabelX = labelX;
m_MinY = minY;
m_MaxY = maxY;
m_StepY = stepY;
m_LabelY = labelY;
m_Height = (int) StrictMath.round((m_MaxY - m_MinY) / m_StepY) + 1;
m_Width = (int) StrictMath.round((m_MaxX - m_MinX) / m_StepX) + 1;
// is min < max?
if (m_MinX >= m_MaxX)
throw new IllegalArgumentException("XMin must be smaller than XMax!");
if (m_MinY >= m_MaxY)
throw new IllegalArgumentException("YMin must be smaller than YMax!");
// steps positive?
if (m_StepX <= 0)
throw new IllegalArgumentException("XStep must be a positive number!");
if (m_StepY <= 0)
throw new IllegalArgumentException("YStep must be a positive number!");
// check borders
if (!Utils.eq(m_MinX + (m_Width-1)*m_StepX, m_MaxX))
throw new IllegalArgumentException(
"X axis doesn't match! Provided max: " + m_MaxX
+ ", calculated max via min and step size: "
+ (m_MinX + (m_Width-1)*m_StepX));
if (!Utils.eq(m_MinY + (m_Height-1)*m_StepY, m_MaxY))
throw new IllegalArgumentException(
"Y axis doesn't match! Provided max: " + m_MaxY
+ ", calculated max via min and step size: "
+ (m_MinY + (m_Height-1)*m_StepY));
}
/**
* Tests itself against the provided grid object
*
* @param o the grid object to compare against
* @return if the two grids have the same setup
*/
public boolean equals(Object o) {
boolean result;
Grid g;
g = (Grid) o;
result = (width() == g.width())
&& (height() == g.height())
&& (getMinX() == g.getMinX())
&& (getMinY() == g.getMinY())
&& (getStepX() == g.getStepX())
&& (getStepY() == g.getStepY())
&& getLabelX().equals(g.getLabelX())
&& getLabelY().equals(g.getLabelY());
return result;
}
/**
* returns the left border
*
* @return the left border
*/
public double getMinX() {
return m_MinX;
}
/**
* returns the right border
*
* @return the right border
*/
public double getMaxX() {
return m_MaxX;
}
/**
* returns the step size on the X axis
*
* @return the step size
*/
public double getStepX() {
return m_StepX;
}
/**
* returns the label for the X axis
*
* @return the label
*/
public String getLabelX() {
return m_LabelX;
}
/**
* returns the bottom border
*
* @return the bottom border
*/
public double getMinY() {
return m_MinY;
}
/**
* returns the top border
*
* @return the top border
*/
public double getMaxY() {
return m_MaxY;
}
/**
* returns the step size on the Y axis
*
* @return the step size
*/
public double getStepY() {
return m_StepY;
}
/**
* returns the label for the Y axis
*
* @return the label
*/
public String getLabelY() {
return m_LabelY;
}
/**
* returns the number of points in the grid on the Y axis (incl. borders)
*
* @return the number of points in the grid on the Y axis
*/
public int height() {
return m_Height;
}
/**
* returns the number of points in the grid on the X axis (incl. borders)
*
* @return the number of points in the grid on the X axis
*/
public int width() {
return m_Width;
}
/**
* returns the values at the given point in the grid
*
* @param x the x-th point on the X axis
* @param y the y-th point on the Y axis
* @return the value pair at the given position
*/
public PointDouble getValues(int x, int y) {
if (x >= width())
throw new IllegalArgumentException("Index out of scope on X axis (" + x + " >= " + width() + ")!");
if (y >= height())
throw new IllegalArgumentException("Index out of scope on Y axis (" + y + " >= " + height() + ")!");
return new PointDouble(m_MinX + m_StepX*x, m_MinY + m_StepY*y);
}
/**
* returns the closest index pair for the given value pair in the grid.
*
* @param values the values to get the indices for
* @return the closest indices in the grid
*/
public PointInt getLocation(PointDouble values) {
PointInt result;
int x;
int y;
double distance;
double currDistance;
int i;
// determine x
x = 0;
distance = m_StepX;
for (i = 0; i < width(); i++) {
currDistance = StrictMath.abs(values.getX() - getValues(i, 0).getX());
if (Utils.sm(currDistance, distance)) {
distance = currDistance;
x = i;
}
}
// determine y
y = 0;
distance = m_StepY;
for (i = 0; i < height(); i++) {
currDistance = StrictMath.abs(values.getY() - getValues(0, i).getY());
if (Utils.sm(currDistance, distance)) {
distance = currDistance;
y = i;
}
}
result = new PointInt(x, y);
return result;
}
/**
* checks whether the given values are on the border of the grid
*
* @param values the values to check
* @return true if the the values are on the border
*/
public boolean isOnBorder(PointDouble values) {
return isOnBorder(getLocation(values));
}
/**
* checks whether the given location is on the border of the grid
*
* @param location the location to check
* @return true if the the location is on the border
*/
public boolean isOnBorder(PointInt location) {
if (location.getX() == 0)
return true;
else if (location.getX() == width() - 1)
return true;
if (location.getY() == 0)
return true;
else if (location.getY() == height() - 1)
return true;
else
return false;
}
/**
* returns a subgrid with the same step sizes, but different borders
*
* @param top the top index
* @param left the left index
* @param bottom the bottom index
* @param right the right index
* @return the Sub-Grid
*/
public Grid subgrid(int top, int left, int bottom, int right) {
return new Grid(
getValues(left, top).getX(), getValues(right, top).getX(), getStepX(), getLabelX(),
getValues(left, bottom).getY(), getValues(left, top).getY(), getStepY(), getLabelY());
}
/**
* returns an extended grid that encompasses the given point (won't be on
* the border of the grid).
*
* @param values the point that the grid should contain
* @return the extended grid
*/
public Grid extend(PointDouble values) {
double minX;
double maxX;
double minY;
double maxY;
double distance;
Grid result;
// left
if (Utils.smOrEq(values.getX(), getMinX())) {
distance = getMinX() - values.getX();
// exactly on grid point?
if (Utils.eq(distance, 0))
minX = getMinX() - getStepX() * (StrictMath.round(distance / getStepX()) + 1);
else
minX = getMinX() - getStepX() * (StrictMath.round(distance / getStepX()));
}
else {
minX = getMinX();
}
// right
if (Utils.grOrEq(values.getX(), getMaxX())) {
distance = values.getX() - getMaxX();
// exactly on grid point?
if (Utils.eq(distance, 0))
maxX = getMaxX() + getStepX() * (StrictMath.round(distance / getStepX()) + 1);
else
maxX = getMaxX() + getStepX() * (StrictMath.round(distance / getStepX()));
}
else {
maxX = getMaxX();
}
// bottom
if (Utils.smOrEq(values.getY(), getMinY())) {
distance = getMinY() - values.getY();
// exactly on grid point?
if (Utils.eq(distance, 0))
minY = getMinY() - getStepY() * (StrictMath.round(distance / getStepY()) + 1);
else
minY = getMinY() - getStepY() * (StrictMath.round(distance / getStepY()));
}
else {
minY = getMinY();
}
// top
if (Utils.grOrEq(values.getY(), getMaxY())) {
distance = values.getY() - getMaxY();
// exactly on grid point?
if (Utils.eq(distance, 0))
maxY = getMaxY() + getStepY() * (StrictMath.round(distance / getStepY()) + 1);
else
maxY = getMaxY() + getStepY() * (StrictMath.round(distance / getStepY()));
}
else {
maxY = getMaxY();
}
result = new Grid(minX, maxX, getStepX(), getLabelX(), minY, maxY, getStepY(), getLabelY());
// did the grid really extend?
if (equals(result))
throw new IllegalStateException("Grid extension failed!");
return result;
}
/**
* returns an Enumeration over all pairs in the given row
*
* @param y the row to retrieve
* @return an Enumeration over all pairs
* @see #getValues(int, int)
*/
public Enumeration row(int y) {
Vector result;
int i;
result = new Vector();
for (i = 0; i < width(); i++)
result.add(getValues(i, y));
return result.elements();
}
/**
* returns an Enumeration over all pairs in the given column
*
* @param x the column to retrieve
* @return an Enumeration over all pairs
* @see #getValues(int, int)
*/
public Enumeration column(int x) {
Vector result;
int i;
result = new Vector();
for (i = 0; i < height(); i++)
result.add(getValues(x, i));
return result.elements();
}
/**
* returns a string representation of the grid
*
* @return a string representation
*/
public String toString() {
String result;
result = "X: " + m_MinX + " - " + m_MaxX + ", Step " + m_StepX;
if (m_LabelX.length() != 0)
result += " (" + m_LabelX + ")";
result += "\n";
result += "Y: " + m_MinY + " - " + m_MaxY + ", Step " + m_StepY;
if (m_LabelY.length() != 0)
result += " (" + m_LabelY + ")";
result += "\n";
result += "Dimensions (Rows x Columns): " + height() + " x " + width();
return result;
}
/**
* Returns the revision string.
*
* @return the revision
*/
public String getRevision() {
return RevisionUtils.extract("$Revision: 5928 $");
}
}
/**
* A helper class for storing the performance of a values-pair.
* Can be sorted with the PerformanceComparator class.
*
* @see PerformanceComparator
*/
protected class Performance
implements Serializable, RevisionHandler {
/** for serialization */
private static final long serialVersionUID = -4374706475277588755L;
/** the value pair the classifier was built with */
protected PointDouble m_Values;
/** the Correlation coefficient */
protected double m_CC;
/** the Root mean squared error */
protected double m_RMSE;
/** the Root relative squared error */
protected double m_RRSE;
/** the Mean absolute error */
protected double m_MAE;
/** the Relative absolute error */
protected double m_RAE;
/** the Accuracy */
protected double m_ACC;
/** the kappa value */
protected double m_Kappa;
/**
* initializes the performance container
*
* @param values the values-pair
* @param evaluation the evaluation to extract the performance
* measures from
* @throws Exception if retrieving of measures fails
*/
public Performance(PointDouble values, Evaluation evaluation) throws Exception {
super();
m_Values = values;
m_RMSE = evaluation.rootMeanSquaredError();
m_RRSE = evaluation.rootRelativeSquaredError();
m_MAE = evaluation.meanAbsoluteError();
m_RAE = evaluation.relativeAbsoluteError();
try {
m_CC = evaluation.correlationCoefficient();
}
catch (Exception e) {
m_CC = Double.NaN;
}
try {
m_ACC = evaluation.pctCorrect();
}
catch (Exception e) {
m_ACC = Double.NaN;
}
try {
m_Kappa = evaluation.kappa();
}
catch (Exception e) {
m_Kappa = Double.NaN;
}
}
/**
* returns the performance measure
*
* @param evaluation the type of measure to return
* @return the performance measure
*/
public double getPerformance(int evaluation) {
double result;
result = Double.NaN;
switch (evaluation) {
case EVALUATION_CC:
result = m_CC;
break;
case EVALUATION_RMSE:
result = m_RMSE;
break;
case EVALUATION_RRSE:
result = m_RRSE;
break;
case EVALUATION_MAE:
result = m_MAE;
break;
case EVALUATION_RAE:
result = m_RAE;
break;
case EVALUATION_COMBINED:
result = (1 - StrictMath.abs(m_CC)) + m_RRSE + m_RAE;
break;
case EVALUATION_ACC:
result = m_ACC;
break;
case EVALUATION_KAPPA:
result = m_Kappa;
break;
default:
throw new IllegalArgumentException("Evaluation type '" + evaluation + "' not supported!");
}
return result;
}
/**
* returns the values-pair for this performance
*
* @return the values-pair
*/
public PointDouble getValues() {
return m_Values;
}
/**
* returns a string representation of this performance object
*
* @param evaluation the type of performance to return
* @return a string representation
*/
public String toString(int evaluation) {
String result;
result = "Performance (" + getValues() + "): "
+ getPerformance(evaluation)
+ " (" + new SelectedTag(evaluation, TAGS_EVALUATION) + ")";
return result;
}
/**
* returns a Gnuplot string of this performance object
*
* @param evaluation the type of performance to return
* @return the gnuplot string (x, y, z)
*/
public String toGnuplot(int evaluation) {
String result;
result = getValues().getX() + "\t"
+ getValues().getY() + "\t"
+ getPerformance(evaluation);
return result;
}
/**
* returns a string representation of this performance object
*
* @return a string representation
*/
public String toString() {
String result;
int i;
result = "Performance (" + getValues() + "): ";
for (i = 0; i < TAGS_EVALUATION.length; i++) {
if (i > 0)
result += ", ";
result += getPerformance(TAGS_EVALUATION[i].getID())
+ " (" + new SelectedTag(TAGS_EVALUATION[i].getID(), TAGS_EVALUATION) + ")";
}
return result;
}
/**
* Returns the revision string.
*
* @return the revision
*/
public String getRevision() {
return RevisionUtils.extract("$Revision: 5928 $");
}
}
/**
* A concrete Comparator for the Performance class.
*
* @see Performance
*/
protected class PerformanceComparator
implements Comparator, Serializable, RevisionHandler {
/** for serialization */
private static final long serialVersionUID = 6507592831825393847L;
/** the performance measure to use for comparison
* @see GridSearch#TAGS_EVALUATION */
protected int m_Evaluation;
/**
* initializes the comparator with the given performance measure
*
* @param evaluation the performance measure to use
* @see GridSearch#TAGS_EVALUATION
*/
public PerformanceComparator(int evaluation) {
super();
m_Evaluation = evaluation;
}
/**
* returns the performance measure that's used to compare the objects
*
* @return the performance measure
* @see GridSearch#TAGS_EVALUATION
*/
public int getEvaluation() {
return m_Evaluation;
}
/**
* Compares its two arguments for order. Returns a negative integer,
* zero, or a positive integer as the first argument is less than,
* equal to, or greater than the second.
*
* @param o1 the first performance
* @param o2 the second performance
* @return the order
*/
public int compare(Performance o1, Performance o2) {
int result;
double p1;
double p2;
p1 = o1.getPerformance(getEvaluation());
p2 = o2.getPerformance(getEvaluation());
if (Utils.sm(p1, p2))
result = -1;
else if (Utils.gr(p1, p2))
result = 1;
else
result = 0;
// only correlation coefficient/accuracy/kappa obey to this order, for the
// errors (and the combination of all three), the smaller the number the
// better -> hence invert them
if ( (getEvaluation() != EVALUATION_CC)
&& (getEvaluation() != EVALUATION_ACC)
&& (getEvaluation() != EVALUATION_KAPPA) )
result = -result;
return result;
}
/**
* Indicates whether some other object is "equal to" this Comparator.
*
* @param obj the object to compare with
* @return true if the same evaluation type is used
*/
public boolean equals(Object obj) {
if (!(obj instanceof PerformanceComparator))
throw new IllegalArgumentException("Must be PerformanceComparator!");
return (m_Evaluation == ((PerformanceComparator) obj).m_Evaluation);
}
/**
* Returns the revision string.
*
* @return the revision
*/
public String getRevision() {
return RevisionUtils.extract("$Revision: 5928 $");
}
}
/**
* Generates a 2-dim array for the performances from a grid for a certain
* type. x-min/y-min is in the bottom-left corner, i.e., getTable()[0][0]
* returns the performance for the x-min/y-max pair.
*
*/
protected class PerformanceTable
implements Serializable, RevisionHandler {
/** for serialization */
private static final long serialVersionUID = 5486491313460338379L;
/** the corresponding grid */
protected Grid m_Grid;
/** the performances */
protected Vector m_Performances;
/** the type of performance the table was generated for */
protected int m_Type;
/** the table with the values */
protected double[][] m_Table;
/** the minimum performance */
protected double m_Min;
/** the maximum performance */
protected double m_Max;
/**
* initializes the table
*
* @param grid the underlying grid
* @param performances the performances
* @param type the type of performance
*/
public PerformanceTable(Grid grid, Vector performances, int type) {
super();
m_Grid = grid;
m_Type = type;
m_Performances = performances;
generate();
}
/**
* generates the table
*/
protected void generate() {
Performance perf;
int i;
PointInt location;
m_Table = new double[getGrid().height()][getGrid().width()];
m_Min = 0;
m_Max = 0;
for (i = 0; i < getPerformances().size(); i++) {
perf = (Performance) getPerformances().get(i);
location = getGrid().getLocation(perf.getValues());
m_Table[getGrid().height() - (int) location.getY() - 1][(int) location.getX()] = perf.getPerformance(getType());
// determine min/max
if (i == 0) {
m_Min = perf.getPerformance(m_Type);
m_Max = m_Min;
}
else {
if (perf.getPerformance(m_Type) < m_Min)
m_Min = perf.getPerformance(m_Type);
if (perf.getPerformance(m_Type) > m_Max)
m_Max = perf.getPerformance(m_Type);
}
}
}
/**
* returns the corresponding grid
*
* @return the underlying grid
*/
public Grid getGrid() {
return m_Grid;
}
/**
* returns the underlying performances
*
* @return the underlying performances
*/
public Vector getPerformances() {
return m_Performances;
}
/**
* returns the type of performance
*
* @return the type of performance
*/
public int getType() {
return m_Type;
}
/**
* returns the generated table
*
* @return the performance table
* @see #m_Table
* @see #generate()
*/
public double[][] getTable() {
return m_Table;
}
/**
* the minimum performance
*
* @return the performance
*/
public double getMin() {
return m_Min;
}
/**
* the maximum performance
*
* @return the performance
*/
public double getMax() {
return m_Max;
}
/**
* returns the table as string
*
* @return the table as string
*/
public String toString() {
String result;
int i;
int n;
result = "Table ("
+ new SelectedTag(getType(), TAGS_EVALUATION).getSelectedTag().getReadable()
+ ") - "
+ "X: " + getGrid().getLabelX() + ", Y: " + getGrid().getLabelY()
+ ":\n";
for (i = 0; i < getTable().length; i++) {
if (i > 0)
result += "\n";
for (n = 0; n < getTable()[i].length; n++) {
if (n > 0)
result += ",";
result += getTable()[i][n];
}
}
return result;
}
/**
* returns a string containing a gnuplot script+data file
*
* @return the data in gnuplot format
*/
public String toGnuplot() {
StringBuffer result;
Tag type;
int i;
result = new StringBuffer();
type = new SelectedTag(getType(), TAGS_EVALUATION).getSelectedTag();
result.append("Gnuplot (" + type.getReadable() + "):\n");
result.append("# begin 'gridsearch.data'\n");
result.append("# " + type.getReadable() + "\n");
for (i = 0; i < getPerformances().size(); i++)
result.append(getPerformances().get(i).toGnuplot(type.getID()) + "\n");
result.append("# end 'gridsearch.data'\n\n");
result.append("# begin 'gridsearch.plot'\n");
result.append("# " + type.getReadable() + "\n");
result.append("set data style lines\n");
result.append("set contour base\n");
result.append("set surface\n");
result.append("set title '" + m_Data.relationName() + "'\n");
result.append("set xrange [" + getGrid().getMinX() + ":" + getGrid().getMaxX() + "]\n");
result.append("set xlabel 'x (" + getFilter().getClass().getName() + ": " + getXProperty() + ")'\n");
result.append("set yrange [" + getGrid().getMinY() + ":" + getGrid().getMaxY() + "]\n");
result.append("set ylabel 'y - (" + getClassifier().getClass().getName() + ": " + getYProperty() + ")'\n");
result.append("set zrange [" + (getMin() - (getMax() - getMin())*0.1) + ":" + (getMax() + (getMax() - getMin())*0.1) + "]\n");
result.append("set zlabel 'z - " + type.getReadable() + "'\n");
result.append("set dgrid3d " + getGrid().height() + "," + getGrid().width() + ",1\n");
result.append("show contour\n");
result.append("splot 'gridsearch.data'\n");
result.append("pause -1\n");
result.append("# end 'gridsearch.plot'");
return result.toString();
}
/**
* Returns the revision string.
*
* @return the revision
*/
public String getRevision() {
return RevisionUtils.extract("$Revision: 5928 $");
}
}
/**
* Represents a simple cache for performance objects.
*/
protected class PerformanceCache
implements Serializable, RevisionHandler {
/** for serialization */
private static final long serialVersionUID = 5838863230451530252L;
/** the cache for points in the grid that got calculated */
protected Hashtable m_Cache = new Hashtable();
/**
* returns the ID string for a cache item
*
* @param cv the number of folds in the cross-validation
* @param values the point in the grid
* @return the ID string
*/
protected String getID(int cv, PointDouble values) {
return cv + "\t" + values.getX() + "\t" + values.getY();
}
/**
* checks whether the point was already calculated ones
*
* @param cv the number of folds in the cross-validation
* @param values the point in the grid
* @return true if the value is already cached
*/
public boolean isCached(int cv, PointDouble values) {
return (get(cv, values) != null);
}
/**
* returns a cached performance object, null if not yet in the cache
*
* @param cv the number of folds in the cross-validation
* @param values the point in the grid
* @return the cached performance item, null if not in cache
*/
public Performance get(int cv, PointDouble values) {
return (Performance) m_Cache.get(getID(cv, values));
}
/**
* adds the performance to the cache
*
* @param cv the number of folds in the cross-validation
* @param p the performance object to store
*/
public void add(int cv, Performance p) {
m_Cache.put(getID(cv, p.getValues()), p);
}
/**
* returns a string representation of the cache
*
* @return the string representation of the cache
*/
public String toString() {
return m_Cache.toString();
}
/**
* Returns the revision string.
*
* @return the revision
*/
public String getRevision() {
return RevisionUtils.extract("$Revision: 5928 $");
}
}
/** for serialization */
private static final long serialVersionUID = -3034773968581595348L;
/** evaluation via: Correlation coefficient */
public static final int EVALUATION_CC = 0;
/** evaluation via: Root mean squared error */
public static final int EVALUATION_RMSE = 1;
/** evaluation via: Root relative squared error */
public static final int EVALUATION_RRSE = 2;
/** evaluation via: Mean absolute error */
public static final int EVALUATION_MAE = 3;
/** evaluation via: Relative absolute error */
public static final int EVALUATION_RAE = 4;
/** evaluation via: Combined = (1-CC) + RRSE + RAE */
public static final int EVALUATION_COMBINED = 5;
/** evaluation via: Accuracy */
public static final int EVALUATION_ACC = 6;
/** evaluation via: kappa statistic */
public static final int EVALUATION_KAPPA = 7;
/** evaluation */
public static final Tag[] TAGS_EVALUATION = {
new Tag(EVALUATION_CC, "CC", "Correlation coefficient"),
new Tag(EVALUATION_RMSE, "RMSE", "Root mean squared error"),
new Tag(EVALUATION_RRSE, "RRSE", "Root relative squared error"),
new Tag(EVALUATION_MAE, "MAE", "Mean absolute error"),
new Tag(EVALUATION_RAE, "RAE", "Root absolute error"),
new Tag(EVALUATION_COMBINED, "COMB", "Combined = (1-abs(CC)) + RRSE + RAE"),
new Tag(EVALUATION_ACC, "ACC", "Accuracy"),
new Tag(EVALUATION_KAPPA, "KAP", "Kappa")
};
/** row-wise grid traversal */
public static final int TRAVERSAL_BY_ROW = 0;
/** column-wise grid traversal */
public static final int TRAVERSAL_BY_COLUMN = 1;
/** traversal */
public static final Tag[] TAGS_TRAVERSAL = {
new Tag(TRAVERSAL_BY_ROW, "row-wise", "row-wise"),
new Tag(TRAVERSAL_BY_COLUMN, "column-wise", "column-wise")
};
/** the prefix to indicate that the option is for the classifier */
public final static String PREFIX_CLASSIFIER = "classifier.";
/** the prefix to indicate that the option is for the filter */
public final static String PREFIX_FILTER = "filter.";
/** the Filter */
protected Filter m_Filter;
/** the Filter with the best setup */
protected Filter m_BestFilter;
/** the Classifier with the best setup */
protected Classifier m_BestClassifier;
/** the best values */
protected PointDouble m_Values = null;
/** the type of evaluation */
protected int m_Evaluation = EVALUATION_CC;
/** the Y option to work on (without leading dash, preceding 'classifier.'
* means to set the option for the classifier 'filter.' for the filter) */
protected String m_Y_Property = PREFIX_CLASSIFIER + "ridge";
/** the minimum of Y */
protected double m_Y_Min = -10;
/** the maximum of Y */
protected double m_Y_Max = +5;
/** the step size of Y */
protected double m_Y_Step = 1;
/** the base for Y */
protected double m_Y_Base = 10;
/**
* The expression for the Y property. Available parameters for the
* expression:
*
*
BASE
*
FROM (= min)
*
TO (= max)
*
STEP
*
I - the current value (from 'from' to 'to' with stepsize 'step')
*
*
* @see MathematicalExpression
* @see MathExpression
*/
protected String m_Y_Expression = "pow(BASE,I)";
/** the X option to work on (without leading dash, preceding 'classifier.'
* means to set the option for the classifier 'filter.' for the filter) */
protected String m_X_Property = PREFIX_FILTER + "numComponents";
/** the minimum of X */
protected double m_X_Min = +5;
/** the maximum of X */
protected double m_X_Max = +20;
/** the step size of */
protected double m_X_Step = 1;
/** the base for */
protected double m_X_Base = 10;
/**
* The expression for the X property. Available parameters for the
* expression:
*
*
BASE
*
FROM (= min)
*
TO (= max)
*
STEP
*
I - the current value (from 'from' to 'to' with stepsize 'step')
*
*
* @see MathematicalExpression
* @see MathExpression
*/
protected String m_X_Expression = "I";
/** whether the grid can be extended */
protected boolean m_GridIsExtendable = false;
/** maximum number of grid extensions (-1 means unlimited) */
protected int m_MaxGridExtensions = 3;
/** the number of extensions performed */
protected int m_GridExtensionsPerformed = 0;
/** the sample size to search the initial grid with */
protected double m_SampleSize = 100;
/** the traversal */
protected int m_Traversal = TRAVERSAL_BY_COLUMN;
/** the log file to use */
protected File m_LogFile = new File(System.getProperty("user.dir"));
/** the value-pairs grid */
protected Grid m_Grid;
/** the training data */
protected Instances m_Data;
/** the cache for points in the grid that got calculated */
protected PerformanceCache m_Cache;
/** whether all performances in the grid are the same */
protected boolean m_UniformPerformance = false;
/**
* the default constructor
*/
public GridSearch() {
super();
// classifier
m_Classifier = new LinearRegression();
((LinearRegression) m_Classifier).setAttributeSelectionMethod(new SelectedTag(LinearRegression.SELECTION_NONE, LinearRegression.TAGS_SELECTION));
((LinearRegression) m_Classifier).setEliminateColinearAttributes(false);
// filter
m_Filter = new PLSFilter();
PLSFilter filter = new PLSFilter();
filter.setPreprocessing(new SelectedTag(PLSFilter.PREPROCESSING_STANDARDIZE, PLSFilter.TAGS_PREPROCESSING));
filter.setReplaceMissing(true);
try {
m_BestClassifier = AbstractClassifier.makeCopy(m_Classifier);
}
catch (Exception e) {
e.printStackTrace();
}
try {
m_BestFilter = Filter.makeCopy(filter);
}
catch (Exception e) {
e.printStackTrace();
}
}
/**
* Returns a string describing classifier
*
* @return a description suitable for displaying in the
* explorer/experimenter gui
*/
public String globalInfo() {
return
"Performs a grid search of parameter pairs for the a classifier "
+ "(Y-axis, default is LinearRegression with the \"Ridge\" parameter) "
+ "and the PLSFilter (X-axis, \"# of Components\") and chooses the best "
+ "pair found for the actual predicting.\n\n"
+ "The initial grid is worked on with 2-fold CV to determine the values "
+ "of the parameter pairs for the selected type of evaluation (e.g., "
+ "accuracy). The best point in the grid is then taken and a 10-fold CV "
+ "is performed with the adjacent parameter pairs. If a better pair is "
+ "found, then this will act as new center and another 10-fold CV will "
+ "be performed (kind of hill-climbing). This process is repeated until "
+ "no better pair is found or the best pair is on the border of the grid.\n"
+ "In case the best pair is on the border, one can let GridSearch "
+ "automatically extend the grid and continue the search. Check out the "
+ "properties 'gridIsExtendable' (option '-extend-grid') and "
+ "'maxGridExtensions' (option '-max-grid-extensions ').\n\n"
+ "GridSearch can handle doubles, integers (values are just cast to int) "
+ "and booleans (0 is false, otherwise true). float, char and long are "
+ "supported as well.\n\n"
+ "The best filter/classifier setup can be accessed after the buildClassifier "
+ "call via the getBestFilter/getBestClassifier methods.\n"
+ "Note on the implementation: after the data has been passed through "
+ "the filter, a default NumericCleaner filter is applied to the data in "
+ "order to avoid numbers that are getting too small and might produce "
+ "NaNs in other schemes.";
}
/**
* String describing default classifier.
*
* @return the classname of the default classifier
*/
protected String defaultClassifierString() {
return LinearRegression.class.getName();
}
/**
* Gets an enumeration describing the available options.
*
* @return an enumeration of all the available options.
*/
public Enumeration listOptions(){
Vector result;
Enumeration en;
String desc;
SelectedTag tag;
int i;
result = new Vector();
desc = "";
for (i = 0; i < TAGS_EVALUATION.length; i++) {
tag = new SelectedTag(TAGS_EVALUATION[i].getID(), TAGS_EVALUATION);
desc += "\t" + tag.getSelectedTag().getIDStr()
+ " = " + tag.getSelectedTag().getReadable()
+ "\n";
}
result.addElement(new Option(
"\tDetermines the parameter used for evaluation:\n"
+ desc
+ "\t(default: " + new SelectedTag(EVALUATION_CC, TAGS_EVALUATION) + ")",
"E", 1, "-E " + Tag.toOptionList(TAGS_EVALUATION)));
result.addElement(new Option(
"\tThe Y option to test (without leading dash).\n"
+ "\t(default: " + PREFIX_CLASSIFIER + "ridge)",
"y-property", 1, "-y-property