diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index a924f10fb366..5e59eb33b4bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -419,11 +419,13 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } } + private def simpleClassName: String = Utils.getSimpleName(this.getClass) + /** * Returns the name of this type of TreeNode. Defaults to the class name. * Note that we remove the "Exec" suffix for physical operators here. */ - def nodeName: String = getClass.getSimpleName.replaceAll("Exec$", "") + def nodeName: String = simpleClassName.replaceAll("Exec$", "") /** * The arguments that should be included in the arg string. Defaults to the `productIterator`. @@ -610,7 +612,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { protected def jsonFields: List[JField] = { val fieldNames = getConstructorParameterNames(getClass) val fieldValues = productIterator.toSeq ++ otherCopyArgs - assert(fieldNames.length == fieldValues.length, s"${getClass.getSimpleName} fields: " + + assert(fieldNames.length == fieldValues.length, s"$simpleClassName fields: " + fieldNames.mkString(", ") + s", values: " + fieldValues.map(_.toString).mkString(", ")) fieldNames.zip(fieldValues).map { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index e37cf8a8e217..883f673cc4e4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.{AliasIdentifier, FunctionIdentifier, Inter import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.dsl.expressions.DslString import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.plans.{LeftOuter, NaturalJoin} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, SubqueryAlias, Union} import org.apache.spark.sql.catalyst.plans.physical.{IdentityBroadcastMode, RoundRobinPartitioning, SinglePartition} @@ -617,4 +617,39 @@ class TreeNodeSuite extends SparkFunSuite { val expected = Coalesce(Stream(Literal(1), Literal(3))) assert(result === expected) } + + object MalformedClassObject extends Serializable { + // Backport notes: this class inline-expands TaggingExpression from Spark 3.1 + case class MalformedNameExpression(child: Expression) extends UnaryExpression { + override def nullable: Boolean = child.nullable + override def dataType: DataType = child.dataType + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + child.genCode(ctx) + + override def eval(input: InternalRow): Any = child.eval(input) + } + } + + test("SPARK-32999: TreeNode.nodeName should not throw malformed class name error") { + val testTriggersExpectedError = try { + classOf[MalformedClassObject.MalformedNameExpression].getSimpleName + false + } catch { + case ex: java.lang.InternalError if ex.getMessage.contains("Malformed class name") => + true + case ex: Throwable => throw ex + } + // This test case only applies on older JDK versions (e.g. JDK8u), and doesn't trigger the + // issue on newer JDK versions (e.g. JDK11u). + assume(testTriggersExpectedError, "the test case didn't trigger malformed class name error") + + val expr = MalformedClassObject.MalformedNameExpression(Literal(1)) + try { + expr.nodeName + } catch { + case ex: java.lang.InternalError if ex.getMessage.contains("Malformed class name") => + fail("TreeNode.nodeName should not throw malformed class name error") + } + } }