Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.SparkContext
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, Vectors, VectorUDT}
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol}
Expand Down Expand Up @@ -170,6 +171,15 @@ private[evaluation] abstract class Silhouette {
def overallScore(df: DataFrame, scoreColumn: Column): Double = {
df.select(avg(scoreColumn)).collect()(0).getDouble(0)
}

protected def getNumberOfFeatures(dataFrame: DataFrame, columnName: String): Int = {
val group = AttributeGroup.fromStructField(dataFrame.schema(columnName))
if (group.size < 0) {
dataFrame.select(col(columnName)).first().getAs[Vector](0).size
} else {
group.size
}
}
}

/**
Expand Down Expand Up @@ -360,7 +370,7 @@ private[evaluation] object SquaredEuclideanSilhouette extends Silhouette {
df: DataFrame,
predictionCol: String,
featuresCol: String): Map[Double, ClusterStats] = {
val numFeatures = df.select(col(featuresCol)).first().getAs[Vector](0).size
val numFeatures = getNumberOfFeatures(df, featuresCol)
val clustersStatsRDD = df.select(
col(predictionCol).cast(DoubleType), col(featuresCol), col("squaredNorm"))
.rdd
Expand Down Expand Up @@ -552,8 +562,11 @@ private[evaluation] object CosineSilhouette extends Silhouette {
* @return A [[scala.collection.immutable.Map]] which associates each cluster id to a
* its statistics (ie. the precomputed values `N` and `$\Omega_{\Gamma}$`).
*/
def computeClusterStats(df: DataFrame, predictionCol: String): Map[Double, (Vector, Long)] = {
val numFeatures = df.select(col(normalizedFeaturesColName)).first().getAs[Vector](0).size
def computeClusterStats(
df: DataFrame,
featuresCol: String,
predictionCol: String): Map[Double, (Vector, Long)] = {
val numFeatures = getNumberOfFeatures(df, featuresCol)
val clustersStatsRDD = df.select(
col(predictionCol).cast(DoubleType), col(normalizedFeaturesColName))
.rdd
Expand Down Expand Up @@ -626,7 +639,8 @@ private[evaluation] object CosineSilhouette extends Silhouette {
normalizeFeatureUDF(col(featuresCol)))

// compute aggregate values for clusters needed by the algorithm
val clustersStatsMap = computeClusterStats(dfWithNormalizedFeatures, predictionCol)
val clustersStatsMap = computeClusterStats(dfWithNormalizedFeatures, featuresCol,
predictionCol)

// Silhouette is reasonable only when the number of clusters is greater then 1
assert(clustersStatsMap.size > 1, "Number of clusters must be greater than one.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

package org.apache.spark.ml.evaluation

import org.apache.spark.SparkFunSuite
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.ml.util.TestingUtils._
Expand Down Expand Up @@ -100,4 +102,23 @@ class ClusteringEvaluatorSuite
}
}

test("SPARK-23568: we should use metadata to determine features number") {
val attributesNum = irisDataset.select("features").rdd.first().getAs[Vector](0).size
val attrGroup = new AttributeGroup("features", attributesNum)
val df = irisDataset.select($"features".as("features", attrGroup.toMetadata()), $"label")
require(AttributeGroup.fromStructField(df.schema("features"))
.numAttributes.isDefined, "numAttributes metadata should be defined")
val evaluator = new ClusteringEvaluator()
.setFeaturesCol("features")
.setPredictionCol("label")

// with the proper metadata we compute correctly the result
assert(evaluator.evaluate(df) ~== 0.6564679231 relTol 1e-5)

val wrongAttrGroup = new AttributeGroup("features", attributesNum + 1)
val dfWrong = irisDataset.select($"features".as("features", wrongAttrGroup.toMetadata()),
$"label")
// with wrong metadata the evaluator throws an Exception
intercept[SparkException](evaluator.evaluate(dfWrong))
}
}