source: src/main/java/weka/classifiers/trees/m5/Rule.java @ 19

Last change on this file since 19 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 14.4 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 *    Rule.java
19 *    Copyright (C) 2000 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.trees.m5;
24
25import weka.core.Instance;
26import weka.core.Instances;
27import weka.core.RevisionHandler;
28import weka.core.RevisionUtils;
29import weka.core.Utils;
30
31import java.io.Serializable;
32
33/**
34 * Generates a single m5 tree or rule
35 *
36 * @author Mark Hall
37 * @version $Revision: 1.15 $
38 */
39public class Rule
40  implements Serializable, RevisionHandler {
41
42  /** for serialization */
43  private static final long serialVersionUID = -4458627451682483204L;
44
45  protected static int LEFT = 0;
46  protected static int RIGHT = 1;
47
48  /**
49   * the instances covered by this rule
50   */
51  private Instances m_instances;
52
53  /**
54   * the class index
55   */
56  private int m_classIndex;
57
58  /**
59   * the number of attributes
60   */
61  private int m_numAttributes;
62
63  /**
64   * the number of instances in the dataset
65   */
66  private int m_numInstances;
67
68  /**
69   * the indexes of the attributes used to split on for this rule
70   */
71  private int[] m_splitAtts;
72
73  /**
74   * the corresponding values of the split points
75   */
76  private double[] m_splitVals;
77
78  /**
79   * the corresponding internal nodes. Used for smoothing rules.
80   */
81  private RuleNode[] m_internalNodes;
82
83  /**
84   * the corresponding relational operators (0 = "<=", 1 = ">")
85   */
86  private int[] m_relOps;
87
88  /**
89   * the leaf encapsulating the linear model for this rule
90   */
91  private RuleNode m_ruleModel;
92
93  /**
94   * the top of the m5 tree for this rule
95   */
96  protected RuleNode m_topOfTree;
97
98  /**
99   * the standard deviation of the class for all the instances
100   */
101  private double m_globalStdDev;
102
103  /**
104   * the absolute deviation of the class for all the instances
105   */
106  private double m_globalAbsDev;
107
108  /**
109   * the instances covered by this rule
110   */
111  private Instances m_covered;
112
113  /**
114   * the number of instances covered by this rule
115   */
116  private int m_numCovered;
117
118  /**
119   * the instances not covered by this rule
120   */
121  private Instances m_notCovered;
122
123  /**
124   * use a pruned m5 tree rather than make a rule
125   */
126  private boolean m_useTree;
127
128  /**
129   * use the original m5 smoothing procedure
130   */
131  private boolean m_smoothPredictions;
132
133  /**
134   * Save instances at each node in an M5 tree for visualization purposes.
135   */
136  private boolean m_saveInstances;
137
138  /**
139   * Make a regression tree instead of a model tree
140   */
141  private boolean m_regressionTree;
142
143  /**
144   * Build unpruned tree/rule
145   */
146  private boolean m_useUnpruned;
147
148  /**
149   * The minimum number of instances to allow at a leaf node
150   */
151  private double m_minNumInstances;
152
153  /**
154   * Constructor declaration
155   *
156   */
157  public Rule() {
158    m_useTree = false;
159    m_smoothPredictions = false;
160    m_useUnpruned = false;
161    m_minNumInstances = 4;
162  }
163
164  /**
165   * Generates a single rule or m5 model tree.
166   *
167   * @param data set of instances serving as training data
168   * @exception Exception if the rule has not been generated
169   * successfully
170   */
171  public void buildClassifier(Instances data) throws Exception {
172    m_instances = null;
173    m_topOfTree = null;
174    m_covered = null;
175    m_notCovered = null;
176    m_ruleModel = null;
177    m_splitAtts = null;
178    m_splitVals = null;
179    m_relOps = null;
180    m_internalNodes = null;
181    m_instances = data;
182    m_classIndex = m_instances.classIndex();
183    m_numAttributes = m_instances.numAttributes();
184    m_numInstances = m_instances.numInstances();
185
186    // first calculate global deviation of class attribute
187    m_globalStdDev = Rule.stdDev(m_classIndex, m_instances);
188    m_globalAbsDev = Rule.absDev(m_classIndex, m_instances);
189
190    m_topOfTree = new RuleNode(m_globalStdDev, m_globalAbsDev, null);
191    m_topOfTree.setSaveInstances(m_saveInstances);
192    m_topOfTree.setRegressionTree(m_regressionTree);
193    m_topOfTree.setMinNumInstances(m_minNumInstances);
194    m_topOfTree.buildClassifier(m_instances);
195
196
197    if (!m_useUnpruned) {
198      m_topOfTree.prune();
199    } else {
200      m_topOfTree.installLinearModels();
201    }
202
203    if (m_smoothPredictions) {
204      m_topOfTree.installSmoothedModels();
205    }
206    //m_topOfTree.printAllModels();
207    m_topOfTree.numLeaves(0);
208
209    if (!m_useTree) {     
210      makeRule();
211      // save space
212      //      m_topOfTree = null;
213    }
214
215    // save space
216    m_instances = new Instances(m_instances, 0);
217   
218  } 
219
220  /**
221   * Calculates a prediction for an instance using this rule
222   * or M5 model tree
223   *
224   * @param instance the instance whos class value is to be predicted
225   * @return the prediction
226   * @exception Exception if a prediction can't be made.
227   */
228  public double classifyInstance(Instance instance) throws Exception {
229    if (m_useTree) {
230      return m_topOfTree.classifyInstance(instance);
231    } 
232
233    // does the instance pass the rule's conditions?
234    if (m_splitAtts.length > 0) {
235      for (int i = 0; i < m_relOps.length; i++) {
236        if (m_relOps[i] == LEFT)    // left
237         {
238          if (instance.value(m_splitAtts[i]) > m_splitVals[i]) {
239            throw new Exception("Rule does not classify instance");
240          } 
241        } else {
242          if (instance.value(m_splitAtts[i]) <= m_splitVals[i]) {
243            throw new Exception("Rule does not classify instance");
244          } 
245        } 
246      } 
247    } 
248
249    // the linear model's prediction for this rule
250    return m_ruleModel.classifyInstance(instance);
251  } 
252
253  /**
254   * Returns the top of the tree.
255   */
256  public RuleNode topOfTree() {
257
258    return m_topOfTree;
259  }
260
261  /**
262   * Make the single best rule from a pruned m5 model tree
263   *
264   * @exception Exception if something goes wrong.
265   */
266  private void makeRule() throws Exception {
267    RuleNode[] best_leaf = new RuleNode[1];
268    double[]   best_cov = new double[1];
269    RuleNode   temp;
270
271    m_notCovered = new Instances(m_instances, 0);
272    m_covered = new Instances(m_instances, 0);
273    best_cov[0] = -1;
274    best_leaf[0] = null;
275
276    m_topOfTree.findBestLeaf(best_cov, best_leaf);
277
278    temp = best_leaf[0];
279
280    if (temp == null) {
281      throw new Exception("Unable to generate rule!");
282    } 
283
284    // save the linear model for this rule
285    m_ruleModel = temp;
286
287    int count = 0;
288
289    while (temp.parentNode() != null) {
290      count++;
291      temp = temp.parentNode();
292    } 
293
294    temp = best_leaf[0];
295    m_relOps = new int[count];
296    m_splitAtts = new int[count];
297    m_splitVals = new double[count];
298    if (m_smoothPredictions) {
299      m_internalNodes = new RuleNode[count];
300    }
301
302    // trace back to the root
303    int i = 0;
304
305    while (temp.parentNode() != null) {
306      m_splitAtts[i] = temp.parentNode().splitAtt();
307      m_splitVals[i] = temp.parentNode().splitVal();
308
309      if (temp.parentNode().leftNode() == temp) {
310        m_relOps[i] = LEFT;
311        //      temp.parentNode().m_right = null;
312      } else {
313        m_relOps[i] = RIGHT;
314        //      temp.parentNode().m_left = null;
315      }
316
317      if (m_smoothPredictions) {
318        m_internalNodes[i] = temp.parentNode();
319      }
320
321      temp = temp.parentNode();
322      i++;
323    } 
324
325    // now assemble the covered and uncovered instances
326    boolean ok;
327
328    for (i = 0; i < m_numInstances; i++) {
329      ok = true;
330
331      for (int j = 0; j < m_relOps.length; j++) {
332        if (m_relOps[j] == LEFT)
333         {
334          if (m_instances.instance(i).value(m_splitAtts[j]) 
335                  > m_splitVals[j]) {
336            m_notCovered.add(m_instances.instance(i));
337            ok = false;
338            break;
339          } 
340        } else {
341          if (m_instances.instance(i).value(m_splitAtts[j]) 
342                  <= m_splitVals[j]) {
343            m_notCovered.add(m_instances.instance(i));
344            ok = false;
345            break;
346          } 
347        } 
348      } 
349
350      if (ok) {
351        m_numCovered++;
352        //      m_covered.add(m_instances.instance(i));
353      } 
354    } 
355  } 
356
357  /**
358   * Return a description of the m5 tree or rule
359   *
360   * @return a description of the m5 tree or rule as a String
361   */
362  public String toString() {
363    if (m_useTree) {
364      return treeToString();
365    } else {
366      return ruleToString();
367    } 
368  } 
369
370  /**
371   * Return a description of the m5 tree
372   *
373   * @return a description of the m5 tree as a String
374   */
375  private String treeToString() {
376    StringBuffer text = new StringBuffer();
377
378    if (m_topOfTree == null) {
379      return "Tree/Rule has not been built yet!";
380    } 
381
382    text.append("M5 "
383                + ((m_useUnpruned)
384                   ? "unpruned "
385                   : "pruned ")
386                + ((m_regressionTree) 
387                   ? "regression "
388                   : "model ")
389                +"tree:\n");
390
391    if (m_smoothPredictions == true) {
392      text.append("(using smoothed linear models)\n");
393    } 
394
395    text.append(m_topOfTree.treeToString(0));
396    text.append(m_topOfTree.printLeafModels());
397    text.append("\nNumber of Rules : " + m_topOfTree.numberOfLinearModels());
398
399    return text.toString();
400  } 
401
402  /**
403   * Return a description of the rule
404   *
405   * @return a description of the rule as a String
406   */
407  private String ruleToString() {
408    StringBuffer text = new StringBuffer();
409
410    if (m_splitAtts.length > 0) {
411      text.append("IF\n");
412
413      for (int i = m_splitAtts.length - 1; i >= 0; i--) {
414        text.append("\t" + m_covered.attribute(m_splitAtts[i]).name() + " ");
415
416        if (m_relOps[i] == 0) {
417          text.append("<= ");
418        } else {
419          text.append("> ");
420        } 
421
422        text.append(Utils.doubleToString(m_splitVals[i], 1, 3) + "\n");
423      } 
424
425      text.append("THEN\n");
426    } 
427
428    if (m_ruleModel != null) {
429      try {
430        text.append(m_ruleModel.printNodeLinearModel());
431        text.append(" [" + m_numCovered/*m_covered.numInstances()*/);
432
433        if (m_globalAbsDev > 0.0) {
434          text.append("/"+Utils.doubleToString((100 * 
435                                                   m_ruleModel.
436                                                   rootMeanSquaredError() / 
437                                                   m_globalStdDev), 1, 3) 
438                      + "%]\n\n");
439        } else {
440          text.append("]\n\n");
441        } 
442      } catch (Exception e) {
443        return "Can't print rule";
444      } 
445    } 
446   
447    //    System.out.println(m_instances);
448    return text.toString();
449  } 
450
451  /**
452   * Use unpruned tree/rules
453   *
454   * @param unpruned true if unpruned tree/rules are to be generated
455   */
456  public void setUnpruned(boolean unpruned) {
457    m_useUnpruned = unpruned;
458  }
459
460  /**
461   * Get whether unpruned tree/rules are being generated
462   *
463   * @return true if unpruned tree/rules are to be generated
464   */
465  public boolean getUnpruned() {
466    return m_useUnpruned;
467  }
468
469  /**
470   * Use an m5 tree rather than generate rules
471   *
472   * @param u true if m5 tree is to be used
473   */
474  public void setUseTree(boolean u) {
475    m_useTree = u;
476  } 
477
478  /**
479   * get whether an m5 tree is being used rather than rules
480   *
481   * @return true if an m5 tree is being used.
482   */
483  public boolean getUseTree() {
484    return m_useTree;
485  } 
486
487  /**
488   * Smooth predictions
489   *
490   * @param s true if smoothing is to be used
491   */
492  public void setSmoothing(boolean s) {
493    m_smoothPredictions = s;
494  } 
495
496  /**
497   * Get whether or not smoothing has been turned on
498   *
499   * @return true if smoothing is being used
500   */
501  public boolean getSmoothing() {
502    return m_smoothPredictions;
503  } 
504
505  /**
506   * Get the instances not covered by this rule
507   *
508   * @return the instances not covered
509   */
510  public Instances notCoveredInstances() {
511    return m_notCovered;
512  } 
513
514//    /**
515//     * Get the instances covered by this rule
516//     *
517//     * @return the instances covered by this rule
518//     */
519//    public Instances coveredInstances() {
520//      return m_covered;
521//    }
522
523  /**
524   * Returns the standard deviation value of the supplied attribute index.
525   *
526   * @param attr an attribute index
527   * @param inst the instances
528   * @return the standard deviation value
529   */
530  protected static final double stdDev(int attr, Instances inst) {
531    int i,count=0;
532    double sd,va,sum=0.0,sqrSum=0.0,value;
533   
534    for(i = 0; i <= inst.numInstances() - 1; i++) {
535      count++;
536      value = inst.instance(i).value(attr);
537      sum +=  value;
538      sqrSum += value * value;
539    }
540   
541    if(count > 1) {
542      va = (sqrSum - sum * sum / count) / count;
543      va = Math.abs(va);
544      sd = Math.sqrt(va);
545    } else {
546      sd = 0.0;
547    }
548
549    return sd;
550  }
551
552  /**
553   * Returns the absolute deviation value of the supplied attribute index.
554   *
555   * @param attr an attribute index
556   * @param inst the instances
557   * @return the absolute deviation value
558   */
559  protected static final double absDev(int attr, Instances inst) {
560    int i;
561    double average=0.0,absdiff=0.0,absDev;
562   
563    for(i = 0; i <= inst.numInstances()-1; i++) {
564      average  += inst.instance(i).value(attr);
565    }
566    if(inst.numInstances() > 1) {
567      average /= (double)inst.numInstances();
568      for(i=0; i <= inst.numInstances()-1; i++) {
569        absdiff += Math.abs(inst.instance(i).value(attr) - average);
570      }
571      absDev = absdiff / (double)inst.numInstances();
572    } else {
573      absDev = 0.0;
574    }
575   
576    return absDev;
577  }
578
579  /**
580   * Sets whether instances at each node in an M5 tree should be saved
581   * for visualization purposes. Default is to save memory.
582   *
583   * @param save a <code>boolean</code> value
584   */
585  protected void setSaveInstances(boolean save) {
586    m_saveInstances = save;
587  }
588
589  /**
590   * Get the value of regressionTree.
591   *
592   * @return Value of regressionTree.
593   */
594  public boolean getRegressionTree() {
595   
596    return m_regressionTree;
597  }
598 
599  /**
600   * Set the value of regressionTree.
601   *
602   * @param newregressionTree Value to assign to regressionTree.
603   */
604  public void setRegressionTree(boolean newregressionTree) {
605   
606    m_regressionTree = newregressionTree;
607  }
608
609  /**
610   * Set the minumum number of instances to allow at a leaf node
611   *
612   * @param minNum the minimum number of instances
613   */
614  public void setMinNumInstances(double minNum) {
615    m_minNumInstances = minNum;
616  }
617
618  /**
619   * Get the minimum number of instances to allow at a leaf node
620   *
621   * @return a <code>double</code> value
622   */
623  public double getMinNumInstances() {
624    return m_minNumInstances;
625  }
626
627  public RuleNode getM5RootNode() {
628    return m_topOfTree;
629  }
630 
631  /**
632   * Returns the revision string.
633   *
634   * @return            the revision
635   */
636  public String getRevision() {
637    return RevisionUtils.extract("$Revision: 1.15 $");
638  }
639}
Note: See TracBrowser for help on using the repository browser.