source: src/main/java/weka/associations/HotSpot.java @ 16

Last change on this file since 16 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 42.1 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    HotSpot.java
19 *    Copyright (C) 2008 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.associations;
24
25import java.util.PriorityQueue;
26import java.util.HashMap;
27import java.util.ArrayList;
28import java.util.Vector;
29import java.util.Enumeration;
30import java.io.Serializable;
31import weka.core.Instances;
32import weka.core.Instance;
33import weka.core.Attribute;
34import weka.core.Utils;
35import weka.core.OptionHandler;
36import weka.core.Option;
37import weka.core.SingleIndex;
38import weka.core.Drawable;
39import weka.core.Capabilities.Capability;
40import weka.core.Capabilities;
41import weka.core.CapabilitiesHandler;
42import weka.core.RevisionHandler;
43import weka.core.RevisionUtils;
44
45/**
46 <!-- globalinfo-start -->
47 * HotSpot learns a set of rules (displayed in a tree-like structure) that maximize/minimize a target variable/value of interest. With a nominal target, one might want to look for segments of the data where there is a high probability of a minority value occuring (given the constraint of a minimum support). For a numeric target, one might be interested in finding segments where this is higher on average than in the whole data set. For example, in a health insurance scenario, find which health insurance groups are at the highest risk (have the highest claim ratio), or, which groups have the highest average insurance payout.
48 * <p/>
49 <!-- globalinfo-end -->
50 *
51 <!-- options-start -->
52 * Valid options are: <p/>
53 *
54 * <pre> -c &lt;num | first | last&gt;
55 *  The target index. (default = last)</pre>
56 *
57 * <pre> -V &lt;num | first | last&gt;
58 *  The target value (nominal target only, default = first)</pre>
59 *
60 * <pre> -L
61 *  Minimize rather than maximize.</pre>
62 *
63 * <pre> -S &lt;num&gt;
64 *  Minimum value count (nominal target)/segment size (numeric target).
65 *  Values between 0 and 1 are
66 *  interpreted as a percentage of
67 *  the total population; values &gt; 1 are
68 *  interpreted as an absolute number of
69 *  instances (default = 0.3)</pre>
70 *
71 * <pre> -M &lt;num&gt;
72 *  Maximum branching factor (default = 2)</pre>
73 *
74 * <pre> -I &lt;num&gt;
75 *  Minimum improvement in target value in order
76 *  to add a new branch/test (default = 0.01 (1%))</pre>
77 *
78 * <pre> -D
79 *  Output debugging info (duplicate rule lookup
80 *  hash table stats)</pre>
81 *
82 <!-- options-end -->
83 *
84 * @author Mark Hall (mhall{[at]}pentaho{[dot]}org
85 * @version $Revision: 6081 $
86 */
87public class HotSpot
88  implements Associator, OptionHandler, RevisionHandler, 
89             CapabilitiesHandler, Drawable, Serializable {
90
91  static final long serialVersionUID = 42972325096347677L;
92
93  /** index of the target attribute */
94  protected SingleIndex m_targetSI = new SingleIndex("last");
95  protected int m_target;
96 
97  /** Support as a fraction of the total training set */
98  protected double m_support;
99 
100  /** Support as an instance count */
101  private int m_supportCount;
102
103  /** The global value of the attribute of interest (mean or probability) */
104  protected double m_globalTarget;
105
106  /** The minimum improvement necessary to justify adding a test */
107  protected double m_minImprovement;
108
109  /** Actual global support of the target value (discrete target only) */
110  protected int m_globalSupport;
111
112  /** For discrete target, the index of the value of interest */
113  protected SingleIndex m_targetIndexSI = new SingleIndex("first");
114  protected int m_targetIndex;
115
116  /** At each level of the tree consider at most this number extensions */
117  protected int m_maxBranchingFactor;
118
119  /** Number of instances in the full data */
120  protected int m_numInstances;
121
122  /** The head of the tree */
123  protected HotNode m_head;
124
125  /** Header of the training data */
126  protected Instances m_header;
127
128  /** Debugging stuff */
129  protected int m_lookups = 0;
130  protected int m_insertions = 0;
131  protected int m_hits = 0;
132
133  protected boolean m_debug;
134 
135  /** Minimize, rather than maximize the target */
136  protected boolean m_minimize;
137
138  /** Error messages relating to too large/small support values */
139  protected String m_errorMessage;
140
141  /** Rule lookup table */
142  protected HashMap<HotSpotHashKey, String> m_ruleLookup;
143
144  /**
145   * Constructor
146   */
147  public HotSpot() {
148    resetOptions();
149  }
150
151  /**
152   * Returns a string describing this classifier
153   * @return a description of the classifier suitable for
154   * displaying in the explorer/experimenter gui
155   */
156  public String globalInfo() {
157    return "HotSpot learns a set of rules (displayed in a tree-like structure) "
158      + "that maximize/minimize a target variable/value of interest. "
159      + "With a nominal target, one might want to look for segments of the "
160      + "data where there is a high probability of a minority value occuring ("
161      + "given the constraint of a minimum support). For a numeric target, "
162      + "one might be interested in finding segments where this is higher "
163      + "on average than in the whole data set. For example, in a health "
164      + "insurance scenario, find which health insurance groups are at "
165      + "the highest risk (have the highest claim ratio), or, which groups "
166      + "have the highest average insurance payout.";
167  }
168 
169  /**
170   * Returns default capabilities of HotSpot
171   *
172   * @return      the capabilities of HotSpot
173   */
174  public Capabilities getCapabilities() {
175    Capabilities result = new Capabilities(this);
176    result.disableAll();
177
178    // attributes
179    result.enable(Capability.NOMINAL_ATTRIBUTES);
180    result.enable(Capability.NUMERIC_ATTRIBUTES);
181    result.enable(Capability.MISSING_VALUES);
182
183    // class
184    result.enable(Capability.NO_CLASS);
185    //result.enable(Capability.NUMERIC_CLASS);
186   // result.enable(Capability.NOMINAL_CLASS);
187
188   
189    return result;
190  }
191
192  /**
193   * Hash key class for sets of attribute, value tests
194   */
195  protected class HotSpotHashKey {
196    // split values, one for each attribute (0 indicates att not used).
197    // for nominal indexes, 1 is added so that 0 can indicate not used.
198    protected double[] m_splitValues;
199
200    // 0 = not used, 1 = "<=", 2 = "=", 3 = ">"
201    protected byte[] m_testTypes;
202
203    protected boolean m_computed = false;
204    protected int m_key;
205   
206    public HotSpotHashKey(double[] splitValues, byte[] testTypes) {
207      m_splitValues = splitValues.clone();
208      m_testTypes = testTypes.clone();
209    }
210
211    public boolean equals(Object b) {
212      if ((b == null) || !(b.getClass().equals(this.getClass()))) {
213        return false;
214      }
215      HotSpotHashKey comp = (HotSpotHashKey)b;
216      boolean ok = true;
217      for (int i = 0; i < m_splitValues.length; i++) {
218        if (m_splitValues[i] != comp.m_splitValues[i] ||
219            m_testTypes[i] != comp.m_testTypes[i]) {
220          ok = false;
221          break;
222        }
223      }
224      return ok;
225    }
226
227    public int hashCode() {
228
229      if (m_computed) {
230        return m_key;
231      } else {
232        int hv = 0;
233        for (int i = 0; i < m_splitValues.length; i++) {
234          hv += (m_splitValues[i] * 5 * i);
235          hv += (m_testTypes[i] * i * 3);
236        }
237        m_computed = true;
238
239        m_key = hv;
240      }
241      return m_key;
242    }
243  }
244
245  /**
246   * Build the tree
247   *
248   * @param instances the training instances
249   * @throws Exception if something goes wrong
250   */
251  public void buildAssociations(Instances instances) throws Exception {
252   
253    // can associator handle the data?
254    getCapabilities().testWithFail(instances);
255   
256    m_errorMessage = null;
257    m_targetSI.setUpper(instances.numAttributes() - 1);
258    m_target = m_targetSI.getIndex();
259    Instances inst = new Instances(instances);
260    inst.setClassIndex(m_target);
261    inst.deleteWithMissingClass();
262
263    if (inst.attribute(m_target).isNominal()) {
264      m_targetIndexSI.setUpper(inst.attribute(m_target).numValues() - 1);
265      m_targetIndex = m_targetIndexSI.getIndex();
266    } else {
267      m_targetIndexSI.setUpper(1); // just to stop this SingleIndex from moaning
268    }
269   
270    if (m_support <= 0) {
271      throw new Exception("Support must be greater than zero.");
272    }
273
274    m_numInstances = inst.numInstances();
275    if (m_support >= 1) {
276      m_supportCount = (int)m_support;
277      m_support = m_support / (double)m_numInstances;
278    }
279    m_supportCount = (int)Math.floor((m_support * m_numInstances) + 0.5d);
280    //    m_supportCount = (int)(m_support * m_numInstances);
281    if (m_supportCount < 1) {
282      m_supportCount = 1;
283    }
284
285    m_header = new Instances(inst, 0);
286
287    if (inst.attribute(m_target).isNumeric()) {
288      if (m_supportCount > m_numInstances) {
289        m_errorMessage = "Error: support set to more instances than there are in the data!";
290        return;
291      }
292      m_globalTarget = inst.meanOrMode(m_target);
293    } else {
294      double[] probs = new double[inst.attributeStats(m_target).nominalCounts.length];
295      for (int i = 0; i < probs.length; i++) {
296        probs[i] = (double)inst.attributeStats(m_target).nominalCounts[i];
297      }
298      m_globalSupport = (int)probs[m_targetIndex];
299      // check that global support is greater than min support
300      if (m_globalSupport < m_supportCount) {
301        m_errorMessage = "Error: minimum support " + m_supportCount
302          + " is too high. Target value " 
303          + m_header.attribute(m_target).value(m_targetIndex) + " has support " 
304          + m_globalSupport + ".";
305      }
306
307      Utils.normalize(probs);
308      m_globalTarget = probs[m_targetIndex];
309      /*      System.err.println("Global target " + m_globalTarget);
310              System.err.println("Min support count " + m_supportCount);  */
311    }
312   
313    m_ruleLookup = new HashMap<HotSpotHashKey, String>();
314    double[] splitVals = new double[m_header.numAttributes()];
315    byte[] tests = new byte[m_header.numAttributes()];
316
317    m_head = new HotNode(inst, m_globalTarget, splitVals, tests);
318    //    m_head = new HotNode(inst, m_globalTarget);
319  }
320
321  /**
322   * Return the tree as a string
323   *
324   * @return a String containing the tree
325   */
326  public String toString() {
327    StringBuffer buff = new StringBuffer();
328    buff.append("\nHot Spot\n========");
329    if (m_errorMessage != null) {
330      buff.append("\n\n" + m_errorMessage + "\n\n");
331      return buff.toString();
332    }
333    if (m_head == null) {
334      buff.append("No model built!");
335      return buff.toString();
336    }
337    buff.append("\nTotal population: ");
338    buff.append("" + m_numInstances + " instances");
339    buff.append("\nTarget attribute: " + m_header.attribute(m_target).name());
340    if (m_header.attribute(m_target).isNominal()) {
341      buff.append("\nTarget value: " + m_header.attribute(m_target).value(m_targetIndex));
342      buff.append(" [value count in total population: " + m_globalSupport + " instances ("
343                  + Utils.doubleToString((m_globalTarget * 100.0), 2) + "%)]");
344
345      buff.append("\nMinimum value count for segments: ");
346    } else {
347      buff.append("\nMinimum segment size: ");
348    }
349    buff.append("" + m_supportCount + " instances (" 
350                + Utils.doubleToString((m_support * 100.0), 2) 
351                + "% of total population)");
352    buff.append("\nMaximum branching factor: " + m_maxBranchingFactor);
353    buff.append("\nMinimum improvement in target: " 
354                + Utils.doubleToString((m_minImprovement * 100.0), 2) + "%");
355   
356    buff.append("\n\n");
357    buff.append(m_header.attribute(m_target).name());
358    if (m_header.attribute(m_target).isNumeric()) {
359      buff.append(" (" + Utils.doubleToString(m_globalTarget, 4) + ")");
360    } else {
361      buff.append("=" + m_header.attribute(m_target).value(m_targetIndex) + " (");
362      buff.append("" + Utils.doubleToString((m_globalTarget * 100.0), 2) + "% [");
363      buff.append("" + m_globalSupport
364                  + "/" + m_numInstances + "])");
365    }
366   
367    m_head.dumpTree(0, buff);
368    buff.append("\n");
369    if (m_debug) {
370      buff.append("\n=== Duplicate rule lookup hashtable stats ===\n");
371      buff.append("Insertions: "+ m_insertions);
372      buff.append("\nLookups : "+ m_lookups);
373      buff.append("\nHits: "+ m_hits);
374      buff.append("\n");
375    }
376    return buff.toString();
377  }
378
379  public String graph() throws Exception {
380    System.err.println("Here");
381    m_head.assignIDs(-1);
382
383    StringBuffer text = new StringBuffer();
384   
385    text.append("digraph HotSpot {\n");
386    text.append("rankdir=LR;\n");
387    text.append("N0 [label=\"" 
388                + m_header.attribute(m_target).name());
389   
390    if (m_header.attribute(m_target).isNumeric()) {
391      text.append("\\n(" + Utils.doubleToString(m_globalTarget, 4) + ")");
392    } else {
393      text.append("=" + m_header.attribute(m_target).value(m_targetIndex) + "\\n(");
394      text.append("" + Utils.doubleToString((m_globalTarget * 100.0), 2) + "% [");
395      text.append("" + m_globalSupport
396                  + "/" + m_numInstances + "])");
397    }
398    text.append("\" shape=plaintext]\n");
399
400    m_head.graphHotSpot(text);
401
402    text.append("}\n");
403    return text.toString();
404  }
405
406  /**
407   * Inner class representing a node/leaf in the tree
408   */
409  protected class HotNode implements Serializable {
410    /**
411     * An inner class holding data on a particular attribute value test
412     */
413    protected class HotTestDetails 
414      implements Comparable<HotTestDetails>,
415                 Serializable {
416      public double m_merit;
417      public int m_support;
418      public int m_subsetSize;
419      public int m_splitAttIndex;
420      public double m_splitValue;
421      public boolean m_lessThan;
422
423      public HotTestDetails(int attIndex, 
424                            double splitVal, 
425                            boolean lessThan,
426                            int support,
427                            int subsetSize,
428                            double merit) {
429        m_merit = merit;
430        m_splitAttIndex = attIndex;
431        m_splitValue = splitVal;
432        m_lessThan = lessThan;
433        m_support = support;
434        m_subsetSize = subsetSize;
435      }
436
437      // reverse order for maximize as PriorityQueue has the least element at the head
438      public int compareTo(HotTestDetails comp) {
439        int result = 0;
440        if (m_minimize) {
441          if (m_merit == comp.m_merit) {
442            // larger support is better
443            if (m_support == comp.m_support) {
444            } else if (m_support > comp.m_support) {
445              result = -1;
446            } else {
447              result = 1;
448            }
449          } else if (m_merit < comp.m_merit) {
450            result = -1;
451          } else {
452            result = 1;
453          }
454        } else {
455          if (m_merit == comp.m_merit) {
456            // larger support is better
457            if (m_support == comp.m_support) {
458            } else if (m_support > comp.m_support) {
459              result = -1;
460            } else {
461              result = 1;
462            }
463          } else if (m_merit < comp.m_merit) {
464            result = 1;
465          } else {
466            result = -1;
467          }
468        }
469        return result;
470      }
471    }
472
473    // the instances at this node
474    protected Instances m_insts;
475
476    // the value (to beat) of the target for these instances
477    protected double m_targetValue;
478
479    // child nodes
480    protected HotNode[] m_children;
481    protected HotTestDetails[] m_testDetails;
482
483    public int m_id;
484
485    /**
486     * Constructor
487     *
488     * @param insts the instances at this node
489     * @param targetValue the target value
490     * @param splitVals the values of attributes split on so far down this branch
491     * @param tests the types of tests corresponding to the split values (<=, =, >)
492     */
493    public HotNode(Instances insts, 
494                   double targetValue, 
495                   double[] splitVals,
496                   byte[] tests) {
497      m_insts = insts;
498      m_targetValue = targetValue;
499      PriorityQueue<HotTestDetails> splitQueue = new PriorityQueue<HotTestDetails>();
500
501      // Consider each attribute
502      for (int i = 0; i < m_insts.numAttributes(); i++) {
503        if (i != m_target) {
504          if (m_insts.attribute(i).isNominal()) {
505            evaluateNominal(i, splitQueue);
506          } else {
507            evaluateNumeric(i, splitQueue);
508          }
509        }
510      }
511
512      if (splitQueue.size() > 0) {
513        int queueSize = splitQueue.size();
514
515        // count how many of the potential children are unique
516        ArrayList<HotTestDetails> newCandidates = new ArrayList<HotTestDetails>();
517        ArrayList<HotSpotHashKey> keyList = new ArrayList<HotSpotHashKey>();
518
519        for (int i = 0; i < queueSize; i++) {
520          if (newCandidates.size() < m_maxBranchingFactor) {
521            HotTestDetails temp = splitQueue.poll();
522            double[] newSplitVals = splitVals.clone();
523            byte[] newTests = tests.clone();
524            newSplitVals[temp.m_splitAttIndex] = temp.m_splitValue + 1;
525            newTests[temp.m_splitAttIndex] = 
526              (m_header.attribute(temp.m_splitAttIndex).isNominal())
527              ? (byte)2 // ==
528              : (temp.m_lessThan)
529              ? (byte)1
530              : (byte)3;
531            HotSpotHashKey key = new HotSpotHashKey(newSplitVals, newTests);
532            m_lookups++;
533            if (!m_ruleLookup.containsKey(key)) {
534              // insert it into the hash table
535              m_ruleLookup.put(key, "");           
536              newCandidates.add(temp);
537              keyList.add(key);
538              m_insertions++;
539            } else {
540              m_hits++;
541            }
542          } else {
543            break;
544          }
545        }
546
547        m_children = new HotNode[(newCandidates.size() < m_maxBranchingFactor)
548                                 ? newCandidates.size()
549                                 : m_maxBranchingFactor];
550        // save the details of the tests at this node
551        m_testDetails = new HotTestDetails[m_children.length];
552        for (int i = 0; i < m_children.length; i++) {
553          m_testDetails[i] = newCandidates.get(i);
554        }
555
556        // save memory
557        splitQueue = null;
558        newCandidates = null;
559        m_insts = new Instances(m_insts, 0);
560
561        // process the children
562        for (int i = 0; i < m_children.length; i++) {
563          Instances subset = subset(insts, m_testDetails[i]);
564          HotSpotHashKey tempKey = keyList.get(i);
565          m_children[i] = new HotNode(subset, m_testDetails[i].m_merit, 
566                                      tempKey.m_splitValues, tempKey.m_testTypes);
567
568        }
569      }
570    }
571
572    /**
573     * Create a subset of instances that correspond to the supplied test details
574     *
575     * @param insts the instances to create the subset from
576     * @param test the details of the split
577     */
578    private Instances subset(Instances insts, HotTestDetails test) {
579      Instances sub = new Instances(insts, insts.numInstances());
580      for (int i = 0; i < insts.numInstances(); i++) {
581        Instance temp = insts.instance(i);
582        if (!temp.isMissing(test.m_splitAttIndex)) {
583          if (insts.attribute(test.m_splitAttIndex).isNominal()) {
584            if (temp.value(test.m_splitAttIndex) == test.m_splitValue) {
585              sub.add(temp);
586            }
587          } else {
588            if (test.m_lessThan) {
589              if (temp.value(test.m_splitAttIndex) <= test.m_splitValue) {
590                sub.add(temp);
591              }
592            } else {
593              if (temp.value(test.m_splitAttIndex) > test.m_splitValue) {
594                sub.add(temp);
595              }
596            }
597          }
598        }
599      }
600      sub.compactify();
601      return sub;
602    }
603
604    /**
605     * Evaluate a numeric attribute for a potential split
606     *
607     * @param attIndex the index of the attribute
608     * @param pq the priority queue of candidtate splits
609     */
610    private void evaluateNumeric(int attIndex, PriorityQueue<HotTestDetails> pq) {
611      Instances tempInsts = m_insts;
612      tempInsts.sort(attIndex);
613     
614      // target sums/counts
615      double targetLeft = 0;
616      double targetRight = 0;
617
618      int numMissing = 0;
619      // count missing values and sum/counts for the initial right subset
620      for (int i = tempInsts.numInstances() - 1; i >= 0; i--) {
621        if (!tempInsts.instance(i).isMissing(attIndex)) {
622          targetRight += (tempInsts.attribute(m_target).isNumeric())
623            ? (tempInsts.instance(i).value(m_target))
624            : ((tempInsts.instance(i).value(m_target) == m_targetIndex)
625               ? 1
626               : 0);
627        } else {
628          numMissing++;
629        }
630      }
631     
632      // are there still enough instances?
633      if (tempInsts.numInstances() - numMissing <= m_supportCount) {
634        return;
635      }
636     
637      double bestMerit = 0.0;
638      double bestSplit = 0.0;
639      double bestSupport = 0.0;
640      double bestSubsetSize = 0;
641      boolean lessThan = true;
642
643      // denominators
644      double leftCount = 0;
645      double rightCount = tempInsts.numInstances() - numMissing;
646           
647      /*      targetRight = (tempInsts.attribute(m_target).isNumeric())
648        ? tempInsts.attributeStats(m_target).numericStats.sum
649        : tempInsts.attributeStats(m_target).nominalCounts[m_targetIndex]; */
650      //      targetRight = tempInsts.attributeStats(attIndexnominalCounts[m_targetIndex];
651
652      // consider all splits
653      for (int i = 0; i < tempInsts.numInstances() - numMissing; i++) {
654        Instance inst = tempInsts.instance(i);
655
656        if (tempInsts.attribute(m_target).isNumeric()) {
657          targetLeft += inst.value(m_target);
658          targetRight -= inst.value(m_target);
659        } else {
660          if ((int)inst.value(m_target) == m_targetIndex) {
661            targetLeft++;
662            targetRight--;
663          }         
664        }
665        leftCount++;
666        rightCount--;
667       
668        // move to the end of any ties
669        if (i < tempInsts.numInstances() - 1 &&
670            inst.value(attIndex) == tempInsts.instance(i + 1).value(attIndex)) {
671          continue;
672        }
673
674        // evaluate split
675        if (tempInsts.attribute(m_target).isNominal()) {
676          if (targetLeft >= m_supportCount) {
677            double delta = (m_minimize) 
678              ? (bestMerit - (targetLeft / leftCount))
679              : ((targetLeft / leftCount) - bestMerit);
680            //            if (targetLeft / leftCount > bestMerit) {
681            if (delta > 0) {
682              bestMerit = targetLeft / leftCount;
683              bestSplit = inst.value(attIndex);
684              bestSupport = targetLeft;
685              bestSubsetSize = leftCount;
686              lessThan = true;
687              //            } else if (targetLeft / leftCount == bestMerit) {
688            } else if (delta == 0) {
689              // break ties in favour of higher support
690              if (targetLeft > bestSupport) {
691                bestMerit = targetLeft / leftCount;
692                bestSplit = inst.value(attIndex);
693                bestSupport = targetLeft;
694                bestSubsetSize = leftCount;
695                lessThan = true;
696              }
697            }
698          }
699
700          if (targetRight >= m_supportCount) {
701            double delta = (m_minimize) 
702              ? (bestMerit - (targetRight / rightCount))
703              : ((targetRight / rightCount) - bestMerit);
704            //            if (targetRight / rightCount > bestMerit) {
705            if (delta > 0) {
706              bestMerit = targetRight / rightCount;
707              bestSplit = inst.value(attIndex);
708              bestSupport = targetRight;
709              bestSubsetSize = rightCount;
710              lessThan = false;
711              //            } else if (targetRight / rightCount == bestMerit) {
712            } else if (delta == 0) {
713              // break ties in favour of higher support
714              if (targetRight > bestSupport) {
715                bestMerit = targetRight / rightCount;
716                bestSplit = inst.value(attIndex);
717                bestSupport = targetRight;
718                bestSubsetSize = rightCount;
719                lessThan = false;
720              }
721            }
722          } 
723        } else {
724          if (leftCount >= m_supportCount) {
725            double delta = (m_minimize) 
726              ? (bestMerit - (targetLeft / leftCount))
727              : ((targetLeft / leftCount) - bestMerit);
728            //            if (targetLeft / leftCount > bestMerit) {
729            if (delta > 0) {
730              bestMerit = targetLeft / leftCount;
731              bestSplit = inst.value(attIndex);
732              bestSupport = leftCount;
733              bestSubsetSize = leftCount;
734              lessThan = true;
735              //            } else if (targetLeft / leftCount == bestMerit) {
736            } else if (delta == 0) {
737              // break ties in favour of higher support
738              if (leftCount > bestSupport) {
739                bestMerit = targetLeft / leftCount;
740                bestSplit = inst.value(attIndex);
741                bestSupport = leftCount;
742                bestSubsetSize = leftCount;
743                lessThan = true;
744              }
745            }
746          }
747
748          if (rightCount >= m_supportCount) {
749            double delta = (m_minimize) 
750              ? (bestMerit - (targetRight / rightCount))
751              : ((targetRight / rightCount) - bestMerit);
752            //            if (targetRight / rightCount > bestMerit) {
753            if (delta > 0) {
754              bestMerit = targetRight / rightCount;
755              bestSplit = inst.value(attIndex);
756              bestSupport = rightCount;
757              bestSubsetSize = rightCount;
758              lessThan = false;
759              //            } else if (targetRight / rightCount == bestMerit) {
760            } else if (delta == 0) {
761              // break ties in favour of higher support
762              if (rightCount > bestSupport) {
763                bestMerit = targetRight / rightCount;
764                bestSplit = inst.value(attIndex);
765                bestSupport = rightCount;
766                bestSubsetSize = rightCount;
767                lessThan = false;
768              }
769            }
770          }         
771        }
772      }
773
774      double delta = (m_minimize)
775        ? m_targetValue - bestMerit
776        : bestMerit - m_targetValue;
777
778      // Have we found a candidate split?
779      if (bestSupport > 0 && (delta / m_targetValue >= m_minImprovement)) {
780        /*        System.err.println("Evaluating " + tempInsts.attribute(attIndex).name());
781        System.err.println("Merit : " + bestMerit);
782        System.err.println("Support : " + bestSupport); */
783        //        double suppFraction = bestSupport / m_numInstances;
784
785        HotTestDetails newD = new HotTestDetails(attIndex, bestSplit, 
786                                                 lessThan, (int)bestSupport, 
787                                                 (int)bestSubsetSize, 
788                                                 bestMerit);
789        pq.add(newD);
790      }
791    }
792
793    /**
794     * Evaluate a nominal attribute for a potential split
795     *
796     * @param attIndex the index of the attribute
797     * @param pq the priority queue of candidtate splits
798     */
799    private void evaluateNominal(int attIndex, PriorityQueue<HotTestDetails> pq) {
800      int[] counts = m_insts.attributeStats(attIndex).nominalCounts;
801      boolean ok = false;
802      // only consider attribute values that result in subsets that meet/exceed min support
803      for (int i = 0; i < m_insts.attribute(attIndex).numValues(); i++) {
804        if (counts[i] >= m_supportCount) {
805          ok = true;
806          break;
807        }
808      }
809      if (ok) {
810        double[] subsetMerit = 
811          new double[m_insts.attribute(attIndex).numValues()];
812
813        for (int i = 0; i < m_insts.numInstances(); i++) {
814          Instance temp = m_insts.instance(i);
815          if (!temp.isMissing(attIndex)) {
816            int attVal = (int)temp.value(attIndex);
817            if (m_insts.attribute(m_target).isNumeric()) {
818              subsetMerit[attVal] += temp.value(m_target);
819            } else {
820              subsetMerit[attVal] += 
821                ((int)temp.value(m_target) == m_targetIndex)
822                ? 1.0
823                : 0;
824            }
825          }
826        }
827       
828        // add to queue if it meets min support and exceeds the merit for the full set
829        for (int i = 0; i < m_insts.attribute(attIndex).numValues(); i++) {
830          // does the subset based on this value have enough instances, and, furthermore,
831          // does the target value (nominal only) occur enough times to exceed min support
832          if (counts[i] >= m_supportCount && 
833              ((m_insts.attribute(m_target).isNominal())
834              ? (subsetMerit[i] >= m_supportCount) // nominal only test
835               : true)) { 
836            double merit = subsetMerit[i] / counts[i]; //subsetMerit[i][1];
837            double delta = (m_minimize)
838              ? m_targetValue - merit
839              : merit - m_targetValue;
840
841            if (delta / m_targetValue >= m_minImprovement) {
842              double support =
843                (m_insts.attribute(m_target).isNominal())
844                ? subsetMerit[i]
845                : counts[i];
846
847              HotTestDetails newD = new HotTestDetails(attIndex, (double)i, 
848                                                       false, (int)support,
849                                                       counts[i], 
850                                                       merit);
851              pq.add(newD);
852            }
853          }
854        }
855      }
856    }
857
858    public int assignIDs(int lastID) {
859      int currentLastID = lastID + 1;
860      m_id = currentLastID;
861      if (m_children != null) {
862        for (int i = 0; i < m_children.length; i++) {
863          currentLastID = m_children[i].assignIDs(currentLastID);
864        }
865      }
866      return currentLastID;
867    }
868
869    private void addNodeDetails(StringBuffer buff, int i, String spacer) {
870      buff.append(m_header.attribute(m_testDetails[i].m_splitAttIndex).name());
871      if (m_header.attribute(m_testDetails[i].m_splitAttIndex).isNumeric()) {
872        if (m_testDetails[i].m_lessThan) {
873          buff.append(" <= ");
874        } else {
875          buff.append(" > ");
876        }
877        buff.append(Utils.doubleToString(m_testDetails[i].m_splitValue, 4));
878      } else {
879        buff.append(" = " + m_header.
880                    attribute(m_testDetails[i].m_splitAttIndex).
881                    value((int)m_testDetails[i].m_splitValue));
882      }
883
884      if (m_header.attribute(m_target).isNumeric()) {
885        buff.append(spacer + "(" + Utils.doubleToString(m_testDetails[i].m_merit, 4) + " ["
886                    + m_testDetails[i].m_support + "])");
887      } else {
888        buff.append(spacer + "(" + Utils.doubleToString((m_testDetails[i].m_merit * 100.0), 2) + "% ["
889                    + m_testDetails[i].m_support 
890                    + "/" + m_testDetails[i].m_subsetSize + "])");
891      }
892    }
893
894    private void graphHotSpot(StringBuffer text) {
895      if (m_children != null) {
896        for (int i = 0; i < m_children.length; i++) {
897          text.append("N" + m_children[i].m_id);
898          text.append(" [label=\"");
899          addNodeDetails(text, i, "\\n");
900          text.append("\" shape=plaintext]\n");
901          m_children[i].graphHotSpot(text);
902          text.append("N" + m_id + "->" + "N" + m_children[i].m_id + "\n");
903        }
904      }
905    }
906
907    /**
908     * Traverse the tree to create a string description
909     *
910     * @param depth the depth at this point in the tree
911     * @param buff the string buffer to append node details to
912     */
913    protected void dumpTree(int depth, StringBuffer buff) {
914      if (m_children == null) {
915        //        buff.append("\n");
916      } else {
917        for (int i = 0; i < m_children.length; i++) {
918          buff.append("\n  ");
919          for (int j = 0; j < depth; j++) {
920            buff.append("|   ");
921          }
922          addNodeDetails(buff, i, " ");
923
924          m_children[i].dumpTree(depth + 1, buff);
925        }
926      }
927    }
928  }
929
930  /**
931   * Returns the tip text for this property
932   * @return tip text for this property suitable for
933   * displaying in the explorer/experimenter gui
934   */
935  public String targetTipText() {
936    return "The target attribute of interest.";
937  }
938
939  /**
940   * Set the target index
941   *
942   * @param target the target index as a string (1-based)
943   */
944  public void setTarget(String target) {
945    m_targetSI.setSingleIndex(target);
946  }
947
948  /**
949   * Get the target index as a string
950   *
951   * @return the target index (1-based)
952   */
953  public String getTarget() {
954    return m_targetSI.getSingleIndex();
955  }
956
957  /**
958   * Returns the tip text for this property
959   * @return tip text for this property suitable for
960   * displaying in the explorer/experimenter gui
961   */
962  public String targetIndexTipText() {
963    return "The value of the target (nominal attributes only) of interest.";
964  }
965
966  /**
967   * For a nominal target, set the index of the value of interest (1-based)
968   *
969   * @param index the index of the nominal value of interest
970   */
971  public void setTargetIndex(String index) {
972    m_targetIndexSI.setSingleIndex(index);
973  }
974
975  /**
976   * For a nominal target, get the index of the value of interest (1-based)
977   *
978   * @return the index of the nominal value of interest
979   */
980  public String getTargetIndex() {
981    return m_targetIndexSI.getSingleIndex();
982  }
983
984  /**
985   * Returns the tip text for this property
986   * @return tip text for this property suitable for
987   * displaying in the explorer/experimenter gui
988   */
989  public String minimizeTargetTipText() {
990    return "Minimize rather than maximize the target.";
991  }
992
993  /**
994   * Set whether to minimize the target rather than maximize
995   *
996   * @param m true if target is to be minimized
997   */
998  public void setMinimizeTarget(boolean m) {
999    m_minimize = m;
1000  }
1001
1002  /**
1003   * Get whether to minimize the target rather than maximize
1004   *
1005   * @return true if target is to be minimized
1006   */
1007  public boolean getMinimizeTarget() {
1008    return m_minimize;
1009  }
1010
1011  /**
1012   * Returns the tip text for this property
1013   * @return tip text for this property suitable for
1014   * displaying in the explorer/experimenter gui
1015   */
1016  public String supportTipText() {
1017    return "The minimum support. Values between 0 and 1 are interpreted "
1018      + "as a percentage of the total population; values > 1 are "
1019      + "interpreted as an absolute number of instances";
1020  }
1021
1022  /**
1023   * Get the minimum support
1024   *
1025   * @return the minimum support
1026   */
1027  public double getSupport() {
1028    return m_support;
1029  }
1030
1031  /**
1032   * Set the minimum support
1033   *
1034   * @param s the minimum support
1035   */
1036  public void setSupport(double s) {
1037    m_support = s;
1038  }
1039
1040  /**
1041   * Returns the tip text for this property
1042   * @return tip text for this property suitable for
1043   * displaying in the explorer/experimenter gui
1044   */
1045  public String maxBranchingFactorTipText() {
1046    return "Maximum branching factor. The maximum number of children "
1047      + "to consider extending each node with.";
1048  }
1049
1050  /**
1051   * Set the maximum branching factor
1052   *
1053   * @param b the maximum branching factor
1054   */
1055  public void setMaxBranchingFactor(int b) {
1056    m_maxBranchingFactor = b;
1057  }
1058
1059  /**
1060   * Get the maximum branching factor
1061   *
1062   * @return the maximum branching factor
1063   */
1064  public int getMaxBranchingFactor() {
1065    return m_maxBranchingFactor;
1066  }
1067
1068  /**
1069   * Returns the tip text for this property
1070   * @return tip text for this property suitable for
1071   * displaying in the explorer/experimenter gui
1072   */
1073  public String minImprovementTipText() {
1074    return "Minimum improvement in target value in order to "
1075      + "consider adding a new branch/test";
1076  }
1077
1078  /**
1079   * Set the minimum improvement in the target necessary to add a test
1080   *
1081   * @param i the minimum improvement
1082   */
1083  public void setMinImprovement(double i) {
1084    m_minImprovement = i;
1085  }
1086
1087  /**
1088   * Get the minimum improvement in the target necessary to add a test
1089   *
1090   * @return the minimum improvement
1091   */
1092  public double getMinImprovement() {
1093    return m_minImprovement;
1094  }
1095
1096  /**
1097   * Returns the tip text for this property
1098   * @return tip text for this property suitable for
1099   * displaying in the explorer/experimenter gui
1100   */
1101  public String debugTipText() {
1102    return "Output debugging info (duplicate rule lookup hash table stats).";
1103  }
1104
1105  /**
1106   * Set whether debugging info is output
1107   *
1108   * @param d true to output debugging info
1109   */
1110  public void setDebug(boolean d) {
1111    m_debug = d;
1112  }
1113
1114  /**
1115   * Get whether debugging info is output
1116   *
1117   * @return true if outputing debugging info
1118   */
1119  public boolean getDebug() {
1120    return m_debug;
1121  }
1122
1123  /**
1124   * Returns an enumeration describing the available options.
1125   *
1126   * @return an enumeration of all the available options.
1127   */
1128  public Enumeration listOptions() {
1129    Vector newVector = new Vector();
1130    newVector.addElement(new Option("\tThe target index. (default = last)",
1131                                    "c", 1,
1132                                    "-c <num | first | last>"));
1133    newVector.addElement(new Option("\tThe target value (nominal target only, default = first)",
1134                                    "V", 1,
1135                                    "-V <num | first | last>"));
1136    newVector.addElement(new Option("\tMinimize rather than maximize.", "L", 0, "-L"));
1137    newVector.addElement(new Option("\tMinimum value count (nominal target)/segment size "
1138                                    + "(numeric target)."
1139                                    +"\n\tValues between 0 and 1 are "
1140                                    + "\n\tinterpreted as a percentage of "
1141                                    + "\n\tthe total population; values > 1 are "
1142                                    + "\n\tinterpreted as an absolute number of "
1143                                    +"\n\tinstances (default = 0.3)",
1144                                    "-S", 1,
1145                                    "-S <num>"));
1146    newVector.addElement(new Option("\tMaximum branching factor (default = 2)",
1147                                    "-M", 1,
1148                                    "-M <num>"));
1149    newVector.addElement(new Option("\tMinimum improvement in target value in order "
1150                                    + "\n\tto add a new branch/test (default = 0.01 (1%))",
1151                                    "-I", 1,
1152                                    "-I <num>"));
1153    newVector.addElement(new Option("\tOutput debugging info (duplicate rule lookup "
1154                                    + "\n\thash table stats)", "-D", 0, "-D"));
1155    return newVector.elements();
1156  }
1157
1158  /**
1159   * Reset options to their defaults
1160   */
1161  public void resetOptions() {
1162    m_support = 0.33;
1163    m_minImprovement = 0.01; // 1%
1164    m_maxBranchingFactor = 2;
1165    m_minimize = false;
1166    m_debug = false;
1167    setTarget("last");
1168    setTargetIndex("first");
1169    m_errorMessage = null;
1170  }
1171
1172  /**
1173   * Parses a given list of options. <p/>
1174   *
1175   <!-- options-start -->
1176   * Valid options are: <p/>
1177   *
1178   * <pre> -c &lt;num | first | last&gt;
1179   *  The target index. (default = last)</pre>
1180   *
1181   * <pre> -V &lt;num | first | last&gt;
1182   *  The target value (nominal target only, default = first)</pre>
1183   *
1184   * <pre> -L
1185   *  Minimize rather than maximize.</pre>
1186   *
1187   * <pre> -S &lt;num&gt;
1188   *  Minimum value count (nominal target)/segment size (numeric target).
1189   *  Values between 0 and 1 are
1190   *  interpreted as a percentage of
1191   *  the total population; values &gt; 1 are
1192   *  interpreted as an absolute number of
1193   *  instances (default = 0.3)</pre>
1194   *
1195   * <pre> -M &lt;num&gt;
1196   *  Maximum branching factor (default = 2)</pre>
1197   *
1198   * <pre> -I &lt;num&gt;
1199   *  Minimum improvement in target value in order
1200   *  to add a new branch/test (default = 0.01 (1%))</pre>
1201   *
1202   * <pre> -D
1203   *  Output debugging info (duplicate rule lookup
1204   *  hash table stats)</pre>
1205   *
1206   <!-- options-end -->
1207   *
1208   * @param options the list of options as an array of strings
1209   * @exception Exception if an option is not supported
1210   */
1211  public void setOptions(String[] options) throws Exception {
1212    resetOptions();
1213
1214    String tempString = Utils.getOption('c', options);
1215    if (tempString.length() != 0) {
1216      setTarget(tempString);
1217    }
1218   
1219    tempString = Utils.getOption('V', options);
1220    if (tempString.length() != 0) {
1221      setTargetIndex(tempString);
1222    }
1223
1224    setMinimizeTarget(Utils.getFlag('L', options));
1225
1226    tempString = Utils.getOption('S', options);
1227    if (tempString.length() != 0) {
1228      setSupport(Double.parseDouble(tempString));
1229    }
1230
1231    tempString = Utils.getOption('M', options);
1232    if (tempString.length() != 0) {
1233      setMaxBranchingFactor(Integer.parseInt(tempString));
1234    }
1235
1236    tempString = Utils.getOption('I', options);
1237    if (tempString.length() != 0) {
1238      setMinImprovement(Double.parseDouble(tempString));
1239    }
1240
1241    setDebug(Utils.getFlag('D', options));
1242  }
1243
1244  /**
1245   * Gets the current settings of HotSpot.
1246   *
1247   * @return an array of strings suitable for passing to setOptions
1248   */
1249  public String [] getOptions() {
1250    String[] options = new String[12];
1251    int current = 0;
1252   
1253    options[current++] = "-c"; options[current++] = getTarget();
1254    options[current++] = "-V"; options[current++] = getTargetIndex();
1255    if (getMinimizeTarget()) {
1256      options[current++] = "-L";
1257    }
1258    options[current++] = "-S"; options[current++] = "" + getSupport();
1259    options[current++] = "-M"; options[current++] = "" + getMaxBranchingFactor();
1260    options[current++] = "-I"; options[current++] = "" + getMinImprovement();
1261    if (getDebug()) {
1262      options[current++] = "-D";
1263    }
1264
1265    while (current < options.length) {
1266      options[current++] = "";
1267    }
1268
1269    return options;
1270  }
1271
1272  /**
1273   * Returns the revision string.
1274   *
1275   * @return            the revision
1276   */
1277  public String getRevision() {
1278    return RevisionUtils.extract("$Revision: 6081 $");
1279  }
1280
1281  /**
1282   *  Returns the type of graph this scheme
1283   *  represents.
1284   *  @return Drawable.TREE
1285   */   
1286  public int graphType() {
1287    return Drawable.TREE;
1288  }
1289
1290  /**
1291   * Main method for testing this class.
1292   *
1293   * @param args the options
1294   */
1295  public static void main(String[] args) {
1296    try {
1297      HotSpot h = new HotSpot();
1298      AbstractAssociator.runAssociator(new HotSpot(), args);
1299    } catch (Exception ex) {
1300      ex.printStackTrace();
1301    }
1302  }
1303}
1304
Note: See TracBrowser for help on using the repository browser.