Skip to content
Prev Previous commit
use expression dsl
  • Loading branch information
peter-toth committed Sep 20, 2023
commit 580f97b0e458a87509ccf6882075cf7330062d54
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@

package org.apache.spark.sql.streaming

import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EqualTo, GreaterThan, LessThan}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Literal}
import org.apache.spark.sql.execution.LocalTableScanExec
import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.JoinConditionSplitPredicates
import org.apache.spark.sql.types._

class StreamingSymmetricHashJoinHelperSuite extends StreamTest {
import org.apache.spark.sql.functions._

val leftAttributeA = AttributeReference("a", IntegerType)()
val leftAttributeB = AttributeReference("b", IntegerType)()
val rightAttributeC = AttributeReference("c", IntegerType)()
Expand All @@ -44,12 +43,7 @@ class StreamingSymmetricHashJoinHelperSuite extends StreamTest {
test("only literals") {
// Literal-only conjuncts end up on the left side because that's the first bucket they fit in.
// There's no semantic reason they couldn't be in any bucket.
val predicate =
And(
And(
LessThan(lit(1).expr, lit(5).expr),
LessThan(lit(6).expr, lit(7).expr)),
EqualTo(lit(0).expr, lit(-1).expr))
val predicate = Literal(1) < Literal(5) && Literal(6) < Literal(7) && Literal(0) === Literal(-1)
val split = JoinConditionSplitPredicates(Some(predicate), left, right)

assert(split.leftSideOnly.contains(predicate))
Expand All @@ -60,11 +54,7 @@ class StreamingSymmetricHashJoinHelperSuite extends StreamTest {

test("only left") {
val predicate =
And(
And(
GreaterThan(leftAttributeA, lit(1).expr),
GreaterThan(leftAttributeB, lit(5).expr)),
LessThan(leftAttributeA, leftAttributeB))
leftAttributeA > Literal(1) && leftAttributeB > Literal(5) && leftAttributeA < leftAttributeB
val split = JoinConditionSplitPredicates(Some(predicate), left, right)

assert(split.leftSideOnly.contains(predicate))
Expand All @@ -74,12 +64,8 @@ class StreamingSymmetricHashJoinHelperSuite extends StreamTest {
}

test("only right") {
val predicate =
And(
And(
GreaterThan(rightAttributeC, lit(1).expr),
GreaterThan(rightAttributeD, lit(5).expr)),
LessThan(rightAttributeD, rightAttributeC))
val predicate = rightAttributeC > Literal(1) && rightAttributeD > Literal(5) &&
rightAttributeD < rightAttributeC
val split = JoinConditionSplitPredicates(Some(predicate), left, right)

assert(split.leftSideOnly.isEmpty)
Expand All @@ -90,66 +76,55 @@ class StreamingSymmetricHashJoinHelperSuite extends StreamTest {

test("mixed conjuncts") {
val predicate =
And(
And(
And(
GreaterThan(leftAttributeA, leftAttributeB),
GreaterThan(rightAttributeC, rightAttributeD)),
EqualTo(leftAttributeA, rightAttributeC)),
EqualTo(lit(1).expr, lit(1).expr))
(leftAttributeA > leftAttributeB
&& rightAttributeC > rightAttributeD
&& leftAttributeA === rightAttributeC
&& Literal(1) === Literal(1))
val split = JoinConditionSplitPredicates(Some(predicate), left, right)

assert(split.leftSideOnly.contains(
And(GreaterThan(leftAttributeA, leftAttributeB), EqualTo(lit(1).expr, lit(1).expr))))
leftAttributeA > leftAttributeB && Literal(1) === Literal(1)))
assert(split.rightSideOnly.contains(
And(GreaterThan(rightAttributeC, rightAttributeD), EqualTo(lit(1).expr, lit(1).expr))))
assert(split.bothSides.contains(EqualTo(leftAttributeA, rightAttributeC)))
rightAttributeC > rightAttributeD && Literal(1) === Literal(1)))
assert(split.bothSides.contains((leftAttributeA === rightAttributeC)))
assert(split.full.contains(predicate))
}

test("conjuncts after nondeterministic") {
val predicate =
And(
And(
And(
And(
GreaterThan(rand(9).expr, lit(0).expr),
GreaterThan(leftAttributeA, leftAttributeB)),
GreaterThan(rightAttributeC, rightAttributeD)),
EqualTo(leftAttributeA, rightAttributeC)),
EqualTo(lit(1).expr, lit(1).expr))
(rand(9) > Literal(0)
&& leftAttributeA > leftAttributeB
&& rightAttributeC > rightAttributeD
&& leftAttributeA === rightAttributeC
&& Literal(1) === Literal(1))
val split = JoinConditionSplitPredicates(Some(predicate), left, right)

assert(split.leftSideOnly.contains(
And(GreaterThan(leftAttributeA, leftAttributeB), EqualTo(lit(1).expr, lit(1).expr))))
leftAttributeA > leftAttributeB && Literal(1) === Literal(1)))
assert(split.rightSideOnly.contains(
And(GreaterThan(rightAttributeC, rightAttributeD), EqualTo(lit(1).expr, lit(1).expr))))
rightAttributeC > rightAttributeD && Literal(1) === Literal(1)))
assert(split.bothSides.contains(
And(EqualTo(leftAttributeA, rightAttributeC), GreaterThan(rand(9).expr, lit(0).expr))))
leftAttributeA === rightAttributeC && rand(9).expr > Literal(0)))
assert(split.full.contains(predicate))
}


test("conjuncts before nondeterministic") {
val randCol = rand()
val randAttribute = rand(0)
val predicate =
And(
And(
And(
And(
GreaterThan(leftAttributeA, leftAttributeB),
GreaterThan(rightAttributeC, rightAttributeD)),
EqualTo(leftAttributeA, rightAttributeC)),
EqualTo(lit(1).expr, lit(1).expr)),
GreaterThan(randCol.expr, lit(0).expr))
(leftAttributeA > leftAttributeB
&& rightAttributeC > rightAttributeD
&& leftAttributeA === rightAttributeC
&& Literal(1) === Literal(1)
&& randAttribute > Literal(0))
val split = JoinConditionSplitPredicates(Some(predicate), left, right)

assert(split.leftSideOnly.contains(
And(GreaterThan(leftAttributeA, leftAttributeB), EqualTo(lit(1).expr, lit(1).expr))))
leftAttributeA > leftAttributeB && Literal(1) === Literal(1)))
assert(split.rightSideOnly.contains(
And(GreaterThan(rightAttributeC, rightAttributeD), EqualTo(lit(1).expr, lit(1).expr))))
rightAttributeC > rightAttributeD && Literal(1) === Literal(1)))
assert(split.bothSides.contains(
And(EqualTo(leftAttributeA, rightAttributeC), GreaterThan(randCol.expr, lit(0).expr))))
leftAttributeA === rightAttributeC && randAttribute > Literal(0)))
assert(split.full.contains(predicate))
}
}