Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, IntegerType, StructType}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.VersionUtils.majorVersion
Expand Down Expand Up @@ -94,6 +94,7 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
val typeCandidates = List( new VectorUDT,
new ArrayType(DoubleType, false),
new ArrayType(FloatType, false))

SchemaUtils.checkColumnTypes(schema, $(featuresCol), typeCandidates)
}
/**
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scala style: always put newline between methods

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ping: There needs to be a newline between the "}" of the previous method and the "/**" Scaladoc of the next method. Please start checking for this.

Expand Down
31 changes: 20 additions & 11 deletions mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.ml.util

import org.apache.spark.annotation.Since
import org.apache.spark.ml.linalg.{Vectors, VectorUDT}
import org.apache.spark.sql.{Column, Dataset}
import org.apache.spark.sql.functions.{col, udf}
Expand All @@ -27,28 +26,38 @@ import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType}
private[spark] object DatasetUtils {

/**
* preprocessing the input feature column to Vector
* @param dataset DataFrame with columns for features
* @param colName column name for features
* @return Vector feature column
* Cast a column in a Dataset to Vector type.
*
* The supported data types of the input column are
* - Vector
* - float/double type Array.
*
* Note: The returned column does not have Metadata.
*
* @param dataset input DataFrame
* @param colName column name.
* @return Vector column
*/
@Since("2.4.0")
def columnToVector(dataset: Dataset[_], colName: String): Column = {
val featuresDataType = dataset.schema(colName).dataType
featuresDataType match {
val columnDataType = dataset.schema(colName).dataType
columnDataType match {
case _: VectorUDT => col(colName)
case fdt: ArrayType =>
val transferUDF = fdt.elementType match {
case _: FloatType => udf(f = (vector: Seq[Float]) => {
val featureArray = Array.fill[Double](vector.size)(0.0)
vector.indices.foreach(idx => featureArray(idx) = vector(idx).toDouble)
Vectors.dense(featureArray)
val inputArray = Array.fill[Double](vector.size)(0.0)
vector.indices.foreach(idx => inputArray(idx) = vector(idx).toDouble)
Vectors.dense(inputArray)
})
case _: DoubleType => udf((vector: Seq[Double]) => {
Vectors.dense(vector.toArray)
})
case other =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I forgot about this since this was generalized.

throw new IllegalArgumentException(s"Array[$other] column cannot be cast to Vector")
}
transferUDF(col(colName))
case other =>
throw new IllegalArgumentException(s"$other column cannot be cast to Vector")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -220,25 +220,20 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
.drop("features")
val newdatasetF = dataset.withColumn(featuresColNameF, floatUDF(col("features")))
.drop("features")

assert(newdatasetD.schema(featuresColNameD).dataType.equals(new ArrayType(DoubleType, false)))
assert(newdatasetF.schema(featuresColNameF).dataType.equals(new ArrayType(FloatType, false)))

val kmeansD = new KMeans().setK(k).setFeaturesCol(featuresColNameD).setSeed(1)
val kmeansF = new KMeans().setK(k).setFeaturesCol(featuresColNameF).setSeed(1)
val kmeansD = new KMeans().setK(k).setMaxIter(1).setFeaturesCol(featuresColNameD).setSeed(1)
val kmeansF = new KMeans().setK(k).setMaxIter(1).setFeaturesCol(featuresColNameF).setSeed(1)
val modelD = kmeansD.fit(newdatasetD)
val modelF = kmeansF.fit(newdatasetF)

val transformedD = modelD.transform(newdatasetD)
val transformedF = modelF.transform(newdatasetF)

val predictDifference = transformedD.select("prediction")
.except(transformedF.select("prediction"))

assert(predictDifference.count() == 0)

assert(modelD.computeCost(newdatasetD) == modelF.computeCost(newdatasetF) )

}


Expand Down