Skip to content

Commit 3ca3670

Browse files
viiryagatorsmile
authored andcommitted
[SPARK-22001][ML][SQL] ImputerModel can do withColumn for all input columns at one pass
## What changes were proposed in this pull request? SPARK-21690 makes one-pass `Imputer` by parallelizing the computation of all input columns. When we transform dataset with `ImputerModel`, we do `withColumn` on all input columns sequentially. We can also do this on all input columns at once by adding a `withColumns` API to `Dataset`. The new `withColumns` API is for internal use only now. ## How was this patch tested? Existing tests for `ImputerModel`'s change. Added tests for `withColumns` API. Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #19229 from viirya/SPARK-22001.
1 parent 02c91e0 commit 3ca3670

File tree

3 files changed

+86
-18
lines changed

3 files changed

+86
-18
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -223,20 +223,18 @@ class ImputerModel private[ml] (
223223

224224
override def transform(dataset: Dataset[_]): DataFrame = {
225225
transformSchema(dataset.schema, logging = true)
226-
var outputDF = dataset
227226
val surrogates = surrogateDF.select($(inputCols).map(col): _*).head().toSeq
228227

229-
$(inputCols).zip($(outputCols)).zip(surrogates).foreach {
228+
val newCols = $(inputCols).zip($(outputCols)).zip(surrogates).map {
230229
case ((inputCol, outputCol), surrogate) =>
231230
val inputType = dataset.schema(inputCol).dataType
232231
val ic = col(inputCol)
233-
outputDF = outputDF.withColumn(outputCol,
234-
when(ic.isNull, surrogate)
232+
when(ic.isNull, surrogate)
235233
.when(ic === $(missingValue), surrogate)
236234
.otherwise(ic)
237-
.cast(inputType))
235+
.cast(inputType)
238236
}
239-
outputDF.toDF()
237+
dataset.withColumns($(outputCols), newCols).toDF()
240238
}
241239

242240
override def transformSchema(schema: StructType): StructType = {

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2083,22 +2083,40 @@ class Dataset[T] private[sql](
20832083
* @group untypedrel
20842084
* @since 2.0.0
20852085
*/
2086-
def withColumn(colName: String, col: Column): DataFrame = {
2086+
def withColumn(colName: String, col: Column): DataFrame = withColumns(Seq(colName), Seq(col))
2087+
2088+
/**
2089+
* Returns a new Dataset by adding columns or replacing the existing columns that has
2090+
* the same names.
2091+
*/
2092+
private[spark] def withColumns(colNames: Seq[String], cols: Seq[Column]): DataFrame = {
2093+
require(colNames.size == cols.size,
2094+
s"The size of column names: ${colNames.size} isn't equal to " +
2095+
s"the size of columns: ${cols.size}")
2096+
SchemaUtils.checkColumnNameDuplication(
2097+
colNames,
2098+
"in given column names",
2099+
sparkSession.sessionState.conf.caseSensitiveAnalysis)
2100+
20872101
val resolver = sparkSession.sessionState.analyzer.resolver
20882102
val output = queryExecution.analyzed.output
2089-
val shouldReplace = output.exists(f => resolver(f.name, colName))
2090-
if (shouldReplace) {
2091-
val columns = output.map { field =>
2092-
if (resolver(field.name, colName)) {
2093-
col.as(colName)
2094-
} else {
2095-
Column(field)
2096-
}
2103+
2104+
val columnMap = colNames.zip(cols).toMap
2105+
2106+
val replacedAndExistingColumns = output.map { field =>
2107+
columnMap.find { case (colName, _) =>
2108+
resolver(field.name, colName)
2109+
} match {
2110+
case Some((colName: String, col: Column)) => col.as(colName)
2111+
case _ => Column(field)
20972112
}
2098-
select(columns : _*)
2099-
} else {
2100-
select(Column("*"), col.as(colName))
21012113
}
2114+
2115+
val newColumns = columnMap.filter { case (colName, col) =>
2116+
!output.exists(f => resolver(f.name, colName))
2117+
}.map { case (colName, col) => col.as(colName) }
2118+
2119+
select(replacedAndExistingColumns ++ newColumns : _*)
21022120
}
21032121

21042122
/**

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,49 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
641641
assert(df.schema.map(_.name) === Seq("key", "value", "newCol"))
642642
}
643643

644+
test("withColumns") {
645+
val df = testData.toDF().withColumns(Seq("newCol1", "newCol2"),
646+
Seq(col("key") + 1, col("key") + 2))
647+
checkAnswer(
648+
df,
649+
testData.collect().map { case Row(key: Int, value: String) =>
650+
Row(key, value, key + 1, key + 2)
651+
}.toSeq)
652+
assert(df.schema.map(_.name) === Seq("key", "value", "newCol1", "newCol2"))
653+
654+
val err = intercept[IllegalArgumentException] {
655+
testData.toDF().withColumns(Seq("newCol1"),
656+
Seq(col("key") + 1, col("key") + 2))
657+
}
658+
assert(
659+
err.getMessage.contains("The size of column names: 1 isn't equal to the size of columns: 2"))
660+
661+
val err2 = intercept[AnalysisException] {
662+
testData.toDF().withColumns(Seq("newCol1", "newCOL1"),
663+
Seq(col("key") + 1, col("key") + 2))
664+
}
665+
assert(err2.getMessage.contains("Found duplicate column(s)"))
666+
}
667+
668+
test("withColumns: case sensitive") {
669+
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
670+
val df = testData.toDF().withColumns(Seq("newCol1", "newCOL1"),
671+
Seq(col("key") + 1, col("key") + 2))
672+
checkAnswer(
673+
df,
674+
testData.collect().map { case Row(key: Int, value: String) =>
675+
Row(key, value, key + 1, key + 2)
676+
}.toSeq)
677+
assert(df.schema.map(_.name) === Seq("key", "value", "newCol1", "newCOL1"))
678+
679+
val err = intercept[AnalysisException] {
680+
testData.toDF().withColumns(Seq("newCol1", "newCol1"),
681+
Seq(col("key") + 1, col("key") + 2))
682+
}
683+
assert(err.getMessage.contains("Found duplicate column(s)"))
684+
}
685+
}
686+
644687
test("replace column using withColumn") {
645688
val df2 = sparkContext.parallelize(Array(1, 2, 3)).toDF("x")
646689
val df3 = df2.withColumn("x", df2("x") + 1)
@@ -649,6 +692,15 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
649692
Row(2) :: Row(3) :: Row(4) :: Nil)
650693
}
651694

695+
test("replace column using withColumns") {
696+
val df2 = sparkContext.parallelize(Array((1, 2), (2, 3), (3, 4))).toDF("x", "y")
697+
val df3 = df2.withColumns(Seq("x", "newCol1", "newCol2"),
698+
Seq(df2("x") + 1, df2("y"), df2("y") + 1))
699+
checkAnswer(
700+
df3.select("x", "newCol1", "newCol2"),
701+
Row(2, 2, 3) :: Row(3, 3, 4) :: Row(4, 4, 5) :: Nil)
702+
}
703+
652704
test("drop column using drop") {
653705
val df = testData.drop("key")
654706
checkAnswer(

0 commit comments

Comments
 (0)