[4] | 1 | /* |
---|
| 2 | * This program is free software; you can redistribute it and/or modify |
---|
| 3 | * it under the terms of the GNU General Public License as published by |
---|
| 4 | * the Free Software Foundation; either version 2 of the License, or |
---|
| 5 | * (at your option) any later version. |
---|
| 6 | * |
---|
| 7 | * This program is distributed in the hope that it will be useful, |
---|
| 8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
| 9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
---|
| 10 | * GNU General Public License for more details. |
---|
| 11 | * |
---|
| 12 | * You should have received a copy of the GNU General Public License |
---|
| 13 | * along with this program; if not, write to the Free Software |
---|
| 14 | * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. |
---|
| 15 | */ |
---|
| 16 | |
---|
| 17 | /* |
---|
| 18 | * TLD.java |
---|
| 19 | * Copyright (C) 2005 University of Waikato, Hamilton, New Zealand |
---|
| 20 | * |
---|
| 21 | */ |
---|
| 22 | |
---|
| 23 | package weka.classifiers.mi; |
---|
| 24 | |
---|
| 25 | import weka.classifiers.RandomizableClassifier; |
---|
| 26 | import weka.core.Capabilities; |
---|
| 27 | import weka.core.Instance; |
---|
| 28 | import weka.core.Instances; |
---|
| 29 | import weka.core.MultiInstanceCapabilitiesHandler; |
---|
| 30 | import weka.core.Optimization; |
---|
| 31 | import weka.core.Option; |
---|
| 32 | import weka.core.OptionHandler; |
---|
| 33 | import weka.core.RevisionUtils; |
---|
| 34 | import weka.core.TechnicalInformation; |
---|
| 35 | import weka.core.TechnicalInformationHandler; |
---|
| 36 | import weka.core.Utils; |
---|
| 37 | import weka.core.Capabilities.Capability; |
---|
| 38 | import weka.core.TechnicalInformation.Field; |
---|
| 39 | import weka.core.TechnicalInformation.Type; |
---|
| 40 | |
---|
| 41 | import java.util.Enumeration; |
---|
| 42 | import java.util.Random; |
---|
| 43 | import java.util.Vector; |
---|
| 44 | |
---|
| 45 | /** |
---|
| 46 | <!-- globalinfo-start --> |
---|
| 47 | * Two-Level Distribution approach, changes the starting value of the searching algorithm, supplement the cut-off modification and check missing values.<br/> |
---|
| 48 | * <br/> |
---|
| 49 | * For more information see:<br/> |
---|
| 50 | * <br/> |
---|
| 51 | * Xin Xu (2003). Statistical learning in multiple instance problem. Hamilton, NZ. |
---|
| 52 | * <p/> |
---|
| 53 | <!-- globalinfo-end --> |
---|
| 54 | * |
---|
| 55 | <!-- technical-bibtex-start --> |
---|
| 56 | * BibTeX: |
---|
| 57 | * <pre> |
---|
| 58 | * @mastersthesis{Xu2003, |
---|
| 59 | * address = {Hamilton, NZ}, |
---|
| 60 | * author = {Xin Xu}, |
---|
| 61 | * note = {0657.594}, |
---|
| 62 | * school = {University of Waikato}, |
---|
| 63 | * title = {Statistical learning in multiple instance problem}, |
---|
| 64 | * year = {2003} |
---|
| 65 | * } |
---|
| 66 | * </pre> |
---|
| 67 | * <p/> |
---|
| 68 | <!-- technical-bibtex-end --> |
---|
| 69 | * |
---|
| 70 | <!-- options-start --> |
---|
| 71 | * Valid options are: <p/> |
---|
| 72 | * |
---|
| 73 | * <pre> -C |
---|
| 74 | * Set whether or not use empirical |
---|
| 75 | * log-odds cut-off instead of 0</pre> |
---|
| 76 | * |
---|
| 77 | * <pre> -R <numOfRuns> |
---|
| 78 | * Set the number of multiple runs |
---|
| 79 | * needed for searching the MLE.</pre> |
---|
| 80 | * |
---|
| 81 | * <pre> -S <num> |
---|
| 82 | * Random number seed. |
---|
| 83 | * (default 1)</pre> |
---|
| 84 | * |
---|
| 85 | * <pre> -D |
---|
| 86 | * If set, classifier is run in debug mode and |
---|
| 87 | * may output additional info to the console</pre> |
---|
| 88 | * |
---|
| 89 | <!-- options-end --> |
---|
| 90 | * |
---|
| 91 | * @author Eibe Frank (eibe@cs.waikato.ac.nz) |
---|
| 92 | * @author Xin Xu (xx5@cs.waikato.ac.nz) |
---|
| 93 | * @version $Revision: 5481 $ |
---|
| 94 | */ |
---|
| 95 | public class TLD |
---|
| 96 | extends RandomizableClassifier |
---|
| 97 | implements OptionHandler, MultiInstanceCapabilitiesHandler, |
---|
| 98 | TechnicalInformationHandler { |
---|
| 99 | |
---|
| 100 | /** for serialization */ |
---|
| 101 | static final long serialVersionUID = 6657315525171152210L; |
---|
| 102 | |
---|
| 103 | /** The mean for each attribute of each positive exemplar */ |
---|
| 104 | protected double[][] m_MeanP = null; |
---|
| 105 | |
---|
| 106 | /** The variance for each attribute of each positive exemplar */ |
---|
| 107 | protected double[][] m_VarianceP = null; |
---|
| 108 | |
---|
| 109 | /** The mean for each attribute of each negative exemplar */ |
---|
| 110 | protected double[][] m_MeanN = null; |
---|
| 111 | |
---|
| 112 | /** The variance for each attribute of each negative exemplar */ |
---|
| 113 | protected double[][] m_VarianceN = null; |
---|
| 114 | |
---|
| 115 | /** The effective sum of weights of each positive exemplar in each dimension*/ |
---|
| 116 | protected double[][] m_SumP = null; |
---|
| 117 | |
---|
| 118 | /** The effective sum of weights of each negative exemplar in each dimension*/ |
---|
| 119 | protected double[][] m_SumN = null; |
---|
| 120 | |
---|
| 121 | /** The parameters to be estimated for each positive exemplar*/ |
---|
| 122 | protected double[] m_ParamsP = null; |
---|
| 123 | |
---|
| 124 | /** The parameters to be estimated for each negative exemplar*/ |
---|
| 125 | protected double[] m_ParamsN = null; |
---|
| 126 | |
---|
| 127 | /** The dimension of each exemplar, i.e. (numAttributes-2) */ |
---|
| 128 | protected int m_Dimension = 0; |
---|
| 129 | |
---|
| 130 | /** The class label of each exemplar */ |
---|
| 131 | protected double[] m_Class = null; |
---|
| 132 | |
---|
| 133 | /** The number of class labels in the data */ |
---|
| 134 | protected int m_NumClasses = 2; |
---|
| 135 | |
---|
| 136 | /** The very small number representing zero */ |
---|
| 137 | static public double ZERO = 1.0e-6; |
---|
| 138 | |
---|
| 139 | /** The number of runs to perform */ |
---|
| 140 | protected int m_Run = 1; |
---|
| 141 | |
---|
| 142 | protected double m_Cutoff; |
---|
| 143 | |
---|
| 144 | protected boolean m_UseEmpiricalCutOff = false; |
---|
| 145 | |
---|
| 146 | /** |
---|
| 147 | * Returns a string describing this filter |
---|
| 148 | * |
---|
| 149 | * @return a description of the filter suitable for |
---|
| 150 | * displaying in the explorer/experimenter gui |
---|
| 151 | */ |
---|
| 152 | public String globalInfo() { |
---|
| 153 | return |
---|
| 154 | "Two-Level Distribution approach, changes the starting value of " |
---|
| 155 | + "the searching algorithm, supplement the cut-off modification and " |
---|
| 156 | + "check missing values.\n\n" |
---|
| 157 | + "For more information see:\n\n" |
---|
| 158 | + getTechnicalInformation().toString(); |
---|
| 159 | } |
---|
| 160 | |
---|
| 161 | /** |
---|
| 162 | * Returns an instance of a TechnicalInformation object, containing |
---|
| 163 | * detailed information about the technical background of this class, |
---|
| 164 | * e.g., paper reference or book this class is based on. |
---|
| 165 | * |
---|
| 166 | * @return the technical information about this class |
---|
| 167 | */ |
---|
| 168 | public TechnicalInformation getTechnicalInformation() { |
---|
| 169 | TechnicalInformation result; |
---|
| 170 | |
---|
| 171 | result = new TechnicalInformation(Type.MASTERSTHESIS); |
---|
| 172 | result.setValue(Field.AUTHOR, "Xin Xu"); |
---|
| 173 | result.setValue(Field.YEAR, "2003"); |
---|
| 174 | result.setValue(Field.TITLE, "Statistical learning in multiple instance problem"); |
---|
| 175 | result.setValue(Field.SCHOOL, "University of Waikato"); |
---|
| 176 | result.setValue(Field.ADDRESS, "Hamilton, NZ"); |
---|
| 177 | result.setValue(Field.NOTE, "0657.594"); |
---|
| 178 | |
---|
| 179 | return result; |
---|
| 180 | } |
---|
| 181 | |
---|
| 182 | /** |
---|
| 183 | * Returns default capabilities of the classifier. |
---|
| 184 | * |
---|
| 185 | * @return the capabilities of this classifier |
---|
| 186 | */ |
---|
| 187 | public Capabilities getCapabilities() { |
---|
| 188 | Capabilities result = super.getCapabilities(); |
---|
| 189 | result.disableAll(); |
---|
| 190 | |
---|
| 191 | // attributes |
---|
| 192 | result.enable(Capability.NOMINAL_ATTRIBUTES); |
---|
| 193 | result.enable(Capability.RELATIONAL_ATTRIBUTES); |
---|
| 194 | result.enable(Capability.MISSING_VALUES); |
---|
| 195 | |
---|
| 196 | // class |
---|
| 197 | result.enable(Capability.BINARY_CLASS); |
---|
| 198 | result.enable(Capability.MISSING_CLASS_VALUES); |
---|
| 199 | |
---|
| 200 | // other |
---|
| 201 | result.enable(Capability.ONLY_MULTIINSTANCE); |
---|
| 202 | |
---|
| 203 | return result; |
---|
| 204 | } |
---|
| 205 | |
---|
| 206 | /** |
---|
| 207 | * Returns the capabilities of this multi-instance classifier for the |
---|
| 208 | * relational data. |
---|
| 209 | * |
---|
| 210 | * @return the capabilities of this object |
---|
| 211 | * @see Capabilities |
---|
| 212 | */ |
---|
| 213 | public Capabilities getMultiInstanceCapabilities() { |
---|
| 214 | Capabilities result = super.getCapabilities(); |
---|
| 215 | result.disableAll(); |
---|
| 216 | |
---|
| 217 | // attributes |
---|
| 218 | result.enable(Capability.NUMERIC_ATTRIBUTES); |
---|
| 219 | result.enable(Capability.MISSING_VALUES); |
---|
| 220 | |
---|
| 221 | // class |
---|
| 222 | result.disableAllClasses(); |
---|
| 223 | result.enable(Capability.NO_CLASS); |
---|
| 224 | |
---|
| 225 | return result; |
---|
| 226 | } |
---|
| 227 | |
---|
| 228 | /** |
---|
| 229 | * |
---|
| 230 | * @param exs the training exemplars |
---|
| 231 | * @throws Exception if the model cannot be built properly |
---|
| 232 | */ |
---|
| 233 | public void buildClassifier(Instances exs)throws Exception{ |
---|
| 234 | // can classifier handle the data? |
---|
| 235 | getCapabilities().testWithFail(exs); |
---|
| 236 | |
---|
| 237 | // remove instances with missing class |
---|
| 238 | exs = new Instances(exs); |
---|
| 239 | exs.deleteWithMissingClass(); |
---|
| 240 | |
---|
| 241 | int numegs = exs.numInstances(); |
---|
| 242 | m_Dimension = exs.attribute(1).relation(). numAttributes(); |
---|
| 243 | Instances pos = new Instances(exs, 0), neg = new Instances(exs, 0); |
---|
| 244 | |
---|
| 245 | for(int u=0; u<numegs; u++){ |
---|
| 246 | Instance example = exs.instance(u); |
---|
| 247 | if(example.classValue() == 1) |
---|
| 248 | pos.add(example); |
---|
| 249 | else |
---|
| 250 | neg.add(example); |
---|
| 251 | } |
---|
| 252 | |
---|
| 253 | int pnum = pos.numInstances(), nnum = neg.numInstances(); |
---|
| 254 | |
---|
| 255 | m_MeanP = new double[pnum][m_Dimension]; |
---|
| 256 | m_VarianceP = new double[pnum][m_Dimension]; |
---|
| 257 | m_SumP = new double[pnum][m_Dimension]; |
---|
| 258 | m_MeanN = new double[nnum][m_Dimension]; |
---|
| 259 | m_VarianceN = new double[nnum][m_Dimension]; |
---|
| 260 | m_SumN = new double[nnum][m_Dimension]; |
---|
| 261 | m_ParamsP = new double[4*m_Dimension]; |
---|
| 262 | m_ParamsN = new double[4*m_Dimension]; |
---|
| 263 | |
---|
| 264 | // Estimation of the parameters: as the start value for search |
---|
| 265 | double[] pSumVal=new double[m_Dimension], // for m |
---|
| 266 | nSumVal=new double[m_Dimension]; |
---|
| 267 | double[] maxVarsP=new double[m_Dimension], // for a |
---|
| 268 | maxVarsN=new double[m_Dimension]; |
---|
| 269 | // Mean of sample variances: for b, b=a/E(\sigma^2)+2 |
---|
| 270 | double[] varMeanP = new double[m_Dimension], |
---|
| 271 | varMeanN = new double[m_Dimension]; |
---|
| 272 | // Variances of sample means: for w, w=E[var(\mu)]/E[\sigma^2] |
---|
| 273 | double[] meanVarP = new double[m_Dimension], |
---|
| 274 | meanVarN = new double[m_Dimension]; |
---|
| 275 | // number of exemplars without all values missing |
---|
| 276 | double[] numExsP = new double[m_Dimension], |
---|
| 277 | numExsN = new double[m_Dimension]; |
---|
| 278 | |
---|
| 279 | // Extract metadata fro both positive and negative bags |
---|
| 280 | for(int v=0; v < pnum; v++){ |
---|
| 281 | /*Exemplar px = pos.exemplar(v); |
---|
| 282 | m_MeanP[v] = px.meanOrMode(); |
---|
| 283 | m_VarianceP[v] = px.variance(); |
---|
| 284 | Instances pxi = px.getInstances(); |
---|
| 285 | */ |
---|
| 286 | |
---|
| 287 | Instances pxi = pos.instance(v).relationalValue(1); |
---|
| 288 | for (int k=0; k<pxi.numAttributes(); k++) { |
---|
| 289 | m_MeanP[v][k] = pxi.meanOrMode(k); |
---|
| 290 | m_VarianceP[v][k] = pxi.variance(k); |
---|
| 291 | } |
---|
| 292 | |
---|
| 293 | for (int w=0,t=0; w < m_Dimension; w++,t++){ |
---|
| 294 | //if((t==m_ClassIndex) || (t==m_IdIndex)) |
---|
| 295 | // t++; |
---|
| 296 | |
---|
| 297 | if(!Double.isNaN(m_MeanP[v][w])){ |
---|
| 298 | for(int u=0;u<pxi.numInstances();u++){ |
---|
| 299 | Instance ins = pxi.instance(u); |
---|
| 300 | if(!ins.isMissing(t)) |
---|
| 301 | m_SumP[v][w] += ins.weight(); |
---|
| 302 | } |
---|
| 303 | numExsP[w]++; |
---|
| 304 | pSumVal[w] += m_MeanP[v][w]; |
---|
| 305 | meanVarP[w] += m_MeanP[v][w]*m_MeanP[v][w]; |
---|
| 306 | if(maxVarsP[w] < m_VarianceP[v][w]) |
---|
| 307 | maxVarsP[w] = m_VarianceP[v][w]; |
---|
| 308 | varMeanP[w] += m_VarianceP[v][w]; |
---|
| 309 | m_VarianceP[v][w] *= (m_SumP[v][w]-1.0); |
---|
| 310 | if(m_VarianceP[v][w] < 0.0) |
---|
| 311 | m_VarianceP[v][w] = 0.0; |
---|
| 312 | } |
---|
| 313 | } |
---|
| 314 | } |
---|
| 315 | |
---|
| 316 | for(int v=0; v < nnum; v++){ |
---|
| 317 | /*Exemplar nx = neg.exemplar(v); |
---|
| 318 | m_MeanN[v] = nx.meanOrMode(); |
---|
| 319 | m_VarianceN[v] = nx.variance(); |
---|
| 320 | Instances nxi = nx.getInstances(); |
---|
| 321 | */ |
---|
| 322 | Instances nxi = neg.instance(v).relationalValue(1); |
---|
| 323 | for (int k=0; k<nxi.numAttributes(); k++) { |
---|
| 324 | m_MeanN[v][k] = nxi.meanOrMode(k); |
---|
| 325 | m_VarianceN[v][k] = nxi.variance(k); |
---|
| 326 | } |
---|
| 327 | |
---|
| 328 | for (int w=0,t=0; w < m_Dimension; w++,t++){ |
---|
| 329 | //if((t==m_ClassIndex) || (t==m_IdIndex)) |
---|
| 330 | // t++; |
---|
| 331 | |
---|
| 332 | if(!Double.isNaN(m_MeanN[v][w])){ |
---|
| 333 | for(int u=0;u<nxi.numInstances();u++) |
---|
| 334 | if(!nxi.instance(u).isMissing(t)) |
---|
| 335 | m_SumN[v][w] += nxi.instance(u).weight(); |
---|
| 336 | numExsN[w]++; |
---|
| 337 | nSumVal[w] += m_MeanN[v][w]; |
---|
| 338 | meanVarN[w] += m_MeanN[v][w]*m_MeanN[v][w]; |
---|
| 339 | if(maxVarsN[w] < m_VarianceN[v][w]) |
---|
| 340 | maxVarsN[w] = m_VarianceN[v][w]; |
---|
| 341 | varMeanN[w] += m_VarianceN[v][w]; |
---|
| 342 | m_VarianceN[v][w] *= (m_SumN[v][w]-1.0); |
---|
| 343 | if(m_VarianceN[v][w] < 0.0) |
---|
| 344 | m_VarianceN[v][w] = 0.0; |
---|
| 345 | } |
---|
| 346 | } |
---|
| 347 | } |
---|
| 348 | |
---|
| 349 | for(int w=0; w<m_Dimension; w++){ |
---|
| 350 | pSumVal[w] /= numExsP[w]; |
---|
| 351 | nSumVal[w] /= numExsN[w]; |
---|
| 352 | if(numExsP[w]>1) |
---|
| 353 | meanVarP[w] = meanVarP[w]/(numExsP[w]-1.0) |
---|
| 354 | - pSumVal[w]*numExsP[w]/(numExsP[w]-1.0); |
---|
| 355 | if(numExsN[w]>1) |
---|
| 356 | meanVarN[w] = meanVarN[w]/(numExsN[w]-1.0) |
---|
| 357 | - nSumVal[w]*numExsN[w]/(numExsN[w]-1.0); |
---|
| 358 | varMeanP[w] /= numExsP[w]; |
---|
| 359 | varMeanN[w] /= numExsN[w]; |
---|
| 360 | } |
---|
| 361 | |
---|
| 362 | //Bounds and parameter values for each run |
---|
| 363 | double[][] bounds = new double[2][4]; |
---|
| 364 | double[] pThisParam = new double[4], |
---|
| 365 | nThisParam = new double[4]; |
---|
| 366 | |
---|
| 367 | // Initial values for parameters |
---|
| 368 | double a, b, w, m; |
---|
| 369 | |
---|
| 370 | // Optimize for one dimension |
---|
| 371 | for (int x=0; x < m_Dimension; x++){ |
---|
| 372 | if (getDebug()) |
---|
| 373 | System.err.println("\n\n!!!!!!!!!!!!!!!!!!!!!!???Dimension #"+x); |
---|
| 374 | |
---|
| 375 | // Positive examplars: first run |
---|
| 376 | a = (maxVarsP[x]>ZERO) ? maxVarsP[x]:1.0; |
---|
| 377 | if (varMeanP[x]<=ZERO) varMeanP[x] = ZERO; // modified by LinDong (09/2005) |
---|
| 378 | b = a/varMeanP[x]+2.0; // a/(b-2) = E(\sigma^2) |
---|
| 379 | w = meanVarP[x]/varMeanP[x]; // E[var(\mu)] = w*E[\sigma^2] |
---|
| 380 | if(w<=ZERO) w=1.0; |
---|
| 381 | |
---|
| 382 | m = pSumVal[x]; |
---|
| 383 | pThisParam[0] = a; // a |
---|
| 384 | pThisParam[1] = b; // b |
---|
| 385 | pThisParam[2] = w; // w |
---|
| 386 | pThisParam[3] = m; // m |
---|
| 387 | |
---|
| 388 | // Negative examplars: first run |
---|
| 389 | a = (maxVarsN[x]>ZERO) ? maxVarsN[x]:1.0; |
---|
| 390 | if (varMeanN[x]<=ZERO) varMeanN[x] = ZERO; // modified by LinDong (09/2005) |
---|
| 391 | b = a/varMeanN[x]+2.0; // a/(b-2) = E(\sigma^2) |
---|
| 392 | w = meanVarN[x]/varMeanN[x]; // E[var(\mu)] = w*E[\sigma^2] |
---|
| 393 | if(w<=ZERO) w=1.0; |
---|
| 394 | |
---|
| 395 | m = nSumVal[x]; |
---|
| 396 | nThisParam[0] = a; // a |
---|
| 397 | nThisParam[1] = b; // b |
---|
| 398 | nThisParam[2] = w; // w |
---|
| 399 | nThisParam[3] = m; // m |
---|
| 400 | |
---|
| 401 | // Bound constraints |
---|
| 402 | bounds[0][0] = ZERO; // a > 0 |
---|
| 403 | bounds[0][1] = 2.0+ZERO; // b > 2 |
---|
| 404 | bounds[0][2] = ZERO; // w > 0 |
---|
| 405 | bounds[0][3] = Double.NaN; |
---|
| 406 | |
---|
| 407 | for(int t=0; t<4; t++){ |
---|
| 408 | bounds[1][t] = Double.NaN; |
---|
| 409 | m_ParamsP[4*x+t] = pThisParam[t]; |
---|
| 410 | m_ParamsN[4*x+t] = nThisParam[t]; |
---|
| 411 | } |
---|
| 412 | double pminVal=Double.MAX_VALUE, nminVal=Double.MAX_VALUE; |
---|
| 413 | Random whichEx = new Random(m_Seed); |
---|
| 414 | TLD_Optm pOp=null, nOp=null; |
---|
| 415 | boolean isRunValid = true; |
---|
| 416 | double[] sumP=new double[pnum], meanP=new double[pnum], |
---|
| 417 | varP=new double[pnum]; |
---|
| 418 | double[] sumN=new double[nnum], meanN=new double[nnum], |
---|
| 419 | varN=new double[nnum]; |
---|
| 420 | |
---|
| 421 | // One dimension |
---|
| 422 | for(int p=0; p<pnum; p++){ |
---|
| 423 | sumP[p] = m_SumP[p][x]; |
---|
| 424 | meanP[p] = m_MeanP[p][x]; |
---|
| 425 | varP[p] = m_VarianceP[p][x]; |
---|
| 426 | } |
---|
| 427 | for(int q=0; q<nnum; q++){ |
---|
| 428 | sumN[q] = m_SumN[q][x]; |
---|
| 429 | meanN[q] = m_MeanN[q][x]; |
---|
| 430 | varN[q] = m_VarianceN[q][x]; |
---|
| 431 | } |
---|
| 432 | |
---|
| 433 | for(int y=0; y<m_Run;){ |
---|
| 434 | if (getDebug()) |
---|
| 435 | System.err.println("\n\n!!!!!!!!!!!!!!!!!!!!!!???Run #"+y); |
---|
| 436 | double thisMin; |
---|
| 437 | |
---|
| 438 | if (getDebug()) |
---|
| 439 | System.err.println("\nPositive exemplars"); |
---|
| 440 | pOp = new TLD_Optm(); |
---|
| 441 | pOp.setNum(sumP); |
---|
| 442 | pOp.setSSquare(varP); |
---|
| 443 | pOp.setXBar(meanP); |
---|
| 444 | |
---|
| 445 | pThisParam = pOp.findArgmin(pThisParam, bounds); |
---|
| 446 | while(pThisParam==null){ |
---|
| 447 | pThisParam = pOp.getVarbValues(); |
---|
| 448 | if (getDebug()) |
---|
| 449 | System.err.println("!!! 200 iterations finished, not enough!"); |
---|
| 450 | pThisParam = pOp.findArgmin(pThisParam, bounds); |
---|
| 451 | } |
---|
| 452 | |
---|
| 453 | thisMin = pOp.getMinFunction(); |
---|
| 454 | if(!Double.isNaN(thisMin) && (thisMin<pminVal)){ |
---|
| 455 | pminVal = thisMin; |
---|
| 456 | for(int z=0; z<4; z++) |
---|
| 457 | m_ParamsP[4*x+z] = pThisParam[z]; |
---|
| 458 | } |
---|
| 459 | |
---|
| 460 | if(Double.isNaN(thisMin)){ |
---|
| 461 | pThisParam = new double[4]; |
---|
| 462 | isRunValid =false; |
---|
| 463 | } |
---|
| 464 | |
---|
| 465 | if (getDebug()) |
---|
| 466 | System.err.println("\nNegative exemplars"); |
---|
| 467 | nOp = new TLD_Optm(); |
---|
| 468 | nOp.setNum(sumN); |
---|
| 469 | nOp.setSSquare(varN); |
---|
| 470 | nOp.setXBar(meanN); |
---|
| 471 | |
---|
| 472 | nThisParam = nOp.findArgmin(nThisParam, bounds); |
---|
| 473 | while(nThisParam==null){ |
---|
| 474 | nThisParam = nOp.getVarbValues(); |
---|
| 475 | if (getDebug()) |
---|
| 476 | System.err.println("!!! 200 iterations finished, not enough!"); |
---|
| 477 | nThisParam = nOp.findArgmin(nThisParam, bounds); |
---|
| 478 | } |
---|
| 479 | thisMin = nOp.getMinFunction(); |
---|
| 480 | if(!Double.isNaN(thisMin) && (thisMin<nminVal)){ |
---|
| 481 | nminVal = thisMin; |
---|
| 482 | for(int z=0; z<4; z++) |
---|
| 483 | m_ParamsN[4*x+z] = nThisParam[z]; |
---|
| 484 | } |
---|
| 485 | |
---|
| 486 | if(Double.isNaN(thisMin)){ |
---|
| 487 | nThisParam = new double[4]; |
---|
| 488 | isRunValid =false; |
---|
| 489 | } |
---|
| 490 | |
---|
| 491 | if(!isRunValid){ y--; isRunValid=true; } |
---|
| 492 | |
---|
| 493 | if(++y<m_Run){ |
---|
| 494 | // Change the initial parameters and restart |
---|
| 495 | int pone = whichEx.nextInt(pnum), // Randomly pick one pos. exmpl. |
---|
| 496 | none = whichEx.nextInt(nnum); |
---|
| 497 | |
---|
| 498 | // Positive exemplars: next run |
---|
| 499 | while((m_SumP[pone][x]<=1.0)||Double.isNaN(m_MeanP[pone][x])) |
---|
| 500 | pone = whichEx.nextInt(pnum); |
---|
| 501 | |
---|
| 502 | a = m_VarianceP[pone][x]/(m_SumP[pone][x]-1.0); |
---|
| 503 | if(a<=ZERO) a=m_ParamsN[4*x]; // Change to negative params |
---|
| 504 | m = m_MeanP[pone][x]; |
---|
| 505 | double sq = (m-m_ParamsP[4*x+3])*(m-m_ParamsP[4*x+3]); |
---|
| 506 | |
---|
| 507 | b = a*m_ParamsP[4*x+2]/sq+2.0; // b=a/Var+2, assuming Var=Sq/w' |
---|
| 508 | if((b<=ZERO) || Double.isNaN(b) || Double.isInfinite(b)) |
---|
| 509 | b=m_ParamsN[4*x+1]; |
---|
| 510 | |
---|
| 511 | w = sq*(m_ParamsP[4*x+1]-2.0)/m_ParamsP[4*x];//w=Sq/Var, assuming Var=a'/(b'-2) |
---|
| 512 | if((w<=ZERO) || Double.isNaN(w) || Double.isInfinite(w)) |
---|
| 513 | w=m_ParamsN[4*x+2]; |
---|
| 514 | |
---|
| 515 | pThisParam[0] = a; // a |
---|
| 516 | pThisParam[1] = b; // b |
---|
| 517 | pThisParam[2] = w; // w |
---|
| 518 | pThisParam[3] = m; // m |
---|
| 519 | |
---|
| 520 | // Negative exemplars: next run |
---|
| 521 | while((m_SumN[none][x]<=1.0)||Double.isNaN(m_MeanN[none][x])) |
---|
| 522 | none = whichEx.nextInt(nnum); |
---|
| 523 | |
---|
| 524 | a = m_VarianceN[none][x]/(m_SumN[none][x]-1.0); |
---|
| 525 | if(a<=ZERO) a=m_ParamsP[4*x]; |
---|
| 526 | m = m_MeanN[none][x]; |
---|
| 527 | sq = (m-m_ParamsN[4*x+3])*(m-m_ParamsN[4*x+3]); |
---|
| 528 | |
---|
| 529 | b = a*m_ParamsN[4*x+2]/sq+2.0; // b=a/Var+2, assuming Var=Sq/w' |
---|
| 530 | if((b<=ZERO) || Double.isNaN(b) || Double.isInfinite(b)) |
---|
| 531 | b=m_ParamsP[4*x+1]; |
---|
| 532 | |
---|
| 533 | w = sq*(m_ParamsN[4*x+1]-2.0)/m_ParamsN[4*x];//w=Sq/Var, assuming Var=a'/(b'-2) |
---|
| 534 | if((w<=ZERO) || Double.isNaN(w) || Double.isInfinite(w)) |
---|
| 535 | w=m_ParamsP[4*x+2]; |
---|
| 536 | |
---|
| 537 | nThisParam[0] = a; // a |
---|
| 538 | nThisParam[1] = b; // b |
---|
| 539 | nThisParam[2] = w; // w |
---|
| 540 | nThisParam[3] = m; // m |
---|
| 541 | } |
---|
| 542 | } |
---|
| 543 | } |
---|
| 544 | |
---|
| 545 | for (int x=0, y=0; x<m_Dimension; x++, y++){ |
---|
| 546 | //if((x==exs.classIndex()) || (x==exs.idIndex())) |
---|
| 547 | //y++; |
---|
| 548 | a=m_ParamsP[4*x]; b=m_ParamsP[4*x+1]; |
---|
| 549 | w=m_ParamsP[4*x+2]; m=m_ParamsP[4*x+3]; |
---|
| 550 | if (getDebug()) |
---|
| 551 | System.err.println("\n\n???Positive: ( "+exs.attribute(1).relation().attribute(y)+ |
---|
| 552 | "): a="+a+", b="+b+", w="+w+", m="+m); |
---|
| 553 | |
---|
| 554 | a=m_ParamsN[4*x]; b=m_ParamsN[4*x+1]; |
---|
| 555 | w=m_ParamsN[4*x+2]; m=m_ParamsN[4*x+3]; |
---|
| 556 | if (getDebug()) |
---|
| 557 | System.err.println("???Negative: ("+exs.attribute(1).relation().attribute(y)+ |
---|
| 558 | "): a="+a+", b="+b+", w="+w+", m="+m); |
---|
| 559 | } |
---|
| 560 | |
---|
| 561 | if(m_UseEmpiricalCutOff){ |
---|
| 562 | // Find the empirical cut-off |
---|
| 563 | double[] pLogOdds=new double[pnum], nLogOdds=new double[nnum]; |
---|
| 564 | for(int p=0; p<pnum; p++) |
---|
| 565 | pLogOdds[p] = |
---|
| 566 | likelihoodRatio(m_SumP[p], m_MeanP[p], m_VarianceP[p]); |
---|
| 567 | |
---|
| 568 | for(int q=0; q<nnum; q++) |
---|
| 569 | nLogOdds[q] = |
---|
| 570 | likelihoodRatio(m_SumN[q], m_MeanN[q], m_VarianceN[q]); |
---|
| 571 | |
---|
| 572 | // Update m_Cutoff |
---|
| 573 | findCutOff(pLogOdds, nLogOdds); |
---|
| 574 | } |
---|
| 575 | else |
---|
| 576 | m_Cutoff = -Math.log((double)pnum/(double)nnum); |
---|
| 577 | |
---|
| 578 | if (getDebug()) |
---|
| 579 | System.err.println("???Cut-off="+m_Cutoff); |
---|
| 580 | } |
---|
| 581 | |
---|
| 582 | /** |
---|
| 583 | * |
---|
| 584 | * @param ex the given test exemplar |
---|
| 585 | * @return the classification |
---|
| 586 | * @throws Exception if the exemplar could not be classified |
---|
| 587 | * successfully |
---|
| 588 | */ |
---|
| 589 | public double classifyInstance(Instance ex)throws Exception{ |
---|
| 590 | //Exemplar ex = new Exemplar(e); |
---|
| 591 | Instances exi = ex.relationalValue(1); |
---|
| 592 | double[] n = new double[m_Dimension]; |
---|
| 593 | double [] xBar = new double[m_Dimension]; |
---|
| 594 | double [] sSq = new double[m_Dimension]; |
---|
| 595 | for (int i=0; i<exi.numAttributes() ; i++){ |
---|
| 596 | xBar[i] = exi.meanOrMode(i); |
---|
| 597 | sSq[i] = exi.variance(i); |
---|
| 598 | } |
---|
| 599 | |
---|
| 600 | for (int w=0, t=0; w < m_Dimension; w++, t++){ |
---|
| 601 | //if((t==m_ClassIndex) || (t==m_IdIndex)) |
---|
| 602 | //t++; |
---|
| 603 | for(int u=0;u<exi.numInstances();u++) |
---|
| 604 | if(!exi.instance(u).isMissing(t)) |
---|
| 605 | n[w] += exi.instance(u).weight(); |
---|
| 606 | |
---|
| 607 | sSq[w] = sSq[w]*(n[w]-1.0); |
---|
| 608 | if(sSq[w] <= 0.0) |
---|
| 609 | sSq[w] = 0.0; |
---|
| 610 | } |
---|
| 611 | |
---|
| 612 | double logOdds = likelihoodRatio(n, xBar, sSq); |
---|
| 613 | return (logOdds > m_Cutoff) ? 1 : 0 ; |
---|
| 614 | } |
---|
| 615 | |
---|
| 616 | private double likelihoodRatio(double[] n, double[] xBar, double[] sSq){ |
---|
| 617 | double LLP = 0.0, LLN = 0.0; |
---|
| 618 | |
---|
| 619 | for (int x=0; x<m_Dimension; x++){ |
---|
| 620 | if(Double.isNaN(xBar[x])) continue; // All missing values |
---|
| 621 | |
---|
| 622 | int halfN = ((int)n[x])/2; |
---|
| 623 | //Log-likelihood for positive |
---|
| 624 | double a=m_ParamsP[4*x], b=m_ParamsP[4*x+1], |
---|
| 625 | w=m_ParamsP[4*x+2], m=m_ParamsP[4*x+3]; |
---|
| 626 | LLP += 0.5*b*Math.log(a) + 0.5*(b+n[x]-1.0)*Math.log(1.0+n[x]*w) |
---|
| 627 | - 0.5*(b+n[x])*Math.log((1.0+n[x]*w)*(a+sSq[x])+ |
---|
| 628 | n[x]*(xBar[x]-m)*(xBar[x]-m)) |
---|
| 629 | - 0.5*n[x]*Math.log(Math.PI); |
---|
| 630 | for(int y=1; y<=halfN; y++) |
---|
| 631 | LLP += Math.log(b/2.0+n[x]/2.0-(double)y); |
---|
| 632 | |
---|
| 633 | if(n[x]/2.0 > halfN) // n is odd |
---|
| 634 | LLP += TLD_Optm.diffLnGamma(b/2.0); |
---|
| 635 | |
---|
| 636 | //Log-likelihood for negative |
---|
| 637 | a=m_ParamsN[4*x]; |
---|
| 638 | b=m_ParamsN[4*x+1]; |
---|
| 639 | w=m_ParamsN[4*x+2]; |
---|
| 640 | m=m_ParamsN[4*x+3]; |
---|
| 641 | LLN += 0.5*b*Math.log(a) + 0.5*(b+n[x]-1.0)*Math.log(1.0+n[x]*w) |
---|
| 642 | - 0.5*(b+n[x])*Math.log((1.0+n[x]*w)*(a+sSq[x])+ |
---|
| 643 | n[x]*(xBar[x]-m)*(xBar[x]-m)) |
---|
| 644 | - 0.5*n[x]*Math.log(Math.PI); |
---|
| 645 | for(int y=1; y<=halfN; y++) |
---|
| 646 | LLN += Math.log(b/2.0+n[x]/2.0-(double)y); |
---|
| 647 | |
---|
| 648 | if(n[x]/2.0 > halfN) // n is odd |
---|
| 649 | LLN += TLD_Optm.diffLnGamma(b/2.0); |
---|
| 650 | } |
---|
| 651 | |
---|
| 652 | return LLP - LLN; |
---|
| 653 | } |
---|
| 654 | |
---|
| 655 | private void findCutOff(double[] pos, double[] neg){ |
---|
| 656 | int[] pOrder = Utils.sort(pos), |
---|
| 657 | nOrder = Utils.sort(neg); |
---|
| 658 | /* |
---|
| 659 | System.err.println("\n\n???Positive: "); |
---|
| 660 | for(int t=0; t<pOrder.length; t++) |
---|
| 661 | System.err.print(t+":"+Utils.doubleToString(pos[pOrder[t]],0,2)+" "); |
---|
| 662 | System.err.println("\n\n???Negative: "); |
---|
| 663 | for(int t=0; t<nOrder.length; t++) |
---|
| 664 | System.err.print(t+":"+Utils.doubleToString(neg[nOrder[t]],0,2)+" "); |
---|
| 665 | */ |
---|
| 666 | int pNum = pos.length, nNum = neg.length, count, p=0, n=0; |
---|
| 667 | double fstAccu=0.0, sndAccu=(double)pNum, split; |
---|
| 668 | double maxAccu = 0, minDistTo0 = Double.MAX_VALUE; |
---|
| 669 | |
---|
| 670 | // Skip continuous negatives |
---|
| 671 | for(;(n<nNum)&&(pos[pOrder[0]]>=neg[nOrder[n]]); n++, fstAccu++); |
---|
| 672 | |
---|
| 673 | if(n>=nNum){ // totally seperate |
---|
| 674 | m_Cutoff = (neg[nOrder[nNum-1]]+pos[pOrder[0]])/2.0; |
---|
| 675 | //m_Cutoff = neg[nOrder[nNum-1]]; |
---|
| 676 | return; |
---|
| 677 | } |
---|
| 678 | |
---|
| 679 | count=n; |
---|
| 680 | while((p<pNum)&&(n<nNum)){ |
---|
| 681 | // Compare the next in the two lists |
---|
| 682 | if(pos[pOrder[p]]>=neg[nOrder[n]]){ // Neg has less log-odds |
---|
| 683 | fstAccu += 1.0; |
---|
| 684 | split=neg[nOrder[n]]; |
---|
| 685 | n++; |
---|
| 686 | } |
---|
| 687 | else{ |
---|
| 688 | sndAccu -= 1.0; |
---|
| 689 | split=pos[pOrder[p]]; |
---|
| 690 | p++; |
---|
| 691 | } |
---|
| 692 | count++; |
---|
| 693 | if((fstAccu+sndAccu > maxAccu) |
---|
| 694 | || ((fstAccu+sndAccu == maxAccu) && (Math.abs(split)<minDistTo0))){ |
---|
| 695 | maxAccu = fstAccu+sndAccu; |
---|
| 696 | m_Cutoff = split; |
---|
| 697 | minDistTo0 = Math.abs(split); |
---|
| 698 | } |
---|
| 699 | } |
---|
| 700 | } |
---|
| 701 | |
---|
| 702 | /** |
---|
| 703 | * Returns an enumeration describing the available options |
---|
| 704 | * |
---|
| 705 | * @return an enumeration of all the available options |
---|
| 706 | */ |
---|
| 707 | public Enumeration listOptions() { |
---|
| 708 | Vector result = new Vector(); |
---|
| 709 | |
---|
| 710 | result.addElement(new Option( |
---|
| 711 | "\tSet whether or not use empirical\n" |
---|
| 712 | + "\tlog-odds cut-off instead of 0", |
---|
| 713 | "C", 0, "-C")); |
---|
| 714 | |
---|
| 715 | result.addElement(new Option( |
---|
| 716 | "\tSet the number of multiple runs \n" |
---|
| 717 | + "\tneeded for searching the MLE.", |
---|
| 718 | "R", 1, "-R <numOfRuns>")); |
---|
| 719 | |
---|
| 720 | Enumeration enu = super.listOptions(); |
---|
| 721 | while (enu.hasMoreElements()) { |
---|
| 722 | result.addElement(enu.nextElement()); |
---|
| 723 | } |
---|
| 724 | |
---|
| 725 | return result.elements(); |
---|
| 726 | } |
---|
| 727 | |
---|
| 728 | /** |
---|
| 729 | * Parses a given list of options. <p/> |
---|
| 730 | * |
---|
| 731 | <!-- options-start --> |
---|
| 732 | * Valid options are: <p/> |
---|
| 733 | * |
---|
| 734 | * <pre> -C |
---|
| 735 | * Set whether or not use empirical |
---|
| 736 | * log-odds cut-off instead of 0</pre> |
---|
| 737 | * |
---|
| 738 | * <pre> -R <numOfRuns> |
---|
| 739 | * Set the number of multiple runs |
---|
| 740 | * needed for searching the MLE.</pre> |
---|
| 741 | * |
---|
| 742 | * <pre> -S <num> |
---|
| 743 | * Random number seed. |
---|
| 744 | * (default 1)</pre> |
---|
| 745 | * |
---|
| 746 | * <pre> -D |
---|
| 747 | * If set, classifier is run in debug mode and |
---|
| 748 | * may output additional info to the console</pre> |
---|
| 749 | * |
---|
| 750 | <!-- options-end --> |
---|
| 751 | * |
---|
| 752 | * @param options the list of options as an array of strings |
---|
| 753 | * @throws Exception if an option is not supported |
---|
| 754 | */ |
---|
| 755 | public void setOptions(String[] options) throws Exception{ |
---|
| 756 | setDebug(Utils.getFlag('D', options)); |
---|
| 757 | |
---|
| 758 | setUsingCutOff(Utils.getFlag('C', options)); |
---|
| 759 | |
---|
| 760 | String runString = Utils.getOption('R', options); |
---|
| 761 | if (runString.length() != 0) |
---|
| 762 | setNumRuns(Integer.parseInt(runString)); |
---|
| 763 | else |
---|
| 764 | setNumRuns(1); |
---|
| 765 | |
---|
| 766 | super.setOptions(options); |
---|
| 767 | } |
---|
| 768 | |
---|
| 769 | /** |
---|
| 770 | * Gets the current settings of the Classifier. |
---|
| 771 | * |
---|
| 772 | * @return an array of strings suitable for passing to setOptions |
---|
| 773 | */ |
---|
| 774 | public String[] getOptions() { |
---|
| 775 | Vector result; |
---|
| 776 | String[] options; |
---|
| 777 | int i; |
---|
| 778 | |
---|
| 779 | result = new Vector(); |
---|
| 780 | options = super.getOptions(); |
---|
| 781 | for (i = 0; i < options.length; i++) |
---|
| 782 | result.add(options[i]); |
---|
| 783 | |
---|
| 784 | if (getDebug()) |
---|
| 785 | result.add("-D"); |
---|
| 786 | |
---|
| 787 | if (getUsingCutOff()) |
---|
| 788 | result.add("-C"); |
---|
| 789 | |
---|
| 790 | result.add("-R"); |
---|
| 791 | result.add("" + getNumRuns()); |
---|
| 792 | |
---|
| 793 | return (String[]) result.toArray(new String[result.size()]); |
---|
| 794 | } |
---|
| 795 | |
---|
| 796 | /** |
---|
| 797 | * Returns the tip text for this property |
---|
| 798 | * |
---|
| 799 | * @return tip text for this property suitable for |
---|
| 800 | * displaying in the explorer/experimenter gui |
---|
| 801 | */ |
---|
| 802 | public String numRunsTipText() { |
---|
| 803 | return "The number of runs to perform."; |
---|
| 804 | } |
---|
| 805 | |
---|
| 806 | /** |
---|
| 807 | * Sets the number of runs to perform. |
---|
| 808 | * |
---|
| 809 | * @param numRuns the number of runs to perform |
---|
| 810 | */ |
---|
| 811 | public void setNumRuns(int numRuns) { |
---|
| 812 | m_Run = numRuns; |
---|
| 813 | } |
---|
| 814 | |
---|
| 815 | /** |
---|
| 816 | * Returns the number of runs to perform. |
---|
| 817 | * |
---|
| 818 | * @return the number of runs to perform |
---|
| 819 | */ |
---|
| 820 | public int getNumRuns() { |
---|
| 821 | return m_Run; |
---|
| 822 | } |
---|
| 823 | |
---|
| 824 | /** |
---|
| 825 | * Returns the tip text for this property |
---|
| 826 | * |
---|
| 827 | * @return tip text for this property suitable for |
---|
| 828 | * displaying in the explorer/experimenter gui |
---|
| 829 | */ |
---|
| 830 | public String usingCutOffTipText() { |
---|
| 831 | return "Whether to use an empirical cutoff."; |
---|
| 832 | } |
---|
| 833 | |
---|
| 834 | /** |
---|
| 835 | * Sets whether to use an empirical cutoff. |
---|
| 836 | * |
---|
| 837 | * @param cutOff whether to use an empirical cutoff |
---|
| 838 | */ |
---|
| 839 | public void setUsingCutOff (boolean cutOff) { |
---|
| 840 | m_UseEmpiricalCutOff = cutOff; |
---|
| 841 | } |
---|
| 842 | |
---|
| 843 | /** |
---|
| 844 | * Returns whether an empirical cutoff is used |
---|
| 845 | * |
---|
| 846 | * @return true if an empirical cutoff is used |
---|
| 847 | */ |
---|
| 848 | public boolean getUsingCutOff() { |
---|
| 849 | return m_UseEmpiricalCutOff; |
---|
| 850 | } |
---|
| 851 | |
---|
| 852 | /** |
---|
| 853 | * Returns the revision string. |
---|
| 854 | * |
---|
| 855 | * @return the revision |
---|
| 856 | */ |
---|
| 857 | public String getRevision() { |
---|
| 858 | return RevisionUtils.extract("$Revision: 5481 $"); |
---|
| 859 | } |
---|
| 860 | |
---|
| 861 | /** |
---|
| 862 | * Main method for testing. |
---|
| 863 | * |
---|
| 864 | * @param args the options for the classifier |
---|
| 865 | */ |
---|
| 866 | public static void main(String[] args) { |
---|
| 867 | runClassifier(new TLD(), args); |
---|
| 868 | } |
---|
| 869 | } |
---|
| 870 | |
---|
| 871 | class TLD_Optm extends Optimization { |
---|
| 872 | |
---|
| 873 | private double[] num; |
---|
| 874 | private double[] sSq; |
---|
| 875 | private double[] xBar; |
---|
| 876 | |
---|
| 877 | public void setNum(double[] n) {num = n;} |
---|
| 878 | public void setSSquare(double[] s){sSq = s;} |
---|
| 879 | public void setXBar(double[] x){xBar = x;} |
---|
| 880 | |
---|
| 881 | /** |
---|
| 882 | * Compute Ln[Gamma(b+0.5)] - Ln[Gamma(b)] |
---|
| 883 | * |
---|
| 884 | * @param b the value in the above formula |
---|
| 885 | * @return the result |
---|
| 886 | */ |
---|
| 887 | public static double diffLnGamma(double b){ |
---|
| 888 | double[] coef= {76.18009172947146, -86.50532032941677, |
---|
| 889 | 24.01409824083091, -1.231739572450155, |
---|
| 890 | 0.1208650973866179e-2, -0.5395239384953e-5}; |
---|
| 891 | double rt = -0.5; |
---|
| 892 | rt += (b+1.0)*Math.log(b+6.0) - (b+0.5)*Math.log(b+5.5); |
---|
| 893 | double series1=1.000000000190015, series2=1.000000000190015; |
---|
| 894 | for(int i=0; i<6; i++){ |
---|
| 895 | series1 += coef[i]/(b+1.5+(double)i); |
---|
| 896 | series2 += coef[i]/(b+1.0+(double)i); |
---|
| 897 | } |
---|
| 898 | |
---|
| 899 | rt += Math.log(series1*b)-Math.log(series2*(b+0.5)); |
---|
| 900 | return rt; |
---|
| 901 | } |
---|
| 902 | |
---|
| 903 | /** |
---|
| 904 | * Compute dLn[Gamma(x+0.5)]/dx - dLn[Gamma(x)]/dx |
---|
| 905 | * |
---|
| 906 | * @param x the value in the above formula |
---|
| 907 | * @return the result |
---|
| 908 | */ |
---|
| 909 | protected double diffFstDervLnGamma(double x){ |
---|
| 910 | double rt=0, series=1.0;// Just make it >0 |
---|
| 911 | for(int i=0;series>=m_Zero*1e-3;i++){ |
---|
| 912 | series = 0.5/((x+(double)i)*(x+(double)i+0.5)); |
---|
| 913 | rt += series; |
---|
| 914 | } |
---|
| 915 | return rt; |
---|
| 916 | } |
---|
| 917 | |
---|
| 918 | /** |
---|
| 919 | * Compute {Ln[Gamma(x+0.5)]}'' - {Ln[Gamma(x)]}'' |
---|
| 920 | * |
---|
| 921 | * @param x the value in the above formula |
---|
| 922 | * @return the result |
---|
| 923 | */ |
---|
| 924 | protected double diffSndDervLnGamma(double x){ |
---|
| 925 | double rt=0, series=1.0;// Just make it >0 |
---|
| 926 | for(int i=0;series>=m_Zero*1e-3;i++){ |
---|
| 927 | series = (x+(double)i+0.25)/ |
---|
| 928 | ((x+(double)i)*(x+(double)i)*(x+(double)i+0.5)*(x+(double)i+0.5)); |
---|
| 929 | rt -= series; |
---|
| 930 | } |
---|
| 931 | return rt; |
---|
| 932 | } |
---|
| 933 | |
---|
| 934 | /** |
---|
| 935 | * Implement this procedure to evaluate objective |
---|
| 936 | * function to be minimized |
---|
| 937 | */ |
---|
| 938 | protected double objectiveFunction(double[] x){ |
---|
| 939 | int numExs = num.length; |
---|
| 940 | double NLL = 0; // Negative Log-Likelihood |
---|
| 941 | |
---|
| 942 | double a=x[0], b=x[1], w=x[2], m=x[3]; |
---|
| 943 | for(int j=0; j < numExs; j++){ |
---|
| 944 | |
---|
| 945 | if(Double.isNaN(xBar[j])) continue; // All missing values |
---|
| 946 | |
---|
| 947 | NLL += 0.5*(b+num[j])* |
---|
| 948 | Math.log((1.0+num[j]*w)*(a+sSq[j]) + |
---|
| 949 | num[j]*(xBar[j]-m)*(xBar[j]-m)); |
---|
| 950 | |
---|
| 951 | if(Double.isNaN(NLL) && m_Debug){ |
---|
| 952 | System.err.println("???????????1: "+a+" "+b+" "+w+" "+m |
---|
| 953 | +"|x-: "+xBar[j] + |
---|
| 954 | "|n: "+num[j] + "|S^2: "+sSq[j]); |
---|
| 955 | System.exit(1); |
---|
| 956 | } |
---|
| 957 | |
---|
| 958 | // Doesn't affect optimization |
---|
| 959 | //NLL += 0.5*num[j]*Math.log(Math.PI); |
---|
| 960 | |
---|
| 961 | NLL -= 0.5*(b+num[j]-1.0)*Math.log(1.0+num[j]*w); |
---|
| 962 | |
---|
| 963 | |
---|
| 964 | if(Double.isNaN(NLL) && m_Debug){ |
---|
| 965 | System.err.println("???????????2: "+a+" "+b+" "+w+" "+m |
---|
| 966 | +"|x-: "+xBar[j] + |
---|
| 967 | "|n: "+num[j] + "|S^2: "+sSq[j]); |
---|
| 968 | System.exit(1); |
---|
| 969 | } |
---|
| 970 | |
---|
| 971 | int halfNum = ((int)num[j])/2; |
---|
| 972 | for(int z=1; z<=halfNum; z++) |
---|
| 973 | NLL -= Math.log(0.5*b+0.5*num[j]-(double)z); |
---|
| 974 | |
---|
| 975 | if(0.5*num[j] > halfNum) // num[j] is odd |
---|
| 976 | NLL -= diffLnGamma(0.5*b); |
---|
| 977 | |
---|
| 978 | if(Double.isNaN(NLL) && m_Debug){ |
---|
| 979 | System.err.println("???????????3: "+a+" "+b+" "+w+" "+m |
---|
| 980 | +"|x-: "+xBar[j] + |
---|
| 981 | "|n: "+num[j] + "|S^2: "+sSq[j]); |
---|
| 982 | System.exit(1); |
---|
| 983 | } |
---|
| 984 | |
---|
| 985 | NLL -= 0.5*Math.log(a)*b; |
---|
| 986 | if(Double.isNaN(NLL) && m_Debug){ |
---|
| 987 | System.err.println("???????????4:"+a+" "+b+" "+w+" "+m); |
---|
| 988 | System.exit(1); |
---|
| 989 | } |
---|
| 990 | } |
---|
| 991 | if(m_Debug) |
---|
| 992 | System.err.println("?????????????5: "+NLL); |
---|
| 993 | if(Double.isNaN(NLL)) |
---|
| 994 | System.exit(1); |
---|
| 995 | |
---|
| 996 | return NLL; |
---|
| 997 | } |
---|
| 998 | |
---|
| 999 | /** |
---|
| 1000 | * Subclass should implement this procedure to evaluate gradient |
---|
| 1001 | * of the objective function |
---|
| 1002 | */ |
---|
| 1003 | protected double[] evaluateGradient(double[] x){ |
---|
| 1004 | double[] g = new double[x.length]; |
---|
| 1005 | int numExs = num.length; |
---|
| 1006 | |
---|
| 1007 | double a=x[0],b=x[1],w=x[2],m=x[3]; |
---|
| 1008 | |
---|
| 1009 | double da=0.0, db=0.0, dw=0.0, dm=0.0; |
---|
| 1010 | for(int j=0; j < numExs; j++){ |
---|
| 1011 | |
---|
| 1012 | if(Double.isNaN(xBar[j])) continue; // All missing values |
---|
| 1013 | |
---|
| 1014 | double denorm = (1.0+num[j]*w)*(a+sSq[j]) + |
---|
| 1015 | num[j]*(xBar[j]-m)*(xBar[j]-m); |
---|
| 1016 | |
---|
| 1017 | da += 0.5*(b+num[j])*(1.0+num[j]*w)/denorm-0.5*b/a; |
---|
| 1018 | |
---|
| 1019 | db += 0.5*Math.log(denorm) |
---|
| 1020 | - 0.5*Math.log(1.0+num[j]*w) |
---|
| 1021 | - 0.5*Math.log(a); |
---|
| 1022 | |
---|
| 1023 | int halfNum = ((int)num[j])/2; |
---|
| 1024 | for(int z=1; z<=halfNum; z++) |
---|
| 1025 | db -= 1.0/(b+num[j]-2.0*(double)z); |
---|
| 1026 | if(num[j]/2.0 > halfNum) // num[j] is odd |
---|
| 1027 | db -= 0.5*diffFstDervLnGamma(0.5*b); |
---|
| 1028 | |
---|
| 1029 | dw += 0.5*(b+num[j])*(a+sSq[j])*num[j]/denorm - |
---|
| 1030 | 0.5*(b+num[j]-1.0)*num[j]/(1.0+num[j]*w); |
---|
| 1031 | |
---|
| 1032 | dm += num[j]*(b+num[j])*(m-xBar[j])/denorm; |
---|
| 1033 | } |
---|
| 1034 | |
---|
| 1035 | g[0] = da; |
---|
| 1036 | g[1] = db; |
---|
| 1037 | g[2] = dw; |
---|
| 1038 | g[3] = dm; |
---|
| 1039 | return g; |
---|
| 1040 | } |
---|
| 1041 | |
---|
| 1042 | /** |
---|
| 1043 | * Subclass should implement this procedure to evaluate second-order |
---|
| 1044 | * gradient of the objective function |
---|
| 1045 | */ |
---|
| 1046 | protected double[] evaluateHessian(double[] x, int index){ |
---|
| 1047 | double[] h = new double[x.length]; |
---|
| 1048 | |
---|
| 1049 | // # of exemplars, # of dimensions |
---|
| 1050 | // which dimension and which variable for 'index' |
---|
| 1051 | int numExs = num.length; |
---|
| 1052 | double a,b,w,m; |
---|
| 1053 | // Take the 2nd-order derivative |
---|
| 1054 | switch(index){ |
---|
| 1055 | case 0: // a |
---|
| 1056 | a=x[0];b=x[1];w=x[2];m=x[3]; |
---|
| 1057 | |
---|
| 1058 | for(int j=0; j < numExs; j++){ |
---|
| 1059 | if(Double.isNaN(xBar[j])) continue; //All missing values |
---|
| 1060 | double denorm = (1.0+num[j]*w)*(a+sSq[j]) + |
---|
| 1061 | num[j]*(xBar[j]-m)*(xBar[j]-m); |
---|
| 1062 | |
---|
| 1063 | h[0] += 0.5*b/(a*a) |
---|
| 1064 | - 0.5*(b+num[j])*(1.0+num[j]*w)*(1.0+num[j]*w) |
---|
| 1065 | /(denorm*denorm); |
---|
| 1066 | |
---|
| 1067 | h[1] += 0.5*(1.0+num[j]*w)/denorm - 0.5/a; |
---|
| 1068 | |
---|
| 1069 | h[2] += 0.5*num[j]*num[j]*(b+num[j])* |
---|
| 1070 | (xBar[j]-m)*(xBar[j]-m)/(denorm*denorm); |
---|
| 1071 | |
---|
| 1072 | h[3] -= num[j]*(b+num[j])*(m-xBar[j]) |
---|
| 1073 | *(1.0+num[j]*w)/(denorm*denorm); |
---|
| 1074 | } |
---|
| 1075 | break; |
---|
| 1076 | |
---|
| 1077 | case 1: // b |
---|
| 1078 | a=x[0];b=x[1];w=x[2];m=x[3]; |
---|
| 1079 | |
---|
| 1080 | for(int j=0; j < numExs; j++){ |
---|
| 1081 | if(Double.isNaN(xBar[j])) continue; //All missing values |
---|
| 1082 | double denorm = (1.0+num[j]*w)*(a+sSq[j]) + |
---|
| 1083 | num[j]*(xBar[j]-m)*(xBar[j]-m); |
---|
| 1084 | |
---|
| 1085 | h[0] += 0.5*(1.0+num[j]*w)/denorm - 0.5/a; |
---|
| 1086 | |
---|
| 1087 | int halfNum = ((int)num[j])/2; |
---|
| 1088 | for(int z=1; z<=halfNum; z++) |
---|
| 1089 | h[1] += |
---|
| 1090 | 1.0/((b+num[j]-2.0*(double)z)*(b+num[j]-2.0*(double)z)); |
---|
| 1091 | if(num[j]/2.0 > halfNum) // num[j] is odd |
---|
| 1092 | h[1] -= 0.25*diffSndDervLnGamma(0.5*b); |
---|
| 1093 | |
---|
| 1094 | h[2] += 0.5*(a+sSq[j])*num[j]/denorm - |
---|
| 1095 | 0.5*num[j]/(1.0+num[j]*w); |
---|
| 1096 | |
---|
| 1097 | h[3] += num[j]*(m-xBar[j])/denorm; |
---|
| 1098 | } |
---|
| 1099 | break; |
---|
| 1100 | |
---|
| 1101 | case 2: // w |
---|
| 1102 | a=x[0];b=x[1];w=x[2];m=x[3]; |
---|
| 1103 | |
---|
| 1104 | for(int j=0; j < numExs; j++){ |
---|
| 1105 | if(Double.isNaN(xBar[j])) continue; //All missing values |
---|
| 1106 | double denorm = (1.0+num[j]*w)*(a+sSq[j]) + |
---|
| 1107 | num[j]*(xBar[j]-m)*(xBar[j]-m); |
---|
| 1108 | |
---|
| 1109 | h[0] += 0.5*num[j]*num[j]*(b+num[j])* |
---|
| 1110 | (xBar[j]-m)*(xBar[j]-m)/(denorm*denorm); |
---|
| 1111 | |
---|
| 1112 | h[1] += 0.5*(a+sSq[j])*num[j]/denorm - |
---|
| 1113 | 0.5*num[j]/(1.0+num[j]*w); |
---|
| 1114 | |
---|
| 1115 | h[2] += 0.5*(b+num[j]-1.0)*num[j]*num[j]/ |
---|
| 1116 | ((1.0+num[j]*w)*(1.0+num[j]*w)) - |
---|
| 1117 | 0.5*(b+num[j])*(a+sSq[j])*(a+sSq[j])* |
---|
| 1118 | num[j]*num[j]/(denorm*denorm); |
---|
| 1119 | |
---|
| 1120 | h[3] -= num[j]*num[j]*(b+num[j])* |
---|
| 1121 | (m-xBar[j])*(a+sSq[j])/(denorm*denorm); |
---|
| 1122 | } |
---|
| 1123 | break; |
---|
| 1124 | |
---|
| 1125 | case 3: // m |
---|
| 1126 | a=x[0];b=x[1];w=x[2];m=x[3]; |
---|
| 1127 | |
---|
| 1128 | for(int j=0; j < numExs; j++){ |
---|
| 1129 | if(Double.isNaN(xBar[j])) continue; //All missing values |
---|
| 1130 | double denorm = (1.0+num[j]*w)*(a+sSq[j]) + |
---|
| 1131 | num[j]*(xBar[j]-m)*(xBar[j]-m); |
---|
| 1132 | |
---|
| 1133 | h[0] -= num[j]*(b+num[j])*(m-xBar[j]) |
---|
| 1134 | *(1.0+num[j]*w)/(denorm*denorm); |
---|
| 1135 | |
---|
| 1136 | h[1] += num[j]*(m-xBar[j])/denorm; |
---|
| 1137 | |
---|
| 1138 | h[2] -= num[j]*num[j]*(b+num[j])* |
---|
| 1139 | (m-xBar[j])*(a+sSq[j])/(denorm*denorm); |
---|
| 1140 | |
---|
| 1141 | h[3] += num[j]*(b+num[j])* |
---|
| 1142 | ((1.0+num[j]*w)*(a+sSq[j])- |
---|
| 1143 | num[j]*(m-xBar[j])*(m-xBar[j])) |
---|
| 1144 | /(denorm*denorm); |
---|
| 1145 | } |
---|
| 1146 | } |
---|
| 1147 | |
---|
| 1148 | return h; |
---|
| 1149 | } |
---|
| 1150 | |
---|
| 1151 | /** |
---|
| 1152 | * Returns the revision string. |
---|
| 1153 | * |
---|
| 1154 | * @return the revision |
---|
| 1155 | */ |
---|
| 1156 | public String getRevision() { |
---|
| 1157 | return RevisionUtils.extract("$Revision: 5481 $"); |
---|
| 1158 | } |
---|
| 1159 | } |
---|