-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-31826][SQL] Support composed type of case class for typed Scala UDF #28645
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
25c881b
4546f35
c1a2d1c
d13f1c7
bb26320
5e0f445
8576d28
1afbdf5
86035fa
21e8aaf
2527c69
26fc42b
9e986f0
7568e8c
6b384b4
1c82558
21ae72b
bdbd45b
4db6401
3e97fa5
e6bb55d
f29a62a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -256,7 +256,8 @@ class Analyzer( | |
| Batch("Nondeterministic", Once, | ||
| PullOutNondeterministic), | ||
| Batch("UDF", Once, | ||
| HandleNullInputsForUDF), | ||
| HandleNullInputsForUDF, | ||
| ResolveEncodersInUDF), | ||
| Batch("UpdateNullability", Once, | ||
| UpdateAttributeNullability), | ||
| Batch("Subquery", Once, | ||
|
|
@@ -2847,6 +2848,45 @@ class Analyzer( | |
| } | ||
| } | ||
|
|
||
| /** | ||
| * Resolve the encoders for the UDF by explicitly given the attributes. We give the | ||
| * attributes explicitly in order to handle the case where the data type of the input | ||
| * value is not the same with the internal schema of the encoder, which could cause | ||
| * data loss. For example, the encoder should not cast the input value to Decimal(38, 18) | ||
| * if the actual data type is Decimal(30, 0). | ||
| * | ||
| * The resolved encoders then will be used to deserialize the internal row to Scala value. | ||
| */ | ||
| object ResolveEncodersInUDF extends Rule[LogicalPlan] { | ||
| override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { | ||
| case p if !p.resolved => p // Skip unresolved nodes. | ||
|
|
||
| case p => p transformExpressionsUp { | ||
|
|
||
| case udf: ScalaUDF if udf.inputEncoders.nonEmpty => | ||
| val boundEncoders = udf.inputEncoders.zipWithIndex.map { case (encOpt, i) => | ||
| val dataType = udf.children(i).dataType | ||
| if (dataType.existsRecursively(_.isInstanceOf[UserDefinedType[_]])) { | ||
| // for UDT, we use `CatalystTypeConverters` | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. encoder does support UDT. We can figure it out later.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It does, but just doesn't support upcast from the subclass to the parent class. So, when the input data type from the child is the subclass of the input parameter data type of the udf, I think this may need a separate fix. |
||
| None | ||
| } else { | ||
| encOpt.map { enc => | ||
| val attrs = if (enc.isSerializedAsStructForTopLevel) { | ||
| dataType.asInstanceOf[StructType].toAttributes | ||
| } else { | ||
| // the field name doesn't matter here, so we use | ||
| // a simple literal to avoid any overhead | ||
| new StructType().add("input", dataType).toAttributes | ||
| } | ||
cloud-fan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| enc.resolveAndBind(attrs) | ||
| } | ||
| } | ||
| } | ||
| udf.copy(inputEncoders = boundEncoders) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Check and add proper window frames for all window functions. | ||
| */ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,14 +17,13 @@ | |
|
|
||
| package org.apache.spark.sql.catalyst.expressions | ||
|
|
||
| import scala.collection.mutable | ||
|
|
||
| import org.apache.spark.SparkException | ||
| import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} | ||
| import org.apache.spark.sql.catalyst.CatalystTypeConverters.{createToCatalystConverter, createToScalaConverter => catalystCreateToScalaConverter, isPrimitive} | ||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder | ||
| import org.apache.spark.sql.catalyst.expressions.codegen._ | ||
| import org.apache.spark.sql.catalyst.expressions.codegen.Block._ | ||
| import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, DataType} | ||
| import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, DataType, UserDefinedType} | ||
|
|
||
| /** | ||
| * User-defined function. | ||
|
|
@@ -103,21 +102,46 @@ case class ScalaUDF( | |
| } | ||
| } | ||
|
|
||
| private def createToScalaConverter(i: Int, dataType: DataType): Any => Any = { | ||
| if (inputEncoders.isEmpty) { | ||
| // for untyped Scala UDF | ||
| CatalystTypeConverters.createToScalaConverter(dataType) | ||
| } else { | ||
| val encoder = inputEncoders(i) | ||
| if (encoder.isDefined && encoder.get.isSerializedAsStructForTopLevel) { | ||
| val fromRow = encoder.get.resolveAndBind().createDeserializer() | ||
| /** | ||
| * Create the converter which converts the catalyst data type to the scala data type. | ||
| * We use `CatalystTypeConverters` to create the converter for: | ||
| * - UDF which doesn't provide inputEncoders, e.g., untyped Scala UDF and Java UDF | ||
| * - type which isn't supported by `ExpressionEncoder`, e.g., Any | ||
| * - primitive types, in order to use `identity` for better performance | ||
| * - UserDefinedType which isn't fully supported by `ExpressionEncoder` | ||
| * For other cases like case class, Option[T], we use `ExpressionEncoder` instead since | ||
| * `CatalystTypeConverters` doesn't support these data types. | ||
| * | ||
| * @param i the index of the child | ||
| * @param dataType the output data type of the i-th child | ||
| * @return the converter and a boolean value to indicate whether the converter is | ||
| * created by using `ExpressionEncoder`. | ||
| */ | ||
| private def scalaConverter(i: Int, dataType: DataType): (Any => Any, Boolean) = { | ||
| val useEncoder = | ||
| !(inputEncoders.isEmpty || // for untyped Scala UDF and Java UDF | ||
| inputEncoders(i).isEmpty || // for types aren't supported by encoder, e.g. Any | ||
| inputPrimitives(i) || // for primitive types | ||
| dataType.existsRecursively(_.isInstanceOf[UserDefinedType[_]])) | ||
|
|
||
| if (useEncoder) { | ||
| val enc = inputEncoders(i).get | ||
| val fromRow = enc.createDeserializer() | ||
| val converter = if (enc.isSerializedAsStructForTopLevel) { | ||
| row: Any => fromRow(row.asInstanceOf[InternalRow]) | ||
| } else { | ||
| CatalystTypeConverters.createToScalaConverter(dataType) | ||
| val inputRow = new GenericInternalRow(1) | ||
| value: Any => inputRow.update(0, value); fromRow(inputRow) | ||
| } | ||
| (converter, true) | ||
| } else { // use CatalystTypeConverters | ||
| (catalystCreateToScalaConverter(dataType), false) | ||
| } | ||
| } | ||
|
|
||
| private def createToScalaConverter(i: Int, dataType: DataType): Any => Any = | ||
| scalaConverter(i, dataType)._1 | ||
|
|
||
| // scalastyle:off line.size.limit | ||
|
|
||
| /** This method has been generated by this script | ||
|
|
@@ -1045,10 +1069,11 @@ case class ScalaUDF( | |
| ev: ExprCode): ExprCode = { | ||
| val converterClassName = classOf[Any => Any].getName | ||
|
|
||
| // The type converters for inputs and the result. | ||
| val converters: Array[Any => Any] = children.zipWithIndex.map { case (c, i) => | ||
| createToScalaConverter(i, c.dataType) | ||
| }.toArray :+ CatalystTypeConverters.createToCatalystConverter(dataType) | ||
| // The type converters for inputs and the result | ||
| val (converters, useEncoders): (Array[Any => Any], Array[Boolean]) = | ||
| (children.zipWithIndex.map { case (c, i) => | ||
| scalaConverter(i, c.dataType) | ||
| }.toArray :+ (createToCatalystConverter(dataType), false)).unzip | ||
| val convertersTerm = ctx.addReferenceObj("converters", converters, s"$converterClassName[]") | ||
| val errorMsgTerm = ctx.addReferenceObj("errMsg", udfErrorMessage) | ||
| val resultTerm = ctx.freshName("result") | ||
|
|
@@ -1064,12 +1089,26 @@ case class ScalaUDF( | |
| val (funcArgs, initArgs) = evals.zipWithIndex.zip(children.map(_.dataType)).map { | ||
| case ((eval, i), dt) => | ||
| val argTerm = ctx.freshName("arg") | ||
| val initArg = if (CatalystTypeConverters.isPrimitive(dt)) { | ||
| // Check `inputPrimitives` when it's not empty in order to figure out the Option | ||
| // type as non primitive type, e.g., Option[Int]. Fall back to `isPrimitive` when | ||
| // `inputPrimitives` is empty for other cases, e.g., Java UDF, untyped Scala UDF | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so untyped Scala UDF doesn't support
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea. We require the encoder to support |
||
| val primitive = (inputPrimitives.isEmpty && isPrimitive(dt)) || | ||
| (inputPrimitives.nonEmpty && inputPrimitives(i)) | ||
| val initArg = if (primitive) { | ||
| val convertedTerm = ctx.freshName("conv") | ||
| s""" | ||
| |${CodeGenerator.boxedType(dt)} $convertedTerm = ${eval.value}; | ||
| |Object $argTerm = ${eval.isNull} ? null : $convertedTerm; | ||
| """.stripMargin | ||
| } else if (useEncoders(i)) { | ||
| s""" | ||
| |Object $argTerm = null; | ||
| |if (${eval.isNull}) { | ||
| | $argTerm = $convertersTerm[$i].apply(null); | ||
| |} else { | ||
| | $argTerm = $convertersTerm[$i].apply(${eval.value}); | ||
| |} | ||
| """.stripMargin | ||
| } else { | ||
| s"Object $argTerm = ${eval.isNull} ? null : $convertersTerm[$i].apply(${eval.value});" | ||
| } | ||
|
|
@@ -1081,7 +1120,7 @@ case class ScalaUDF( | |
| val resultConverter = s"$convertersTerm[${children.length}]" | ||
| val boxedType = CodeGenerator.boxedType(dataType) | ||
|
|
||
| val funcInvokation = if (CatalystTypeConverters.isPrimitive(dataType) | ||
| val funcInvokation = if (isPrimitive(dataType) | ||
| // If the output is nullable, the returned value must be unwrapped from the Option | ||
| && !nullable) { | ||
| s"$resultTerm = ($boxedType)$getFuncResult" | ||
|
|
@@ -1112,7 +1151,7 @@ case class ScalaUDF( | |
| """.stripMargin) | ||
| } | ||
|
|
||
| private[this] val resultConverter = CatalystTypeConverters.createToCatalystConverter(dataType) | ||
| private[this] val resultConverter = createToCatalystConverter(dataType) | ||
|
|
||
| lazy val udfErrorMessage = { | ||
| val funcCls = function.getClass.getSimpleName | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.