Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Addressed feedback from Sean Owen
  • Loading branch information
DB Tsai committed Feb 2, 2015
commit 4348426b0f36c7f8257ea6c990288e56ab6c233c
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,29 @@ import org.apache.spark.rdd.RDD
* @param intercept Intercept computed for this model. (Only used in Binary Logistic Regression.
* In Multinomial Logistic Regression, the intercepts will not be a single values,
* so the intercepts will be part of the weights.)
* @param nClasses The number of possible outcomes for Multinomial Logistic Regression.
* The default value is 2 which is Binary Logistic Regression.
*/
class LogisticRegressionModel (
override val weights: Vector,
override val intercept: Double,
nClasses: Int = 2)
override val intercept: Double)
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable {

private var threshold: Option[Double] = Some(0.5)

private var nClasses = 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move numClasses to constructor and it should be a val.


/**
* :: Experimental ::
* Set the number of possible outcomes for k classes classification problem in
* Multinomial Logistic Regression.
* By default, it is binary logistic regression so k will be set to 2.
*/
@Experimental
def setNumOfClasses(k: Int): this.type = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this setter.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

setNumClasses, k->numClasses

assert(k > 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert -> require

nClasses = k
this
}

/**
* :: Experimental ::
* Sets the threshold that separates positive predictions from negative predictions
Expand Down Expand Up @@ -77,50 +89,38 @@ class LogisticRegressionModel (
}
} else {
val dataWithBiasSize = weightMatrix.size / (nClasses - 1)
val dataWithBias = if(dataWithBiasSize == dataMatrix.size) {
val dataWithBias = if (dataWithBiasSize == dataMatrix.size) {
dataMatrix
} else {
} else {
assert(dataMatrix.size + 1 == dataWithBiasSize)
MLUtils.appendBias(dataMatrix)
}

val margins = Array.ofDim[Double](nClasses)

val weightsArray = weights match {
case dv: DenseVector => dv.values
case _ =>
throw new IllegalArgumentException(
s"weights only supports dense vector but got type ${weights.getClass}.")
case dv: DenseVector => dv.values
case _ =>
throw new IllegalArgumentException(
s"weights only supports dense vector but got type ${weights.getClass}.")
}

var i = 0
while (i < nClasses - 1) {
val margins = (0 until nClasses - 1).map { i =>
var margin = 0.0
dataWithBias.foreachActive { (index, value) =>
if (value != 0.0) margin += value * weightsArray((i * dataWithBiasSize) + index)
}
margins(i + 1) = margin
i += 1
margin
}

/**
* Find the one with maximum margins. Note that `margins(0) == 0`.
* Find the one with maximum margins. If the maxMargin is negative, then the prediction
* result will be the first class.
*
* PS, if you want to compute the probabilities for each outcome instead of the outcome
* with maximum probability, remember to subtract the maxMargin from margins if maxMargin
* is positive to prevent overflow.
*/
var label = 0.0
var max = margins(0)
i = 0
while (i < nClasses) {
if (margins(i) > max) {
label = i
max = margins(i)
}
i += 1
}
label
val maxMargin = margins.max
if (maxMargin > 0) (margins.indexOf(maxMargin) + 1).toDouble else 0.0
}
}
}
Expand Down Expand Up @@ -265,10 +265,12 @@ class LogisticRegressionWithLBFGS
validators = List(DataValidators.binaryLabelValidator)

/**
* :: Experimental ::
* Set the number of possible outcomes for k classes classification problem in
* Multinomial Logistic Regression.
* By default, it is binary logistic regression so k will be set to 2.
*/
@Experimental
def setNumOfClasses(k: Int): this.type = {
assert(k > 1)
numOfLinearPredictor = k - 1
Expand All @@ -279,6 +281,6 @@ class LogisticRegressionWithLBFGS
}

override protected def createModel(weights: Vector, intercept: Double) = {
new LogisticRegressionModel(weights, intercept, numOfLinearPredictor + 1)
new LogisticRegressionModel(weights, intercept).setNumOfClasses(numOfLinearPredictor + 1)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -175,63 +175,50 @@ class LogisticGradient extends Gradient {
throw new IllegalArgumentException(
s"cumGradient only supports dense vector but got type ${cumGradient.getClass}.")
}
val margins = Array.ofDim[Double](n)

// marginY is margins(label - 1) in the formula.
var marginY = 0.0
var maxMargin = Double.NegativeInfinity
var maxMarginIndex = 0
var sum = 0.0

var i = 0
while (i < n) {

val margins = (0 until n).map { i =>
var margin = 0.0
data.foreachActive { (index, value) =>
if (value != 0.0) margin += value * weightsArray((i * dataSize) + index)
}
margins(i) = margin
if (i == label.toInt - 1) marginY = margin
if (margin > maxMargin) {
maxMargin = margin
maxMarginIndex = i
}
i += 1
margin
}

if (maxMargin > 0) {
/**
* When maxMargin > 0, the original formula will cause overflow as we discuss
* in the previous comment.
* We address this by subtracting maxMargin from all the margins, so it's guaranteed
* that all of the new margins will be smaller than zero to prevent arithmetic overflow.
*/
i = 0
while (i < n) {
margins(i) -= maxMargin
if (i == maxMarginIndex) {
sum += math.exp(-maxMargin)
} else {
sum += math.exp(margins(i))
}
i += 1
}
} else {
i = 0
while (i < n) {
/**
* When maxMargin > 0, the original formula will cause overflow as we discuss
* in the previous comment.
* We address this by subtracting maxMargin from all the margins, so it's guaranteed
* that all of the new margins will be smaller than zero to prevent arithmetic overflow.
*/
if (maxMargin > 0) for (i <- 0 until n) {
margins(i) -= maxMargin
if (i == maxMarginIndex) {
sum += math.exp(-maxMargin)
} else {
sum += math.exp(margins(i))
i += 1
}
} else for (i <- 0 until n) {
sum += math.exp(margins(i))
}

i = 0
while (i < n) {
for (i <- 0 until n) {
val multiplier = math.exp(margins(i)) / (sum + 1.0) - {
if (label != 0.0 && label == i + 1) 1.0 else 0.0
}
data.foreachActive { (index, value) =>
if (value != 0.0) cumGradientArray(i * dataSize + index) += multiplier * value
}
i += 1
}

val loss = if (label > 0.0) math.log1p(sum) - marginY else math.log1p(sum)
Expand Down