1 | /* |
---|
2 | * This program is free software; you can redistribute it and/or modify |
---|
3 | * it under the terms of the GNU General Public License as published by |
---|
4 | * the Free Software Foundation; either version 2 of the License, or |
---|
5 | * (at your option) any later version. |
---|
6 | * |
---|
7 | * This program is distributed in the hope that it will be useful, |
---|
8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
---|
10 | * GNU General Public License for more details. |
---|
11 | * |
---|
12 | * You should have received a copy of the GNU General Public License |
---|
13 | * along with this program; if not, write to the Free Software |
---|
14 | * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. |
---|
15 | */ |
---|
16 | |
---|
17 | /* |
---|
18 | * MIDD.java |
---|
19 | * Copyright (C) 2005 University of Waikato, Hamilton, New Zealand |
---|
20 | * |
---|
21 | */ |
---|
22 | |
---|
23 | package weka.classifiers.mi; |
---|
24 | |
---|
25 | import weka.classifiers.Classifier; |
---|
26 | import weka.classifiers.AbstractClassifier; |
---|
27 | import weka.core.Capabilities; |
---|
28 | import weka.core.FastVector; |
---|
29 | import weka.core.Instance; |
---|
30 | import weka.core.Instances; |
---|
31 | import weka.core.MultiInstanceCapabilitiesHandler; |
---|
32 | import weka.core.Optimization; |
---|
33 | import weka.core.Option; |
---|
34 | import weka.core.OptionHandler; |
---|
35 | import weka.core.RevisionUtils; |
---|
36 | import weka.core.SelectedTag; |
---|
37 | import weka.core.Tag; |
---|
38 | import weka.core.TechnicalInformation; |
---|
39 | import weka.core.TechnicalInformationHandler; |
---|
40 | import weka.core.Utils; |
---|
41 | import weka.core.Capabilities.Capability; |
---|
42 | import weka.core.TechnicalInformation.Field; |
---|
43 | import weka.core.TechnicalInformation.Type; |
---|
44 | import weka.filters.Filter; |
---|
45 | import weka.filters.unsupervised.attribute.Normalize; |
---|
46 | import weka.filters.unsupervised.attribute.ReplaceMissingValues; |
---|
47 | import weka.filters.unsupervised.attribute.Standardize; |
---|
48 | |
---|
49 | import java.util.Enumeration; |
---|
50 | import java.util.Vector; |
---|
51 | |
---|
52 | /** |
---|
53 | <!-- globalinfo-start --> |
---|
54 | * Re-implement the Diverse Density algorithm, changes the testing procedure.<br/> |
---|
55 | * <br/> |
---|
56 | * Oded Maron (1998). Learning from ambiguity.<br/> |
---|
57 | * <br/> |
---|
58 | * O. Maron, T. Lozano-Perez (1998). A Framework for Multiple Instance Learning. Neural Information Processing Systems. 10. |
---|
59 | * <p/> |
---|
60 | <!-- globalinfo-end --> |
---|
61 | * |
---|
62 | <!-- technical-bibtex-start --> |
---|
63 | * BibTeX: |
---|
64 | * <pre> |
---|
65 | * @phdthesis{Maron1998, |
---|
66 | * author = {Oded Maron}, |
---|
67 | * school = {Massachusetts Institute of Technology}, |
---|
68 | * title = {Learning from ambiguity}, |
---|
69 | * year = {1998} |
---|
70 | * } |
---|
71 | * |
---|
72 | * @article{Maron1998, |
---|
73 | * author = {O. Maron and T. Lozano-Perez}, |
---|
74 | * journal = {Neural Information Processing Systems}, |
---|
75 | * title = {A Framework for Multiple Instance Learning}, |
---|
76 | * volume = {10}, |
---|
77 | * year = {1998} |
---|
78 | * } |
---|
79 | * </pre> |
---|
80 | * <p/> |
---|
81 | <!-- technical-bibtex-end --> |
---|
82 | * |
---|
83 | <!-- options-start --> |
---|
84 | * Valid options are: <p/> |
---|
85 | * |
---|
86 | * <pre> -D |
---|
87 | * Turn on debugging output.</pre> |
---|
88 | * |
---|
89 | * <pre> -N <num> |
---|
90 | * Whether to 0=normalize/1=standardize/2=neither. |
---|
91 | * (default 1=standardize)</pre> |
---|
92 | * |
---|
93 | <!-- options-end --> |
---|
94 | * |
---|
95 | * @author Eibe Frank (eibe@cs.waikato.ac.nz) |
---|
96 | * @author Xin Xu (xx5@cs.waikato.ac.nz) |
---|
97 | * @version $Revision: 5928 $ |
---|
98 | */ |
---|
99 | public class MIDD |
---|
100 | extends AbstractClassifier |
---|
101 | implements OptionHandler, MultiInstanceCapabilitiesHandler, |
---|
102 | TechnicalInformationHandler { |
---|
103 | |
---|
104 | /** for serialization */ |
---|
105 | static final long serialVersionUID = 4263507733600536168L; |
---|
106 | |
---|
107 | /** The index of the class attribute */ |
---|
108 | protected int m_ClassIndex; |
---|
109 | |
---|
110 | protected double[] m_Par; |
---|
111 | |
---|
112 | /** The number of the class labels */ |
---|
113 | protected int m_NumClasses; |
---|
114 | |
---|
115 | /** Class labels for each bag */ |
---|
116 | protected int[] m_Classes; |
---|
117 | |
---|
118 | /** MI data */ |
---|
119 | protected double[][][] m_Data; |
---|
120 | |
---|
121 | /** All attribute names */ |
---|
122 | protected Instances m_Attributes; |
---|
123 | |
---|
124 | /** The filter used to standardize/normalize all values. */ |
---|
125 | protected Filter m_Filter = null; |
---|
126 | |
---|
127 | /** Whether to normalize/standardize/neither, default:standardize */ |
---|
128 | protected int m_filterType = FILTER_STANDARDIZE; |
---|
129 | |
---|
130 | /** Normalize training data */ |
---|
131 | public static final int FILTER_NORMALIZE = 0; |
---|
132 | /** Standardize training data */ |
---|
133 | public static final int FILTER_STANDARDIZE = 1; |
---|
134 | /** No normalization/standardization */ |
---|
135 | public static final int FILTER_NONE = 2; |
---|
136 | /** The filter to apply to the training data */ |
---|
137 | public static final Tag [] TAGS_FILTER = { |
---|
138 | new Tag(FILTER_NORMALIZE, "Normalize training data"), |
---|
139 | new Tag(FILTER_STANDARDIZE, "Standardize training data"), |
---|
140 | new Tag(FILTER_NONE, "No normalization/standardization"), |
---|
141 | }; |
---|
142 | |
---|
143 | /** The filter used to get rid of missing values. */ |
---|
144 | protected ReplaceMissingValues m_Missing = new ReplaceMissingValues(); |
---|
145 | |
---|
146 | /** |
---|
147 | * Returns a string describing this filter |
---|
148 | * |
---|
149 | * @return a description of the filter suitable for |
---|
150 | * displaying in the explorer/experimenter gui |
---|
151 | */ |
---|
152 | public String globalInfo() { |
---|
153 | return |
---|
154 | "Re-implement the Diverse Density algorithm, changes the testing " |
---|
155 | + "procedure.\n\n" |
---|
156 | + getTechnicalInformation().toString(); |
---|
157 | } |
---|
158 | |
---|
159 | /** |
---|
160 | * Returns an instance of a TechnicalInformation object, containing |
---|
161 | * detailed information about the technical background of this class, |
---|
162 | * e.g., paper reference or book this class is based on. |
---|
163 | * |
---|
164 | * @return the technical information about this class |
---|
165 | */ |
---|
166 | public TechnicalInformation getTechnicalInformation() { |
---|
167 | TechnicalInformation result; |
---|
168 | TechnicalInformation additional; |
---|
169 | |
---|
170 | result = new TechnicalInformation(Type.PHDTHESIS); |
---|
171 | result.setValue(Field.AUTHOR, "Oded Maron"); |
---|
172 | result.setValue(Field.YEAR, "1998"); |
---|
173 | result.setValue(Field.TITLE, "Learning from ambiguity"); |
---|
174 | result.setValue(Field.SCHOOL, "Massachusetts Institute of Technology"); |
---|
175 | |
---|
176 | additional = result.add(Type.ARTICLE); |
---|
177 | additional.setValue(Field.AUTHOR, "O. Maron and T. Lozano-Perez"); |
---|
178 | additional.setValue(Field.YEAR, "1998"); |
---|
179 | additional.setValue(Field.TITLE, "A Framework for Multiple Instance Learning"); |
---|
180 | additional.setValue(Field.JOURNAL, "Neural Information Processing Systems"); |
---|
181 | additional.setValue(Field.VOLUME, "10"); |
---|
182 | |
---|
183 | return result; |
---|
184 | } |
---|
185 | |
---|
186 | /** |
---|
187 | * Returns an enumeration describing the available options |
---|
188 | * |
---|
189 | * @return an enumeration of all the available options |
---|
190 | */ |
---|
191 | public Enumeration listOptions() { |
---|
192 | Vector result = new Vector(); |
---|
193 | |
---|
194 | result.addElement(new Option( |
---|
195 | "\tTurn on debugging output.", |
---|
196 | "D", 0, "-D")); |
---|
197 | |
---|
198 | result.addElement(new Option( |
---|
199 | "\tWhether to 0=normalize/1=standardize/2=neither.\n" |
---|
200 | + "\t(default 1=standardize)", |
---|
201 | "N", 1, "-N <num>")); |
---|
202 | |
---|
203 | return result.elements(); |
---|
204 | } |
---|
205 | |
---|
206 | /** |
---|
207 | * Parses a given list of options. <p/> |
---|
208 | * |
---|
209 | <!-- options-start --> |
---|
210 | * Valid options are: <p/> |
---|
211 | * |
---|
212 | * <pre> -D |
---|
213 | * Turn on debugging output.</pre> |
---|
214 | * |
---|
215 | * <pre> -N <num> |
---|
216 | * Whether to 0=normalize/1=standardize/2=neither. |
---|
217 | * (default 1=standardize)</pre> |
---|
218 | * |
---|
219 | <!-- options-end --> |
---|
220 | * |
---|
221 | * @param options the list of options as an array of strings |
---|
222 | * @throws Exception if an option is not supported |
---|
223 | */ |
---|
224 | public void setOptions(String[] options) throws Exception { |
---|
225 | setDebug(Utils.getFlag('D', options)); |
---|
226 | |
---|
227 | String nString = Utils.getOption('N', options); |
---|
228 | if (nString.length() != 0) { |
---|
229 | setFilterType(new SelectedTag(Integer.parseInt(nString), TAGS_FILTER)); |
---|
230 | } else { |
---|
231 | setFilterType(new SelectedTag(FILTER_STANDARDIZE, TAGS_FILTER)); |
---|
232 | } |
---|
233 | } |
---|
234 | |
---|
235 | /** |
---|
236 | * Gets the current settings of the classifier. |
---|
237 | * |
---|
238 | * @return an array of strings suitable for passing to setOptions |
---|
239 | */ |
---|
240 | public String[] getOptions() { |
---|
241 | Vector result; |
---|
242 | |
---|
243 | result = new Vector(); |
---|
244 | |
---|
245 | if (getDebug()) |
---|
246 | result.add("-D"); |
---|
247 | |
---|
248 | result.add("-N"); |
---|
249 | result.add("" + m_filterType); |
---|
250 | |
---|
251 | return (String[]) result.toArray(new String[result.size()]); |
---|
252 | } |
---|
253 | |
---|
254 | /** |
---|
255 | * Returns the tip text for this property |
---|
256 | * |
---|
257 | * @return tip text for this property suitable for |
---|
258 | * displaying in the explorer/experimenter gui |
---|
259 | */ |
---|
260 | public String filterTypeTipText() { |
---|
261 | return "The filter type for transforming the training data."; |
---|
262 | } |
---|
263 | |
---|
264 | /** |
---|
265 | * Gets how the training data will be transformed. Will be one of |
---|
266 | * FILTER_NORMALIZE, FILTER_STANDARDIZE, FILTER_NONE. |
---|
267 | * |
---|
268 | * @return the filtering mode |
---|
269 | */ |
---|
270 | public SelectedTag getFilterType() { |
---|
271 | return new SelectedTag(m_filterType, TAGS_FILTER); |
---|
272 | } |
---|
273 | |
---|
274 | /** |
---|
275 | * Sets how the training data will be transformed. Should be one of |
---|
276 | * FILTER_NORMALIZE, FILTER_STANDARDIZE, FILTER_NONE. |
---|
277 | * |
---|
278 | * @param newType the new filtering mode |
---|
279 | */ |
---|
280 | public void setFilterType(SelectedTag newType) { |
---|
281 | |
---|
282 | if (newType.getTags() == TAGS_FILTER) { |
---|
283 | m_filterType = newType.getSelectedTag().getID(); |
---|
284 | } |
---|
285 | } |
---|
286 | |
---|
287 | private class OptEng |
---|
288 | extends Optimization { |
---|
289 | |
---|
290 | /** |
---|
291 | * Evaluate objective function |
---|
292 | * @param x the current values of variables |
---|
293 | * @return the value of the objective function |
---|
294 | */ |
---|
295 | protected double objectiveFunction(double[] x){ |
---|
296 | double nll = 0; // -LogLikelihood |
---|
297 | for(int i=0; i<m_Classes.length; i++){ // ith bag |
---|
298 | int nI = m_Data[i][0].length; // numInstances in ith bag |
---|
299 | double bag = 0.0; // NLL of pos bag |
---|
300 | |
---|
301 | for(int j=0; j<nI; j++){ |
---|
302 | double ins=0.0; |
---|
303 | for(int k=0; k<m_Data[i].length; k++) |
---|
304 | ins += (m_Data[i][k][j]-x[k*2])*(m_Data[i][k][j]-x[k*2])* |
---|
305 | x[k*2+1]*x[k*2+1]; |
---|
306 | ins = Math.exp(-ins); |
---|
307 | ins = 1.0-ins; |
---|
308 | |
---|
309 | if(m_Classes[i] == 1) |
---|
310 | bag += Math.log(ins); |
---|
311 | else{ |
---|
312 | if(ins<=m_Zero) ins=m_Zero; |
---|
313 | nll -= Math.log(ins); |
---|
314 | } |
---|
315 | } |
---|
316 | |
---|
317 | if(m_Classes[i] == 1){ |
---|
318 | bag = 1.0 - Math.exp(bag); |
---|
319 | if(bag<=m_Zero) bag=m_Zero; |
---|
320 | nll -= Math.log(bag); |
---|
321 | } |
---|
322 | } |
---|
323 | return nll; |
---|
324 | } |
---|
325 | |
---|
326 | /** |
---|
327 | * Evaluate Jacobian vector |
---|
328 | * @param x the current values of variables |
---|
329 | * @return the gradient vector |
---|
330 | */ |
---|
331 | protected double[] evaluateGradient(double[] x){ |
---|
332 | double[] grad = new double[x.length]; |
---|
333 | for(int i=0; i<m_Classes.length; i++){ // ith bag |
---|
334 | int nI = m_Data[i][0].length; // numInstances in ith bag |
---|
335 | |
---|
336 | double denom=0.0; |
---|
337 | double[] numrt = new double[x.length]; |
---|
338 | |
---|
339 | for(int j=0; j<nI; j++){ |
---|
340 | double exp=0.0; |
---|
341 | for(int k=0; k<m_Data[i].length; k++) |
---|
342 | exp += (m_Data[i][k][j]-x[k*2])*(m_Data[i][k][j]-x[k*2]) |
---|
343 | *x[k*2+1]*x[k*2+1]; |
---|
344 | exp = Math.exp(-exp); |
---|
345 | exp = 1.0-exp; |
---|
346 | if(m_Classes[i]==1) |
---|
347 | denom += Math.log(exp); |
---|
348 | |
---|
349 | if(exp<=m_Zero) exp=m_Zero; |
---|
350 | // Instance-wise update |
---|
351 | for(int p=0; p<m_Data[i].length; p++){ // pth variable |
---|
352 | numrt[2*p] += (1.0-exp)*2.0*(x[2*p]-m_Data[i][p][j])*x[p*2+1]*x[p*2+1] |
---|
353 | /exp; |
---|
354 | numrt[2*p+1] += 2.0*(1.0-exp)*(x[2*p]-m_Data[i][p][j])*(x[2*p]-m_Data[i][p][j]) |
---|
355 | *x[p*2+1]/exp; |
---|
356 | } |
---|
357 | } |
---|
358 | |
---|
359 | // Bag-wise update |
---|
360 | denom = 1.0-Math.exp(denom); |
---|
361 | if(denom <= m_Zero) denom = m_Zero; |
---|
362 | for(int q=0; q<m_Data[i].length; q++){ |
---|
363 | if(m_Classes[i]==1){ |
---|
364 | grad[2*q] += numrt[2*q]*(1.0-denom)/denom; |
---|
365 | grad[2*q+1] += numrt[2*q+1]*(1.0-denom)/denom; |
---|
366 | }else{ |
---|
367 | grad[2*q] -= numrt[2*q]; |
---|
368 | grad[2*q+1] -= numrt[2*q+1]; |
---|
369 | } |
---|
370 | } |
---|
371 | } // one bag |
---|
372 | |
---|
373 | return grad; |
---|
374 | } |
---|
375 | |
---|
376 | /** |
---|
377 | * Returns the revision string. |
---|
378 | * |
---|
379 | * @return the revision |
---|
380 | */ |
---|
381 | public String getRevision() { |
---|
382 | return RevisionUtils.extract("$Revision: 5928 $"); |
---|
383 | } |
---|
384 | } |
---|
385 | |
---|
386 | /** |
---|
387 | * Returns default capabilities of the classifier. |
---|
388 | * |
---|
389 | * @return the capabilities of this classifier |
---|
390 | */ |
---|
391 | public Capabilities getCapabilities() { |
---|
392 | Capabilities result = super.getCapabilities(); |
---|
393 | result.disableAll(); |
---|
394 | |
---|
395 | // attributes |
---|
396 | result.enable(Capability.NOMINAL_ATTRIBUTES); |
---|
397 | result.enable(Capability.RELATIONAL_ATTRIBUTES); |
---|
398 | result.enable(Capability.MISSING_VALUES); |
---|
399 | |
---|
400 | // class |
---|
401 | result.enable(Capability.BINARY_CLASS); |
---|
402 | result.enable(Capability.MISSING_CLASS_VALUES); |
---|
403 | |
---|
404 | // other |
---|
405 | result.enable(Capability.ONLY_MULTIINSTANCE); |
---|
406 | |
---|
407 | return result; |
---|
408 | } |
---|
409 | |
---|
410 | /** |
---|
411 | * Returns the capabilities of this multi-instance classifier for the |
---|
412 | * relational data. |
---|
413 | * |
---|
414 | * @return the capabilities of this object |
---|
415 | * @see Capabilities |
---|
416 | */ |
---|
417 | public Capabilities getMultiInstanceCapabilities() { |
---|
418 | Capabilities result = super.getCapabilities(); |
---|
419 | result.disableAll(); |
---|
420 | |
---|
421 | // attributes |
---|
422 | result.enable(Capability.NOMINAL_ATTRIBUTES); |
---|
423 | result.enable(Capability.NUMERIC_ATTRIBUTES); |
---|
424 | result.enable(Capability.DATE_ATTRIBUTES); |
---|
425 | result.enable(Capability.MISSING_VALUES); |
---|
426 | |
---|
427 | // class |
---|
428 | result.disableAllClasses(); |
---|
429 | result.enable(Capability.NO_CLASS); |
---|
430 | |
---|
431 | return result; |
---|
432 | } |
---|
433 | |
---|
434 | /** |
---|
435 | * Builds the classifier |
---|
436 | * |
---|
437 | * @param train the training data to be used for generating the |
---|
438 | * boosted classifier. |
---|
439 | * @throws Exception if the classifier could not be built successfully |
---|
440 | */ |
---|
441 | public void buildClassifier(Instances train) throws Exception { |
---|
442 | // can classifier handle the data? |
---|
443 | getCapabilities().testWithFail(train); |
---|
444 | |
---|
445 | // remove instances with missing class |
---|
446 | train = new Instances(train); |
---|
447 | train.deleteWithMissingClass(); |
---|
448 | |
---|
449 | m_ClassIndex = train.classIndex(); |
---|
450 | m_NumClasses = train.numClasses(); |
---|
451 | |
---|
452 | int nR = train.attribute(1).relation().numAttributes(); |
---|
453 | int nC = train.numInstances(); |
---|
454 | FastVector maxSzIdx=new FastVector(); |
---|
455 | int maxSz=0; |
---|
456 | int [] bagSize=new int [nC]; |
---|
457 | Instances datasets= new Instances(train.attribute(1).relation(),0); |
---|
458 | |
---|
459 | m_Data = new double [nC][nR][]; // Data values |
---|
460 | m_Classes = new int [nC]; // Class values |
---|
461 | m_Attributes = datasets.stringFreeStructure(); |
---|
462 | if (m_Debug) { |
---|
463 | System.out.println("Extracting data..."); |
---|
464 | } |
---|
465 | |
---|
466 | for(int h=0; h<nC; h++) {//h_th bag |
---|
467 | Instance current = train.instance(h); |
---|
468 | m_Classes[h] = (int)current.classValue(); // Class value starts from 0 |
---|
469 | Instances currInsts = current.relationalValue(1); |
---|
470 | for (int i=0; i<currInsts.numInstances();i++){ |
---|
471 | Instance inst=currInsts.instance(i); |
---|
472 | datasets.add(inst); |
---|
473 | } |
---|
474 | |
---|
475 | int nI = currInsts.numInstances(); |
---|
476 | bagSize[h]=nI; |
---|
477 | if(m_Classes[h]==1){ |
---|
478 | if(nI>maxSz){ |
---|
479 | maxSz=nI; |
---|
480 | maxSzIdx=new FastVector(1); |
---|
481 | maxSzIdx.addElement(new Integer(h)); |
---|
482 | } |
---|
483 | else if(nI == maxSz) |
---|
484 | maxSzIdx.addElement(new Integer(h)); |
---|
485 | } |
---|
486 | |
---|
487 | } |
---|
488 | |
---|
489 | /* filter the training data */ |
---|
490 | if (m_filterType == FILTER_STANDARDIZE) |
---|
491 | m_Filter = new Standardize(); |
---|
492 | else if (m_filterType == FILTER_NORMALIZE) |
---|
493 | m_Filter = new Normalize(); |
---|
494 | else |
---|
495 | m_Filter = null; |
---|
496 | |
---|
497 | if (m_Filter!=null) { |
---|
498 | m_Filter.setInputFormat(datasets); |
---|
499 | datasets = Filter.useFilter(datasets, m_Filter); |
---|
500 | } |
---|
501 | |
---|
502 | m_Missing.setInputFormat(datasets); |
---|
503 | datasets = Filter.useFilter(datasets, m_Missing); |
---|
504 | |
---|
505 | |
---|
506 | int instIndex=0; |
---|
507 | int start=0; |
---|
508 | for(int h=0; h<nC; h++) { |
---|
509 | for (int i = 0; i < datasets.numAttributes(); i++) { |
---|
510 | // initialize m_data[][][] |
---|
511 | m_Data[h][i] = new double[bagSize[h]]; |
---|
512 | instIndex=start; |
---|
513 | for (int k=0; k<bagSize[h]; k++){ |
---|
514 | m_Data[h][i][k]=datasets.instance(instIndex).value(i); |
---|
515 | instIndex ++; |
---|
516 | } |
---|
517 | } |
---|
518 | start=instIndex; |
---|
519 | } |
---|
520 | |
---|
521 | |
---|
522 | if (m_Debug) { |
---|
523 | System.out.println("\nIteration History..." ); |
---|
524 | } |
---|
525 | |
---|
526 | double[] x = new double[nR*2], tmp = new double[x.length]; |
---|
527 | double[][] b = new double[2][x.length]; |
---|
528 | |
---|
529 | OptEng opt; |
---|
530 | double nll, bestnll = Double.MAX_VALUE; |
---|
531 | for (int t=0; t<x.length; t++){ |
---|
532 | b[0][t] = Double.NaN; |
---|
533 | b[1][t] = Double.NaN; |
---|
534 | } |
---|
535 | |
---|
536 | // Largest Positive exemplar |
---|
537 | for(int s=0; s<maxSzIdx.size(); s++){ |
---|
538 | int exIdx = ((Integer)maxSzIdx.elementAt(s)).intValue(); |
---|
539 | for(int p=0; p<m_Data[exIdx][0].length; p++){ |
---|
540 | for (int q=0; q < nR;q++){ |
---|
541 | x[2*q] = m_Data[exIdx][q][p]; // pick one instance |
---|
542 | x[2*q+1] = 1.0; |
---|
543 | } |
---|
544 | |
---|
545 | opt = new OptEng(); |
---|
546 | //opt.setDebug(m_Debug); |
---|
547 | tmp = opt.findArgmin(x, b); |
---|
548 | while(tmp==null){ |
---|
549 | tmp = opt.getVarbValues(); |
---|
550 | if (m_Debug) |
---|
551 | System.out.println("200 iterations finished, not enough!"); |
---|
552 | tmp = opt.findArgmin(tmp, b); |
---|
553 | } |
---|
554 | nll = opt.getMinFunction(); |
---|
555 | |
---|
556 | if(nll < bestnll){ |
---|
557 | bestnll = nll; |
---|
558 | m_Par = tmp; |
---|
559 | tmp = new double[x.length]; // Save memory |
---|
560 | if (m_Debug) |
---|
561 | System.out.println("!!!!!!!!!!!!!!!!Smaller NLL found: "+nll); |
---|
562 | } |
---|
563 | if (m_Debug) |
---|
564 | System.out.println(exIdx+": -------------<Converged>--------------"); |
---|
565 | } |
---|
566 | } |
---|
567 | } |
---|
568 | |
---|
569 | /** |
---|
570 | * Computes the distribution for a given exemplar |
---|
571 | * |
---|
572 | * @param exmp the exemplar for which distribution is computed |
---|
573 | * @return the distribution |
---|
574 | * @throws Exception if the distribution can't be computed successfully |
---|
575 | */ |
---|
576 | public double[] distributionForInstance(Instance exmp) |
---|
577 | throws Exception { |
---|
578 | |
---|
579 | // Extract the data |
---|
580 | Instances ins = exmp.relationalValue(1); |
---|
581 | if(m_Filter!=null) |
---|
582 | ins = Filter.useFilter(ins, m_Filter); |
---|
583 | |
---|
584 | ins = Filter.useFilter(ins, m_Missing); |
---|
585 | |
---|
586 | int nI = ins.numInstances(), nA = ins.numAttributes(); |
---|
587 | double[][] dat = new double [nI][nA]; |
---|
588 | for(int j=0; j<nI; j++){ |
---|
589 | for(int k=0; k<nA; k++){ |
---|
590 | dat[j][k] = ins.instance(j).value(k); |
---|
591 | } |
---|
592 | } |
---|
593 | |
---|
594 | // Compute the probability of the bag |
---|
595 | double [] distribution = new double[2]; |
---|
596 | distribution[0]=0.0; // log-Prob. for class 0 |
---|
597 | |
---|
598 | for(int i=0; i<nI; i++){ |
---|
599 | double exp = 0.0; |
---|
600 | for(int r=0; r<nA; r++) |
---|
601 | exp += (m_Par[r*2]-dat[i][r])*(m_Par[r*2]-dat[i][r])* |
---|
602 | m_Par[r*2+1]*m_Par[r*2+1]; |
---|
603 | exp = Math.exp(-exp); |
---|
604 | |
---|
605 | // Prob. updated for one instance |
---|
606 | distribution[0] += Math.log(1.0-exp); |
---|
607 | } |
---|
608 | |
---|
609 | distribution[0] = Math.exp(distribution[0]); |
---|
610 | distribution[1] = 1.0-distribution[0]; |
---|
611 | |
---|
612 | return distribution; |
---|
613 | } |
---|
614 | |
---|
615 | /** |
---|
616 | * Gets a string describing the classifier. |
---|
617 | * |
---|
618 | * @return a string describing the classifer built. |
---|
619 | */ |
---|
620 | public String toString() { |
---|
621 | |
---|
622 | //double CSq = m_LLn - m_LL; |
---|
623 | //int df = m_NumPredictors; |
---|
624 | String result = "Diverse Density"; |
---|
625 | if (m_Par == null) { |
---|
626 | return result + ": No model built yet."; |
---|
627 | } |
---|
628 | |
---|
629 | result += "\nCoefficients...\n" |
---|
630 | + "Variable Point Scale\n"; |
---|
631 | for (int j = 0, idx=0; j < m_Par.length/2; j++, idx++) { |
---|
632 | result += m_Attributes.attribute(idx).name(); |
---|
633 | result += " "+Utils.doubleToString(m_Par[j*2], 12, 4); |
---|
634 | result += " "+Utils.doubleToString(m_Par[j*2+1], 12, 4)+"\n"; |
---|
635 | } |
---|
636 | |
---|
637 | return result; |
---|
638 | } |
---|
639 | |
---|
640 | /** |
---|
641 | * Returns the revision string. |
---|
642 | * |
---|
643 | * @return the revision |
---|
644 | */ |
---|
645 | public String getRevision() { |
---|
646 | return RevisionUtils.extract("$Revision: 5928 $"); |
---|
647 | } |
---|
648 | |
---|
649 | /** |
---|
650 | * Main method for testing this class. |
---|
651 | * |
---|
652 | * @param argv should contain the command line arguments to the |
---|
653 | * scheme (see Evaluation) |
---|
654 | */ |
---|
655 | public static void main(String[] argv) { |
---|
656 | runClassifier(new MIDD(), argv); |
---|
657 | } |
---|
658 | } |
---|