[4] | 1 | /* |
---|
| 2 | * This program is free software; you can redistribute it and/or modify |
---|
| 3 | * it under the terms of the GNU General Public License as published by |
---|
| 4 | * the Free Software Foundation; either version 2 of the License, or |
---|
| 5 | * (at your option) any later version. |
---|
| 6 | * |
---|
| 7 | * This program is distributed in the hope that it will be useful, |
---|
| 8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
| 9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
---|
| 10 | * GNU General Public License for more details. |
---|
| 11 | * |
---|
| 12 | * You should have received a copy of the GNU General Public License |
---|
| 13 | * along with this program; if not, write to the Free Software |
---|
| 14 | * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. |
---|
| 15 | */ |
---|
| 16 | |
---|
| 17 | /* |
---|
| 18 | * TLDSimple.java |
---|
| 19 | * Copyright (C) 2005 University of Waikato, Hamilton, New Zealand |
---|
| 20 | * |
---|
| 21 | */ |
---|
| 22 | |
---|
| 23 | package weka.classifiers.mi; |
---|
| 24 | |
---|
| 25 | import weka.classifiers.RandomizableClassifier; |
---|
| 26 | import weka.core.Capabilities; |
---|
| 27 | import weka.core.Instance; |
---|
| 28 | import weka.core.Instances; |
---|
| 29 | import weka.core.MultiInstanceCapabilitiesHandler; |
---|
| 30 | import weka.core.Optimization; |
---|
| 31 | import weka.core.Option; |
---|
| 32 | import weka.core.OptionHandler; |
---|
| 33 | import weka.core.RevisionUtils; |
---|
| 34 | import weka.core.TechnicalInformation; |
---|
| 35 | import weka.core.TechnicalInformationHandler; |
---|
| 36 | import weka.core.Utils; |
---|
| 37 | import weka.core.Capabilities.Capability; |
---|
| 38 | import weka.core.TechnicalInformation.Field; |
---|
| 39 | import weka.core.TechnicalInformation.Type; |
---|
| 40 | |
---|
| 41 | import java.util.Enumeration; |
---|
| 42 | import java.util.Random; |
---|
| 43 | import java.util.Vector; |
---|
| 44 | |
---|
| 45 | /** |
---|
| 46 | <!-- globalinfo-start --> |
---|
| 47 | * A simpler version of TLD, mu random but sigma^2 fixed and estimated via data.<br/> |
---|
| 48 | * <br/> |
---|
| 49 | * For more information see:<br/> |
---|
| 50 | * <br/> |
---|
| 51 | * Xin Xu (2003). Statistical learning in multiple instance problem. Hamilton, NZ. |
---|
| 52 | * <p/> |
---|
| 53 | <!-- globalinfo-end --> |
---|
| 54 | * |
---|
| 55 | <!-- technical-bibtex-start --> |
---|
| 56 | * BibTeX: |
---|
| 57 | * <pre> |
---|
| 58 | * @mastersthesis{Xu2003, |
---|
| 59 | * address = {Hamilton, NZ}, |
---|
| 60 | * author = {Xin Xu}, |
---|
| 61 | * note = {0657.594}, |
---|
| 62 | * school = {University of Waikato}, |
---|
| 63 | * title = {Statistical learning in multiple instance problem}, |
---|
| 64 | * year = {2003} |
---|
| 65 | * } |
---|
| 66 | * </pre> |
---|
| 67 | * <p/> |
---|
| 68 | <!-- technical-bibtex-end --> |
---|
| 69 | * |
---|
| 70 | <!-- options-start --> |
---|
| 71 | * Valid options are: <p/> |
---|
| 72 | * |
---|
| 73 | * <pre> -C |
---|
| 74 | * Set whether or not use empirical |
---|
| 75 | * log-odds cut-off instead of 0</pre> |
---|
| 76 | * |
---|
| 77 | * <pre> -R <numOfRuns> |
---|
| 78 | * Set the number of multiple runs |
---|
| 79 | * needed for searching the MLE.</pre> |
---|
| 80 | * |
---|
| 81 | * <pre> -S <num> |
---|
| 82 | * Random number seed. |
---|
| 83 | * (default 1)</pre> |
---|
| 84 | * |
---|
| 85 | * <pre> -D |
---|
| 86 | * If set, classifier is run in debug mode and |
---|
| 87 | * may output additional info to the console</pre> |
---|
| 88 | * |
---|
| 89 | <!-- options-end --> |
---|
| 90 | * |
---|
| 91 | * @author Eibe Frank (eibe@cs.waikato.ac.nz) |
---|
| 92 | * @author Xin Xu (xx5@cs.waikato.ac.nz) |
---|
| 93 | * @version $Revision: 5481 $ |
---|
| 94 | */ |
---|
| 95 | public class TLDSimple |
---|
| 96 | extends RandomizableClassifier |
---|
| 97 | implements OptionHandler, MultiInstanceCapabilitiesHandler, |
---|
| 98 | TechnicalInformationHandler { |
---|
| 99 | |
---|
| 100 | /** for serialization */ |
---|
| 101 | static final long serialVersionUID = 9040995947243286591L; |
---|
| 102 | |
---|
| 103 | /** The mean for each attribute of each positive exemplar */ |
---|
| 104 | protected double[][] m_MeanP = null; |
---|
| 105 | |
---|
| 106 | /** The mean for each attribute of each negative exemplar */ |
---|
| 107 | protected double[][] m_MeanN = null; |
---|
| 108 | |
---|
| 109 | /** The effective sum of weights of each positive exemplar in each dimension*/ |
---|
| 110 | protected double[][] m_SumP = null; |
---|
| 111 | |
---|
| 112 | /** The effective sum of weights of each negative exemplar in each dimension*/ |
---|
| 113 | protected double[][] m_SumN = null; |
---|
| 114 | |
---|
| 115 | /** Estimated sigma^2 in positive bags*/ |
---|
| 116 | protected double[] m_SgmSqP; |
---|
| 117 | |
---|
| 118 | /** Estimated sigma^2 in negative bags*/ |
---|
| 119 | protected double[] m_SgmSqN; |
---|
| 120 | |
---|
| 121 | /** The parameters to be estimated for each positive exemplar*/ |
---|
| 122 | protected double[] m_ParamsP = null; |
---|
| 123 | |
---|
| 124 | /** The parameters to be estimated for each negative exemplar*/ |
---|
| 125 | protected double[] m_ParamsN = null; |
---|
| 126 | |
---|
| 127 | /** The dimension of each exemplar, i.e. (numAttributes-2) */ |
---|
| 128 | protected int m_Dimension = 0; |
---|
| 129 | |
---|
| 130 | /** The class label of each exemplar */ |
---|
| 131 | protected double[] m_Class = null; |
---|
| 132 | |
---|
| 133 | /** The number of class labels in the data */ |
---|
| 134 | protected int m_NumClasses = 2; |
---|
| 135 | |
---|
| 136 | /** The very small number representing zero */ |
---|
| 137 | static public double ZERO = 1.0e-12; |
---|
| 138 | |
---|
| 139 | protected int m_Run = 1; |
---|
| 140 | |
---|
| 141 | protected double m_Cutoff; |
---|
| 142 | |
---|
| 143 | protected boolean m_UseEmpiricalCutOff = false; |
---|
| 144 | |
---|
| 145 | private double[] m_LkRatio; |
---|
| 146 | |
---|
| 147 | private Instances m_Attribute = null; |
---|
| 148 | |
---|
| 149 | /** |
---|
| 150 | * Returns a string describing this filter |
---|
| 151 | * |
---|
| 152 | * @return a description of the filter suitable for |
---|
| 153 | * displaying in the explorer/experimenter gui |
---|
| 154 | */ |
---|
| 155 | public String globalInfo() { |
---|
| 156 | return |
---|
| 157 | "A simpler version of TLD, mu random but sigma^2 fixed and estimated " |
---|
| 158 | + "via data.\n\n" |
---|
| 159 | + "For more information see:\n\n" |
---|
| 160 | + getTechnicalInformation().toString(); |
---|
| 161 | } |
---|
| 162 | |
---|
| 163 | /** |
---|
| 164 | * Returns an instance of a TechnicalInformation object, containing |
---|
| 165 | * detailed information about the technical background of this class, |
---|
| 166 | * e.g., paper reference or book this class is based on. |
---|
| 167 | * |
---|
| 168 | * @return the technical information about this class |
---|
| 169 | */ |
---|
| 170 | public TechnicalInformation getTechnicalInformation() { |
---|
| 171 | TechnicalInformation result; |
---|
| 172 | |
---|
| 173 | result = new TechnicalInformation(Type.MASTERSTHESIS); |
---|
| 174 | result.setValue(Field.AUTHOR, "Xin Xu"); |
---|
| 175 | result.setValue(Field.YEAR, "2003"); |
---|
| 176 | result.setValue(Field.TITLE, "Statistical learning in multiple instance problem"); |
---|
| 177 | result.setValue(Field.SCHOOL, "University of Waikato"); |
---|
| 178 | result.setValue(Field.ADDRESS, "Hamilton, NZ"); |
---|
| 179 | result.setValue(Field.NOTE, "0657.594"); |
---|
| 180 | |
---|
| 181 | return result; |
---|
| 182 | } |
---|
| 183 | |
---|
| 184 | /** |
---|
| 185 | * Returns default capabilities of the classifier. |
---|
| 186 | * |
---|
| 187 | * @return the capabilities of this classifier |
---|
| 188 | */ |
---|
| 189 | public Capabilities getCapabilities() { |
---|
| 190 | Capabilities result = super.getCapabilities(); |
---|
| 191 | result.disableAll(); |
---|
| 192 | |
---|
| 193 | // attributes |
---|
| 194 | result.enable(Capability.NOMINAL_ATTRIBUTES); |
---|
| 195 | result.enable(Capability.RELATIONAL_ATTRIBUTES); |
---|
| 196 | result.enable(Capability.MISSING_VALUES); |
---|
| 197 | |
---|
| 198 | // class |
---|
| 199 | result.enable(Capability.BINARY_CLASS); |
---|
| 200 | result.enable(Capability.MISSING_CLASS_VALUES); |
---|
| 201 | |
---|
| 202 | // other |
---|
| 203 | result.enable(Capability.ONLY_MULTIINSTANCE); |
---|
| 204 | |
---|
| 205 | return result; |
---|
| 206 | } |
---|
| 207 | |
---|
| 208 | /** |
---|
| 209 | * Returns the capabilities of this multi-instance classifier for the |
---|
| 210 | * relational data. |
---|
| 211 | * |
---|
| 212 | * @return the capabilities of this object |
---|
| 213 | * @see Capabilities |
---|
| 214 | */ |
---|
| 215 | public Capabilities getMultiInstanceCapabilities() { |
---|
| 216 | Capabilities result = super.getCapabilities(); |
---|
| 217 | result.disableAll(); |
---|
| 218 | |
---|
| 219 | // attributes |
---|
| 220 | result.enable(Capability.NOMINAL_ATTRIBUTES); |
---|
| 221 | result.enable(Capability.NUMERIC_ATTRIBUTES); |
---|
| 222 | result.enable(Capability.DATE_ATTRIBUTES); |
---|
| 223 | result.enable(Capability.MISSING_VALUES); |
---|
| 224 | |
---|
| 225 | // class |
---|
| 226 | result.disableAllClasses(); |
---|
| 227 | result.enable(Capability.NO_CLASS); |
---|
| 228 | |
---|
| 229 | return result; |
---|
| 230 | } |
---|
| 231 | |
---|
| 232 | /** |
---|
| 233 | * |
---|
| 234 | * @param exs the training exemplars |
---|
| 235 | * @throws Exception if the model cannot be built properly |
---|
| 236 | */ |
---|
| 237 | public void buildClassifier(Instances exs)throws Exception{ |
---|
| 238 | // can classifier handle the data? |
---|
| 239 | getCapabilities().testWithFail(exs); |
---|
| 240 | |
---|
| 241 | // remove instances with missing class |
---|
| 242 | exs = new Instances(exs); |
---|
| 243 | exs.deleteWithMissingClass(); |
---|
| 244 | |
---|
| 245 | int numegs = exs.numInstances(); |
---|
| 246 | m_Dimension = exs.attribute(1).relation().numAttributes(); |
---|
| 247 | m_Attribute = exs.attribute(1).relation().stringFreeStructure(); |
---|
| 248 | Instances pos = new Instances(exs, 0), neg = new Instances(exs, 0); |
---|
| 249 | |
---|
| 250 | // Divide into two groups |
---|
| 251 | for(int u=0; u<numegs; u++){ |
---|
| 252 | Instance example = exs.instance(u); |
---|
| 253 | if(example.classValue() == 1) |
---|
| 254 | pos.add(example); |
---|
| 255 | else |
---|
| 256 | neg.add(example); |
---|
| 257 | } |
---|
| 258 | int pnum = pos.numInstances(), nnum = neg.numInstances(); |
---|
| 259 | |
---|
| 260 | // xBar, n |
---|
| 261 | m_MeanP = new double[pnum][m_Dimension]; |
---|
| 262 | m_SumP = new double[pnum][m_Dimension]; |
---|
| 263 | m_MeanN = new double[nnum][m_Dimension]; |
---|
| 264 | m_SumN = new double[nnum][m_Dimension]; |
---|
| 265 | // w, m |
---|
| 266 | m_ParamsP = new double[2*m_Dimension]; |
---|
| 267 | m_ParamsN = new double[2*m_Dimension]; |
---|
| 268 | // \sigma^2 |
---|
| 269 | m_SgmSqP = new double[m_Dimension]; |
---|
| 270 | m_SgmSqN = new double[m_Dimension]; |
---|
| 271 | // S^2 |
---|
| 272 | double[][] varP=new double[pnum][m_Dimension], |
---|
| 273 | varN=new double[nnum][m_Dimension]; |
---|
| 274 | // numOfEx 'e' without all missing |
---|
| 275 | double[] effNumExP=new double[m_Dimension], |
---|
| 276 | effNumExN=new double[m_Dimension]; |
---|
| 277 | // For the starting values |
---|
| 278 | double[] pMM=new double[m_Dimension], |
---|
| 279 | nMM=new double[m_Dimension], |
---|
| 280 | pVM=new double[m_Dimension], |
---|
| 281 | nVM=new double[m_Dimension]; |
---|
| 282 | // # of exemplars with only one instance |
---|
| 283 | double[] numOneInsExsP=new double[m_Dimension], |
---|
| 284 | numOneInsExsN=new double[m_Dimension]; |
---|
| 285 | // sum_i(1/n_i) |
---|
| 286 | double[] pInvN = new double[m_Dimension], nInvN = new double[m_Dimension]; |
---|
| 287 | |
---|
| 288 | // Extract metadata from both positive and negative bags |
---|
| 289 | for(int v=0; v < pnum; v++){ |
---|
| 290 | //Instance px = pos.instance(v); |
---|
| 291 | Instances pxi = pos.instance(v).relationalValue(1); |
---|
| 292 | for (int k=0; k<pxi.numAttributes(); k++) { |
---|
| 293 | m_MeanP[v][k] = pxi.meanOrMode(k); |
---|
| 294 | varP[v][k] = pxi.variance(k); |
---|
| 295 | } |
---|
| 296 | |
---|
| 297 | for (int w=0,t=0; w < m_Dimension; w++,t++){ |
---|
| 298 | //if((t==m_ClassIndex) || (t==m_IdIndex)) |
---|
| 299 | // t++; |
---|
| 300 | if(varP[v][w] <= 0.0) |
---|
| 301 | varP[v][w] = 0.0; |
---|
| 302 | if(!Double.isNaN(m_MeanP[v][w])){ |
---|
| 303 | |
---|
| 304 | for(int u=0;u<pxi.numInstances();u++) |
---|
| 305 | if(!pxi.instance(u).isMissing(t)) |
---|
| 306 | m_SumP[v][w] += pxi.instance(u).weight(); |
---|
| 307 | |
---|
| 308 | pMM[w] += m_MeanP[v][w]; |
---|
| 309 | pVM[w] += m_MeanP[v][w]*m_MeanP[v][w]; |
---|
| 310 | if((m_SumP[v][w]>1) && (varP[v][w]>ZERO)){ |
---|
| 311 | |
---|
| 312 | m_SgmSqP[w] += varP[v][w]*(m_SumP[v][w]-1.0)/m_SumP[v][w]; |
---|
| 313 | |
---|
| 314 | //m_SgmSqP[w] += varP[v][w]*(m_SumP[v][w]-1.0); |
---|
| 315 | effNumExP[w]++; // Not count exemplars with 1 instance |
---|
| 316 | pInvN[w] += 1.0/m_SumP[v][w]; |
---|
| 317 | //pInvN[w] += m_SumP[v][w]; |
---|
| 318 | } |
---|
| 319 | else |
---|
| 320 | numOneInsExsP[w]++; |
---|
| 321 | } |
---|
| 322 | |
---|
| 323 | } |
---|
| 324 | } |
---|
| 325 | |
---|
| 326 | |
---|
| 327 | for(int v=0; v < nnum; v++){ |
---|
| 328 | //Instance nx = neg.instance(v); |
---|
| 329 | Instances nxi = neg.instance(v).relationalValue(1); |
---|
| 330 | for (int k=0; k<nxi.numAttributes(); k++) { |
---|
| 331 | m_MeanN[v][k] = nxi.meanOrMode(k); |
---|
| 332 | varN[v][k] = nxi.variance(k); |
---|
| 333 | } |
---|
| 334 | //Instances nxi = nx.getInstances(); |
---|
| 335 | |
---|
| 336 | for (int w=0,t=0; w < m_Dimension; w++,t++){ |
---|
| 337 | |
---|
| 338 | //if((t==m_ClassIndex) || (t==m_IdIndex)) |
---|
| 339 | // t++; |
---|
| 340 | if(varN[v][w] <= 0.0) |
---|
| 341 | varN[v][w] = 0.0; |
---|
| 342 | if(!Double.isNaN(m_MeanN[v][w])){ |
---|
| 343 | for(int u=0;u<nxi.numInstances();u++) |
---|
| 344 | if(!nxi.instance(u).isMissing(t)) |
---|
| 345 | m_SumN[v][w] += nxi.instance(u).weight(); |
---|
| 346 | |
---|
| 347 | nMM[w] += m_MeanN[v][w]; |
---|
| 348 | nVM[w] += m_MeanN[v][w]*m_MeanN[v][w]; |
---|
| 349 | if((m_SumN[v][w]>1) && (varN[v][w]>ZERO)){ |
---|
| 350 | m_SgmSqN[w] += varN[v][w]*(m_SumN[v][w]-1.0)/m_SumN[v][w]; |
---|
| 351 | //m_SgmSqN[w] += varN[v][w]*(m_SumN[v][w]-1.0); |
---|
| 352 | effNumExN[w]++; // Not count exemplars with 1 instance |
---|
| 353 | nInvN[w] += 1.0/m_SumN[v][w]; |
---|
| 354 | //nInvN[w] += m_SumN[v][w]; |
---|
| 355 | } |
---|
| 356 | else |
---|
| 357 | numOneInsExsN[w]++; |
---|
| 358 | } |
---|
| 359 | } |
---|
| 360 | } |
---|
| 361 | |
---|
| 362 | // Expected \sigma^2 |
---|
| 363 | /* if m_SgmSqP[u] or m_SgmSqN[u] is 0, assign 0 to sigma^2. |
---|
| 364 | * Otherwise, may cause k m_SgmSqP / m_SgmSqN to be NaN. |
---|
| 365 | * Modified by Lin Dong (Sep. 2005) |
---|
| 366 | */ |
---|
| 367 | for (int u=0; u < m_Dimension; u++){ |
---|
| 368 | // For exemplars with only one instance, use avg(\sigma^2) of other exemplars |
---|
| 369 | if (m_SgmSqP[u]!=0) |
---|
| 370 | m_SgmSqP[u] /= (effNumExP[u]-pInvN[u]); |
---|
| 371 | else |
---|
| 372 | m_SgmSqP[u] = 0; |
---|
| 373 | if (m_SgmSqN[u]!=0) |
---|
| 374 | m_SgmSqN[u] /= (effNumExN[u]-nInvN[u]); |
---|
| 375 | else |
---|
| 376 | m_SgmSqN[u] = 0; |
---|
| 377 | |
---|
| 378 | //m_SgmSqP[u] /= (pInvN[u]-effNumExP[u]); |
---|
| 379 | //m_SgmSqN[u] /= (nInvN[u]-effNumExN[u]); |
---|
| 380 | effNumExP[u] += numOneInsExsP[u]; |
---|
| 381 | effNumExN[u] += numOneInsExsN[u]; |
---|
| 382 | pMM[u] /= effNumExP[u]; |
---|
| 383 | nMM[u] /= effNumExN[u]; |
---|
| 384 | pVM[u] = pVM[u]/(effNumExP[u]-1.0) - pMM[u]*pMM[u]*effNumExP[u]/(effNumExP[u]-1.0); |
---|
| 385 | nVM[u] = nVM[u]/(effNumExN[u]-1.0) - nMM[u]*nMM[u]*effNumExN[u]/(effNumExN[u]-1.0); |
---|
| 386 | } |
---|
| 387 | |
---|
| 388 | //Bounds and parameter values for each run |
---|
| 389 | double[][] bounds = new double[2][2]; |
---|
| 390 | double[] pThisParam = new double[2], |
---|
| 391 | nThisParam = new double[2]; |
---|
| 392 | |
---|
| 393 | // Initial values for parameters |
---|
| 394 | double w, m; |
---|
| 395 | Random whichEx = new Random(m_Seed); |
---|
| 396 | |
---|
| 397 | // Optimize for one dimension |
---|
| 398 | for (int x=0; x < m_Dimension; x++){ |
---|
| 399 | // System.out.println("\n\n!!!!!!!!!!!!!!!!!!!!!!???Dimension #"+x); |
---|
| 400 | |
---|
| 401 | // Positive examplars: first run |
---|
| 402 | pThisParam[0] = pVM[x]; // w |
---|
| 403 | if( pThisParam[0] <= ZERO) |
---|
| 404 | pThisParam[0] = 1.0; |
---|
| 405 | pThisParam[1] = pMM[x]; // m |
---|
| 406 | |
---|
| 407 | // Negative examplars: first run |
---|
| 408 | nThisParam[0] = nVM[x]; // w |
---|
| 409 | if(nThisParam[0] <= ZERO) |
---|
| 410 | nThisParam[0] = 1.0; |
---|
| 411 | nThisParam[1] = nMM[x]; // m |
---|
| 412 | |
---|
| 413 | // Bound constraints |
---|
| 414 | bounds[0][0] = ZERO; // w > 0 |
---|
| 415 | bounds[0][1] = Double.NaN; |
---|
| 416 | bounds[1][0] = Double.NaN; |
---|
| 417 | bounds[1][1] = Double.NaN; |
---|
| 418 | |
---|
| 419 | double pminVal=Double.MAX_VALUE, nminVal=Double.MAX_VALUE; |
---|
| 420 | TLDSimple_Optm pOp=null, nOp=null; |
---|
| 421 | boolean isRunValid = true; |
---|
| 422 | double[] sumP=new double[pnum], meanP=new double[pnum]; |
---|
| 423 | double[] sumN=new double[nnum], meanN=new double[nnum]; |
---|
| 424 | |
---|
| 425 | // One dimension |
---|
| 426 | for(int p=0; p<pnum; p++){ |
---|
| 427 | sumP[p] = m_SumP[p][x]; |
---|
| 428 | meanP[p] = m_MeanP[p][x]; |
---|
| 429 | } |
---|
| 430 | for(int q=0; q<nnum; q++){ |
---|
| 431 | sumN[q] = m_SumN[q][x]; |
---|
| 432 | meanN[q] = m_MeanN[q][x]; |
---|
| 433 | } |
---|
| 434 | |
---|
| 435 | for(int y=0; y<m_Run; y++){ |
---|
| 436 | //System.out.println("\n\n!!!!!!!!!Positive exemplars: Run #"+y); |
---|
| 437 | double thisMin; |
---|
| 438 | pOp = new TLDSimple_Optm(); |
---|
| 439 | pOp.setNum(sumP); |
---|
| 440 | pOp.setSgmSq(m_SgmSqP[x]); |
---|
| 441 | if (getDebug()) |
---|
| 442 | System.out.println("m_SgmSqP["+x+"]= " +m_SgmSqP[x]); |
---|
| 443 | pOp.setXBar(meanP); |
---|
| 444 | //pOp.setDebug(true); |
---|
| 445 | pThisParam = pOp.findArgmin(pThisParam, bounds); |
---|
| 446 | while(pThisParam==null){ |
---|
| 447 | pThisParam = pOp.getVarbValues(); |
---|
| 448 | if (getDebug()) |
---|
| 449 | System.out.println("!!! 200 iterations finished, not enough!"); |
---|
| 450 | pThisParam = pOp.findArgmin(pThisParam, bounds); |
---|
| 451 | } |
---|
| 452 | |
---|
| 453 | thisMin = pOp.getMinFunction(); |
---|
| 454 | if(!Double.isNaN(thisMin) && (thisMin<pminVal)){ |
---|
| 455 | pminVal = thisMin; |
---|
| 456 | for(int z=0; z<2; z++) |
---|
| 457 | m_ParamsP[2*x+z] = pThisParam[z]; |
---|
| 458 | } |
---|
| 459 | |
---|
| 460 | if(Double.isNaN(thisMin)){ |
---|
| 461 | pThisParam = new double[2]; |
---|
| 462 | isRunValid =false; |
---|
| 463 | } |
---|
| 464 | if(!isRunValid){ y--; isRunValid=true; } |
---|
| 465 | |
---|
| 466 | // Change the initial parameters and restart |
---|
| 467 | int pone = whichEx.nextInt(pnum); |
---|
| 468 | |
---|
| 469 | // Positive exemplars: next run |
---|
| 470 | while(Double.isNaN(m_MeanP[pone][x])) |
---|
| 471 | pone = whichEx.nextInt(pnum); |
---|
| 472 | |
---|
| 473 | m = m_MeanP[pone][x]; |
---|
| 474 | w = (m-pThisParam[1])*(m-pThisParam[1]); |
---|
| 475 | pThisParam[0] = w; // w |
---|
| 476 | pThisParam[1] = m; // m |
---|
| 477 | } |
---|
| 478 | |
---|
| 479 | for(int y=0; y<m_Run; y++){ |
---|
| 480 | //System.out.println("\n\n!!!!!!!!!Negative exemplars: Run #"+y); |
---|
| 481 | double thisMin; |
---|
| 482 | nOp = new TLDSimple_Optm(); |
---|
| 483 | nOp.setNum(sumN); |
---|
| 484 | nOp.setSgmSq(m_SgmSqN[x]); |
---|
| 485 | if (getDebug()) |
---|
| 486 | System.out.println(m_SgmSqN[x]); |
---|
| 487 | nOp.setXBar(meanN); |
---|
| 488 | //nOp.setDebug(true); |
---|
| 489 | nThisParam = nOp.findArgmin(nThisParam, bounds); |
---|
| 490 | |
---|
| 491 | while(nThisParam==null){ |
---|
| 492 | nThisParam = nOp.getVarbValues(); |
---|
| 493 | if (getDebug()) |
---|
| 494 | System.out.println("!!! 200 iterations finished, not enough!"); |
---|
| 495 | nThisParam = nOp.findArgmin(nThisParam, bounds); |
---|
| 496 | } |
---|
| 497 | |
---|
| 498 | thisMin = nOp.getMinFunction(); |
---|
| 499 | if(!Double.isNaN(thisMin) && (thisMin<nminVal)){ |
---|
| 500 | nminVal = thisMin; |
---|
| 501 | for(int z=0; z<2; z++) |
---|
| 502 | m_ParamsN[2*x+z] = nThisParam[z]; |
---|
| 503 | } |
---|
| 504 | |
---|
| 505 | if(Double.isNaN(thisMin)){ |
---|
| 506 | nThisParam = new double[2]; |
---|
| 507 | isRunValid =false; |
---|
| 508 | } |
---|
| 509 | |
---|
| 510 | if(!isRunValid){ y--; isRunValid=true; } |
---|
| 511 | |
---|
| 512 | // Change the initial parameters and restart |
---|
| 513 | int none = whichEx.nextInt(nnum);// Randomly pick one pos. exmpl. |
---|
| 514 | |
---|
| 515 | // Negative exemplars: next run |
---|
| 516 | while(Double.isNaN(m_MeanN[none][x])) |
---|
| 517 | none = whichEx.nextInt(nnum); |
---|
| 518 | |
---|
| 519 | m = m_MeanN[none][x]; |
---|
| 520 | w = (m-nThisParam[1])*(m-nThisParam[1]); |
---|
| 521 | nThisParam[0] = w; // w |
---|
| 522 | nThisParam[1] = m; // m |
---|
| 523 | } |
---|
| 524 | } |
---|
| 525 | |
---|
| 526 | m_LkRatio = new double[m_Dimension]; |
---|
| 527 | |
---|
| 528 | if(m_UseEmpiricalCutOff){ |
---|
| 529 | // Find the empirical cut-off |
---|
| 530 | double[] pLogOdds=new double[pnum], nLogOdds=new double[nnum]; |
---|
| 531 | for(int p=0; p<pnum; p++) |
---|
| 532 | pLogOdds[p] = |
---|
| 533 | likelihoodRatio(m_SumP[p], m_MeanP[p]); |
---|
| 534 | |
---|
| 535 | for(int q=0; q<nnum; q++) |
---|
| 536 | nLogOdds[q] = |
---|
| 537 | likelihoodRatio(m_SumN[q], m_MeanN[q]); |
---|
| 538 | |
---|
| 539 | // Update m_Cutoff |
---|
| 540 | findCutOff(pLogOdds, nLogOdds); |
---|
| 541 | } |
---|
| 542 | else |
---|
| 543 | m_Cutoff = -Math.log((double)pnum/(double)nnum); |
---|
| 544 | |
---|
| 545 | /* |
---|
| 546 | for(int x=0, y=0; x<m_Dimension; x++, y++){ |
---|
| 547 | if((x==exs.classIndex()) || (x==exs.idIndex())) |
---|
| 548 | y++; |
---|
| 549 | |
---|
| 550 | w=m_ParamsP[2*x]; m=m_ParamsP[2*x+1]; |
---|
| 551 | System.err.println("\n\n???Positive: ( "+exs.attribute(y)+ |
---|
| 552 | "): w="+w+", m="+m+", sgmSq="+m_SgmSqP[x]); |
---|
| 553 | |
---|
| 554 | w=m_ParamsN[2*x]; m=m_ParamsN[2*x+1]; |
---|
| 555 | System.err.println("???Negative: ("+exs.attribute(y)+ |
---|
| 556 | "): w="+w+", m="+m+", sgmSq="+m_SgmSqN[x]+ |
---|
| 557 | "\nAvg. log-likelihood ratio in training data=" |
---|
| 558 | +(m_LkRatio[x]/(pnum+nnum))); |
---|
| 559 | } |
---|
| 560 | */ |
---|
| 561 | if (getDebug()) |
---|
| 562 | System.err.println("\n\n???Cut-off="+m_Cutoff); |
---|
| 563 | } |
---|
| 564 | |
---|
| 565 | /** |
---|
| 566 | * |
---|
| 567 | * @param ex the given test exemplar |
---|
| 568 | * @return the classification |
---|
| 569 | * @throws Exception if the exemplar could not be classified |
---|
| 570 | * successfully |
---|
| 571 | */ |
---|
| 572 | public double classifyInstance(Instance ex)throws Exception{ |
---|
| 573 | //Instance ex = new Exemplar(e); |
---|
| 574 | Instances exi = ex.relationalValue(1); |
---|
| 575 | double[] n = new double[m_Dimension]; |
---|
| 576 | double [] xBar = new double[m_Dimension]; |
---|
| 577 | for (int i=0; i<exi.numAttributes() ; i++) |
---|
| 578 | xBar[i] = exi.meanOrMode(i); |
---|
| 579 | |
---|
| 580 | for (int w=0, t=0; w < m_Dimension; w++, t++){ |
---|
| 581 | // if((t==m_ClassIndex) || (t==m_IdIndex)) |
---|
| 582 | //t++; |
---|
| 583 | for(int u=0;u<exi.numInstances();u++) |
---|
| 584 | if(!exi.instance(u).isMissing(t)) |
---|
| 585 | n[w] += exi.instance(u).weight(); |
---|
| 586 | } |
---|
| 587 | |
---|
| 588 | double logOdds = likelihoodRatio(n, xBar); |
---|
| 589 | return (logOdds > m_Cutoff) ? 1 : 0 ; |
---|
| 590 | } |
---|
| 591 | |
---|
| 592 | /** |
---|
| 593 | * Computes the distribution for a given exemplar |
---|
| 594 | * |
---|
| 595 | * @param ex the exemplar for which distribution is computed |
---|
| 596 | * @return the distribution |
---|
| 597 | * @throws Exception if the distribution can't be computed successfully |
---|
| 598 | */ |
---|
| 599 | public double[] distributionForInstance(Instance ex) throws Exception { |
---|
| 600 | |
---|
| 601 | double[] distribution = new double[2]; |
---|
| 602 | Instances exi = ex.relationalValue(1); |
---|
| 603 | double[] n = new double[m_Dimension]; |
---|
| 604 | double[] xBar = new double[m_Dimension]; |
---|
| 605 | for (int i = 0; i < exi.numAttributes() ; i++) |
---|
| 606 | xBar[i] = exi.meanOrMode(i); |
---|
| 607 | |
---|
| 608 | for (int w = 0, t = 0; w < m_Dimension; w++, t++){ |
---|
| 609 | for (int u = 0; u < exi.numInstances(); u++) |
---|
| 610 | if (!exi.instance(u).isMissing(t)) |
---|
| 611 | n[w] += exi.instance(u).weight(); |
---|
| 612 | } |
---|
| 613 | |
---|
| 614 | double logOdds = likelihoodRatio(n, xBar); |
---|
| 615 | |
---|
| 616 | // returned logOdds value has been divided by m_Dimension to avoid |
---|
| 617 | // Math.exp(logOdds) getting too large or too small, |
---|
| 618 | // that may result in two fixed distribution value (1 or 0). |
---|
| 619 | distribution[0] = 1 / (1 + Math.exp(logOdds)); // Prob. for class 0 (negative) |
---|
| 620 | distribution[1] = 1 - distribution[0]; |
---|
| 621 | |
---|
| 622 | return distribution; |
---|
| 623 | } |
---|
| 624 | |
---|
| 625 | /** |
---|
| 626 | * Compute the log-likelihood ratio |
---|
| 627 | */ |
---|
| 628 | private double likelihoodRatio(double[] n, double[] xBar){ |
---|
| 629 | double LLP = 0.0, LLN = 0.0; |
---|
| 630 | |
---|
| 631 | for (int x=0; x<m_Dimension; x++){ |
---|
| 632 | if(Double.isNaN(xBar[x])) continue; // All missing values |
---|
| 633 | //if(Double.isNaN(xBar[x]) || (m_ParamsP[2*x] <= ZERO) |
---|
| 634 | // || (m_ParamsN[2*x]<=ZERO)) |
---|
| 635 | // continue; // All missing values |
---|
| 636 | |
---|
| 637 | //Log-likelihood for positive |
---|
| 638 | double w=m_ParamsP[2*x], m=m_ParamsP[2*x+1]; |
---|
| 639 | double llp = Math.log(w*n[x]+m_SgmSqP[x]) |
---|
| 640 | + n[x]*(m-xBar[x])*(m-xBar[x])/(w*n[x]+m_SgmSqP[x]); |
---|
| 641 | LLP -= llp; |
---|
| 642 | |
---|
| 643 | //Log-likelihood for negative |
---|
| 644 | w=m_ParamsN[2*x]; m=m_ParamsN[2*x+1]; |
---|
| 645 | double lln = Math.log(w*n[x]+m_SgmSqN[x]) |
---|
| 646 | + n[x]*(m-xBar[x])*(m-xBar[x])/(w*n[x]+m_SgmSqN[x]); |
---|
| 647 | LLN -= lln; |
---|
| 648 | |
---|
| 649 | m_LkRatio[x] += llp - lln; |
---|
| 650 | } |
---|
| 651 | |
---|
| 652 | return LLP - LLN / m_Dimension; |
---|
| 653 | } |
---|
| 654 | |
---|
| 655 | private void findCutOff(double[] pos, double[] neg){ |
---|
| 656 | int[] pOrder = Utils.sort(pos), |
---|
| 657 | nOrder = Utils.sort(neg); |
---|
| 658 | /* |
---|
| 659 | System.err.println("\n\n???Positive: "); |
---|
| 660 | for(int t=0; t<pOrder.length; t++) |
---|
| 661 | System.err.print(t+":"+Utils.doubleToString(pos[pOrder[t]],0,2)+" "); |
---|
| 662 | System.err.println("\n\n???Negative: "); |
---|
| 663 | for(int t=0; t<nOrder.length; t++) |
---|
| 664 | System.err.print(t+":"+Utils.doubleToString(neg[nOrder[t]],0,2)+" "); |
---|
| 665 | */ |
---|
| 666 | int pNum = pos.length, nNum = neg.length, count, p=0, n=0; |
---|
| 667 | double fstAccu=0.0, sndAccu=(double)pNum, split; |
---|
| 668 | double maxAccu = 0, minDistTo0 = Double.MAX_VALUE; |
---|
| 669 | |
---|
| 670 | // Skip continuous negatives |
---|
| 671 | for(;(n<nNum)&&(pos[pOrder[0]]>=neg[nOrder[n]]); n++, fstAccu++); |
---|
| 672 | |
---|
| 673 | if(n>=nNum){ // totally seperate |
---|
| 674 | m_Cutoff = (neg[nOrder[nNum-1]]+pos[pOrder[0]])/2.0; |
---|
| 675 | //m_Cutoff = neg[nOrder[nNum-1]]; |
---|
| 676 | return; |
---|
| 677 | } |
---|
| 678 | |
---|
| 679 | count=n; |
---|
| 680 | while((p<pNum)&&(n<nNum)){ |
---|
| 681 | // Compare the next in the two lists |
---|
| 682 | if(pos[pOrder[p]]>=neg[nOrder[n]]){ // Neg has less log-odds |
---|
| 683 | fstAccu += 1.0; |
---|
| 684 | split=neg[nOrder[n]]; |
---|
| 685 | n++; |
---|
| 686 | } |
---|
| 687 | else{ |
---|
| 688 | sndAccu -= 1.0; |
---|
| 689 | split=pos[pOrder[p]]; |
---|
| 690 | p++; |
---|
| 691 | } |
---|
| 692 | count++; |
---|
| 693 | /* |
---|
| 694 | double entropy=0.0, cover=(double)count; |
---|
| 695 | if(fstAccu>0.0) |
---|
| 696 | entropy -= fstAccu*Math.log(fstAccu/cover); |
---|
| 697 | if(sndAccu>0.0) |
---|
| 698 | entropy -= sndAccu*Math.log(sndAccu/(total-cover)); |
---|
| 699 | |
---|
| 700 | if(entropy < minEntropy){ |
---|
| 701 | minEntropy = entropy; |
---|
| 702 | //find the next smallest |
---|
| 703 | //double next = neg[nOrder[n]]; |
---|
| 704 | //if(pos[pOrder[p]]<neg[nOrder[n]]) |
---|
| 705 | // next = pos[pOrder[p]]; |
---|
| 706 | //m_Cutoff = (split+next)/2.0; |
---|
| 707 | m_Cutoff = split; |
---|
| 708 | } |
---|
| 709 | */ |
---|
| 710 | if ((fstAccu+sndAccu > maxAccu) || |
---|
| 711 | ((fstAccu+sndAccu == maxAccu) && (Math.abs(split)<minDistTo0))){ |
---|
| 712 | maxAccu = fstAccu+sndAccu; |
---|
| 713 | m_Cutoff = split; |
---|
| 714 | minDistTo0 = Math.abs(split); |
---|
| 715 | } |
---|
| 716 | } |
---|
| 717 | } |
---|
| 718 | |
---|
| 719 | /** |
---|
| 720 | * Returns an enumeration describing the available options |
---|
| 721 | * |
---|
| 722 | * @return an enumeration of all the available options |
---|
| 723 | */ |
---|
| 724 | public Enumeration listOptions() { |
---|
| 725 | Vector result = new Vector(); |
---|
| 726 | |
---|
| 727 | result.addElement(new Option( |
---|
| 728 | "\tSet whether or not use empirical\n" |
---|
| 729 | + "\tlog-odds cut-off instead of 0", |
---|
| 730 | "C", 0, "-C")); |
---|
| 731 | |
---|
| 732 | result.addElement(new Option( |
---|
| 733 | "\tSet the number of multiple runs \n" |
---|
| 734 | + "\tneeded for searching the MLE.", |
---|
| 735 | "R", 1, "-R <numOfRuns>")); |
---|
| 736 | |
---|
| 737 | Enumeration enu = super.listOptions(); |
---|
| 738 | while (enu.hasMoreElements()) { |
---|
| 739 | result.addElement(enu.nextElement()); |
---|
| 740 | } |
---|
| 741 | |
---|
| 742 | return result.elements(); |
---|
| 743 | } |
---|
| 744 | |
---|
| 745 | /** |
---|
| 746 | * Parses a given list of options. <p/> |
---|
| 747 | * |
---|
| 748 | <!-- options-start --> |
---|
| 749 | * Valid options are: <p/> |
---|
| 750 | * |
---|
| 751 | * <pre> -C |
---|
| 752 | * Set whether or not use empirical |
---|
| 753 | * log-odds cut-off instead of 0</pre> |
---|
| 754 | * |
---|
| 755 | * <pre> -R <numOfRuns> |
---|
| 756 | * Set the number of multiple runs |
---|
| 757 | * needed for searching the MLE.</pre> |
---|
| 758 | * |
---|
| 759 | * <pre> -S <num> |
---|
| 760 | * Random number seed. |
---|
| 761 | * (default 1)</pre> |
---|
| 762 | * |
---|
| 763 | * <pre> -D |
---|
| 764 | * If set, classifier is run in debug mode and |
---|
| 765 | * may output additional info to the console</pre> |
---|
| 766 | * |
---|
| 767 | <!-- options-end --> |
---|
| 768 | * |
---|
| 769 | * @param options the list of options as an array of strings |
---|
| 770 | * @throws Exception if an option is not supported |
---|
| 771 | */ |
---|
| 772 | public void setOptions(String[] options) throws Exception{ |
---|
| 773 | setDebug(Utils.getFlag('D', options)); |
---|
| 774 | |
---|
| 775 | setUsingCutOff(Utils.getFlag('C', options)); |
---|
| 776 | |
---|
| 777 | String runString = Utils.getOption('R', options); |
---|
| 778 | if (runString.length() != 0) |
---|
| 779 | setNumRuns(Integer.parseInt(runString)); |
---|
| 780 | else |
---|
| 781 | setNumRuns(1); |
---|
| 782 | |
---|
| 783 | super.setOptions(options); |
---|
| 784 | } |
---|
| 785 | |
---|
| 786 | /** |
---|
| 787 | * Gets the current settings of the Classifier. |
---|
| 788 | * |
---|
| 789 | * @return an array of strings suitable for passing to setOptions |
---|
| 790 | */ |
---|
| 791 | public String[] getOptions() { |
---|
| 792 | Vector result; |
---|
| 793 | String[] options; |
---|
| 794 | int i; |
---|
| 795 | |
---|
| 796 | result = new Vector(); |
---|
| 797 | options = super.getOptions(); |
---|
| 798 | for (i = 0; i < options.length; i++) |
---|
| 799 | result.add(options[i]); |
---|
| 800 | |
---|
| 801 | if (getDebug()) |
---|
| 802 | result.add("-D"); |
---|
| 803 | |
---|
| 804 | if (getUsingCutOff()) |
---|
| 805 | result.add("-C"); |
---|
| 806 | |
---|
| 807 | result.add("-R"); |
---|
| 808 | result.add("" + getNumRuns()); |
---|
| 809 | |
---|
| 810 | return (String[]) result.toArray(new String[result.size()]); |
---|
| 811 | } |
---|
| 812 | |
---|
| 813 | /** |
---|
| 814 | * Returns the tip text for this property |
---|
| 815 | * |
---|
| 816 | * @return tip text for this property suitable for |
---|
| 817 | * displaying in the explorer/experimenter gui |
---|
| 818 | */ |
---|
| 819 | public String numRunsTipText() { |
---|
| 820 | return "The number of runs to perform."; |
---|
| 821 | } |
---|
| 822 | |
---|
| 823 | /** |
---|
| 824 | * Sets the number of runs to perform. |
---|
| 825 | * |
---|
| 826 | * @param numRuns the number of runs to perform |
---|
| 827 | */ |
---|
| 828 | public void setNumRuns(int numRuns) { |
---|
| 829 | m_Run = numRuns; |
---|
| 830 | } |
---|
| 831 | |
---|
| 832 | /** |
---|
| 833 | * Returns the number of runs to perform. |
---|
| 834 | * |
---|
| 835 | * @return the number of runs to perform |
---|
| 836 | */ |
---|
| 837 | public int getNumRuns() { |
---|
| 838 | return m_Run; |
---|
| 839 | } |
---|
| 840 | |
---|
| 841 | /** |
---|
| 842 | * Returns the tip text for this property |
---|
| 843 | * |
---|
| 844 | * @return tip text for this property suitable for |
---|
| 845 | * displaying in the explorer/experimenter gui |
---|
| 846 | */ |
---|
| 847 | public String usingCutOffTipText() { |
---|
| 848 | return "Whether to use an empirical cutoff."; |
---|
| 849 | } |
---|
| 850 | |
---|
| 851 | /** |
---|
| 852 | * Sets whether to use an empirical cutoff. |
---|
| 853 | * |
---|
| 854 | * @param cutOff whether to use an empirical cutoff |
---|
| 855 | */ |
---|
| 856 | public void setUsingCutOff (boolean cutOff) { |
---|
| 857 | m_UseEmpiricalCutOff =cutOff; |
---|
| 858 | } |
---|
| 859 | |
---|
| 860 | /** |
---|
| 861 | * Returns whether an empirical cutoff is used |
---|
| 862 | * |
---|
| 863 | * @return true if an empirical cutoff is used |
---|
| 864 | */ |
---|
| 865 | public boolean getUsingCutOff() { |
---|
| 866 | return m_UseEmpiricalCutOff ; |
---|
| 867 | } |
---|
| 868 | |
---|
| 869 | /** |
---|
| 870 | * Gets a string describing the classifier. |
---|
| 871 | * |
---|
| 872 | * @return a string describing the classifer built. |
---|
| 873 | */ |
---|
| 874 | public String toString(){ |
---|
| 875 | StringBuffer text = new StringBuffer("\n\nTLDSimple:\n"); |
---|
| 876 | double sgm, w, m; |
---|
| 877 | for (int x=0, y=0; x<m_Dimension; x++, y++){ |
---|
| 878 | // if((x==m_ClassIndex) || (x==m_IdIndex)) |
---|
| 879 | //y++; |
---|
| 880 | sgm = m_SgmSqP[x]; |
---|
| 881 | w=m_ParamsP[2*x]; |
---|
| 882 | m=m_ParamsP[2*x+1]; |
---|
| 883 | text.append("\n"+m_Attribute.attribute(y).name()+"\nPositive: "+ |
---|
| 884 | "sigma^2="+sgm+", w="+w+", m="+m+"\n"); |
---|
| 885 | sgm = m_SgmSqN[x]; |
---|
| 886 | w=m_ParamsN[2*x]; |
---|
| 887 | m=m_ParamsN[2*x+1]; |
---|
| 888 | text.append("Negative: "+ |
---|
| 889 | "sigma^2="+sgm+", w="+w+", m="+m+"\n"); |
---|
| 890 | } |
---|
| 891 | |
---|
| 892 | return text.toString(); |
---|
| 893 | } |
---|
| 894 | |
---|
| 895 | /** |
---|
| 896 | * Returns the revision string. |
---|
| 897 | * |
---|
| 898 | * @return the revision |
---|
| 899 | */ |
---|
| 900 | public String getRevision() { |
---|
| 901 | return RevisionUtils.extract("$Revision: 5481 $"); |
---|
| 902 | } |
---|
| 903 | |
---|
| 904 | /** |
---|
| 905 | * Main method for testing. |
---|
| 906 | * |
---|
| 907 | * @param args the options for the classifier |
---|
| 908 | */ |
---|
| 909 | public static void main(String[] args) { |
---|
| 910 | runClassifier(new TLDSimple(), args); |
---|
| 911 | } |
---|
| 912 | } |
---|
| 913 | |
---|
| 914 | class TLDSimple_Optm extends Optimization { |
---|
| 915 | |
---|
| 916 | private double[] num; |
---|
| 917 | private double sSq; |
---|
| 918 | private double[] xBar; |
---|
| 919 | |
---|
| 920 | public void setNum(double[] n) {num = n;} |
---|
| 921 | public void setSgmSq(double s){ |
---|
| 922 | |
---|
| 923 | sSq = s; |
---|
| 924 | } |
---|
| 925 | public void setXBar(double[] x){xBar = x;} |
---|
| 926 | |
---|
| 927 | /** |
---|
| 928 | * Implement this procedure to evaluate objective |
---|
| 929 | * function to be minimized |
---|
| 930 | */ |
---|
| 931 | protected double objectiveFunction(double[] x){ |
---|
| 932 | int numExs = num.length; |
---|
| 933 | double NLL=0; // Negative Log-Likelihood |
---|
| 934 | |
---|
| 935 | double w=x[0], m=x[1]; |
---|
| 936 | for(int j=0; j < numExs; j++){ |
---|
| 937 | |
---|
| 938 | if(Double.isNaN(xBar[j])) continue; // All missing values |
---|
| 939 | double bag=0; |
---|
| 940 | |
---|
| 941 | bag += Math.log(w*num[j]+sSq); |
---|
| 942 | |
---|
| 943 | if(Double.isNaN(bag) && m_Debug){ |
---|
| 944 | System.out.println("???????????1: "+w+" "+m |
---|
| 945 | +"|x-: "+xBar[j] + |
---|
| 946 | "|n: "+num[j] + "|S^2: "+sSq); |
---|
| 947 | //System.exit(1); |
---|
| 948 | } |
---|
| 949 | |
---|
| 950 | bag += num[j]*(m-xBar[j])*(m-xBar[j])/(w*num[j]+sSq); |
---|
| 951 | if(Double.isNaN(bag) && m_Debug){ |
---|
| 952 | System.out.println("???????????2: "+w+" "+m |
---|
| 953 | +"|x-: "+xBar[j] + |
---|
| 954 | "|n: "+num[j] + "|S^2: "+sSq); |
---|
| 955 | //System.exit(1); |
---|
| 956 | } |
---|
| 957 | |
---|
| 958 | //if(bag<0) bag=0; |
---|
| 959 | NLL += bag; |
---|
| 960 | } |
---|
| 961 | |
---|
| 962 | //System.out.println("???????????NLL:"+NLL); |
---|
| 963 | return NLL; |
---|
| 964 | } |
---|
| 965 | |
---|
| 966 | /** |
---|
| 967 | * Subclass should implement this procedure to evaluate gradient |
---|
| 968 | * of the objective function |
---|
| 969 | */ |
---|
| 970 | protected double[] evaluateGradient(double[] x){ |
---|
| 971 | double[] g = new double[x.length]; |
---|
| 972 | int numExs = num.length; |
---|
| 973 | |
---|
| 974 | double w=x[0],m=x[1]; |
---|
| 975 | double dw=0.0, dm=0.0; |
---|
| 976 | |
---|
| 977 | for(int j=0; j < numExs; j++){ |
---|
| 978 | |
---|
| 979 | if(Double.isNaN(xBar[j])) continue; // All missing values |
---|
| 980 | dw += num[j]/(w*num[j]+sSq) |
---|
| 981 | - num[j]*num[j]*(m-xBar[j])*(m-xBar[j])/((w*num[j]+sSq)*(w*num[j]+sSq)); |
---|
| 982 | |
---|
| 983 | dm += 2.0*num[j]*(m-xBar[j])/(w*num[j]+sSq); |
---|
| 984 | } |
---|
| 985 | |
---|
| 986 | g[0] = dw; |
---|
| 987 | g[1] = dm; |
---|
| 988 | return g; |
---|
| 989 | } |
---|
| 990 | |
---|
| 991 | /** |
---|
| 992 | * Subclass should implement this procedure to evaluate second-order |
---|
| 993 | * gradient of the objective function |
---|
| 994 | */ |
---|
| 995 | protected double[] evaluateHessian(double[] x, int index){ |
---|
| 996 | double[] h = new double[x.length]; |
---|
| 997 | |
---|
| 998 | // # of exemplars, # of dimensions |
---|
| 999 | // which dimension and which variable for 'index' |
---|
| 1000 | int numExs = num.length; |
---|
| 1001 | double w,m; |
---|
| 1002 | // Take the 2nd-order derivative |
---|
| 1003 | switch(index){ |
---|
| 1004 | case 0: // w |
---|
| 1005 | w=x[0];m=x[1]; |
---|
| 1006 | |
---|
| 1007 | for(int j=0; j < numExs; j++){ |
---|
| 1008 | if(Double.isNaN(xBar[j])) continue; //All missing values |
---|
| 1009 | |
---|
| 1010 | h[0] += 2.0*Math.pow(num[j],3)*(m-xBar[j])*(m-xBar[j])/Math.pow(w*num[j]+sSq,3) |
---|
| 1011 | - num[j]*num[j]/((w*num[j]+sSq)*(w*num[j]+sSq)); |
---|
| 1012 | |
---|
| 1013 | h[1] -= 2.0*(m-xBar[j])*num[j]*num[j]/((num[j]*w+sSq)*(num[j]*w+sSq)); |
---|
| 1014 | } |
---|
| 1015 | break; |
---|
| 1016 | |
---|
| 1017 | case 1: // m |
---|
| 1018 | w=x[0];m=x[1]; |
---|
| 1019 | |
---|
| 1020 | for(int j=0; j < numExs; j++){ |
---|
| 1021 | if(Double.isNaN(xBar[j])) continue; //All missing values |
---|
| 1022 | |
---|
| 1023 | h[0] -= 2.0*(m-xBar[j])*num[j]*num[j]/((num[j]*w+sSq)*(num[j]*w+sSq)); |
---|
| 1024 | |
---|
| 1025 | h[1] += 2.0*num[j]/(w*num[j]+sSq); |
---|
| 1026 | } |
---|
| 1027 | } |
---|
| 1028 | |
---|
| 1029 | return h; |
---|
| 1030 | } |
---|
| 1031 | |
---|
| 1032 | /** |
---|
| 1033 | * Returns the revision string. |
---|
| 1034 | * |
---|
| 1035 | * @return the revision |
---|
| 1036 | */ |
---|
| 1037 | public String getRevision() { |
---|
| 1038 | return RevisionUtils.extract("$Revision: 5481 $"); |
---|
| 1039 | } |
---|
| 1040 | } |
---|