Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2871,20 +2871,15 @@ class Analyzer(
case udf: ScalaUDF if udf.inputEncoders.nonEmpty =>
val boundEncoders = udf.inputEncoders.zipWithIndex.map { case (encOpt, i) =>
val dataType = udf.children(i).dataType
if (dataType.existsRecursively(_.isInstanceOf[UserDefinedType[_]])) {
// for UDT, we use `CatalystTypeConverters`
None
} else {
encOpt.map { enc =>
val attrs = if (enc.isSerializedAsStructForTopLevel) {
dataType.asInstanceOf[StructType].toAttributes
} else {
// the field name doesn't matter here, so we use
// a simple literal to avoid any overhead
new StructType().add("input", dataType).toAttributes
}
enc.resolveAndBind(attrs)
encOpt.map { enc =>
val attrs = if (enc.isSerializedAsStructForTopLevel) {
dataType.asInstanceOf[StructType].toAttributes
} else {
// the field name doesn't matter here, so we use
// a simple literal to avoid any overhead
new StructType().add("input", dataType).toAttributes
}
enc.resolveAndBind(attrs)
}
}
udf.copy(inputEncoders = boundEncoders)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@ object Cast {
toField.nullable)
}

case (udt1: UserDefinedType[_], udt2: UserDefinedType[_]) if udt1.userClass == udt2.userClass =>
true
case (udt1: UserDefinedType[_], udt2: UserDefinedType[_]) if udt2.acceptsType(udt1) => true

case _ => false
}
Expand Down Expand Up @@ -157,6 +156,8 @@ object Cast {
resolvableNullability(f1.nullable, f2.nullable) && canUpCast(f1.dataType, f2.dataType)
}

case (from: UserDefinedType[_], to: UserDefinedType[_]) if to.acceptsType(from) => true

case _ => false
}

Expand Down Expand Up @@ -810,8 +811,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
castArray(from.asInstanceOf[ArrayType].elementType, array.elementType)
case map: MapType => castMap(from.asInstanceOf[MapType], map)
case struct: StructType => castStruct(from.asInstanceOf[StructType], struct)
case udt: UserDefinedType[_]
if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass =>
case udt: UserDefinedType[_] if udt.acceptsType(from) =>
identity[Any]
case _: UserDefinedType[_] =>
throw new SparkException(s"Cannot cast $from to $to.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ case class ScalaUDF(
* - UDF which doesn't provide inputEncoders, e.g., untyped Scala UDF and Java UDF
* - type which isn't supported by `ExpressionEncoder`, e.g., Any
* - primitive types, in order to use `identity` for better performance
* - UserDefinedType which isn't fully supported by `ExpressionEncoder`
* For other cases like case class, Option[T], we use `ExpressionEncoder` instead since
* `CatalystTypeConverters` doesn't support these data types.
*
Expand All @@ -121,8 +120,7 @@ case class ScalaUDF(
val useEncoder =
!(inputEncoders.isEmpty || // for untyped Scala UDF and Java UDF
inputEncoders(i).isEmpty || // for types aren't supported by encoder, e.g. Any
inputPrimitives(i) || // for primitive types
dataType.existsRecursively(_.isInstanceOf[UserDefinedType[_]]))
inputPrimitives(i)) // for primitive types

if (useEncoder) {
val enc = inputEncoders(i).get
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,11 +275,11 @@ class UserDefinedTypeSuite extends QueryTest with SharedSparkSession with Parque

// this worked already before the fix SPARK-19311:
// return type of doUDF equals parameter type of doOtherUDF
sql("SELECT doOtherUDF(doUDF(41))")
checkAnswer(sql("SELECT doOtherUDF(doUDF(41))"), Row(41) :: Nil)

// this one passes only with the fix SPARK-19311:
// return type of doSubUDF is a subtype of the parameter type of doOtherUDF
sql("SELECT doOtherUDF(doSubTypeUDF(42))")
checkAnswer(sql("SELECT doOtherUDF(doSubTypeUDF(42))"), Row(42) :: Nil)
}

test("except on UDT") {
Expand Down