@@ -85,17 +85,50 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
8585 def getInitSteps : Int = $(initSteps)
8686
8787 /**
88- * Validates and transforms the input schema.
88+ * Validates the input schema.
8989 * @param schema input schema
90- * @return output schema
9190 */
92- protected def validateAndTransformSchema (schema : StructType ): StructType = {
91+ protected def validateSchema (schema : StructType ): Unit = {
9392 val typeCandidates = List ( new VectorUDT ,
9493 new ArrayType (DoubleType , false ),
9594 new ArrayType (FloatType , false ))
9695 SchemaUtils .checkColumnTypes(schema, $(featuresCol), typeCandidates)
96+ }
97+ /**
98+ * Validates and transforms the input schema.
99+ * @param schema input schema
100+ * @return output schema
101+ */
102+ protected def validateAndTransformSchema (schema : StructType ): StructType = {
103+ validateSchema(schema)
97104 SchemaUtils .appendColumn(schema, $(predictionCol), IntegerType )
98105 }
106+
107+ /**
108+ * preprocessing the input feature column to Vector
109+ * @param dataset DataFrame with columns for features
110+ * @param colName column name for features
111+ * @return Vector feature column
112+ */
113+ @ Since (" 2.4.0" )
114+ protected def featureToVector (dataset : Dataset [_], colName : String ): Column = {
115+ val featuresDataType = dataset.schema(colName).dataType
116+ featuresDataType match {
117+ case _ : VectorUDT => col(colName)
118+ case fdt : ArrayType =>
119+ val transferUDF = fdt.elementType match {
120+ case _ : FloatType => udf(f = (vector : Seq [Float ]) => {
121+ val featureArray = Array .fill[Double ](vector.size)(0.0 )
122+ vector.indices.foreach(idx => featureArray(idx) = vector(idx).toDouble)
123+ Vectors .dense(featureArray)
124+ })
125+ case _ : DoubleType => udf((vector : Seq [Double ]) => {
126+ Vectors .dense(vector.toArray)
127+ })
128+ }
129+ transferUDF(col(colName))
130+ }
131+ }
99132}
100133
101134/**
@@ -123,32 +156,13 @@ class KMeansModel private[ml] (
123156 @ Since (" 2.0.0" )
124157 def setPredictionCol (value : String ): this .type = set(predictionCol, value)
125158
126- @ Since (" 2.4.0" )
127- def featureToVector (dataset : Dataset [_], col : Column ): Column = {
128- val featuresDataType = dataset.schema(getFeaturesCol).dataType
129- val transferUDF = featuresDataType match {
130- case _ : VectorUDT => udf((vector : Vector ) => vector)
131- case fdt : ArrayType => fdt.elementType match {
132- case _ : FloatType => udf(f = (vector : Seq [Float ]) => {
133- val featureArray = Array .fill[Double ](vector.size)(0.0 )
134- vector.indices.foreach(idx => featureArray(idx) = vector(idx).toDouble)
135- Vectors .dense(featureArray)
136- })
137- case _ : DoubleType => udf((vector : Seq [Double ]) => {
138- Vectors .dense(vector.toArray)
139- })
140- }
141- }
142- transferUDF(col)
143- }
144-
145159 @ Since (" 2.0.0" )
146160 override def transform (dataset : Dataset [_]): DataFrame = {
147161 transformSchema(dataset.schema, logging = true )
148162
149163 val predictUDF = udf((vector : Vector ) => predict(vector))
150164
151- dataset.withColumn($(predictionCol), predictUDF(featureToVector(dataset, col( getFeaturesCol) )))
165+ dataset.withColumn($(predictionCol), predictUDF(featureToVector(dataset, getFeaturesCol)))
152166 }
153167
154168 @ Since (" 1.5.0" )
@@ -168,22 +182,9 @@ class KMeansModel private[ml] (
168182 // TODO: Replace the temp fix when we have proper evaluators defined for clustering.
169183 @ Since (" 2.0.0" )
170184 def computeCost (dataset : Dataset [_]): Double = {
171- val typeCandidates = List ( new VectorUDT ,
172- new ArrayType (DoubleType , false ),
173- new ArrayType (FloatType , false ))
174- SchemaUtils .checkColumnTypes(dataset.schema, $(featuresCol), typeCandidates)
185+ validateSchema(dataset.schema)
175186
176- /* val data: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
177- case Row(point: Vector) => OldVectors.fromML(point)
178- case Row(point: Seq[_]) =>
179- val featureArray = Array.fill[Double](point.size)(0.0)
180- for (idx <- point.indices) {
181- featureArray(idx) = point(idx).toString.toDouble
182- }
183- OldVectors.fromML(Vectors.dense(featureArray))
184- }
185- */
186- val data : RDD [OldVector ] = dataset.select(featureToVector(dataset, col(getFeaturesCol)))
187+ val data : RDD [OldVector ] = dataset.select(featureToVector(dataset, getFeaturesCol))
187188 .rdd.map {
188189 case Row (point : Vector ) => OldVectors .fromML(point)
189190 }
@@ -344,45 +345,16 @@ class KMeans @Since("1.5.0") (
344345 @ Since (" 1.5.0" )
345346 def setSeed (value : Long ): this .type = set(seed, value)
346347
347- @ Since (" 2.4.0" )
348- def featureToVector (dataset : Dataset [_], col : Column ): Column = {
349- val featuresDataType = dataset.schema(getFeaturesCol).dataType
350- val transferUDF = featuresDataType match {
351- case _ : VectorUDT => udf((vector : Vector ) => vector)
352- case fdt : ArrayType => fdt.elementType match {
353- case _ : FloatType => udf(f = (vector : Seq [Float ]) => {
354- val featureArray = Array .fill[Double ](vector.size)(0.0 )
355- vector.indices.foreach(idx => featureArray(idx) = vector(idx).toDouble)
356- Vectors .dense(featureArray)
357- })
358- case _ : DoubleType => udf((vector : Seq [Double ]) => {
359- Vectors .dense(vector.toArray)
360- })
361- }
362- }
363- transferUDF(col)
364- }
365-
366348 @ Since (" 2.0.0" )
367349 override def fit (dataset : Dataset [_]): KMeansModel = {
368350 transformSchema(dataset.schema, logging = true )
369351
370352 val handlePersistence = dataset.storageLevel == StorageLevel .NONE
371- val instances : RDD [OldVector ] = dataset.select(featureToVector(dataset, col( getFeaturesCol) ))
353+ val instances : RDD [OldVector ] = dataset.select(featureToVector(dataset, getFeaturesCol))
372354 .rdd.map {
373355 case Row (point : Vector ) => OldVectors .fromML(point)
374356 }
375- /*
376- val instances: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
377- case Row(point: Vector) => OldVectors.fromML(point)
378- case Row(point: Seq[_]) =>
379- val featureArray = Array.fill[Double](point.size)(0.0)
380- for (idx <- point.indices) {
381- featureArray(idx) = point(idx).toString.toDouble
382- }
383- OldVectors.fromML(Vectors.dense(featureArray))
384- }
385- */
357+
386358 if (handlePersistence) {
387359 instances.persist(StorageLevel .MEMORY_AND_DISK )
388360 }
0 commit comments