source: src/main/java/weka/classifiers/mi/TLD.java @ 8

Last change on this file since 8 was 4, checked in by gnappo, 14 years ago

Import di weka.

File size: 34.2 KB
Line 
1/*
2 *    This program is free software; you can redistribute it and/or modify
3 *    it under the terms of the GNU General Public License as published by
4 *    the Free Software Foundation; either version 2 of the License, or
5 *    (at your option) any later version.
6 *
7 *    This program is distributed in the hope that it will be useful,
8 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
9 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10 *    GNU General Public License for more details.
11 *
12 *    You should have received a copy of the GNU General Public License
13 *    along with this program; if not, write to the Free Software
14 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
15 */
16
17/*
18 * TLD.java
19 * Copyright (C) 2005 University of Waikato, Hamilton, New Zealand
20 *
21 */
22
23package weka.classifiers.mi;
24
25import weka.classifiers.RandomizableClassifier;
26import weka.core.Capabilities;
27import weka.core.Instance;
28import weka.core.Instances;
29import weka.core.MultiInstanceCapabilitiesHandler;
30import weka.core.Optimization;
31import weka.core.Option;
32import weka.core.OptionHandler;
33import weka.core.RevisionUtils;
34import weka.core.TechnicalInformation;
35import weka.core.TechnicalInformationHandler;
36import weka.core.Utils;
37import weka.core.Capabilities.Capability;
38import weka.core.TechnicalInformation.Field;
39import weka.core.TechnicalInformation.Type;
40
41import java.util.Enumeration;
42import java.util.Random;
43import java.util.Vector;
44
45/**
46 <!-- globalinfo-start -->
47 * Two-Level Distribution approach, changes the starting value of the searching algorithm, supplement the cut-off modification and check missing values.<br/>
48 * <br/>
49 * For more information see:<br/>
50 * <br/>
51 * Xin Xu (2003). Statistical learning in multiple instance problem. Hamilton, NZ.
52 * <p/>
53 <!-- globalinfo-end -->
54 *
55 <!-- technical-bibtex-start -->
56 * BibTeX:
57 * <pre>
58 * &#64;mastersthesis{Xu2003,
59 *    address = {Hamilton, NZ},
60 *    author = {Xin Xu},
61 *    note = {0657.594},
62 *    school = {University of Waikato},
63 *    title = {Statistical learning in multiple instance problem},
64 *    year = {2003}
65 * }
66 * </pre>
67 * <p/>
68 <!-- technical-bibtex-end -->
69 *
70 <!-- options-start -->
71 * Valid options are: <p/>
72 *
73 * <pre> -C
74 *  Set whether or not use empirical
75 *  log-odds cut-off instead of 0</pre>
76 *
77 * <pre> -R &lt;numOfRuns&gt;
78 *  Set the number of multiple runs
79 *  needed for searching the MLE.</pre>
80 *
81 * <pre> -S &lt;num&gt;
82 *  Random number seed.
83 *  (default 1)</pre>
84 *
85 * <pre> -D
86 *  If set, classifier is run in debug mode and
87 *  may output additional info to the console</pre>
88 *
89 <!-- options-end -->
90 *
91 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
92 * @author Xin Xu (xx5@cs.waikato.ac.nz)
93 * @version $Revision: 5481 $
94 */
95public class TLD 
96  extends RandomizableClassifier
97  implements OptionHandler, MultiInstanceCapabilitiesHandler,
98             TechnicalInformationHandler {
99
100  /** for serialization */
101  static final long serialVersionUID = 6657315525171152210L;
102 
103  /** The mean for each attribute of each positive exemplar */
104  protected double[][] m_MeanP = null;
105
106  /** The variance for each attribute of each positive exemplar */
107  protected double[][] m_VarianceP = null;
108
109  /** The mean for each attribute of each negative exemplar */
110  protected double[][] m_MeanN = null;
111
112  /** The variance for each attribute of each negative exemplar */
113  protected double[][] m_VarianceN = null;
114
115  /** The effective sum of weights of each positive exemplar in each dimension*/
116  protected double[][] m_SumP = null;
117
118  /** The effective sum of weights of each negative exemplar in each dimension*/
119  protected double[][] m_SumN = null;
120
121  /** The parameters to be estimated for each positive exemplar*/
122  protected double[] m_ParamsP = null;
123
124  /** The parameters to be estimated for each negative exemplar*/
125  protected double[] m_ParamsN = null;
126
127  /** The dimension of each exemplar, i.e. (numAttributes-2) */
128  protected int m_Dimension = 0;
129
130  /** The class label of each exemplar */
131  protected double[] m_Class = null;
132
133  /** The number of class labels in the data */
134  protected int m_NumClasses = 2;
135
136  /** The very small number representing zero */
137  static public double ZERO = 1.0e-6;   
138
139  /** The number of runs to perform */
140  protected int m_Run = 1;
141
142  protected double m_Cutoff;
143
144  protected boolean m_UseEmpiricalCutOff = false;   
145
146  /**
147   * Returns a string describing this filter
148   *
149   * @return a description of the filter suitable for
150   * displaying in the explorer/experimenter gui
151   */
152  public String globalInfo() {
153    return 
154        "Two-Level Distribution approach, changes the starting value of "
155      + "the searching algorithm, supplement the cut-off modification and "
156      + "check missing values.\n\n"
157      + "For more information see:\n\n"
158      + getTechnicalInformation().toString();
159  }
160 
161  /**
162   * Returns an instance of a TechnicalInformation object, containing
163   * detailed information about the technical background of this class,
164   * e.g., paper reference or book this class is based on.
165   *
166   * @return the technical information about this class
167   */
168  public TechnicalInformation getTechnicalInformation() {
169    TechnicalInformation        result;
170   
171    result = new TechnicalInformation(Type.MASTERSTHESIS);
172    result.setValue(Field.AUTHOR, "Xin Xu");
173    result.setValue(Field.YEAR, "2003");
174    result.setValue(Field.TITLE, "Statistical learning in multiple instance problem");
175    result.setValue(Field.SCHOOL, "University of Waikato");
176    result.setValue(Field.ADDRESS, "Hamilton, NZ");
177    result.setValue(Field.NOTE, "0657.594");
178   
179    return result;
180  }
181
182  /**
183   * Returns default capabilities of the classifier.
184   *
185   * @return      the capabilities of this classifier
186   */
187  public Capabilities getCapabilities() {
188    Capabilities result = super.getCapabilities();
189    result.disableAll();
190
191    // attributes
192    result.enable(Capability.NOMINAL_ATTRIBUTES);
193    result.enable(Capability.RELATIONAL_ATTRIBUTES);
194    result.enable(Capability.MISSING_VALUES);
195
196    // class
197    result.enable(Capability.BINARY_CLASS);
198    result.enable(Capability.MISSING_CLASS_VALUES);
199   
200    // other
201    result.enable(Capability.ONLY_MULTIINSTANCE);
202   
203    return result;
204  }
205
206  /**
207   * Returns the capabilities of this multi-instance classifier for the
208   * relational data.
209   *
210   * @return            the capabilities of this object
211   * @see               Capabilities
212   */
213  public Capabilities getMultiInstanceCapabilities() {
214    Capabilities result = super.getCapabilities();
215    result.disableAll();
216   
217    // attributes
218    result.enable(Capability.NUMERIC_ATTRIBUTES);
219    result.enable(Capability.MISSING_VALUES);
220
221    // class
222    result.disableAllClasses();
223    result.enable(Capability.NO_CLASS);
224   
225    return result;
226  }
227
228  /**
229   *
230   * @param exs the training exemplars
231   * @throws Exception if the model cannot be built properly
232   */   
233  public void buildClassifier(Instances exs)throws Exception{
234    // can classifier handle the data?
235    getCapabilities().testWithFail(exs);
236
237    // remove instances with missing class
238    exs = new Instances(exs);
239    exs.deleteWithMissingClass();
240   
241    int numegs = exs.numInstances();
242    m_Dimension = exs.attribute(1).relation(). numAttributes();
243    Instances pos = new Instances(exs, 0), neg = new Instances(exs, 0);
244
245    for(int u=0; u<numegs; u++){
246      Instance example = exs.instance(u);
247      if(example.classValue() == 1)
248        pos.add(example);
249      else
250        neg.add(example);
251    }
252
253    int pnum = pos.numInstances(), nnum = neg.numInstances();   
254
255    m_MeanP = new double[pnum][m_Dimension];
256    m_VarianceP = new double[pnum][m_Dimension];
257    m_SumP = new double[pnum][m_Dimension];
258    m_MeanN = new double[nnum][m_Dimension];
259    m_VarianceN = new double[nnum][m_Dimension];
260    m_SumN = new double[nnum][m_Dimension];
261    m_ParamsP = new double[4*m_Dimension];
262    m_ParamsN = new double[4*m_Dimension];
263
264    // Estimation of the parameters: as the start value for search
265    double[] pSumVal=new double[m_Dimension], // for m
266      nSumVal=new double[m_Dimension]; 
267    double[] maxVarsP=new double[m_Dimension], // for a
268      maxVarsN=new double[m_Dimension]; 
269    // Mean of sample variances: for b, b=a/E(\sigma^2)+2
270    double[] varMeanP = new double[m_Dimension],
271      varMeanN = new double[m_Dimension]; 
272    // Variances of sample means: for w, w=E[var(\mu)]/E[\sigma^2]
273    double[] meanVarP = new double[m_Dimension],
274      meanVarN = new double[m_Dimension];
275    // number of exemplars without all values missing
276    double[] numExsP = new double[m_Dimension],
277      numExsN = new double[m_Dimension];
278
279    // Extract metadata fro both positive and negative bags
280    for(int v=0; v < pnum; v++){
281      /*Exemplar px = pos.exemplar(v);
282        m_MeanP[v] = px.meanOrMode();
283        m_VarianceP[v] = px.variance();
284        Instances pxi =  px.getInstances();
285        */
286
287      Instances pxi =  pos.instance(v).relationalValue(1);
288      for (int k=0; k<pxi.numAttributes(); k++) { 
289        m_MeanP[v][k] = pxi.meanOrMode(k);
290        m_VarianceP[v][k] = pxi.variance(k);
291      }
292
293      for (int w=0,t=0; w < m_Dimension; w++,t++){             
294        //if((t==m_ClassIndex) || (t==m_IdIndex))
295        //  t++;               
296
297        if(!Double.isNaN(m_MeanP[v][w])){
298          for(int u=0;u<pxi.numInstances();u++){
299            Instance ins = pxi.instance(u);                     
300            if(!ins.isMissing(t))
301              m_SumP[v][w] += ins.weight();                       
302          }   
303          numExsP[w]++; 
304          pSumVal[w] += m_MeanP[v][w];
305          meanVarP[w] += m_MeanP[v][w]*m_MeanP[v][w];   
306          if(maxVarsP[w] < m_VarianceP[v][w])
307            maxVarsP[w] = m_VarianceP[v][w];
308          varMeanP[w] += m_VarianceP[v][w];
309          m_VarianceP[v][w] *= (m_SumP[v][w]-1.0);
310          if(m_VarianceP[v][w] < 0.0)
311            m_VarianceP[v][w] = 0.0;
312        }
313      }
314    }
315
316    for(int v=0; v < nnum; v++){
317      /*Exemplar nx = neg.exemplar(v);
318        m_MeanN[v] = nx.meanOrMode();
319        m_VarianceN[v] = nx.variance();
320        Instances nxi =  nx.getInstances();
321        */
322      Instances nxi =  neg.instance(v).relationalValue(1);
323      for (int k=0; k<nxi.numAttributes(); k++) {
324        m_MeanN[v][k] = nxi.meanOrMode(k);
325        m_VarianceN[v][k] = nxi.variance(k);
326      }
327
328      for (int w=0,t=0; w < m_Dimension; w++,t++){             
329        //if((t==m_ClassIndex) || (t==m_IdIndex))
330        //  t++;               
331
332        if(!Double.isNaN(m_MeanN[v][w])){
333          for(int u=0;u<nxi.numInstances();u++)
334            if(!nxi.instance(u).isMissing(t))
335              m_SumN[v][w] += nxi.instance(u).weight();
336          numExsN[w]++;         
337          nSumVal[w] += m_MeanN[v][w];
338          meanVarN[w] += m_MeanN[v][w]*m_MeanN[v][w]; 
339          if(maxVarsN[w] < m_VarianceN[v][w])
340            maxVarsN[w] = m_VarianceN[v][w];
341          varMeanN[w] += m_VarianceN[v][w];
342          m_VarianceN[v][w] *= (m_SumN[v][w]-1.0);
343          if(m_VarianceN[v][w] < 0.0)
344            m_VarianceN[v][w] = 0.0;
345        }
346      }
347    }
348
349    for(int w=0; w<m_Dimension; w++){
350      pSumVal[w] /= numExsP[w];
351      nSumVal[w] /= numExsN[w];
352      if(numExsP[w]>1)
353        meanVarP[w] = meanVarP[w]/(numExsP[w]-1.0) 
354          - pSumVal[w]*numExsP[w]/(numExsP[w]-1.0);
355      if(numExsN[w]>1)
356        meanVarN[w] = meanVarN[w]/(numExsN[w]-1.0) 
357          - nSumVal[w]*numExsN[w]/(numExsN[w]-1.0);
358      varMeanP[w] /= numExsP[w];
359      varMeanN[w] /= numExsN[w];
360    }
361
362    //Bounds and parameter values for each run
363    double[][] bounds = new double[2][4];
364    double[] pThisParam = new double[4], 
365      nThisParam = new double[4];
366
367    // Initial values for parameters
368    double a, b, w, m;
369
370    // Optimize for one dimension
371    for (int x=0; x < m_Dimension; x++){
372      if (getDebug())
373        System.err.println("\n\n!!!!!!!!!!!!!!!!!!!!!!???Dimension #"+x);
374
375      // Positive examplars: first run
376      a = (maxVarsP[x]>ZERO) ? maxVarsP[x]:1.0; 
377      if (varMeanP[x]<=ZERO)   varMeanP[x] = ZERO;  // modified by LinDong (09/2005)
378      b = a/varMeanP[x]+2.0; // a/(b-2) = E(\sigma^2)
379      w = meanVarP[x]/varMeanP[x]; // E[var(\mu)] = w*E[\sigma^2]           
380      if(w<=ZERO)  w=1.0;
381
382      m = pSumVal[x];     
383      pThisParam[0] = a;    // a
384      pThisParam[1] = b;  // b
385      pThisParam[2] = w;  // w
386      pThisParam[3] = m;  // m
387
388      // Negative examplars: first run
389      a = (maxVarsN[x]>ZERO) ? maxVarsN[x]:1.0; 
390      if (varMeanN[x]<=ZERO)   varMeanN[x] = ZERO; // modified by LinDong (09/2005)
391      b = a/varMeanN[x]+2.0; // a/(b-2) = E(\sigma^2)
392      w = meanVarN[x]/varMeanN[x]; // E[var(\mu)] = w*E[\sigma^2]           
393      if(w<=ZERO) w=1.0;
394
395      m = nSumVal[x];     
396      nThisParam[0] = a;    // a
397      nThisParam[1] = b;  // b
398      nThisParam[2] = w;  // w
399      nThisParam[3] = m;  // m
400
401      // Bound constraints
402      bounds[0][0] = ZERO; // a > 0
403      bounds[0][1] = 2.0+ZERO;  // b > 2
404      bounds[0][2] = ZERO; // w > 0
405      bounds[0][3] = Double.NaN;
406
407      for(int t=0; t<4; t++){
408        bounds[1][t] = Double.NaN;
409        m_ParamsP[4*x+t] = pThisParam[t];       
410        m_ParamsN[4*x+t] = nThisParam[t];
411      }
412      double pminVal=Double.MAX_VALUE, nminVal=Double.MAX_VALUE;
413      Random whichEx = new Random(m_Seed); 
414      TLD_Optm pOp=null, nOp=null;     
415      boolean isRunValid = true;
416      double[] sumP=new double[pnum], meanP=new double[pnum],
417        varP=new double[pnum];
418      double[] sumN=new double[nnum], meanN=new double[nnum],
419        varN=new double[nnum];
420
421      // One dimension
422      for(int p=0; p<pnum; p++){
423        sumP[p] = m_SumP[p][x];
424        meanP[p] = m_MeanP[p][x];
425        varP[p] = m_VarianceP[p][x];
426      }
427      for(int q=0; q<nnum; q++){
428        sumN[q] = m_SumN[q][x];
429        meanN[q] = m_MeanN[q][x];
430        varN[q] = m_VarianceN[q][x];
431      }
432
433      for(int y=0; y<m_Run;){
434        if (getDebug())
435          System.err.println("\n\n!!!!!!!!!!!!!!!!!!!!!!???Run #"+y);
436        double thisMin;
437
438        if (getDebug())
439          System.err.println("\nPositive exemplars");
440        pOp = new TLD_Optm();
441        pOp.setNum(sumP);
442        pOp.setSSquare(varP);
443        pOp.setXBar(meanP);
444
445        pThisParam = pOp.findArgmin(pThisParam, bounds);
446        while(pThisParam==null){
447          pThisParam = pOp.getVarbValues();                 
448          if (getDebug())
449            System.err.println("!!! 200 iterations finished, not enough!");
450          pThisParam = pOp.findArgmin(pThisParam, bounds);
451        }       
452
453        thisMin = pOp.getMinFunction();
454        if(!Double.isNaN(thisMin) && (thisMin<pminVal)){
455          pminVal = thisMin;
456          for(int z=0; z<4; z++)
457            m_ParamsP[4*x+z] = pThisParam[z];
458        }
459
460        if(Double.isNaN(thisMin)){
461          pThisParam = new double[4];
462          isRunValid =false;
463        }
464
465        if (getDebug())
466          System.err.println("\nNegative exemplars");
467        nOp = new TLD_Optm();
468        nOp.setNum(sumN);
469        nOp.setSSquare(varN);
470        nOp.setXBar(meanN);
471
472        nThisParam = nOp.findArgmin(nThisParam, bounds);
473        while(nThisParam==null){
474          nThisParam = nOp.getVarbValues();
475          if (getDebug())
476            System.err.println("!!! 200 iterations finished, not enough!");
477          nThisParam = nOp.findArgmin(nThisParam, bounds);
478        }       
479        thisMin = nOp.getMinFunction();
480        if(!Double.isNaN(thisMin) && (thisMin<nminVal)){
481          nminVal = thisMin;
482          for(int z=0; z<4; z++)
483            m_ParamsN[4*x+z] = nThisParam[z];     
484        }
485
486        if(Double.isNaN(thisMin)){
487          nThisParam = new double[4];
488          isRunValid =false;
489        }
490
491        if(!isRunValid){ y--; isRunValid=true; }               
492
493        if(++y<m_Run){
494          // Change the initial parameters and restart             
495          int pone = whichEx.nextInt(pnum), // Randomly pick one pos. exmpl.
496              none = whichEx.nextInt(nnum);
497
498          // Positive exemplars: next run
499          while((m_SumP[pone][x]<=1.0)||Double.isNaN(m_MeanP[pone][x]))
500            pone = whichEx.nextInt(pnum);
501
502          a = m_VarianceP[pone][x]/(m_SumP[pone][x]-1.0);               
503          if(a<=ZERO) a=m_ParamsN[4*x]; // Change to negative params
504          m = m_MeanP[pone][x];
505          double sq = (m-m_ParamsP[4*x+3])*(m-m_ParamsP[4*x+3]);
506
507          b = a*m_ParamsP[4*x+2]/sq+2.0; // b=a/Var+2, assuming Var=Sq/w'
508          if((b<=ZERO) || Double.isNaN(b) || Double.isInfinite(b))
509            b=m_ParamsN[4*x+1];
510
511          w = sq*(m_ParamsP[4*x+1]-2.0)/m_ParamsP[4*x];//w=Sq/Var, assuming Var=a'/(b'-2)
512          if((w<=ZERO) || Double.isNaN(w) || Double.isInfinite(w))
513            w=m_ParamsN[4*x+2];
514
515          pThisParam[0] = a;    // a
516          pThisParam[1] = b;  // b
517          pThisParam[2] = w;  // w
518          pThisParam[3] = m;  // m         
519
520          // Negative exemplars: next run
521          while((m_SumN[none][x]<=1.0)||Double.isNaN(m_MeanN[none][x]))
522            none = whichEx.nextInt(nnum);           
523
524          a = m_VarianceN[none][x]/(m_SumN[none][x]-1.0);       
525          if(a<=ZERO) a=m_ParamsP[4*x];       
526          m = m_MeanN[none][x];
527          sq = (m-m_ParamsN[4*x+3])*(m-m_ParamsN[4*x+3]);
528
529          b = a*m_ParamsN[4*x+2]/sq+2.0; // b=a/Var+2, assuming Var=Sq/w'
530          if((b<=ZERO) || Double.isNaN(b) || Double.isInfinite(b))
531            b=m_ParamsP[4*x+1];
532
533          w = sq*(m_ParamsN[4*x+1]-2.0)/m_ParamsN[4*x];//w=Sq/Var, assuming Var=a'/(b'-2)
534          if((w<=ZERO) || Double.isNaN(w) || Double.isInfinite(w))
535            w=m_ParamsP[4*x+2];
536
537          nThisParam[0] = a;    // a
538          nThisParam[1] = b;  // b
539          nThisParam[2] = w;  // w
540          nThisParam[3] = m;  // m                     
541        }
542      }                     
543    }
544
545    for (int x=0, y=0; x<m_Dimension; x++, y++){
546      //if((x==exs.classIndex()) || (x==exs.idIndex()))
547      //y++;
548      a=m_ParamsP[4*x]; b=m_ParamsP[4*x+1]; 
549      w=m_ParamsP[4*x+2]; m=m_ParamsP[4*x+3];
550      if (getDebug())
551        System.err.println("\n\n???Positive: ( "+exs.attribute(1).relation().attribute(y)+
552          "): a="+a+", b="+b+", w="+w+", m="+m);
553
554      a=m_ParamsN[4*x]; b=m_ParamsN[4*x+1]; 
555      w=m_ParamsN[4*x+2]; m=m_ParamsN[4*x+3];
556      if (getDebug())
557        System.err.println("???Negative: ("+exs.attribute(1).relation().attribute(y)+
558          "): a="+a+", b="+b+", w="+w+", m="+m);
559    }
560
561    if(m_UseEmpiricalCutOff){   
562      // Find the empirical cut-off
563      double[] pLogOdds=new double[pnum], nLogOdds=new double[nnum]; 
564      for(int p=0; p<pnum; p++)
565        pLogOdds[p] = 
566          likelihoodRatio(m_SumP[p], m_MeanP[p], m_VarianceP[p]);
567
568      for(int q=0; q<nnum; q++)
569        nLogOdds[q] = 
570          likelihoodRatio(m_SumN[q], m_MeanN[q], m_VarianceN[q]);
571
572      // Update m_Cutoff
573      findCutOff(pLogOdds, nLogOdds);
574    }
575    else
576      m_Cutoff = -Math.log((double)pnum/(double)nnum);
577
578    if (getDebug())
579      System.err.println("???Cut-off="+m_Cutoff);
580  }       
581
582  /**
583   *
584   * @param ex the given test exemplar
585   * @return the classification
586   * @throws Exception if the exemplar could not be classified
587   * successfully
588   */
589  public double classifyInstance(Instance ex)throws Exception{
590    //Exemplar ex = new Exemplar(e);
591    Instances exi = ex.relationalValue(1);
592    double[] n = new double[m_Dimension];
593    double [] xBar = new double[m_Dimension];
594    double [] sSq = new double[m_Dimension];
595    for (int i=0; i<exi.numAttributes() ; i++){
596      xBar[i] = exi.meanOrMode(i);
597      sSq[i] = exi.variance(i);
598    }
599
600    for (int w=0, t=0; w < m_Dimension; w++, t++){
601      //if((t==m_ClassIndex) || (t==m_IdIndex))
602      //t++;   
603      for(int u=0;u<exi.numInstances();u++)
604        if(!exi.instance(u).isMissing(t))
605          n[w] += exi.instance(u).weight();
606
607      sSq[w] = sSq[w]*(n[w]-1.0);
608      if(sSq[w] <= 0.0)
609        sSq[w] = 0.0;
610    }
611
612    double logOdds = likelihoodRatio(n, xBar, sSq);
613    return (logOdds > m_Cutoff) ? 1 : 0 ;
614  }
615
616  private double likelihoodRatio(double[] n, double[] xBar, double[] sSq){     
617    double LLP = 0.0, LLN = 0.0;
618
619    for (int x=0; x<m_Dimension; x++){
620      if(Double.isNaN(xBar[x])) continue; // All missing values
621
622      int halfN = ((int)n[x])/2;       
623      //Log-likelihood for positive
624      double a=m_ParamsP[4*x], b=m_ParamsP[4*x+1], 
625             w=m_ParamsP[4*x+2], m=m_ParamsP[4*x+3];
626      LLP += 0.5*b*Math.log(a) + 0.5*(b+n[x]-1.0)*Math.log(1.0+n[x]*w)
627        - 0.5*(b+n[x])*Math.log((1.0+n[x]*w)*(a+sSq[x])+
628            n[x]*(xBar[x]-m)*(xBar[x]-m))
629        - 0.5*n[x]*Math.log(Math.PI);
630      for(int y=1; y<=halfN; y++)
631        LLP += Math.log(b/2.0+n[x]/2.0-(double)y);
632
633      if(n[x]/2.0 > halfN) // n is odd
634        LLP += TLD_Optm.diffLnGamma(b/2.0);
635
636      //Log-likelihood for negative
637      a=m_ParamsN[4*x];
638      b=m_ParamsN[4*x+1]; 
639      w=m_ParamsN[4*x+2];
640      m=m_ParamsN[4*x+3];
641      LLN += 0.5*b*Math.log(a) + 0.5*(b+n[x]-1.0)*Math.log(1.0+n[x]*w)
642        - 0.5*(b+n[x])*Math.log((1.0+n[x]*w)*(a+sSq[x])+
643            n[x]*(xBar[x]-m)*(xBar[x]-m))
644        - 0.5*n[x]*Math.log(Math.PI);
645      for(int y=1; y<=halfN; y++)
646        LLN += Math.log(b/2.0+n[x]/2.0-(double)y);     
647
648      if(n[x]/2.0 > halfN) // n is odd
649        LLN += TLD_Optm.diffLnGamma(b/2.0);   
650    }
651
652    return LLP - LLN;
653  }
654
655  private void findCutOff(double[] pos, double[] neg){
656    int[] pOrder = Utils.sort(pos),
657      nOrder = Utils.sort(neg);
658    /*
659       System.err.println("\n\n???Positive: ");
660       for(int t=0; t<pOrder.length; t++)
661       System.err.print(t+":"+Utils.doubleToString(pos[pOrder[t]],0,2)+" ");
662       System.err.println("\n\n???Negative: ");
663       for(int t=0; t<nOrder.length; t++)
664       System.err.print(t+":"+Utils.doubleToString(neg[nOrder[t]],0,2)+" ");
665       */
666    int pNum = pos.length, nNum = neg.length, count, p=0, n=0; 
667    double fstAccu=0.0, sndAccu=(double)pNum, split; 
668    double maxAccu = 0, minDistTo0 = Double.MAX_VALUE;
669
670    // Skip continuous negatives       
671    for(;(n<nNum)&&(pos[pOrder[0]]>=neg[nOrder[n]]); n++, fstAccu++);
672
673    if(n>=nNum){ // totally seperate
674      m_Cutoff = (neg[nOrder[nNum-1]]+pos[pOrder[0]])/2.0;     
675      //m_Cutoff = neg[nOrder[nNum-1]];
676      return; 
677    }   
678
679    count=n;
680    while((p<pNum)&&(n<nNum)){
681      // Compare the next in the two lists
682      if(pos[pOrder[p]]>=neg[nOrder[n]]){ // Neg has less log-odds
683        fstAccu += 1.0;   
684        split=neg[nOrder[n]];
685        n++;     
686      }
687      else{
688        sndAccu -= 1.0;
689        split=pos[pOrder[p]];
690        p++;
691      }           
692      count++;
693      if((fstAccu+sndAccu > maxAccu) 
694          || ((fstAccu+sndAccu == maxAccu) && (Math.abs(split)<minDistTo0))){
695        maxAccu = fstAccu+sndAccu;
696        m_Cutoff = split;
697        minDistTo0 = Math.abs(split);
698      }     
699    }           
700  }
701
702  /**
703   * Returns an enumeration describing the available options
704   *
705   * @return an enumeration of all the available options
706   */
707  public Enumeration listOptions() {
708    Vector result = new Vector();
709   
710    result.addElement(new Option(
711          "\tSet whether or not use empirical\n"
712          + "\tlog-odds cut-off instead of 0",
713          "C", 0, "-C"));
714   
715    result.addElement(new Option(
716          "\tSet the number of multiple runs \n"
717          + "\tneeded for searching the MLE.",
718          "R", 1, "-R <numOfRuns>"));
719   
720    Enumeration enu = super.listOptions();
721    while (enu.hasMoreElements()) {
722      result.addElement(enu.nextElement());
723    }
724
725    return result.elements();
726  }
727
728  /**
729   * Parses a given list of options. <p/>
730   *
731   <!-- options-start -->
732   * Valid options are: <p/>
733   *
734   * <pre> -C
735   *  Set whether or not use empirical
736   *  log-odds cut-off instead of 0</pre>
737   *
738   * <pre> -R &lt;numOfRuns&gt;
739   *  Set the number of multiple runs
740   *  needed for searching the MLE.</pre>
741   *
742   * <pre> -S &lt;num&gt;
743   *  Random number seed.
744   *  (default 1)</pre>
745   *
746   * <pre> -D
747   *  If set, classifier is run in debug mode and
748   *  may output additional info to the console</pre>
749   *
750   <!-- options-end -->
751   *
752   * @param options the list of options as an array of strings
753   * @throws Exception if an option is not supported
754   */
755  public void setOptions(String[] options) throws Exception{
756    setDebug(Utils.getFlag('D', options));
757
758    setUsingCutOff(Utils.getFlag('C', options));
759
760    String runString = Utils.getOption('R', options);
761    if (runString.length() != 0) 
762      setNumRuns(Integer.parseInt(runString));
763    else 
764      setNumRuns(1);
765
766    super.setOptions(options);
767  }
768
769  /**
770   * Gets the current settings of the Classifier.
771   *
772   * @return an array of strings suitable for passing to setOptions
773   */
774  public String[] getOptions() {
775    Vector        result;
776    String[]      options;
777    int           i;
778   
779    result  = new Vector();
780    options = super.getOptions();
781    for (i = 0; i < options.length; i++)
782      result.add(options[i]);
783
784    if (getDebug())
785      result.add("-D");
786   
787    if (getUsingCutOff())
788      result.add("-C");
789
790    result.add("-R");
791    result.add("" + getNumRuns());
792
793    return (String[]) result.toArray(new String[result.size()]);
794  }
795
796  /**
797   * Returns the tip text for this property
798   *
799   * @return tip text for this property suitable for
800   * displaying in the explorer/experimenter gui
801   */
802  public String numRunsTipText() {
803    return "The number of runs to perform.";
804  }
805
806  /**
807   * Sets the number of runs to perform.
808   *
809   * @param numRuns   the number of runs to perform
810   */
811  public void setNumRuns(int numRuns) {
812    m_Run = numRuns;
813  }
814
815  /**
816   * Returns the number of runs to perform.
817   *
818   * @return          the number of runs to perform
819   */
820  public int getNumRuns() {
821    return m_Run;
822  }
823
824  /**
825   * Returns the tip text for this property
826   *
827   * @return tip text for this property suitable for
828   * displaying in the explorer/experimenter gui
829   */
830  public String usingCutOffTipText() {
831    return "Whether to use an empirical cutoff.";
832  }
833
834  /**
835   * Sets whether to use an empirical cutoff.
836   *
837   * @param cutOff      whether to use an empirical cutoff
838   */
839  public void setUsingCutOff (boolean cutOff) {
840    m_UseEmpiricalCutOff = cutOff;
841  }
842
843  /**
844   * Returns whether an empirical cutoff is used
845   *
846   * @return            true if an empirical cutoff is used
847   */
848  public boolean getUsingCutOff() {
849    return m_UseEmpiricalCutOff;
850  }
851 
852  /**
853   * Returns the revision string.
854   *
855   * @return            the revision
856   */
857  public String getRevision() {
858    return RevisionUtils.extract("$Revision: 5481 $");
859  }
860
861  /**
862   * Main method for testing.
863   *
864   * @param args the options for the classifier
865   */
866  public static void main(String[] args) {     
867    runClassifier(new TLD(), args);
868  }
869}
870
871class TLD_Optm extends Optimization {
872
873  private double[] num;
874  private double[] sSq;
875  private double[] xBar;
876
877  public void setNum(double[] n) {num = n;}
878  public void setSSquare(double[] s){sSq = s;}
879  public void setXBar(double[] x){xBar = x;}
880
881  /**
882   * Compute Ln[Gamma(b+0.5)] - Ln[Gamma(b)]
883   *
884   * @param b the value in the above formula
885   * @return the result
886   */   
887  public static double diffLnGamma(double b){
888    double[] coef= {76.18009172947146, -86.50532032941677,
889      24.01409824083091, -1.231739572450155, 
890      0.1208650973866179e-2, -0.5395239384953e-5};
891    double rt = -0.5;
892    rt += (b+1.0)*Math.log(b+6.0) - (b+0.5)*Math.log(b+5.5);
893    double series1=1.000000000190015, series2=1.000000000190015;
894    for(int i=0; i<6; i++){
895      series1 += coef[i]/(b+1.5+(double)i);
896      series2 += coef[i]/(b+1.0+(double)i);
897    }
898
899    rt += Math.log(series1*b)-Math.log(series2*(b+0.5));
900    return rt;
901  }
902
903  /**
904   * Compute dLn[Gamma(x+0.5)]/dx - dLn[Gamma(x)]/dx
905   *
906   * @param x the value in the above formula
907   * @return the result
908   */   
909  protected double diffFstDervLnGamma(double x){
910    double rt=0, series=1.0;// Just make it >0
911    for(int i=0;series>=m_Zero*1e-3;i++){
912      series = 0.5/((x+(double)i)*(x+(double)i+0.5));
913      rt += series;
914    }
915    return rt;
916  }
917
918  /**
919   * Compute {Ln[Gamma(x+0.5)]}'' - {Ln[Gamma(x)]}''
920   *
921   * @param x the value in the above formula
922   * @return the result
923   */   
924  protected double diffSndDervLnGamma(double x){
925    double rt=0, series=1.0;// Just make it >0
926    for(int i=0;series>=m_Zero*1e-3;i++){
927      series = (x+(double)i+0.25)/
928        ((x+(double)i)*(x+(double)i)*(x+(double)i+0.5)*(x+(double)i+0.5));
929      rt -= series;
930    }
931    return rt;
932  }
933
934  /**
935   * Implement this procedure to evaluate objective
936   * function to be minimized
937   */
938  protected double objectiveFunction(double[] x){
939    int numExs = num.length;
940    double NLL = 0; // Negative Log-Likelihood
941
942    double a=x[0], b=x[1], w=x[2], m=x[3];
943    for(int j=0; j < numExs; j++){
944
945      if(Double.isNaN(xBar[j])) continue; // All missing values
946
947      NLL += 0.5*(b+num[j])*
948        Math.log((1.0+num[j]*w)*(a+sSq[j]) + 
949            num[j]*(xBar[j]-m)*(xBar[j]-m));       
950
951      if(Double.isNaN(NLL) && m_Debug){
952        System.err.println("???????????1: "+a+" "+b+" "+w+" "+m
953            +"|x-: "+xBar[j] + 
954            "|n: "+num[j] + "|S^2: "+sSq[j]);
955        System.exit(1);
956      }
957
958      // Doesn't affect optimization
959      //NLL += 0.5*num[j]*Math.log(Math.PI);           
960
961      NLL -= 0.5*(b+num[j]-1.0)*Math.log(1.0+num[j]*w);
962
963
964      if(Double.isNaN(NLL) && m_Debug){
965        System.err.println("???????????2: "+a+" "+b+" "+w+" "+m
966            +"|x-: "+xBar[j] + 
967            "|n: "+num[j] + "|S^2: "+sSq[j]);
968        System.exit(1);
969      }
970
971      int halfNum = ((int)num[j])/2;
972      for(int z=1; z<=halfNum; z++)
973        NLL -= Math.log(0.5*b+0.5*num[j]-(double)z);
974
975      if(0.5*num[j] > halfNum) // num[j] is odd
976        NLL -= diffLnGamma(0.5*b);
977
978      if(Double.isNaN(NLL) && m_Debug){
979        System.err.println("???????????3: "+a+" "+b+" "+w+" "+m
980            +"|x-: "+xBar[j] + 
981            "|n: "+num[j] + "|S^2: "+sSq[j]);
982        System.exit(1);
983      }                         
984
985      NLL -= 0.5*Math.log(a)*b;
986      if(Double.isNaN(NLL) && m_Debug){
987        System.err.println("???????????4:"+a+" "+b+" "+w+" "+m);
988        System.exit(1);
989      }     
990    }
991    if(m_Debug)
992      System.err.println("?????????????5: "+NLL);
993    if(Double.isNaN(NLL))         
994      System.exit(1);
995
996    return NLL;
997  }
998
999  /**
1000   * Subclass should implement this procedure to evaluate gradient
1001   * of the objective function
1002   */
1003  protected double[] evaluateGradient(double[] x){
1004    double[] g = new double[x.length];
1005    int numExs = num.length;
1006
1007    double a=x[0],b=x[1],w=x[2],m=x[3];
1008
1009    double da=0.0, db=0.0, dw=0.0, dm=0.0; 
1010    for(int j=0; j < numExs; j++){
1011
1012      if(Double.isNaN(xBar[j])) continue; // All missing values
1013
1014      double denorm = (1.0+num[j]*w)*(a+sSq[j]) + 
1015        num[j]*(xBar[j]-m)*(xBar[j]-m);
1016
1017      da += 0.5*(b+num[j])*(1.0+num[j]*w)/denorm-0.5*b/a;
1018
1019      db += 0.5*Math.log(denorm) 
1020        - 0.5*Math.log(1.0+num[j]*w)
1021        - 0.5*Math.log(a);
1022
1023      int halfNum = ((int)num[j])/2;
1024      for(int z=1; z<=halfNum; z++)
1025        db -= 1.0/(b+num[j]-2.0*(double)z);             
1026      if(num[j]/2.0 > halfNum) // num[j] is odd
1027        db -= 0.5*diffFstDervLnGamma(0.5*b);           
1028
1029      dw += 0.5*(b+num[j])*(a+sSq[j])*num[j]/denorm -
1030        0.5*(b+num[j]-1.0)*num[j]/(1.0+num[j]*w);
1031
1032      dm += num[j]*(b+num[j])*(m-xBar[j])/denorm;
1033    }
1034
1035    g[0] = da;
1036    g[1] = db;
1037    g[2] = dw;
1038    g[3] = dm;
1039    return g;
1040  }
1041
1042  /**
1043   * Subclass should implement this procedure to evaluate second-order
1044   * gradient of the objective function
1045   */
1046  protected double[] evaluateHessian(double[] x, int index){
1047    double[] h = new double[x.length];
1048
1049    // # of exemplars, # of dimensions
1050    // which dimension and which variable for 'index'
1051    int numExs = num.length;
1052    double a,b,w,m;
1053    // Take the 2nd-order derivative
1054    switch(index){
1055      case 0:  // a       
1056        a=x[0];b=x[1];w=x[2];m=x[3];
1057
1058        for(int j=0; j < numExs; j++){
1059          if(Double.isNaN(xBar[j])) continue; //All missing values
1060          double denorm = (1.0+num[j]*w)*(a+sSq[j]) + 
1061            num[j]*(xBar[j]-m)*(xBar[j]-m);
1062
1063          h[0] += 0.5*b/(a*a) 
1064            - 0.5*(b+num[j])*(1.0+num[j]*w)*(1.0+num[j]*w)
1065            /(denorm*denorm);
1066
1067          h[1] += 0.5*(1.0+num[j]*w)/denorm - 0.5/a;
1068
1069          h[2] += 0.5*num[j]*num[j]*(b+num[j])*
1070            (xBar[j]-m)*(xBar[j]-m)/(denorm*denorm);
1071
1072          h[3] -= num[j]*(b+num[j])*(m-xBar[j])
1073            *(1.0+num[j]*w)/(denorm*denorm);
1074        }
1075        break;
1076
1077      case 1: // b     
1078        a=x[0];b=x[1];w=x[2];m=x[3];
1079
1080        for(int j=0; j < numExs; j++){
1081          if(Double.isNaN(xBar[j])) continue; //All missing values
1082          double denorm = (1.0+num[j]*w)*(a+sSq[j]) + 
1083            num[j]*(xBar[j]-m)*(xBar[j]-m);
1084
1085          h[0] += 0.5*(1.0+num[j]*w)/denorm - 0.5/a;
1086
1087          int halfNum = ((int)num[j])/2;
1088          for(int z=1; z<=halfNum; z++)
1089            h[1] += 
1090              1.0/((b+num[j]-2.0*(double)z)*(b+num[j]-2.0*(double)z));
1091          if(num[j]/2.0 > halfNum) // num[j] is odd
1092            h[1] -= 0.25*diffSndDervLnGamma(0.5*b); 
1093
1094          h[2] += 0.5*(a+sSq[j])*num[j]/denorm -
1095            0.5*num[j]/(1.0+num[j]*w);
1096
1097          h[3] += num[j]*(m-xBar[j])/denorm;
1098        }
1099        break;
1100
1101      case 2: // w   
1102        a=x[0];b=x[1];w=x[2];m=x[3];
1103
1104        for(int j=0; j < numExs; j++){
1105          if(Double.isNaN(xBar[j])) continue; //All missing values
1106          double denorm = (1.0+num[j]*w)*(a+sSq[j]) + 
1107            num[j]*(xBar[j]-m)*(xBar[j]-m);
1108
1109          h[0] += 0.5*num[j]*num[j]*(b+num[j])*
1110            (xBar[j]-m)*(xBar[j]-m)/(denorm*denorm);
1111
1112          h[1] += 0.5*(a+sSq[j])*num[j]/denorm -
1113            0.5*num[j]/(1.0+num[j]*w);
1114
1115          h[2] += 0.5*(b+num[j]-1.0)*num[j]*num[j]/
1116            ((1.0+num[j]*w)*(1.0+num[j]*w)) -
1117            0.5*(b+num[j])*(a+sSq[j])*(a+sSq[j])*
1118            num[j]*num[j]/(denorm*denorm);
1119
1120          h[3] -= num[j]*num[j]*(b+num[j])*
1121            (m-xBar[j])*(a+sSq[j])/(denorm*denorm);
1122        }
1123        break;
1124
1125      case 3: // m
1126        a=x[0];b=x[1];w=x[2];m=x[3];
1127
1128        for(int j=0; j < numExs; j++){
1129          if(Double.isNaN(xBar[j])) continue; //All missing values
1130          double denorm = (1.0+num[j]*w)*(a+sSq[j]) + 
1131            num[j]*(xBar[j]-m)*(xBar[j]-m);
1132
1133          h[0] -= num[j]*(b+num[j])*(m-xBar[j])
1134            *(1.0+num[j]*w)/(denorm*denorm);
1135
1136          h[1] += num[j]*(m-xBar[j])/denorm;
1137
1138          h[2] -= num[j]*num[j]*(b+num[j])*
1139            (m-xBar[j])*(a+sSq[j])/(denorm*denorm);
1140
1141          h[3] += num[j]*(b+num[j])*
1142            ((1.0+num[j]*w)*(a+sSq[j])-
1143             num[j]*(m-xBar[j])*(m-xBar[j]))
1144            /(denorm*denorm);
1145        }
1146    }
1147
1148    return h;
1149  }
1150 
1151  /**
1152   * Returns the revision string.
1153   *
1154   * @return            the revision
1155   */
1156  public String getRevision() {
1157    return RevisionUtils.extract("$Revision: 5481 $");
1158  }
1159}
Note: See TracBrowser for help on using the repository browser.