diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala index b7072728d48f..55b460f1a452 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.clustering +import scala.collection.mutable + import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -76,12 +78,15 @@ class PowerIterationClusteringSuite extends SparkFunSuite .setMaxIter(40) .setWeightCol("weight") .assignClusters(data) - val localAssignments = assignments - .select('id, 'cluster) - .as[(Long, Int)].collect().toSet - val expectedResult = (0 until n1).map(x => (x, 1)).toSet ++ - (n1 until n).map(x => (x, 0)).toSet - assert(localAssignments === expectedResult) + .select("id", "cluster") + .as[(Long, Int)] + .collect() + + val predictions = Array.fill(2)(mutable.Set.empty[Long]) + assignments.foreach { + case (id, cluster) => predictions(cluster) += id + } + assert(predictions.toSet === Set((0 until n1).toSet, (n1 until n).toSet)) val assignments2 = new PowerIterationClustering() .setK(2) @@ -89,10 +94,15 @@ class PowerIterationClusteringSuite extends SparkFunSuite .setInitMode("degree") .setWeightCol("weight") .assignClusters(data) - val localAssignments2 = assignments2 - .select('id, 'cluster) - .as[(Long, Int)].collect().toSet - assert(localAssignments2 === expectedResult) + .select("id", "cluster") + .as[(Long, Int)] + .collect() + + val predictions2 = Array.fill(2)(mutable.Set.empty[Long]) + assignments2.foreach { + case (id, cluster) => predictions2(cluster) += id + } + assert(predictions2.toSet === Set((0 until n1).toSet, (n1 until n).toSet)) } test("supported input types") {