/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * GridSearch.java * Copyright (C) 2006 University of Waikato, Hamilton, New Zealand */ package weka.classifiers.meta; import weka.classifiers.Classifier; import weka.classifiers.AbstractClassifier; import weka.classifiers.Evaluation; import weka.classifiers.RandomizableSingleClassifierEnhancer; import weka.classifiers.functions.LinearRegression; import weka.core.AdditionalMeasureProducer; import weka.core.Capabilities; import weka.core.Debug; import weka.core.Instance; import weka.core.Instances; import weka.core.MathematicalExpression; import weka.core.Option; import weka.core.OptionHandler; import weka.core.PropertyPath; import weka.core.RevisionHandler; import weka.core.RevisionUtils; import weka.core.SelectedTag; import weka.core.SerializedObject; import weka.core.Summarizable; import weka.core.Tag; import weka.core.Utils; import weka.core.Capabilities.Capability; import weka.filters.Filter; import weka.filters.supervised.attribute.PLSFilter; import weka.filters.unsupervised.attribute.MathExpression; import weka.filters.unsupervised.attribute.NumericCleaner; import weka.filters.unsupervised.instance.Resample; import java.beans.PropertyDescriptor; import java.io.File; import java.io.Serializable; import java.util.Collections; import java.util.Comparator; import java.util.Enumeration; import java.util.HashMap; import java.util.Hashtable; import java.util.Iterator; import java.util.Random; import java.util.Vector; /** * Performs a grid search of parameter pairs for the a classifier (Y-axis, default is LinearRegression with the "Ridge" parameter) and the PLSFilter (X-axis, "# of Components") and chooses the best pair found for the actual predicting.
*
* The initial grid is worked on with 2-fold CV to determine the values of the parameter pairs for the selected type of evaluation (e.g., accuracy). The best point in the grid is then taken and a 10-fold CV is performed with the adjacent parameter pairs. If a better pair is found, then this will act as new center and another 10-fold CV will be performed (kind of hill-climbing). This process is repeated until no better pair is found or the best pair is on the border of the grid.
* In case the best pair is on the border, one can let GridSearch automatically extend the grid and continue the search. Check out the properties 'gridIsExtendable' (option '-extend-grid') and 'maxGridExtensions' (option '-max-grid-extensions <num>').
*
* GridSearch can handle doubles, integers (values are just cast to int) and booleans (0 is false, otherwise true). float, char and long are supported as well.
*
* The best filter/classifier setup can be accessed after the buildClassifier call via the getBestFilter/getBestClassifier methods.
* Note on the implementation: after the data has been passed through the filter, a default NumericCleaner filter is applied to the data in order to avoid numbers that are getting too small and might produce NaNs in other schemes. *

* * Valid options are:

* *

 -E <CC|RMSE|RRSE|MAE|RAE|COMB|ACC|KAP>
 *  Determines the parameter used for evaluation:
 *  CC = Correlation coefficient
 *  RMSE = Root mean squared error
 *  RRSE = Root relative squared error
 *  MAE = Mean absolute error
 *  RAE = Root absolute error
 *  COMB = Combined = (1-abs(CC)) + RRSE + RAE
 *  ACC = Accuracy
 *  KAP = Kappa
 *  (default: CC)
* *
 -y-property <option>
 *  The Y option to test (without leading dash).
 *  (default: classifier.ridge)
* *
 -y-min <num>
 *  The minimum for Y.
 *  (default: -10)
* *
 -y-max <num>
 *  The maximum for Y.
 *  (default: +5)
* *
 -y-step <num>
 *  The step size for Y.
 *  (default: 1)
* *
 -y-base <num>
 *  The base for Y.
 *  (default: 10)
* *
 -y-expression <expr>
 *  The expression for Y.
 *  Available parameters:
 *   BASE
 *   FROM
 *   TO
 *   STEP
 *   I - the current iteration value
 *   (from 'FROM' to 'TO' with stepsize 'STEP')
 *  (default: 'pow(BASE,I)')
* *
 -filter <filter specification>
 *  The filter to use (on X axis). Full classname of filter to include, 
 *  followed by scheme options.
 *  (default: weka.filters.supervised.attribute.PLSFilter)
* *
 -x-property <option>
 *  The X option to test (without leading dash).
 *  (default: filter.numComponents)
* *
 -x-min <num>
 *  The minimum for X.
 *  (default: +5)
* *
 -x-max <num>
 *  The maximum for X.
 *  (default: +20)
* *
 -x-step <num>
 *  The step size for X.
 *  (default: 1)
* *
 -x-base <num>
 *  The base for X.
 *  (default: 10)
* *
 -x-expression <expr>
 *  The expression for the X value.
 *  Available parameters:
 *   BASE
 *   MIN
 *   MAX
 *   STEP
 *   I - the current iteration value
 *   (from 'FROM' to 'TO' with stepsize 'STEP')
 *  (default: 'pow(BASE,I)')
* *
 -extend-grid
 *  Whether the grid can be extended.
 *  (default: no)
* *
 -max-grid-extensions <num>
 *  The maximum number of grid extensions (-1 is unlimited).
 *  (default: 3)
* *
 -sample-size <num>
 *  The size (in percent) of the sample to search the inital grid with.
 *  (default: 100)
* *
 -traversal <ROW-WISE|COLUMN-WISE>
 *  The type of traversal for the grid.
 *  (default: COLUMN-WISE)
* *
 -log-file <filename>
 *  The log file to log the messages to.
 *  (default: none)
* *
 -S <num>
 *  Random number seed.
 *  (default 1)
* *
 -D
 *  If set, classifier is run in debug mode and
 *  may output additional info to the console
* *
 -W
 *  Full name of base classifier.
 *  (default: weka.classifiers.functions.LinearRegression)
* *
 
 * Options specific to classifier weka.classifiers.functions.LinearRegression:
 * 
* *
 -D
 *  Produce debugging output.
 *  (default no debugging output)
* *
 -S <number of selection method>
 *  Set the attribute selection method to use. 1 = None, 2 = Greedy.
 *  (default 0 = M5' method)
* *
 -C
 *  Do not try to eliminate colinear attributes.
 * 
* *
 -R <double>
 *  Set ridge parameter (default 1.0e-8).
 * 
* *
 
 * Options specific to filter weka.filters.supervised.attribute.PLSFilter ('-filter'):
 * 
* *
 -D
 *  Turns on output of debugging information.
* *
 -C <num>
 *  The number of components to compute.
 *  (default: 20)
* *
 -U
 *  Updates the class attribute as well.
 *  (default: off)
* *
 -M
 *  Turns replacing of missing values on.
 *  (default: off)
* *
 -A <SIMPLS|PLS1>
 *  The algorithm to use.
 *  (default: PLS1)
* *
 -P <none|center|standardize>
 *  The type of preprocessing that is applied to the data.
 *  (default: center)
* * * Examples: * * * General notes: * * * @author Bernhard Pfahringer (bernhard at cs dot waikato dot ac dot nz) * @author Geoff Holmes (geoff at cs dot waikato dot ac dot nz) * @author fracpete (fracpete at waikato dot ac dot nz) * @version $Revision: 5928 $ * @see PLSFilter * @see LinearRegression * @see NumericCleaner */ public class GridSearch extends RandomizableSingleClassifierEnhancer implements AdditionalMeasureProducer, Summarizable { /** * a serializable version of Point2D.Double * * @see java.awt.geom.Point2D.Double */ protected class PointDouble extends java.awt.geom.Point2D.Double implements Serializable, RevisionHandler { /** for serialization */ private static final long serialVersionUID = 7151661776161898119L; /** * the default constructor * * @param x the x value of the point * @param y the y value of the point */ public PointDouble(double x, double y) { super(x, y); } /** * Determines whether or not two points are equal. * * @param obj an object to be compared with this PointDouble * @return true if the object to be compared has the same values; * false otherwise. */ public boolean equals(Object obj) { PointDouble pd; pd = (PointDouble) obj; return (Utils.eq(this.getX(), pd.getX()) && Utils.eq(this.getY(), pd.getY())); } /** * returns a string representation of the Point * * @return the point as string */ public String toString() { return super.toString().replaceAll(".*\\[", "["); } /** * Returns the revision string. * * @return the revision */ public String getRevision() { return RevisionUtils.extract("$Revision: 5928 $"); } } /** * a serializable version of Point * * @see java.awt.Point */ protected class PointInt extends java.awt.Point implements Serializable, RevisionHandler { /** for serialization */ private static final long serialVersionUID = -5900415163698021618L; /** * the default constructor * * @param x the x value of the point * @param y the y value of the point */ public PointInt(int x, int y) { super(x, y); } /** * returns a string representation of the Point * * @return the point as string */ public String toString() { return super.toString().replaceAll(".*\\[", "["); } /** * Returns the revision string. * * @return the revision */ public String getRevision() { return RevisionUtils.extract("$Revision: 5928 $"); } } /** * for generating the parameter pairs in a grid */ protected class Grid implements Serializable, RevisionHandler { /** for serialization */ private static final long serialVersionUID = 7290732613611243139L; /** the minimum on the X axis */ protected double m_MinX; /** the maximum on the X axis */ protected double m_MaxX; /** the step size for the X axis */ protected double m_StepX; /** the label for the X axis */ protected String m_LabelX; /** the minimum on the Y axis */ protected double m_MinY; /** the maximum on the Y axis */ protected double m_MaxY; /** the step size for the Y axis */ protected double m_StepY; /** the label for the Y axis */ protected String m_LabelY; /** the number of points on the X axis */ protected int m_Width; /** the number of points on the Y axis */ protected int m_Height; /** * initializes the grid * * @param minX the minimum on the X axis * @param maxX the maximum on the X axis * @param stepX the step size for the X axis * @param minY the minimum on the Y axis * @param maxY the maximum on the Y axis * @param stepY the step size for the Y axis */ public Grid(double minX, double maxX, double stepX, double minY, double maxY, double stepY) { this(minX, maxX, stepX, "", minY, maxY, stepY, ""); } /** * initializes the grid * * @param minX the minimum on the X axis * @param maxX the maximum on the X axis * @param stepX the step size for the X axis * @param labelX the label for the X axis * @param minY the minimum on the Y axis * @param maxY the maximum on the Y axis * @param stepY the step size for the Y axis * @param labelY the label for the Y axis */ public Grid(double minX, double maxX, double stepX, String labelX, double minY, double maxY, double stepY, String labelY) { super(); m_MinX = minX; m_MaxX = maxX; m_StepX = stepX; m_LabelX = labelX; m_MinY = minY; m_MaxY = maxY; m_StepY = stepY; m_LabelY = labelY; m_Height = (int) StrictMath.round((m_MaxY - m_MinY) / m_StepY) + 1; m_Width = (int) StrictMath.round((m_MaxX - m_MinX) / m_StepX) + 1; // is min < max? if (m_MinX >= m_MaxX) throw new IllegalArgumentException("XMin must be smaller than XMax!"); if (m_MinY >= m_MaxY) throw new IllegalArgumentException("YMin must be smaller than YMax!"); // steps positive? if (m_StepX <= 0) throw new IllegalArgumentException("XStep must be a positive number!"); if (m_StepY <= 0) throw new IllegalArgumentException("YStep must be a positive number!"); // check borders if (!Utils.eq(m_MinX + (m_Width-1)*m_StepX, m_MaxX)) throw new IllegalArgumentException( "X axis doesn't match! Provided max: " + m_MaxX + ", calculated max via min and step size: " + (m_MinX + (m_Width-1)*m_StepX)); if (!Utils.eq(m_MinY + (m_Height-1)*m_StepY, m_MaxY)) throw new IllegalArgumentException( "Y axis doesn't match! Provided max: " + m_MaxY + ", calculated max via min and step size: " + (m_MinY + (m_Height-1)*m_StepY)); } /** * Tests itself against the provided grid object * * @param o the grid object to compare against * @return if the two grids have the same setup */ public boolean equals(Object o) { boolean result; Grid g; g = (Grid) o; result = (width() == g.width()) && (height() == g.height()) && (getMinX() == g.getMinX()) && (getMinY() == g.getMinY()) && (getStepX() == g.getStepX()) && (getStepY() == g.getStepY()) && getLabelX().equals(g.getLabelX()) && getLabelY().equals(g.getLabelY()); return result; } /** * returns the left border * * @return the left border */ public double getMinX() { return m_MinX; } /** * returns the right border * * @return the right border */ public double getMaxX() { return m_MaxX; } /** * returns the step size on the X axis * * @return the step size */ public double getStepX() { return m_StepX; } /** * returns the label for the X axis * * @return the label */ public String getLabelX() { return m_LabelX; } /** * returns the bottom border * * @return the bottom border */ public double getMinY() { return m_MinY; } /** * returns the top border * * @return the top border */ public double getMaxY() { return m_MaxY; } /** * returns the step size on the Y axis * * @return the step size */ public double getStepY() { return m_StepY; } /** * returns the label for the Y axis * * @return the label */ public String getLabelY() { return m_LabelY; } /** * returns the number of points in the grid on the Y axis (incl. borders) * * @return the number of points in the grid on the Y axis */ public int height() { return m_Height; } /** * returns the number of points in the grid on the X axis (incl. borders) * * @return the number of points in the grid on the X axis */ public int width() { return m_Width; } /** * returns the values at the given point in the grid * * @param x the x-th point on the X axis * @param y the y-th point on the Y axis * @return the value pair at the given position */ public PointDouble getValues(int x, int y) { if (x >= width()) throw new IllegalArgumentException("Index out of scope on X axis (" + x + " >= " + width() + ")!"); if (y >= height()) throw new IllegalArgumentException("Index out of scope on Y axis (" + y + " >= " + height() + ")!"); return new PointDouble(m_MinX + m_StepX*x, m_MinY + m_StepY*y); } /** * returns the closest index pair for the given value pair in the grid. * * @param values the values to get the indices for * @return the closest indices in the grid */ public PointInt getLocation(PointDouble values) { PointInt result; int x; int y; double distance; double currDistance; int i; // determine x x = 0; distance = m_StepX; for (i = 0; i < width(); i++) { currDistance = StrictMath.abs(values.getX() - getValues(i, 0).getX()); if (Utils.sm(currDistance, distance)) { distance = currDistance; x = i; } } // determine y y = 0; distance = m_StepY; for (i = 0; i < height(); i++) { currDistance = StrictMath.abs(values.getY() - getValues(0, i).getY()); if (Utils.sm(currDistance, distance)) { distance = currDistance; y = i; } } result = new PointInt(x, y); return result; } /** * checks whether the given values are on the border of the grid * * @param values the values to check * @return true if the the values are on the border */ public boolean isOnBorder(PointDouble values) { return isOnBorder(getLocation(values)); } /** * checks whether the given location is on the border of the grid * * @param location the location to check * @return true if the the location is on the border */ public boolean isOnBorder(PointInt location) { if (location.getX() == 0) return true; else if (location.getX() == width() - 1) return true; if (location.getY() == 0) return true; else if (location.getY() == height() - 1) return true; else return false; } /** * returns a subgrid with the same step sizes, but different borders * * @param top the top index * @param left the left index * @param bottom the bottom index * @param right the right index * @return the Sub-Grid */ public Grid subgrid(int top, int left, int bottom, int right) { return new Grid( getValues(left, top).getX(), getValues(right, top).getX(), getStepX(), getLabelX(), getValues(left, bottom).getY(), getValues(left, top).getY(), getStepY(), getLabelY()); } /** * returns an extended grid that encompasses the given point (won't be on * the border of the grid). * * @param values the point that the grid should contain * @return the extended grid */ public Grid extend(PointDouble values) { double minX; double maxX; double minY; double maxY; double distance; Grid result; // left if (Utils.smOrEq(values.getX(), getMinX())) { distance = getMinX() - values.getX(); // exactly on grid point? if (Utils.eq(distance, 0)) minX = getMinX() - getStepX() * (StrictMath.round(distance / getStepX()) + 1); else minX = getMinX() - getStepX() * (StrictMath.round(distance / getStepX())); } else { minX = getMinX(); } // right if (Utils.grOrEq(values.getX(), getMaxX())) { distance = values.getX() - getMaxX(); // exactly on grid point? if (Utils.eq(distance, 0)) maxX = getMaxX() + getStepX() * (StrictMath.round(distance / getStepX()) + 1); else maxX = getMaxX() + getStepX() * (StrictMath.round(distance / getStepX())); } else { maxX = getMaxX(); } // bottom if (Utils.smOrEq(values.getY(), getMinY())) { distance = getMinY() - values.getY(); // exactly on grid point? if (Utils.eq(distance, 0)) minY = getMinY() - getStepY() * (StrictMath.round(distance / getStepY()) + 1); else minY = getMinY() - getStepY() * (StrictMath.round(distance / getStepY())); } else { minY = getMinY(); } // top if (Utils.grOrEq(values.getY(), getMaxY())) { distance = values.getY() - getMaxY(); // exactly on grid point? if (Utils.eq(distance, 0)) maxY = getMaxY() + getStepY() * (StrictMath.round(distance / getStepY()) + 1); else maxY = getMaxY() + getStepY() * (StrictMath.round(distance / getStepY())); } else { maxY = getMaxY(); } result = new Grid(minX, maxX, getStepX(), getLabelX(), minY, maxY, getStepY(), getLabelY()); // did the grid really extend? if (equals(result)) throw new IllegalStateException("Grid extension failed!"); return result; } /** * returns an Enumeration over all pairs in the given row * * @param y the row to retrieve * @return an Enumeration over all pairs * @see #getValues(int, int) */ public Enumeration row(int y) { Vector result; int i; result = new Vector(); for (i = 0; i < width(); i++) result.add(getValues(i, y)); return result.elements(); } /** * returns an Enumeration over all pairs in the given column * * @param x the column to retrieve * @return an Enumeration over all pairs * @see #getValues(int, int) */ public Enumeration column(int x) { Vector result; int i; result = new Vector(); for (i = 0; i < height(); i++) result.add(getValues(x, i)); return result.elements(); } /** * returns a string representation of the grid * * @return a string representation */ public String toString() { String result; result = "X: " + m_MinX + " - " + m_MaxX + ", Step " + m_StepX; if (m_LabelX.length() != 0) result += " (" + m_LabelX + ")"; result += "\n"; result += "Y: " + m_MinY + " - " + m_MaxY + ", Step " + m_StepY; if (m_LabelY.length() != 0) result += " (" + m_LabelY + ")"; result += "\n"; result += "Dimensions (Rows x Columns): " + height() + " x " + width(); return result; } /** * Returns the revision string. * * @return the revision */ public String getRevision() { return RevisionUtils.extract("$Revision: 5928 $"); } } /** * A helper class for storing the performance of a values-pair. * Can be sorted with the PerformanceComparator class. * * @see PerformanceComparator */ protected class Performance implements Serializable, RevisionHandler { /** for serialization */ private static final long serialVersionUID = -4374706475277588755L; /** the value pair the classifier was built with */ protected PointDouble m_Values; /** the Correlation coefficient */ protected double m_CC; /** the Root mean squared error */ protected double m_RMSE; /** the Root relative squared error */ protected double m_RRSE; /** the Mean absolute error */ protected double m_MAE; /** the Relative absolute error */ protected double m_RAE; /** the Accuracy */ protected double m_ACC; /** the kappa value */ protected double m_Kappa; /** * initializes the performance container * * @param values the values-pair * @param evaluation the evaluation to extract the performance * measures from * @throws Exception if retrieving of measures fails */ public Performance(PointDouble values, Evaluation evaluation) throws Exception { super(); m_Values = values; m_RMSE = evaluation.rootMeanSquaredError(); m_RRSE = evaluation.rootRelativeSquaredError(); m_MAE = evaluation.meanAbsoluteError(); m_RAE = evaluation.relativeAbsoluteError(); try { m_CC = evaluation.correlationCoefficient(); } catch (Exception e) { m_CC = Double.NaN; } try { m_ACC = evaluation.pctCorrect(); } catch (Exception e) { m_ACC = Double.NaN; } try { m_Kappa = evaluation.kappa(); } catch (Exception e) { m_Kappa = Double.NaN; } } /** * returns the performance measure * * @param evaluation the type of measure to return * @return the performance measure */ public double getPerformance(int evaluation) { double result; result = Double.NaN; switch (evaluation) { case EVALUATION_CC: result = m_CC; break; case EVALUATION_RMSE: result = m_RMSE; break; case EVALUATION_RRSE: result = m_RRSE; break; case EVALUATION_MAE: result = m_MAE; break; case EVALUATION_RAE: result = m_RAE; break; case EVALUATION_COMBINED: result = (1 - StrictMath.abs(m_CC)) + m_RRSE + m_RAE; break; case EVALUATION_ACC: result = m_ACC; break; case EVALUATION_KAPPA: result = m_Kappa; break; default: throw new IllegalArgumentException("Evaluation type '" + evaluation + "' not supported!"); } return result; } /** * returns the values-pair for this performance * * @return the values-pair */ public PointDouble getValues() { return m_Values; } /** * returns a string representation of this performance object * * @param evaluation the type of performance to return * @return a string representation */ public String toString(int evaluation) { String result; result = "Performance (" + getValues() + "): " + getPerformance(evaluation) + " (" + new SelectedTag(evaluation, TAGS_EVALUATION) + ")"; return result; } /** * returns a Gnuplot string of this performance object * * @param evaluation the type of performance to return * @return the gnuplot string (x, y, z) */ public String toGnuplot(int evaluation) { String result; result = getValues().getX() + "\t" + getValues().getY() + "\t" + getPerformance(evaluation); return result; } /** * returns a string representation of this performance object * * @return a string representation */ public String toString() { String result; int i; result = "Performance (" + getValues() + "): "; for (i = 0; i < TAGS_EVALUATION.length; i++) { if (i > 0) result += ", "; result += getPerformance(TAGS_EVALUATION[i].getID()) + " (" + new SelectedTag(TAGS_EVALUATION[i].getID(), TAGS_EVALUATION) + ")"; } return result; } /** * Returns the revision string. * * @return the revision */ public String getRevision() { return RevisionUtils.extract("$Revision: 5928 $"); } } /** * A concrete Comparator for the Performance class. * * @see Performance */ protected class PerformanceComparator implements Comparator, Serializable, RevisionHandler { /** for serialization */ private static final long serialVersionUID = 6507592831825393847L; /** the performance measure to use for comparison * @see GridSearch#TAGS_EVALUATION */ protected int m_Evaluation; /** * initializes the comparator with the given performance measure * * @param evaluation the performance measure to use * @see GridSearch#TAGS_EVALUATION */ public PerformanceComparator(int evaluation) { super(); m_Evaluation = evaluation; } /** * returns the performance measure that's used to compare the objects * * @return the performance measure * @see GridSearch#TAGS_EVALUATION */ public int getEvaluation() { return m_Evaluation; } /** * Compares its two arguments for order. Returns a negative integer, * zero, or a positive integer as the first argument is less than, * equal to, or greater than the second. * * @param o1 the first performance * @param o2 the second performance * @return the order */ public int compare(Performance o1, Performance o2) { int result; double p1; double p2; p1 = o1.getPerformance(getEvaluation()); p2 = o2.getPerformance(getEvaluation()); if (Utils.sm(p1, p2)) result = -1; else if (Utils.gr(p1, p2)) result = 1; else result = 0; // only correlation coefficient/accuracy/kappa obey to this order, for the // errors (and the combination of all three), the smaller the number the // better -> hence invert them if ( (getEvaluation() != EVALUATION_CC) && (getEvaluation() != EVALUATION_ACC) && (getEvaluation() != EVALUATION_KAPPA) ) result = -result; return result; } /** * Indicates whether some other object is "equal to" this Comparator. * * @param obj the object to compare with * @return true if the same evaluation type is used */ public boolean equals(Object obj) { if (!(obj instanceof PerformanceComparator)) throw new IllegalArgumentException("Must be PerformanceComparator!"); return (m_Evaluation == ((PerformanceComparator) obj).m_Evaluation); } /** * Returns the revision string. * * @return the revision */ public String getRevision() { return RevisionUtils.extract("$Revision: 5928 $"); } } /** * Generates a 2-dim array for the performances from a grid for a certain * type. x-min/y-min is in the bottom-left corner, i.e., getTable()[0][0] * returns the performance for the x-min/y-max pair. *
   * x-min     x-max
   * |-------------|
   *                - y-max
   *                |
   *                |
   *                - y-min
   * 
*/ protected class PerformanceTable implements Serializable, RevisionHandler { /** for serialization */ private static final long serialVersionUID = 5486491313460338379L; /** the corresponding grid */ protected Grid m_Grid; /** the performances */ protected Vector m_Performances; /** the type of performance the table was generated for */ protected int m_Type; /** the table with the values */ protected double[][] m_Table; /** the minimum performance */ protected double m_Min; /** the maximum performance */ protected double m_Max; /** * initializes the table * * @param grid the underlying grid * @param performances the performances * @param type the type of performance */ public PerformanceTable(Grid grid, Vector performances, int type) { super(); m_Grid = grid; m_Type = type; m_Performances = performances; generate(); } /** * generates the table */ protected void generate() { Performance perf; int i; PointInt location; m_Table = new double[getGrid().height()][getGrid().width()]; m_Min = 0; m_Max = 0; for (i = 0; i < getPerformances().size(); i++) { perf = (Performance) getPerformances().get(i); location = getGrid().getLocation(perf.getValues()); m_Table[getGrid().height() - (int) location.getY() - 1][(int) location.getX()] = perf.getPerformance(getType()); // determine min/max if (i == 0) { m_Min = perf.getPerformance(m_Type); m_Max = m_Min; } else { if (perf.getPerformance(m_Type) < m_Min) m_Min = perf.getPerformance(m_Type); if (perf.getPerformance(m_Type) > m_Max) m_Max = perf.getPerformance(m_Type); } } } /** * returns the corresponding grid * * @return the underlying grid */ public Grid getGrid() { return m_Grid; } /** * returns the underlying performances * * @return the underlying performances */ public Vector getPerformances() { return m_Performances; } /** * returns the type of performance * * @return the type of performance */ public int getType() { return m_Type; } /** * returns the generated table * * @return the performance table * @see #m_Table * @see #generate() */ public double[][] getTable() { return m_Table; } /** * the minimum performance * * @return the performance */ public double getMin() { return m_Min; } /** * the maximum performance * * @return the performance */ public double getMax() { return m_Max; } /** * returns the table as string * * @return the table as string */ public String toString() { String result; int i; int n; result = "Table (" + new SelectedTag(getType(), TAGS_EVALUATION).getSelectedTag().getReadable() + ") - " + "X: " + getGrid().getLabelX() + ", Y: " + getGrid().getLabelY() + ":\n"; for (i = 0; i < getTable().length; i++) { if (i > 0) result += "\n"; for (n = 0; n < getTable()[i].length; n++) { if (n > 0) result += ","; result += getTable()[i][n]; } } return result; } /** * returns a string containing a gnuplot script+data file * * @return the data in gnuplot format */ public String toGnuplot() { StringBuffer result; Tag type; int i; result = new StringBuffer(); type = new SelectedTag(getType(), TAGS_EVALUATION).getSelectedTag(); result.append("Gnuplot (" + type.getReadable() + "):\n"); result.append("# begin 'gridsearch.data'\n"); result.append("# " + type.getReadable() + "\n"); for (i = 0; i < getPerformances().size(); i++) result.append(getPerformances().get(i).toGnuplot(type.getID()) + "\n"); result.append("# end 'gridsearch.data'\n\n"); result.append("# begin 'gridsearch.plot'\n"); result.append("# " + type.getReadable() + "\n"); result.append("set data style lines\n"); result.append("set contour base\n"); result.append("set surface\n"); result.append("set title '" + m_Data.relationName() + "'\n"); result.append("set xrange [" + getGrid().getMinX() + ":" + getGrid().getMaxX() + "]\n"); result.append("set xlabel 'x (" + getFilter().getClass().getName() + ": " + getXProperty() + ")'\n"); result.append("set yrange [" + getGrid().getMinY() + ":" + getGrid().getMaxY() + "]\n"); result.append("set ylabel 'y - (" + getClassifier().getClass().getName() + ": " + getYProperty() + ")'\n"); result.append("set zrange [" + (getMin() - (getMax() - getMin())*0.1) + ":" + (getMax() + (getMax() - getMin())*0.1) + "]\n"); result.append("set zlabel 'z - " + type.getReadable() + "'\n"); result.append("set dgrid3d " + getGrid().height() + "," + getGrid().width() + ",1\n"); result.append("show contour\n"); result.append("splot 'gridsearch.data'\n"); result.append("pause -1\n"); result.append("# end 'gridsearch.plot'"); return result.toString(); } /** * Returns the revision string. * * @return the revision */ public String getRevision() { return RevisionUtils.extract("$Revision: 5928 $"); } } /** * Represents a simple cache for performance objects. */ protected class PerformanceCache implements Serializable, RevisionHandler { /** for serialization */ private static final long serialVersionUID = 5838863230451530252L; /** the cache for points in the grid that got calculated */ protected Hashtable m_Cache = new Hashtable(); /** * returns the ID string for a cache item * * @param cv the number of folds in the cross-validation * @param values the point in the grid * @return the ID string */ protected String getID(int cv, PointDouble values) { return cv + "\t" + values.getX() + "\t" + values.getY(); } /** * checks whether the point was already calculated ones * * @param cv the number of folds in the cross-validation * @param values the point in the grid * @return true if the value is already cached */ public boolean isCached(int cv, PointDouble values) { return (get(cv, values) != null); } /** * returns a cached performance object, null if not yet in the cache * * @param cv the number of folds in the cross-validation * @param values the point in the grid * @return the cached performance item, null if not in cache */ public Performance get(int cv, PointDouble values) { return (Performance) m_Cache.get(getID(cv, values)); } /** * adds the performance to the cache * * @param cv the number of folds in the cross-validation * @param p the performance object to store */ public void add(int cv, Performance p) { m_Cache.put(getID(cv, p.getValues()), p); } /** * returns a string representation of the cache * * @return the string representation of the cache */ public String toString() { return m_Cache.toString(); } /** * Returns the revision string. * * @return the revision */ public String getRevision() { return RevisionUtils.extract("$Revision: 5928 $"); } } /** for serialization */ private static final long serialVersionUID = -3034773968581595348L; /** evaluation via: Correlation coefficient */ public static final int EVALUATION_CC = 0; /** evaluation via: Root mean squared error */ public static final int EVALUATION_RMSE = 1; /** evaluation via: Root relative squared error */ public static final int EVALUATION_RRSE = 2; /** evaluation via: Mean absolute error */ public static final int EVALUATION_MAE = 3; /** evaluation via: Relative absolute error */ public static final int EVALUATION_RAE = 4; /** evaluation via: Combined = (1-CC) + RRSE + RAE */ public static final int EVALUATION_COMBINED = 5; /** evaluation via: Accuracy */ public static final int EVALUATION_ACC = 6; /** evaluation via: kappa statistic */ public static final int EVALUATION_KAPPA = 7; /** evaluation */ public static final Tag[] TAGS_EVALUATION = { new Tag(EVALUATION_CC, "CC", "Correlation coefficient"), new Tag(EVALUATION_RMSE, "RMSE", "Root mean squared error"), new Tag(EVALUATION_RRSE, "RRSE", "Root relative squared error"), new Tag(EVALUATION_MAE, "MAE", "Mean absolute error"), new Tag(EVALUATION_RAE, "RAE", "Root absolute error"), new Tag(EVALUATION_COMBINED, "COMB", "Combined = (1-abs(CC)) + RRSE + RAE"), new Tag(EVALUATION_ACC, "ACC", "Accuracy"), new Tag(EVALUATION_KAPPA, "KAP", "Kappa") }; /** row-wise grid traversal */ public static final int TRAVERSAL_BY_ROW = 0; /** column-wise grid traversal */ public static final int TRAVERSAL_BY_COLUMN = 1; /** traversal */ public static final Tag[] TAGS_TRAVERSAL = { new Tag(TRAVERSAL_BY_ROW, "row-wise", "row-wise"), new Tag(TRAVERSAL_BY_COLUMN, "column-wise", "column-wise") }; /** the prefix to indicate that the option is for the classifier */ public final static String PREFIX_CLASSIFIER = "classifier."; /** the prefix to indicate that the option is for the filter */ public final static String PREFIX_FILTER = "filter."; /** the Filter */ protected Filter m_Filter; /** the Filter with the best setup */ protected Filter m_BestFilter; /** the Classifier with the best setup */ protected Classifier m_BestClassifier; /** the best values */ protected PointDouble m_Values = null; /** the type of evaluation */ protected int m_Evaluation = EVALUATION_CC; /** the Y option to work on (without leading dash, preceding 'classifier.' * means to set the option for the classifier 'filter.' for the filter) */ protected String m_Y_Property = PREFIX_CLASSIFIER + "ridge"; /** the minimum of Y */ protected double m_Y_Min = -10; /** the maximum of Y */ protected double m_Y_Max = +5; /** the step size of Y */ protected double m_Y_Step = 1; /** the base for Y */ protected double m_Y_Base = 10; /** * The expression for the Y property. Available parameters for the * expression: *
    *
  • BASE
  • *
  • FROM (= min)
  • *
  • TO (= max)
  • *
  • STEP
  • *
  • I - the current value (from 'from' to 'to' with stepsize 'step')
  • *
* * @see MathematicalExpression * @see MathExpression */ protected String m_Y_Expression = "pow(BASE,I)"; /** the X option to work on (without leading dash, preceding 'classifier.' * means to set the option for the classifier 'filter.' for the filter) */ protected String m_X_Property = PREFIX_FILTER + "numComponents"; /** the minimum of X */ protected double m_X_Min = +5; /** the maximum of X */ protected double m_X_Max = +20; /** the step size of */ protected double m_X_Step = 1; /** the base for */ protected double m_X_Base = 10; /** * The expression for the X property. Available parameters for the * expression: *
    *
  • BASE
  • *
  • FROM (= min)
  • *
  • TO (= max)
  • *
  • STEP
  • *
  • I - the current value (from 'from' to 'to' with stepsize 'step')
  • *
* * @see MathematicalExpression * @see MathExpression */ protected String m_X_Expression = "I"; /** whether the grid can be extended */ protected boolean m_GridIsExtendable = false; /** maximum number of grid extensions (-1 means unlimited) */ protected int m_MaxGridExtensions = 3; /** the number of extensions performed */ protected int m_GridExtensionsPerformed = 0; /** the sample size to search the initial grid with */ protected double m_SampleSize = 100; /** the traversal */ protected int m_Traversal = TRAVERSAL_BY_COLUMN; /** the log file to use */ protected File m_LogFile = new File(System.getProperty("user.dir")); /** the value-pairs grid */ protected Grid m_Grid; /** the training data */ protected Instances m_Data; /** the cache for points in the grid that got calculated */ protected PerformanceCache m_Cache; /** whether all performances in the grid are the same */ protected boolean m_UniformPerformance = false; /** * the default constructor */ public GridSearch() { super(); // classifier m_Classifier = new LinearRegression(); ((LinearRegression) m_Classifier).setAttributeSelectionMethod(new SelectedTag(LinearRegression.SELECTION_NONE, LinearRegression.TAGS_SELECTION)); ((LinearRegression) m_Classifier).setEliminateColinearAttributes(false); // filter m_Filter = new PLSFilter(); PLSFilter filter = new PLSFilter(); filter.setPreprocessing(new SelectedTag(PLSFilter.PREPROCESSING_STANDARDIZE, PLSFilter.TAGS_PREPROCESSING)); filter.setReplaceMissing(true); try { m_BestClassifier = AbstractClassifier.makeCopy(m_Classifier); } catch (Exception e) { e.printStackTrace(); } try { m_BestFilter = Filter.makeCopy(filter); } catch (Exception e) { e.printStackTrace(); } } /** * Returns a string describing classifier * * @return a description suitable for displaying in the * explorer/experimenter gui */ public String globalInfo() { return "Performs a grid search of parameter pairs for the a classifier " + "(Y-axis, default is LinearRegression with the \"Ridge\" parameter) " + "and the PLSFilter (X-axis, \"# of Components\") and chooses the best " + "pair found for the actual predicting.\n\n" + "The initial grid is worked on with 2-fold CV to determine the values " + "of the parameter pairs for the selected type of evaluation (e.g., " + "accuracy). The best point in the grid is then taken and a 10-fold CV " + "is performed with the adjacent parameter pairs. If a better pair is " + "found, then this will act as new center and another 10-fold CV will " + "be performed (kind of hill-climbing). This process is repeated until " + "no better pair is found or the best pair is on the border of the grid.\n" + "In case the best pair is on the border, one can let GridSearch " + "automatically extend the grid and continue the search. Check out the " + "properties 'gridIsExtendable' (option '-extend-grid') and " + "'maxGridExtensions' (option '-max-grid-extensions ').\n\n" + "GridSearch can handle doubles, integers (values are just cast to int) " + "and booleans (0 is false, otherwise true). float, char and long are " + "supported as well.\n\n" + "The best filter/classifier setup can be accessed after the buildClassifier " + "call via the getBestFilter/getBestClassifier methods.\n" + "Note on the implementation: after the data has been passed through " + "the filter, a default NumericCleaner filter is applied to the data in " + "order to avoid numbers that are getting too small and might produce " + "NaNs in other schemes."; } /** * String describing default classifier. * * @return the classname of the default classifier */ protected String defaultClassifierString() { return LinearRegression.class.getName(); } /** * Gets an enumeration describing the available options. * * @return an enumeration of all the available options. */ public Enumeration listOptions(){ Vector result; Enumeration en; String desc; SelectedTag tag; int i; result = new Vector(); desc = ""; for (i = 0; i < TAGS_EVALUATION.length; i++) { tag = new SelectedTag(TAGS_EVALUATION[i].getID(), TAGS_EVALUATION); desc += "\t" + tag.getSelectedTag().getIDStr() + " = " + tag.getSelectedTag().getReadable() + "\n"; } result.addElement(new Option( "\tDetermines the parameter used for evaluation:\n" + desc + "\t(default: " + new SelectedTag(EVALUATION_CC, TAGS_EVALUATION) + ")", "E", 1, "-E " + Tag.toOptionList(TAGS_EVALUATION))); result.addElement(new Option( "\tThe Y option to test (without leading dash).\n" + "\t(default: " + PREFIX_CLASSIFIER + "ridge)", "y-property", 1, "-y-property