Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -395,42 +395,43 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {

private def fillMap(values: Seq[(String, Any)]): DataFrame = {
// Error handling
values.foreach { case (colName, replaceValue) =>
val attrToValue = AttributeMap(values.map { case (colName, replaceValue) =>
// Check column name exists
df.resolve(colName)

val attr = df.resolve(colName) match {
case a: Attribute => a
case _ => throw new UnsupportedOperationException(
s"Nested field ${colName} is not supported.")
}
// Check data type
replaceValue match {
case _: jl.Double | _: jl.Float | _: jl.Integer | _: jl.Long | _: jl.Boolean | _: String =>
// This is good
case _ => throw new IllegalArgumentException(
s"Unsupported value type ${replaceValue.getClass.getName} ($replaceValue).")
}
}

val columnEquals = df.sparkSession.sessionState.analyzer.resolver
val projections = df.schema.fields.map { f =>
values.find { case (k, _) => columnEquals(k, f.name) }.map { case (_, v) =>
v match {
case v: jl.Float => fillCol[Float](f, v)
case v: jl.Double => fillCol[Double](f, v)
case v: jl.Long => fillCol[Long](f, v)
case v: jl.Integer => fillCol[Integer](f, v)
case v: jl.Boolean => fillCol[Boolean](f, v.booleanValue())
case v: String => fillCol[String](f, v)
}
}.getOrElse(df.col(f.name))
attr -> replaceValue
})

val output = df.queryExecution.analyzed.output
val projections = output.map {
attr => attrToValue.get(attr).map {
case v: jl.Float => fillCol[Float](attr, v)
case v: jl.Double => fillCol[Double](attr, v)
case v: jl.Long => fillCol[Long](attr, v)
case v: jl.Integer => fillCol[Integer](attr, v)
case v: jl.Boolean => fillCol[Boolean](attr, v.booleanValue())
case v: String => fillCol[String](attr, v)
}.getOrElse(Column(attr))
}
df.select(projections : _*)
}

/**
* Returns a [[Column]] expression that replaces null value in `col` with `replacement`.
* It selects a column based on its name.
* Returns a [[Column]] expression that replaces null value in column defined by `attr`
* with `replacement`.
*/
private def fillCol[T](col: StructField, replacement: T): Column = {
val quotedColName = "`" + col.name + "`"
fillCol(col.dataType, col.name, df.col(quotedColName), replacement)
private def fillCol[T](attr: Attribute, replacement: T): Column = {
fillCol(attr.dataType, attr.name, Column(attr), replacement)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -460,4 +460,29 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSparkSession {
Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) ::
Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil)
}

test("SPARK-34417 - test fillMap() for column with a dot in the name") {
val na = "n/a"
checkAnswer(
Seq(("abc", 23L), ("def", 44L), (null, 0L)).toDF("ColWith.Dot", "Col")
.na.fill(Map("`ColWith.Dot`" -> na)),
Row("abc", 23) :: Row("def", 44L) :: Row(na, 0L) :: Nil)
}

test("SPARK-34417 - test fillMap() for qualified-column with a dot in the name") {
val na = "n/a"
checkAnswer(
Seq(("abc", 23L), ("def", 44L), (null, 0L)).toDF("ColWith.Dot", "Col").as("testDF")
.na.fill(Map("testDF.`ColWith.Dot`" -> na)),
Row("abc", 23) :: Row("def", 44L) :: Row(na, 0L) :: Nil)
}

test("SPARK-34417 - test fillMap() for column without a dot in the name" +
" and dataframe with another column having a dot in the name") {
val na = "n/a"
checkAnswer(
Seq(("abc", 23L), ("def", 44L), (null, 0L)).toDF("Col", "ColWith.Dot")
.na.fill(Map("Col" -> na)),
Row("abc", 23) :: Row("def", 44L) :: Row(na, 0L) :: Nil)
}
}