-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-2309][MLlib] Multinomial Logistic Regression #3833
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,17 +32,29 @@ import org.apache.spark.rdd.RDD | |
| * @param intercept Intercept computed for this model. (Only used in Binary Logistic Regression. | ||
| * In Multinomial Logistic Regression, the intercepts will not be a single values, | ||
| * so the intercepts will be part of the weights.) | ||
| * @param nClasses The number of possible outcomes for Multinomial Logistic Regression. | ||
| * The default value is 2 which is Binary Logistic Regression. | ||
| */ | ||
| class LogisticRegressionModel ( | ||
| override val weights: Vector, | ||
| override val intercept: Double, | ||
| nClasses: Int = 2) | ||
| override val intercept: Double) | ||
| extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable { | ||
|
|
||
| private var threshold: Option[Double] = Some(0.5) | ||
|
|
||
| private var nClasses = 2 | ||
|
|
||
| /** | ||
| * :: Experimental :: | ||
| * Set the number of possible outcomes for k classes classification problem in | ||
| * Multinomial Logistic Regression. | ||
| * By default, it is binary logistic regression so k will be set to 2. | ||
| */ | ||
| @Experimental | ||
| def setNumOfClasses(k: Int): this.type = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove this setter.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| assert(k > 1) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| nClasses = k | ||
| this | ||
| } | ||
|
|
||
| /** | ||
| * :: Experimental :: | ||
| * Sets the threshold that separates positive predictions from negative predictions | ||
|
|
@@ -77,50 +89,38 @@ class LogisticRegressionModel ( | |
| } | ||
| } else { | ||
| val dataWithBiasSize = weightMatrix.size / (nClasses - 1) | ||
| val dataWithBias = if(dataWithBiasSize == dataMatrix.size) { | ||
| val dataWithBias = if (dataWithBiasSize == dataMatrix.size) { | ||
| dataMatrix | ||
| } else { | ||
| } else { | ||
| assert(dataMatrix.size + 1 == dataWithBiasSize) | ||
| MLUtils.appendBias(dataMatrix) | ||
| } | ||
|
|
||
| val margins = Array.ofDim[Double](nClasses) | ||
|
|
||
| val weightsArray = weights match { | ||
| case dv: DenseVector => dv.values | ||
| case _ => | ||
| throw new IllegalArgumentException( | ||
| s"weights only supports dense vector but got type ${weights.getClass}.") | ||
| case dv: DenseVector => dv.values | ||
| case _ => | ||
| throw new IllegalArgumentException( | ||
| s"weights only supports dense vector but got type ${weights.getClass}.") | ||
| } | ||
|
|
||
| var i = 0 | ||
| while (i < nClasses - 1) { | ||
| val margins = (0 until nClasses - 1).map { i => | ||
| var margin = 0.0 | ||
| dataWithBias.foreachActive { (index, value) => | ||
| if (value != 0.0) margin += value * weightsArray((i * dataWithBiasSize) + index) | ||
| } | ||
| margins(i + 1) = margin | ||
| i += 1 | ||
| margin | ||
| } | ||
|
|
||
| /** | ||
| * Find the one with maximum margins. Note that `margins(0) == 0`. | ||
| * Find the one with maximum margins. If the maxMargin is negative, then the prediction | ||
| * result will be the first class. | ||
| * | ||
| * PS, if you want to compute the probabilities for each outcome instead of the outcome | ||
| * with maximum probability, remember to subtract the maxMargin from margins if maxMargin | ||
| * is positive to prevent overflow. | ||
| */ | ||
| var label = 0.0 | ||
| var max = margins(0) | ||
| i = 0 | ||
| while (i < nClasses) { | ||
| if (margins(i) > max) { | ||
| label = i | ||
| max = margins(i) | ||
| } | ||
| i += 1 | ||
| } | ||
| label | ||
| val maxMargin = margins.max | ||
| if (maxMargin > 0) (margins.indexOf(maxMargin) + 1).toDouble else 0.0 | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -265,10 +265,12 @@ class LogisticRegressionWithLBFGS | |
| validators = List(DataValidators.binaryLabelValidator) | ||
|
|
||
| /** | ||
| * :: Experimental :: | ||
| * Set the number of possible outcomes for k classes classification problem in | ||
| * Multinomial Logistic Regression. | ||
| * By default, it is binary logistic regression so k will be set to 2. | ||
| */ | ||
| @Experimental | ||
| def setNumOfClasses(k: Int): this.type = { | ||
| assert(k > 1) | ||
| numOfLinearPredictor = k - 1 | ||
|
|
@@ -279,6 +281,6 @@ class LogisticRegressionWithLBFGS | |
| } | ||
|
|
||
| override protected def createModel(weights: Vector, intercept: Double) = { | ||
| new LogisticRegressionModel(weights, intercept, numOfLinearPredictor + 1) | ||
| new LogisticRegressionModel(weights, intercept).setNumOfClasses(numOfLinearPredictor + 1) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move
numClassesto constructor and it should be aval.