source: src/main/java/weka/classifiers/ParallelIteratedSingleClassifierEnhancer.java @ 22

Last change on this file since 22 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 8.1 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    ParallelIteratedSingleClassifierEnhancer.java
19 *    Copyright (C) 2009 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers;
24
25import java.util.Enumeration;
26import java.util.Vector;
27import java.util.concurrent.LinkedBlockingQueue;
28import java.util.concurrent.ThreadPoolExecutor;
29import java.util.concurrent.TimeUnit;
30
31import weka.core.Instances;
32import weka.core.Option;
33import weka.core.Utils;
34
35/**
36 * Abstract utility class for handling settings common to
37 * meta classifiers that build an ensemble in parallel from a single
38 * base learner.
39 *
40 * @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
41 * @version $Revision: 6041 $
42 */
43public abstract class ParallelIteratedSingleClassifierEnhancer extends
44    IteratedSingleClassifierEnhancer {
45
46  /** For serialization */
47  private static final long serialVersionUID = -5026378741833046436L;
48
49  /** The number of threads to have executing at any one time */
50  protected int m_numExecutionSlots = 1;
51
52  /** Pool of threads to train models with */
53  protected transient ThreadPoolExecutor m_executorPool;
54
55  /** The number of classifiers completed so far */
56  protected int m_completed;
57
58  /**
59   * The number of classifiers that experienced a failure of some sort
60   * during construction
61   */
62  protected int m_failed;
63
64  /**
65   * Returns an enumeration describing the available options.
66   *
67   * @return an enumeration of all the available options.
68   */
69  public Enumeration listOptions() {
70
71    Vector newVector = new Vector(2);
72
73    newVector.addElement(new Option(
74              "\tNumber of execution slots.\n"
75              + "\t(default 1 - i.e. no parallelism)",
76              "num-slots", 1, "-num-slots <num>"));
77
78    Enumeration enu = super.listOptions();
79    while (enu.hasMoreElements()) {
80      newVector.addElement(enu.nextElement());
81    }
82    return newVector.elements();
83  }
84
85  /**
86   * Parses a given list of options. Valid options are:<p>
87   *
88   * -Z num <br>
89   * Set the number of execution slots to use (default 1 - i.e. no parallelism). <p>
90   *
91   * Options after -- are passed to the designated classifier.<p>
92   *
93   * @param options the list of options as an array of strings
94   * @exception Exception if an option is not supported
95   */
96  public void setOptions(String[] options) throws Exception {
97
98    String iterations = Utils.getOption("num-slots", options);
99    if (iterations.length() != 0) {
100      setNumExecutionSlots(Integer.parseInt(iterations));
101    } else {
102      setNumExecutionSlots(1);
103    }
104
105    super.setOptions(options);
106  }
107
108  /**
109   * Gets the current settings of the classifier.
110   *
111   * @return an array of strings suitable for passing to setOptions
112   */
113  public String [] getOptions() {
114
115    String [] superOptions = super.getOptions();
116    String [] options = new String [superOptions.length + 2];
117
118    int current = 0;
119    options[current++] = "-num-slots";
120    options[current++] = "" + getNumExecutionSlots();
121
122    System.arraycopy(superOptions, 0, options, current,
123                     superOptions.length);
124
125    return options;
126  }
127
128  /**
129   * Set the number of execution slots (threads) to use for building the
130   * members of the ensemble.
131   *
132   * @param numSlots the number of slots to use.
133   */
134  public void setNumExecutionSlots(int numSlots) {
135    m_numExecutionSlots = numSlots;
136  }
137
138  /**
139   * Get the number of execution slots (threads) to use for building
140   * the members of the ensemble.
141   *
142   * @return the number of slots to use
143   */
144  public int getNumExecutionSlots() {
145    return m_numExecutionSlots;
146  }
147
148  /**
149   * Returns the tip text for this property
150   * @return tip text for this property suitable for
151   * displaying in the explorer/experimenter gui
152   */
153  public String numExecutionSlotsTipText() {
154    return "The number of execution slots (threads) to use for " +
155      "constructing the ensemble.";
156  }
157
158  /**
159   * Stump method for building the classifiers
160   *
161   * @param data the training data to be used for generating the ensemble
162   * @exception Exception if the classifier could not be built successfully
163   */
164  public void buildClassifier(Instances data) throws Exception {
165    super.buildClassifier(data);
166
167    if (m_numExecutionSlots < 1) {
168      throw new Exception("Number of execution slots needs to be >= 1!");
169    }
170
171    if (m_numExecutionSlots > 1) {
172      startExecutorPool();
173    }
174    m_completed = 0;
175    m_failed = 0;
176  }
177
178  /**
179   * Start the pool of execution threads
180   */
181  protected void startExecutorPool() {
182    if (m_executorPool != null) {
183      m_executorPool.shutdownNow();
184    }
185
186    m_executorPool = new ThreadPoolExecutor(m_numExecutionSlots, m_numExecutionSlots,
187        120, TimeUnit.SECONDS, new LinkedBlockingQueue<Runnable>());
188  }
189
190  private synchronized void block(boolean tf) {
191    if (tf) {
192      try {
193        wait();
194      } catch (InterruptedException ex) {
195      }
196    } else {
197      notifyAll();
198    }
199  }
200
201  /**
202   * Does the actual construction of the ensemble
203   *
204   * @throws Exception if something goes wrong during the training
205   * process
206   */
207  protected synchronized void buildClassifiers() throws Exception {
208
209    for (int i = 0; i < m_Classifiers.length; i++) {
210      if (m_numExecutionSlots > 1) {
211        final Classifier currentClassifier = m_Classifiers[i];
212        final int iteration = i;
213        if (m_Debug) {
214          System.out.print("Training classifier (" + (i +1) + ")");
215        }
216        Runnable newTask = new Runnable() {
217          public void run() {
218            try {
219              currentClassifier.buildClassifier(getTrainingSet(iteration));
220              completedClassifier(iteration, true);
221            } catch (Exception ex) {
222              ex.printStackTrace();
223              completedClassifier(iteration, false);
224            }
225          }
226        };
227
228        // launch this task
229        m_executorPool.execute(newTask);
230      } else {
231        m_Classifiers[i].buildClassifier(getTrainingSet(i));
232      }
233    }
234
235    if (m_numExecutionSlots > 1 && m_completed + m_failed < m_Classifiers.length) {
236      block(true);
237    }
238  }
239
240  /**
241   * Records the completion of the training of a single classifier. Unblocks if
242   * all classifiers have been trained.
243   *
244   * @param iteration the iteration that has completed
245   * @param success whether the classifier trained successfully
246   */
247  protected synchronized void completedClassifier(int iteration,
248      boolean success) {
249    m_completed++;
250
251    if (!success) {
252      m_failed++;
253      if (m_Debug) {
254        System.err.println("Iteration " + iteration + " failed!");
255      }
256    }
257
258    if (m_completed + m_failed == m_Classifiers.length) {
259      if (m_failed > 0) {
260        if (m_Debug) {
261          System.err.println("Problem building classifiers - some iterations failed.");
262        }
263      }
264
265      // have to shut the pool down or program executes as a server
266      // and when running from the command line does not return to the
267      // prompt
268      m_executorPool.shutdown();
269      block(false);
270    }
271  }
272
273  /**
274   * Gets a training set for a particular iteration. Implementations need
275   * to be careful with thread safety and should probably be synchronized
276   * to be on the safe side.
277   *
278   * @param iteration the number of the iteration for the requested training set
279   * @return the training set for the supplied iteration number
280   * @throws Exception if something goes wrong.
281   */
282  protected abstract Instances getTrainingSet(int iteration) throws Exception;
283}
Note: See TracBrowser for help on using the repository browser.