Skip to content

Commit 6d222a3

Browse files
committed
make sure the code works for Float type and add the unit test
1 parent badb0cc commit 6d222a3

File tree

2 files changed

+59
-28
lines changed

2 files changed

+59
-28
lines changed

mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -129,14 +129,21 @@ class KMeansModel private[ml] (
129129
override def transform(dataset: Dataset[_]): DataFrame = {
130130
transformSchema(dataset.schema, logging = true)
131131
// val predictUDF = udf((vector: Vector) => predict(vector))
132-
if (dataset.schema($(featuresCol)).dataType.equals(new VectorUDT)) {
133-
val predictUDF = udf((vector: Vector) => predict(vector))
134-
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
135-
} else {
136-
val predictUDF = udf((vector: Seq[_]) =>
137-
predict(Vectors.dense(vector.asInstanceOf[Seq[Double]].toArray)))
138-
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
132+
val predictUDF = if (dataset.schema($(featuresCol)).dataType.equals(new VectorUDT)) {
133+
udf((vector: Vector) => predict(vector))
139134
}
135+
else {
136+
udf((vector: Seq[_]) => {
137+
val featureArray = Array.fill[Double](vector.size)(0.0)
138+
for (idx <- 0 until vector.size) {
139+
featureArray(idx) = vector(idx).toString().toDouble
140+
}
141+
OldVectors.fromML(Vectors.dense(featureArray))
142+
predict(Vectors.dense(featureArray))
143+
})
144+
}
145+
146+
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
140147
}
141148

142149
@Since("1.5.0")
@@ -164,6 +171,12 @@ class KMeansModel private[ml] (
164171
SchemaUtils.checkColumnTypes(dataset.schema, $(featuresCol), typeCandidates)
165172
val data: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
166173
case Row(point: Vector) => OldVectors.fromML(point)
174+
case Row(point: Seq[_]) =>
175+
val featureArray = Array.fill[Double](point.size)(0.0)
176+
for (idx <- 0 until point.size) {
177+
featureArray(idx) = point(idx).toString().toDouble
178+
}
179+
OldVectors.fromML(Vectors.dense(featureArray))
167180
}
168181
parentModel.computeCost(data)
169182
}
@@ -330,8 +343,12 @@ class KMeans @Since("1.5.0") (
330343
val instances: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
331344
case Row(point: Vector) => OldVectors.fromML(point)
332345
case Row(point: Seq[_]) =>
333-
OldVectors.fromML(Vectors.dense(point.asInstanceOf[Seq[Double]].toArray))
334-
}
346+
val featureArray = Array.fill[Double](point.size)(0.0)
347+
for (idx <- 0 until point.size) {
348+
featureArray(idx) = point(idx).toString().toDouble
349+
}
350+
OldVectors.fromML(Vectors.dense(featureArray))
351+
}
335352

336353
if (handlePersistence) {
337354
instances.persist(StorageLevel.MEMORY_AND_DISK)

mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans
2727
import org.apache.spark.mllib.util.MLlibTestSparkContext
2828
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
2929
import org.apache.spark.sql.functions._
30+
import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, IntegerType, StructType}
3031

3132
private[clustering] case class TestRow(features: Vector)
3233

@@ -196,30 +197,43 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
196197
}
197198

198199
test("KMean with Array input") {
199-
val featuresColName = "array_model_features"
200+
val featuresColNameD = "array_double_features"
201+
val featuresColNameF = "array_float_features"
200202

201-
val arrayUDF = udf { (features: Vector) =>
202-
features.toArray
203+
val doubleUDF = udf { (features: Vector) =>
204+
val featureArray = Array.fill[Double](features.size)(0.0)
205+
features.foreachActive((idx, value) => featureArray(idx) = value.toFloat)
206+
featureArray
207+
}
208+
val floatUDF = udf { (features: Vector) =>
209+
val featureArray = Array.fill[Float](features.size)(0.0f)
210+
features.foreachActive((idx, value) => featureArray(idx) = value.toFloat)
211+
featureArray
203212
}
204-
val newdataset = dataset.withColumn(featuresColName, arrayUDF(col("features")) )
205213

206-
val kmeans = new KMeans()
207-
.setFeaturesCol(featuresColName)
214+
val newdatasetD = dataset.withColumn(featuresColNameD, doubleUDF(col("features")))
215+
.drop("features")
216+
val newdatasetF = dataset.withColumn(featuresColNameF, floatUDF(col("features")))
217+
.drop("features")
208218

209-
assert(kmeans.getK === 2)
210-
assert(kmeans.getFeaturesCol === featuresColName)
211-
assert(kmeans.getPredictionCol === "prediction")
212-
assert(kmeans.getMaxIter === 20)
213-
assert(kmeans.getInitMode === MLlibKMeans.K_MEANS_PARALLEL)
214-
assert(kmeans.getInitSteps === 2)
215-
assert(kmeans.getTol === 1e-4)
216-
assert(kmeans.getDistanceMeasure === DistanceMeasure.EUCLIDEAN)
217-
val model = kmeans.setMaxIter(1).fit(newdataset)
219+
assert(newdatasetD.schema(featuresColNameD).dataType.equals(new ArrayType(DoubleType, false)))
220+
assert(newdatasetF.schema(featuresColNameF).dataType.equals(new ArrayType(FloatType, false)))
221+
222+
val kmeansD = new KMeans().setK(k).setFeaturesCol(featuresColNameD).setSeed(1)
223+
val kmeansF = new KMeans().setK(k).setFeaturesCol(featuresColNameF).setSeed(1)
224+
val modelD = kmeansD.fit(newdatasetD)
225+
val modelF = kmeansF.fit(newdatasetF)
226+
227+
val transformedD = modelD.transform(newdatasetD)
228+
val transformedF = modelF.transform(newdatasetF)
229+
230+
val predictDifference = transformedD.select("prediction")
231+
.except(transformedF.select("prediction"))
232+
233+
assert(predictDifference.count() == 0)
234+
235+
assert(modelD.computeCost(newdatasetD) == modelF.computeCost(newdatasetF) )
218236

219-
MLTestingUtils.checkCopyAndUids(kmeans, model)
220-
assert(model.hasSummary)
221-
val copiedModel = model.copy(ParamMap.empty)
222-
assert(copiedModel.hasSummary)
223237
}
224238

225239

0 commit comments

Comments
 (0)