Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,8 @@ class Analyzer(

lazy val batches: Seq[Batch] = Seq(
Batch("Substitution", fixedPoint,
CTESubstitution ::
WindowsSubstitution ::
Nil : _*),
CTESubstitution,
WindowsSubstitution),
Batch("Resolution", fixedPoint,
ResolveRelations ::
ResolveReferences ::
Expand All @@ -84,7 +83,8 @@ class Analyzer(
HiveTypeCoercion.typeCoercionRules ++
extendedResolutionRules : _*),
Batch("Nondeterministic", Once,
PullOutNondeterministic),
PullOutNondeterministic,
ComputeCurrentTime),
Batch("UDF", Once,
HandleNullInputsForUDF),
Batch("Cleanup", fixedPoint,
Expand Down Expand Up @@ -1076,7 +1076,7 @@ class Analyzer(
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case p if !p.resolved => p // Skip unresolved nodes.

case plan => plan transformExpressionsUp {
case p => p transformExpressionsUp {

case udf @ ScalaUDF(func, _, inputs, _) =>
val parameterTypes = ScalaReflection.getParameterTypes(func)
Expand Down Expand Up @@ -1162,3 +1162,20 @@ object CleanupAliases extends Rule[LogicalPlan] {
}
}
}

/**
* Computes the current date and time to make sure we return the same result in a single query.
*/
object ComputeCurrentTime extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
val dateExpr = CurrentDate()
val timeExpr = CurrentTimestamp()
val currentDate = Literal.create(dateExpr.eval(EmptyRow), dateExpr.dataType)
val currentTime = Literal.create(timeExpr.eval(EmptyRow), timeExpr.dataType)

plan transformAllExpressions {
case CurrentDate() => currentDate
case CurrentTimestamp() => currentTime
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._

class AnalysisSuite extends AnalysisTest {
Expand Down Expand Up @@ -218,4 +219,41 @@ class AnalysisSuite extends AnalysisTest {
udf4)
// checkUDF(udf4, expected4)
}

test("analyzer should replace current_timestamp with literals") {
val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()),
LocalRelation())

val min = System.currentTimeMillis() * 1000
val plan = in.analyze.asInstanceOf[Project]
val max = (System.currentTimeMillis() + 1) * 1000

val lits = new scala.collection.mutable.ArrayBuffer[Long]
plan.transformAllExpressions { case e: Literal =>
lits += e.value.asInstanceOf[Long]
e
}
assert(lits.size == 2)
assert(lits(0) >= min && lits(0) <= max)
assert(lits(1) >= min && lits(1) <= max)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we also assert that lits(0) == lits(1)?

assert(lits(0) == lits(1))
}

test("analyzer should replace current_date with literals") {
val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation())

val min = DateTimeUtils.millisToDays(System.currentTimeMillis())
val plan = in.analyze.asInstanceOf[Project]
val max = DateTimeUtils.millisToDays(System.currentTimeMillis())

val lits = new scala.collection.mutable.ArrayBuffer[Int]
plan.transformAllExpressions { case e: Literal =>
lits += e.value.asInstanceOf[Int]
e
}
assert(lits.size == 2)
assert(lits(0) >= min && lits(0) <= max)
assert(lits(1) >= min && lits(1) <= max)
assert(lits(0) == lits(1))
}
}