Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
improve
  • Loading branch information
Ngone51 committed May 26, 2020
commit 4546f357ca7c5b9e8b6ae8bd3bcaa96833bcc48e
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, DataType}
import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, DataType, DecimalType, StructType}

/**
* User-defined function.
Expand Down Expand Up @@ -112,17 +112,21 @@ case class ScalaUDF(
val encoder = inputEncoders(i)
encoder match {
case Some(enc) =>
val fromRow = enc.resolveAndBind().createDeserializer()
if (enc.isSerializedAsStructForTopLevel) {
val fromRow = enc.resolveAndBind().createDeserializer()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to be consistent, shall we bind with child.dataType.asInstanceOf[StructType].toAttributes?

row: Any => fromRow(row.asInstanceOf[InternalRow])
} else {
val child = children(i)
val attrs = new StructType().add(s"$child", child.dataType).toAttributes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

child.toString can be expensive. how about "child"? The name doesn't matter anyway.

val fromRow = enc.resolveAndBind(attrs).createDeserializer()

value: Any =>
val row = new GenericInternalRow(1)
row.update(0, value)
fromRow(row)
}

case None => createToScalaConverter(dataType)
case _ => createToScalaConverter(dataType)
}
}
}
Expand Down
9 changes: 8 additions & 1 deletion sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,13 @@ class UDFSuite extends QueryTest with SharedSparkSession {
val f3 = (s: Map[TestData, TestData]) => s.keys.head.key * s.values.head.value.toInt
val myUdf3 = udf(f3)
val df3 = Seq(("data", Map(TestData(50, "2") -> TestData(50, "2")))).toDF("col1", "col2")
checkAnswer(df2.select(myUdf2(Column("col2"))), Row(100) :: Nil)
checkAnswer(df3.select(myUdf3(Column("col2"))), Row(100) :: Nil)
}

test("case class as element of tuple") {
val f = (s: (TestData, Int)) => s._1.key * s._2
val myUdf = udf(f)
val df = Seq(("data", (TestData(50, "2"), 2))).toDF("col1", "col2")
checkAnswer(df.select(myUdf(Column("col2"))), Row(100) :: Nil)
}
}