source: src/main/java/weka/classifiers/meta/ensembleSelection/ModelBag.java @ 14

Last change on this file since 14 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 22.1 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    EnsembleSelection.java
19 *    Copyright (C) 2006 David Michael
20 *
21 */
22
23package weka.classifiers.meta.ensembleSelection;
24
25import weka.classifiers.Evaluation;
26import weka.core.Instances;
27import weka.core.RevisionHandler;
28import weka.core.RevisionUtils;
29
30import java.util.Random;
31
32/**
33 * This class is responsible for the duties of a bag of models. It is designed
34 * for use with the EnsembleSelection meta classifier. It handles shuffling the
35 * models, doing sort initialization, performing forward selection/ backwards
36 * elimination, etc.
37 * <p/>
38 * We utilize a simple "virtual indexing" scheme inside. If we shuffle and/or
39 * sort the models, we change the "virtual" order around. The elements of the
40 * bag are always those elements with virtual index 0..(m_bagSize-1). Each
41 * "virtual" index maps to some real index in m_models. Not every model in
42 * m_models gets a virtual index... the virtual indexing is what defines the
43 * subset of models of which our Bag is composed. This makes it easy to refer to
44 * models in the bag, by their virtual index, while maintaining the original
45 * indexing for our clients.
46 *
47 * @author  David Michael
48 * @version $Revision: 1.2 $
49 */
50public class ModelBag
51  implements RevisionHandler {
52 
53  /**
54   * The "models", as a multidimensional array of predictions for the
55   * validation set. The first index is the model index, the second index is
56   * the index of the instance, and the third is the typical "class" index for
57   * a prediction's distribution. This is given to us in the constructor, and
58   * we never change it.
59   */
60  private double m_models[][][];
61 
62  /**
63   * Maps each model in our virtual indexing scheme to its original index as
64   * it is in m_models. The first m_bag_size elements here are considered our
65   * bag. Throughout the code, we use the index in to this array to refer to a
66   * model. When we shuffle the models, we really simply shuffle this array.
67   * When we want to refer back to the original model, it is easily looked up
68   * in this array. That is, if j = m_model_index[i], then m_models[j] is the
69   * model referred to by "virtual index" i. Models can easily be accessed by
70   * their virtual index using the "model()" method.
71   */
72  private int m_modelIndex[];
73 
74  /**
75   * The number of models in our bag. 1 <= m_bag_size <= m_models.length
76   */
77  private int m_bagSize;
78 
79  /**
80   * The total number of models chosen thus far for this bag. This value is
81   * important when calculating the predictions for the bag. (See
82   * computePredictions).
83   */
84  private int m_numChosen;
85 
86  /**
87   * The number of times each model has been chosen. Also can be thought of as
88   * the weight for each model. Indexed by the "virtual index".
89   */
90  private int m_timesChosen[];
91 
92  /**
93   * If true, print out debug information.
94   */
95  private boolean m_debug;
96 
97  /**
98   * Double representing the best performance achieved thus far in this bag.
99   * This Must be updated each time we make a change to the bag that improves
100   * performance. This is so that after all hillclimbing is completed, we can
101   * go back to the best ensemble that we encountered during hillclimbing.
102   */
103  private double m_bestPerformance;
104 
105  /**
106   * Array representing the weights for all the models which achieved the best
107   * performance thus far for the bag (i.e., the weights that achieved
108   * m_bestPerformance. This Must be updated each time we make a change to the
109   * bag (that improves performance, by calling updateBestTimesChosen. This is
110   * so that after all hillclimbing is completed, we can go back to the best
111   * ensemble that we encountered during hillclimbing. This array, unlike
112   * m_timesChosen, uses the original indexing as taken from m_models. That
113   * way, any time getModelWeights is called (which returns this array), the
114   * array is in the correct format for our client.
115   */
116  private int m_bestTimesChosen[];
117 
118  /**
119   * Constructor for ModelBag.
120   *
121   * @param models
122   *            The complete set of models from which to draw our bag. First
123   *            index is for the model, second is for the instance. The last
124   *            is a prediction distribution for that instance. Models are
125   *            represented by this array of predictions for validation data,
126   *            since that's all ensemble selection needs to know.
127   * @param bag_percent
128   *            The percentage of the set of given models that should be used
129   *            in the Model Bag.
130   * @param debug
131   *            Whether the ModelBag should print debug information.
132   *
133   */
134  public ModelBag(double models[][][], double bag_percent, boolean debug) {
135    m_debug = debug;
136    if (models.length == 0) {
137      throw new IllegalArgumentException(
138      "ModelBag needs at least 1 model.");
139    }
140    m_bagSize = (int) ((double) models.length * bag_percent);
141    m_models = models;
142    m_modelIndex = new int[m_models.length];
143    m_timesChosen = new int[m_models.length];
144    m_bestTimesChosen = m_timesChosen;
145    m_bestPerformance = 0.0;
146   
147    // Initially, no models are chosen.
148    m_numChosen = 0;
149    // Prepare our virtual indexing scheme. Initially, the indexes are
150    // the same as the original.
151    for (int i = 0; i < m_models.length; ++i) {
152      m_modelIndex[i] = i;
153      m_timesChosen[i] = 0;
154    }
155  }
156 
157  /**
158   * Swap model at virtual index i with model at virtual index j. This is used
159   * to shuffle the models. We do not change m_models, only the arrays which
160   * use the virtual indexing; m_modelIndex and m_timesChosen.
161   *
162   * @param i   first index
163   * @param j   second index
164   */
165  private void swap(int i, int j) {
166    if (i != j) {
167      int temp_index = m_modelIndex[i];
168      m_modelIndex[i] = m_modelIndex[j];
169      m_modelIndex[j] = temp_index;
170     
171      int tempWeight = m_timesChosen[i];
172      m_timesChosen[i] = m_timesChosen[j];
173      m_timesChosen[j] = tempWeight;
174    }
175  }
176 
177  /**
178   * Shuffle the models. The order in m_models is preserved, but we change our
179   * virtual indexes around.
180   *
181   * @param rand        the random number generator to use
182   */
183  public void shuffle(Random rand) {
184    if (m_models.length < 2)
185      return;
186   
187    for (int i = 0; i < m_models.length; ++i) {
188      int swap_index = rand.nextInt(m_models.length - 1);
189      if (swap_index >= i)
190        ++swap_index; // don't swap with itself
191      swap(i, swap_index);
192    }
193  }
194 
195  /**
196   * Convert an array of weights using virtual indices to an array of weights
197   * using real indices.
198   *
199   * @param virtual_weights     the virtual indices
200   * @return                    the real indices
201   */
202  private int[] virtualToRealWeights(int virtual_weights[]) {
203    int real_weights[] = new int[virtual_weights.length];
204    for (int i = 0; i < real_weights.length; ++i) {
205      real_weights[m_modelIndex[i]] = virtual_weights[i];
206    }
207    return real_weights;
208  }
209 
210  /**
211   *
212   */
213  private void updateBestTimesChosen() {
214    m_bestTimesChosen = virtualToRealWeights(m_timesChosen);
215  }
216 
217  /**
218   * Sort initialize the bag.
219   *
220   * @param num
221   *            the Maximum number of models to initialize with
222   * @param greedy
223   *            True if we do greedy addition, up to num. Greedy sort
224   *            initialization adds models (up to num) in order of best to
225   *            worst performance until performance no longer improves.
226   * @param instances
227   *            the data set (needed for performance evaluation)
228   * @param metric
229   *            metric for which to optimize. See EnsembleMetricHelper
230   * @return returns an array of indexes which were selected, in order
231   *         starting from the model with best performance.
232   * @throws Exception if something goes wrong
233   */
234  public int[] sortInitialize(int num, boolean greedy, Instances instances,
235      int metric) throws Exception {
236   
237    // First, get the performance of each model
238    double performance[] = new double[m_bagSize];
239    for (int i = 0; i < m_bagSize; ++i) {
240      performance[i] = evaluatePredictions(instances, model(i), metric);
241    }
242    int bestModels[] = new int[num]; // we'll use this to save model info
243    // Now sort the models by their performance... note we only need the
244    // first "num",
245    // so we don't actually bother to sort the whole thing... instead, we
246    // pick the num best
247    // by running num iterations of selection sort.
248    for (int i = 0; i < num; ++i) {
249      int max_index = i;
250      double max_value = performance[i];
251      for (int j = i + 1; j < m_bagSize; ++j) {
252        // Find the best model which we haven't already selected
253        if (performance[j] > max_value) {
254          max_value = performance[j];
255          max_index = j;
256        }
257      }
258      // Swap ith model in to the ith position (selection sort)
259      this.swap(i, max_index);
260      // swap performance numbers, too
261      double temp_perf = performance[i];
262      performance[i] = performance[max_index];
263      performance[max_index] = temp_perf;
264     
265      bestModels[i] = m_modelIndex[i];
266      if (!greedy) {
267        // If we're not being greedy, we just throw the model in
268        // no matter what
269        ++m_timesChosen[i];
270        ++m_numChosen;
271      }
272    }
273    // Now the best "num" models are all sorted and in position.
274    if (greedy) {
275      // If the "greedy" option was specified, do a smart sort
276      // initialization
277      // that adds models only so long as they help overall performance.
278      // This is what was done in the original Caruana paper.
279      double[][] tempPredictions = null;
280      double bestPerformance = 0.0;
281      if (num > 0) {
282        ++m_timesChosen[0];
283        ++m_numChosen;
284        updateBestTimesChosen();
285      }
286      for (int i = 1; i < num; ++i) {
287        tempPredictions = computePredictions(i, true);
288        double metric_value = evaluatePredictions(instances,
289            tempPredictions, metric);
290        if (metric_value > bestPerformance) {
291          // If performance improved, update the appropriate info.
292          bestPerformance = metric_value;
293          ++m_timesChosen[i];
294          ++m_numChosen;
295          updateBestTimesChosen();
296        } else {
297          // We found a model that doesn't help performance, so we
298          // stop adding models.
299          break;
300        }
301      }
302    }
303    updateBestTimesChosen();
304    if (m_debug) {
305      System.out.println("Sort Initialization added best " + m_numChosen
306          + " models to the bag.");
307    }
308    return bestModels;
309  }
310 
311  /**
312   * Add "weight" to the number of times each model in the bag was chosen.
313   * Typically for use with backward elimination.
314   *
315   * @param weight      the weight to add
316   */
317  public void weightAll(int weight) {
318    for (int i = 0; i < m_bagSize; ++i) {
319      m_timesChosen[i] += weight;
320      m_numChosen += weight;
321    }
322    updateBestTimesChosen();
323  }
324 
325  /**
326   * Forward select one model. Will add the model which has the best effect on
327   * performance. If replacement is false, and all models are chosen, no
328   * action is taken. If a model can be added, one always is (even if it hurts
329   * performance).
330   *
331   * @param withReplacement
332   *            whether a model can be added more than once.
333   * @param instances
334   *            The dataset, for calculating performance.
335   * @param metric
336   *            The metric to which we will optimize. See EnsembleMetricHelper
337   * @throws Exception if something goes wrong
338   */
339  public void forwardSelect(boolean withReplacement, Instances instances,
340      int metric) throws Exception {
341   
342    double bestPerformance = -1.0;
343    int bestIndex = -1;
344    double tempPredictions[][];
345    for (int i = 0; i < m_bagSize; ++i) {
346      // For each model in the bag
347      if ((m_timesChosen[i] == 0) || withReplacement) {
348        // If the model has not been chosen, or we're allowing
349        // replacement
350        // Get the predictions we would have if we add this model to the
351        // ensemble
352        tempPredictions = computePredictions(i, true);
353        // And find out how the hypothetical ensemble would perform.
354        double metric_value = evaluatePredictions(instances,
355            tempPredictions, metric);
356        if (metric_value > bestPerformance) {
357          // If it's better than our current best, make it our NEW
358          // best.
359          bestIndex = i;
360          bestPerformance = metric_value;
361        }
362      }
363    }
364    if (bestIndex == -1) {
365      // Replacement must be false, with more hillclimb iterations than
366      // models. Do nothing and return.
367      if (m_debug) {
368        System.out.println("Couldn't add model.  No action performed.");
369      }
370      return;
371    }
372    // We picked bestIndex as our best model. Update appropriate info.
373    m_timesChosen[bestIndex]++;
374    m_numChosen++;
375    if (bestPerformance > m_bestPerformance) {
376      // We find the peak of our performance over all hillclimb
377      // iterations.
378      // If this forwardSelect step improved our overall performance,
379      // update
380      // our best ensemble info.
381      updateBestTimesChosen();
382      m_bestPerformance = bestPerformance;
383    }
384  }
385 
386  /**
387   * Find the model whose removal will help the ensemble's performance the
388   * most, and remove it. If there is only one model left, we leave it in. If
389   * we can remove a model, we always do, even if it hurts performance.
390   *
391   * @param instances
392   *            The data set, for calculating performance
393   * @param metric
394   *            Metric to optimize for. See EnsembleMetricHelper.
395   * @throws Exception if something goes wrong
396   */
397  public void backwardEliminate(Instances instances, int metric)
398  throws Exception {
399   
400    // Find the best model to remove. I.e., model for which removal improves
401    // performance the most (or hurts it least), and remove it.
402    if (m_numChosen <= 1) {
403      // If we only have one model left, keep it, as a bag
404      // which chooses no models doesn't make much sense.
405      return;
406    }
407    double bestPerformance = -1.0;
408    int bestIndex = -1;
409    double tempPredictions[][];
410    for (int i = 0; i < m_bagSize; ++i) {
411      // For each model in the bag
412      if (m_timesChosen[i] > 0) {
413        // If the model has been chosen at least once,
414        // Get the predictions we would have if we remove this model
415        tempPredictions = computePredictions(i, false);
416        // And find out how the hypothetical ensemble would perform.
417        double metric_value = evaluatePredictions(instances,
418            tempPredictions, metric);
419        if (metric_value > bestPerformance) {
420          // If it's better than our current best, make it our NEW
421          // best.
422          bestIndex = i;
423          bestPerformance = metric_value;
424        }
425      }
426    }
427    if (bestIndex == -1) {
428      // The most likely cause of this is that we didn't have any models
429      // we could
430      // remove. Do nothing & return.
431      if (m_debug) {
432        System.out
433        .println("Couldn't remove model.  No action performed.");
434      }
435      return;
436    }
437    // We picked bestIndex as our best model. Update appropriate info.
438    m_timesChosen[bestIndex]--;
439    m_numChosen--;
440    if (m_debug) {
441      System.out.println("Removing model " + m_modelIndex[bestIndex]
442                                                          + " (" + bestIndex + ") " + bestPerformance);
443    }
444    if (bestPerformance > m_bestPerformance) {
445      // We find the peak of our performance over all hillclimb
446      // iterations.
447      // If this forwardSelect step improved our overall performance,
448      // update
449      // our best ensemble info.
450      updateBestTimesChosen();
451      m_bestPerformance = bestPerformance;
452    }
453    // return m_model_index[best_index]; //translate to original indexing
454    // and return
455  }
456 
457  /**
458   * Find the best action to perform, be it adding a model or removing a
459   * model, and perform it. Some action is always performed, even if it hurts
460   * performance.
461   *
462   * @param with_replacement
463   *            whether we can add a model more than once
464   * @param instances
465   *            The dataset, for determining performance.
466   * @param metric
467   *            The metric for which to optimize. See EnsembleMetricHelper.
468   * @throws Exception if something goes wrong
469   */
470  public void forwardSelectOrBackwardEliminate(boolean with_replacement,
471      Instances instances, int metric) throws Exception {
472   
473    // Find the best action to perform, be it adding a model or removing a
474    // model,
475    // and do it.
476    double bestPerformance = -1.0;
477    int bestIndex = -1;
478    boolean added = true;
479    double tempPredictions[][];
480    for (int i = 0; i < m_bagSize; ++i) {
481      // For each model in the bag:
482      // Try removing the model
483      if (m_timesChosen[i] > 0) {
484        // If the model has been chosen at least once,
485        // Get the predictions we would have if we remove this model
486        tempPredictions = computePredictions(i, false);
487        // And find out how the hypothetical ensemble would perform.
488        double metric_value = evaluatePredictions(instances,
489            tempPredictions, metric);
490        if (metric_value > bestPerformance) {
491          // If it's better than our current best, make it our NEW
492          // best.
493          bestIndex = i;
494          bestPerformance = metric_value;
495          added = false;
496        }
497      }
498      if ((m_timesChosen[i] == 0) || with_replacement) {
499        // If the model hasn't been chosen, or if we can choose it more
500        // than once, try adding it:
501        // Get the predictions we would have if we added the model
502        tempPredictions = computePredictions(i, true);
503        // And find out how the hypothetical ensemble would perform.
504        double metric_value = evaluatePredictions(instances,
505            tempPredictions, metric);
506        if (metric_value > bestPerformance) {
507          // If it's better than our current best, make it our NEW
508          // best.
509          bestIndex = i;
510          bestPerformance = metric_value;
511          added = true;
512        }
513      }
514    }
515    if (bestIndex == -1) {
516      // Shouldn't really happen. Possible (I think) if the model bag is
517      // empty. Just return.
518      if (m_debug) {
519        System.out.println("Couldn't add or remove model.  No action performed.");
520      }
521      return;
522    }
523    // Now we've found the best change to make:
524    // * bestIndex is the (virtual) index of the model we should change
525    // * added is true if the model should be added (false if should be
526    // removed)
527    int changeInWeight = added ? 1 : -1;
528    m_timesChosen[bestIndex] += changeInWeight;
529    m_numChosen += changeInWeight;
530    if (bestPerformance > m_bestPerformance) {
531      // We find the peak of our performance over all hillclimb
532      // iterations.
533      // If this forwardSelect step improved our overall performance,
534      // update
535      // our best ensemble info.
536      updateBestTimesChosen();
537      m_bestPerformance = bestPerformance;
538    }
539  }
540 
541  /**
542   * returns the model weights
543   *
544   * @return            the model weights
545   */
546  public int[] getModelWeights() {
547    return m_bestTimesChosen;
548  }
549 
550  /**
551   * Returns the "model" at the given virtual index. Here, by "model" we mean
552   * its predictions with respect to the validation set. This is just a
553   * convenience method, since we use the "virtual" index more than the real
554   * one inside this class.
555   *
556   * @param index
557   *            the "virtual" index - the one for internal use
558   * @return the predictions for the model for all validation instances.
559   */
560  private double[][] model(int index) {
561    return m_models[m_modelIndex[index]];
562  }
563 
564  /**
565   * Compute predictions based on the current model, adding (or removing) the
566   * model at the given (internal) index.
567   *
568   * @param index_to_change
569   *            index of model we're adding or removing
570   * @param add
571   *            whether we add it. If false, we remove it.
572   * @return the predictions for all validation instances
573   */
574  private double[][] computePredictions(int index_to_change, boolean add) {
575    double[][] predictions = new double[m_models[0].length][m_models[0][0].length];
576    for (int i = 0; i < m_bagSize; ++i) {
577      if (m_timesChosen[i] > 0) {
578        for (int j = 0; j < m_models[0].length; ++j) {
579          for (int k = 0; k < m_models[0][j].length; ++k) {
580            predictions[j][k] += model(i)[j][k] * m_timesChosen[i];
581          }
582        }
583      }
584    }
585    for (int j = 0; j < m_models[0].length; ++j) {
586      int change = add ? 1 : -1;
587      for (int k = 0; k < m_models[0][j].length; ++k) {
588        predictions[j][k] += change * model(index_to_change)[j][k];
589        predictions[j][k] /= (m_numChosen + change);
590      }
591    }
592    return predictions;
593  }
594 
595  /**
596   * Return the performance of the given predictions on the given instances
597   * with respect to the given metric (see EnsembleMetricHelper).
598   *
599   * @param instances
600   *            the validation data
601   * @param temp_predictions
602   *            the predictions to evaluate
603   * @param metric
604   *            the metric for which to optimize (see EnsembleMetricHelper)
605   * @return the performance
606   * @throws Exception if something goes wrong
607   */
608  private double evaluatePredictions(Instances instances,
609      double[][] temp_predictions, int metric) throws Exception {
610   
611    Evaluation eval = new Evaluation(instances);
612    for (int i = 0; i < instances.numInstances(); ++i) {
613      eval.evaluateModelOnceAndRecordPrediction(temp_predictions[i],
614          instances.instance(i));
615    }
616    return EnsembleMetricHelper.getMetric(eval, metric);
617  }
618 
619  /**
620   * Gets the individual performances of all the models in the bag.
621   *
622   * @param instances
623   *            The validation data, for which we want performance.
624   * @param metric
625   *            The desired metric (see EnsembleMetricHelper).
626   * @return the performance
627   * @throws Exception if something goes wrong
628   */
629  public double[] getIndividualPerformance(Instances instances, int metric)
630    throws Exception {
631   
632    double[] performance = new double[m_bagSize];
633    for (int i = 0; i < m_bagSize; ++i) {
634      performance[i] = evaluatePredictions(instances, model(i), metric);
635    }
636    return performance;
637  }
638 
639  /**
640   * Returns the revision string.
641   *
642   * @return            the revision
643   */
644  public String getRevision() {
645    return RevisionUtils.extract("$Revision: 1.2 $");
646  }
647}
Note: See TracBrowser for help on using the repository browser.