source: src/main/java/weka/classifiers/mi/TLDSimple.java @ 28

Last change on this file since 28 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 29.8 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 * TLDSimple.java
19 * Copyright (C) 2005 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.mi;
24
25import weka.classifiers.RandomizableClassifier;
26import weka.core.Capabilities;
27import weka.core.Instance;
28import weka.core.Instances;
29import weka.core.MultiInstanceCapabilitiesHandler;
30import weka.core.Optimization;
31import weka.core.Option;
32import weka.core.OptionHandler;
33import weka.core.RevisionUtils;
34import weka.core.TechnicalInformation;
35import weka.core.TechnicalInformationHandler;
36import weka.core.Utils;
37import weka.core.Capabilities.Capability;
38import weka.core.TechnicalInformation.Field;
39import weka.core.TechnicalInformation.Type;
40
41import java.util.Enumeration;
42import java.util.Random;
43import java.util.Vector;
44
45/**
46 <!-- globalinfo-start -->
47 * A simpler version of TLD, mu random but sigma^2 fixed and estimated via data.<br/>
48 * <br/>
49 * For more information see:<br/>
50 * <br/>
51 * Xin Xu (2003). Statistical learning in multiple instance problem. Hamilton, NZ.
52 * <p/>
53 <!-- globalinfo-end -->
54 *
55 <!-- technical-bibtex-start -->
56 * BibTeX:
57 * <pre>
58 * &#64;mastersthesis{Xu2003,
59 *    address = {Hamilton, NZ},
60 *    author = {Xin Xu},
61 *    note = {0657.594},
62 *    school = {University of Waikato},
63 *    title = {Statistical learning in multiple instance problem},
64 *    year = {2003}
65 * }
66 * </pre>
67 * <p/>
68 <!-- technical-bibtex-end -->
69 *
70 <!-- options-start -->
71 * Valid options are: <p/>
72 *
73 * <pre> -C
74 *  Set whether or not use empirical
75 *  log-odds cut-off instead of 0</pre>
76 *
77 * <pre> -R &lt;numOfRuns&gt;
78 *  Set the number of multiple runs
79 *  needed for searching the MLE.</pre>
80 *
81 * <pre> -S &lt;num&gt;
82 *  Random number seed.
83 *  (default 1)</pre>
84 *
85 * <pre> -D
86 *  If set, classifier is run in debug mode and
87 *  may output additional info to the console</pre>
88 *
89 <!-- options-end -->
90 *
91 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
92 * @author Xin Xu (xx5@cs.waikato.ac.nz)
93 * @version $Revision: 5481 $
94 */
95public class TLDSimple 
96  extends RandomizableClassifier
97  implements OptionHandler, MultiInstanceCapabilitiesHandler,
98             TechnicalInformationHandler {
99
100  /** for serialization */
101  static final long serialVersionUID = 9040995947243286591L;
102 
103  /** The mean for each attribute of each positive exemplar */
104  protected double[][] m_MeanP = null;
105
106  /** The mean for each attribute of each negative exemplar */
107  protected double[][] m_MeanN = null;
108
109  /** The effective sum of weights of each positive exemplar in each dimension*/
110  protected double[][] m_SumP = null;
111
112  /** The effective sum of weights of each negative exemplar in each dimension*/
113  protected double[][] m_SumN = null;
114
115  /** Estimated sigma^2 in positive bags*/
116  protected double[] m_SgmSqP;
117
118  /** Estimated sigma^2 in negative bags*/
119  protected double[] m_SgmSqN;
120
121  /** The parameters to be estimated for each positive exemplar*/
122  protected double[] m_ParamsP = null;
123
124  /** The parameters to be estimated for each negative exemplar*/
125  protected double[] m_ParamsN = null;
126
127  /** The dimension of each exemplar, i.e. (numAttributes-2) */
128  protected int m_Dimension = 0;
129
130  /** The class label of each exemplar */
131  protected double[] m_Class = null;
132
133  /** The number of class labels in the data */
134  protected int m_NumClasses = 2;
135
136  /** The very small number representing zero */
137  static public double ZERO = 1.0e-12;   
138
139  protected int m_Run = 1;
140
141  protected double m_Cutoff;
142
143  protected boolean m_UseEmpiricalCutOff = false;   
144
145  private double[] m_LkRatio;
146
147  private Instances m_Attribute = null;
148
149  /**
150   * Returns a string describing this filter
151   *
152   * @return a description of the filter suitable for
153   * displaying in the explorer/experimenter gui
154   */
155  public String globalInfo() {
156    return 
157        "A simpler version of TLD, mu random but sigma^2 fixed and estimated "
158      + "via data.\n\n"
159      + "For more information see:\n\n"
160      + getTechnicalInformation().toString();
161  }
162 
163  /**
164   * Returns an instance of a TechnicalInformation object, containing
165   * detailed information about the technical background of this class,
166   * e.g., paper reference or book this class is based on.
167   *
168   * @return the technical information about this class
169   */
170  public TechnicalInformation getTechnicalInformation() {
171    TechnicalInformation        result;
172   
173    result = new TechnicalInformation(Type.MASTERSTHESIS);
174    result.setValue(Field.AUTHOR, "Xin Xu");
175    result.setValue(Field.YEAR, "2003");
176    result.setValue(Field.TITLE, "Statistical learning in multiple instance problem");
177    result.setValue(Field.SCHOOL, "University of Waikato");
178    result.setValue(Field.ADDRESS, "Hamilton, NZ");
179    result.setValue(Field.NOTE, "0657.594");
180   
181    return result;
182  }
183
184  /**
185   * Returns default capabilities of the classifier.
186   *
187   * @return      the capabilities of this classifier
188   */
189  public Capabilities getCapabilities() {
190    Capabilities result = super.getCapabilities();
191    result.disableAll();
192
193    // attributes
194    result.enable(Capability.NOMINAL_ATTRIBUTES);
195    result.enable(Capability.RELATIONAL_ATTRIBUTES);
196    result.enable(Capability.MISSING_VALUES);
197
198    // class
199    result.enable(Capability.BINARY_CLASS);
200    result.enable(Capability.MISSING_CLASS_VALUES);
201   
202    // other
203    result.enable(Capability.ONLY_MULTIINSTANCE);
204   
205    return result;
206  }
207
208  /**
209   * Returns the capabilities of this multi-instance classifier for the
210   * relational data.
211   *
212   * @return            the capabilities of this object
213   * @see               Capabilities
214   */
215  public Capabilities getMultiInstanceCapabilities() {
216    Capabilities result = super.getCapabilities();
217    result.disableAll();
218   
219    // attributes
220    result.enable(Capability.NOMINAL_ATTRIBUTES);
221    result.enable(Capability.NUMERIC_ATTRIBUTES);
222    result.enable(Capability.DATE_ATTRIBUTES);
223    result.enable(Capability.MISSING_VALUES);
224
225    // class
226    result.disableAllClasses();
227    result.enable(Capability.NO_CLASS);
228   
229    return result;
230  }
231
232  /**
233   *
234   * @param exs the training exemplars
235   * @throws Exception if the model cannot be built properly
236   */   
237  public void buildClassifier(Instances exs)throws Exception{
238    // can classifier handle the data?
239    getCapabilities().testWithFail(exs);
240
241    // remove instances with missing class
242    exs = new Instances(exs);
243    exs.deleteWithMissingClass();
244   
245    int numegs = exs.numInstances();
246    m_Dimension = exs.attribute(1).relation().numAttributes();
247    m_Attribute = exs.attribute(1).relation().stringFreeStructure();
248    Instances pos = new Instances(exs, 0), neg = new Instances(exs, 0);
249
250    // Divide into two groups
251    for(int u=0; u<numegs; u++){
252      Instance example = exs.instance(u);
253      if(example.classValue() == 1)
254        pos.add(example);
255      else
256        neg.add(example);
257    }   
258    int pnum = pos.numInstances(), nnum = neg.numInstances();   
259
260    // xBar, n
261    m_MeanP = new double[pnum][m_Dimension];
262    m_SumP = new double[pnum][m_Dimension];
263    m_MeanN = new double[nnum][m_Dimension];
264    m_SumN = new double[nnum][m_Dimension];
265    // w, m
266    m_ParamsP = new double[2*m_Dimension];
267    m_ParamsN = new double[2*m_Dimension];
268    // \sigma^2
269    m_SgmSqP = new double[m_Dimension];
270    m_SgmSqN = new double[m_Dimension];
271    // S^2
272    double[][] varP=new double[pnum][m_Dimension], 
273      varN=new double[nnum][m_Dimension];
274    // numOfEx 'e' without all missing
275    double[] effNumExP=new double[m_Dimension], 
276      effNumExN=new double[m_Dimension];
277    // For the starting values
278    double[] pMM=new double[m_Dimension], 
279      nMM=new double[m_Dimension],
280      pVM=new double[m_Dimension],
281      nVM=new double[m_Dimension];
282    // # of exemplars with only one instance
283    double[] numOneInsExsP=new double[m_Dimension],
284      numOneInsExsN=new double[m_Dimension];
285    // sum_i(1/n_i)
286    double[] pInvN = new double[m_Dimension], nInvN = new double[m_Dimension];
287
288    // Extract metadata from both positive and negative bags
289    for(int v=0; v < pnum; v++){
290      //Instance px = pos.instance(v);
291      Instances pxi =  pos.instance(v).relationalValue(1);
292      for (int k=0; k<pxi.numAttributes(); k++) {
293        m_MeanP[v][k] = pxi.meanOrMode(k);
294        varP[v][k] = pxi.variance(k);
295      }
296
297      for (int w=0,t=0; w < m_Dimension; w++,t++){             
298        //if((t==m_ClassIndex) || (t==m_IdIndex))
299        //  t++;       
300        if(varP[v][w] <= 0.0)
301          varP[v][w] = 0.0;
302        if(!Double.isNaN(m_MeanP[v][w])){
303
304          for(int u=0;u<pxi.numInstances();u++)
305            if(!pxi.instance(u).isMissing(t))                       
306              m_SumP[v][w] += pxi.instance(u).weight();
307
308          pMM[w] += m_MeanP[v][w];
309          pVM[w] += m_MeanP[v][w]*m_MeanP[v][w];                   
310          if((m_SumP[v][w]>1) && (varP[v][w]>ZERO)){   
311
312            m_SgmSqP[w] += varP[v][w]*(m_SumP[v][w]-1.0)/m_SumP[v][w];
313
314            //m_SgmSqP[w] += varP[v][w]*(m_SumP[v][w]-1.0);
315            effNumExP[w]++; // Not count exemplars with 1 instance
316            pInvN[w] += 1.0/m_SumP[v][w];
317            //pInvN[w] += m_SumP[v][w];
318          }
319          else
320            numOneInsExsP[w]++;
321        }
322
323      }                     
324    }
325
326
327    for(int v=0; v < nnum; v++){
328      //Instance nx = neg.instance(v);
329      Instances nxi = neg.instance(v).relationalValue(1);
330      for (int k=0; k<nxi.numAttributes(); k++) {
331        m_MeanN[v][k] = nxi.meanOrMode(k);
332        varN[v][k] = nxi.variance(k);
333      }
334      //Instances nxi =  nx.getInstances();
335
336      for (int w=0,t=0; w < m_Dimension; w++,t++){
337
338        //if((t==m_ClassIndex) || (t==m_IdIndex))
339        //  t++;       
340        if(varN[v][w] <= 0.0)
341          varN[v][w] = 0.0;
342        if(!Double.isNaN(m_MeanN[v][w])){
343          for(int u=0;u<nxi.numInstances();u++)
344            if(!nxi.instance(u).isMissing(t))
345              m_SumN[v][w] += nxi.instance(u).weight(); 
346
347          nMM[w] += m_MeanN[v][w]; 
348          nVM[w] += m_MeanN[v][w]*m_MeanN[v][w];
349          if((m_SumN[v][w]>1) && (varN[v][w]>ZERO)){                   
350            m_SgmSqN[w] += varN[v][w]*(m_SumN[v][w]-1.0)/m_SumN[v][w];
351            //m_SgmSqN[w] += varN[v][w]*(m_SumN[v][w]-1.0);
352            effNumExN[w]++; // Not count exemplars with 1 instance
353            nInvN[w] += 1.0/m_SumN[v][w];
354            //nInvN[w] += m_SumN[v][w];
355          }
356          else
357            numOneInsExsN[w]++;
358        }                                       
359      }
360    }
361
362    // Expected \sigma^2
363    /* if m_SgmSqP[u] or m_SgmSqN[u] is 0, assign 0 to sigma^2.
364     * Otherwise, may cause k m_SgmSqP / m_SgmSqN to be NaN.
365     * Modified by Lin Dong (Sep. 2005)
366     */
367    for (int u=0; u < m_Dimension; u++){
368      // For exemplars with only one instance, use avg(\sigma^2) of other exemplars
369      if (m_SgmSqP[u]!=0)
370        m_SgmSqP[u] /= (effNumExP[u]-pInvN[u]);
371      else
372        m_SgmSqP[u] = 0;
373      if (m_SgmSqN[u]!=0)
374        m_SgmSqN[u] /= (effNumExN[u]-nInvN[u]);
375      else
376        m_SgmSqN[u] = 0;
377
378      //m_SgmSqP[u] /= (pInvN[u]-effNumExP[u]);
379      //m_SgmSqN[u] /= (nInvN[u]-effNumExN[u]);
380      effNumExP[u] += numOneInsExsP[u];
381      effNumExN[u] += numOneInsExsN[u];
382      pMM[u] /= effNumExP[u];
383      nMM[u] /= effNumExN[u];
384      pVM[u] = pVM[u]/(effNumExP[u]-1.0) - pMM[u]*pMM[u]*effNumExP[u]/(effNumExP[u]-1.0);
385      nVM[u] = nVM[u]/(effNumExN[u]-1.0) - nMM[u]*nMM[u]*effNumExN[u]/(effNumExN[u]-1.0);
386    }
387
388    //Bounds and parameter values for each run
389    double[][] bounds = new double[2][2];
390    double[] pThisParam = new double[2], 
391      nThisParam = new double[2];
392
393    // Initial values for parameters
394    double w, m;
395    Random whichEx = new Random(m_Seed);
396
397    // Optimize for one dimension
398    for (int x=0; x < m_Dimension; x++){     
399      // System.out.println("\n\n!!!!!!!!!!!!!!!!!!!!!!???Dimension #"+x);
400
401      // Positive examplars: first run
402      pThisParam[0] = pVM[x];  // w
403      if( pThisParam[0] <= ZERO)
404        pThisParam[0] = 1.0;
405      pThisParam[1] = pMM[x];  // m
406
407      // Negative examplars: first run
408      nThisParam[0] = nVM[x];  // w
409      if(nThisParam[0] <= ZERO)
410        nThisParam[0] = 1.0;
411      nThisParam[1] = nMM[x];  // m
412
413      // Bound constraints
414      bounds[0][0] = ZERO; // w > 0
415      bounds[0][1] = Double.NaN;
416      bounds[1][0] = Double.NaN; 
417      bounds[1][1] = Double.NaN;
418
419      double pminVal=Double.MAX_VALUE, nminVal=Double.MAX_VALUE; 
420      TLDSimple_Optm pOp=null, nOp=null;       
421      boolean isRunValid = true;
422      double[] sumP=new double[pnum], meanP=new double[pnum];
423      double[] sumN=new double[nnum], meanN=new double[nnum];
424
425      // One dimension
426      for(int p=0; p<pnum; p++){
427        sumP[p] = m_SumP[p][x];
428        meanP[p] = m_MeanP[p][x];
429      }
430      for(int q=0; q<nnum; q++){
431        sumN[q] = m_SumN[q][x];
432        meanN[q] = m_MeanN[q][x];
433      }
434
435      for(int y=0; y<m_Run; y++){
436        //System.out.println("\n\n!!!!!!!!!Positive exemplars: Run #"+y);
437        double thisMin;
438        pOp = new TLDSimple_Optm();
439        pOp.setNum(sumP);
440        pOp.setSgmSq(m_SgmSqP[x]);
441        if (getDebug())
442          System.out.println("m_SgmSqP["+x+"]= " +m_SgmSqP[x]);
443        pOp.setXBar(meanP);
444        //pOp.setDebug(true);
445        pThisParam = pOp.findArgmin(pThisParam, bounds);
446        while(pThisParam==null){
447          pThisParam = pOp.getVarbValues();                 
448          if (getDebug())
449            System.out.println("!!! 200 iterations finished, not enough!");
450          pThisParam = pOp.findArgmin(pThisParam, bounds);
451        }       
452
453        thisMin = pOp.getMinFunction();
454        if(!Double.isNaN(thisMin) && (thisMin<pminVal)){
455          pminVal = thisMin;
456          for(int z=0; z<2; z++)
457            m_ParamsP[2*x+z] = pThisParam[z];
458        }
459
460        if(Double.isNaN(thisMin)){
461          pThisParam = new double[2];
462          isRunValid =false;
463        }
464        if(!isRunValid){ y--; isRunValid=true; } 
465
466        // Change the initial parameters and restart
467        int pone = whichEx.nextInt(pnum);
468
469        // Positive exemplars: next run
470        while(Double.isNaN(m_MeanP[pone][x]))
471          pone = whichEx.nextInt(pnum);
472
473        m = m_MeanP[pone][x];
474        w = (m-pThisParam[1])*(m-pThisParam[1]);
475        pThisParam[0] = w;  // w
476        pThisParam[1] = m;  // m           
477      }
478
479      for(int y=0; y<m_Run; y++){
480        //System.out.println("\n\n!!!!!!!!!Negative exemplars: Run #"+y);
481        double thisMin;
482        nOp = new TLDSimple_Optm();
483        nOp.setNum(sumN);
484        nOp.setSgmSq(m_SgmSqN[x]);
485        if (getDebug())
486          System.out.println(m_SgmSqN[x]);
487        nOp.setXBar(meanN);
488        //nOp.setDebug(true);
489        nThisParam = nOp.findArgmin(nThisParam, bounds);
490
491        while(nThisParam==null){       
492          nThisParam = nOp.getVarbValues();
493          if (getDebug())
494            System.out.println("!!! 200 iterations finished, not enough!");
495          nThisParam = nOp.findArgmin(nThisParam, bounds);
496        }                       
497
498        thisMin = nOp.getMinFunction(); 
499        if(!Double.isNaN(thisMin) && (thisMin<nminVal)){
500          nminVal = thisMin;
501          for(int z=0; z<2; z++)
502            m_ParamsN[2*x+z] = nThisParam[z];     
503        }
504
505        if(Double.isNaN(thisMin)){
506          nThisParam = new double[2];
507          isRunValid =false;
508        }
509
510        if(!isRunValid){ y--; isRunValid=true; }               
511
512        // Change the initial parameters and restart               
513        int none = whichEx.nextInt(nnum);// Randomly pick one pos. exmpl.
514
515        // Negative exemplars: next run
516        while(Double.isNaN(m_MeanN[none][x]))
517          none = whichEx.nextInt(nnum);
518
519        m = m_MeanN[none][x];
520        w = (m-nThisParam[1])*(m-nThisParam[1]);
521        nThisParam[0] = w;  // w
522        nThisParam[1] = m;  // m                       
523      }                             
524    }
525
526    m_LkRatio = new double[m_Dimension];
527
528    if(m_UseEmpiricalCutOff){   
529      // Find the empirical cut-off
530      double[] pLogOdds=new double[pnum], nLogOdds=new double[nnum]; 
531      for(int p=0; p<pnum; p++)
532        pLogOdds[p] = 
533          likelihoodRatio(m_SumP[p], m_MeanP[p]);
534
535      for(int q=0; q<nnum; q++)
536        nLogOdds[q] = 
537          likelihoodRatio(m_SumN[q], m_MeanN[q]);
538
539      // Update m_Cutoff
540      findCutOff(pLogOdds, nLogOdds);
541    }
542    else
543      m_Cutoff = -Math.log((double)pnum/(double)nnum);
544
545    /*
546       for(int x=0, y=0; x<m_Dimension; x++, y++){
547       if((x==exs.classIndex()) || (x==exs.idIndex()))
548       y++;
549
550       w=m_ParamsP[2*x]; m=m_ParamsP[2*x+1];
551       System.err.println("\n\n???Positive: ( "+exs.attribute(y)+
552       "):  w="+w+", m="+m+", sgmSq="+m_SgmSqP[x]);
553
554       w=m_ParamsN[2*x]; m=m_ParamsN[2*x+1];
555       System.err.println("???Negative: ("+exs.attribute(y)+
556       "):  w="+w+", m="+m+", sgmSq="+m_SgmSqN[x]+
557       "\nAvg. log-likelihood ratio in training data="
558       +(m_LkRatio[x]/(pnum+nnum)));
559       }       
560       */
561    if (getDebug())
562      System.err.println("\n\n???Cut-off="+m_Cutoff);
563  }       
564
565  /**
566   *
567   * @param ex the given test exemplar
568   * @return the classification
569   * @throws Exception if the exemplar could not be classified
570   * successfully
571   */
572  public double classifyInstance(Instance ex)throws Exception{
573    //Instance ex = new Exemplar(e);
574    Instances exi = ex.relationalValue(1);
575    double[] n = new double[m_Dimension];
576    double [] xBar = new double[m_Dimension];
577    for (int i=0; i<exi.numAttributes() ; i++)
578      xBar[i] = exi.meanOrMode(i);
579
580    for (int w=0, t=0; w < m_Dimension; w++, t++){
581      // if((t==m_ClassIndex) || (t==m_IdIndex))
582      //t++;   
583      for(int u=0;u<exi.numInstances();u++)
584        if(!exi.instance(u).isMissing(t))
585          n[w] += exi.instance(u).weight();
586    }
587
588    double logOdds = likelihoodRatio(n, xBar);
589    return (logOdds > m_Cutoff) ? 1 : 0 ;
590  }
591 
592  /**
593   * Computes the distribution for a given exemplar
594   *
595   * @param ex the exemplar for which distribution is computed
596   * @return the distribution
597   * @throws Exception if the distribution can't be computed successfully
598   */
599  public double[] distributionForInstance(Instance ex) throws Exception {
600   
601    double[] distribution = new double[2];
602    Instances exi = ex.relationalValue(1);
603    double[] n = new double[m_Dimension];
604    double[] xBar = new double[m_Dimension];
605    for (int i = 0; i < exi.numAttributes() ; i++)
606      xBar[i] = exi.meanOrMode(i);
607   
608    for (int w = 0, t = 0; w < m_Dimension; w++, t++){
609      for (int u = 0; u < exi.numInstances(); u++)
610        if (!exi.instance(u).isMissing(t))
611          n[w] += exi.instance(u).weight();
612    }
613   
614    double logOdds = likelihoodRatio(n, xBar);
615   
616    // returned logOdds value has been divided by m_Dimension to avoid
617    // Math.exp(logOdds) getting too large or too small,
618    // that may result in two fixed distribution value (1 or 0).
619    distribution[0] = 1 / (1 + Math.exp(logOdds)); // Prob. for class 0 (negative)
620    distribution[1] = 1 - distribution[0];
621   
622    return distribution;
623  }     
624
625  /**
626   * Compute the log-likelihood ratio
627   */
628  private double likelihoodRatio(double[] n, double[] xBar){   
629    double LLP = 0.0, LLN = 0.0;
630
631    for (int x=0; x<m_Dimension; x++){
632      if(Double.isNaN(xBar[x])) continue; // All missing values
633      //if(Double.isNaN(xBar[x]) || (m_ParamsP[2*x] <= ZERO)
634      //  || (m_ParamsN[2*x]<=ZERO))
635      //        continue; // All missing values
636
637      //Log-likelihood for positive
638      double w=m_ParamsP[2*x], m=m_ParamsP[2*x+1];
639      double llp = Math.log(w*n[x]+m_SgmSqP[x])
640        + n[x]*(m-xBar[x])*(m-xBar[x])/(w*n[x]+m_SgmSqP[x]);
641      LLP -= llp;
642
643      //Log-likelihood for negative
644      w=m_ParamsN[2*x]; m=m_ParamsN[2*x+1]; 
645      double lln = Math.log(w*n[x]+m_SgmSqN[x])
646        + n[x]*(m-xBar[x])*(m-xBar[x])/(w*n[x]+m_SgmSqN[x]);
647      LLN -= lln;
648
649      m_LkRatio[x] += llp - lln;
650    }
651
652    return LLP - LLN / m_Dimension;
653  }
654
655  private void findCutOff(double[] pos, double[] neg){
656    int[] pOrder = Utils.sort(pos),
657      nOrder = Utils.sort(neg);
658    /*
659       System.err.println("\n\n???Positive: ");
660       for(int t=0; t<pOrder.length; t++)
661       System.err.print(t+":"+Utils.doubleToString(pos[pOrder[t]],0,2)+" ");
662       System.err.println("\n\n???Negative: ");
663       for(int t=0; t<nOrder.length; t++)
664       System.err.print(t+":"+Utils.doubleToString(neg[nOrder[t]],0,2)+" ");
665       */
666    int pNum = pos.length, nNum = neg.length, count, p=0, n=0; 
667    double fstAccu=0.0, sndAccu=(double)pNum, split; 
668    double maxAccu = 0, minDistTo0 = Double.MAX_VALUE;
669
670    // Skip continuous negatives       
671    for(;(n<nNum)&&(pos[pOrder[0]]>=neg[nOrder[n]]); n++, fstAccu++);
672
673    if(n>=nNum){ // totally seperate
674      m_Cutoff = (neg[nOrder[nNum-1]]+pos[pOrder[0]])/2.0;     
675      //m_Cutoff = neg[nOrder[nNum-1]];
676      return; 
677    }   
678
679    count=n;
680    while((p<pNum)&&(n<nNum)){
681      // Compare the next in the two lists
682      if(pos[pOrder[p]]>=neg[nOrder[n]]){ // Neg has less log-odds
683        fstAccu += 1.0;   
684        split=neg[nOrder[n]];
685        n++;     
686      }
687      else{
688        sndAccu -= 1.0;
689        split=pos[pOrder[p]];
690        p++;
691      }           
692      count++;
693      /*
694         double entropy=0.0, cover=(double)count;
695         if(fstAccu>0.0)
696         entropy -= fstAccu*Math.log(fstAccu/cover);
697         if(sndAccu>0.0)
698         entropy -= sndAccu*Math.log(sndAccu/(total-cover));
699
700         if(entropy < minEntropy){
701         minEntropy = entropy;
702      //find the next smallest
703      //double next = neg[nOrder[n]];
704      //if(pos[pOrder[p]]<neg[nOrder[n]])
705      //    next = pos[pOrder[p]];     
706      //m_Cutoff = (split+next)/2.0;
707      m_Cutoff = split;
708         }
709         */
710      if ((fstAccu+sndAccu > maxAccu) || 
711          ((fstAccu+sndAccu == maxAccu) && (Math.abs(split)<minDistTo0))){
712        maxAccu = fstAccu+sndAccu;
713        m_Cutoff = split;
714        minDistTo0 = Math.abs(split);
715     }     
716    }           
717  }
718
719  /**
720   * Returns an enumeration describing the available options
721   *
722   * @return an enumeration of all the available options
723   */
724  public Enumeration listOptions() {
725    Vector result = new Vector();
726   
727    result.addElement(new Option(
728          "\tSet whether or not use empirical\n"
729          + "\tlog-odds cut-off instead of 0",
730          "C", 0, "-C"));
731   
732    result.addElement(new Option(
733          "\tSet the number of multiple runs \n"
734          + "\tneeded for searching the MLE.",
735          "R", 1, "-R <numOfRuns>"));
736   
737    Enumeration enu = super.listOptions();
738    while (enu.hasMoreElements()) {
739      result.addElement(enu.nextElement());
740    }
741
742    return result.elements();
743  }
744
745  /**
746   * Parses a given list of options. <p/>
747   *
748   <!-- options-start -->
749   * Valid options are: <p/>
750   *
751   * <pre> -C
752   *  Set whether or not use empirical
753   *  log-odds cut-off instead of 0</pre>
754   *
755   * <pre> -R &lt;numOfRuns&gt;
756   *  Set the number of multiple runs
757   *  needed for searching the MLE.</pre>
758   *
759   * <pre> -S &lt;num&gt;
760   *  Random number seed.
761   *  (default 1)</pre>
762   *
763   * <pre> -D
764   *  If set, classifier is run in debug mode and
765   *  may output additional info to the console</pre>
766   *
767   <!-- options-end -->
768   *
769   * @param options the list of options as an array of strings
770   * @throws Exception if an option is not supported
771   */
772  public void setOptions(String[] options) throws Exception{
773    setDebug(Utils.getFlag('D', options));
774
775    setUsingCutOff(Utils.getFlag('C', options));
776
777    String runString = Utils.getOption('R', options);
778    if (runString.length() != 0) 
779      setNumRuns(Integer.parseInt(runString));
780    else 
781      setNumRuns(1);
782
783    super.setOptions(options);
784  }
785
786  /**
787   * Gets the current settings of the Classifier.
788   *
789   * @return an array of strings suitable for passing to setOptions
790   */
791  public String[] getOptions() {
792    Vector        result;
793    String[]      options;
794    int           i;
795   
796    result  = new Vector();
797    options = super.getOptions();
798    for (i = 0; i < options.length; i++)
799      result.add(options[i]);
800
801    if (getDebug())
802      result.add("-D");
803   
804    if (getUsingCutOff())
805      result.add("-C");
806
807    result.add("-R");
808    result.add("" + getNumRuns());
809
810    return (String[]) result.toArray(new String[result.size()]);
811  }
812
813  /**
814   * Returns the tip text for this property
815   *
816   * @return tip text for this property suitable for
817   * displaying in the explorer/experimenter gui
818   */
819  public String numRunsTipText() {
820    return "The number of runs to perform.";
821  }
822 
823  /**
824   * Sets the number of runs to perform.
825   *
826   * @param numRuns   the number of runs to perform
827   */
828  public void setNumRuns(int numRuns) {
829    m_Run = numRuns;
830  }
831
832  /**
833   * Returns the number of runs to perform.
834   *
835   * @return          the number of runs to perform
836   */
837  public int getNumRuns() {
838    return m_Run;
839  }
840
841  /**
842   * Returns the tip text for this property
843   *
844   * @return tip text for this property suitable for
845   * displaying in the explorer/experimenter gui
846   */
847  public String usingCutOffTipText() {
848    return "Whether to use an empirical cutoff.";
849  }
850
851  /**
852   * Sets whether to use an empirical cutoff.
853   *
854   * @param cutOff      whether to use an empirical cutoff
855   */
856  public void setUsingCutOff (boolean cutOff) {
857    m_UseEmpiricalCutOff =cutOff;
858  }
859
860  /**
861   * Returns whether an empirical cutoff is used
862   *
863   * @return            true if an empirical cutoff is used
864   */
865  public boolean getUsingCutOff() {
866    return m_UseEmpiricalCutOff ;
867  }
868
869  /**
870   * Gets a string describing the classifier.
871   *
872   * @return a string describing the classifer built.
873   */
874  public String toString(){
875    StringBuffer text = new StringBuffer("\n\nTLDSimple:\n");
876    double sgm, w, m;
877    for (int x=0, y=0; x<m_Dimension; x++, y++){
878      // if((x==m_ClassIndex) || (x==m_IdIndex))
879      //y++;
880      sgm = m_SgmSqP[x];
881      w=m_ParamsP[2*x]; 
882      m=m_ParamsP[2*x+1];
883      text.append("\n"+m_Attribute.attribute(y).name()+"\nPositive: "+
884          "sigma^2="+sgm+", w="+w+", m="+m+"\n");
885      sgm = m_SgmSqN[x];
886      w=m_ParamsN[2*x]; 
887      m=m_ParamsN[2*x+1];
888      text.append("Negative: "+
889          "sigma^2="+sgm+", w="+w+", m="+m+"\n");
890    }
891
892    return text.toString();
893  }     
894 
895  /**
896   * Returns the revision string.
897   *
898   * @return            the revision
899   */
900  public String getRevision() {
901    return RevisionUtils.extract("$Revision: 5481 $");
902  }
903
904  /**
905   * Main method for testing.
906   *
907   * @param args the options for the classifier
908   */
909  public static void main(String[] args) {     
910    runClassifier(new TLDSimple(), args);
911  }
912}
913
914class TLDSimple_Optm extends Optimization {
915
916  private double[] num;
917  private double sSq;
918  private double[] xBar;
919
920  public void setNum(double[] n) {num = n;}
921  public void setSgmSq(double s){
922
923    sSq = s;
924  }
925  public void setXBar(double[] x){xBar = x;}
926
927  /**
928   * Implement this procedure to evaluate objective
929   * function to be minimized
930   */
931  protected double objectiveFunction(double[] x){
932    int numExs = num.length;
933    double NLL=0; // Negative Log-Likelihood
934
935    double w=x[0], m=x[1]; 
936    for(int j=0; j < numExs; j++){
937
938      if(Double.isNaN(xBar[j])) continue; // All missing values
939      double bag=0; 
940
941      bag += Math.log(w*num[j]+sSq);
942
943      if(Double.isNaN(bag) && m_Debug){
944        System.out.println("???????????1: "+w+" "+m
945            +"|x-: "+xBar[j] + 
946            "|n: "+num[j] + "|S^2: "+sSq);
947        //System.exit(1);
948      }
949
950      bag += num[j]*(m-xBar[j])*(m-xBar[j])/(w*num[j]+sSq);                 
951      if(Double.isNaN(bag) && m_Debug){
952        System.out.println("???????????2: "+w+" "+m
953            +"|x-: "+xBar[j] + 
954            "|n: "+num[j] + "|S^2: "+sSq);
955        //System.exit(1);
956      }               
957
958      //if(bag<0) bag=0;
959      NLL += bag;
960    }
961
962    //System.out.println("???????????NLL:"+NLL);
963    return NLL;
964  }
965
966  /**
967   * Subclass should implement this procedure to evaluate gradient
968   * of the objective function
969   */
970  protected double[] evaluateGradient(double[] x){
971    double[] g = new double[x.length];
972    int numExs = num.length;
973
974    double w=x[0],m=x[1];       
975    double dw=0.0, dm=0.0;
976
977    for(int j=0; j < numExs; j++){
978
979      if(Double.isNaN(xBar[j])) continue; // All missing values     
980      dw += num[j]/(w*num[j]+sSq) 
981        - num[j]*num[j]*(m-xBar[j])*(m-xBar[j])/((w*num[j]+sSq)*(w*num[j]+sSq));
982
983      dm += 2.0*num[j]*(m-xBar[j])/(w*num[j]+sSq);
984    }
985
986    g[0] = dw;
987    g[1] = dm;
988    return g;
989  }
990
991  /**
992   * Subclass should implement this procedure to evaluate second-order
993   * gradient of the objective function
994   */
995  protected double[] evaluateHessian(double[] x, int index){
996    double[] h = new double[x.length];
997
998    // # of exemplars, # of dimensions
999    // which dimension and which variable for 'index'
1000    int numExs = num.length;
1001    double w,m;
1002    // Take the 2nd-order derivative
1003    switch(index){     
1004      case 0: // w   
1005        w=x[0];m=x[1];
1006
1007        for(int j=0; j < numExs; j++){
1008          if(Double.isNaN(xBar[j])) continue; //All missing values
1009
1010          h[0] += 2.0*Math.pow(num[j],3)*(m-xBar[j])*(m-xBar[j])/Math.pow(w*num[j]+sSq,3)
1011            - num[j]*num[j]/((w*num[j]+sSq)*(w*num[j]+sSq));
1012
1013          h[1] -= 2.0*(m-xBar[j])*num[j]*num[j]/((num[j]*w+sSq)*(num[j]*w+sSq));               
1014        }
1015        break;
1016
1017      case 1: // m
1018        w=x[0];m=x[1];
1019
1020        for(int j=0; j < numExs; j++){
1021          if(Double.isNaN(xBar[j])) continue; //All missing values
1022
1023          h[0] -= 2.0*(m-xBar[j])*num[j]*num[j]/((num[j]*w+sSq)*(num[j]*w+sSq));
1024
1025          h[1] += 2.0*num[j]/(w*num[j]+sSq);                           
1026        }
1027    }
1028
1029    return h;
1030  }
1031 
1032  /**
1033   * Returns the revision string.
1034   *
1035   * @return            the revision
1036   */
1037  public String getRevision() {
1038    return RevisionUtils.extract("$Revision: 5481 $");
1039  }
1040}
Note: See TracBrowser for help on using the repository browser.