/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * RacedIncrementalLogitBoost.java * Copyright (C) 2002 University of Waikato, Hamilton, New Zealand * */ package weka.classifiers.meta; import weka.classifiers.Classifier; import weka.classifiers.AbstractClassifier; import weka.classifiers.RandomizableSingleClassifierEnhancer; import weka.classifiers.UpdateableClassifier; import weka.classifiers.rules.ZeroR; import weka.core.Attribute; import weka.core.Capabilities; import weka.core.FastVector; import weka.core.Instance; import weka.core.Instances; import weka.core.Option; import weka.core.RevisionHandler; import weka.core.RevisionUtils; import weka.core.SelectedTag; import weka.core.Tag; import weka.core.Utils; import weka.core.WeightedInstancesHandler; import weka.core.Capabilities.Capability; import java.io.Serializable; import java.util.Enumeration; import java.util.Random; import java.util.Vector; /** * Classifier for incremental learning of large datasets by way of racing logit-boosted committees. *
* * Valid options are: * *-C <num> * Minimum size of chunks. * (default 500)* *
-M <num> * Maximum size of chunks. * (default 2000)* *
-V <num> * Size of validation set. * (default 1000)* *
-P <pruning type> * Committee pruning to perform. * 0=none, 1=log likelihood (default)* *
-Q * Use resampling for boosting.* *
-S <num> * Random number seed. * (default 1)* *
-D * If set, classifier is run in debug mode and * may output additional info to the console* *
-W * Full name of base classifier. * (default: weka.classifiers.trees.DecisionStump)* *
* Options specific to classifier weka.classifiers.trees.DecisionStump: ** *
-D * If set, classifier is run in debug mode and * may output additional info to the console* * * Options after -- are passed to the designated learner.
*
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @author Eibe Frank (eibe@cs.waikato.ac.nz)
* @version $Revision: 5987 $
*/
public class RacedIncrementalLogitBoost
extends RandomizableSingleClassifierEnhancer
implements UpdateableClassifier {
/** for serialization */
static final long serialVersionUID = 908598343772170052L;
/** no pruning */
public static final int PRUNETYPE_NONE = 0;
/** log likelihood pruning */
public static final int PRUNETYPE_LOGLIKELIHOOD = 1;
/** The pruning types */
public static final Tag [] TAGS_PRUNETYPE = {
new Tag(PRUNETYPE_NONE, "No pruning"),
new Tag(PRUNETYPE_LOGLIKELIHOOD, "Log likelihood pruning")
};
/** The committees */
protected FastVector m_committees;
/** The pruning type used */
protected int m_PruningType = PRUNETYPE_LOGLIKELIHOOD;
/** Whether to use resampling */
protected boolean m_UseResampling = false;
/** The number of classes */
protected int m_NumClasses;
/** A threshold for responses (Friedman suggests between 2 and 4) */
protected static final double Z_MAX = 4;
/** Dummy dataset with a numeric class */
protected Instances m_NumericClassData;
/** The actual class attribute (for getting class names) */
protected Attribute m_ClassAttribute;
/** The minimum chunk size used for training */
protected int m_minChunkSize = 500;
/** The maimum chunk size used for training */
protected int m_maxChunkSize = 2000;
/** The size of the validation set */
protected int m_validationChunkSize = 1000;
/** The number of instances consumed */
protected int m_numInstancesConsumed;
/** The instances used for validation */
protected Instances m_validationSet;
/** The instances currently in memory for training */
protected Instances m_currentSet;
/** The current best committee */
protected Committee m_bestCommittee;
/** The default scheme used when committees aren't ready */
protected ZeroR m_zeroR = null;
/** Whether the validation set has recently been changed */
protected boolean m_validationSetChanged;
/** The maximum number of instances required for processing */
protected int m_maxBatchSizeRequired;
/** The random number generator used */
protected Random m_RandomInstance = null;
/**
* Constructor.
*/
public RacedIncrementalLogitBoost() {
m_Classifier = new weka.classifiers.trees.DecisionStump();
}
/**
* String describing default classifier.
*
* @return the default classifier classname
*/
protected String defaultClassifierString() {
return "weka.classifiers.trees.DecisionStump";
}
/**
* Class representing a committee of LogitBoosted models
*/
protected class Committee
implements Serializable, RevisionHandler {
/** for serialization */
static final long serialVersionUID = 5559880306684082199L;
protected int m_chunkSize;
/** number eaten from m_currentSet */
protected int m_instancesConsumed;
protected FastVector m_models;
protected double m_lastValidationError;
protected double m_lastLogLikelihood;
protected boolean m_modelHasChanged;
protected boolean m_modelHasChangedLL;
protected double[][] m_validationFs;
protected double[][] m_newValidationFs;
/**
* constructor
*
* @param chunkSize the size of the chunk
*/
public Committee(int chunkSize) {
m_chunkSize = chunkSize;
m_instancesConsumed = 0;
m_models = new FastVector();
m_lastValidationError = 1.0;
m_lastLogLikelihood = Double.MAX_VALUE;
m_modelHasChanged = true;
m_modelHasChangedLL = true;
m_validationFs = new double[m_validationChunkSize][m_NumClasses];
m_newValidationFs = new double[m_validationChunkSize][m_NumClasses];
}
/**
* update the committee
*
* @return true if the committee has changed
* @throws Exception if anything goes wrong
*/
public boolean update() throws Exception {
boolean hasChanged = false;
while (m_currentSet.numInstances() - m_instancesConsumed >= m_chunkSize) {
Classifier[] newModel = boost(new Instances(m_currentSet, m_instancesConsumed, m_chunkSize));
for (int i=0; i -C <num>
* Minimum size of chunks.
* (default 500)
*
* -M <num>
* Maximum size of chunks.
* (default 2000)
*
* -V <num>
* Size of validation set.
* (default 1000)
*
* -P <pruning type>
* Committee pruning to perform.
* 0=none, 1=log likelihood (default)
*
* -Q
* Use resampling for boosting.
*
* -S <num>
* Random number seed.
* (default 1)
*
* -D
* If set, classifier is run in debug mode and
* may output additional info to the console
*
* -W
* Full name of base classifier.
* (default: weka.classifiers.trees.DecisionStump)
*
*
* Options specific to classifier weka.classifiers.trees.DecisionStump:
*
*
* -D
* If set, classifier is run in debug mode and
* may output additional info to the console
*
*
* @param options the list of options as an array of strings
* @throws Exception if an option is not supported
*/
public void setOptions(String[] options) throws Exception {
String minChunkSize = Utils.getOption('C', options);
if (minChunkSize.length() != 0) {
setMinChunkSize(Integer.parseInt(minChunkSize));
} else {
setMinChunkSize(500);
}
String maxChunkSize = Utils.getOption('M', options);
if (maxChunkSize.length() != 0) {
setMaxChunkSize(Integer.parseInt(maxChunkSize));
} else {
setMaxChunkSize(2000);
}
String validationChunkSize = Utils.getOption('V', options);
if (validationChunkSize.length() != 0) {
setValidationChunkSize(Integer.parseInt(validationChunkSize));
} else {
setValidationChunkSize(1000);
}
String pruneType = Utils.getOption('P', options);
if (pruneType.length() != 0) {
setPruningType(new SelectedTag(Integer.parseInt(pruneType), TAGS_PRUNETYPE));
} else {
setPruningType(new SelectedTag(PRUNETYPE_LOGLIKELIHOOD, TAGS_PRUNETYPE));
}
setUseResampling(Utils.getFlag('Q', options));
super.setOptions(options);
}
/**
* Gets the current settings of the Classifier.
*
* @return an array of strings suitable for passing to setOptions
*/
public String [] getOptions() {
String [] superOptions = super.getOptions();
String [] options = new String [superOptions.length + 9];
int current = 0;
if (getUseResampling()) {
options[current++] = "-Q";
}
options[current++] = "-C"; options[current++] = "" + getMinChunkSize();
options[current++] = "-M"; options[current++] = "" + getMaxChunkSize();
options[current++] = "-V"; options[current++] = "" + getValidationChunkSize();
options[current++] = "-P"; options[current++] = "" + m_PruningType;
System.arraycopy(superOptions, 0, options, current,
superOptions.length);
current += superOptions.length;
while (current < options.length) {
options[current++] = "";
}
return options;
}
/**
* @return a description of the classifier suitable for
* displaying in the explorer/experimenter gui
*/
public String globalInfo() {
return "Classifier for incremental learning of large datasets by way of racing logit-boosted committees.";
}
/**
* Set the base learner.
*
* @param newClassifier the classifier to use.
* @throws IllegalArgumentException if base classifier cannot handle numeric
* class
*/
public void setClassifier(Classifier newClassifier) {
Capabilities cap = newClassifier.getCapabilities();
if (!cap.handles(Capability.NUMERIC_CLASS))
throw new IllegalArgumentException("Base classifier cannot handle numeric class!");
super.setClassifier(newClassifier);
}
/**
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String minChunkSizeTipText() {
return "The minimum number of instances to train the base learner with.";
}
/**
* Set the minimum chunk size
*
* @param chunkSize the minimum chunk size
*/
public void setMinChunkSize(int chunkSize) {
m_minChunkSize = chunkSize;
}
/**
* Get the minimum chunk size
*
* @return the chunk size
*/
public int getMinChunkSize() {
return m_minChunkSize;
}
/**
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String maxChunkSizeTipText() {
return "The maximum number of instances to train the base learner with. The chunk sizes used will start at minChunkSize and grow twice as large for as many times as they are less than or equal to the maximum size.";
}
/**
* Set the maximum chunk size
*
* @param chunkSize the maximum chunk size
*/
public void setMaxChunkSize(int chunkSize) {
m_maxChunkSize = chunkSize;
}
/**
* Get the maximum chunk size
*
* @return the chunk size
*/
public int getMaxChunkSize() {
return m_maxChunkSize;
}
/**
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String validationChunkSizeTipText() {
return "The number of instances to hold out for validation. These instances will be taken from the beginning of the stream, so learning will not start until these instances have been consumed first.";
}
/**
* Set the validation chunk size
*
* @param chunkSize the validation chunk size
*/
public void setValidationChunkSize(int chunkSize) {
m_validationChunkSize = chunkSize;
}
/**
* Get the validation chunk size
*
* @return the chunk size
*/
public int getValidationChunkSize() {
return m_validationChunkSize;
}
/**
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String pruningTypeTipText() {
return "The pruning method to use within each committee. Log likelihood pruning will discard new models if they have a negative effect on the log likelihood of the validation data.";
}
/**
* Set the pruning type
*
* @param pruneType the pruning type
*/
public void setPruningType(SelectedTag pruneType) {
if (pruneType.getTags() == TAGS_PRUNETYPE) {
m_PruningType = pruneType.getSelectedTag().getID();
}
}
/**
* Get the pruning type
*
* @return the type
*/
public SelectedTag getPruningType() {
return new SelectedTag(m_PruningType, TAGS_PRUNETYPE);
}
/**
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String useResamplingTipText() {
return "Force the use of resampling data rather than using the weight-handling capabilities of the base classifier. Resampling is always used if the base classifier cannot handle weighted instances.";
}
/**
* Set resampling mode
*
* @param r true if resampling should be done
*/
public void setUseResampling(boolean r) {
m_UseResampling = r;
}
/**
* Get whether resampling is turned on
*
* @return true if resampling output is on
*/
public boolean getUseResampling() {
return m_UseResampling;
}
/**
* Get the best committee chunk size
*
* @return the best committee chunk size
*/
public int getBestCommitteeChunkSize() {
if (m_bestCommittee != null) {
return m_bestCommittee.chunkSize();
}
else return 0;
}
/**
* Get the number of members in the best committee
*
* @return the number of members
*/
public int getBestCommitteeSize() {
if (m_bestCommittee != null) {
return m_bestCommittee.committeeSize();
}
else return 0;
}
/**
* Get the best committee's error on the validation data
*
* @return the best committee's error
*/
public double getBestCommitteeErrorEstimate() {
if (m_bestCommittee != null) {
try {
return m_bestCommittee.validationError() * 100.0;
} catch (Exception e) {
System.err.println(e.getMessage());
return 100.0;
}
}
else return 100.0;
}
/**
* Get the best committee's log likelihood on the validation data
*
* @return best committee's log likelihood
*/
public double getBestCommitteeLLEstimate() {
if (m_bestCommittee != null) {
try {
return m_bestCommittee.logLikelihood();
} catch (Exception e) {
System.err.println(e.getMessage());
return Double.MAX_VALUE;
}
}
else return Double.MAX_VALUE;
}
/**
* Returns description of the boosted classifier.
*
* @return description of the boosted classifier as a string
*/
public String toString() {
if (m_bestCommittee != null) {
return m_bestCommittee.toString();
} else {
if ((m_validationSetChanged || m_zeroR == null) && m_validationSet != null
&& m_validationSet.numInstances() > 0) {
m_zeroR = new ZeroR();
try {
m_zeroR.buildClassifier(m_validationSet);
} catch (Exception e) {}
m_validationSetChanged = false;
}
if (m_zeroR != null) {
return ("RacedIncrementalLogitBoost: insufficient data to build model, resorting to ZeroR:\n\n"
+ m_zeroR.toString());
}
else return ("RacedIncrementalLogitBoost: no model built yet.");
}
}
/**
* Returns the revision string.
*
* @return the revision
*/
public String getRevision() {
return RevisionUtils.extract("$Revision: 5987 $");
}
/**
* Main method for this class.
*
* @param argv the commandline parameters
*/
public static void main(String[] argv) {
runClassifier(new RacedIncrementalLogitBoost(), argv);
}
}