@@ -27,6 +27,7 @@ import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans
2727import org .apache .spark .mllib .util .MLlibTestSparkContext
2828import org .apache .spark .sql .{DataFrame , Dataset , SparkSession }
2929import org .apache .spark .sql .functions ._
30+ import org .apache .spark .sql .types .{ArrayType , DoubleType , FloatType , IntegerType , StructType }
3031
3132private [clustering] case class TestRow (features : Vector )
3233
@@ -196,30 +197,43 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
196197 }
197198
198199 test(" KMean with Array input" ) {
199- val featuresColName = " array_model_features"
200+ val featuresColNameD = " array_double_features"
201+ val featuresColNameF = " array_float_features"
200202
201- val arrayUDF = udf { (features : Vector ) =>
202- features.toArray
203+ val doubleUDF = udf { (features : Vector ) =>
204+ val featureArray = Array .fill[Double ](features.size)(0.0 )
205+ features.foreachActive((idx, value) => featureArray(idx) = value.toFloat)
206+ featureArray
207+ }
208+ val floatUDF = udf { (features : Vector ) =>
209+ val featureArray = Array .fill[Float ](features.size)(0.0f )
210+ features.foreachActive((idx, value) => featureArray(idx) = value.toFloat)
211+ featureArray
203212 }
204- val newdataset = dataset.withColumn(featuresColName, arrayUDF(col(" features" )) )
205213
206- val kmeans = new KMeans ()
207- .setFeaturesCol(featuresColName)
214+ val newdatasetD = dataset.withColumn(featuresColNameD, doubleUDF(col(" features" )))
215+ .drop(" features" )
216+ val newdatasetF = dataset.withColumn(featuresColNameF, floatUDF(col(" features" )))
217+ .drop(" features" )
208218
209- assert(kmeans.getK === 2 )
210- assert(kmeans.getFeaturesCol === featuresColName)
211- assert(kmeans.getPredictionCol === " prediction" )
212- assert(kmeans.getMaxIter === 20 )
213- assert(kmeans.getInitMode === MLlibKMeans .K_MEANS_PARALLEL )
214- assert(kmeans.getInitSteps === 2 )
215- assert(kmeans.getTol === 1e-4 )
216- assert(kmeans.getDistanceMeasure === DistanceMeasure .EUCLIDEAN )
217- val model = kmeans.setMaxIter(1 ).fit(newdataset)
219+ assert(newdatasetD.schema(featuresColNameD).dataType.equals(new ArrayType (DoubleType , false )))
220+ assert(newdatasetF.schema(featuresColNameF).dataType.equals(new ArrayType (FloatType , false )))
221+
222+ val kmeansD = new KMeans ().setK(k).setFeaturesCol(featuresColNameD).setSeed(1 )
223+ val kmeansF = new KMeans ().setK(k).setFeaturesCol(featuresColNameF).setSeed(1 )
224+ val modelD = kmeansD.fit(newdatasetD)
225+ val modelF = kmeansF.fit(newdatasetF)
226+
227+ val transformedD = modelD.transform(newdatasetD)
228+ val transformedF = modelF.transform(newdatasetF)
229+
230+ val predictDifference = transformedD.select(" prediction" )
231+ .except(transformedF.select(" prediction" ))
232+
233+ assert(predictDifference.count() == 0 )
234+
235+ assert(modelD.computeCost(newdatasetD) == modelF.computeCost(newdatasetF) )
218236
219- MLTestingUtils .checkCopyAndUids(kmeans, model)
220- assert(model.hasSummary)
221- val copiedModel = model.copy(ParamMap .empty)
222- assert(copiedModel.hasSummary)
223237 }
224238
225239
0 commit comments