/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 * Copyright (C) 2002 University of Waikato 
 */

package weka.classifiers.meta;

import weka.classifiers.AbstractClassifierTest;
import weka.classifiers.Classifier;
import weka.classifiers.evaluation.EvaluationUtils;
import weka.classifiers.evaluation.NominalPrediction;
import weka.core.Attribute;
import weka.core.FastVector;
import weka.core.Instances;
import weka.core.NoSupportForMissingValuesException;
import weka.core.SelectedTag;
import weka.core.UnsupportedAttributeTypeException;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.RemoveType;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;

import java.io.BufferedReader;
import java.io.InputStreamReader;

import junit.framework.Test;
import junit.framework.TestSuite;

/**
 * Tests ThresholdSelector. Run from the command line with:<p>
 * java weka.classifiers.meta.ThresholdSelectorTest
 *
 * @author <a href="mailto:len@reeltwo.com">Len Trigg</a>
 * @author FracPete (fracpete at waikato dot ac dot nz)
 * @version $Revision: 1.8 $
 */
public class ThresholdSelectorTest 
  extends AbstractClassifierTest {

  private static double[] DIST1 = new double [] {
    0.25,
    0.375,
    0.5,
    0.625,
    0.75,
    0.875,
    1.0
  };

  /** A set of instances to test with */
  protected transient Instances m_Instances;

  /** Used to generate various types of predictions */
  protected transient EvaluationUtils m_Evaluation;

  public ThresholdSelectorTest(String name) { 
    super(name);  
  }

  /**
   * Called by JUnit before each test method. This implementation creates
   * the default classifier to test and loads a test set of Instances.
   *
   * @exception Exception if an error occurs reading the example instances.
   */
  protected void setUp() throws Exception {
    super.setUp();

    m_Evaluation = new EvaluationUtils();
    m_Instances = new Instances(
                    new BufferedReader(
                      new InputStreamReader(
                        ClassLoader.getSystemResourceAsStream(
                          "weka/classifiers/data/ClassifierTest.arff"))));
  }

  /** Creates a default ThresholdSelector */
  public Classifier getClassifier() {
    return getClassifier(DIST1);
  }

  /** Called by JUnit after each test method */
  protected void tearDown() {
    super.tearDown();

    m_Evaluation = null;
  }

  /**
   * Creates a ThresholdSelector that returns predictions from a 
   * given distribution
   */
  public Classifier getClassifier(double[] dist) {
    return getClassifier(new ThresholdSelectorDummyClassifier(dist));
  }

  /**
   * Creates a ThresholdSelector with the given subclassifier.
   *
   * @param classifier a <code>Classifier</code> to use as the
   * subclassifier
   * @return a new <code>ThresholdSelector</code>
   */
  public Classifier getClassifier(Classifier classifier) {
    ThresholdSelector t = new ThresholdSelector();
    t.setClassifier(classifier);
    return t;
  }

  /**
   * Builds a model using the current classifier using the first
   * half of the current data for training, and generates a bunch of
   * predictions using the remaining half of the data for testing.
   *
   * @return a <code>FastVector</code> containing the predictions.
   */
  protected FastVector useClassifier() throws Exception {

    Classifier dc = null;
    int tot = m_Instances.numInstances();
    int mid = tot / 2;
    Instances train = null;
    Instances test = null;
    try {
      train = new Instances(m_Instances, 0, mid);
      test = new Instances(m_Instances, mid, tot - mid);
      dc = m_Classifier;
    } catch (Exception ex) {
      ex.printStackTrace();
      fail("Problem setting up to use classifier: " + ex);
    }
    int counter = 0;
    do {
      try {
	return m_Evaluation.getTrainTestPredictions(dc, train, test);
      } catch (UnsupportedAttributeTypeException ex) {
	SelectedTag tag = null;
	boolean invert = false;
	String msg = ex.getMessage();
	if ((msg.indexOf("string") != -1) && 
	    (msg.indexOf("attributes") != -1)) {
	  System.err.println("\nDeleting string attributes.");
	  tag = new SelectedTag(Attribute.STRING,
				RemoveType.TAGS_ATTRIBUTETYPE);
	} else if ((msg.indexOf("only") != -1) && 
		   (msg.indexOf("nominal") != -1)) {
	  System.err.println("\nDeleting non-nominal attributes.");
	  tag = new SelectedTag(Attribute.NOMINAL,
				RemoveType.TAGS_ATTRIBUTETYPE);
	  invert = true;
	} else if ((msg.indexOf("only") != -1) && 
		   (msg.indexOf("numeric") != -1)) {
	  System.err.println("\nDeleting non-numeric attributes.");
	  tag = new SelectedTag(Attribute.NUMERIC,
				RemoveType.TAGS_ATTRIBUTETYPE);
	  invert = true;
	}  else {
	  throw ex;
	}
	RemoveType attFilter = new RemoveType();
	attFilter.setAttributeType(tag);
	attFilter.setInvertSelection(invert);
	attFilter.setInputFormat(train);
	train = Filter.useFilter(train, attFilter);
	attFilter.batchFinished();
	test = Filter.useFilter(test, attFilter);
	counter++;
	if (counter > 2) {
	  throw ex;
	}
      } catch (NoSupportForMissingValuesException ex2) {
	System.err.println("\nReplacing missing values.");
	ReplaceMissingValues rmFilter = new ReplaceMissingValues();
	rmFilter.setInputFormat(train);
	train = Filter.useFilter(train, rmFilter);
	rmFilter.batchFinished();
	test = Filter.useFilter(test, rmFilter);
      } catch (IllegalArgumentException ex3) {
	String msg = ex3.getMessage();
	if (msg.indexOf("Not enough instances") != -1) {
	  System.err.println("\nInflating training data.");
	  Instances trainNew = new Instances(train);
	  for (int i = 0; i < train.numInstances(); i++) {
	    trainNew.add(train.instance(i));
	  }
	  train = trainNew;
	} else {
	  throw ex3;
	}
      }
    } while (true);
  }

  public void testRangeNone() throws Exception {
    
    int cind = 0;
    ((ThresholdSelector)m_Classifier).setDesignatedClass(new SelectedTag(ThresholdSelector.OPTIMIZE_0, ThresholdSelector.TAGS_OPTIMIZE));
    ((ThresholdSelector)m_Classifier).setRangeCorrection(new SelectedTag(ThresholdSelector.RANGE_NONE, ThresholdSelector.TAGS_RANGE));
    FastVector result = null;
    m_Instances.setClassIndex(1);
    result = useClassifier();
    assertTrue(result.size() != 0);
    double minp = 0;
    double maxp = 0;
    for (int i = 0; i < result.size(); i++) {
      NominalPrediction p = (NominalPrediction)result.elementAt(i);
      double prob = p.distribution()[cind];
      if ((i == 0) || (prob < minp)) minp = prob;
      if ((i == 0) || (prob > maxp)) maxp = prob;
    }
    assertTrue("Upper limit shouldn't increase", maxp <= 1.0);
    assertTrue("Lower limit shouldn'd decrease", minp >= 0.25);
  }
  
  public void testDesignatedClass() throws Exception {
    
    int cind = 0;
    for (int i = 0; i < ThresholdSelector.TAGS_OPTIMIZE.length; i++) {
      ((ThresholdSelector)m_Classifier).setDesignatedClass(new SelectedTag(ThresholdSelector.TAGS_OPTIMIZE[i].getID(), ThresholdSelector.TAGS_OPTIMIZE));
      m_Instances.setClassIndex(1);
      FastVector result = useClassifier();
      assertTrue(result.size() != 0);
    }
  }

  public void testEvaluationMode() throws Exception {
    
    int cind = 0;
    for (int i = 0; i < ThresholdSelector.TAGS_EVAL.length; i++) {
      ((ThresholdSelector)m_Classifier).setEvaluationMode(new SelectedTag(ThresholdSelector.TAGS_EVAL[i].getID(), ThresholdSelector.TAGS_EVAL));
      m_Instances.setClassIndex(1);
      FastVector result = useClassifier();
      assertTrue(result.size() != 0);
    }
  }

  public void testNumXValFolds() throws Exception {
    
    try {
      ((ThresholdSelector)m_Classifier).setNumXValFolds(0);
      fail("Expected IllegalArgumentException");
    } catch (IllegalArgumentException e) {
      // OK
    }

    int cind = 0;
    for (int i = 2; i < 20; i += 2) {
      ((ThresholdSelector)m_Classifier).setNumXValFolds(i);
      m_Instances.setClassIndex(1);
      FastVector result = useClassifier();
      assertTrue(result.size() != 0);
    }
  }

  public static Test suite() {
    return new TestSuite(ThresholdSelectorTest.class);
  }

  public static void main(String[] args){
    junit.textui.TestRunner.run(suite());
  }
}
