Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,11 @@ object NestedColumnAliasing {
case _ => false
}

private def canAlias(ev: Expression): Boolean = {
// we can not alias the attr from lambda variable whose expr id is not available
!ev.exists(_.isInstanceOf[NamedLambdaVariable]) && ev.references.size == 1
}

/**
* Returns two types of expressions:
* - Root references that are individually accessed
Expand All @@ -226,11 +231,11 @@ object NestedColumnAliasing {
*/
private def collectRootReferenceAndExtractValue(e: Expression): Seq[Expression] = e match {
case _: AttributeReference => Seq(e)
case GetStructField(_: ExtractValue | _: AttributeReference, _, _) => Seq(e)
case GetStructField(_: ExtractValue | _: AttributeReference, _, _) if canAlias(e) => Seq(e)
case GetArrayStructFields(_: MapValues |
_: MapKeys |
_: ExtractValue |
_: AttributeReference, _, _, _, _) => Seq(e)
_: AttributeReference, _, _, _, _) if canAlias(e) => Seq(e)
case es if es.children.nonEmpty => es.children.flatMap(collectRootReferenceAndExtractValue)
case _ => Seq.empty
}
Expand All @@ -249,13 +254,8 @@ object NestedColumnAliasing {
val otherRootReferences = new mutable.ArrayBuffer[AttributeReference]()
exprList.foreach { e =>
extractor(e).foreach {
// we can not alias the attr from lambda variable whose expr id is not available
case ev: ExtractValue if !ev.exists(_.isInstanceOf[NamedLambdaVariable]) =>
if (ev.references.size == 1) {
nestedFieldReferences.append(ev)
}
case ev: ExtractValue => nestedFieldReferences.append(ev)
case ar: AttributeReference => otherRootReferences.append(ar)
case _ => // ignore
}
}
val exclusiveAttrSet = AttributeSet(exclusiveAttrs ++ otherRootReferences)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,27 @@ class NestedColumnAliasingSuite extends SchemaPruningTest {
// The plan is expected to be unchanged.
comparePlans(plan, RemoveNoopOperators.apply(optimized.get))
}

test("SPARK-48428: Do not pushdown when attr is used in expression with mutliple references") {
val query = contact
.limit(5)
.select(
GetStructField(GetStructField(CreateStruct(Seq($"id", $"employer")), 1), 0),
$"employer.id")
.analyze
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I haven't touched this part of code for a while. Can you give a detailed explanation about how IllegalStateException can be thrown, with this test case, to help reviewers understand the bug?

Copy link
Contributor Author

@eejbyfeldt eejbyfeldt Jun 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the test case with the code on master collectRootReferenceAndExtractValue will collect to candidates for aliasing $"employer.id" and GetStructField(GetStructField(CreateStruct(Seq($"id", $"employer")), 1), 0) and no root references. The second one gets filtered out due to this condition
https://github.com/apache/spark/blob/v3.5.1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala#L253

Then the logic after this will only consider that we are accessing only $"employer.id" and will miss/be unaware that we are accessing $employer inside the other expression. This will cause it to perform the transformation/optimization incorrectly which leads to the IllegalStateException.

With the fix collectRootReferenceAndExtractValue will return $"employer.id" as a candidate and $id and $employer as root references. This means that when evaluating $"employer.id this time it will be discarded due to this filter: https://github.com/apache/spark/blob/v3.5.1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala#L264

Does that help understand the issue?


val optimized = Optimize.execute(query)

val expected = contact
.select($"id", $"employer")
.limit(5)
.select(
GetStructField(GetStructField(CreateStruct(Seq($"id", $"employer")), 1), 0),
$"employer.id")
.analyze

comparePlans(optimized, expected)
}
}

object NestedColumnAliasingSuite {
Expand Down