Skip to content
31 changes: 25 additions & 6 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
Expand All @@ -32,7 +32,7 @@ import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, IntegerType, StructType}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.VersionUtils.majorVersion

Expand Down Expand Up @@ -90,7 +90,12 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
val typeCandidates = List( new VectorUDT,
new ArrayType(DoubleType, true),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about this, let's actually disallow nullable columns. KMeans won't handle nulls properly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, IntelliJ may warn you about passing boolean arguments as named arguments; that'd be nice to fix here.

new ArrayType(DoubleType, false),
new ArrayType(FloatType, true),
new ArrayType(FloatType, false))
SchemaUtils.checkColumnTypes(schema, $(featuresCol), typeCandidates)
SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
}
}
Expand Down Expand Up @@ -123,8 +128,15 @@ class KMeansModel private[ml] (
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val predictUDF = udf((vector: Vector) => predict(vector))
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
// val predictUDF = udf((vector: Vector) => predict(vector))
if (dataset.schema($(featuresCol)).dataType.equals(new VectorUDT)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tip: This can be more succinct if written as:

val predictUDF = if (dataset.schema(...).dataType.equals(...)) { A } else { B }
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))  // so this line is only written once

val predictUDF = udf((vector: Vector) => predict(vector))
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
} else {
val predictUDF = udf((vector: Seq[_]) =>
predict(Vectors.dense(vector.asInstanceOf[Seq[Double]].toArray)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may not work with arrays of FloatType.

dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scala style: } else {

}

@Since("1.5.0")
Expand All @@ -144,7 +156,12 @@ class KMeansModel private[ml] (
// TODO: Replace the temp fix when we have proper evaluators defined for clustering.
@Since("2.0.0")
def computeCost(dataset: Dataset[_]): Double = {
SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT)
val typeCandidates = List( new VectorUDT,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can reuse validateAndTransformSchema here.

new ArrayType(DoubleType, true),
new ArrayType(DoubleType, false),
new ArrayType(FloatType, true),
new ArrayType(FloatType, false))
SchemaUtils.checkColumnTypes(dataset.schema, $(featuresCol), typeCandidates)
val data: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
case Row(point: Vector) => OldVectors.fromML(point)
}
Expand Down Expand Up @@ -312,6 +329,8 @@ class KMeans @Since("1.5.0") (
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
val instances: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
case Row(point: Vector) => OldVectors.fromML(point)
case Row(point: Seq[_]) =>
OldVectors.fromML(Vectors.dense(point.asInstanceOf[Seq[Double]].toArray))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this will work with arrays of FloatType. Make sure to test it

}

if (handlePersistence) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.functions._

private[clustering] case class TestRow(features: Vector)

Expand Down Expand Up @@ -194,6 +195,34 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
assert(e.getCause.getMessage.contains("Cosine distance is not defined"))
}

test("KMean with Array input") {
val featuresColName = "array_model_features"

val arrayUDF = udf { (features: Vector) =>
features.toArray
}
val newdataset = dataset.withColumn(featuresColName, arrayUDF(col("features")) )
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: You could drop the original column as well just to make extra sure that it's not being accidentally used.


val kmeans = new KMeans()
.setFeaturesCol(featuresColName)

assert(kmeans.getK === 2)
assert(kmeans.getFeaturesCol === featuresColName)
assert(kmeans.getPredictionCol === "prediction")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to check this or the other Params which are not relevant to this test

assert(kmeans.getMaxIter === 20)
assert(kmeans.getInitMode === MLlibKMeans.K_MEANS_PARALLEL)
assert(kmeans.getInitSteps === 2)
assert(kmeans.getTol === 1e-4)
assert(kmeans.getDistanceMeasure === DistanceMeasure.EUCLIDEAN)
val model = kmeans.setMaxIter(1).fit(newdataset)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove unnecessary newline

MLTestingUtils.checkCopyAndUids(kmeans, model)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need this test here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto for hasSummary and copying

assert(model.hasSummary)
val copiedModel = model.copy(ParamMap.empty)
assert(copiedModel.hasSummary)
}


test("read/write") {
def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = {
assert(model.clusterCenters === model2.clusterCenters)
Expand Down