1 | /* |
---|
2 | * This program is free software; you can redistribute it and/or modify |
---|
3 | * it under the terms of the GNU General Public License as published by |
---|
4 | * the Free Software Foundation; either version 2 of the License, or |
---|
5 | * (at your option) any later version. |
---|
6 | * |
---|
7 | * This program is distributed in the hope that it will be useful, |
---|
8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
---|
10 | * GNU General Public License for more details. |
---|
11 | * |
---|
12 | * You should have received a copy of the GNU General Public License |
---|
13 | * along with this program; if not, write to the Free Software |
---|
14 | * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. |
---|
15 | */ |
---|
16 | |
---|
17 | /* |
---|
18 | * TLD.java |
---|
19 | * Copyright (C) 2005 University of Waikato, Hamilton, New Zealand |
---|
20 | * |
---|
21 | */ |
---|
22 | |
---|
23 | package weka.classifiers.mi; |
---|
24 | |
---|
25 | import weka.classifiers.RandomizableClassifier; |
---|
26 | import weka.core.Capabilities; |
---|
27 | import weka.core.Instance; |
---|
28 | import weka.core.Instances; |
---|
29 | import weka.core.MultiInstanceCapabilitiesHandler; |
---|
30 | import weka.core.Optimization; |
---|
31 | import weka.core.Option; |
---|
32 | import weka.core.OptionHandler; |
---|
33 | import weka.core.RevisionUtils; |
---|
34 | import weka.core.TechnicalInformation; |
---|
35 | import weka.core.TechnicalInformationHandler; |
---|
36 | import weka.core.Utils; |
---|
37 | import weka.core.Capabilities.Capability; |
---|
38 | import weka.core.TechnicalInformation.Field; |
---|
39 | import weka.core.TechnicalInformation.Type; |
---|
40 | |
---|
41 | import java.util.Enumeration; |
---|
42 | import java.util.Random; |
---|
43 | import java.util.Vector; |
---|
44 | |
---|
45 | /** |
---|
46 | <!-- globalinfo-start --> |
---|
47 | * Two-Level Distribution approach, changes the starting value of the searching algorithm, supplement the cut-off modification and check missing values.<br/> |
---|
48 | * <br/> |
---|
49 | * For more information see:<br/> |
---|
50 | * <br/> |
---|
51 | * Xin Xu (2003). Statistical learning in multiple instance problem. Hamilton, NZ. |
---|
52 | * <p/> |
---|
53 | <!-- globalinfo-end --> |
---|
54 | * |
---|
55 | <!-- technical-bibtex-start --> |
---|
56 | * BibTeX: |
---|
57 | * <pre> |
---|
58 | * @mastersthesis{Xu2003, |
---|
59 | * address = {Hamilton, NZ}, |
---|
60 | * author = {Xin Xu}, |
---|
61 | * note = {0657.594}, |
---|
62 | * school = {University of Waikato}, |
---|
63 | * title = {Statistical learning in multiple instance problem}, |
---|
64 | * year = {2003} |
---|
65 | * } |
---|
66 | * </pre> |
---|
67 | * <p/> |
---|
68 | <!-- technical-bibtex-end --> |
---|
69 | * |
---|
70 | <!-- options-start --> |
---|
71 | * Valid options are: <p/> |
---|
72 | * |
---|
73 | * <pre> -C |
---|
74 | * Set whether or not use empirical |
---|
75 | * log-odds cut-off instead of 0</pre> |
---|
76 | * |
---|
77 | * <pre> -R <numOfRuns> |
---|
78 | * Set the number of multiple runs |
---|
79 | * needed for searching the MLE.</pre> |
---|
80 | * |
---|
81 | * <pre> -S <num> |
---|
82 | * Random number seed. |
---|
83 | * (default 1)</pre> |
---|
84 | * |
---|
85 | * <pre> -D |
---|
86 | * If set, classifier is run in debug mode and |
---|
87 | * may output additional info to the console</pre> |
---|
88 | * |
---|
89 | <!-- options-end --> |
---|
90 | * |
---|
91 | * @author Eibe Frank (eibe@cs.waikato.ac.nz) |
---|
92 | * @author Xin Xu (xx5@cs.waikato.ac.nz) |
---|
93 | * @version $Revision: 5481 $ |
---|
94 | */ |
---|
95 | public class TLD |
---|
96 | extends RandomizableClassifier |
---|
97 | implements OptionHandler, MultiInstanceCapabilitiesHandler, |
---|
98 | TechnicalInformationHandler { |
---|
99 | |
---|
100 | /** for serialization */ |
---|
101 | static final long serialVersionUID = 6657315525171152210L; |
---|
102 | |
---|
103 | /** The mean for each attribute of each positive exemplar */ |
---|
104 | protected double[][] m_MeanP = null; |
---|
105 | |
---|
106 | /** The variance for each attribute of each positive exemplar */ |
---|
107 | protected double[][] m_VarianceP = null; |
---|
108 | |
---|
109 | /** The mean for each attribute of each negative exemplar */ |
---|
110 | protected double[][] m_MeanN = null; |
---|
111 | |
---|
112 | /** The variance for each attribute of each negative exemplar */ |
---|
113 | protected double[][] m_VarianceN = null; |
---|
114 | |
---|
115 | /** The effective sum of weights of each positive exemplar in each dimension*/ |
---|
116 | protected double[][] m_SumP = null; |
---|
117 | |
---|
118 | /** The effective sum of weights of each negative exemplar in each dimension*/ |
---|
119 | protected double[][] m_SumN = null; |
---|
120 | |
---|
121 | /** The parameters to be estimated for each positive exemplar*/ |
---|
122 | protected double[] m_ParamsP = null; |
---|
123 | |
---|
124 | /** The parameters to be estimated for each negative exemplar*/ |
---|
125 | protected double[] m_ParamsN = null; |
---|
126 | |
---|
127 | /** The dimension of each exemplar, i.e. (numAttributes-2) */ |
---|
128 | protected int m_Dimension = 0; |
---|
129 | |
---|
130 | /** The class label of each exemplar */ |
---|
131 | protected double[] m_Class = null; |
---|
132 | |
---|
133 | /** The number of class labels in the data */ |
---|
134 | protected int m_NumClasses = 2; |
---|
135 | |
---|
136 | /** The very small number representing zero */ |
---|
137 | static public double ZERO = 1.0e-6; |
---|
138 | |
---|
139 | /** The number of runs to perform */ |
---|
140 | protected int m_Run = 1; |
---|
141 | |
---|
142 | protected double m_Cutoff; |
---|
143 | |
---|
144 | protected boolean m_UseEmpiricalCutOff = false; |
---|
145 | |
---|
146 | /** |
---|
147 | * Returns a string describing this filter |
---|
148 | * |
---|
149 | * @return a description of the filter suitable for |
---|
150 | * displaying in the explorer/experimenter gui |
---|
151 | */ |
---|
152 | public String globalInfo() { |
---|
153 | return |
---|
154 | "Two-Level Distribution approach, changes the starting value of " |
---|
155 | + "the searching algorithm, supplement the cut-off modification and " |
---|
156 | + "check missing values.\n\n" |
---|
157 | + "For more information see:\n\n" |
---|
158 | + getTechnicalInformation().toString(); |
---|
159 | } |
---|
160 | |
---|
161 | /** |
---|
162 | * Returns an instance of a TechnicalInformation object, containing |
---|
163 | * detailed information about the technical background of this class, |
---|
164 | * e.g., paper reference or book this class is based on. |
---|
165 | * |
---|
166 | * @return the technical information about this class |
---|
167 | */ |
---|
168 | public TechnicalInformation getTechnicalInformation() { |
---|
169 | TechnicalInformation result; |
---|
170 | |
---|
171 | result = new TechnicalInformation(Type.MASTERSTHESIS); |
---|
172 | result.setValue(Field.AUTHOR, "Xin Xu"); |
---|
173 | result.setValue(Field.YEAR, "2003"); |
---|
174 | result.setValue(Field.TITLE, "Statistical learning in multiple instance problem"); |
---|
175 | result.setValue(Field.SCHOOL, "University of Waikato"); |
---|
176 | result.setValue(Field.ADDRESS, "Hamilton, NZ"); |
---|
177 | result.setValue(Field.NOTE, "0657.594"); |
---|
178 | |
---|
179 | return result; |
---|
180 | } |
---|
181 | |
---|
182 | /** |
---|
183 | * Returns default capabilities of the classifier. |
---|
184 | * |
---|
185 | * @return the capabilities of this classifier |
---|
186 | */ |
---|
187 | public Capabilities getCapabilities() { |
---|
188 | Capabilities result = super.getCapabilities(); |
---|
189 | result.disableAll(); |
---|
190 | |
---|
191 | // attributes |
---|
192 | result.enable(Capability.NOMINAL_ATTRIBUTES); |
---|
193 | result.enable(Capability.RELATIONAL_ATTRIBUTES); |
---|
194 | result.enable(Capability.MISSING_VALUES); |
---|
195 | |
---|
196 | // class |
---|
197 | result.enable(Capability.BINARY_CLASS); |
---|
198 | result.enable(Capability.MISSING_CLASS_VALUES); |
---|
199 | |
---|
200 | // other |
---|
201 | result.enable(Capability.ONLY_MULTIINSTANCE); |
---|
202 | |
---|
203 | return result; |
---|
204 | } |
---|
205 | |
---|
206 | /** |
---|
207 | * Returns the capabilities of this multi-instance classifier for the |
---|
208 | * relational data. |
---|
209 | * |
---|
210 | * @return the capabilities of this object |
---|
211 | * @see Capabilities |
---|
212 | */ |
---|
213 | public Capabilities getMultiInstanceCapabilities() { |
---|
214 | Capabilities result = super.getCapabilities(); |
---|
215 | result.disableAll(); |
---|
216 | |
---|
217 | // attributes |
---|
218 | result.enable(Capability.NUMERIC_ATTRIBUTES); |
---|
219 | result.enable(Capability.MISSING_VALUES); |
---|
220 | |
---|
221 | // class |
---|
222 | result.disableAllClasses(); |
---|
223 | result.enable(Capability.NO_CLASS); |
---|
224 | |
---|
225 | return result; |
---|
226 | } |
---|
227 | |
---|
228 | /** |
---|
229 | * |
---|
230 | * @param exs the training exemplars |
---|
231 | * @throws Exception if the model cannot be built properly |
---|
232 | */ |
---|
233 | public void buildClassifier(Instances exs)throws Exception{ |
---|
234 | // can classifier handle the data? |
---|
235 | getCapabilities().testWithFail(exs); |
---|
236 | |
---|
237 | // remove instances with missing class |
---|
238 | exs = new Instances(exs); |
---|
239 | exs.deleteWithMissingClass(); |
---|
240 | |
---|
241 | int numegs = exs.numInstances(); |
---|
242 | m_Dimension = exs.attribute(1).relation(). numAttributes(); |
---|
243 | Instances pos = new Instances(exs, 0), neg = new Instances(exs, 0); |
---|
244 | |
---|
245 | for(int u=0; u<numegs; u++){ |
---|
246 | Instance example = exs.instance(u); |
---|
247 | if(example.classValue() == 1) |
---|
248 | pos.add(example); |
---|
249 | else |
---|
250 | neg.add(example); |
---|
251 | } |
---|
252 | |
---|
253 | int pnum = pos.numInstances(), nnum = neg.numInstances(); |
---|
254 | |
---|
255 | m_MeanP = new double[pnum][m_Dimension]; |
---|
256 | m_VarianceP = new double[pnum][m_Dimension]; |
---|
257 | m_SumP = new double[pnum][m_Dimension]; |
---|
258 | m_MeanN = new double[nnum][m_Dimension]; |
---|
259 | m_VarianceN = new double[nnum][m_Dimension]; |
---|
260 | m_SumN = new double[nnum][m_Dimension]; |
---|
261 | m_ParamsP = new double[4*m_Dimension]; |
---|
262 | m_ParamsN = new double[4*m_Dimension]; |
---|
263 | |
---|
264 | // Estimation of the parameters: as the start value for search |
---|
265 | double[] pSumVal=new double[m_Dimension], // for m |
---|
266 | nSumVal=new double[m_Dimension]; |
---|
267 | double[] maxVarsP=new double[m_Dimension], // for a |
---|
268 | maxVarsN=new double[m_Dimension]; |
---|
269 | // Mean of sample variances: for b, b=a/E(\sigma^2)+2 |
---|
270 | double[] varMeanP = new double[m_Dimension], |
---|
271 | varMeanN = new double[m_Dimension]; |
---|
272 | // Variances of sample means: for w, w=E[var(\mu)]/E[\sigma^2] |
---|
273 | double[] meanVarP = new double[m_Dimension], |
---|
274 | meanVarN = new double[m_Dimension]; |
---|
275 | // number of exemplars without all values missing |
---|
276 | double[] numExsP = new double[m_Dimension], |
---|
277 | numExsN = new double[m_Dimension]; |
---|
278 | |
---|
279 | // Extract metadata fro both positive and negative bags |
---|
280 | for(int v=0; v < pnum; v++){ |
---|
281 | /*Exemplar px = pos.exemplar(v); |
---|
282 | m_MeanP[v] = px.meanOrMode(); |
---|
283 | m_VarianceP[v] = px.variance(); |
---|
284 | Instances pxi = px.getInstances(); |
---|
285 | */ |
---|
286 | |
---|
287 | Instances pxi = pos.instance(v).relationalValue(1); |
---|
288 | for (int k=0; k<pxi.numAttributes(); k++) { |
---|
289 | m_MeanP[v][k] = pxi.meanOrMode(k); |
---|
290 | m_VarianceP[v][k] = pxi.variance(k); |
---|
291 | } |
---|
292 | |
---|
293 | for (int w=0,t=0; w < m_Dimension; w++,t++){ |
---|
294 | //if((t==m_ClassIndex) || (t==m_IdIndex)) |
---|
295 | // t++; |
---|
296 | |
---|
297 | if(!Double.isNaN(m_MeanP[v][w])){ |
---|
298 | for(int u=0;u<pxi.numInstances();u++){ |
---|
299 | Instance ins = pxi.instance(u); |
---|
300 | if(!ins.isMissing(t)) |
---|
301 | m_SumP[v][w] += ins.weight(); |
---|
302 | } |
---|
303 | numExsP[w]++; |
---|
304 | pSumVal[w] += m_MeanP[v][w]; |
---|
305 | meanVarP[w] += m_MeanP[v][w]*m_MeanP[v][w]; |
---|
306 | if(maxVarsP[w] < m_VarianceP[v][w]) |
---|
307 | maxVarsP[w] = m_VarianceP[v][w]; |
---|
308 | varMeanP[w] += m_VarianceP[v][w]; |
---|
309 | m_VarianceP[v][w] *= (m_SumP[v][w]-1.0); |
---|
310 | if(m_VarianceP[v][w] < 0.0) |
---|
311 | m_VarianceP[v][w] = 0.0; |
---|
312 | } |
---|
313 | } |
---|
314 | } |
---|
315 | |
---|
316 | for(int v=0; v < nnum; v++){ |
---|
317 | /*Exemplar nx = neg.exemplar(v); |
---|
318 | m_MeanN[v] = nx.meanOrMode(); |
---|
319 | m_VarianceN[v] = nx.variance(); |
---|
320 | Instances nxi = nx.getInstances(); |
---|
321 | */ |
---|
322 | Instances nxi = neg.instance(v).relationalValue(1); |
---|
323 | for (int k=0; k<nxi.numAttributes(); k++) { |
---|
324 | m_MeanN[v][k] = nxi.meanOrMode(k); |
---|
325 | m_VarianceN[v][k] = nxi.variance(k); |
---|
326 | } |
---|
327 | |
---|
328 | for (int w=0,t=0; w < m_Dimension; w++,t++){ |
---|
329 | //if((t==m_ClassIndex) || (t==m_IdIndex)) |
---|
330 | // t++; |
---|
331 | |
---|
332 | if(!Double.isNaN(m_MeanN[v][w])){ |
---|
333 | for(int u=0;u<nxi.numInstances();u++) |
---|
334 | if(!nxi.instance(u).isMissing(t)) |
---|
335 | m_SumN[v][w] += nxi.instance(u).weight(); |
---|
336 | numExsN[w]++; |
---|
337 | nSumVal[w] += m_MeanN[v][w]; |
---|
338 | meanVarN[w] += m_MeanN[v][w]*m_MeanN[v][w]; |
---|
339 | if(maxVarsN[w] < m_VarianceN[v][w]) |
---|
340 | maxVarsN[w] = m_VarianceN[v][w]; |
---|
341 | varMeanN[w] += m_VarianceN[v][w]; |
---|
342 | m_VarianceN[v][w] *= (m_SumN[v][w]-1.0); |
---|
343 | if(m_VarianceN[v][w] < 0.0) |
---|
344 | m_VarianceN[v][w] = 0.0; |
---|
345 | } |
---|
346 | } |
---|
347 | } |
---|
348 | |
---|
349 | for(int w=0; w<m_Dimension; w++){ |
---|
350 | pSumVal[w] /= numExsP[w]; |
---|
351 | nSumVal[w] /= numExsN[w]; |
---|
352 | if(numExsP[w]>1) |
---|
353 | meanVarP[w] = meanVarP[w]/(numExsP[w]-1.0) |
---|
354 | - pSumVal[w]*numExsP[w]/(numExsP[w]-1.0); |
---|
355 | if(numExsN[w]>1) |
---|
356 | meanVarN[w] = meanVarN[w]/(numExsN[w]-1.0) |
---|
357 | - nSumVal[w]*numExsN[w]/(numExsN[w]-1.0); |
---|
358 | varMeanP[w] /= numExsP[w]; |
---|
359 | varMeanN[w] /= numExsN[w]; |
---|
360 | } |
---|
361 | |
---|
362 | //Bounds and parameter values for each run |
---|
363 | double[][] bounds = new double[2][4]; |
---|
364 | double[] pThisParam = new double[4], |
---|
365 | nThisParam = new double[4]; |
---|
366 | |
---|
367 | // Initial values for parameters |
---|
368 | double a, b, w, m; |
---|
369 | |
---|
370 | // Optimize for one dimension |
---|
371 | for (int x=0; x < m_Dimension; x++){ |
---|
372 | if (getDebug()) |
---|
373 | System.err.println("\n\n!!!!!!!!!!!!!!!!!!!!!!???Dimension #"+x); |
---|
374 | |
---|
375 | // Positive examplars: first run |
---|
376 | a = (maxVarsP[x]>ZERO) ? maxVarsP[x]:1.0; |
---|
377 | if (varMeanP[x]<=ZERO) varMeanP[x] = ZERO; // modified by LinDong (09/2005) |
---|
378 | b = a/varMeanP[x]+2.0; // a/(b-2) = E(\sigma^2) |
---|
379 | w = meanVarP[x]/varMeanP[x]; // E[var(\mu)] = w*E[\sigma^2] |
---|
380 | if(w<=ZERO) w=1.0; |
---|
381 | |
---|
382 | m = pSumVal[x]; |
---|
383 | pThisParam[0] = a; // a |
---|
384 | pThisParam[1] = b; // b |
---|
385 | pThisParam[2] = w; // w |
---|
386 | pThisParam[3] = m; // m |
---|
387 | |
---|
388 | // Negative examplars: first run |
---|
389 | a = (maxVarsN[x]>ZERO) ? maxVarsN[x]:1.0; |
---|
390 | if (varMeanN[x]<=ZERO) varMeanN[x] = ZERO; // modified by LinDong (09/2005) |
---|
391 | b = a/varMeanN[x]+2.0; // a/(b-2) = E(\sigma^2) |
---|
392 | w = meanVarN[x]/varMeanN[x]; // E[var(\mu)] = w*E[\sigma^2] |
---|
393 | if(w<=ZERO) w=1.0; |
---|
394 | |
---|
395 | m = nSumVal[x]; |
---|
396 | nThisParam[0] = a; // a |
---|
397 | nThisParam[1] = b; // b |
---|
398 | nThisParam[2] = w; // w |
---|
399 | nThisParam[3] = m; // m |
---|
400 | |
---|
401 | // Bound constraints |
---|
402 | bounds[0][0] = ZERO; // a > 0 |
---|
403 | bounds[0][1] = 2.0+ZERO; // b > 2 |
---|
404 | bounds[0][2] = ZERO; // w > 0 |
---|
405 | bounds[0][3] = Double.NaN; |
---|
406 | |
---|
407 | for(int t=0; t<4; t++){ |
---|
408 | bounds[1][t] = Double.NaN; |
---|
409 | m_ParamsP[4*x+t] = pThisParam[t]; |
---|
410 | m_ParamsN[4*x+t] = nThisParam[t]; |
---|
411 | } |
---|
412 | double pminVal=Double.MAX_VALUE, nminVal=Double.MAX_VALUE; |
---|
413 | Random whichEx = new Random(m_Seed); |
---|
414 | TLD_Optm pOp=null, nOp=null; |
---|
415 | boolean isRunValid = true; |
---|
416 | double[] sumP=new double[pnum], meanP=new double[pnum], |
---|
417 | varP=new double[pnum]; |
---|
418 | double[] sumN=new double[nnum], meanN=new double[nnum], |
---|
419 | varN=new double[nnum]; |
---|
420 | |
---|
421 | // One dimension |
---|
422 | for(int p=0; p<pnum; p++){ |
---|
423 | sumP[p] = m_SumP[p][x]; |
---|
424 | meanP[p] = m_MeanP[p][x]; |
---|
425 | varP[p] = m_VarianceP[p][x]; |
---|
426 | } |
---|
427 | for(int q=0; q<nnum; q++){ |
---|
428 | sumN[q] = m_SumN[q][x]; |
---|
429 | meanN[q] = m_MeanN[q][x]; |
---|
430 | varN[q] = m_VarianceN[q][x]; |
---|
431 | } |
---|
432 | |
---|
433 | for(int y=0; y<m_Run;){ |
---|
434 | if (getDebug()) |
---|
435 | System.err.println("\n\n!!!!!!!!!!!!!!!!!!!!!!???Run #"+y); |
---|
436 | double thisMin; |
---|
437 | |
---|
438 | if (getDebug()) |
---|
439 | System.err.println("\nPositive exemplars"); |
---|
440 | pOp = new TLD_Optm(); |
---|
441 | pOp.setNum(sumP); |
---|
442 | pOp.setSSquare(varP); |
---|
443 | pOp.setXBar(meanP); |
---|
444 | |
---|
445 | pThisParam = pOp.findArgmin(pThisParam, bounds); |
---|
446 | while(pThisParam==null){ |
---|
447 | pThisParam = pOp.getVarbValues(); |
---|
448 | if (getDebug()) |
---|
449 | System.err.println("!!! 200 iterations finished, not enough!"); |
---|
450 | pThisParam = pOp.findArgmin(pThisParam, bounds); |
---|
451 | } |
---|
452 | |
---|
453 | thisMin = pOp.getMinFunction(); |
---|
454 | if(!Double.isNaN(thisMin) && (thisMin<pminVal)){ |
---|
455 | pminVal = thisMin; |
---|
456 | for(int z=0; z<4; z++) |
---|
457 | m_ParamsP[4*x+z] = pThisParam[z]; |
---|
458 | } |
---|
459 | |
---|
460 | if(Double.isNaN(thisMin)){ |
---|
461 | pThisParam = new double[4]; |
---|
462 | isRunValid =false; |
---|
463 | } |
---|
464 | |
---|
465 | if (getDebug()) |
---|
466 | System.err.println("\nNegative exemplars"); |
---|
467 | nOp = new TLD_Optm(); |
---|
468 | nOp.setNum(sumN); |
---|
469 | nOp.setSSquare(varN); |
---|
470 | nOp.setXBar(meanN); |
---|
471 | |
---|
472 | nThisParam = nOp.findArgmin(nThisParam, bounds); |
---|
473 | while(nThisParam==null){ |
---|
474 | nThisParam = nOp.getVarbValues(); |
---|
475 | if (getDebug()) |
---|
476 | System.err.println("!!! 200 iterations finished, not enough!"); |
---|
477 | nThisParam = nOp.findArgmin(nThisParam, bounds); |
---|
478 | } |
---|
479 | thisMin = nOp.getMinFunction(); |
---|
480 | if(!Double.isNaN(thisMin) && (thisMin<nminVal)){ |
---|
481 | nminVal = thisMin; |
---|
482 | for(int z=0; z<4; z++) |
---|
483 | m_ParamsN[4*x+z] = nThisParam[z]; |
---|
484 | } |
---|
485 | |
---|
486 | if(Double.isNaN(thisMin)){ |
---|
487 | nThisParam = new double[4]; |
---|
488 | isRunValid =false; |
---|
489 | } |
---|
490 | |
---|
491 | if(!isRunValid){ y--; isRunValid=true; } |
---|
492 | |
---|
493 | if(++y<m_Run){ |
---|
494 | // Change the initial parameters and restart |
---|
495 | int pone = whichEx.nextInt(pnum), // Randomly pick one pos. exmpl. |
---|
496 | none = whichEx.nextInt(nnum); |
---|
497 | |
---|
498 | // Positive exemplars: next run |
---|
499 | while((m_SumP[pone][x]<=1.0)||Double.isNaN(m_MeanP[pone][x])) |
---|
500 | pone = whichEx.nextInt(pnum); |
---|
501 | |
---|
502 | a = m_VarianceP[pone][x]/(m_SumP[pone][x]-1.0); |
---|
503 | if(a<=ZERO) a=m_ParamsN[4*x]; // Change to negative params |
---|
504 | m = m_MeanP[pone][x]; |
---|
505 | double sq = (m-m_ParamsP[4*x+3])*(m-m_ParamsP[4*x+3]); |
---|
506 | |
---|
507 | b = a*m_ParamsP[4*x+2]/sq+2.0; // b=a/Var+2, assuming Var=Sq/w' |
---|
508 | if((b<=ZERO) || Double.isNaN(b) || Double.isInfinite(b)) |
---|
509 | b=m_ParamsN[4*x+1]; |
---|
510 | |
---|
511 | w = sq*(m_ParamsP[4*x+1]-2.0)/m_ParamsP[4*x];//w=Sq/Var, assuming Var=a'/(b'-2) |
---|
512 | if((w<=ZERO) || Double.isNaN(w) || Double.isInfinite(w)) |
---|
513 | w=m_ParamsN[4*x+2]; |
---|
514 | |
---|
515 | pThisParam[0] = a; // a |
---|
516 | pThisParam[1] = b; // b |
---|
517 | pThisParam[2] = w; // w |
---|
518 | pThisParam[3] = m; // m |
---|
519 | |
---|
520 | // Negative exemplars: next run |
---|
521 | while((m_SumN[none][x]<=1.0)||Double.isNaN(m_MeanN[none][x])) |
---|
522 | none = whichEx.nextInt(nnum); |
---|
523 | |
---|
524 | a = m_VarianceN[none][x]/(m_SumN[none][x]-1.0); |
---|
525 | if(a<=ZERO) a=m_ParamsP[4*x]; |
---|
526 | m = m_MeanN[none][x]; |
---|
527 | sq = (m-m_ParamsN[4*x+3])*(m-m_ParamsN[4*x+3]); |
---|
528 | |
---|
529 | b = a*m_ParamsN[4*x+2]/sq+2.0; // b=a/Var+2, assuming Var=Sq/w' |
---|
530 | if((b<=ZERO) || Double.isNaN(b) || Double.isInfinite(b)) |
---|
531 | b=m_ParamsP[4*x+1]; |
---|
532 | |
---|
533 | w = sq*(m_ParamsN[4*x+1]-2.0)/m_ParamsN[4*x];//w=Sq/Var, assuming Var=a'/(b'-2) |
---|
534 | if((w<=ZERO) || Double.isNaN(w) || Double.isInfinite(w)) |
---|
535 | w=m_ParamsP[4*x+2]; |
---|
536 | |
---|
537 | nThisParam[0] = a; // a |
---|
538 | nThisParam[1] = b; // b |
---|
539 | nThisParam[2] = w; // w |
---|
540 | nThisParam[3] = m; // m |
---|
541 | } |
---|
542 | } |
---|
543 | } |
---|
544 | |
---|
545 | for (int x=0, y=0; x<m_Dimension; x++, y++){ |
---|
546 | //if((x==exs.classIndex()) || (x==exs.idIndex())) |
---|
547 | //y++; |
---|
548 | a=m_ParamsP[4*x]; b=m_ParamsP[4*x+1]; |
---|
549 | w=m_ParamsP[4*x+2]; m=m_ParamsP[4*x+3]; |
---|
550 | if (getDebug()) |
---|
551 | System.err.println("\n\n???Positive: ( "+exs.attribute(1).relation().attribute(y)+ |
---|
552 | "): a="+a+", b="+b+", w="+w+", m="+m); |
---|
553 | |
---|
554 | a=m_ParamsN[4*x]; b=m_ParamsN[4*x+1]; |
---|
555 | w=m_ParamsN[4*x+2]; m=m_ParamsN[4*x+3]; |
---|
556 | if (getDebug()) |
---|
557 | System.err.println("???Negative: ("+exs.attribute(1).relation().attribute(y)+ |
---|
558 | "): a="+a+", b="+b+", w="+w+", m="+m); |
---|
559 | } |
---|
560 | |
---|
561 | if(m_UseEmpiricalCutOff){ |
---|
562 | // Find the empirical cut-off |
---|
563 | double[] pLogOdds=new double[pnum], nLogOdds=new double[nnum]; |
---|
564 | for(int p=0; p<pnum; p++) |
---|
565 | pLogOdds[p] = |
---|
566 | likelihoodRatio(m_SumP[p], m_MeanP[p], m_VarianceP[p]); |
---|
567 | |
---|
568 | for(int q=0; q<nnum; q++) |
---|
569 | nLogOdds[q] = |
---|
570 | likelihoodRatio(m_SumN[q], m_MeanN[q], m_VarianceN[q]); |
---|
571 | |
---|
572 | // Update m_Cutoff |
---|
573 | findCutOff(pLogOdds, nLogOdds); |
---|
574 | } |
---|
575 | else |
---|
576 | m_Cutoff = -Math.log((double)pnum/(double)nnum); |
---|
577 | |
---|
578 | if (getDebug()) |
---|
579 | System.err.println("???Cut-off="+m_Cutoff); |
---|
580 | } |
---|
581 | |
---|
582 | /** |
---|
583 | * |
---|
584 | * @param ex the given test exemplar |
---|
585 | * @return the classification |
---|
586 | * @throws Exception if the exemplar could not be classified |
---|
587 | * successfully |
---|
588 | */ |
---|
589 | public double classifyInstance(Instance ex)throws Exception{ |
---|
590 | //Exemplar ex = new Exemplar(e); |
---|
591 | Instances exi = ex.relationalValue(1); |
---|
592 | double[] n = new double[m_Dimension]; |
---|
593 | double [] xBar = new double[m_Dimension]; |
---|
594 | double [] sSq = new double[m_Dimension]; |
---|
595 | for (int i=0; i<exi.numAttributes() ; i++){ |
---|
596 | xBar[i] = exi.meanOrMode(i); |
---|
597 | sSq[i] = exi.variance(i); |
---|
598 | } |
---|
599 | |
---|
600 | for (int w=0, t=0; w < m_Dimension; w++, t++){ |
---|
601 | //if((t==m_ClassIndex) || (t==m_IdIndex)) |
---|
602 | //t++; |
---|
603 | for(int u=0;u<exi.numInstances();u++) |
---|
604 | if(!exi.instance(u).isMissing(t)) |
---|
605 | n[w] += exi.instance(u).weight(); |
---|
606 | |
---|
607 | sSq[w] = sSq[w]*(n[w]-1.0); |
---|
608 | if(sSq[w] <= 0.0) |
---|
609 | sSq[w] = 0.0; |
---|
610 | } |
---|
611 | |
---|
612 | double logOdds = likelihoodRatio(n, xBar, sSq); |
---|
613 | return (logOdds > m_Cutoff) ? 1 : 0 ; |
---|
614 | } |
---|
615 | |
---|
616 | private double likelihoodRatio(double[] n, double[] xBar, double[] sSq){ |
---|
617 | double LLP = 0.0, LLN = 0.0; |
---|
618 | |
---|
619 | for (int x=0; x<m_Dimension; x++){ |
---|
620 | if(Double.isNaN(xBar[x])) continue; // All missing values |
---|
621 | |
---|
622 | int halfN = ((int)n[x])/2; |
---|
623 | //Log-likelihood for positive |
---|
624 | double a=m_ParamsP[4*x], b=m_ParamsP[4*x+1], |
---|
625 | w=m_ParamsP[4*x+2], m=m_ParamsP[4*x+3]; |
---|
626 | LLP += 0.5*b*Math.log(a) + 0.5*(b+n[x]-1.0)*Math.log(1.0+n[x]*w) |
---|
627 | - 0.5*(b+n[x])*Math.log((1.0+n[x]*w)*(a+sSq[x])+ |
---|
628 | n[x]*(xBar[x]-m)*(xBar[x]-m)) |
---|
629 | - 0.5*n[x]*Math.log(Math.PI); |
---|
630 | for(int y=1; y<=halfN; y++) |
---|
631 | LLP += Math.log(b/2.0+n[x]/2.0-(double)y); |
---|
632 | |
---|
633 | if(n[x]/2.0 > halfN) // n is odd |
---|
634 | LLP += TLD_Optm.diffLnGamma(b/2.0); |
---|
635 | |
---|
636 | //Log-likelihood for negative |
---|
637 | a=m_ParamsN[4*x]; |
---|
638 | b=m_ParamsN[4*x+1]; |
---|
639 | w=m_ParamsN[4*x+2]; |
---|
640 | m=m_ParamsN[4*x+3]; |
---|
641 | LLN += 0.5*b*Math.log(a) + 0.5*(b+n[x]-1.0)*Math.log(1.0+n[x]*w) |
---|
642 | - 0.5*(b+n[x])*Math.log((1.0+n[x]*w)*(a+sSq[x])+ |
---|
643 | n[x]*(xBar[x]-m)*(xBar[x]-m)) |
---|
644 | - 0.5*n[x]*Math.log(Math.PI); |
---|
645 | for(int y=1; y<=halfN; y++) |
---|
646 | LLN += Math.log(b/2.0+n[x]/2.0-(double)y); |
---|
647 | |
---|
648 | if(n[x]/2.0 > halfN) // n is odd |
---|
649 | LLN += TLD_Optm.diffLnGamma(b/2.0); |
---|
650 | } |
---|
651 | |
---|
652 | return LLP - LLN; |
---|
653 | } |
---|
654 | |
---|
655 | private void findCutOff(double[] pos, double[] neg){ |
---|
656 | int[] pOrder = Utils.sort(pos), |
---|
657 | nOrder = Utils.sort(neg); |
---|
658 | /* |
---|
659 | System.err.println("\n\n???Positive: "); |
---|
660 | for(int t=0; t<pOrder.length; t++) |
---|
661 | System.err.print(t+":"+Utils.doubleToString(pos[pOrder[t]],0,2)+" "); |
---|
662 | System.err.println("\n\n???Negative: "); |
---|
663 | for(int t=0; t<nOrder.length; t++) |
---|
664 | System.err.print(t+":"+Utils.doubleToString(neg[nOrder[t]],0,2)+" "); |
---|
665 | */ |
---|
666 | int pNum = pos.length, nNum = neg.length, count, p=0, n=0; |
---|
667 | double fstAccu=0.0, sndAccu=(double)pNum, split; |
---|
668 | double maxAccu = 0, minDistTo0 = Double.MAX_VALUE; |
---|
669 | |
---|
670 | // Skip continuous negatives |
---|
671 | for(;(n<nNum)&&(pos[pOrder[0]]>=neg[nOrder[n]]); n++, fstAccu++); |
---|
672 | |
---|
673 | if(n>=nNum){ // totally seperate |
---|
674 | m_Cutoff = (neg[nOrder[nNum-1]]+pos[pOrder[0]])/2.0; |
---|
675 | //m_Cutoff = neg[nOrder[nNum-1]]; |
---|
676 | return; |
---|
677 | } |
---|
678 | |
---|
679 | count=n; |
---|
680 | while((p<pNum)&&(n<nNum)){ |
---|
681 | // Compare the next in the two lists |
---|
682 | if(pos[pOrder[p]]>=neg[nOrder[n]]){ // Neg has less log-odds |
---|
683 | fstAccu += 1.0; |
---|
684 | split=neg[nOrder[n]]; |
---|
685 | n++; |
---|
686 | } |
---|
687 | else{ |
---|
688 | sndAccu -= 1.0; |
---|
689 | split=pos[pOrder[p]]; |
---|
690 | p++; |
---|
691 | } |
---|
692 | count++; |
---|
693 | if((fstAccu+sndAccu > maxAccu) |
---|
694 | || ((fstAccu+sndAccu == maxAccu) && (Math.abs(split)<minDistTo0))){ |
---|
695 | maxAccu = fstAccu+sndAccu; |
---|
696 | m_Cutoff = split; |
---|
697 | minDistTo0 = Math.abs(split); |
---|
698 | } |
---|
699 | } |
---|
700 | } |
---|
701 | |
---|
702 | /** |
---|
703 | * Returns an enumeration describing the available options |
---|
704 | * |
---|
705 | * @return an enumeration of all the available options |
---|
706 | */ |
---|
707 | public Enumeration listOptions() { |
---|
708 | Vector result = new Vector(); |
---|
709 | |
---|
710 | result.addElement(new Option( |
---|
711 | "\tSet whether or not use empirical\n" |
---|
712 | + "\tlog-odds cut-off instead of 0", |
---|
713 | "C", 0, "-C")); |
---|
714 | |
---|
715 | result.addElement(new Option( |
---|
716 | "\tSet the number of multiple runs \n" |
---|
717 | + "\tneeded for searching the MLE.", |
---|
718 | "R", 1, "-R <numOfRuns>")); |
---|
719 | |
---|
720 | Enumeration enu = super.listOptions(); |
---|
721 | while (enu.hasMoreElements()) { |
---|
722 | result.addElement(enu.nextElement()); |
---|
723 | } |
---|
724 | |
---|
725 | return result.elements(); |
---|
726 | } |
---|
727 | |
---|
728 | /** |
---|
729 | * Parses a given list of options. <p/> |
---|
730 | * |
---|
731 | <!-- options-start --> |
---|
732 | * Valid options are: <p/> |
---|
733 | * |
---|
734 | * <pre> -C |
---|
735 | * Set whether or not use empirical |
---|
736 | * log-odds cut-off instead of 0</pre> |
---|
737 | * |
---|
738 | * <pre> -R <numOfRuns> |
---|
739 | * Set the number of multiple runs |
---|
740 | * needed for searching the MLE.</pre> |
---|
741 | * |
---|
742 | * <pre> -S <num> |
---|
743 | * Random number seed. |
---|
744 | * (default 1)</pre> |
---|
745 | * |
---|
746 | * <pre> -D |
---|
747 | * If set, classifier is run in debug mode and |
---|
748 | * may output additional info to the console</pre> |
---|
749 | * |
---|
750 | <!-- options-end --> |
---|
751 | * |
---|
752 | * @param options the list of options as an array of strings |
---|
753 | * @throws Exception if an option is not supported |
---|
754 | */ |
---|
755 | public void setOptions(String[] options) throws Exception{ |
---|
756 | setDebug(Utils.getFlag('D', options)); |
---|
757 | |
---|
758 | setUsingCutOff(Utils.getFlag('C', options)); |
---|
759 | |
---|
760 | String runString = Utils.getOption('R', options); |
---|
761 | if (runString.length() != 0) |
---|
762 | setNumRuns(Integer.parseInt(runString)); |
---|
763 | else |
---|
764 | setNumRuns(1); |
---|
765 | |
---|
766 | super.setOptions(options); |
---|
767 | } |
---|
768 | |
---|
769 | /** |
---|
770 | * Gets the current settings of the Classifier. |
---|
771 | * |
---|
772 | * @return an array of strings suitable for passing to setOptions |
---|
773 | */ |
---|
774 | public String[] getOptions() { |
---|
775 | Vector result; |
---|
776 | String[] options; |
---|
777 | int i; |
---|
778 | |
---|
779 | result = new Vector(); |
---|
780 | options = super.getOptions(); |
---|
781 | for (i = 0; i < options.length; i++) |
---|
782 | result.add(options[i]); |
---|
783 | |
---|
784 | if (getDebug()) |
---|
785 | result.add("-D"); |
---|
786 | |
---|
787 | if (getUsingCutOff()) |
---|
788 | result.add("-C"); |
---|
789 | |
---|
790 | result.add("-R"); |
---|
791 | result.add("" + getNumRuns()); |
---|
792 | |
---|
793 | return (String[]) result.toArray(new String[result.size()]); |
---|
794 | } |
---|
795 | |
---|
796 | /** |
---|
797 | * Returns the tip text for this property |
---|
798 | * |
---|
799 | * @return tip text for this property suitable for |
---|
800 | * displaying in the explorer/experimenter gui |
---|
801 | */ |
---|
802 | public String numRunsTipText() { |
---|
803 | return "The number of runs to perform."; |
---|
804 | } |
---|
805 | |
---|
806 | /** |
---|
807 | * Sets the number of runs to perform. |
---|
808 | * |
---|
809 | * @param numRuns the number of runs to perform |
---|
810 | */ |
---|
811 | public void setNumRuns(int numRuns) { |
---|
812 | m_Run = numRuns; |
---|
813 | } |
---|
814 | |
---|
815 | /** |
---|
816 | * Returns the number of runs to perform. |
---|
817 | * |
---|
818 | * @return the number of runs to perform |
---|
819 | */ |
---|
820 | public int getNumRuns() { |
---|
821 | return m_Run; |
---|
822 | } |
---|
823 | |
---|
824 | /** |
---|
825 | * Returns the tip text for this property |
---|
826 | * |
---|
827 | * @return tip text for this property suitable for |
---|
828 | * displaying in the explorer/experimenter gui |
---|
829 | */ |
---|
830 | public String usingCutOffTipText() { |
---|
831 | return "Whether to use an empirical cutoff."; |
---|
832 | } |
---|
833 | |
---|
834 | /** |
---|
835 | * Sets whether to use an empirical cutoff. |
---|
836 | * |
---|
837 | * @param cutOff whether to use an empirical cutoff |
---|
838 | */ |
---|
839 | public void setUsingCutOff (boolean cutOff) { |
---|
840 | m_UseEmpiricalCutOff = cutOff; |
---|
841 | } |
---|
842 | |
---|
843 | /** |
---|
844 | * Returns whether an empirical cutoff is used |
---|
845 | * |
---|
846 | * @return true if an empirical cutoff is used |
---|
847 | */ |
---|
848 | public boolean getUsingCutOff() { |
---|
849 | return m_UseEmpiricalCutOff; |
---|
850 | } |
---|
851 | |
---|
852 | /** |
---|
853 | * Returns the revision string. |
---|
854 | * |
---|
855 | * @return the revision |
---|
856 | */ |
---|
857 | public String getRevision() { |
---|
858 | return RevisionUtils.extract("$Revision: 5481 $"); |
---|
859 | } |
---|
860 | |
---|
861 | /** |
---|
862 | * Main method for testing. |
---|
863 | * |
---|
864 | * @param args the options for the classifier |
---|
865 | */ |
---|
866 | public static void main(String[] args) { |
---|
867 | runClassifier(new TLD(), args); |
---|
868 | } |
---|
869 | } |
---|
870 | |
---|
871 | class TLD_Optm extends Optimization { |
---|
872 | |
---|
873 | private double[] num; |
---|
874 | private double[] sSq; |
---|
875 | private double[] xBar; |
---|
876 | |
---|
877 | public void setNum(double[] n) {num = n;} |
---|
878 | public void setSSquare(double[] s){sSq = s;} |
---|
879 | public void setXBar(double[] x){xBar = x;} |
---|
880 | |
---|
881 | /** |
---|
882 | * Compute Ln[Gamma(b+0.5)] - Ln[Gamma(b)] |
---|
883 | * |
---|
884 | * @param b the value in the above formula |
---|
885 | * @return the result |
---|
886 | */ |
---|
887 | public static double diffLnGamma(double b){ |
---|
888 | double[] coef= {76.18009172947146, -86.50532032941677, |
---|
889 | 24.01409824083091, -1.231739572450155, |
---|
890 | 0.1208650973866179e-2, -0.5395239384953e-5}; |
---|
891 | double rt = -0.5; |
---|
892 | rt += (b+1.0)*Math.log(b+6.0) - (b+0.5)*Math.log(b+5.5); |
---|
893 | double series1=1.000000000190015, series2=1.000000000190015; |
---|
894 | for(int i=0; i<6; i++){ |
---|
895 | series1 += coef[i]/(b+1.5+(double)i); |
---|
896 | series2 += coef[i]/(b+1.0+(double)i); |
---|
897 | } |
---|
898 | |
---|
899 | rt += Math.log(series1*b)-Math.log(series2*(b+0.5)); |
---|
900 | return rt; |
---|
901 | } |
---|
902 | |
---|
903 | /** |
---|
904 | * Compute dLn[Gamma(x+0.5)]/dx - dLn[Gamma(x)]/dx |
---|
905 | * |
---|
906 | * @param x the value in the above formula |
---|
907 | * @return the result |
---|
908 | */ |
---|
909 | protected double diffFstDervLnGamma(double x){ |
---|
910 | double rt=0, series=1.0;// Just make it >0 |
---|
911 | for(int i=0;series>=m_Zero*1e-3;i++){ |
---|
912 | series = 0.5/((x+(double)i)*(x+(double)i+0.5)); |
---|
913 | rt += series; |
---|
914 | } |
---|
915 | return rt; |
---|
916 | } |
---|
917 | |
---|
918 | /** |
---|
919 | * Compute {Ln[Gamma(x+0.5)]}'' - {Ln[Gamma(x)]}'' |
---|
920 | * |
---|
921 | * @param x the value in the above formula |
---|
922 | * @return the result |
---|
923 | */ |
---|
924 | protected double diffSndDervLnGamma(double x){ |
---|
925 | double rt=0, series=1.0;// Just make it >0 |
---|
926 | for(int i=0;series>=m_Zero*1e-3;i++){ |
---|
927 | series = (x+(double)i+0.25)/ |
---|
928 | ((x+(double)i)*(x+(double)i)*(x+(double)i+0.5)*(x+(double)i+0.5)); |
---|
929 | rt -= series; |
---|
930 | } |
---|
931 | return rt; |
---|
932 | } |
---|
933 | |
---|
934 | /** |
---|
935 | * Implement this procedure to evaluate objective |
---|
936 | * function to be minimized |
---|
937 | */ |
---|
938 | protected double objectiveFunction(double[] x){ |
---|
939 | int numExs = num.length; |
---|
940 | double NLL = 0; // Negative Log-Likelihood |
---|
941 | |
---|
942 | double a=x[0], b=x[1], w=x[2], m=x[3]; |
---|
943 | for(int j=0; j < numExs; j++){ |
---|
944 | |
---|
945 | if(Double.isNaN(xBar[j])) continue; // All missing values |
---|
946 | |
---|
947 | NLL += 0.5*(b+num[j])* |
---|
948 | Math.log((1.0+num[j]*w)*(a+sSq[j]) + |
---|
949 | num[j]*(xBar[j]-m)*(xBar[j]-m)); |
---|
950 | |
---|
951 | if(Double.isNaN(NLL) && m_Debug){ |
---|
952 | System.err.println("???????????1: "+a+" "+b+" "+w+" "+m |
---|
953 | +"|x-: "+xBar[j] + |
---|
954 | "|n: "+num[j] + "|S^2: "+sSq[j]); |
---|
955 | System.exit(1); |
---|
956 | } |
---|
957 | |
---|
958 | // Doesn't affect optimization |
---|
959 | //NLL += 0.5*num[j]*Math.log(Math.PI); |
---|
960 | |
---|
961 | NLL -= 0.5*(b+num[j]-1.0)*Math.log(1.0+num[j]*w); |
---|
962 | |
---|
963 | |
---|
964 | if(Double.isNaN(NLL) && m_Debug){ |
---|
965 | System.err.println("???????????2: "+a+" "+b+" "+w+" "+m |
---|
966 | +"|x-: "+xBar[j] + |
---|
967 | "|n: "+num[j] + "|S^2: "+sSq[j]); |
---|
968 | System.exit(1); |
---|
969 | } |
---|
970 | |
---|
971 | int halfNum = ((int)num[j])/2; |
---|
972 | for(int z=1; z<=halfNum; z++) |
---|
973 | NLL -= Math.log(0.5*b+0.5*num[j]-(double)z); |
---|
974 | |
---|
975 | if(0.5*num[j] > halfNum) // num[j] is odd |
---|
976 | NLL -= diffLnGamma(0.5*b); |
---|
977 | |
---|
978 | if(Double.isNaN(NLL) && m_Debug){ |
---|
979 | System.err.println("???????????3: "+a+" "+b+" "+w+" "+m |
---|
980 | +"|x-: "+xBar[j] + |
---|
981 | "|n: "+num[j] + "|S^2: "+sSq[j]); |
---|
982 | System.exit(1); |
---|
983 | } |
---|
984 | |
---|
985 | NLL -= 0.5*Math.log(a)*b; |
---|
986 | if(Double.isNaN(NLL) && m_Debug){ |
---|
987 | System.err.println("???????????4:"+a+" "+b+" "+w+" "+m); |
---|
988 | System.exit(1); |
---|
989 | } |
---|
990 | } |
---|
991 | if(m_Debug) |
---|
992 | System.err.println("?????????????5: "+NLL); |
---|
993 | if(Double.isNaN(NLL)) |
---|
994 | System.exit(1); |
---|
995 | |
---|
996 | return NLL; |
---|
997 | } |
---|
998 | |
---|
999 | /** |
---|
1000 | * Subclass should implement this procedure to evaluate gradient |
---|
1001 | * of the objective function |
---|
1002 | */ |
---|
1003 | protected double[] evaluateGradient(double[] x){ |
---|
1004 | double[] g = new double[x.length]; |
---|
1005 | int numExs = num.length; |
---|
1006 | |
---|
1007 | double a=x[0],b=x[1],w=x[2],m=x[3]; |
---|
1008 | |
---|
1009 | double da=0.0, db=0.0, dw=0.0, dm=0.0; |
---|
1010 | for(int j=0; j < numExs; j++){ |
---|
1011 | |
---|
1012 | if(Double.isNaN(xBar[j])) continue; // All missing values |
---|
1013 | |
---|
1014 | double denorm = (1.0+num[j]*w)*(a+sSq[j]) + |
---|
1015 | num[j]*(xBar[j]-m)*(xBar[j]-m); |
---|
1016 | |
---|
1017 | da += 0.5*(b+num[j])*(1.0+num[j]*w)/denorm-0.5*b/a; |
---|
1018 | |
---|
1019 | db += 0.5*Math.log(denorm) |
---|
1020 | - 0.5*Math.log(1.0+num[j]*w) |
---|
1021 | - 0.5*Math.log(a); |
---|
1022 | |
---|
1023 | int halfNum = ((int)num[j])/2; |
---|
1024 | for(int z=1; z<=halfNum; z++) |
---|
1025 | db -= 1.0/(b+num[j]-2.0*(double)z); |
---|
1026 | if(num[j]/2.0 > halfNum) // num[j] is odd |
---|
1027 | db -= 0.5*diffFstDervLnGamma(0.5*b); |
---|
1028 | |
---|
1029 | dw += 0.5*(b+num[j])*(a+sSq[j])*num[j]/denorm - |
---|
1030 | 0.5*(b+num[j]-1.0)*num[j]/(1.0+num[j]*w); |
---|
1031 | |
---|
1032 | dm += num[j]*(b+num[j])*(m-xBar[j])/denorm; |
---|
1033 | } |
---|
1034 | |
---|
1035 | g[0] = da; |
---|
1036 | g[1] = db; |
---|
1037 | g[2] = dw; |
---|
1038 | g[3] = dm; |
---|
1039 | return g; |
---|
1040 | } |
---|
1041 | |
---|
1042 | /** |
---|
1043 | * Subclass should implement this procedure to evaluate second-order |
---|
1044 | * gradient of the objective function |
---|
1045 | */ |
---|
1046 | protected double[] evaluateHessian(double[] x, int index){ |
---|
1047 | double[] h = new double[x.length]; |
---|
1048 | |
---|
1049 | // # of exemplars, # of dimensions |
---|
1050 | // which dimension and which variable for 'index' |
---|
1051 | int numExs = num.length; |
---|
1052 | double a,b,w,m; |
---|
1053 | // Take the 2nd-order derivative |
---|
1054 | switch(index){ |
---|
1055 | case 0: // a |
---|
1056 | a=x[0];b=x[1];w=x[2];m=x[3]; |
---|
1057 | |
---|
1058 | for(int j=0; j < numExs; j++){ |
---|
1059 | if(Double.isNaN(xBar[j])) continue; //All missing values |
---|
1060 | double denorm = (1.0+num[j]*w)*(a+sSq[j]) + |
---|
1061 | num[j]*(xBar[j]-m)*(xBar[j]-m); |
---|
1062 | |
---|
1063 | h[0] += 0.5*b/(a*a) |
---|
1064 | - 0.5*(b+num[j])*(1.0+num[j]*w)*(1.0+num[j]*w) |
---|
1065 | /(denorm*denorm); |
---|
1066 | |
---|
1067 | h[1] += 0.5*(1.0+num[j]*w)/denorm - 0.5/a; |
---|
1068 | |
---|
1069 | h[2] += 0.5*num[j]*num[j]*(b+num[j])* |
---|
1070 | (xBar[j]-m)*(xBar[j]-m)/(denorm*denorm); |
---|
1071 | |
---|
1072 | h[3] -= num[j]*(b+num[j])*(m-xBar[j]) |
---|
1073 | *(1.0+num[j]*w)/(denorm*denorm); |
---|
1074 | } |
---|
1075 | break; |
---|
1076 | |
---|
1077 | case 1: // b |
---|
1078 | a=x[0];b=x[1];w=x[2];m=x[3]; |
---|
1079 | |
---|
1080 | for(int j=0; j < numExs; j++){ |
---|
1081 | if(Double.isNaN(xBar[j])) continue; //All missing values |
---|
1082 | double denorm = (1.0+num[j]*w)*(a+sSq[j]) + |
---|
1083 | num[j]*(xBar[j]-m)*(xBar[j]-m); |
---|
1084 | |
---|
1085 | h[0] += 0.5*(1.0+num[j]*w)/denorm - 0.5/a; |
---|
1086 | |
---|
1087 | int halfNum = ((int)num[j])/2; |
---|
1088 | for(int z=1; z<=halfNum; z++) |
---|
1089 | h[1] += |
---|
1090 | 1.0/((b+num[j]-2.0*(double)z)*(b+num[j]-2.0*(double)z)); |
---|
1091 | if(num[j]/2.0 > halfNum) // num[j] is odd |
---|
1092 | h[1] -= 0.25*diffSndDervLnGamma(0.5*b); |
---|
1093 | |
---|
1094 | h[2] += 0.5*(a+sSq[j])*num[j]/denorm - |
---|
1095 | 0.5*num[j]/(1.0+num[j]*w); |
---|
1096 | |
---|
1097 | h[3] += num[j]*(m-xBar[j])/denorm; |
---|
1098 | } |
---|
1099 | break; |
---|
1100 | |
---|
1101 | case 2: // w |
---|
1102 | a=x[0];b=x[1];w=x[2];m=x[3]; |
---|
1103 | |
---|
1104 | for(int j=0; j < numExs; j++){ |
---|
1105 | if(Double.isNaN(xBar[j])) continue; //All missing values |
---|
1106 | double denorm = (1.0+num[j]*w)*(a+sSq[j]) + |
---|
1107 | num[j]*(xBar[j]-m)*(xBar[j]-m); |
---|
1108 | |
---|
1109 | h[0] += 0.5*num[j]*num[j]*(b+num[j])* |
---|
1110 | (xBar[j]-m)*(xBar[j]-m)/(denorm*denorm); |
---|
1111 | |
---|
1112 | h[1] += 0.5*(a+sSq[j])*num[j]/denorm - |
---|
1113 | 0.5*num[j]/(1.0+num[j]*w); |
---|
1114 | |
---|
1115 | h[2] += 0.5*(b+num[j]-1.0)*num[j]*num[j]/ |
---|
1116 | ((1.0+num[j]*w)*(1.0+num[j]*w)) - |
---|
1117 | 0.5*(b+num[j])*(a+sSq[j])*(a+sSq[j])* |
---|
1118 | num[j]*num[j]/(denorm*denorm); |
---|
1119 | |
---|
1120 | h[3] -= num[j]*num[j]*(b+num[j])* |
---|
1121 | (m-xBar[j])*(a+sSq[j])/(denorm*denorm); |
---|
1122 | } |
---|
1123 | break; |
---|
1124 | |
---|
1125 | case 3: // m |
---|
1126 | a=x[0];b=x[1];w=x[2];m=x[3]; |
---|
1127 | |
---|
1128 | for(int j=0; j < numExs; j++){ |
---|
1129 | if(Double.isNaN(xBar[j])) continue; //All missing values |
---|
1130 | double denorm = (1.0+num[j]*w)*(a+sSq[j]) + |
---|
1131 | num[j]*(xBar[j]-m)*(xBar[j]-m); |
---|
1132 | |
---|
1133 | h[0] -= num[j]*(b+num[j])*(m-xBar[j]) |
---|
1134 | *(1.0+num[j]*w)/(denorm*denorm); |
---|
1135 | |
---|
1136 | h[1] += num[j]*(m-xBar[j])/denorm; |
---|
1137 | |
---|
1138 | h[2] -= num[j]*num[j]*(b+num[j])* |
---|
1139 | (m-xBar[j])*(a+sSq[j])/(denorm*denorm); |
---|
1140 | |
---|
1141 | h[3] += num[j]*(b+num[j])* |
---|
1142 | ((1.0+num[j]*w)*(a+sSq[j])- |
---|
1143 | num[j]*(m-xBar[j])*(m-xBar[j])) |
---|
1144 | /(denorm*denorm); |
---|
1145 | } |
---|
1146 | } |
---|
1147 | |
---|
1148 | return h; |
---|
1149 | } |
---|
1150 | |
---|
1151 | /** |
---|
1152 | * Returns the revision string. |
---|
1153 | * |
---|
1154 | * @return the revision |
---|
1155 | */ |
---|
1156 | public String getRevision() { |
---|
1157 | return RevisionUtils.extract("$Revision: 5481 $"); |
---|
1158 | } |
---|
1159 | } |
---|