diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index 569a5fb99376..d536926b137d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -194,4 +194,6 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, * This internal method is used to implement [[transform()]] and output [[predictionCol]]. */ protected def predict(features: FeaturesType): Double + + def transformInstance(features: FeaturesType): Double = {predict(features)} } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index a3a2b55adc25..70648fe6ea1f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -95,6 +95,8 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] */ protected def createTransformFunc: IN => OUT + def transformInstance(input: IN) : OUT = {this.createTransformFunc(input)} + /** * Returns the data type of the output column. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index fa9634fdfa7e..dfcec47c591d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -76,25 +76,9 @@ final class Binarizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) val outputSchema = transformSchema(dataset.schema, logging = true) val schema = dataset.schema val inputType = schema($(inputCol)).dataType - val td = $(threshold) - - val binarizerDouble = udf { in: Double => if (in > td) 1.0 else 0.0 } - val binarizerVector = udf { (data: Vector) => - val indices = ArrayBuilder.make[Int] - val values = ArrayBuilder.make[Double] - - data.foreachActive { (index, value) => - if (value > td) { - indices += index - values += 1.0 - } - } - - Vectors.sparse(data.size, indices.result(), values.result()).compressed - } - + val binarizerDouble = udf{(x: Double) => transformInstance(x)} + val binarizerVector = udf{(x: Vector) => transformInstance(x)} val metadata = outputSchema($(outputCol)).metadata - inputType match { case DoubleType => dataset.select(col("*"), binarizerDouble(col($(inputCol))).as($(outputCol), metadata)) @@ -103,6 +87,22 @@ final class Binarizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) } } + def transformInstance(in: Double): Double = {if (in > $(threshold)) 1.0 else 0.0} + + def transformInstance(data: Vector): Vector = { + val indices = ArrayBuilder.make[Int] + val values = ArrayBuilder.make[Double] + val td = $(threshold) + data.foreachActive { (index, value) => + if (value > td) { + indices += index + values += 1.0 + } + } + Vectors.sparse(data.size, indices.result(), values.result()).compressed + } + + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { val inputType = schema($(inputCol)).dataType diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index caffc39e2be1..7bcd54c61117 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -78,14 +78,15 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema) - val bucketizer = udf { feature: Double => - Bucketizer.binarySearchForBuckets($(splits), feature) - } + val bucketizer = udf { transformInstance _ } val newCol = bucketizer(dataset($(inputCol))) val newField = prepOutputField(dataset.schema) dataset.withColumn($(outputCol), newCol, newField.metadata) } + def transformInstance(feature: Double): Double = + {Bucketizer.binarySearchForBuckets($(splits), feature)} + private def prepOutputField(schema: StructType): StructField = { val buckets = $(splits).sliding(2).map(bucket => bucket.mkString(", ")).toArray val attr = new NominalAttribute(name = Some($(outputCol)), isOrdinal = Some(true), diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index 1c329267d70d..02d4f842f332 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -148,12 +148,13 @@ final class ChiSqSelectorModel private[ml] ( override def transform(dataset: Dataset[_]): DataFrame = { val transformedSchema = transformSchema(dataset.schema, logging = true) val newField = transformedSchema.last + val selector = udf(transformInstance _) + dataset.withColumn($(outputCol), selector(col($(featuresCol))), newField.metadata) + } + def transformInstance(v: Vector): Vector = { // TODO: Make the transformer natively in ml framework to avoid extra conversion. - val transformer: Vector => Vector = v => chiSqSelector.transform(OldVectors.fromML(v)).asML - - val selector = udf(transformer) - dataset.withColumn($(outputCol), selector(col($(featuresCol))), newField.metadata) + chiSqSelector.transform(OldVectors.fromML(v)).asML } @Since("1.6.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 3250fe55980d..814e16cddc79 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -21,7 +21,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.linalg.{Vectors, VectorUDT} +import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ @@ -246,28 +246,29 @@ class CountVectorizerModel( val dict = vocabulary.zipWithIndex.toMap broadcastDict = Some(dataset.sparkSession.sparkContext.broadcast(dict)) } + val vectorizer = udf { transformInstance _ } + dataset.withColumn($(outputCol), vectorizer(col($(inputCol)))) + } + + def transformInstance(document: Seq[String]): Vector = { val dictBr = broadcastDict.get val minTf = $(minTF) - val vectorizer = udf { (document: Seq[String]) => - val termCounts = new OpenHashMap[Int, Double] - var tokenCount = 0L - document.foreach { term => - dictBr.value.get(term) match { - case Some(index) => termCounts.changeValue(index, 1.0, _ + 1.0) - case None => // ignore terms not in the vocabulary - } - tokenCount += 1 - } - val effectiveMinTF = if (minTf >= 1.0) minTf else tokenCount * minTf - val effectiveCounts = if ($(binary)) { - termCounts.filter(_._2 >= effectiveMinTF).map(p => (p._1, 1.0)).toSeq - } else { - termCounts.filter(_._2 >= effectiveMinTF).toSeq + val termCounts = new OpenHashMap[Int, Double] + var tokenCount = 0L + document.foreach { term => + dictBr.value.get(term) match { + case Some(index) => termCounts.changeValue(index, 1.0, _ + 1.0) + case None => // ignore terms not in the vocabulary } - - Vectors.sparse(dictBr.value.size, effectiveCounts) + tokenCount += 1 } - dataset.withColumn($(outputCol), vectorizer(col($(inputCol)))) + val effectiveMinTF = if (minTf >= 1.0) minTf else tokenCount * minTf + val effectiveCounts = if ($(binary)) { + termCounts.filter(_._2 >= effectiveMinTF).map(p => (p._1, 1.0)).toSeq + } else { + termCounts.filter(_._2 >= effectiveMinTF).toSeq + } + Vectors.sparse(dictBr.value.size, effectiveCounts) } @Since("1.5.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 6ca7336cd048..15eec5f24478 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ @@ -82,7 +83,11 @@ class HashingTF @Since("1.4.0") (@Since("1.4.0") override val uid: String) /** @group setParam */ @Since("1.2.0") - def setNumFeatures(value: Int): this.type = set(numFeatures, value) + def setNumFeatures(value: Int): this.type = { + set(numFeatures, value) + hashingTF = new feature.HashingTF($(numFeatures)).setBinary($(binary)) + this + } /** @group getParam */ @Since("2.0.0") @@ -90,18 +95,28 @@ class HashingTF @Since("1.4.0") (@Since("1.4.0") override val uid: String) /** @group setParam */ @Since("2.0.0") - def setBinary(value: Boolean): this.type = set(binary, value) + def setBinary(value: Boolean): this.type = { + set(binary, value) + hashingTF = new feature.HashingTF($(numFeatures)).setBinary($(binary)) + this + } @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { val outputSchema = transformSchema(dataset.schema) - val hashingTF = new feature.HashingTF($(numFeatures)).setBinary($(binary)) - // TODO: Make the hashingTF.transform natively in ml framework to avoid extra conversion. - val t = udf { terms: Seq[_] => hashingTF.transform(terms).asML } + val t = udf { transformInstance _ } val metadata = outputSchema($(outputCol)).metadata dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata)) } + /** Updated by the setters when parameters change */ + private var hashingTF = new feature.HashingTF($(numFeatures)).setBinary($(binary)) + + def transformInstance(terms: Seq[_]): Vector = { + // TODO: Make the hashingTF.transform natively in ml framework to avoid extra conversion. + hashingTF.transform(terms).asML + } + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { val inputType = schema($(inputCol)).dataType diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index cf03a2845ced..bda7ec598906 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -133,11 +133,14 @@ class IDFModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - // TODO: Make the idfModel.transform natively in ml framework to avoid extra conversion. - val idf = udf { vec: Vector => idfModel.transform(OldVectors.fromML(vec)).asML } + val idf = udf { transformInstance _ } dataset.withColumn($(outputCol), idf(col($(inputCol)))) } + def transformInstance(vec: Vector) : Vector = { + // TODO: Make the idfModel.transform natively in ml framework to avoid extra conversion. + idfModel.transform(OldVectors.fromML(vec)).asML} + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala index 31a58152671c..cb133a70fb43 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala @@ -125,13 +125,15 @@ class MaxAbsScalerModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) + val reScale = udf { transformInstance _ } + dataset.withColumn($(outputCol), reScale(col($(inputCol)))) + } + + def transformInstance(vector: Vector): Vector = { // TODO: this looks hack, we may have to handle sparse and dense vectors separately. val maxAbsUnzero = Vectors.dense(maxAbs.toArray.map(x => if (x == 0) 1 else x)) - val reScale = udf { (vector: Vector) => - val brz = vector.asBreeze / maxAbsUnzero.asBreeze - Vectors.fromBreeze(brz) - } - dataset.withColumn($(outputCol), reScale(col($(inputCol)))) + val brz = vector.asBreeze / maxAbsUnzero.asBreeze + Vectors.fromBreeze(brz) } @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index dd5a1f9b41fc..c0c23212b5c5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -173,25 +173,24 @@ class MinMaxScalerModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { + val reScale = udf { transformInstance _ } + dataset.withColumn($(outputCol), reScale(col($(inputCol)))) + } + + def transformInstance(vector: Vector): Vector = { val originalRange = (originalMax.asBreeze - originalMin.asBreeze).toArray val minArray = originalMin.toArray - - val reScale = udf { (vector: Vector) => - val scale = $(max) - $(min) - - // 0 in sparse vector will probably be rescaled to non-zero - val values = vector.toArray - val size = values.length - var i = 0 - while (i < size) { - val raw = if (originalRange(i) != 0) (values(i) - minArray(i)) / originalRange(i) else 0.5 - values(i) = raw * scale + $(min) - i += 1 - } - Vectors.dense(values) + val scale = $(max) - $(min) + // 0 in sparse vector will probably be rescaled to non-zero + val values = vector.toArray + val size = values.length + var i = 0 + while (i < size) { + val raw = if (originalRange(i) != 0) (values(i) - minArray(i)) / originalRange(i) else 0.5 + values(i) = raw * scale + $(min) + i += 1 } - - dataset.withColumn($(outputCol), reScale(col($(inputCol)))) + Vectors.dense(values) } @Since("1.5.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index b89c85991f39..0fc637a63476 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -149,15 +149,17 @@ class PCAModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - val pcaModel = new feature.PCAModel($(k), - OldMatrices.fromML(pc).asInstanceOf[OldDenseMatrix], - OldVectors.fromML(explainedVariance).asInstanceOf[OldDenseVector]) + val pcaOp = udf{ transformInstance _ } + dataset.withColumn($(outputCol), pcaOp(col($(inputCol)))) + } - // TODO: Make the transformer natively in ml framework to avoid extra conversion. - val transformer: Vector => Vector = v => pcaModel.transform(OldVectors.fromML(v)).asML + lazy val pcaModel = new feature.PCAModel($(k), + OldMatrices.fromML(pc).asInstanceOf[OldDenseMatrix], + OldVectors.fromML(explainedVariance).asInstanceOf[OldDenseVector]) - val pcaOp = udf(transformer) - dataset.withColumn($(outputCol), pcaOp(col($(inputCol)))) + def transformInstance(v: Vector): Vector = { + // TODO: Make the transformer natively in ml framework to avoid extra conversion. + pcaModel.transform(OldVectors.fromML(v)).asML } @Since("1.5.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 5e1bacf876ca..5d08a8e77b6e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -164,13 +164,15 @@ class StandardScalerModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - val scaler = new feature.StandardScalerModel(std, mean, $(withStd), $(withMean)) + val scale = udf(transformInstance _) + dataset.withColumn($(outputCol), scale(col($(inputCol)))) + } - // TODO: Make the transformer natively in ml framework to avoid extra conversion. - val transformer: Vector => Vector = v => scaler.transform(OldVectors.fromML(v)).asML + private lazy val scaler = new feature.StandardScalerModel(std, mean, $(withStd), $(withMean)) - val scale = udf(transformer) - dataset.withColumn($(outputCol), scale(col($(inputCol)))) + def transformInstance(v: Vector) : Vector = { + // TODO: Make the transformer natively in ml framework to avoid extra conversion. + scaler.transform(OldVectors.fromML(v)).asML } @Since("1.4.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 1a6f42f773cd..4bd650c94a7a 100755 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -60,7 +60,12 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String /** @group setParam */ @Since("1.5.0") - def setStopWords(value: Array[String]): this.type = set(stopWords, value) + def setStopWords(value: Array[String]): this.type = { + set(stopWords, value) + stopWordsSet = if ($(caseSensitive)) $(stopWords).toSet + else $(stopWords).map(toLower(_)).toSet + this + } /** @group getParam */ @Since("1.5.0") @@ -77,7 +82,12 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String /** @group setParam */ @Since("1.5.0") - def setCaseSensitive(value: Boolean): this.type = set(caseSensitive, value) + def setCaseSensitive(value: Boolean): this.type = { + set(caseSensitive, value) + stopWordsSet = if ($(caseSensitive)) $(stopWords).toSet + else $(stopWords).map(toLower(_)).toSet + this + } /** @group getParam */ @Since("1.5.0") @@ -88,21 +98,24 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { val outputSchema = transformSchema(dataset.schema) - val t = if ($(caseSensitive)) { - val stopWordsSet = $(stopWords).toSet - udf { terms: Seq[String] => - terms.filter(s => !stopWordsSet.contains(s)) - } + val t = udf { transformInstance _ } + val metadata = outputSchema($(outputCol)).metadata + dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata)) + } + + private val toLower = (s: String) => if (s != null) s.toLowerCase else s + + /** Updated by the setters when parameters change */ + private var stopWordsSet = if ($(caseSensitive)) $(stopWords).toSet + else $(stopWords).map(toLower(_)).toSet + + def transformInstance(terms: Seq[String]) : Seq[String] = { + if ($(caseSensitive)) { + terms.filter(s => !stopWordsSet.contains(s)) } else { // TODO: support user locale (SPARK-15064) - val toLower = (s: String) => if (s != null) s.toLowerCase else s - val lowerStopWords = $(stopWords).map(toLower(_)).toSet - udf { terms: Seq[String] => - terms.filter(s => !lowerStopWords.contains(toLower(s))) - } + terms.filter(s => !stopWordsSet.contains(toLower(s))) } - val metadata = outputSchema($(outputCol)).metadata - dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata)) } @Since("1.5.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 0f7337ce6b55..4230dd07536e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -165,15 +165,7 @@ class StringIndexerModel ( return dataset.toDF } validateAndTransformSchema(dataset.schema) - - val indexer = udf { label: String => - if (labelToIndex.contains(label)) { - labelToIndex(label) - } else { - throw new SparkException(s"Unseen label: $label.") - } - } - + val indexer = udf { transformInstance _ } val metadata = NominalAttribute.defaultAttr .withName($(outputCol)).withValues(labels).toMetadata() // If we are skipping invalid records, filter them out. @@ -189,6 +181,15 @@ class StringIndexerModel ( indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata)) } + def transformInstance(label: String): Double = { + // Throws an exception if the label is unseen before + if (labelToIndex.contains(label)) { + labelToIndex(label) + } else { + throw new SparkException(s"Unseen label: $label.") + } + } + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { if (schema.fieldNames.contains($(inputCol))) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 4939dabd987e..29ef6dc19a9a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -97,7 +97,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) // Data transformation. val assembleFunc = udf { r: Row => - VectorAssembler.assemble(r.toSeq: _*) + transformInstance(r.toSeq) } val args = $(inputCols).map { c => schema(c).dataType match { @@ -110,6 +110,9 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) dataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata)) } + /** Takes a sequence of values or vectors and assembles them into a single vector */ + def transformInstance(seq: Seq[Any]): Vector = {VectorAssembler.assemble(seq: _*)} + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { val inputColNames = $(inputCols) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 52db996c841b..a6718ee0464c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -363,11 +363,13 @@ class VectorIndexerModel private[ml] ( override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val newField = prepOutputField(dataset.schema) - val transformUDF = udf { (vector: Vector) => transformFunc(vector) } + val transformUDF = udf { transformInstance _ } val newCol = transformUDF(dataset($(inputCol))) dataset.withColumn($(outputCol), newCol, newField.metadata) } + def transformInstance(vector: Vector): Vector = {transformFunc(vector)} + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { val dataType = new VectorUDT