Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
init
  • Loading branch information
LuciferYang committed Oct 27, 2023
commit d2ffca1ae29a862c851a0cc0c2fad67a8881a210
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ trait AliasAwareOutputExpression extends SQLConfHelper {
/**
* Return a stream of expressions in which the original expression is projected with `aliasMap`.
*/
protected def projectExpression(expr: Expression): Stream[Expression] = {
protected def projectExpression(expr: Expression): LazyList[Expression] = {
val outputSet = AttributeSet(outputExpressions.map(_.toAttribute))
expr.multiTransformDown {
// Mapping with aliases
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]]
val newArgs = mapProductIterator {
case s: StructType => s // Don't convert struct types to some other type of Seq[StructField]
// Handle Seq[TreeNode] in TreeNode parameters.
case s: Stream[_] =>
// Stream is lazy so we need to force materialization
case s: LazyList[_] =>
// LazyList is lazy so we need to force materialization
s.map(mapChild).force
case s: Seq[_] =>
s.map(mapChild)
Expand Down Expand Up @@ -557,7 +557,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]]
* @return the stream of alternatives
*/
def multiTransformDown(
rule: PartialFunction[BaseType, Seq[BaseType]]): Stream[BaseType] = {
rule: PartialFunction[BaseType, Seq[BaseType]]): LazyList[BaseType] = {
multiTransformDownWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule)
}

Expand All @@ -567,19 +567,19 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]]
*
* As it is very easy to generate enormous number of alternatives when the input tree is huge or
* when the rule returns many alternatives for many nodes, this function returns the alternatives
* as a lazy `Stream` to be able to limit the number of alternatives generated at the caller side
* as needed.
* as a lazy `LazyList` to be able to limit the number of alternatives generated at the caller
* side as needed.
*
* The purpose of this function to access the returned alternatives by the rule only if they are
* needed so the rule can return a `Stream` whose elements are also lazily calculated.
* needed so the rule can return a `LazyList` whose elements are also lazily calculated.
* E.g. `multiTransform*` calls can be nested with the help of
* `MultiTransform.generateCartesianProduct()`.
*
* The rule should not apply or can return a one element `Seq` of original node to indicate that
* the original node without any transformation is a valid alternative.
*
* The rule can return `Seq.empty` to indicate that the original node should be pruned. In this
* case `multiTransform()` returns an empty `Stream`.
* case `multiTransform()` returns an empty `LazyList`.
*
* Please consider the following examples of `input.multiTransformDown(rule)`:
*
Expand All @@ -593,7 +593,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]]
* `Add(a, b)` => `Seq(11, 12, 21, 22)`
*
* The output is:
* `Stream(11, 12, 21, 22)`
* `LazyList(11, 12, 21, 22)`
*
* 2.
* In the previous example if we want to generate alternatives of `a` and `b` too then we need to
Expand All @@ -603,7 +603,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]]
* `Add(a, b)` => `Seq(11, 12, 21, 22, Add(a, b))`
*
* The output is:
* `Stream(11, 12, 21, 22, Add(1, 10), Add(2, 10), Add(1, 20), Add(2, 20))`
* `LazyList(11, 12, 21, 22, Add(1, 10), Add(2, 10), Add(1, 20), Add(2, 20))`
*
* @param rule a function used to generate alternatives for a node
* @param cond a Lambda expression to prune tree traversals. If `cond.apply` returns false
Expand All @@ -619,15 +619,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]]
def multiTransformDownWithPruning(
cond: TreePatternBits => Boolean,
ruleId: RuleId = UnknownRuleId
)(rule: PartialFunction[BaseType, Seq[BaseType]]): Stream[BaseType] = {
)(rule: PartialFunction[BaseType, Seq[BaseType]]): LazyList[BaseType] = {
if (!cond.apply(this) || isRuleIneffective(ruleId)) {
return Stream(this)
return LazyList(this)
}

// We could return `Seq(this)` if the `rule` doesn't apply and handle both
// - the doesn't apply
// - and the rule returns a one element `Seq(originalNode)`
// cases together. The returned `Seq` can be a `Stream` and unfortunately it doesn't seem like
// cases together. The returned `Seq` can be a `LazyList` and unfortunately it doesn't seem like
// there is a way to match on a one element stream without eagerly computing the tail's head.
// This contradicts with the purpose of only taking the necessary elements from the
// alternatives. I.e. the "multiTransformDown is lazy" test case in `TreeNodeSuite` would fail.
Expand All @@ -641,18 +641,18 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]]
})
}

val afterRulesStream = if (afterRules.isEmpty) {
val afterRulesLazyList = if (afterRules.isEmpty) {
if (ruleApplied) {
// If the rule returned with empty alternatives then prune
Stream.empty
LazyList.empty
} else {
// If the rule was not applied then keep the original node
this.markRuleAsIneffective(ruleId)
Stream(this)
LazyList(this)
}
} else {
// If the rule was applied then use the returned alternatives
afterRules.toStream.map { afterRule =>
afterRules.to(LazyList).map { afterRule =>
if (this fastEquals afterRule) {
this
} else {
Expand All @@ -662,13 +662,13 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]]
}
}

afterRulesStream.flatMap { afterRule =>
afterRulesLazyList.flatMap { afterRule =>
if (afterRule.containsChild.nonEmpty) {
MultiTransform.generateCartesianProduct(
afterRule.children.map(c => () => c.multiTransformDownWithPruning(cond, ruleId)(rule)))
.map(afterRule.withNewChildren)
} else {
Stream(afterRule)
LazyList(afterRule)
}
}
}
Expand Down Expand Up @@ -792,7 +792,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]]
case other => other
}.view.force.toMap // `mapValues` is lazy and we need to force it to materialize
case d: DataType => d // Avoid unpacking Structs
case args: Stream[_] => args.map(mapChild).force // Force materialization on stream
case args: LazyList[_] => args.map(mapChild).force // Force materialization on stream
case args: Iterable[_] => args.map(mapChild)
case nonChild: AnyRef => nonChild
case null => null
Expand Down Expand Up @@ -1321,8 +1321,8 @@ object MultiTransform {
* @param elementSeqs a list of sequences to build the cartesian product from
* @return the stream of generated `Seq` elements
*/
def generateCartesianProduct[T](elementSeqs: Seq[() => Seq[T]]): Stream[Seq[T]] = {
elementSeqs.foldRight(Stream(Seq.empty[T]))((elements, elementTails) =>
def generateCartesianProduct[T](elementSeqs: Seq[() => Seq[T]]): LazyList[Seq[T]] = {
elementSeqs.foldRight(LazyList(Seq.empty[T]))((elements, elementTails) =>
for {
elementTail <- elementTails
element <- elements()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -949,9 +949,9 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
}
}

private def newErrorAfterStream(es: Expression*) = {
es.toStream.append(
throw new NoSuchElementException("Stream should not return more elements")
private def newErrorAfterLazyList(es: Expression*) = {
es.to(LazyList).lazyAppendedAll(
throw new NoSuchElementException("LazyList should not return more elements")
)
}

Expand All @@ -975,8 +975,8 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d")))
val transformed = e.multiTransformDown {
case StringLiteral("a") => Seq(Literal(1), Literal(2), Literal(3))
case StringLiteral("b") => newErrorAfterStream(Literal(10))
case Add(StringLiteral("c"), StringLiteral("d"), _) => newErrorAfterStream(Literal(100))
case StringLiteral("b") => newErrorAfterLazyList(Literal(10))
case Add(StringLiteral("c"), StringLiteral("d"), _) => newErrorAfterLazyList(Literal(100))
}
val expected = for {
a <- Seq(Literal(1), Literal(2), Literal(3))
Expand All @@ -990,7 +990,7 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
val transformed2 = e.multiTransformDown {
case StringLiteral("a") => Seq(Literal(1), Literal(2), Literal(3))
case StringLiteral("b") => Seq(Literal(10), Literal(20), Literal(30))
case Add(StringLiteral("c"), StringLiteral("d"), _) => newErrorAfterStream(Literal(100))
case Add(StringLiteral("c"), StringLiteral("d"), _) => newErrorAfterLazyList(Literal(100))
}
val expected2 = for {
b <- Seq(Literal(10), Literal(20), Literal(30))
Expand Down Expand Up @@ -1055,7 +1055,7 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
test("multiTransformDown alternatives are generated only if needed") {
val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), Literal("d")))
val transformed = e.multiTransformDown {
case StringLiteral("a") => newErrorAfterStream()
case StringLiteral("a") => newErrorAfterLazyList()
case StringLiteral("b") => Seq.empty
}
assert(transformed.isEmpty)
Expand Down