Skip to content
48 changes: 42 additions & 6 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
Expand All @@ -32,7 +32,7 @@ import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, IntegerType, StructType}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.VersionUtils.majorVersion

Expand Down Expand Up @@ -90,7 +90,12 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
val typeCandidates = List( new VectorUDT,
new ArrayType(DoubleType, true),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about this, let's actually disallow nullable columns. KMeans won't handle nulls properly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, IntelliJ may warn you about passing boolean arguments as named arguments; that'd be nice to fix here.

new ArrayType(DoubleType, false),
new ArrayType(FloatType, true),
new ArrayType(FloatType, false))
SchemaUtils.checkColumnTypes(schema, $(featuresCol), typeCandidates)
SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
}
}
Expand Down Expand Up @@ -123,7 +128,21 @@ class KMeansModel private[ml] (
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val predictUDF = udf((vector: Vector) => predict(vector))
// val predictUDF = udf((vector: Vector) => predict(vector))
val predictUDF = if (dataset.schema($(featuresCol)).dataType.equals(new VectorUDT)) {
udf((vector: Vector) => predict(vector))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Side note: I realized that "predict" will cause the whole model to be serialized and sent to workers. But that's actually OK since we do need to send most of the model data to make predictions and since there's not a clean way to just sent the model weights. So I think my previous comment about copying "numClasses" to a local variable was not necessary. Don't bother reverting the change though.

}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scala style: } else {

else {
udf((vector: Seq[_]) => {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scala style: remove unnecessary { at end of line (IntelliJ should warn you about this)

val featureArray = Array.fill[Double](vector.size)(0.0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You shouldn't have to do the conversion in this convoluted (and less efficient) way. I'd recommend doing a match-case statement on dataset.schema; I think that will be the most succinct. Then you can handle Vector, Seq of Float, and Seq of Double separately, without conversions to strings.

Same for the similar cases below.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's what I meant:

    val predictUDF = featuresDataType match {
      case _: VectorUDT =>
        udf((vector: Vector) => predict(vector))
      case fdt: ArrayType => fdt.elementType match {
        case _: FloatType =>
          ???
        case _: DoubleType =>
          ???
      }
    }

for (idx <- 0 until vector.size) {
featureArray(idx) = vector(idx).toString().toDouble
}
OldVectors.fromML(Vectors.dense(featureArray))
predict(Vectors.dense(featureArray))
})
}

dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}

Expand All @@ -144,9 +163,20 @@ class KMeansModel private[ml] (
// TODO: Replace the temp fix when we have proper evaluators defined for clustering.
@Since("2.0.0")
def computeCost(dataset: Dataset[_]): Double = {
SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT)
val typeCandidates = List( new VectorUDT,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can reuse validateAndTransformSchema here.

new ArrayType(DoubleType, true),
new ArrayType(DoubleType, false),
new ArrayType(FloatType, true),
new ArrayType(FloatType, false))
SchemaUtils.checkColumnTypes(dataset.schema, $(featuresCol), typeCandidates)
val data: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
case Row(point: Vector) => OldVectors.fromML(point)
case Row(point: Seq[_]) =>
val featureArray = Array.fill[Double](point.size)(0.0)
for (idx <- 0 until point.size) {
featureArray(idx) = point(idx).toString().toDouble
}
OldVectors.fromML(Vectors.dense(featureArray))
}
parentModel.computeCost(data)
}
Expand Down Expand Up @@ -312,7 +342,13 @@ class KMeans @Since("1.5.0") (
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
val instances: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
case Row(point: Vector) => OldVectors.fromML(point)
}
case Row(point: Seq[_]) =>
val featureArray = Array.fill[Double](point.size)(0.0)
for (idx <- 0 until point.size) {
featureArray(idx) = point(idx).toString().toDouble
}
OldVectors.fromML(Vectors.dense(featureArray))
}

if (handlePersistence) {
instances.persist(StorageLevel.MEMORY_AND_DISK)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, IntegerType, StructType}

private[clustering] case class TestRow(features: Vector)

Expand Down Expand Up @@ -194,6 +196,47 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
assert(e.getCause.getMessage.contains("Cosine distance is not defined"))
}

test("KMean with Array input") {
val featuresColNameD = "array_double_features"
val featuresColNameF = "array_float_features"

val doubleUDF = udf { (features: Vector) =>
val featureArray = Array.fill[Double](features.size)(0.0)
features.foreachActive((idx, value) => featureArray(idx) = value.toFloat)
featureArray
}
val floatUDF = udf { (features: Vector) =>
val featureArray = Array.fill[Float](features.size)(0.0f)
features.foreachActive((idx, value) => featureArray(idx) = value.toFloat)
featureArray
}

val newdatasetD = dataset.withColumn(featuresColNameD, doubleUDF(col("features")))
.drop("features")
val newdatasetF = dataset.withColumn(featuresColNameF, floatUDF(col("features")))
.drop("features")

assert(newdatasetD.schema(featuresColNameD).dataType.equals(new ArrayType(DoubleType, false)))
assert(newdatasetF.schema(featuresColNameF).dataType.equals(new ArrayType(FloatType, false)))

val kmeansD = new KMeans().setK(k).setFeaturesCol(featuresColNameD).setSeed(1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also do: setMaxIter(1) to make this a little faster.

val kmeansF = new KMeans().setK(k).setFeaturesCol(featuresColNameF).setSeed(1)
val modelD = kmeansD.fit(newdatasetD)
val modelF = kmeansF.fit(newdatasetF)

val transformedD = modelD.transform(newdatasetD)
val transformedF = modelF.transform(newdatasetF)

val predictDifference = transformedD.select("prediction")
.except(transformedF.select("prediction"))

assert(predictDifference.count() == 0)

assert(modelD.computeCost(newdatasetD) == modelF.computeCost(newdatasetF) )

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove unnecessary newline

}


test("read/write") {
def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = {
assert(model.clusterCenters === model2.clusterCenters)
Expand Down