-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-23975][ML]Allow Clustering to take Arrays of Double as input features #21081
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
ed890d3
badb0cc
6d222a3
009b918
cd988c7
3ffb322
fee36ad
3e012fb
c4e1a51
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,7 +22,7 @@ import org.apache.hadoop.fs.Path | |
| import org.apache.spark.SparkException | ||
| import org.apache.spark.annotation.{Experimental, Since} | ||
| import org.apache.spark.ml.{Estimator, Model} | ||
| import org.apache.spark.ml.linalg.{Vector, VectorUDT} | ||
| import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} | ||
| import org.apache.spark.ml.param._ | ||
| import org.apache.spark.ml.param.shared._ | ||
| import org.apache.spark.ml.util._ | ||
|
|
@@ -32,7 +32,7 @@ import org.apache.spark.mllib.linalg.VectorImplicits._ | |
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.{DataFrame, Dataset, Row} | ||
| import org.apache.spark.sql.functions.{col, udf} | ||
| import org.apache.spark.sql.types.{IntegerType, StructType} | ||
| import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, IntegerType, StructType} | ||
| import org.apache.spark.storage.StorageLevel | ||
| import org.apache.spark.util.VersionUtils.majorVersion | ||
|
|
||
|
|
@@ -90,7 +90,12 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe | |
| * @return output schema | ||
| */ | ||
| protected def validateAndTransformSchema(schema: StructType): StructType = { | ||
| SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) | ||
| val typeCandidates = List( new VectorUDT, | ||
| new ArrayType(DoubleType, true), | ||
| new ArrayType(DoubleType, false), | ||
| new ArrayType(FloatType, true), | ||
| new ArrayType(FloatType, false)) | ||
| SchemaUtils.checkColumnTypes(schema, $(featuresCol), typeCandidates) | ||
| SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) | ||
| } | ||
| } | ||
|
|
@@ -123,8 +128,15 @@ class KMeansModel private[ml] ( | |
| @Since("2.0.0") | ||
| override def transform(dataset: Dataset[_]): DataFrame = { | ||
| transformSchema(dataset.schema, logging = true) | ||
| val predictUDF = udf((vector: Vector) => predict(vector)) | ||
| dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) | ||
| // val predictUDF = udf((vector: Vector) => predict(vector)) | ||
| if (dataset.schema($(featuresCol)).dataType.equals(new VectorUDT)) { | ||
|
||
| val predictUDF = udf((vector: Vector) => predict(vector)) | ||
| dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) | ||
| } else { | ||
| val predictUDF = udf((vector: Seq[_]) => | ||
| predict(Vectors.dense(vector.asInstanceOf[Seq[Double]].toArray))) | ||
|
||
| dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) | ||
| } | ||
|
||
| } | ||
|
|
||
| @Since("1.5.0") | ||
|
|
@@ -144,7 +156,12 @@ class KMeansModel private[ml] ( | |
| // TODO: Replace the temp fix when we have proper evaluators defined for clustering. | ||
| @Since("2.0.0") | ||
| def computeCost(dataset: Dataset[_]): Double = { | ||
| SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) | ||
| val typeCandidates = List( new VectorUDT, | ||
|
||
| new ArrayType(DoubleType, true), | ||
| new ArrayType(DoubleType, false), | ||
| new ArrayType(FloatType, true), | ||
| new ArrayType(FloatType, false)) | ||
| SchemaUtils.checkColumnTypes(dataset.schema, $(featuresCol), typeCandidates) | ||
| val data: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { | ||
| case Row(point: Vector) => OldVectors.fromML(point) | ||
| } | ||
|
|
@@ -312,6 +329,8 @@ class KMeans @Since("1.5.0") ( | |
| val handlePersistence = dataset.storageLevel == StorageLevel.NONE | ||
| val instances: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { | ||
| case Row(point: Vector) => OldVectors.fromML(point) | ||
| case Row(point: Seq[_]) => | ||
| OldVectors.fromML(Vectors.dense(point.asInstanceOf[Seq[Double]].toArray)) | ||
|
||
| } | ||
|
|
||
| if (handlePersistence) { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,6 +26,7 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} | |
| import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans} | ||
| import org.apache.spark.mllib.util.MLlibTestSparkContext | ||
| import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} | ||
| import org.apache.spark.sql.functions._ | ||
|
|
||
| private[clustering] case class TestRow(features: Vector) | ||
|
|
||
|
|
@@ -194,6 +195,34 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR | |
| assert(e.getCause.getMessage.contains("Cosine distance is not defined")) | ||
| } | ||
|
|
||
| test("KMean with Array input") { | ||
| val featuresColName = "array_model_features" | ||
|
|
||
| val arrayUDF = udf { (features: Vector) => | ||
| features.toArray | ||
| } | ||
| val newdataset = dataset.withColumn(featuresColName, arrayUDF(col("features")) ) | ||
|
||
|
|
||
| val kmeans = new KMeans() | ||
| .setFeaturesCol(featuresColName) | ||
|
|
||
| assert(kmeans.getK === 2) | ||
| assert(kmeans.getFeaturesCol === featuresColName) | ||
| assert(kmeans.getPredictionCol === "prediction") | ||
|
||
| assert(kmeans.getMaxIter === 20) | ||
| assert(kmeans.getInitMode === MLlibKMeans.K_MEANS_PARALLEL) | ||
| assert(kmeans.getInitSteps === 2) | ||
| assert(kmeans.getTol === 1e-4) | ||
| assert(kmeans.getDistanceMeasure === DistanceMeasure.EUCLIDEAN) | ||
| val model = kmeans.setMaxIter(1).fit(newdataset) | ||
|
|
||
|
||
| MLTestingUtils.checkCopyAndUids(kmeans, model) | ||
|
||
| assert(model.hasSummary) | ||
| val copiedModel = model.copy(ParamMap.empty) | ||
| assert(copiedModel.hasSummary) | ||
| } | ||
|
|
||
|
|
||
| test("read/write") { | ||
| def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = { | ||
| assert(model.clusterCenters === model2.clusterCenters) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thinking about this, let's actually disallow nullable columns. KMeans won't handle nulls properly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, IntelliJ may warn you about passing boolean arguments as named arguments; that'd be nice to fix here.