diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 7694773c816b..ba5f8d70c819 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -303,7 +303,7 @@ class LogisticRegression @Since("1.2.0") ( throw new SparkException(msg) } - if (numClasses > 2) { + if (numClasses != 2) { val msg = s"Currently, LogisticRegression with ElasticNet in ML package only supports " + s"binary classification. Found $numClasses in the input dataset." logError(msg)