Skip to content
This repository was archived by the owner on Nov 15, 2024. It is now read-only.

Commit f8deaf0

Browse files
WeichenXu123MatthewRBruce
authored andcommitted
[SPARK-21681][ML] fix bug of MLOR do not work correctly when featureStd contains zero (backport PR for 2.2)
## What changes were proposed in this pull request? This is backport PR of apache#18896 fix bug of MLOR do not work correctly when featureStd contains zero We can reproduce the bug through such dataset (features including zero variance), will generate wrong result (all coefficients becomes 0) ``` val multinomialDatasetWithZeroVar = { val nPoints = 100 val coefficients = Array( -0.57997, 0.912083, -0.371077, -0.16624, -0.84355, -0.048509) val xMean = Array(5.843, 3.0) val xVariance = Array(0.6856, 0.0) // including zero variance val testData = generateMultinomialLogisticInput( coefficients, xMean, xVariance, addIntercept = true, nPoints, seed) val df = sc.parallelize(testData, 4).toDF().withColumn("weight", lit(1.0)) df.cache() df } ``` ## How was this patch tested? testcase added. Author: WeichenXu <[email protected]> Closes apache#19026 from WeichenXu123/fix_mlor_zero_var_bug_2_2.
1 parent d654b74 commit f8deaf0

File tree

2 files changed

+82
-5
lines changed

2 files changed

+82
-5
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1727,11 +1727,13 @@ private class LogisticAggregator(
17271727

17281728
val margins = new Array[Double](numClasses)
17291729
features.foreachActive { (index, value) =>
1730-
val stdValue = value / localFeaturesStd(index)
1731-
var j = 0
1732-
while (j < numClasses) {
1733-
margins(j) += localCoefficients(index * numClasses + j) * stdValue
1734-
j += 1
1730+
if (localFeaturesStd(index) != 0.0 && value != 0.0) {
1731+
val stdValue = value / localFeaturesStd(index)
1732+
var j = 0
1733+
while (j < numClasses) {
1734+
margins(j) += localCoefficients(index * numClasses + j) * stdValue
1735+
j += 1
1736+
}
17351737
}
17361738
}
17371739
var i = 0

mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class LogisticRegressionSuite
4545
@transient var smallMultinomialDataset: Dataset[_] = _
4646
@transient var binaryDataset: Dataset[_] = _
4747
@transient var multinomialDataset: Dataset[_] = _
48+
@transient var multinomialDatasetWithZeroVar: Dataset[_] = _
4849
private val eps: Double = 1e-5
4950

5051
override def beforeAll(): Unit = {
@@ -98,6 +99,23 @@ class LogisticRegressionSuite
9899
df.cache()
99100
df
100101
}
102+
103+
multinomialDatasetWithZeroVar = {
104+
val nPoints = 100
105+
val coefficients = Array(
106+
-0.57997, 0.912083, -0.371077,
107+
-0.16624, -0.84355, -0.048509)
108+
109+
val xMean = Array(5.843, 3.0)
110+
val xVariance = Array(0.6856, 0.0)
111+
112+
val testData = generateMultinomialLogisticInput(
113+
coefficients, xMean, xVariance, addIntercept = true, nPoints, seed)
114+
115+
val df = sc.parallelize(testData, 4).toDF().withColumn("weight", lit(1.0))
116+
df.cache()
117+
df
118+
}
101119
}
102120

103121
/**
@@ -111,6 +129,11 @@ class LogisticRegressionSuite
111129
multinomialDataset.rdd.map { case Row(label: Double, features: Vector, weight: Double) =>
112130
label + "," + weight + "," + features.toArray.mkString(",")
113131
}.repartition(1).saveAsTextFile("target/tmp/LogisticRegressionSuite/multinomialDataset")
132+
multinomialDatasetWithZeroVar.rdd.map {
133+
case Row(label: Double, features: Vector, weight: Double) =>
134+
label + "," + weight + "," + features.toArray.mkString(",")
135+
}.repartition(1)
136+
.saveAsTextFile("target/tmp/LogisticRegressionSuite/multinomialDatasetWithZeroVar")
114137
}
115138

116139
test("params") {
@@ -1391,6 +1414,58 @@ class LogisticRegressionSuite
13911414
assert(model2.interceptVector.toArray.sum ~== 0.0 absTol eps)
13921415
}
13931416

1417+
test("multinomial logistic regression with zero variance (SPARK-21681)") {
1418+
val sqlContext = multinomialDatasetWithZeroVar.sqlContext
1419+
import sqlContext.implicits._
1420+
val mlr = new LogisticRegression().setFamily("multinomial").setFitIntercept(true)
1421+
.setElasticNetParam(0.0).setRegParam(0.0).setStandardization(true).setWeightCol("weight")
1422+
1423+
val model = mlr.fit(multinomialDatasetWithZeroVar)
1424+
1425+
/*
1426+
Use the following R code to load the data and train the model using glmnet package.
1427+
library("glmnet")
1428+
data <- read.csv("path", header=FALSE)
1429+
label = as.factor(data$V1)
1430+
w = data$V2
1431+
features = as.matrix(data.frame(data$V3, data$V4))
1432+
coefficients = coef(glmnet(features, label, weights=w, family="multinomial",
1433+
alpha = 0, lambda = 0))
1434+
coefficients
1435+
$`0`
1436+
3 x 1 sparse Matrix of class "dgCMatrix"
1437+
s0
1438+
0.2658824
1439+
data.V3 0.1881871
1440+
data.V4 .
1441+
$`1`
1442+
3 x 1 sparse Matrix of class "dgCMatrix"
1443+
s0
1444+
0.53604701
1445+
data.V3 -0.02412645
1446+
data.V4 .
1447+
$`2`
1448+
3 x 1 sparse Matrix of class "dgCMatrix"
1449+
s0
1450+
-0.8019294
1451+
data.V3 -0.1640607
1452+
data.V4 .
1453+
*/
1454+
1455+
val coefficientsR = new DenseMatrix(3, 2, Array(
1456+
0.1881871, 0.0,
1457+
-0.02412645, 0.0,
1458+
-0.1640607, 0.0), isTransposed = true)
1459+
val interceptsR = Vectors.dense(0.2658824, 0.53604701, -0.8019294)
1460+
1461+
model.coefficientMatrix.colIter.foreach(v => assert(v.toArray.sum ~== 0.0 absTol eps))
1462+
1463+
assert(model.coefficientMatrix ~== coefficientsR relTol 0.05)
1464+
assert(model.coefficientMatrix.toArray.sum ~== 0.0 absTol eps)
1465+
assert(model.interceptVector ~== interceptsR relTol 0.05)
1466+
assert(model.interceptVector.toArray.sum ~== 0.0 absTol eps)
1467+
}
1468+
13941469
test("multinomial logistic regression with intercept without regularization with bound") {
13951470
// Bound constrained optimization with bound on one side.
13961471
val lowerBoundsOnCoefficients = Matrices.dense(3, 4, Array.fill(12)(1.0))

0 commit comments

Comments
 (0)