diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 179a348d5baa..831257d78fb4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types.DataType /** * A pattern that matches any number of filter operations on top of another relational operator. @@ -231,3 +232,48 @@ object Unions { case other => other :: Nil } } + +/** + * A pattern that finds joins with range conditions that can be evaluated using a range join. + * + * TODO support partial range joins. + */ +object ExtractRangeJoinKeys extends PredicateHelper { + type ReturnType = (LogicalPlan, LogicalPlan, Seq[Expression], Seq[Expression], Seq[Boolean]) + def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { + case Join(left, right, Inner, Some(And(RangePredicate(d1, l1, h1, equi1), + RangePredicate(d2, l2, h2, equi2)))) if d1 == d2 => { + // L.Low < R.High && R.Low < L.High + if (evaluateOrder(l1, h1, l2, h2, left, right)) { + Some(left, right, Seq(l1, h2), Seq(l2, h1), Seq(equi1, equi2)) + } + // R.Low < L.High && L.Low < R.High + else if (evaluateOrder(l1, h1, l2, h2, right, left)) { + Some(left, right, Seq(l2, h1), Seq(l1, h2), Seq(equi2, equi1)) + } + else None + } + case _ => None + } + + def evaluateOrder(low1: Expression, high1: Expression, + low2: Expression, high2: Expression, + left: LogicalPlan, right: LogicalPlan): Boolean = { + canEvaluate(low1, left) && canEvaluate(high1, right) && + canEvaluate(low2, right) && canEvaluate(high2, left) + } +} + +/** + * A pattern that normalizes all range expressions. + */ +object RangePredicate { + def unapply(expression: Expression): Option[(DataType, Expression, Expression, Boolean)] = + expression match { + case LessThan(low, high) => Some(expression.dataType, low, high, false) + case LessThanOrEqual(low, high) => Some(expression.dataType, low, high, true) + case GreaterThan(high, low) => Some(expression.dataType, low, high, false) + case GreaterThanOrEqual(high, low) => Some(expression.dataType, low, high, true) + case _ => None + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 6005d35f015a..aa567d07afac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -315,6 +315,10 @@ private[spark] object SQLConf { defaultValue = Some(false), doc = "") + val RANGE_JOIN = booleanConf("spark.sql.planner.rangeJoin", + defaultValue = Some(false), + doc = "") + // This is only used for the thriftserver val THRIFTSERVER_POOL = stringConf("spark.sql.thriftserver.scheduler.pool", doc = "Set a Fair Scheduler pool for a JDBC client session") @@ -457,6 +461,9 @@ private[sql] class SQLConf extends Serializable with CatalystConf { */ private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN) + /** When true the planner will use range join operator (instead of BNL) for range queries. */ + private[spark] def rangeJoinEnabled: Boolean = getConf(RANGE_JOIN) + /** * When set to true, Spark SQL will use the Janino at runtime to generate custom bytecode * that evaluates expressions found in queries. In general this custom code runs much faster diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 477dea916472..d0bd0bc9b7ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -868,6 +868,7 @@ class SQLContext(@transient val sparkContext: SparkContext) InMemoryScans :: ParquetOperations :: BasicOperators :: + BroadcastRangeJoin :: CartesianProduct :: BroadcastNestedLoopJoin :: Nil) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index ce25af58b6ca..bbce13823552 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -210,6 +210,33 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + object BroadcastRangeJoin extends Strategy { + private[this] def makeRangeJoin(leftRangeKeys: Seq[Expression], + rightRangeKeys: Seq[Expression], + equality: Seq[Boolean], + buildSide: joins.BuildSide, + left: LogicalPlan, + right: LogicalPlan) = { + new joins.BroadcastRangeJoin( + leftRangeKeys, + rightRangeKeys, + equality, + buildSide, + planLater(left), + planLater(right)) + } + + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case ExtractRangeJoinKeys(CanBroadcast(left), right, + leftKeys, rightKeys, equality) if sqlContext.conf.rangeJoinEnabled => + makeRangeJoin(leftKeys, rightKeys, equality, joins.BuildLeft, left, right) :: Nil + case ExtractRangeJoinKeys(left, CanBroadcast(right), + leftKeys, rightKeys, equality) if sqlContext.conf.rangeJoinEnabled => + makeRangeJoin(leftKeys, rightKeys, equality, joins.BuildRight, left, right) :: Nil + case _ => Nil + } + } + object CartesianProduct extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Join(left, right, _, None) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastRangeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastRangeJoin.scala new file mode 100644 index 000000000000..cce45bad8b34 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastRangeJoin.scala @@ -0,0 +1,411 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.util.ThreadUtils + +import scala.annotation.tailrec +import scala.collection.mutable +import scala.concurrent._ +import scala.concurrent.duration._ + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +/** + * Performs an inner range join on two tables. A range join typically has the following form: + * + * SELECT A.* + * ,B.* + * FROM tableA A + * JOIN tableB B + * ON A.start <= B.end + * AND A.end > B.start + * + * The implementation builds a range index from the smaller build side, broadcasts this index + * to all executors. The streaming side is then matched against the index. This reduces the number + * of comparisons made by log(n) (n is the number of records in the build table) over the + * typical solution (Nested Loop Join). + * + * TODO NaN values + * TODO NULL values + * TODO Outer joins? StreamSide is quite easy/BuildSide requires bookkeeping and + * TODO This join will maintain sort order. The build side rows will also be added in a lower + * bound sorted fashion. + */ +@DeveloperApi +case class BroadcastRangeJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + equality: Seq[Boolean], + buildSide: BuildSide, + left: SparkPlan, + right: SparkPlan) + extends BinaryNode { + + private[this] lazy val (buildPlan, streamedPlan) = buildSide match { + case BuildLeft => (left, right) + case BuildRight => (right, left) + } + + private[this] lazy val (buildKeys, streamedKeys) = buildSide match { + case BuildLeft => (leftKeys, rightKeys) + case BuildRight => (rightKeys, leftKeys) + } + + override def output: Seq[Attribute] = left.output ++ right.output + + @transient + private[this] lazy val buildSideKeyGenerator: Projection = + newProjection(buildKeys, buildPlan.output) + + @transient + private[this] lazy val streamSideKeyGenerator: () => MutableProjection = + newMutableProjection(streamedKeys, streamedPlan.output) + + private[this] val timeout: Duration = { + val timeoutValue = sqlContext.conf.broadcastTimeout + if (timeoutValue < 0) { + Duration.Inf + } else { + timeoutValue.seconds + } + } + + // Construct the range index. + @transient + private[this] val indexBroadcastFuture = future { + // Deal with equality. + val Seq(allowLowEqual: Boolean, allowHighEqual: Boolean) = buildSide match { + case BuildLeft => equality.reverse + case BuildRight => equality + } + + // Get the ordering for the datatype. + val ordering = TypeUtils.getOrdering(buildKeys.head.dataType) + + // Note that we use .execute().collect() because we don't want to convert data to Scala types + // TODO find out if the result of a sort and a collect is still sorted. + val eventifier = RangeIndex.toRangeEvent(buildSideKeyGenerator, ordering) + val events = buildPlan.execute().map(_.copy()).collect().flatMap(eventifier) + + // Create the index. + val index = RangeIndex.build(ordering, events, allowLowEqual, allowHighEqual) + + // Broadcast the index. + sparkContext.broadcast(index) + }(BroadcastRangeJoin.broadcastRangeJoinExecutionContext) + + override def doExecute(): RDD[InternalRow] = { + // Construct the range index. + val indexBC = Await.result(indexBroadcastFuture, timeout) + + // Iterate over the streaming relation. + streamedPlan.execute().mapPartitions { stream => + new Iterator[InternalRow] { + private[this] val index = indexBC.value + private[this] val streamSideKeys = streamSideKeyGenerator() + private[this] val join = new JoinedRow2 // TODO create our own join row... + private[this] var row: InternalRow = EmptyRow + private[this] var iterator: Iterator[InternalRow] = Iterator.empty + + override final def hasNext: Boolean = { + var result = iterator.hasNext + while (!result && stream.hasNext) { + row = stream.next() + val lowHigh = streamSideKeys(row) + val low = lowHigh(0) + val high = lowHigh(1) + if (low != null && high != null) { + iterator = index.intersect(low, high) + } + result = iterator.hasNext + } + result + } + + override final def next(): InternalRow = { + buildSide match { + case BuildRight => join(row, iterator.next()) + case BuildLeft => join(iterator.next(), row) + } + } + } + } + } +} + +private[joins] object BroadcastRangeJoin { + private val broadcastRangeJoinExecutionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("broadcast-range-join", 128)) +} + +private[joins] object RangeIndex { + type RangeEvent = (Any, Int, InternalRow) + + /** Build a range index from an array of unsorted events. */ + def build(ordering: Ordering[Any], events: Array[RangeEvent], + allowLowEqual: Boolean, allowHighEqual: Boolean): RangeIndex = { + val eventOrdering = Ordering.Tuple2(ordering, Ordering.Int) + val sortedEvents = events.sortBy(e => (e._1, e._2))(eventOrdering) + buildFromSorted(ordering, sortedEvents, allowLowEqual, allowHighEqual) + } + + /** Build a range index from an array of sorted events. */ + def buildFromSorted(ordering: Ordering[Any], events: Array[RangeEvent], + allowLowEqual: Boolean, allowHighEqual: Boolean): RangeIndex = { + // Persisted index components. A dummy null value is added to the array. This makes searching + // easier and it allows us to deal gracefully with unbound keys. + val empty = Array.empty[InternalRow] + val keys = mutable.Buffer[Any](null) + val offsets = mutable.Buffer[Int](0) + val activatedRows = mutable.Buffer.empty[InternalRow] + val activeNewOffsets = mutable.Buffer.empty[Int] + val activeRows = mutable.Buffer.empty[Array[InternalRow]] + + // Current State of the iteration. + var currentKey: Any = null + var currentActiveNewOffset: Int = -1 + val currentActiveRows = mutable.Buffer.empty[InternalRow] + + // Store the currently active rows. + def writeActiveRows = { + activeNewOffsets += currentActiveNewOffset + if (currentActiveRows.isEmpty) activeRows += empty + else activeRows += currentActiveRows.toArray + } + events.foreach { + case (key, flow, row) => + // Check if we have finished processing a key + if (currentKey != key) { + writeActiveRows + currentKey = key + currentActiveNewOffset = -1 + keys += key + offsets += activatedRows.size + } + + // Store the offset at which we are starting to add rows to the 'active' buffer. + if (flow >= 0 && currentActiveNewOffset == -1) { + currentActiveNewOffset = currentActiveRows.size + } + + // Keep track of rows. + flow match { + case 1 => + activatedRows += row + currentActiveRows += row + case 0 => + activatedRows += row + case -1 => + currentActiveRows -= row + } + } + + // Store the final array of activate rows. + writeActiveRows + + // Determine corrections based on equality + val lowBoundEqualCorrection = if (allowLowEqual) 1 else 0 + val highBoundEqualCorrection = if (allowHighEqual) 0 else 1 + + // Create the index. + new RangeIndex(ordering, keys.toArray, offsets.toArray, activeNewOffsets.toArray, + activeRows.toArray, activatedRows.toArray, lowBoundEqualCorrection, highBoundEqualCorrection) + } + + /** Create a function that turns a row into its respective range events. */ + def toRangeEvent(lowHighExtr: Projection, cmp: Ordering[Any]): + (InternalRow => Seq[RangeEvent]) = { + (row: InternalRow) => { + val Row(low, high) = lowHighExtr(row) + // Valid points and intervals. + if (low != null && high != null) { + val result = cmp.compare(low, high) + // Point + if (result == 0) { + (low, 0, row) :: Nil + } + // Interval + else if (result < 0) { + (low, 1, row) ::(high, -1, row) :: Nil + } + // Reversed Interval (low > high) - Cannot join on this record. + else Nil + } + // Nulls + else Nil + } + } +} + +/** + * A range index is an data structure which can be used to efficiently execute range queries upon. A + * range query has a lower and an upper bound, the result of a range query is a iterator of rows + * that match the given constraints. + * + * @param ordering used for sorting keys, comparing keys and values, and retrieving the rows in a + * given interval. + * @param keys which are used for finding the active and activated rows. A key is used when + * something changes, an event occurs (if you will), in the composition of the active + * or activated rows. + * @param active contains the rows that are 'active' between the current key and the next key. + * This is only used for rows that span an interval. + * @param activeNewOffsets array contains the index at which the rows in the active array are new. + * @param activated contains the row that have been activated at the current key. This contains + * both rows that span an interval or only exist at one point + * @param lowBoundEqualCorrection correction to apply to the lower bound index, when the value + * queried equals the key. + * @param highBoundEqualCorrection correction to apply to the upper bound index, when the value + * queried equals the key. + */ +private[joins] class RangeIndex( + private[this] val ordering: Ordering[Any], + private[this] val keys: Array[Any], + private[this] val offsets: Array[Int], + private[this] val activeNewOffsets: Array[Int], + private[this] val active: Array[Array[InternalRow]], + private[this] val activated: Array[InternalRow], + private[this] val lowBoundEqualCorrection: Int, + private[this] val highBoundEqualCorrection: Int) extends Serializable { + + private[this] val maxKeyIndex = keys.length - 1 + + /** + * Find the index of the closest key lower than or equal to the value given. When a value is + * equal to the found key, the result is corrected. + * + * This method is tail recursive. + * + * @param value to find the closest lower or equal key index for. + * @param equalCorrection to correct the result with in case of equality. + * @param first index (inclusive) to start searching at. + * @param last index (inclusive) to stop searching at. + * @return the index of the closest upper bound. + */ + @tailrec + final def closestLowerKey(value: Any, equalCorrection: Int, + first: Int = 0, last: Int = maxKeyIndex): Int = { + // Determine the mid point. + val mid = first + ((last - first + 1) >>> 1) + + // Compare the value with the key at the mid point. + // Note that a value is always larger than NULL. + val key = keys(mid) + val cmp = if (key == null) 1 + else ordering.compare(value, key) + + // Value == Key. Keys are unique so we can stop. + if (cmp == 0) mid - equalCorrection + // No more elements left to search. + else if (first == last) mid + // Value > Key: Search the top half of the key array. + else if (cmp > 0) closestLowerKey(value, equalCorrection, mid, last) + // Value < Key: Search the lower half of the array. + else closestLowerKey(value, equalCorrection, first, mid - 1) + } + + /** + * Calculate the intersection between the index and a given range. + * + * @param low point of the range. Note that a NULL value is currently interpreted as an unbound + * (negative infinite) value. + * @param high point of the range to intersect with. Note that a NULL value is currently + * interpreted as an unbound (infinite) value. + * @return an iterator containing all the rows that fall within the given range. + */ + final def intersect(low: Any, high: Any): Iterator[InternalRow] = { + // Find first index by searching for the last key lower than the low value. + val first = if (low == null) 0 + else closestLowerKey(low, lowBoundEqualCorrection) + + // Find last index by searching for the last lower than the high value. + val last = if (high == null || first == maxKeyIndex) maxKeyIndex + else closestLowerKey(high, highBoundEqualCorrection, first) + + new Iterator[InternalRow] { + var activatedAvailable = first < last + var rowIndex = 0 + var rows = active(first) + var rowLength = if (first <= last) rows.length + else activeNewOffsets(first) + + override final def hasNext: Boolean = { + var result = rowIndex < rowLength + if (!result && activatedAvailable) { + activatedAvailable = false + rows = activated + rowIndex = offsets(first + 1) + rowLength = if (last == maxKeyIndex) activated.length + else offsets(last + 1) + result = rowIndex < rowLength + } + result + } + + override final def next(): InternalRow = { + val row = rows(rowIndex) + rowIndex += 1 + row + } + } + } + + /** + * Create a textual representation of the index for debugging purposes. + * + * @param maxKeys maximum number of keys shows in the string. + * @return a textual representation of the index for debugging purposes. + */ + def toDebugString(maxKeys: Int = Int.MaxValue): String = { + val builder = new StringBuilder + builder.append("Index[lowBoundEqualCorrection = ") + builder.append(lowBoundEqualCorrection) + builder.append(", highBoundEqualCorrection = ") + builder.append(highBoundEqualCorrection) + builder.append("]") + val keysShown = math.min(keys.length, maxKeys) + val keysLeft = keys.length - keysShown + for (i <- 0 until keysShown) { + builder.append("\n +[") + builder.append(keys(i)) + builder.append("]@") + builder.append(offsets(i)) + builder.append("\n | Active: ") + builder.append(active(i).mkString(",")) + builder.append("\n | Activated: ") + val nextOffset = if (i == maxKeyIndex) activated.length + else offsets(i + 1) + builder.append(activated.slice(offsets(i), nextOffset).mkString(",")) + } + if (keysLeft > 0) { + builder.append("\n (") + builder.append(keysLeft) + builder.append(" keys left)") + } + builder.toString + } + + override def toString: String = toDebugString(10) +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 3dd24130af81..c3ae045be0a2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.SparkFunSuite import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} +import org.apache.spark.sql.execution.joins.{BroadcastRangeJoin, BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext.implicits._ @@ -147,4 +147,19 @@ class PlannerSuite extends SparkFunSuite { val planned = planner.TakeOrderedAndProject(query) assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) } + + test("range query join condition and broadcastable table will use RangeJoin optimization") { + val origRangeJoinSetting = conf.rangeJoinEnabled + setConf(SQLConf.RANGE_JOIN, true) + + val interval = Seq((1, 20), (30, 70)).toDF("low", "high") + val planned = testData. + join(broadcast(interval), $"low" <= $"key" && $"key" < $"high", "inner"). + queryExecution. + executedPlan + val broadcastRangeJoins = planned.collect{ case j: BroadcastRangeJoin => j } + assert(broadcastRangeJoins.size == 1, "Should use broadcast range join") + + setConf(SQLConf.RANGE_JOIN, origRangeJoinSetting) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/RangeIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/RangeIndexSuite.scala new file mode 100644 index 000000000000..42db1f8eb881 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/RangeIndexSuite.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{BoundReference, InterpretedMutableProjection} +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.types.IntegerType + +class RangeIndexSuite extends SparkFunSuite { + private[this] val ordering = TypeUtils.getOrdering(IntegerType) + + private[this] val projection = new InterpretedMutableProjection(Seq( + BoundReference(0, IntegerType, nullable = false), + BoundReference(1, IntegerType, nullable = false))) + + private[this] val eventifier = RangeIndex.toRangeEvent(projection, ordering) + + test("RangeIndex Point Query") { + val r1 = InternalRow(1, 1) + val r2 = InternalRow(2, 2) + val r3 = InternalRow(3, 3) + val r4 = InternalRow(4, 4) + val r5 = InternalRow(4, 4) + val rs = Array(r1, r2, r3, r4, r5).flatMap(eventifier) + + // Low Bound Included - High Bound Excluded: + val index1 = RangeIndex.build(ordering, rs, allowLowEqual = true, allowHighEqual = false) + + // All + assertResult(Seq(r1, r2, r3, r4, r5))(index1.intersect(null, null).toSeq) + + // Bounds - l <= p && p < h + assertResult(Nil)(index1.intersect(5, null).toSeq) + assertResult(r3 :: r4 :: r5 :: Nil)(index1.intersect(3, null).toSeq) + assertResult(Nil)(index1.intersect(null, 1).toSeq) + assertResult(r1 :: Nil)(index1.intersect(null, 2).toSeq) + + // Ranges + assertResult(r1 :: Nil)(index1.intersect(1, 2).toSeq) + assertResult(r3 :: r4 :: r5 :: Nil)(index1.intersect(3, 5).toSeq) + + // Low Bound Excluded - High Bound Included: + val index2 = RangeIndex.build(ordering, rs, allowLowEqual = false, allowHighEqual = true) + + // Bounds + assertResult(r4 :: r5 :: Nil)(index2.intersect(3, null).toSeq) + assertResult(r1 :: Nil)(index2.intersect(null, 1).toSeq) + assertResult(r1 :: r2 :: Nil)(index2.intersect(null, 2).toSeq) + + // Ranges + assertResult(r2 :: Nil)(index2.intersect(1, 2).toSeq) + assertResult(r4 :: r5 :: Nil)(index2.intersect(3, 5).toSeq) + } + + test("RangeIndex Interval Query") { + val r1 = InternalRow(1, 2) + val r3 = InternalRow(3, 4) + val r4 = InternalRow(4, 5) + val r5 = InternalRow(3, 6) + val rs = Array(r1, r3, r4, r5).flatMap(eventifier) + + // Low Bound Excluded - High Bound Excluded (Normal when intersecting intervals): + val index1 = RangeIndex.build(ordering, rs, allowLowEqual = false, allowHighEqual = false) + + // All + assertResult(Seq(r1, r3, r5, r4))(index1.intersect(null, null).toSeq) + + // Bounds + assertResult(r5 :: Nil)(index1.intersect(5, null).toSeq) + assertResult(r3 :: r5 :: r4 :: Nil)(index1.intersect(3, null).toSeq) + assertResult(Nil)(index1.intersect(null, 1).toSeq) + assertResult(r1 :: Nil)(index1.intersect(null, 2).toSeq) + + // Ranges + assertResult(r1 :: Nil)(index1.intersect(1, 2).toSeq) + assertResult(r3 :: r5 :: r4 :: Nil)(index1.intersect(3, 5).toSeq) + assertResult(r5 :: r4 :: Nil)(index1.intersect(4, 5).toSeq) + + // Points + assertResult(Nil)(index1.intersect(2, 2).toSeq) + assertResult(r5 :: Nil)(index1.intersect(4, 4).toSeq) + + // Low Bound Included - High Bound Included: + val index2 = RangeIndex.build(ordering, rs, allowLowEqual = true, allowHighEqual = true) + + // Bounds + assertResult(r5 :: r4 :: Nil)(index2.intersect(5, null).toSeq) + assertResult(r1 :: Nil)(index2.intersect(null, 1).toSeq) + assertResult(r1 :: Nil)(index2.intersect(null, 2).toSeq) + + // Ranges + assertResult(r1 :: r3 :: r5 :: Nil)(index2.intersect(1, 3).toSeq) + assertResult(r3 :: r5 :: r4 :: Nil)(index2.intersect(3, 5).toSeq) + assertResult(r3 :: r5 :: r4 :: Nil)(index2.intersect(4, 5).toSeq) + + // Points + assertResult(r1 :: Nil)(index2.intersect(2, 2).toSeq) + assertResult(r3 :: r5 :: Nil)(index2.intersect(3, 3).toSeq) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/RangeJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/RangeJoinSuite.scala new file mode 100644 index 000000000000..9b3c5544999a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/RangeJoinSuite.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} + +class RangeJoinSuite extends SparkPlanTest { + val intervals1 = Seq( + (-1, 0), + (0, 1), + (0, 2), + (1, 5) + ).toDF("low1", "high1") + val intervalKeys1 = Seq("low1", "high1").map(UnresolvedAttribute.apply) + + val intervals2 = Seq( + (-2, -1), + (-4, -2), + (1, 3), + (5, 7) + ).toDF("low2", "high2") + val intervalKeys2 = Seq("low2", "high2").map(UnresolvedAttribute.apply) + + val points = Seq(-3, 1, 3, 6).map(Tuple1.apply).toDF("point") + val pointKeys = Seq("point", "point").map(UnresolvedAttribute.apply) + + test("interval-point range join") { + // low1 <= point && point < high1 + checkAnswer2(intervals1, points, (left: SparkPlan, right: SparkPlan) => + BroadcastRangeJoin( + intervalKeys1, + pointKeys, + true :: false :: Nil, + BuildRight, + left, + right), + Seq( + (0, 2, 1), + (1, 5, 1), + (1, 5, 3) + ).map(Row.fromTuple)) + + // low1 <= point && point < high1 + checkAnswer2(intervals1, points, (left: SparkPlan, right: SparkPlan) => + BroadcastRangeJoin( + intervalKeys1, + pointKeys, + false :: false :: Nil, + BuildRight, + left, + right), + Seq( + (0, 2, 1), + (1, 5, 3) + ).map(Row.fromTuple)) + + // low <= point && point <= high1 + checkAnswer2(points, intervals1, (left: SparkPlan, right: SparkPlan) => + BroadcastRangeJoin( + pointKeys, + intervalKeys1, + true :: true :: Nil, + BuildRight, + left, + right), + Seq( + (1, 0, 1), + (1, 0, 2), + (1, 1, 5), + (3, 1, 5) + ).map(Row.fromTuple)) + + // low1 < point && point < high1 + checkAnswer2(intervals1, points, (left: SparkPlan, right: SparkPlan) => + BroadcastRangeJoin( + intervalKeys1, + pointKeys, + false :: false :: Nil, + BuildLeft, + left, + right), + Seq( + (0, 2, 1), + (1, 5, 3) + ).map(Row.fromTuple)) + } + + test("interval-interval range join") { + // low1 <= high2 && low2 < high1 + checkAnswer2(intervals1, intervals2, (left: SparkPlan, right: SparkPlan) => + BroadcastRangeJoin( + intervalKeys1, + intervalKeys2, + true :: false :: Nil, + BuildRight, + left, + right), + Seq( + (-1, 0, -2, -1), + (0, 2, 1, 3), + (1, 5, 1, 3) + ).map(Row.fromTuple)) + + // low1 < high2 && low2 <= high1 + checkAnswer2(intervals1, intervals2, (left: SparkPlan, right: SparkPlan) => + BroadcastRangeJoin( + intervalKeys1, + intervalKeys2, + false :: true :: Nil, + BuildLeft, + left, + right), + Seq( + (0, 1, 1, 3), + (0, 2, 1, 3), + (1, 5, 1, 3), + (1, 5, 5, 7) + ).map(Row.fromTuple)) + + // low1 < high2 && low2 < high1 + checkAnswer2(intervals1, intervals2, (left: SparkPlan, right: SparkPlan) => + BroadcastRangeJoin( + intervalKeys1, + intervalKeys2, + false :: false :: Nil, + BuildRight, + left, + right), + Seq( + (0, 2, 1, 3), + (1, 5, 1, 3) + ).map(Row.fromTuple)) + + // low1 <= high2 && low2 <= high1 + checkAnswer2(intervals1, intervals2, (left: SparkPlan, right: SparkPlan) => + BroadcastRangeJoin( + intervalKeys1, + intervalKeys2, + true :: true :: Nil, + BuildLeft, + left, + right), + Seq( + (-1, 0, -2, -1), + (0, 1, 1, 3), + (0, 2, 1, 3), + (1, 5, 1, 3), + (1, 5, 5, 7) + ).map(Row.fromTuple)) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 4684d48aff88..7a34bc86114d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -454,6 +454,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { LeftSemiJoin, HashJoin, BasicOperators, + BroadcastRangeJoin, CartesianProduct, BroadcastNestedLoopJoin )