Skip to content

Commit cd988c7

Browse files
committed
change featureToVector to KMeanParams and add the scala docs
add validateSchema and use it in computeCost addressed the comments from @jkbradley
1 parent 009b918 commit cd988c7

File tree

1 file changed

+41
-69
lines changed
  • mllib/src/main/scala/org/apache/spark/ml/clustering

1 file changed

+41
-69
lines changed

mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala

Lines changed: 41 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -85,17 +85,50 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
8585
def getInitSteps: Int = $(initSteps)
8686

8787
/**
88-
* Validates and transforms the input schema.
88+
* Validates the input schema.
8989
* @param schema input schema
90-
* @return output schema
9190
*/
92-
protected def validateAndTransformSchema(schema: StructType): StructType = {
91+
protected def validateSchema(schema: StructType): Unit = {
9392
val typeCandidates = List( new VectorUDT,
9493
new ArrayType(DoubleType, false),
9594
new ArrayType(FloatType, false))
9695
SchemaUtils.checkColumnTypes(schema, $(featuresCol), typeCandidates)
96+
}
97+
/**
98+
* Validates and transforms the input schema.
99+
* @param schema input schema
100+
* @return output schema
101+
*/
102+
protected def validateAndTransformSchema(schema: StructType): StructType = {
103+
validateSchema(schema)
97104
SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
98105
}
106+
107+
/**
108+
* preprocessing the input feature column to Vector
109+
* @param dataset DataFrame with columns for features
110+
* @param colName column name for features
111+
* @return Vector feature column
112+
*/
113+
@Since("2.4.0")
114+
protected def featureToVector(dataset: Dataset[_], colName: String): Column = {
115+
val featuresDataType = dataset.schema(colName).dataType
116+
featuresDataType match {
117+
case _: VectorUDT => col(colName)
118+
case fdt: ArrayType =>
119+
val transferUDF = fdt.elementType match {
120+
case _: FloatType => udf(f = (vector: Seq[Float]) => {
121+
val featureArray = Array.fill[Double](vector.size)(0.0)
122+
vector.indices.foreach(idx => featureArray(idx) = vector(idx).toDouble)
123+
Vectors.dense(featureArray)
124+
})
125+
case _: DoubleType => udf((vector: Seq[Double]) => {
126+
Vectors.dense(vector.toArray)
127+
})
128+
}
129+
transferUDF(col(colName))
130+
}
131+
}
99132
}
100133

101134
/**
@@ -123,32 +156,13 @@ class KMeansModel private[ml] (
123156
@Since("2.0.0")
124157
def setPredictionCol(value: String): this.type = set(predictionCol, value)
125158

126-
@Since("2.4.0")
127-
def featureToVector(dataset: Dataset[_], col: Column): Column = {
128-
val featuresDataType = dataset.schema(getFeaturesCol).dataType
129-
val transferUDF = featuresDataType match {
130-
case _: VectorUDT => udf((vector: Vector) => vector)
131-
case fdt: ArrayType => fdt.elementType match {
132-
case _: FloatType => udf(f = (vector: Seq[Float]) => {
133-
val featureArray = Array.fill[Double](vector.size)(0.0)
134-
vector.indices.foreach(idx => featureArray(idx) = vector(idx).toDouble)
135-
Vectors.dense(featureArray)
136-
})
137-
case _: DoubleType => udf((vector: Seq[Double]) => {
138-
Vectors.dense(vector.toArray)
139-
})
140-
}
141-
}
142-
transferUDF(col)
143-
}
144-
145159
@Since("2.0.0")
146160
override def transform(dataset: Dataset[_]): DataFrame = {
147161
transformSchema(dataset.schema, logging = true)
148162

149163
val predictUDF = udf((vector: Vector) => predict(vector))
150164

151-
dataset.withColumn($(predictionCol), predictUDF(featureToVector(dataset, col(getFeaturesCol))))
165+
dataset.withColumn($(predictionCol), predictUDF(featureToVector(dataset, getFeaturesCol)))
152166
}
153167

154168
@Since("1.5.0")
@@ -168,22 +182,9 @@ class KMeansModel private[ml] (
168182
// TODO: Replace the temp fix when we have proper evaluators defined for clustering.
169183
@Since("2.0.0")
170184
def computeCost(dataset: Dataset[_]): Double = {
171-
val typeCandidates = List( new VectorUDT,
172-
new ArrayType(DoubleType, false),
173-
new ArrayType(FloatType, false))
174-
SchemaUtils.checkColumnTypes(dataset.schema, $(featuresCol), typeCandidates)
185+
validateSchema(dataset.schema)
175186

176-
/* val data: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
177-
case Row(point: Vector) => OldVectors.fromML(point)
178-
case Row(point: Seq[_]) =>
179-
val featureArray = Array.fill[Double](point.size)(0.0)
180-
for (idx <- point.indices) {
181-
featureArray(idx) = point(idx).toString.toDouble
182-
}
183-
OldVectors.fromML(Vectors.dense(featureArray))
184-
}
185-
*/
186-
val data: RDD[OldVector] = dataset.select(featureToVector(dataset, col(getFeaturesCol)))
187+
val data: RDD[OldVector] = dataset.select(featureToVector(dataset, getFeaturesCol))
187188
.rdd.map {
188189
case Row(point: Vector) => OldVectors.fromML(point)
189190
}
@@ -344,45 +345,16 @@ class KMeans @Since("1.5.0") (
344345
@Since("1.5.0")
345346
def setSeed(value: Long): this.type = set(seed, value)
346347

347-
@Since("2.4.0")
348-
def featureToVector(dataset: Dataset[_], col: Column): Column = {
349-
val featuresDataType = dataset.schema(getFeaturesCol).dataType
350-
val transferUDF = featuresDataType match {
351-
case _: VectorUDT => udf((vector: Vector) => vector)
352-
case fdt: ArrayType => fdt.elementType match {
353-
case _: FloatType => udf(f = (vector: Seq[Float]) => {
354-
val featureArray = Array.fill[Double](vector.size)(0.0)
355-
vector.indices.foreach(idx => featureArray(idx) = vector(idx).toDouble)
356-
Vectors.dense(featureArray)
357-
})
358-
case _: DoubleType => udf((vector: Seq[Double]) => {
359-
Vectors.dense(vector.toArray)
360-
})
361-
}
362-
}
363-
transferUDF(col)
364-
}
365-
366348
@Since("2.0.0")
367349
override def fit(dataset: Dataset[_]): KMeansModel = {
368350
transformSchema(dataset.schema, logging = true)
369351

370352
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
371-
val instances: RDD[OldVector] = dataset.select(featureToVector(dataset, col(getFeaturesCol)))
353+
val instances: RDD[OldVector] = dataset.select(featureToVector(dataset, getFeaturesCol))
372354
.rdd.map {
373355
case Row(point: Vector) => OldVectors.fromML(point)
374356
}
375-
/*
376-
val instances: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
377-
case Row(point: Vector) => OldVectors.fromML(point)
378-
case Row(point: Seq[_]) =>
379-
val featureArray = Array.fill[Double](point.size)(0.0)
380-
for (idx <- point.indices) {
381-
featureArray(idx) = point(idx).toString.toDouble
382-
}
383-
OldVectors.fromML(Vectors.dense(featureArray))
384-
}
385-
*/
357+
386358
if (handlePersistence) {
387359
instances.persist(StorageLevel.MEMORY_AND_DISK)
388360
}

0 commit comments

Comments
 (0)