-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-31826][SQL] Support composed type of case class for typed Scala UDF #28645
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
25c881b
4546f35
c1a2d1c
d13f1c7
bb26320
5e0f445
8576d28
1afbdf5
86035fa
21e8aaf
2527c69
26fc42b
9e986f0
7568e8c
6b384b4
1c82558
21ae72b
bdbd45b
4db6401
3e97fa5
e6bb55d
f29a62a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -257,7 +257,7 @@ class Analyzer( | |
| PullOutNondeterministic), | ||
| Batch("UDF", Once, | ||
| HandleNullInputsForUDF, | ||
| PrepareDeserializerForUDF), | ||
| ResolveEncodersInUDF), | ||
| Batch("UpdateNullability", Once, | ||
| UpdateAttributeNullability), | ||
| Batch("Subquery", Once, | ||
|
|
@@ -2814,7 +2814,7 @@ class Analyzer( | |
|
|
||
| case p => p transformExpressionsUp { | ||
|
|
||
| case udf @ ScalaUDF(_, _, inputs, _, _, _, _, _) | ||
| case udf @ ScalaUDF(_, _, inputs, _, _, _, _) | ||
| if udf.inputPrimitives.contains(true) => | ||
| // Otherwise, add special handling of null for fields that can't accept null. | ||
| // The result of operations like this, when passed null, is generally to return null. | ||
|
|
@@ -2848,20 +2848,17 @@ class Analyzer( | |
| } | ||
| } | ||
|
|
||
| object PrepareDeserializerForUDF extends Rule[LogicalPlan] { | ||
| object ResolveEncodersInUDF extends Rule[LogicalPlan] { | ||
| override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { | ||
| case p if !p.resolved => p // Skip unresolved nodes. | ||
|
|
||
| case p => p transformExpressionsUp { | ||
|
|
||
| case udf @ ScalaUDF(_, _, inputs, encoders, _, _, _, desers) | ||
| if encoders.nonEmpty && desers.isEmpty => | ||
| val deserializers = encoders.zipWithIndex.map { case (encOpt, i) => | ||
| case udf @ ScalaUDF(_, _, inputs, encoders, _, _, _) if encoders.nonEmpty => | ||
|
||
| val resolvedEncoders = encoders.zipWithIndex.map { case (encOpt, i) => | ||
|
||
| val dataType = inputs(i).dataType | ||
| if (CatalystTypeConverters.isPrimitive(dataType) || | ||
| dataType.isInstanceOf[UserDefinedType[_]]) { | ||
| // primitive/UDT data types use `CatalystTypeConverters` to | ||
| // convert internal data to external data. | ||
| if (dataType.isInstanceOf[UserDefinedType[_]]) { | ||
|
||
| // for UDT, we use `CatalystTypeConverters` | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. encoder does support UDT. We can figure it out later.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It does, but just doesn't support upcast from the subclass to the parent class. So, when the input data type from the child is the subclass of the input parameter data type of the udf, I think this may need a separate fix. |
||
| None | ||
| } else { | ||
| encOpt.map { enc => | ||
|
|
@@ -2870,13 +2867,13 @@ class Analyzer( | |
| } else { | ||
| // the field name doesn't matter here, so we use | ||
| // a simple literal to avoid any overhead | ||
| new StructType().add(s"input", dataType).toAttributes | ||
| new StructType().add("input", dataType).toAttributes | ||
| } | ||
| enc.resolveAndBind(attrs).createDeserializer() | ||
| enc.resolveAndBind(attrs) | ||
| } | ||
| } | ||
| } | ||
| udf.copy(inputDeserializers = deserializers) | ||
| udf.copy(inputEncoders = resolvedEncoders) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,10 +21,9 @@ import org.apache.spark.SparkException | |
| import org.apache.spark.sql.catalyst.CatalystTypeConverters.{createToCatalystConverter, createToScalaConverter, isPrimitive} | ||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder | ||
| import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Deserializer | ||
| import org.apache.spark.sql.catalyst.expressions.codegen._ | ||
| import org.apache.spark.sql.catalyst.expressions.codegen.Block._ | ||
| import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, DataType} | ||
| import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, DataType, UserDefinedType} | ||
|
|
||
| /** | ||
| * User-defined function. | ||
|
|
@@ -41,9 +40,6 @@ import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, DataType} | |
| * @param nullable True if the UDF can return null value. | ||
| * @param udfDeterministic True if the UDF is deterministic. Deterministic UDF returns same result | ||
| * each time it is invoked with a particular input. | ||
| * @param inputDeserializers deserializers used to convert internal inputs from children | ||
| * to external data types, and they will only be instantiated | ||
| * by `PrepareInputDeserializerForUDF`. | ||
| */ | ||
| case class ScalaUDF( | ||
| function: AnyRef, | ||
|
|
@@ -52,19 +48,11 @@ case class ScalaUDF( | |
| inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Nil, | ||
| udfName: Option[String] = None, | ||
| nullable: Boolean = true, | ||
| udfDeterministic: Boolean = true, | ||
| inputDeserializers: Seq[Option[Deserializer[_]]] = Nil) | ||
| udfDeterministic: Boolean = true) | ||
| extends Expression with NonSQLExpression with UserDefinedExpression { | ||
|
|
||
| override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) | ||
|
|
||
| override lazy val canonicalized: Expression = { | ||
| val canonicalizedChildren = children.map(_.canonicalized) | ||
| // if the canonicalized children and inputEncoders are equal, | ||
| // then we must have equal inputDeserializers as well. | ||
| this.copy(inputDeserializers = Nil).withNewChildren(canonicalizedChildren) | ||
| } | ||
|
|
||
| override def toString: String = s"${udfName.getOrElse("UDF")}(${children.mkString(", ")})" | ||
|
|
||
| /** | ||
|
|
@@ -115,27 +103,25 @@ case class ScalaUDF( | |
| } | ||
|
|
||
| private def scalaConverter(i: Int, dataType: DataType): Any => Any = { | ||
|
||
| if (inputEncoders.isEmpty) { | ||
| // for untyped Scala UDF | ||
| if (inputEncoders.isEmpty || // for untyped Scala UDF | ||
|
||
| dataType.isInstanceOf[UserDefinedType[_]]) { | ||
| createToScalaConverter(dataType) | ||
| } else if (isPrimitive(dataType)) { | ||
| identity | ||
| } else { | ||
| val encoder = inputEncoders(i) | ||
| encoder match { | ||
| case Some(enc) => | ||
| val fromRow = inputDeserializers(i).get | ||
| if (enc.isSerializedAsStructForTopLevel) { | ||
| row: Any => fromRow(row.asInstanceOf[InternalRow]) | ||
| } else { | ||
| value: Any => | ||
| val row = new GenericInternalRow(1) | ||
| row.update(0, value) | ||
| fromRow(row) | ||
| } | ||
|
|
||
| // e.g. for UDF types | ||
| case _ => createToScalaConverter(dataType) | ||
| val encoderOpt = inputEncoders(i) | ||
| assert(encoderOpt.isDefined, s"ScalaUDF expects an encoder of ${i}th child, but got empty.") | ||
| val enc = encoderOpt.get | ||
| if (isPrimitive(dataType) && !enc.schema.head.nullable) { | ||
| createToScalaConverter(dataType) | ||
| } else { | ||
| val fromRow = enc.createDeserializer() | ||
| if (enc.isSerializedAsStructForTopLevel) { | ||
| row: Any => fromRow(row.asInstanceOf[InternalRow]) | ||
| } else { | ||
| value: Any => | ||
| val row = new GenericInternalRow(1) | ||
| row.update(0, value) | ||
| fromRow(row) | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.