diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 3d10b084a8db..242a065a58e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, DataType, UserDefinedType} +import org.apache.spark.util.Utils /** * User-defined function. @@ -1152,7 +1153,7 @@ case class ScalaUDF( private[this] val resultConverter = createToCatalystConverter(dataType) lazy val udfErrorMessage = { - val funcCls = function.getClass.getSimpleName + val funcCls = Utils.getSimpleName(function.getClass) val inputTypes = children.map(_.dataType.catalogString).mkString(", ") val outputType = dataType.catalogString s"Failed to execute user defined function($funcCls: ($inputTypes) => $outputType)" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 5c1fe265c15d..8b7e9ecfe4e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -19,7 +19,9 @@ package org.apache.spark.sql import java.math.BigDecimal +import org.apache.spark.SparkException import org.apache.spark.sql.api.java._ +import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.{QueryExecution, SimpleMode} import org.apache.spark.sql.execution.columnar.InMemoryRelation @@ -669,4 +671,30 @@ class UDFSuite extends QueryTest with SharedSparkSession { val df = Seq(Array(Some(TestData(50, "2")), None)).toDF("col") checkAnswer(df.select(myUdf(Column("col"))), Row(100) :: Nil) } + + object MalformedClassObject extends Serializable { + class MalformedNonPrimitiveFunction extends (String => Int) with Serializable { + override def apply(v1: String): Int = v1.toInt / 0 + } + + class MalformedPrimitiveFunction extends (Int => Int) with Serializable { + override def apply(v1: Int): Int = v1 / 0 + } + } + + test("SPARK-32238: Use Utils.getSimpleName to avoid hitting Malformed class name") { + OuterScopes.addOuterScope(MalformedClassObject) + val f1 = new MalformedClassObject.MalformedNonPrimitiveFunction() + val f2 = new MalformedClassObject.MalformedPrimitiveFunction() + + val e1 = intercept[SparkException] { + Seq("20").toDF("col").select(udf(f1).apply(Column("col"))).collect() + } + assert(e1.getMessage.contains("UDFSuite$MalformedClassObject$MalformedNonPrimitiveFunction")) + + val e2 = intercept[SparkException] { + Seq(20).toDF("col").select(udf(f2).apply(Column("col"))).collect() + } + assert(e2.getMessage.contains("UDFSuite$MalformedClassObject$MalformedPrimitiveFunction")) + } }