[4] | 1 | /* |
---|
| 2 | * This program is free software; you can redistribute it and/or modify |
---|
| 3 | * it under the terms of the GNU General Public License as published by |
---|
| 4 | * the Free Software Foundation; either version 2 of the License, or |
---|
| 5 | * (at your option) any later version. |
---|
| 6 | * |
---|
| 7 | * This program is distributed in the hope that it will be useful, |
---|
| 8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
| 9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
---|
| 10 | * GNU General Public License for more details. |
---|
| 11 | * |
---|
| 12 | * You should have received a copy of the GNU General Public License |
---|
| 13 | * along with this program; if not, write to the Free Software |
---|
| 14 | * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. |
---|
| 15 | */ |
---|
| 16 | |
---|
| 17 | /* |
---|
| 18 | * SPegasos.java |
---|
| 19 | * Copyright (C) 2009 University of Waikato, Hamilton, New Zealand |
---|
| 20 | * |
---|
| 21 | */ |
---|
| 22 | |
---|
| 23 | package weka.classifiers.functions; |
---|
| 24 | |
---|
| 25 | import java.util.ArrayList; |
---|
| 26 | import java.util.Enumeration; |
---|
| 27 | import java.util.Vector; |
---|
| 28 | |
---|
| 29 | import weka.classifiers.AbstractClassifier; |
---|
| 30 | import weka.classifiers.UpdateableClassifier; |
---|
| 31 | import weka.core.Capabilities; |
---|
| 32 | import weka.core.Instance; |
---|
| 33 | import weka.core.Instances; |
---|
| 34 | import weka.core.Option; |
---|
| 35 | import weka.core.OptionHandler; |
---|
| 36 | import weka.core.RevisionUtils; |
---|
| 37 | import weka.core.SelectedTag; |
---|
| 38 | import weka.core.Tag; |
---|
| 39 | import weka.core.TechnicalInformation; |
---|
| 40 | import weka.core.TechnicalInformationHandler; |
---|
| 41 | import weka.core.Utils; |
---|
| 42 | import weka.core.Capabilities.Capability; |
---|
| 43 | import weka.core.TechnicalInformation.Field; |
---|
| 44 | import weka.core.TechnicalInformation.Type; |
---|
| 45 | import weka.filters.Filter; |
---|
| 46 | import weka.filters.unsupervised.attribute.NominalToBinary; |
---|
| 47 | import weka.filters.unsupervised.attribute.ReplaceMissingValues; |
---|
| 48 | import weka.filters.unsupervised.attribute.Normalize; |
---|
| 49 | |
---|
| 50 | /** |
---|
| 51 | <!-- globalinfo-start --> |
---|
| 52 | * Implements the stochastic variant of the Pegasos (Primal Estimated sub-GrAdient SOlver for SVM) method of Shalev-Shwartz et al. (2007). This implementation globally replaces all missing values and transforms nominal attributes into binary ones. It also normalizes all attributes, so the coefficients in the output are based on the normalized data. Can either minimize the hinge loss (SVM) or log loss (logistic regression). For more information, see<br/> |
---|
| 53 | * <br/> |
---|
| 54 | * S. Shalev-Shwartz, Y. Singer, N. Srebro: Pegasos: Primal Estimated sub-GrAdient SOlver for SVM. In: 24th International Conference on MachineLearning, 807-814, 2007. |
---|
| 55 | * <p/> |
---|
| 56 | <!-- globalinfo-end --> |
---|
| 57 | * |
---|
| 58 | <!-- technical-bibtex-start --> |
---|
| 59 | * BibTeX: |
---|
| 60 | * <pre> |
---|
| 61 | * @inproceedings{Shalev-Shwartz2007, |
---|
| 62 | * author = {S. Shalev-Shwartz and Y. Singer and N. Srebro}, |
---|
| 63 | * booktitle = {24th International Conference on MachineLearning}, |
---|
| 64 | * pages = {807-814}, |
---|
| 65 | * title = {Pegasos: Primal Estimated sub-GrAdient SOlver for SVM}, |
---|
| 66 | * year = {2007} |
---|
| 67 | * } |
---|
| 68 | * </pre> |
---|
| 69 | * <p/> |
---|
| 70 | <!-- technical-bibtex-end --> |
---|
| 71 | * |
---|
| 72 | <!-- options-start --> |
---|
| 73 | * Valid options are: <p/> |
---|
| 74 | * |
---|
| 75 | * <pre> -F |
---|
| 76 | * Set the loss function to minimize. 0 = hinge loss (SVM), 1 = log loss (logistic regression). |
---|
| 77 | * (default = 0)</pre> |
---|
| 78 | * |
---|
| 79 | * <pre> -L <double> |
---|
| 80 | * The lambda regularization constant (default = 0.0001)</pre> |
---|
| 81 | * |
---|
| 82 | * <pre> -E <integer> |
---|
| 83 | * The number of epochs to perform (batch learning only, default = 500)</pre> |
---|
| 84 | * |
---|
| 85 | * <pre> -N |
---|
| 86 | * Don't normalize the data</pre> |
---|
| 87 | * |
---|
| 88 | * <pre> -M |
---|
| 89 | * Don't replace missing values</pre> |
---|
| 90 | * |
---|
| 91 | <!-- options-end --> |
---|
| 92 | * |
---|
| 93 | * @author Mark Hall (mhall{[at]}pentaho{[dot]}com) |
---|
| 94 | * @version $Revision: 6105 $ |
---|
| 95 | * |
---|
| 96 | */ |
---|
| 97 | public class SPegasos extends AbstractClassifier |
---|
| 98 | implements TechnicalInformationHandler, UpdateableClassifier, |
---|
| 99 | OptionHandler { |
---|
| 100 | |
---|
| 101 | /** For serialization */ |
---|
| 102 | private static final long serialVersionUID = -3732968666673530290L; |
---|
| 103 | |
---|
| 104 | /** Replace missing values */ |
---|
| 105 | protected ReplaceMissingValues m_replaceMissing; |
---|
| 106 | |
---|
| 107 | /** Convert nominal attributes to numerically coded binary ones */ |
---|
| 108 | protected NominalToBinary m_nominalToBinary; |
---|
| 109 | |
---|
| 110 | /** Normalize the training data */ |
---|
| 111 | protected Normalize m_normalize; |
---|
| 112 | |
---|
| 113 | /** The regularization parameter */ |
---|
| 114 | protected double m_lambda = 0.0001; |
---|
| 115 | |
---|
| 116 | /** Stores the weights (+ bias in the last element) */ |
---|
| 117 | protected double[] m_weights; |
---|
| 118 | |
---|
| 119 | /** Holds the current iteration number */ |
---|
| 120 | protected double m_t; |
---|
| 121 | |
---|
| 122 | /** |
---|
| 123 | * The number of epochs to perform (batch learning). Total iterations is |
---|
| 124 | * m_epochs * num instances |
---|
| 125 | */ |
---|
| 126 | protected int m_epochs = 500; |
---|
| 127 | |
---|
| 128 | /** |
---|
| 129 | * Turn off normalization of the input data. This option gets |
---|
| 130 | * forced for incremental training. |
---|
| 131 | */ |
---|
| 132 | protected boolean m_dontNormalize = false; |
---|
| 133 | |
---|
| 134 | /** |
---|
| 135 | * Turn off global replacement of missing values. Missing values |
---|
| 136 | * will be ignored instead. This option gets forced for |
---|
| 137 | * incremental training. |
---|
| 138 | */ |
---|
| 139 | protected boolean m_dontReplaceMissing = false; |
---|
| 140 | |
---|
| 141 | /** Holds the header of the training data */ |
---|
| 142 | protected Instances m_data; |
---|
| 143 | |
---|
| 144 | /** |
---|
| 145 | * Returns default capabilities of the classifier. |
---|
| 146 | * |
---|
| 147 | * @return the capabilities of this classifier |
---|
| 148 | */ |
---|
| 149 | public Capabilities getCapabilities() { |
---|
| 150 | Capabilities result = super.getCapabilities(); |
---|
| 151 | result.disableAll(); |
---|
| 152 | |
---|
| 153 | //attributes |
---|
| 154 | result.enable(Capability.NOMINAL_ATTRIBUTES); |
---|
| 155 | result.enable(Capability.NUMERIC_ATTRIBUTES); |
---|
| 156 | result.enable(Capability.MISSING_VALUES); |
---|
| 157 | |
---|
| 158 | // class |
---|
| 159 | result.enable(Capability.BINARY_CLASS); |
---|
| 160 | result.enable(Capability.MISSING_CLASS_VALUES); |
---|
| 161 | |
---|
| 162 | // instances |
---|
| 163 | result.setMinimumNumberInstances(0); |
---|
| 164 | |
---|
| 165 | return result; |
---|
| 166 | } |
---|
| 167 | |
---|
| 168 | /** |
---|
| 169 | * Returns the tip text for this property |
---|
| 170 | * |
---|
| 171 | * @return tip text for this property suitable for |
---|
| 172 | * displaying in the explorer/experimenter gui |
---|
| 173 | */ |
---|
| 174 | public String lambdaTipText() { |
---|
| 175 | return "The regularization constant. (default = 0.0001)"; |
---|
| 176 | } |
---|
| 177 | |
---|
| 178 | /** |
---|
| 179 | * Set the value of lambda to use |
---|
| 180 | * |
---|
| 181 | * @param lambda the value of lambda to use |
---|
| 182 | */ |
---|
| 183 | public void setLambda(double lambda) { |
---|
| 184 | m_lambda = lambda; |
---|
| 185 | } |
---|
| 186 | |
---|
| 187 | /** |
---|
| 188 | * Get the current value of lambda |
---|
| 189 | * |
---|
| 190 | * @return the current value of lambda |
---|
| 191 | */ |
---|
| 192 | public double getLambda() { |
---|
| 193 | return m_lambda; |
---|
| 194 | } |
---|
| 195 | |
---|
| 196 | /** |
---|
| 197 | * Returns the tip text for this property |
---|
| 198 | * |
---|
| 199 | * @return tip text for this property suitable for |
---|
| 200 | * displaying in the explorer/experimenter gui |
---|
| 201 | */ |
---|
| 202 | public String epochsTipText() { |
---|
| 203 | return "The number of epochs to perform (batch learning). " + |
---|
| 204 | "The total number of iterations is epochs * num" + |
---|
| 205 | " instances."; |
---|
| 206 | } |
---|
| 207 | |
---|
| 208 | /** |
---|
| 209 | * Set the number of epochs to use |
---|
| 210 | * |
---|
| 211 | * @param e the number of epochs to use |
---|
| 212 | */ |
---|
| 213 | public void setEpochs(int e) { |
---|
| 214 | m_epochs = e; |
---|
| 215 | } |
---|
| 216 | |
---|
| 217 | /** |
---|
| 218 | * Get current number of epochs |
---|
| 219 | * |
---|
| 220 | * @return the current number of epochs |
---|
| 221 | */ |
---|
| 222 | public int getEpochs() { |
---|
| 223 | return m_epochs; |
---|
| 224 | } |
---|
| 225 | |
---|
| 226 | /** |
---|
| 227 | * Turn normalization off/on. |
---|
| 228 | * |
---|
| 229 | * @param m true if normalization is to be disabled. |
---|
| 230 | */ |
---|
| 231 | public void setDontNormalize(boolean m) { |
---|
| 232 | m_dontNormalize = m; |
---|
| 233 | } |
---|
| 234 | |
---|
| 235 | /** |
---|
| 236 | * Get whether normalization has been turned off. |
---|
| 237 | * |
---|
| 238 | * @return true if normalization has been disabled. |
---|
| 239 | */ |
---|
| 240 | public boolean getDontNormalize() { |
---|
| 241 | return m_dontNormalize; |
---|
| 242 | } |
---|
| 243 | |
---|
| 244 | /** |
---|
| 245 | * Returns the tip text for this property |
---|
| 246 | * |
---|
| 247 | * @return tip text for this property suitable for |
---|
| 248 | * displaying in the explorer/experimenter gui |
---|
| 249 | */ |
---|
| 250 | public String dontNormalizeTipText() { |
---|
| 251 | return "Turn normalization off"; |
---|
| 252 | } |
---|
| 253 | |
---|
| 254 | /** |
---|
| 255 | * Turn global replacement of missing values off/on. If turned off, |
---|
| 256 | * then missing values are effectively ignored. |
---|
| 257 | * |
---|
| 258 | * @param m true if global replacement of missing values is to be |
---|
| 259 | * turned off. |
---|
| 260 | */ |
---|
| 261 | public void setDontReplaceMissing(boolean m) { |
---|
| 262 | m_dontReplaceMissing = m; |
---|
| 263 | } |
---|
| 264 | |
---|
| 265 | /** |
---|
| 266 | * Get whether global replacement of missing values has been |
---|
| 267 | * disabled. |
---|
| 268 | * |
---|
| 269 | * @return true if global replacement of missing values has been turned |
---|
| 270 | * off |
---|
| 271 | */ |
---|
| 272 | public boolean getDontReplaceMissing() { |
---|
| 273 | return m_dontReplaceMissing; |
---|
| 274 | } |
---|
| 275 | |
---|
| 276 | /** |
---|
| 277 | * Returns the tip text for this property |
---|
| 278 | * |
---|
| 279 | * @return tip text for this property suitable for |
---|
| 280 | * displaying in the explorer/experimenter gui |
---|
| 281 | */ |
---|
| 282 | public String dontReplaceMissingTipText() { |
---|
| 283 | return "Turn off global replacement of missing values"; |
---|
| 284 | } |
---|
| 285 | |
---|
| 286 | /** |
---|
| 287 | * Set the loss function to use. |
---|
| 288 | * |
---|
| 289 | * @param function the loss function to use. |
---|
| 290 | */ |
---|
| 291 | public void setLossFunction(SelectedTag function) { |
---|
| 292 | if (function.getTags() == TAGS_SELECTION) { |
---|
| 293 | m_loss = function.getSelectedTag().getID(); |
---|
| 294 | } |
---|
| 295 | } |
---|
| 296 | |
---|
| 297 | /** |
---|
| 298 | * Get the current loss function. |
---|
| 299 | * |
---|
| 300 | * @return the current loss function. |
---|
| 301 | */ |
---|
| 302 | public SelectedTag getLossFunction() { |
---|
| 303 | return new SelectedTag(m_loss, TAGS_SELECTION); |
---|
| 304 | } |
---|
| 305 | |
---|
| 306 | /** |
---|
| 307 | * Returns the tip text for this property |
---|
| 308 | * |
---|
| 309 | * @return tip text for this property suitable for |
---|
| 310 | * displaying in the explorer/experimenter gui |
---|
| 311 | */ |
---|
| 312 | public String lossFunctionTipText() { |
---|
| 313 | return "The loss function to use. Hinge loss (SVM) " + |
---|
| 314 | "or log loss (logistic regression)."; |
---|
| 315 | } |
---|
| 316 | |
---|
| 317 | /** |
---|
| 318 | * Returns an enumeration describing the available options. |
---|
| 319 | * |
---|
| 320 | * @return an enumeration of all the available options. |
---|
| 321 | */ |
---|
| 322 | public Enumeration<Option> listOptions() { |
---|
| 323 | |
---|
| 324 | Vector<Option> newVector = new Vector<Option>(); |
---|
| 325 | |
---|
| 326 | newVector.add(new Option("\tSet the loss function to minimize. 0 = " + |
---|
| 327 | "hinge loss (SVM), 1 = log loss (logistic regression).\n" + |
---|
| 328 | "\t(default = 0)", "F", 1, "-F")); |
---|
| 329 | newVector.add(new Option("\tThe lambda regularization constant " + |
---|
| 330 | "(default = 0.0001)", |
---|
| 331 | "L", 1, "-L <double>")); |
---|
| 332 | newVector.add(new Option("\tThe number of epochs to perform (" + |
---|
| 333 | "batch learning only, default = 500)", "E", 1, |
---|
| 334 | "-E <integer>")); |
---|
| 335 | newVector.add(new Option("\tDon't normalize the data", "N", 0, "-N")); |
---|
| 336 | newVector.add(new Option("\tDon't replace missing values", "M", 0, "-M")); |
---|
| 337 | |
---|
| 338 | return newVector.elements(); |
---|
| 339 | } |
---|
| 340 | |
---|
| 341 | /** |
---|
| 342 | * Parses a given list of options. <p/> |
---|
| 343 | * |
---|
| 344 | <!-- options-start --> |
---|
| 345 | * Valid options are: <p/> |
---|
| 346 | * |
---|
| 347 | * <pre> -F |
---|
| 348 | * Set the loss function to minimize. 0 = hinge loss (SVM), 1 = log loss (logistic regression). |
---|
| 349 | * (default = 0)</pre> |
---|
| 350 | * |
---|
| 351 | * <pre> -L <double> |
---|
| 352 | * The lambda regularization constant (default = 0.0001)</pre> |
---|
| 353 | * |
---|
| 354 | * <pre> -E <integer> |
---|
| 355 | * The number of epochs to perform (batch learning only, default = 500)</pre> |
---|
| 356 | * |
---|
| 357 | * <pre> -N |
---|
| 358 | * Don't normalize the data</pre> |
---|
| 359 | * |
---|
| 360 | * <pre> -M |
---|
| 361 | * Don't replace missing values</pre> |
---|
| 362 | * |
---|
| 363 | <!-- options-end --> |
---|
| 364 | * |
---|
| 365 | * @param options the list of options as an array of strings |
---|
| 366 | * @throws Exception if an option is not supported |
---|
| 367 | */ |
---|
| 368 | public void setOptions(String[] options) throws Exception { |
---|
| 369 | reset(); |
---|
| 370 | |
---|
| 371 | String lossString = Utils.getOption('F', options); |
---|
| 372 | if (lossString.length() != 0) { |
---|
| 373 | setLossFunction(new SelectedTag(Integer.parseInt(lossString), |
---|
| 374 | TAGS_SELECTION)); |
---|
| 375 | } else { |
---|
| 376 | setLossFunction(new SelectedTag(HINGE, TAGS_SELECTION)); |
---|
| 377 | } |
---|
| 378 | |
---|
| 379 | String lambdaString = Utils.getOption('L', options); |
---|
| 380 | if (lambdaString.length() > 0) { |
---|
| 381 | setLambda(Double.parseDouble(lambdaString)); |
---|
| 382 | } |
---|
| 383 | |
---|
| 384 | String epochsString = Utils.getOption("E", options); |
---|
| 385 | if (epochsString.length() > 0) { |
---|
| 386 | setEpochs(Integer.parseInt(epochsString)); |
---|
| 387 | } |
---|
| 388 | |
---|
| 389 | setDontNormalize(Utils.getFlag("N", options)); |
---|
| 390 | setDontReplaceMissing(Utils.getFlag('M', options)); |
---|
| 391 | } |
---|
| 392 | |
---|
| 393 | /** |
---|
| 394 | * Gets the current settings of the classifier. |
---|
| 395 | * |
---|
| 396 | * @return an array of strings suitable for passing to setOptions |
---|
| 397 | */ |
---|
| 398 | public String[] getOptions() { |
---|
| 399 | ArrayList<String> options = new ArrayList<String>(); |
---|
| 400 | |
---|
| 401 | options.add("-F"); options.add("" + getLossFunction().getSelectedTag().getID()); |
---|
| 402 | options.add("-L"); options.add("" + getLambda()); |
---|
| 403 | options.add("-E"); options.add("" + getEpochs()); |
---|
| 404 | if (getDontNormalize()) { |
---|
| 405 | options.add("-N"); |
---|
| 406 | } |
---|
| 407 | if (getDontReplaceMissing()) { |
---|
| 408 | options.add("-M"); |
---|
| 409 | } |
---|
| 410 | |
---|
| 411 | return options.toArray(new String[1]); |
---|
| 412 | } |
---|
| 413 | |
---|
| 414 | /** |
---|
| 415 | * Returns a string describing classifier |
---|
| 416 | * @return a description suitable for |
---|
| 417 | * displaying in the explorer/experimenter gui |
---|
| 418 | */ |
---|
| 419 | public String globalInfo() { |
---|
| 420 | return "Implements the stochastic variant of the Pegasos" + |
---|
| 421 | " (Primal Estimated sub-GrAdient SOlver for SVM)" + |
---|
| 422 | " method of Shalev-Shwartz et al. (2007). This implementation" + |
---|
| 423 | " globally replaces all missing values and transforms nominal" + |
---|
| 424 | " attributes into binary ones. It also normalizes all attributes," + |
---|
| 425 | " so the coefficients in the output are based on the normalized" + |
---|
| 426 | " data. Can either minimize the hinge loss (SVM) or log loss (" + |
---|
| 427 | "logistic regression). For more information, see\n\n" + |
---|
| 428 | getTechnicalInformation().toString(); |
---|
| 429 | } |
---|
| 430 | |
---|
| 431 | /** |
---|
| 432 | * Returns an instance of a TechnicalInformation object, containing |
---|
| 433 | * detailed information about the technical background of this class, |
---|
| 434 | * e.g., paper reference or book this class is based on. |
---|
| 435 | * |
---|
| 436 | * @return the technical information about this class |
---|
| 437 | */ |
---|
| 438 | public TechnicalInformation getTechnicalInformation() { |
---|
| 439 | TechnicalInformation result; |
---|
| 440 | |
---|
| 441 | result = new TechnicalInformation(Type.INPROCEEDINGS); |
---|
| 442 | result.setValue(Field.AUTHOR, "S. Shalev-Shwartz and Y. Singer and N. Srebro"); |
---|
| 443 | result.setValue(Field.YEAR, "2007"); |
---|
| 444 | result.setValue(Field.TITLE, "Pegasos: Primal Estimated sub-GrAdient " + |
---|
| 445 | "SOlver for SVM"); |
---|
| 446 | result.setValue(Field.BOOKTITLE, "24th International Conference on Machine" + |
---|
| 447 | "Learning"); |
---|
| 448 | result.setValue(Field.PAGES, "807-814"); |
---|
| 449 | |
---|
| 450 | return result; |
---|
| 451 | } |
---|
| 452 | |
---|
| 453 | /** |
---|
| 454 | * Reset the classifier. |
---|
| 455 | */ |
---|
| 456 | public void reset() { |
---|
| 457 | m_t = 1; |
---|
| 458 | m_weights = null; |
---|
| 459 | m_normalize = null; |
---|
| 460 | m_replaceMissing = null; |
---|
| 461 | m_nominalToBinary = null; |
---|
| 462 | } |
---|
| 463 | |
---|
| 464 | /** |
---|
| 465 | * Method for building the classifier. |
---|
| 466 | * |
---|
| 467 | * @param data the set of training instances. |
---|
| 468 | * @throws Exception if the classifier can't be built successfully. |
---|
| 469 | */ |
---|
| 470 | public void buildClassifier(Instances data) throws Exception { |
---|
| 471 | reset(); |
---|
| 472 | |
---|
| 473 | // can classifier handle the data? |
---|
| 474 | getCapabilities().testWithFail(data); |
---|
| 475 | |
---|
| 476 | data = new Instances(data); |
---|
| 477 | data.deleteWithMissingClass(); |
---|
| 478 | |
---|
| 479 | if (data.numInstances() > 0 && !m_dontReplaceMissing) { |
---|
| 480 | m_replaceMissing = new ReplaceMissingValues(); |
---|
| 481 | m_replaceMissing.setInputFormat(data); |
---|
| 482 | data = Filter.useFilter(data, m_replaceMissing); |
---|
| 483 | } |
---|
| 484 | |
---|
| 485 | // check for only numeric attributes |
---|
| 486 | boolean onlyNumeric = true; |
---|
| 487 | for (int i = 0; i < data.numAttributes(); i++) { |
---|
| 488 | if (i != data.classIndex()) { |
---|
| 489 | if (!data.attribute(i).isNumeric()) { |
---|
| 490 | onlyNumeric = false; |
---|
| 491 | break; |
---|
| 492 | } |
---|
| 493 | } |
---|
| 494 | } |
---|
| 495 | |
---|
| 496 | if (!onlyNumeric) { |
---|
| 497 | m_nominalToBinary = new NominalToBinary(); |
---|
| 498 | m_nominalToBinary.setInputFormat(data); |
---|
| 499 | data = Filter.useFilter(data, m_nominalToBinary); |
---|
| 500 | } |
---|
| 501 | |
---|
| 502 | if (!m_dontNormalize && data.numInstances() > 0) { |
---|
| 503 | |
---|
| 504 | m_normalize = new Normalize(); |
---|
| 505 | m_normalize.setInputFormat(data); |
---|
| 506 | data = Filter.useFilter(data, m_normalize); |
---|
| 507 | } |
---|
| 508 | |
---|
| 509 | m_weights = new double[data.numAttributes() + 1]; |
---|
| 510 | m_data = new Instances(data, 0); |
---|
| 511 | |
---|
| 512 | if (data.numInstances() > 0) { |
---|
| 513 | train(data); |
---|
| 514 | } |
---|
| 515 | } |
---|
| 516 | |
---|
| 517 | protected static final int HINGE = 0; |
---|
| 518 | protected static final int LOGLOSS = 1; |
---|
| 519 | |
---|
| 520 | /** The current loss function to minimize */ |
---|
| 521 | protected int m_loss = HINGE; |
---|
| 522 | |
---|
| 523 | /** Loss functions to choose from */ |
---|
| 524 | public static final Tag [] TAGS_SELECTION = { |
---|
| 525 | new Tag(HINGE, "Hinge loss (SVM)"), |
---|
| 526 | new Tag(LOGLOSS, "Log loss (logistic regression)") |
---|
| 527 | }; |
---|
| 528 | |
---|
| 529 | protected double dloss(double z) { |
---|
| 530 | if (m_loss == HINGE) { |
---|
| 531 | return (z < 1) ? 1 : 0; |
---|
| 532 | } |
---|
| 533 | |
---|
| 534 | // log loss |
---|
| 535 | if (z < 0) { |
---|
| 536 | return 1.0 / (Math.exp(z) + 1.0); |
---|
| 537 | } else { |
---|
| 538 | double t = Math.exp(-z); |
---|
| 539 | return t / (t + 1); |
---|
| 540 | } |
---|
| 541 | } |
---|
| 542 | |
---|
| 543 | private void train(Instances data) { |
---|
| 544 | for (int e = 0; e < m_epochs; e++) { |
---|
| 545 | for (int i = 0; i < data.numInstances(); i++) { |
---|
| 546 | Instance instance = data.instance(i); |
---|
| 547 | |
---|
| 548 | double learningRate = 1.0 / (m_lambda * m_t); |
---|
| 549 | //double scale = 1.0 - learningRate * m_lambda; |
---|
| 550 | double scale = 1.0 - 1.0 / m_t; |
---|
| 551 | double y = (instance.classValue() == 0) ? -1 : 1; |
---|
| 552 | double wx = dotProd(instance, m_weights, instance.classIndex()); |
---|
| 553 | double z = y * (wx + m_weights[m_weights.length - 1]); |
---|
| 554 | |
---|
| 555 | |
---|
| 556 | if (m_loss == LOGLOSS || (z < 1)) { |
---|
| 557 | double delta = learningRate * dloss(z); |
---|
| 558 | int n1 = instance.numValues(); |
---|
| 559 | int n2 = data.numAttributes(); |
---|
| 560 | for (int p1 = 0, p2 = 0; p2 < n2;) { |
---|
| 561 | int indS = 0; |
---|
| 562 | indS = (p1 < n1) ? instance.index(p1) : indS; |
---|
| 563 | int indP = p2; |
---|
| 564 | if (indP != data.classIndex()) { |
---|
| 565 | m_weights[indP] *= scale; |
---|
| 566 | } |
---|
| 567 | if (indS == indP) { |
---|
| 568 | if (indS != data.classIndex() && |
---|
| 569 | !instance.isMissingSparse(p1)) { |
---|
| 570 | //double m = learningRate * (instance.valueSparse(p1) * y); |
---|
| 571 | double m = delta * (instance.valueSparse(p1) * y); |
---|
| 572 | m_weights[indS] += m; |
---|
| 573 | } |
---|
| 574 | p1++; |
---|
| 575 | } |
---|
| 576 | p2++; |
---|
| 577 | } |
---|
| 578 | |
---|
| 579 | // update the bias |
---|
| 580 | m_weights[m_weights.length - 1] += delta * y; |
---|
| 581 | |
---|
| 582 | double norm = 0; |
---|
| 583 | for (int k = 0; k < m_weights.length; k++) { |
---|
| 584 | if (k != data.classIndex()) { |
---|
| 585 | norm += (m_weights[k] * m_weights[k]); |
---|
| 586 | } |
---|
| 587 | } |
---|
| 588 | norm = Math.sqrt(norm); |
---|
| 589 | |
---|
| 590 | double scale2 = Math.min(1.0, (1.0 / (Math.sqrt(m_lambda) * norm))); |
---|
| 591 | if (scale2 < 1.0) { |
---|
| 592 | for (int j = 0; j < m_weights.length; j++) { |
---|
| 593 | m_weights[j] *= scale2; |
---|
| 594 | } |
---|
| 595 | } |
---|
| 596 | } |
---|
| 597 | m_t++; |
---|
| 598 | } |
---|
| 599 | } |
---|
| 600 | } |
---|
| 601 | |
---|
| 602 | protected static double dotProd(Instance inst1, double[] weights, int classIndex) { |
---|
| 603 | double result = 0; |
---|
| 604 | |
---|
| 605 | int n1 = inst1.numValues(); |
---|
| 606 | int n2 = weights.length - 1; |
---|
| 607 | |
---|
| 608 | for (int p1 = 0, p2 = 0; p1 < n1 && p2 < n2;) { |
---|
| 609 | int ind1 = inst1.index(p1); |
---|
| 610 | int ind2 = p2; |
---|
| 611 | if (ind1 == ind2) { |
---|
| 612 | if (ind1 != classIndex && !inst1.isMissingSparse(p1)) { |
---|
| 613 | result += inst1.valueSparse(p1) * weights[p2]; |
---|
| 614 | } |
---|
| 615 | p1++; |
---|
| 616 | p2++; |
---|
| 617 | } else if (ind1 > ind2) { |
---|
| 618 | p2++; |
---|
| 619 | } else { |
---|
| 620 | p1++; |
---|
| 621 | } |
---|
| 622 | } |
---|
| 623 | return (result); |
---|
| 624 | } |
---|
| 625 | |
---|
| 626 | /** |
---|
| 627 | * Updates the classifier with the given instance. |
---|
| 628 | * |
---|
| 629 | * @param instance the new training instance to include in the model |
---|
| 630 | * @exception Exception if the instance could not be incorporated in |
---|
| 631 | * the model. |
---|
| 632 | */ |
---|
| 633 | public void updateClassifier(Instance instance) throws Exception { |
---|
| 634 | if (!instance.classIsMissing()) { |
---|
| 635 | double learningRate = 1.0 / (m_lambda * m_t); |
---|
| 636 | //double scale = 1.0 - learningRate * m_lambda; |
---|
| 637 | double scale = 1.0 - 1.0 / m_t; |
---|
| 638 | double y = (instance.classValue() == 0) ? -1 : 1; |
---|
| 639 | double wx = dotProd(instance, m_weights, instance.classIndex()); |
---|
| 640 | double z = y * (wx + m_weights[m_weights.length - 1]); |
---|
| 641 | |
---|
| 642 | for (int j = 0; j < m_weights.length; j++) { |
---|
| 643 | m_weights[j] *= scale; |
---|
| 644 | } |
---|
| 645 | |
---|
| 646 | if (m_loss == LOGLOSS || (z < 1)) { |
---|
| 647 | double delta = learningRate * dloss(z); |
---|
| 648 | int n1 = instance.numValues(); |
---|
| 649 | int n2 = instance.numAttributes(); |
---|
| 650 | for (int p1 = 0, p2 = 0; p2 < n2;) { |
---|
| 651 | int indS = 0; |
---|
| 652 | indS = (p1 < n1) ? instance.index(p1) : indS; |
---|
| 653 | int indP = p2; |
---|
| 654 | if (indP != instance.classIndex()) { |
---|
| 655 | m_weights[indP] *= scale; |
---|
| 656 | } |
---|
| 657 | if (indS == indP) { |
---|
| 658 | if (indS != instance.classIndex() && |
---|
| 659 | !instance.isMissingSparse(p1)) { |
---|
| 660 | double m = delta * (instance.valueSparse(p1) * y); |
---|
| 661 | m_weights[indS] += m; |
---|
| 662 | } |
---|
| 663 | p1++; |
---|
| 664 | } |
---|
| 665 | p2++; |
---|
| 666 | } |
---|
| 667 | |
---|
| 668 | // update the bias |
---|
| 669 | m_weights[m_weights.length - 1] += delta * y; |
---|
| 670 | |
---|
| 671 | double norm = 0; |
---|
| 672 | for (int k = 0; k < m_weights.length; k++) { |
---|
| 673 | if (k != instance.classIndex()) { |
---|
| 674 | norm += (m_weights[k] * m_weights[k]); |
---|
| 675 | } |
---|
| 676 | } |
---|
| 677 | norm = Math.sqrt(norm); |
---|
| 678 | |
---|
| 679 | double scale2 = Math.min(1.0, (1.0 / (Math.sqrt(m_lambda) * norm))); |
---|
| 680 | if (scale2 < 1.0) { |
---|
| 681 | for (int j = 0; j < m_weights.length; j++) { |
---|
| 682 | m_weights[j] *= scale2; |
---|
| 683 | } |
---|
| 684 | } |
---|
| 685 | } |
---|
| 686 | |
---|
| 687 | m_t++; |
---|
| 688 | } |
---|
| 689 | } |
---|
| 690 | |
---|
| 691 | /** |
---|
| 692 | * Computes the distribution for a given instance |
---|
| 693 | * |
---|
| 694 | * @param instance the instance for which distribution is computed |
---|
| 695 | * @return the distribution |
---|
| 696 | * @throws Exception if the distribution can't be computed successfully |
---|
| 697 | */ |
---|
| 698 | public double[] distributionForInstance(Instance inst) throws Exception { |
---|
| 699 | double[] result = new double[2]; |
---|
| 700 | |
---|
| 701 | if (m_replaceMissing != null) { |
---|
| 702 | m_replaceMissing.input(inst); |
---|
| 703 | inst = m_replaceMissing.output(); |
---|
| 704 | } |
---|
| 705 | |
---|
| 706 | if (m_nominalToBinary != null) { |
---|
| 707 | m_nominalToBinary.input(inst); |
---|
| 708 | inst = m_nominalToBinary.output(); |
---|
| 709 | } |
---|
| 710 | |
---|
| 711 | if (m_normalize != null){ |
---|
| 712 | m_normalize.input(inst); |
---|
| 713 | inst = m_normalize.output(); |
---|
| 714 | } |
---|
| 715 | |
---|
| 716 | double wx = dotProd(inst, m_weights, inst.classIndex());// * m_wScale; |
---|
| 717 | double z = (wx + m_weights[m_weights.length - 1]); |
---|
| 718 | //System.out.print("" + z + ": "); |
---|
| 719 | // System.out.println(1.0 / (1.0 + Math.exp(-z))); |
---|
| 720 | if (z <= 0) { |
---|
| 721 | // z = 0; |
---|
| 722 | if (m_loss == LOGLOSS) { |
---|
| 723 | result[0] = 1.0 / (1.0 + Math.exp(z)); |
---|
| 724 | result[1] = 1.0 - result[0]; |
---|
| 725 | } else { |
---|
| 726 | result[0] = 1; |
---|
| 727 | } |
---|
| 728 | } else { |
---|
| 729 | if (m_loss == LOGLOSS) { |
---|
| 730 | result[1] = 1.0 / (1.0 + Math.exp(-z)); |
---|
| 731 | result[0] = 1.0 - result[1]; |
---|
| 732 | } else { |
---|
| 733 | result[1] = 1; |
---|
| 734 | } |
---|
| 735 | } |
---|
| 736 | return result; |
---|
| 737 | } |
---|
| 738 | |
---|
| 739 | |
---|
| 740 | /** |
---|
| 741 | * Prints out the classifier. |
---|
| 742 | * |
---|
| 743 | * @return a description of the classifier as a string |
---|
| 744 | */ |
---|
| 745 | public String toString() { |
---|
| 746 | if (m_weights == null) { |
---|
| 747 | return "SPegasos: No model built yet.\n"; |
---|
| 748 | } |
---|
| 749 | StringBuffer buff = new StringBuffer(); |
---|
| 750 | buff.append("Loss function: "); |
---|
| 751 | if (m_loss == HINGE) { |
---|
| 752 | buff.append("Hinge loss (SVM)\n\n"); |
---|
| 753 | } else { |
---|
| 754 | buff.append("Log loss (logistic regression)\n\n"); |
---|
| 755 | } |
---|
| 756 | int printed = 0; |
---|
| 757 | |
---|
| 758 | for (int i = 0 ; i < m_weights.length - 1; i++) { |
---|
| 759 | if (i != m_data.classIndex()) { |
---|
| 760 | if (printed > 0) { |
---|
| 761 | buff.append(" + "); |
---|
| 762 | } else { |
---|
| 763 | buff.append(" "); |
---|
| 764 | } |
---|
| 765 | |
---|
| 766 | buff.append(Utils.doubleToString(m_weights[i], 12, 4) + |
---|
| 767 | " " + ((m_normalize != null) ? "(normalized) " : "") |
---|
| 768 | + m_data.attribute(i).name() + "\n"); |
---|
| 769 | |
---|
| 770 | printed++; |
---|
| 771 | } |
---|
| 772 | } |
---|
| 773 | |
---|
| 774 | if (m_weights[m_weights.length - 1] > 0) { |
---|
| 775 | buff.append(" + " + Utils.doubleToString(m_weights[m_weights.length - 1], 12, 4)); |
---|
| 776 | } else { |
---|
| 777 | buff.append(" - " + Utils.doubleToString(-m_weights[m_weights.length - 1], 12, 4)); |
---|
| 778 | } |
---|
| 779 | |
---|
| 780 | return buff.toString(); |
---|
| 781 | } |
---|
| 782 | |
---|
| 783 | /** |
---|
| 784 | * Returns the revision string. |
---|
| 785 | * |
---|
| 786 | * @return the revision |
---|
| 787 | */ |
---|
| 788 | public String getRevision() { |
---|
| 789 | return RevisionUtils.extract("$Revision: 6105 $"); |
---|
| 790 | } |
---|
| 791 | |
---|
| 792 | /** |
---|
| 793 | * Main method for testing this class. |
---|
| 794 | */ |
---|
| 795 | public static void main(String[] args) { |
---|
| 796 | runClassifier(new SPegasos(), args); |
---|
| 797 | } |
---|
| 798 | } |
---|
| 799 | |
---|