Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ class Analyzer(
* expressions which equal GroupBy expressions with Literal(null), if those expressions
* are not set for this grouping set (according to the bit mask).
*/
private[this] def expand(g: GroupingSets): Seq[GroupExpression] = {
val result = new scala.collection.mutable.ArrayBuffer[GroupExpression]
private[this] def expand(g: GroupingSets): Seq[Seq[Expression]] = {
val result = new scala.collection.mutable.ArrayBuffer[Seq[Expression]]

g.bitmasks.foreach { bitmask =>
// get the non selected grouping attributes according to the bit mask
Expand All @@ -194,7 +194,7 @@ class Analyzer(
Literal.create(bitmask, IntegerType)
})

result += GroupExpression(substitution)
result += substitution
}

result.toSeq
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,18 +230,6 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio
}
}

// TODO Semantically we probably not need GroupExpression
// All we need is holding the Seq[Expression], and ONLY used in doing the
// expressions transformation correctly. Probably will be removed since it's
// not like a real expressions.
case class GroupExpression(children: Seq[Expression]) extends Expression {
self: Product =>
override def eval(input: Row): Any = throw new UnsupportedOperationException
override def nullable: Boolean = false
override def foldable: Boolean = false
override def dataType: DataType = throw new UnsupportedOperationException
}

/**
* Expressions that require a specific `DataType` as input should implement this trait
* so that the proper type conversions can be performed in the analyzer.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,16 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
}
}

val newArgs = productIterator.map {
def recursiveTransform(arg: Any): AnyRef = arg match {
case e: Expression => transformExpressionDown(e)
case Some(e: Expression) => Some(transformExpressionDown(e))
case m: Map[_, _] => m
case d: DataType => d // Avoid unpacking Structs
case seq: Traversable[_] => seq.map {
case e: Expression => transformExpressionDown(e)
case other => other
}
case seq: Traversable[_] => seq.map(recursiveTransform)
case other: AnyRef => other
}.toArray
}

val newArgs = productIterator.map(recursiveTransform).toArray

if (changed) makeCopy(newArgs) else this
}
Expand All @@ -114,17 +113,16 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
}
}

val newArgs = productIterator.map {
def recursiveTransform(arg: Any): AnyRef = arg match {
case e: Expression => transformExpressionUp(e)
case Some(e: Expression) => Some(transformExpressionUp(e))
case m: Map[_, _] => m
case d: DataType => d // Avoid unpacking Structs
case seq: Traversable[_] => seq.map {
case e: Expression => transformExpressionUp(e)
case other => other
}
case seq: Traversable[_] => seq.map(recursiveTransform)
case other: AnyRef => other
}.toArray
}

val newArgs = productIterator.map(recursiveTransform).toArray

if (changed) makeCopy(newArgs) else this
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ case class Window(
* @param child Child operator
*/
case class Expand(
projections: Seq[GroupExpression],
projections: Seq[Seq[Expression]],
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
override def statistics: Statistics = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ case class Dummy(optKey: Option[Expression]) extends Expression {
override def eval(input: Row): Any = null.asInstanceOf[Any]
}

case class ComplexPlan(exprs: Seq[Seq[Expression]])
extends org.apache.spark.sql.catalyst.plans.logical.LeafNode {
override def output: Seq[Attribute] = Nil
}

class TreeNodeSuite extends SparkFunSuite {
test("top node changed") {
val after = Literal(1) transform { case Literal(1, _) => Literal(2) }
Expand Down Expand Up @@ -220,4 +225,13 @@ class TreeNodeSuite extends SparkFunSuite {
assert(expected === actual)
}
}

test("transformExpressions on nested expression sequence") {
val plan = ComplexPlan(Seq(Seq(Literal(1)), Seq(Literal(2))))
val actual = plan.transformExpressions {
case Literal(value, _) => Literal(value.toString)
}
val expected = ComplexPlan(Seq(Seq(Literal("1")), Seq(Literal("2"))))
assert(expected === actual)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{UnknownPartitioning, Partit
*/
@DeveloperApi
case class Expand(
projections: Seq[GroupExpression],
projections: Seq[Seq[Expression]],
output: Seq[Attribute],
child: SparkPlan)
extends UnaryNode {
Expand All @@ -49,7 +49,7 @@ case class Expand(
// workers via closure. However we can't assume the Projection
// is serializable because of the code gen, so we have to
// create the projections within each of the partition processing.
val groups = projections.map(ee => newProjection(ee.children, child.output)).toArray
val groups = projections.map(ee => newProjection(ee, child.output)).toArray

new Iterator[Row] {
private[this] var result: Row = _
Expand Down