-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-2309][MLlib] Multinomial Logistic Regression #3833
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,30 +18,42 @@ | |
| package org.apache.spark.mllib.classification | ||
|
|
||
| import org.apache.spark.annotation.Experimental | ||
| import org.apache.spark.mllib.linalg.Vector | ||
| import org.apache.spark.mllib.linalg.BLAS.dot | ||
| import org.apache.spark.mllib.linalg.{DenseVector, Vector} | ||
| import org.apache.spark.mllib.optimization._ | ||
| import org.apache.spark.mllib.regression._ | ||
| import org.apache.spark.mllib.util.DataValidators | ||
| import org.apache.spark.mllib.util.{DataValidators, MLUtils} | ||
| import org.apache.spark.rdd.RDD | ||
|
|
||
| /** | ||
| * Classification model trained using Logistic Regression. | ||
| * Classification model trained using Multinomial/Binary Logistic Regression. | ||
| * | ||
| * @param weights Weights computed for every feature. | ||
| * @param intercept Intercept computed for this model. | ||
| * @param intercept Intercept computed for this model. (Only used in Binary Logistic Regression. | ||
| * In Multinomial Logistic Regression, the intercepts will not be a single values, | ||
| * so the intercepts will be part of the weights.) | ||
| * @param featureSize the dimension of the features | ||
| * @param numClasses the number of possible outcomes for k classes classification problem in | ||
| * Multinomial Logistic Regression. By default, it is binary logistic regression | ||
| * so numClasses will be set to 2. | ||
| */ | ||
| class LogisticRegressionModel ( | ||
| override val weights: Vector, | ||
| override val intercept: Double) | ||
| override val intercept: Double, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will change this API, no? is that OK? Also optional args are a bit Java-unfriendly. Maybe at least a helper constructor here that sets
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. addressed. thks. |
||
| val featureSize: Int, | ||
| val numClasses: Int) | ||
| extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable { | ||
|
|
||
| def this(weights: Vector, intercept: Double, featureSize: Int) = | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to keep the original constructor to be backward compatible.
|
||
| this(weights, intercept, featureSize, 2) | ||
|
|
||
| private var threshold: Option[Double] = Some(0.5) | ||
|
|
||
| /** | ||
| * :: Experimental :: | ||
| * Sets the threshold that separates positive predictions from negative predictions. An example | ||
| * with prediction score greater than or equal to this threshold is identified as an positive, | ||
| * and negative otherwise. The default value is 0.5. | ||
| * Sets the threshold that separates positive predictions from negative predictions | ||
| * in Binary Logistic Regression. An example with prediction score greater than or equal to | ||
| * this threshold is identified as an positive, and negative otherwise. The default value is 0.5. | ||
| */ | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And this has no meaning or multinomial, unless it's the threshold for making any prediction at all.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the model should have api to predict as probability, and we have another transformer to take threshold so we can reuse the logic for all the probabilistic model. I will like to remove threshold stuff from LOR entirely. @mengxr what do u think? |
||
| @Experimental | ||
| def setThreshold(threshold: Double): this.type = { | ||
|
|
@@ -61,20 +73,68 @@ class LogisticRegressionModel ( | |
|
|
||
| override protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know this was in the existing code, but
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought about having weights as a matrix, but it's required to change so many places. For example, the gradient object has to change, the underline
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The argument to the gradient calculation is properly a vector of weights, so that need not change for API reasons. So is it just having to do the translation? it's a line of code I think, although requires a copy. Maybe someone else can weigh in with an opinion too. |
||
| intercept: Double) = { | ||
| val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept | ||
| val score = 1.0 / (1.0 + math.exp(-margin)) | ||
| threshold match { | ||
| case Some(t) => if (score > t) 1.0 else 0.0 | ||
| case None => score | ||
| require(dataMatrix.size == featureSize) | ||
|
|
||
| // If dataMatrix and weightMatrix have the same dimension, it's binary logistic regression. | ||
| if (numClasses == 2) { | ||
| require(featureSize == weightMatrix.size) | ||
| val margin = dot(weights, dataMatrix) + intercept | ||
| val score = 1.0 / (1.0 + math.exp(-margin)) | ||
| threshold match { | ||
| case Some(t) => if (score > t) 1.0 else 0.0 | ||
| case None => score | ||
| } | ||
| } else { | ||
| val dataWithBiasSize = weightMatrix.size / (numClasses - 1) | ||
|
|
||
| val weightsArray = weights match { | ||
| case dv: DenseVector => dv.values | ||
| case _ => | ||
| throw new IllegalArgumentException( | ||
| s"weights only supports dense vector but got type ${weights.getClass}.") | ||
| } | ||
|
|
||
| val margins = (0 until numClasses - 1).map { i => | ||
| var margin = 0.0 | ||
| dataMatrix.foreachActive { (index, value) => | ||
| if (value != 0.0) margin += value * weightsArray((i * dataWithBiasSize) + index) | ||
| } | ||
| // Intercept is required to be added into margin. | ||
| if (dataMatrix.size + 1 == dataWithBiasSize) { | ||
| margin += weightsArray((i * dataWithBiasSize) + dataMatrix.size) | ||
| } | ||
| margin | ||
| } | ||
|
|
||
| /** | ||
| * Find the one with maximum margins. If the maxMargin is negative, then the prediction | ||
| * result will be the first class. | ||
| * | ||
| * PS, if you want to compute the probabilities for each outcome instead of the outcome | ||
| * with maximum probability, remember to subtract the maxMargin from margins if maxMargin | ||
| * is positive to prevent overflow. | ||
| */ | ||
| var bestClass = 0 | ||
| var maxMargin = 0.0 | ||
| var i = 0 | ||
| while(i < margins.size) { | ||
| if (margins(i) > maxMargin) { | ||
| maxMargin = margins(i) | ||
| bestClass = i + 1 | ||
| } | ||
| i += 1 | ||
| } | ||
| bestClass.toDouble | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Train a classification model for Logistic Regression using Stochastic Gradient Descent. By | ||
| * default L2 regularization is used, which can be changed via | ||
| * [[LogisticRegressionWithSGD.optimizer]]. | ||
| * NOTE: Labels used in Logistic Regression should be {0, 1}. | ||
| * Train a classification model for Binary Logistic Regression | ||
| * using Stochastic Gradient Descent. By default L2 regularization is used, | ||
| * which can be changed via [[LogisticRegressionWithSGD.optimizer]]. | ||
| * NOTE: Labels used in Logistic Regression should be {0, 1, ..., k - 1} | ||
| * for k classes multi-label classification problem. | ||
| * Using [[LogisticRegressionWithLBFGS]] is recommended over this. | ||
| */ | ||
| class LogisticRegressionWithSGD private ( | ||
|
|
@@ -100,7 +160,7 @@ class LogisticRegressionWithSGD private ( | |
| def this() = this(1.0, 100, 0.01, 1.0) | ||
|
|
||
| override protected def createModel(weights: Vector, intercept: Double) = { | ||
| new LogisticRegressionModel(weights, intercept) | ||
| new LogisticRegressionModel(weights, intercept, numFeatures) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -194,9 +254,10 @@ object LogisticRegressionWithSGD { | |
| } | ||
|
|
||
| /** | ||
| * Train a classification model for Logistic Regression using Limited-memory BFGS. | ||
| * Standard feature scaling and L2 regularization are used by default. | ||
| * NOTE: Labels used in Logistic Regression should be {0, 1} | ||
| * Train a classification model for Multinomial/Binary Logistic Regression using | ||
| * Limited-memory BFGS. Standard feature scaling and L2 regularization are used by default. | ||
| * NOTE: Labels used in Logistic Regression should be {0, 1, ..., k - 1} | ||
| * for k classes multi-label classification problem. | ||
| */ | ||
| class LogisticRegressionWithLBFGS | ||
| extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable { | ||
|
|
@@ -205,9 +266,33 @@ class LogisticRegressionWithLBFGS | |
|
|
||
| override val optimizer = new LBFGS(new LogisticGradient, new SquaredL2Updater) | ||
|
|
||
| override protected val validators = List(DataValidators.binaryLabelValidator) | ||
| override protected val validators = List(multiLabelValidator) | ||
|
|
||
| private def multiLabelValidator: RDD[LabeledPoint] => Boolean = { data => | ||
| if (numOfLinearPredictor > 1) { | ||
| DataValidators.multiLabelValidator(numOfLinearPredictor + 1)(data) | ||
| } else { | ||
| DataValidators.binaryLabelValidator(data) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * :: Experimental :: | ||
| * Set the number of possible outcomes for k classes classification problem in | ||
| * Multinomial Logistic Regression. | ||
| * By default, it is binary logistic regression so k will be set to 2. | ||
| */ | ||
| @Experimental | ||
| def setNumClasses(numClasses: Int): this.type = { | ||
| require(numClasses > 1) | ||
| numOfLinearPredictor = numClasses - 1 | ||
| if (numClasses > 2) { | ||
| optimizer.setGradient(new LogisticGradient(numClasses)) | ||
| } | ||
| this | ||
| } | ||
|
|
||
| override protected def createModel(weights: Vector, intercept: Double) = { | ||
| new LogisticRegressionModel(weights, intercept) | ||
| new LogisticRegressionModel(weights, intercept, numFeatures, numOfLinearPredictor + 1) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
featureSize->numFeatures(to be consistent withnumClasses.