source: src/main/java/weka/gui/boundaryvisualizer/BoundaryPanelDistributed.java @ 7

Last change on this file since 7 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 20.1 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *   BoundaryPanelDistrubuted.java
19 *   Copyright (C) 2003 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.gui.boundaryvisualizer;
24
25import weka.classifiers.Classifier;
26import weka.classifiers.AbstractClassifier;
27import weka.core.FastVector;
28import weka.core.Instances;
29import weka.core.Utils;
30import weka.experiment.Compute;
31import weka.experiment.RemoteExperimentEvent;
32import weka.experiment.RemoteExperimentListener;
33import weka.experiment.TaskStatusInfo;
34
35import java.awt.BorderLayout;
36import java.io.BufferedReader;
37import java.io.FileInputStream;
38import java.io.FileReader;
39import java.io.ObjectInputStream;
40import java.rmi.Naming;
41import java.util.Vector;
42
43/**
44 * This class extends BoundaryPanel with code for distributing the
45 * processing necessary to create a visualization among a list of
46 * remote machines. Specifically, a visualization is broken down and
47 * processed row by row using the available remote computers.
48 *
49 * @author <a href="mailto:mhall@cs.waikato.ac.nz">Mark Hall</a>
50 * @version $Revision: 5928 $
51 * @since 1.0
52 * @see BoundaryPanel
53 */
54public class BoundaryPanelDistributed
55  extends BoundaryPanel {
56
57  /** for serialization */
58  private static final long serialVersionUID = -1743284397893937776L;
59
60  /** a list of RemoteExperimentListeners */
61  protected Vector m_listeners = new Vector();
62
63  /** Holds the names of machines with remoteEngine servers running */
64  protected Vector m_remoteHosts = new Vector();
65 
66  /** The queue of available hosts */
67  private weka.core.Queue m_remoteHostsQueue = new weka.core.Queue();
68
69  /** The status of each of the remote hosts */
70  private int [] m_remoteHostsStatus;
71
72  /** The number of times tasks have failed on each remote host */
73  private int [] m_remoteHostFailureCounts;
74
75  protected static final int AVAILABLE=0;
76  protected static final int IN_USE=1;
77  protected static final int CONNECTION_FAILED=2;
78  protected static final int SOME_OTHER_FAILURE=3;
79
80  protected static final int MAX_FAILURES=3;
81
82  /** Set to true if MAX_FAILURES exceeded on all hosts or connections fail
83      on all hosts or user aborts plotting */
84  private boolean m_plottingAborted = false;
85
86  /** The number of hosts removed due to exceeding max failures */
87  private int m_removedHosts;
88
89  /** The count of failed sub-tasks */
90  private int m_failedCount;
91
92  /** The count of successfully completed sub-tasks */
93  private int m_finishedCount;
94
95  /** The queue of sub-tasks waiting to be processed */
96  private weka.core.Queue m_subExpQueue = new weka.core.Queue();
97
98  /** number of seconds between polling server */
99  private int m_minTaskPollTime = 1000;
100
101  private int [] m_hostPollingTime;
102
103  /**
104   * Creates a new <code>BoundaryPanelDistributed</code> instance.
105   *
106   * @param panelWidth width of the display
107   * @param panelHeight height of the display
108   */
109  public BoundaryPanelDistributed(int panelWidth, int panelHeight) {
110    super(panelWidth, panelHeight);
111  }
112
113  /**
114   * Set a list of host names of machines to distribute processing to
115   *
116   * @param remHosts a Vector of host names (Strings)
117   */
118  public void setRemoteHosts(Vector remHosts) {
119    m_remoteHosts = remHosts;
120  }
121
122  /**
123   * Add an object to the list of those interested in recieving update
124   * information from the RemoteExperiment
125   * @param r a listener
126   */
127  public void addRemoteExperimentListener(RemoteExperimentListener r) {
128    m_listeners.addElement(r);
129  }
130
131  protected void initialize() {
132    super.initialize();
133
134    m_plottingAborted = false;
135    m_finishedCount = 0;
136    m_failedCount = 0;
137
138    // initialize all remote hosts to available
139    m_remoteHostsStatus = new int [m_remoteHosts.size()];   
140    m_remoteHostFailureCounts = new int [m_remoteHosts.size()];
141
142    m_remoteHostsQueue = new weka.core.Queue();
143
144    if (m_remoteHosts.size() == 0) {
145      System.err.println("No hosts specified!");
146      System.exit(1);
147    }
148
149    // prime the hosts queue
150    m_hostPollingTime = new int [m_remoteHosts.size()];
151    for (int i=0;i<m_remoteHosts.size();i++) {
152      m_remoteHostsQueue.push(new Integer(i));
153      m_hostPollingTime[i] = m_minTaskPollTime;
154    }
155
156    // set up sub taskss (just holds the row numbers to be processed
157    m_subExpQueue = new weka.core.Queue();
158    for (int i = 0; i < m_panelHeight; i++) {
159      m_subExpQueue.push(new Integer(i));
160    }
161   
162    try {
163      // need to build classifier and data generator
164      m_classifier.buildClassifier(m_trainingData);
165    } catch (Exception ex) {
166      ex.printStackTrace();
167      System.exit(1);
168    }
169   
170    boolean [] attsToWeightOn;
171    // build DataGenerator
172    attsToWeightOn = new boolean[m_trainingData.numAttributes()];
173    attsToWeightOn[m_xAttribute] = true;
174    attsToWeightOn[m_yAttribute] = true;
175   
176    m_dataGenerator.setWeightingDimensions(attsToWeightOn);   
177    try {
178      m_dataGenerator.buildGenerator(m_trainingData);
179    } catch (Exception ex) {
180      ex.printStackTrace();
181      System.exit(1);
182    }
183  }
184
185  /**
186   * Start processing
187   *
188   * @exception Exception if an error occurs
189   */
190  public void start() throws Exception {
191    // done in the sub task
192    /*     m_numOfSamplesPerGenerator =
193           (int)Math.pow(m_samplesBase, m_trainingData.numAttributes()-3); */
194
195    m_stopReplotting = true;
196    if (m_trainingData == null) {
197      throw new Exception("No training data set (BoundaryPanel)");
198    }
199    if (m_classifier == null) {
200      throw new Exception("No classifier set (BoundaryPanel)");
201    }
202    if (m_dataGenerator == null) {
203      throw new Exception("No data generator set (BoundaryPanel)");
204    }
205    if (m_trainingData.attribute(m_xAttribute).isNominal() || 
206        m_trainingData.attribute(m_yAttribute).isNominal()) {
207      throw new Exception("Visualization dimensions must be numeric "
208                          +"(BoundaryPanel)");
209    }
210   
211    computeMinMaxAtts();
212    initialize();
213
214    // launch tasks on all available hosts
215    int totalHosts = m_remoteHostsQueue.size();
216    for (int i = 0; i < totalHosts; i++) {
217      availableHost(-1);
218      Thread.sleep(70);
219    }
220  }
221
222  /**
223   * Push a host back onto the list of available hosts and launch a waiting
224   * Task (if any).
225   *
226   * @param hostNum the number of the host to return to the queue. -1
227   * if no host to return.
228   */
229  protected synchronized void availableHost(int hostNum) {
230    if (hostNum >= 0) { 
231      if (m_remoteHostFailureCounts[hostNum] < MAX_FAILURES) {
232        m_remoteHostsQueue.push(new Integer(hostNum));
233      } else {
234        notifyListeners(false,true,false,"Max failures exceeded for host "
235                        +((String)m_remoteHosts.elementAt(hostNum))
236                        +". Removed from host list.");
237        m_removedHosts++;
238      }
239    }
240
241    // check for all sub exp complete or all hosts failed or failed count
242    // exceeded
243    if (m_failedCount == (MAX_FAILURES * m_remoteHosts.size())) {
244      m_plottingAborted = true;
245      notifyListeners(false,true,true,"Plotting aborted! Max failures "
246                      +"exceeded on all remote hosts.");
247      return;
248    }
249
250    /*    System.err.println("--------------");
251    System.err.println("exp q :"+m_subExpQueue.size());
252    System.err.println("host list size "+m_remoteHosts.size());
253    System.err.println("actual host list size "+m_remoteHostsQueue.size());
254    System.err.println("removed hosts "+m_removedHosts); */
255    if (m_subExpQueue.size() == 0 && 
256        (m_remoteHosts.size() == 
257         (m_remoteHostsQueue.size() + m_removedHosts))) {
258      if (m_plotTrainingData) {
259        plotTrainingData();
260      }
261      notifyListeners(false,true,true,"Plotting completed successfully.");
262
263      return;
264    }
265
266
267    if (checkForAllFailedHosts()) {
268      return;
269    }
270
271    if (m_plottingAborted && 
272        (m_remoteHostsQueue.size() + m_removedHosts) == 
273        m_remoteHosts.size()) {
274      notifyListeners(false,true,true,"Plotting aborted. All remote tasks "
275                      +"finished.");
276    }
277
278    if (!m_subExpQueue.empty() && !m_plottingAborted) {
279      if (!m_remoteHostsQueue.empty()) {
280        int availHost, waitingTask;
281        try {
282          availHost = ((Integer)m_remoteHostsQueue.pop()).intValue();
283          waitingTask = ((Integer)m_subExpQueue.pop()).intValue();
284          launchNext(waitingTask, availHost);
285        } catch (Exception ex) {
286          ex.printStackTrace();
287        }
288      }
289    }   
290  }
291
292  /**
293   * Inform all listeners of progress
294   * @param status true if this is a status type of message
295   * @param log true if this is a log type of message
296   * @param finished true if the remote task has finished
297   * @param message the message.
298   */
299  private synchronized void notifyListeners(boolean status, 
300                                            boolean log, 
301                                            boolean finished,
302                                            String message) {
303    if (m_listeners.size() > 0) {
304      for (int i=0;i<m_listeners.size();i++) {
305        RemoteExperimentListener r = 
306          (RemoteExperimentListener)(m_listeners.elementAt(i));
307        r.remoteExperimentStatus(new RemoteExperimentEvent(status,
308                                                           log,
309                                                           finished,
310                                                           message));
311      }
312    } else {
313      System.err.println(message);
314    }
315  }
316
317  /**
318   * Check to see if we have failed to connect to all hosts
319   */
320  private boolean checkForAllFailedHosts() {
321    boolean allbad = true;
322    for (int i = 0; i < m_remoteHostsStatus.length; i++) {
323      if (m_remoteHostsStatus[i] != CONNECTION_FAILED) {
324        allbad = false;
325        break;
326      }
327    }
328    if (allbad) {
329      m_plottingAborted = true;
330      notifyListeners(false,true,true,"Plotting aborted! All connections "
331                      +"to remote hosts failed.");
332    }
333    return allbad;
334  }
335
336  /**
337   * Increment the number of successfully completed sub experiments
338   */
339  protected synchronized void incrementFinished() {
340    m_finishedCount++;
341  }
342
343  /**
344   * Increment the overall number of failures and the number of failures for
345   * a particular host
346   * @param hostNum the index of the host to increment failure count
347   */
348  protected synchronized void incrementFailed(int hostNum) {
349    m_failedCount++;
350    m_remoteHostFailureCounts[hostNum]++;
351  }
352
353  /**
354   * Push an experiment back on the queue of waiting experiments
355   * @param expNum the index of the experiment to push onto the queue
356   */
357  protected synchronized void waitingTask(int expNum) {
358    m_subExpQueue.push(new Integer(expNum));
359  }
360
361  protected void launchNext(final int wtask, final int ah) {
362    Thread subTaskThread;
363    subTaskThread = new Thread() {
364        public void run() {
365          m_remoteHostsStatus[ah] = IN_USE;
366          //      m_subExpComplete[wtask] = TaskStatusInfo.PROCESSING;
367          RemoteBoundaryVisualizerSubTask vSubTask = 
368            new RemoteBoundaryVisualizerSubTask();
369          vSubTask.setXAttribute(m_xAttribute);
370          vSubTask.setYAttribute(m_yAttribute);
371          vSubTask.setRowNumber(wtask);
372          vSubTask.setPanelWidth(m_panelWidth);
373          vSubTask.setPanelHeight(m_panelHeight);
374          vSubTask.setPixHeight(m_pixHeight);
375          vSubTask.setPixWidth(m_pixWidth);
376          vSubTask.setClassifier(m_classifier);
377          vSubTask.setDataGenerator(m_dataGenerator);
378          vSubTask.setInstances(m_trainingData);
379          vSubTask.setMinMaxX(m_minX, m_maxX);
380          vSubTask.setMinMaxY(m_minY, m_maxY);
381          vSubTask.setNumSamplesPerRegion(m_numOfSamplesPerRegion);
382          vSubTask.setGeneratorSamplesBase(m_samplesBase);
383          try {
384            String name = "//"
385              +((String)m_remoteHosts.elementAt(ah))
386              +"/RemoteEngine";
387            Compute comp = (Compute) Naming.lookup(name);
388            // assess the status of the sub-exp
389            notifyListeners(false,true,false,"Starting row "
390                            +wtask
391                            +" on host "
392                            +((String)m_remoteHosts.elementAt(ah)));
393            Object subTaskId = comp.executeTask(vSubTask);
394            boolean finished = false;
395            TaskStatusInfo is = null;
396            long startTime = System.currentTimeMillis();
397            while (!finished) {
398              try {
399                Thread.sleep(Math.max(m_minTaskPollTime, 
400                                      m_hostPollingTime[ah]));
401               
402                TaskStatusInfo cs = (TaskStatusInfo)comp.
403                  checkStatus(subTaskId);
404                if (cs.getExecutionStatus() == TaskStatusInfo.FINISHED) {
405                  // push host back onto queue and try launching any waiting
406                  // sub-experiments
407                  long runTime = System.currentTimeMillis() - startTime;
408                  runTime /= 4;
409                  if (runTime < 1000) {
410                    runTime = 1000;
411                  }
412                  m_hostPollingTime[ah] = (int)runTime;
413
414                  // Extract the row from the result
415                  RemoteResult rr =  (RemoteResult)cs.getTaskResult();
416                  double [][] probs = rr.getProbabilities();
417                 
418                  for (int i = 0; i < m_panelWidth; i++) {
419                    m_probabilityCache[wtask][i] = probs[i];
420                    if (i < m_panelWidth-1) {
421                      plotPoint(i, wtask, probs[i], false);
422                    } else {
423                      plotPoint(i, wtask, probs[i], true);
424                    }
425                  }
426                  notifyListeners(false, true, false,  cs.getStatusMessage());
427                  m_remoteHostsStatus[ah] = AVAILABLE;
428                  incrementFinished();
429                  availableHost(ah);
430                  finished = true;
431                } else if (cs.getExecutionStatus() == 
432                           TaskStatusInfo.FAILED) {
433                  // a non connection related error---possibly host doesn't have
434                  // access to data sets or security policy is not set up
435                  // correctly or classifier(s) failed for some reason
436                  notifyListeners(false, true, false, 
437                                  cs.getStatusMessage());
438                  m_remoteHostsStatus[ah] = SOME_OTHER_FAILURE;
439                  //              m_subExpComplete[wexp] = TaskStatusInfo.FAILED;
440                  notifyListeners(false,true,false,"Row "+wtask
441                                  +" "+cs.getStatusMessage()
442                                  +". Scheduling for execution on another host.");
443                  incrementFailed(ah);
444                  // push experiment back onto queue
445                  waitingTask(wtask);   
446                  // push host back onto queue and try launching any waiting
447                  // Tasks. Host is pushed back on the queue as the
448                  // failure may be temporary.
449                  availableHost(ah);
450                  finished = true;
451                } else {
452                  if (is == null) {
453                    is = cs;
454                    notifyListeners(false, true, false, cs.getStatusMessage());
455                  } else {
456                    RemoteResult rr = (RemoteResult)cs.getTaskResult();
457                    if (rr != null) {
458                      int percentComplete = rr.getPercentCompleted();
459                      String timeRemaining = "";
460                      if (percentComplete > 0 && percentComplete < 100) {
461                        double timeSoFar = (double)System.currentTimeMillis() -
462                          (double)startTime;
463                        double timeToGo = 
464                          ((100.0 - percentComplete) 
465                           / (double)percentComplete) * timeSoFar;
466                        if (timeToGo < m_hostPollingTime[ah]) {
467                          m_hostPollingTime[ah] = (int)timeToGo;
468                        }
469                        String units = "seconds";
470                        timeToGo /= 1000.0;
471                        if (timeToGo > 60) {
472                          units = "minutes";
473                          timeToGo /= 60.0;
474                        }
475                        if (timeToGo > 60) {
476                          units = "hours";
477                          timeToGo /= 60.0;
478                        }
479                        timeRemaining = " (approx. time remaining "
480                          +Utils.doubleToString(timeToGo, 1)+" "+units+")";
481                      }
482                      if (percentComplete < 25 
483                          /*&& minTaskPollTime < 30000*/) {             
484                        if (percentComplete > 0) {
485                          m_hostPollingTime[ah] = 
486                            (int)((25.0 / (double)percentComplete) * 
487                                  m_hostPollingTime[ah]);
488                        } else {
489                          m_hostPollingTime[ah] *= 2;
490                        }
491                        if (m_hostPollingTime[ah] > 60000) {
492                          m_hostPollingTime[ah] = 60000;
493                        }
494                      }
495                      notifyListeners(false, true, false,
496                                      "Row "+wtask+" "+percentComplete
497                                      +"% complete"+timeRemaining+".");
498                    } else {
499                      notifyListeners(false, true, false,
500                                      "Row "+wtask+" queued on "
501                                      +((String)m_remoteHosts.
502                                        elementAt(ah)));
503                      if (m_hostPollingTime[ah] < 60000) {
504                        m_hostPollingTime[ah] *= 2;
505                      }
506                    }
507
508                    is = cs;
509                  }
510                }
511              } catch (InterruptedException ie) {
512                ie.printStackTrace();
513              }
514            }
515          } catch (Exception ce) {
516            m_remoteHostsStatus[ah] = CONNECTION_FAILED;
517            m_removedHosts++;
518            System.err.println(ce);
519            ce.printStackTrace();
520            notifyListeners(false,true,false,"Connection to "
521                            +((String)m_remoteHosts.elementAt(ah))
522                            +" failed. Scheduling row "
523                            +wtask
524                            +" for execution on another host.");
525            checkForAllFailedHosts();
526            waitingTask(wtask);
527          } finally {
528            if (isInterrupted()) {
529              System.err.println("Sub exp Interupted!");
530            }
531          }
532        }
533      };
534    subTaskThread.setPriority(Thread.MIN_PRIORITY);
535    subTaskThread.start();
536  }
537
538  /**
539   * Main method for testing this class
540   *
541   * @param args a <code>String[]</code> value
542   */
543  public static void main (String [] args) {
544    try {
545      if (args.length < 8) {
546        System.err.println("Usage : BoundaryPanelDistributed <dataset> "
547                           +"<class col> <xAtt> <yAtt> "
548                           +"<base> <# loc/pixel> <kernel bandwidth> "
549                           +"<display width> "
550                           +"<display height> <classifier "
551                           +"[classifier options]>");
552        System.exit(1);
553      }
554     
555      Vector hostNames = new Vector();
556      // try loading hosts file
557      try {
558        BufferedReader br = new BufferedReader(new FileReader("hosts.vis"));
559        String hostName = br.readLine();
560        while (hostName != null) {
561          System.out.println("Adding host "+hostName);
562          hostNames.add(hostName);
563          hostName = br.readLine();
564        }
565        br.close();
566      } catch (Exception ex) {
567        System.err.println("No hosts.vis file - create this file in "
568                           +"the current directory with one host name "
569                           +"per line, or use BoundaryPanel instead.");
570        System.exit(1);
571      }
572
573      final javax.swing.JFrame jf = 
574        new javax.swing.JFrame("Weka classification boundary visualizer");
575      jf.getContentPane().setLayout(new BorderLayout());
576
577      System.err.println("Loading instances from : "+args[0]);
578      java.io.Reader r = new java.io.BufferedReader(
579                         new java.io.FileReader(args[0]));
580      final Instances i = new Instances(r);
581      i.setClassIndex(Integer.parseInt(args[1]));
582
583      //      bv.setClassifier(new Logistic());
584      final int xatt = Integer.parseInt(args[2]);
585      final int yatt = Integer.parseInt(args[3]);
586      int base = Integer.parseInt(args[4]);
587      int loc = Integer.parseInt(args[5]);
588
589      int bandWidth = Integer.parseInt(args[6]);
590      int panelWidth = Integer.parseInt(args[7]);
591      int panelHeight = Integer.parseInt(args[8]);
592
593      final String classifierName = args[9];
594      final BoundaryPanelDistributed bv = 
595        new BoundaryPanelDistributed(panelWidth,panelHeight);
596      bv.addRemoteExperimentListener(new RemoteExperimentListener() {
597          public void remoteExperimentStatus(RemoteExperimentEvent e) {
598            if (e.m_experimentFinished) {
599              String classifierNameNew = 
600                classifierName.substring(classifierName.lastIndexOf('.')+1, 
601                                         classifierName.length());
602              bv.saveImage(classifierNameNew+"_"+i.relationName()
603                           +"_X"+xatt+"_Y"+yatt+".jpg");
604            } else {
605              System.err.println(e.m_messageString);
606            }
607          }
608        });
609      bv.setRemoteHosts(hostNames);
610
611      jf.getContentPane().add(bv, BorderLayout.CENTER);
612      jf.setSize(bv.getMinimumSize());
613      //      jf.setSize(200,200);
614      jf.addWindowListener(new java.awt.event.WindowAdapter() {
615          public void windowClosing(java.awt.event.WindowEvent e) {
616            jf.dispose();
617            System.exit(0);
618          }
619        });
620
621      jf.pack();
622      jf.setVisible(true);
623      //      bv.initialize();
624      bv.repaint();
625     
626
627      String [] argsR = null;
628      if (args.length > 10) {
629        argsR = new String [args.length-10];
630        for (int j = 10; j < args.length; j++) {
631          argsR[j-10] = args[j];
632        }
633      }
634      Classifier c = AbstractClassifier.forName(args[9], argsR);
635      KDDataGenerator dataGen = new KDDataGenerator();
636      dataGen.setKernelBandwidth(bandWidth);
637      bv.setDataGenerator(dataGen);
638      bv.setNumSamplesPerRegion(loc);
639      bv.setGeneratorSamplesBase(base);
640      bv.setClassifier(c);
641      bv.setTrainingData(i);
642      bv.setXAttribute(xatt);
643      bv.setYAttribute(yatt);
644
645      try {
646        // try and load a color map if one exists
647        FileInputStream fis = new FileInputStream("colors.ser");
648        ObjectInputStream ois = new ObjectInputStream(fis);
649        FastVector colors = (FastVector)ois.readObject();
650        bv.setColors(colors);   
651      } catch (Exception ex) {
652        System.err.println("No color map file");
653      }
654      bv.start();
655    } catch (Exception ex) {
656      ex.printStackTrace();
657    }
658  }
659}
Note: See TracBrowser for help on using the repository browser.