From 21722fed755eee0f06692a0b29b3518c92a5eb41 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 30 Jun 2015 22:17:03 -0400 Subject: [PATCH 1/8] Initial Range Join commit: Compiles & Style Checks work. --- .../sql/catalyst/planning/patterns.scala | 46 +++ .../org/apache/spark/sql/SQLContext.scala | 1 + .../spark/sql/execution/SparkStrategies.scala | 27 ++ .../execution/joins/BroadcastRangeJoin.scala | 312 ++++++++++++++++++ 4 files changed, 386 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastRangeJoin.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 179a348d5baa..831257d78fb4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types.DataType /** * A pattern that matches any number of filter operations on top of another relational operator. @@ -231,3 +232,48 @@ object Unions { case other => other :: Nil } } + +/** + * A pattern that finds joins with range conditions that can be evaluated using a range join. + * + * TODO support partial range joins. + */ +object ExtractRangeJoinKeys extends PredicateHelper { + type ReturnType = (LogicalPlan, LogicalPlan, Seq[Expression], Seq[Expression], Seq[Boolean]) + def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { + case Join(left, right, Inner, Some(And(RangePredicate(d1, l1, h1, equi1), + RangePredicate(d2, l2, h2, equi2)))) if d1 == d2 => { + // L.Low < R.High && R.Low < L.High + if (evaluateOrder(l1, h1, l2, h2, left, right)) { + Some(left, right, Seq(l1, h2), Seq(l2, h1), Seq(equi1, equi2)) + } + // R.Low < L.High && L.Low < R.High + else if (evaluateOrder(l1, h1, l2, h2, right, left)) { + Some(left, right, Seq(l2, h1), Seq(l1, h2), Seq(equi2, equi1)) + } + else None + } + case _ => None + } + + def evaluateOrder(low1: Expression, high1: Expression, + low2: Expression, high2: Expression, + left: LogicalPlan, right: LogicalPlan): Boolean = { + canEvaluate(low1, left) && canEvaluate(high1, right) && + canEvaluate(low2, right) && canEvaluate(high2, left) + } +} + +/** + * A pattern that normalizes all range expressions. + */ +object RangePredicate { + def unapply(expression: Expression): Option[(DataType, Expression, Expression, Boolean)] = + expression match { + case LessThan(low, high) => Some(expression.dataType, low, high, false) + case LessThanOrEqual(low, high) => Some(expression.dataType, low, high, true) + case GreaterThan(high, low) => Some(expression.dataType, low, high, false) + case GreaterThanOrEqual(high, low) => Some(expression.dataType, low, high, true) + case _ => None + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 477dea916472..d0bd0bc9b7ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -868,6 +868,7 @@ class SQLContext(@transient val sparkContext: SparkContext) InMemoryScans :: ParquetOperations :: BasicOperators :: + BroadcastRangeJoin :: CartesianProduct :: BroadcastNestedLoopJoin :: Nil) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index ce25af58b6ca..2768f2137ee2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -210,6 +210,33 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + object BroadcastRangeJoin extends Strategy { + private[this] def makeRangeJoin(leftRangeKeys: Seq[Expression], + rightRangeKeys: Seq[Expression], + equality: Seq[Boolean], + buildSide: joins.BuildSide, + left: LogicalPlan, + right: LogicalPlan) = { + new joins.BroadcastRangeJoin( + leftRangeKeys, + rightRangeKeys, + equality, + buildSide, + planLater(left), + planLater(right)) + } + + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case ExtractRangeJoinKeys(CanBroadcast(left), right, + leftKeys, rightKeys, equality) => + makeRangeJoin(leftKeys, rightKeys, equality, joins.BuildLeft, left, right) :: Nil + case ExtractRangeJoinKeys(left, CanBroadcast(right), + leftKeys, rightKeys, equality) => + makeRangeJoin(leftKeys, rightKeys, equality, joins.BuildRight, left, right) :: Nil + case _ => Nil + } + } + object CartesianProduct extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Join(left, right, _, None) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastRangeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastRangeJoin.scala new file mode 100644 index 000000000000..4f58dc27f505 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastRangeJoin.scala @@ -0,0 +1,312 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.util.ThreadUtils + +import scala.annotation.tailrec +import scala.collection.mutable +import scala.concurrent._ +import scala.concurrent.duration._ + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +/** + * Performs an inner range join on two tables. A range join typically has the following form: + * + * SELECT A.* + * ,B.* + * FROM tableA A + * JOIN tableB B + * ON A.start <= B.end + * AND A.end > B.start + * + * The implementation builds a range index from the smaller build side, broadcasts this index + * to all executors. The streaming side is then matched against the index. This reduces the number + * of comparisons made by log(n) (n is the number of records in the build table) over the + * typical solution (Nested Loop Join). + * + * TODO + * TODO NULL values + * TODO Equality + * TODO Outer joins? + */ +@DeveloperApi +case class BroadcastRangeJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + equality: Seq[Boolean], + buildSide: BuildSide, + left: SparkPlan, + right: SparkPlan) + extends BinaryNode { + + private[this] lazy val (buildPlan, streamedPlan) = buildSide match { + case BuildLeft => (left, right) + case BuildRight => (right, left) + } + + private[this] lazy val (buildKeys, streamedKeys) = buildSide match { + case BuildLeft => (leftKeys, rightKeys) + case BuildRight => (rightKeys, leftKeys) + } + + override def output: Seq[Attribute] = left.output ++ right.output + + @transient + private[this] lazy val buildSideKeyGenerator: Projection = + newProjection(buildKeys, buildPlan.output) + + @transient + private[this] lazy val streamSideKeyGenerator: () => MutableProjection = + newMutableProjection(streamedKeys, streamedPlan.output) + + private[this] val timeout: Duration = { + val timeoutValue = sqlContext.conf.broadcastTimeout + if (timeoutValue < 0) { + Duration.Inf + } else { + timeoutValue.seconds + } + } + + // Construct the range index. + @transient + private[this] val indexBroadcastFuture = future { + // Deal with equality. + val Seq(allowLowEqual: Boolean, allowHighEqual: Boolean) = buildSide match { + case BuildLeft => equality.head + case BuildRight => equality.reverse + } + + // Get the ordering for the datatype. + val ordering = TypeUtils.getOrdering(buildKeys.head.dataType) + + // Note that we use .execute().collect() because we don't want to convert data to Scala types + // TODO find out if the result of a sort and a collect is still sorted. + val events = buildPlan.execute().flatMap(RangeIndex.toRangeEvent( + buildSideKeyGenerator, ordering)).collect() + + // Create the index. + val index = RangeIndex.build(ordering, events, allowLowEqual, allowHighEqual) + + // Broadcast the index. + sparkContext.broadcast(index) + }(BroadcastRangeJoin.broadcastRangeJoinExecutionContext) + + override def doExecute(): RDD[InternalRow] = { + // Construct the range index. + val indexBC = Await.result(indexBroadcastFuture, timeout) + + // Iterate over the streaming relation. + streamedPlan.execute().mapPartitions { stream => + val index = indexBC.value + val streamSideKeys = streamSideKeyGenerator() + val join = new JoinedRow2 // TODO create our own join row... + stream.flatMap { sRow => + // Get the bounds. + val lowHigh = streamSideKeys(sRow) + val low = lowHigh(0) + val high = lowHigh(1) + + // Only allow non-null keys. + if (low != null && high != null) { + index.intersect(low, high).map { bRow => + buildSide match { + case BuildRight => join(sRow, bRow) + case BuildLeft => join(bRow, sRow) + } + } + } + else Iterator.empty + } + } + } +} + +private[joins] object BroadcastRangeJoin { + private val broadcastRangeJoinExecutionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("broadcast-range-join", 128)) +} + +private[joins] object RangeIndex { + type RangeEvent = (Any, Int, InternalRow) + + /** Build a range index from an array of unsorted events. */ + def build(ordering: Ordering[Any], events: Array[RangeEvent], + allowLowEqual: Boolean, allowHighEqual: Boolean): RangeIndex = { + buildFromSorted(ordering, events.sortBy(_._1)(ordering), allowLowEqual, allowHighEqual) + } + + /** Build a range index from an array of sorted events. */ + def buildFromSorted(ordering: Ordering[Any], events: Array[RangeEvent], + allowLowEqual: Boolean, allowHighEqual: Boolean): RangeIndex = { + // Persisted index components. A dummy null value is added to the array. This makes searching + // easier and it allows us to deal gracefully with unbound keys. + val keys = mutable.Buffer[Any](null) + val activatedRows = mutable.Buffer(Array.empty[InternalRow]) + val activeRows = mutable.Buffer(Array.empty[InternalRow]) + + // Current State of the iteration. + var currentKey: Any = null + var currentRowCount = 0 + val currentActivatedRows = mutable.Buffer.empty[InternalRow] + val currentActiveRows = mutable.Buffer.empty[InternalRow] + + // Store the current key state in the final results. + def finishKey = { + if (currentRowCount > 0) { + keys += currentKey + activatedRows += currentActivatedRows.toArray + activeRows += currentActiveRows.toArray + } + } + events.foreach { + case (key, flow, row) => + // Check if we have finished processing a key + if (currentKey != key) { + finishKey + currentKey = key + currentActivatedRows.clear() + currentRowCount = 0 + } + + // Keep track of rows. + flow match { + case 1 => + currentActiveRows += row + currentActivatedRows += row + case -1 => + currentActiveRows -= row + } + currentRowCount += 1 + } + + // Store the final events. + finishKey + + // Determine corrections based on equality + val lowBoundEqualCorrection = if (allowLowEqual) 0 else 1 + val highBoundEqualCorrection = if (allowHighEqual) 1 else 0 + + // Create the index. + new RangeIndex(ordering, keys.toArray, activeRows.toArray, activatedRows.toArray, + lowBoundEqualCorrection, highBoundEqualCorrection) + } + + /** Create a function that turns a row into its respective range events. */ + def toRangeEvent(lowHighExtr: Projection, cmp: Ordering[Any]): + (InternalRow => Seq[RangeEvent]) = { + (row: InternalRow) => { + val Row(low, high) = lowHighExtr(row) + // Valid points and intervals. + if (low != null && high != null && cmp.compare(low, high) <= 0) { + val copy = row.copy() + (low, 1, copy) ::(high, -1, copy) :: Nil + } + // Nulls + else Nil + } + } +} + +/** + * + * @param ordering used retrieve + * @param keys array + * @param active contains the rows that are 'active' between the current key and the next key. + * @param activated contains the row that have been activated at the current key. + */ +private[joins] class RangeIndex( + private[this] val ordering: Ordering[Any], + private[this] val keys: Array[Any], + private[this] val active: Array[Array[InternalRow]], + private[this] val activated: Array[Array[InternalRow]], + private[this] val lowBoundEqualCorrection: Int, + private[this] val highBoundEqualCorrection: Int) { + + /** + * Find the index of the closest upper bound to the value given. We can correct the result of a + * match by passing an index correction to the function. + * + * @param value to find the index of the closest upper bound. + * @return the index of the closest upper bound. + */ + @tailrec + final def closestLowerKey(value: Any, equalCorrection: Int, + first: Int = 0, last: Int = keys.length): Int = { + if (first < last) { + val index = first + (last - first - 1) / 2 + val key = keys(index) + val cmp = if (key == null) 1 + else ordering.compare(key, value) + if (cmp == 0) index + equalCorrection + else if (cmp < 0) closestLowerKey(value, equalCorrection, first, index) + else closestLowerKey(value, equalCorrection, index + 1, last) + } else first + } + + /** + * Calculate the intersection between the index and a given range. + * + * @param low point of the range. Note that a NULL value is currently interpreted as an unbound + * (negative infinite) value. + * @param high point of the range to intersect with. Note that a NULL value is currently + * interpreted as an unbound (infinite) value. + * @return an iterator containing all the rows that fall within the given range. + */ + final def intersect(low: Any, high: Any): Iterator[InternalRow] = { + // Find first index by searching for closest the upper bound to the low value. + val first = if (low == null) 0 + else closestLowerKey(low, lowBoundEqualCorrection) + + // Find last index by searching for closest the upper bound to the high value. + val last = if (high == null || first == keys.length - 1) keys.length - 1 + else closestLowerKey(high, highBoundEqualCorrection, first) + + // Return the iterator. + new Iterator[InternalRow] { + var index = first + var rowIndex = 0 + var rows = active(index) + + def hasNext: Boolean = { + var result = rows != null && rowIndex < rows.length + while (!result && index < last) { + index += 1 + rowIndex = 0 + rows = activated(index) + result = rows != null && rowIndex < rows.length + } + result + } + + def next() = { + val row = rows(rowIndex) + rowIndex += 1 + row + } + } + } +} From 65ce5fff6d1d24d79233ca0a5f01b2ede86a57e4 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 13 Jul 2015 18:27:03 -0400 Subject: [PATCH 2/8] Added Tests for Range Index. Ton of Bug Fixes. --- .../scala/org/apache/spark/sql/SQLConf.scala | 7 ++ .../spark/sql/execution/SparkStrategies.scala | 4 +- .../execution/joins/BroadcastRangeJoin.scala | 61 +++++++---- .../sql/execution/joins/RangeIndexSuite.scala | 100 ++++++++++++++++++ .../apache/spark/sql/hive/HiveContext.scala | 1 + 5 files changed, 149 insertions(+), 24 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/joins/RangeIndexSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 6005d35f015a..aa567d07afac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -315,6 +315,10 @@ private[spark] object SQLConf { defaultValue = Some(false), doc = "") + val RANGE_JOIN = booleanConf("spark.sql.planner.rangeJoin", + defaultValue = Some(false), + doc = "") + // This is only used for the thriftserver val THRIFTSERVER_POOL = stringConf("spark.sql.thriftserver.scheduler.pool", doc = "Set a Fair Scheduler pool for a JDBC client session") @@ -457,6 +461,9 @@ private[sql] class SQLConf extends Serializable with CatalystConf { */ private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN) + /** When true the planner will use range join operator (instead of BNL) for range queries. */ + private[spark] def rangeJoinEnabled: Boolean = getConf(RANGE_JOIN) + /** * When set to true, Spark SQL will use the Janino at runtime to generate custom bytecode * that evaluates expressions found in queries. In general this custom code runs much faster diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 2768f2137ee2..bbce13823552 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -228,10 +228,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ExtractRangeJoinKeys(CanBroadcast(left), right, - leftKeys, rightKeys, equality) => + leftKeys, rightKeys, equality) if sqlContext.conf.rangeJoinEnabled => makeRangeJoin(leftKeys, rightKeys, equality, joins.BuildLeft, left, right) :: Nil case ExtractRangeJoinKeys(left, CanBroadcast(right), - leftKeys, rightKeys, equality) => + leftKeys, rightKeys, equality) if sqlContext.conf.rangeJoinEnabled => makeRangeJoin(leftKeys, rightKeys, equality, joins.BuildRight, left, right) :: Nil case _ => Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastRangeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastRangeJoin.scala index 4f58dc27f505..9c2680e04efc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastRangeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastRangeJoin.scala @@ -47,10 +47,9 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} * of comparisons made by log(n) (n is the number of records in the build table) over the * typical solution (Nested Loop Join). * - * TODO + * TODO NaN values * TODO NULL values - * TODO Equality - * TODO Outer joins? + * TODO Outer joins? StreamSide is quite easy/BuildSide requires bookkeeping and */ @DeveloperApi case class BroadcastRangeJoin( @@ -96,7 +95,7 @@ case class BroadcastRangeJoin( private[this] val indexBroadcastFuture = future { // Deal with equality. val Seq(allowLowEqual: Boolean, allowHighEqual: Boolean) = buildSide match { - case BuildLeft => equality.head + case BuildLeft => equality case BuildRight => equality.reverse } @@ -105,8 +104,8 @@ case class BroadcastRangeJoin( // Note that we use .execute().collect() because we don't want to convert data to Scala types // TODO find out if the result of a sort and a collect is still sorted. - val events = buildPlan.execute().flatMap(RangeIndex.toRangeEvent( - buildSideKeyGenerator, ordering)).collect() + val eventifier = RangeIndex.toRangeEvent(buildSideKeyGenerator, ordering) + val events = buildPlan.execute().collect().flatMap(eventifier) // Create the index. val index = RangeIndex.build(ordering, events, allowLowEqual, allowHighEqual) @@ -207,8 +206,8 @@ private[joins] object RangeIndex { finishKey // Determine corrections based on equality - val lowBoundEqualCorrection = if (allowLowEqual) 0 else 1 - val highBoundEqualCorrection = if (allowHighEqual) 1 else 0 + val lowBoundEqualCorrection = if (allowLowEqual) 1 else 0 + val highBoundEqualCorrection = if (allowHighEqual) 0 else 1 // Create the index. new RangeIndex(ordering, keys.toArray, activeRows.toArray, activatedRows.toArray, @@ -232,11 +231,23 @@ private[joins] object RangeIndex { } /** + * A range index is an data structure which can be used to efficiently execute range queries upon. A + * range query has a lower and an upper bound, the result of a range query is a iterator of rows + * that match the given constraints. * - * @param ordering used retrieve - * @param keys array + * @param ordering used for sorting keys, comparing keys and values, and retrieving the rows in a + * given interval. + * @param keys which are used for finding the active and activated rows. A key is used when + * something changes, an event occurs (if you will), in the composition of the active + * or activated rows. * @param active contains the rows that are 'active' between the current key and the next key. - * @param activated contains the row that have been activated at the current key. + * This is only used for rows that span an interval. + * @param activated contains the row that have been activated at the current key. This contains + * both rows that span an interval or only exist at one point + * @param lowBoundEqualCorrection correction to apply to the lower bound index, when the value + * queried equals the key. + * @param highBoundEqualCorrection correction to apply to the upper bound index, when the value + * queried equals the key. */ private[joins] class RangeIndex( private[this] val ordering: Ordering[Any], @@ -244,24 +255,29 @@ private[joins] class RangeIndex( private[this] val active: Array[Array[InternalRow]], private[this] val activated: Array[Array[InternalRow]], private[this] val lowBoundEqualCorrection: Int, - private[this] val highBoundEqualCorrection: Int) { + private[this] val highBoundEqualCorrection: Int) extends Serializable { /** - * Find the index of the closest upper bound to the value given. We can correct the result of a - * match by passing an index correction to the function. + * Find the index of the closest key lower than or equal to the value given. When a value is + * equal to the found key, the result is corrected. * - * @param value to find the index of the closest upper bound. + * This method is tail recursive. + * + * @param value to find the closest lower or equal key index for. + * @param equalCorrection to correct the result with in case of equality. + * @param first index (inclusive) to start searching at. + * @param last index (inclusive) to stop searching at. * @return the index of the closest upper bound. */ @tailrec final def closestLowerKey(value: Any, equalCorrection: Int, - first: Int = 0, last: Int = keys.length): Int = { + first: Int = 0, last: Int = keys.length - 1): Int = { if (first < last) { - val index = first + (last - first - 1) / 2 + val index = first + ((last - first) >> 1) val key = keys(index) - val cmp = if (key == null) 1 - else ordering.compare(key, value) - if (cmp == 0) index + equalCorrection + val cmp = if (key == null) -1 + else ordering.compare(value, key) + if (cmp == 0) index - equalCorrection else if (cmp < 0) closestLowerKey(value, equalCorrection, first, index) else closestLowerKey(value, equalCorrection, index + 1, last) } else first @@ -277,11 +293,11 @@ private[joins] class RangeIndex( * @return an iterator containing all the rows that fall within the given range. */ final def intersect(low: Any, high: Any): Iterator[InternalRow] = { - // Find first index by searching for closest the upper bound to the low value. + // Find first index by searching for the last key lower than the low value. val first = if (low == null) 0 else closestLowerKey(low, lowBoundEqualCorrection) - // Find last index by searching for closest the upper bound to the high value. + // Find last index by searching for the last lower than the high value. val last = if (high == null || first == keys.length - 1) keys.length - 1 else closestLowerKey(high, highBoundEqualCorrection, first) @@ -310,3 +326,4 @@ private[joins] class RangeIndex( } } } + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/RangeIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/RangeIndexSuite.scala new file mode 100644 index 000000000000..fe1d9dc19a87 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/RangeIndexSuite.scala @@ -0,0 +1,100 @@ +package org.apache.spark.sql.execution.joins + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{BoundReference, InterpretedMutableProjection} +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.types.IntegerType + +class RangeIndexSuite extends SparkFunSuite { + private[this] val ordering = TypeUtils.getOrdering(IntegerType) + + private[this] val projection = new InterpretedMutableProjection(Seq( + BoundReference(0, IntegerType, nullable = false), + BoundReference(1, IntegerType, nullable = false))) + + private[this] val eventifier = RangeIndex.toRangeEvent(projection, ordering) + + test("RangeIndex Point Query") { + val r1 = InternalRow(1, 1) + val r2 = InternalRow(2, 2) + val r3 = InternalRow(3, 3) + val r4 = InternalRow(4, 4) + val r5 = InternalRow(4, 4) + val rs = Array(r1, r2, r3, r4, r5).flatMap(eventifier) + + // Low Bound Included - High Bound Excluded: + val index1 = RangeIndex.build(ordering, rs, allowLowEqual = true, allowHighEqual = false) + + // All + assertResult(Seq(r1, r2, r3, r4, r5))(index1.intersect(null, null).toSeq) + + // Bounds - l <= p && p < h + assertResult(Nil)(index1.intersect(5, null).toSeq) + assertResult(r3 :: r4 :: r5 :: Nil)(index1.intersect(3, null).toSeq) + assertResult(Nil)(index1.intersect(null, 1).toSeq) + assertResult(r1 :: Nil)(index1.intersect(null, 2).toSeq) + + // Ranges + assertResult(r1 :: Nil)(index1.intersect(1, 2).toSeq) + assertResult(r3 :: r4 :: r5 :: Nil)(index1.intersect(3, 5).toSeq) + + // Low Bound Excluded - High Bound Included: + val index2 = RangeIndex.build(ordering, rs, allowLowEqual = false, allowHighEqual = true) + + // Bounds + assertResult(r4 :: r5 :: Nil)(index2.intersect(3, null).toSeq) + assertResult(r1 :: Nil)(index2.intersect(null, 1).toSeq) + assertResult(r1 :: r2 :: Nil)(index2.intersect(null, 2).toSeq) + + // Ranges + assertResult(r2 :: Nil)(index2.intersect(1, 2).toSeq) + assertResult(r4 :: r5 :: Nil)(index2.intersect(3, 5).toSeq) + } + + test("RangeIndex Interval Query") { + val r1 = InternalRow(1, 2) + val r3 = InternalRow(3, 4) + val r4 = InternalRow(4, 5) + val r5 = InternalRow(3, 6) + val rs = Array(r1, r3, r4, r5).flatMap(eventifier) + + // Low Bound Excluded - High Bound Excluded (Normal when intersecting intervals): + val index1 = RangeIndex.build(ordering, rs, allowLowEqual = false, allowHighEqual = false) + + // All + assertResult(Seq(r1, r3, r5, r4))(index1.intersect(null, null).toSeq) + + // Bounds + assertResult(r5 :: Nil)(index1.intersect(5, null).toSeq) + assertResult(r3 :: r5 :: r4 :: Nil)(index1.intersect(3, null).toSeq) + assertResult(Nil)(index1.intersect(null, 1).toSeq) + assertResult(r1 :: Nil)(index1.intersect(null, 2).toSeq) + + // Ranges + assertResult(r1 :: Nil)(index1.intersect(1, 2).toSeq) + assertResult(r3 :: r5 :: r4 :: Nil)(index1.intersect(3, 5).toSeq) + assertResult(r5 :: r4 :: Nil)(index1.intersect(4, 5).toSeq) + + // Points + assertResult(Nil)(index1.intersect(2, 2).toSeq) + assertResult(r3 :: r5 :: Nil)(index1.intersect(3, 3).toSeq) + + // Low Bound Included - High Bound Included: + val index2 = RangeIndex.build(ordering, rs, allowLowEqual = true, allowHighEqual = true) + + // Bounds + assertResult(r5 :: r4 :: Nil)(index2.intersect(5, null).toSeq) + assertResult(r1 :: Nil)(index2.intersect(null, 1).toSeq) + assertResult(r1 :: Nil)(index2.intersect(null, 2).toSeq) + + // Ranges + assertResult(r1 :: r3 :: r5 :: Nil)(index2.intersect(1, 3).toSeq) + assertResult(r3 :: r5 :: r4 :: Nil)(index2.intersect(3, 5).toSeq) + assertResult(r3 :: r5 :: r4 :: Nil)(index2.intersect(4, 5).toSeq) + + // Points + assertResult(r1 :: Nil)(index2.intersect(2, 2).toSeq) + assertResult(r3 :: r5 :: Nil)(index2.intersect(3, 3).toSeq) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 4684d48aff88..7a34bc86114d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -454,6 +454,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { LeftSemiJoin, HashJoin, BasicOperators, + BroadcastRangeJoin, CartesianProduct, BroadcastNestedLoopJoin ) From 773c009662d8d398d3caa1b6c7affd9c8902ef23 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 14 Jul 2015 00:22:18 -0400 Subject: [PATCH 3/8] Add License to RangeIndexSuite --- .../sql/execution/joins/RangeIndexSuite.scala | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/RangeIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/RangeIndexSuite.scala index fe1d9dc19a87..c6df5c83c273 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/RangeIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/RangeIndexSuite.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.sql.execution.joins import org.apache.spark.SparkFunSuite From 6661c47af9177d9ad803d64b71c2a4bb8d0b6e39 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 14 Jul 2015 12:03:54 -0400 Subject: [PATCH 4/8] Treat intervals and points differenty during index creation. --- .../sql/execution/joins/BroadcastRangeJoin.scala | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastRangeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastRangeJoin.scala index 9c2680e04efc..18a29e348f5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastRangeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastRangeJoin.scala @@ -196,6 +196,8 @@ private[joins] object RangeIndex { case 1 => currentActiveRows += row currentActivatedRows += row + case 0 => + currentActivatedRows += row case -1 => currentActiveRows -= row } @@ -220,9 +222,19 @@ private[joins] object RangeIndex { (row: InternalRow) => { val Row(low, high) = lowHighExtr(row) // Valid points and intervals. - if (low != null && high != null && cmp.compare(low, high) <= 0) { + if (low != null && high != null) { + val result = cmp.compare(low, high) val copy = row.copy() - (low, 1, copy) ::(high, -1, copy) :: Nil + // Point + if (result == 0) { + (low, 0, copy) :: Nil + } + // Interval + else if (result < 0) { + (low, 1, copy) ::(high, -1, copy) :: Nil + } + // Reversed Interval (low > high) - Cannot join on this record. + else Nil } // Nulls else Nil From 6d205d46fd96bb9350a509a3c9d4535ca16bc195 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 15 Jul 2015 00:12:08 -0400 Subject: [PATCH 5/8] Bug Fixes. Improved Iterator code. --- .../execution/joins/BroadcastRangeJoin.scala | 72 +++++++++++-------- 1 file changed, 44 insertions(+), 28 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastRangeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastRangeJoin.scala index 18a29e348f5c..e67c347df5dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastRangeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastRangeJoin.scala @@ -95,8 +95,8 @@ case class BroadcastRangeJoin( private[this] val indexBroadcastFuture = future { // Deal with equality. val Seq(allowLowEqual: Boolean, allowHighEqual: Boolean) = buildSide match { - case BuildLeft => equality - case BuildRight => equality.reverse + case BuildLeft => equality.reverse + case BuildRight => equality } // Get the ordering for the datatype. @@ -105,7 +105,7 @@ case class BroadcastRangeJoin( // Note that we use .execute().collect() because we don't want to convert data to Scala types // TODO find out if the result of a sort and a collect is still sorted. val eventifier = RangeIndex.toRangeEvent(buildSideKeyGenerator, ordering) - val events = buildPlan.execute().collect().flatMap(eventifier) + val events = buildPlan.execute().map(_.copy()).collect().flatMap(eventifier) // Create the index. val index = RangeIndex.build(ordering, events, allowLowEqual, allowHighEqual) @@ -120,25 +120,34 @@ case class BroadcastRangeJoin( // Iterate over the streaming relation. streamedPlan.execute().mapPartitions { stream => - val index = indexBC.value - val streamSideKeys = streamSideKeyGenerator() - val join = new JoinedRow2 // TODO create our own join row... - stream.flatMap { sRow => - // Get the bounds. - val lowHigh = streamSideKeys(sRow) - val low = lowHigh(0) - val high = lowHigh(1) - - // Only allow non-null keys. - if (low != null && high != null) { - index.intersect(low, high).map { bRow => - buildSide match { - case BuildRight => join(sRow, bRow) - case BuildLeft => join(bRow, sRow) + new Iterator[InternalRow] { + private[this] val index = indexBC.value + private[this] val streamSideKeys = streamSideKeyGenerator() + private[this] val join = new JoinedRow2 // TODO create our own join row... + private[this] var row: InternalRow = EmptyRow + private[this] var iterator: Iterator[InternalRow] = Iterator.empty + + override final def hasNext: Boolean = { + var result = iterator.hasNext + while (!result && stream.hasNext) { + row = stream.next() + val lowHigh = streamSideKeys(row) + val low = lowHigh(0) + val high = lowHigh(1) + if (low != null && high != null) { + iterator = index.intersect(low, high) } + result = iterator.hasNext + } + result + } + + override final def next(): InternalRow = { + buildSide match { + case BuildRight => join(row, iterator.next()) + case BuildLeft => join(iterator.next(), row) } } - else Iterator.empty } } } @@ -224,14 +233,13 @@ private[joins] object RangeIndex { // Valid points and intervals. if (low != null && high != null) { val result = cmp.compare(low, high) - val copy = row.copy() // Point if (result == 0) { - (low, 0, copy) :: Nil + (low, 0, row) :: Nil } // Interval else if (result < 0) { - (low, 1, copy) ::(high, -1, copy) :: Nil + (low, 1, row) ::(high, -1, row) :: Nil } // Reversed Interval (low > high) - Cannot join on this record. else Nil @@ -285,13 +293,21 @@ private[joins] class RangeIndex( final def closestLowerKey(value: Any, equalCorrection: Int, first: Int = 0, last: Int = keys.length - 1): Int = { if (first < last) { - val index = first + ((last - first) >> 1) - val key = keys(index) - val cmp = if (key == null) -1 + // Determine the mid point. + val mid = first + ((last - first + 1) >>> 1) + + // Compare the value with the key at the mid point. + // Note that a value is always larger than NULL. + val key = keys(mid) + val cmp = if (key == null) 1 else ordering.compare(value, key) - if (cmp == 0) index - equalCorrection - else if (cmp < 0) closestLowerKey(value, equalCorrection, first, index) - else closestLowerKey(value, equalCorrection, index + 1, last) + + // Value == Key. Keys are unique so we can stop. + if (cmp == 0) mid - equalCorrection + // Value > Key. + else if (cmp > 0) closestLowerKey(value, equalCorrection, mid, last) + // Value < Key. + else closestLowerKey(value, equalCorrection, first, mid - 1) } else first } From b405e45d931fb04b914858e75e3fa3cb07bc0394 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 15 Jul 2015 23:06:05 -0400 Subject: [PATCH 6/8] Added more test and fixed a few boudary related bugs. --- .../execution/joins/BroadcastRangeJoin.scala | 162 ++++++++++++------ .../spark/sql/execution/PlannerSuite.scala | 17 +- .../sql/execution/joins/RangeIndexSuite.scala | 2 +- .../sql/execution/joins/RangeJoinSuite.scala | 119 +++++++++++++ 4 files changed, 244 insertions(+), 56 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/joins/RangeJoinSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastRangeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastRangeJoin.scala index e67c347df5dd..cce45bad8b34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastRangeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastRangeJoin.scala @@ -50,6 +50,8 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} * TODO NaN values * TODO NULL values * TODO Outer joins? StreamSide is quite easy/BuildSide requires bookkeeping and + * TODO This join will maintain sort order. The build side rows will also be added in a lower + * bound sorted fashion. */ @DeveloperApi case class BroadcastRangeJoin( @@ -164,7 +166,9 @@ private[joins] object RangeIndex { /** Build a range index from an array of unsorted events. */ def build(ordering: Ordering[Any], events: Array[RangeEvent], allowLowEqual: Boolean, allowHighEqual: Boolean): RangeIndex = { - buildFromSorted(ordering, events.sortBy(_._1)(ordering), allowLowEqual, allowHighEqual) + val eventOrdering = Ordering.Tuple2(ordering, Ordering.Int) + val sortedEvents = events.sortBy(e => (e._1, e._2))(eventOrdering) + buildFromSorted(ordering, sortedEvents, allowLowEqual, allowHighEqual) } /** Build a range index from an array of sorted events. */ @@ -172,57 +176,62 @@ private[joins] object RangeIndex { allowLowEqual: Boolean, allowHighEqual: Boolean): RangeIndex = { // Persisted index components. A dummy null value is added to the array. This makes searching // easier and it allows us to deal gracefully with unbound keys. + val empty = Array.empty[InternalRow] val keys = mutable.Buffer[Any](null) - val activatedRows = mutable.Buffer(Array.empty[InternalRow]) - val activeRows = mutable.Buffer(Array.empty[InternalRow]) + val offsets = mutable.Buffer[Int](0) + val activatedRows = mutable.Buffer.empty[InternalRow] + val activeNewOffsets = mutable.Buffer.empty[Int] + val activeRows = mutable.Buffer.empty[Array[InternalRow]] // Current State of the iteration. var currentKey: Any = null - var currentRowCount = 0 - val currentActivatedRows = mutable.Buffer.empty[InternalRow] + var currentActiveNewOffset: Int = -1 val currentActiveRows = mutable.Buffer.empty[InternalRow] - // Store the current key state in the final results. - def finishKey = { - if (currentRowCount > 0) { - keys += currentKey - activatedRows += currentActivatedRows.toArray - activeRows += currentActiveRows.toArray - } + // Store the currently active rows. + def writeActiveRows = { + activeNewOffsets += currentActiveNewOffset + if (currentActiveRows.isEmpty) activeRows += empty + else activeRows += currentActiveRows.toArray } events.foreach { case (key, flow, row) => // Check if we have finished processing a key if (currentKey != key) { - finishKey + writeActiveRows currentKey = key - currentActivatedRows.clear() - currentRowCount = 0 + currentActiveNewOffset = -1 + keys += key + offsets += activatedRows.size + } + + // Store the offset at which we are starting to add rows to the 'active' buffer. + if (flow >= 0 && currentActiveNewOffset == -1) { + currentActiveNewOffset = currentActiveRows.size } // Keep track of rows. flow match { case 1 => + activatedRows += row currentActiveRows += row - currentActivatedRows += row case 0 => - currentActivatedRows += row + activatedRows += row case -1 => currentActiveRows -= row } - currentRowCount += 1 } - // Store the final events. - finishKey + // Store the final array of activate rows. + writeActiveRows // Determine corrections based on equality val lowBoundEqualCorrection = if (allowLowEqual) 1 else 0 val highBoundEqualCorrection = if (allowHighEqual) 0 else 1 // Create the index. - new RangeIndex(ordering, keys.toArray, activeRows.toArray, activatedRows.toArray, - lowBoundEqualCorrection, highBoundEqualCorrection) + new RangeIndex(ordering, keys.toArray, offsets.toArray, activeNewOffsets.toArray, + activeRows.toArray, activatedRows.toArray, lowBoundEqualCorrection, highBoundEqualCorrection) } /** Create a function that turns a row into its respective range events. */ @@ -262,6 +271,7 @@ private[joins] object RangeIndex { * or activated rows. * @param active contains the rows that are 'active' between the current key and the next key. * This is only used for rows that span an interval. + * @param activeNewOffsets array contains the index at which the rows in the active array are new. * @param activated contains the row that have been activated at the current key. This contains * both rows that span an interval or only exist at one point * @param lowBoundEqualCorrection correction to apply to the lower bound index, when the value @@ -272,11 +282,15 @@ private[joins] object RangeIndex { private[joins] class RangeIndex( private[this] val ordering: Ordering[Any], private[this] val keys: Array[Any], + private[this] val offsets: Array[Int], + private[this] val activeNewOffsets: Array[Int], private[this] val active: Array[Array[InternalRow]], - private[this] val activated: Array[Array[InternalRow]], + private[this] val activated: Array[InternalRow], private[this] val lowBoundEqualCorrection: Int, private[this] val highBoundEqualCorrection: Int) extends Serializable { + private[this] val maxKeyIndex = keys.length - 1 + /** * Find the index of the closest key lower than or equal to the value given. When a value is * equal to the found key, the result is corrected. @@ -291,24 +305,24 @@ private[joins] class RangeIndex( */ @tailrec final def closestLowerKey(value: Any, equalCorrection: Int, - first: Int = 0, last: Int = keys.length - 1): Int = { - if (first < last) { - // Determine the mid point. - val mid = first + ((last - first + 1) >>> 1) - - // Compare the value with the key at the mid point. - // Note that a value is always larger than NULL. - val key = keys(mid) - val cmp = if (key == null) 1 - else ordering.compare(value, key) - - // Value == Key. Keys are unique so we can stop. - if (cmp == 0) mid - equalCorrection - // Value > Key. - else if (cmp > 0) closestLowerKey(value, equalCorrection, mid, last) - // Value < Key. - else closestLowerKey(value, equalCorrection, first, mid - 1) - } else first + first: Int = 0, last: Int = maxKeyIndex): Int = { + // Determine the mid point. + val mid = first + ((last - first + 1) >>> 1) + + // Compare the value with the key at the mid point. + // Note that a value is always larger than NULL. + val key = keys(mid) + val cmp = if (key == null) 1 + else ordering.compare(value, key) + + // Value == Key. Keys are unique so we can stop. + if (cmp == 0) mid - equalCorrection + // No more elements left to search. + else if (first == last) mid + // Value > Key: Search the top half of the key array. + else if (cmp > 0) closestLowerKey(value, equalCorrection, mid, last) + // Value < Key: Search the lower half of the array. + else closestLowerKey(value, equalCorrection, first, mid - 1) } /** @@ -326,32 +340,72 @@ private[joins] class RangeIndex( else closestLowerKey(low, lowBoundEqualCorrection) // Find last index by searching for the last lower than the high value. - val last = if (high == null || first == keys.length - 1) keys.length - 1 + val last = if (high == null || first == maxKeyIndex) maxKeyIndex else closestLowerKey(high, highBoundEqualCorrection, first) - // Return the iterator. new Iterator[InternalRow] { - var index = first + var activatedAvailable = first < last var rowIndex = 0 - var rows = active(index) - - def hasNext: Boolean = { - var result = rows != null && rowIndex < rows.length - while (!result && index < last) { - index += 1 - rowIndex = 0 - rows = activated(index) - result = rows != null && rowIndex < rows.length + var rows = active(first) + var rowLength = if (first <= last) rows.length + else activeNewOffsets(first) + + override final def hasNext: Boolean = { + var result = rowIndex < rowLength + if (!result && activatedAvailable) { + activatedAvailable = false + rows = activated + rowIndex = offsets(first + 1) + rowLength = if (last == maxKeyIndex) activated.length + else offsets(last + 1) + result = rowIndex < rowLength } result } - def next() = { + override final def next(): InternalRow = { val row = rows(rowIndex) rowIndex += 1 row } } } + + /** + * Create a textual representation of the index for debugging purposes. + * + * @param maxKeys maximum number of keys shows in the string. + * @return a textual representation of the index for debugging purposes. + */ + def toDebugString(maxKeys: Int = Int.MaxValue): String = { + val builder = new StringBuilder + builder.append("Index[lowBoundEqualCorrection = ") + builder.append(lowBoundEqualCorrection) + builder.append(", highBoundEqualCorrection = ") + builder.append(highBoundEqualCorrection) + builder.append("]") + val keysShown = math.min(keys.length, maxKeys) + val keysLeft = keys.length - keysShown + for (i <- 0 until keysShown) { + builder.append("\n +[") + builder.append(keys(i)) + builder.append("]@") + builder.append(offsets(i)) + builder.append("\n | Active: ") + builder.append(active(i).mkString(",")) + builder.append("\n | Activated: ") + val nextOffset = if (i == maxKeyIndex) activated.length + else offsets(i + 1) + builder.append(activated.slice(offsets(i), nextOffset).mkString(",")) + } + if (keysLeft > 0) { + builder.append("\n (") + builder.append(keysLeft) + builder.append(" keys left)") + } + builder.toString + } + + override def toString: String = toDebugString(10) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 3dd24130af81..c3ae045be0a2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.SparkFunSuite import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} +import org.apache.spark.sql.execution.joins.{BroadcastRangeJoin, BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext.implicits._ @@ -147,4 +147,19 @@ class PlannerSuite extends SparkFunSuite { val planned = planner.TakeOrderedAndProject(query) assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) } + + test("range query join condition and broadcastable table will use RangeJoin optimization") { + val origRangeJoinSetting = conf.rangeJoinEnabled + setConf(SQLConf.RANGE_JOIN, true) + + val interval = Seq((1, 20), (30, 70)).toDF("low", "high") + val planned = testData. + join(broadcast(interval), $"low" <= $"key" && $"key" < $"high", "inner"). + queryExecution. + executedPlan + val broadcastRangeJoins = planned.collect{ case j: BroadcastRangeJoin => j } + assert(broadcastRangeJoins.size == 1, "Should use broadcast range join") + + setConf(SQLConf.RANGE_JOIN, origRangeJoinSetting) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/RangeIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/RangeIndexSuite.scala index c6df5c83c273..42db1f8eb881 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/RangeIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/RangeIndexSuite.scala @@ -95,7 +95,7 @@ class RangeIndexSuite extends SparkFunSuite { // Points assertResult(Nil)(index1.intersect(2, 2).toSeq) - assertResult(r3 :: r5 :: Nil)(index1.intersect(3, 3).toSeq) + assertResult(r5 :: Nil)(index1.intersect(4, 4).toSeq) // Low Bound Included - High Bound Included: val index2 = RangeIndex.build(ordering, rs, allowLowEqual = true, allowHighEqual = true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/RangeJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/RangeJoinSuite.scala new file mode 100644 index 000000000000..4eff223724c4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/RangeJoinSuite.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} + +class RangeJoinSuite extends SparkPlanTest { + val intervals1 = Seq( + (-1, 0), + (0, 1), + (0, 2), + (1, 5) + ).toDF("low1", "high1") + val intervalKeys1 = Seq("low1", "high1").map(UnresolvedAttribute.apply) + + val intervals2 = Seq( + (-2, -1), + (-4, -2), + (1 ,3), + (5, 7) + ).toDF("low2", "high2") + val intervalKeys2 = Seq("low2", "high2").map(UnresolvedAttribute.apply) + + val points = Seq(-3, 1, 3, 6).map(Tuple1.apply).toDF("point") + val pointKeys = Seq("point", "point").map(UnresolvedAttribute.apply) + + test("interval-point range join") { + // low1 <= point && point < high1 + checkAnswer2(intervals1, points, (left: SparkPlan, right: SparkPlan) => + BroadcastRangeJoin(intervalKeys1, pointKeys, true :: false :: Nil, BuildRight, left, right), + Seq( + (0, 2, 1), + (1, 5, 1), + (1, 5, 3) + ).map(Row.fromTuple)) + + // low1 <= point && point < high1 + checkAnswer2(intervals1, points, (left: SparkPlan, right: SparkPlan) => + BroadcastRangeJoin(intervalKeys1, pointKeys, false :: false :: Nil, BuildRight, left, right), + Seq( + (0, 2, 1), + (1, 5, 3) + ).map(Row.fromTuple)) + + // low <= point && point <= high1 + checkAnswer2(points, intervals1, (left: SparkPlan, right: SparkPlan) => + BroadcastRangeJoin(pointKeys, intervalKeys1, true :: true :: Nil, BuildRight, left, right), + Seq( + (1, 0, 1), + (1, 0, 2), + (1, 1, 5), + (3, 1, 5) + ).map(Row.fromTuple)) + + // low1 < point && point < high1 + checkAnswer2(intervals1, points, (left: SparkPlan, right: SparkPlan) => + BroadcastRangeJoin(intervalKeys1, pointKeys, false :: false :: Nil, BuildLeft, left, right), + Seq( + (0, 2, 1), + (1, 5, 3) + ).map(Row.fromTuple)) + } + + test("interval-interval range join") { + // low1 <= high2 && low2 < high1 + checkAnswer2(intervals1, intervals2, (left: SparkPlan, right: SparkPlan) => + BroadcastRangeJoin(intervalKeys1, intervalKeys2, true :: false :: Nil, BuildRight, left, right), + Seq( + (-1, 0, -2, -1), + (0, 2, 1, 3), + (1, 5, 1, 3) + ).map(Row.fromTuple)) + + // low1 < high2 && low2 <= high1 + checkAnswer2(intervals1, intervals2, (left: SparkPlan, right: SparkPlan) => + BroadcastRangeJoin(intervalKeys1, intervalKeys2, false :: true :: Nil, BuildLeft, left, right), + Seq( + (0, 1, 1, 3), + (0, 2, 1, 3), + (1, 5, 1, 3), + (1, 5, 5, 7) + ).map(Row.fromTuple)) + + // low1 < high2 && low2 < high1 + checkAnswer2(intervals1, intervals2, (left: SparkPlan, right: SparkPlan) => + BroadcastRangeJoin(intervalKeys1, intervalKeys2, false :: false :: Nil, BuildRight, left, right), + Seq( + (0, 2, 1, 3), + (1, 5, 1, 3) + ).map(Row.fromTuple)) + + // low1 <= high2 && low2 <= high1 + checkAnswer2(intervals1, intervals2, (left: SparkPlan, right: SparkPlan) => + BroadcastRangeJoin(intervalKeys1, intervalKeys2, true :: true :: Nil, BuildLeft, left, right), + Seq( + (-1, 0, -2, -1), + (0, 1, 1, 3), + (0, 2, 1, 3), + (1, 5, 1, 3), + (1, 5, 5, 7) + ).map(Row.fromTuple)) + } +} From 9cc1c7ef3da1058f18f82016733a9bd28e5cdb9a Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 15 Jul 2015 23:29:23 -0400 Subject: [PATCH 7/8] scalastyle --- .../sql/execution/joins/RangeJoinSuite.scala | 64 ++++++++++++++++--- 1 file changed, 56 insertions(+), 8 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/RangeJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/RangeJoinSuite.scala index 4eff223724c4..b2e79ba80c27 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/RangeJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/RangeJoinSuite.scala @@ -43,7 +43,13 @@ class RangeJoinSuite extends SparkPlanTest { test("interval-point range join") { // low1 <= point && point < high1 checkAnswer2(intervals1, points, (left: SparkPlan, right: SparkPlan) => - BroadcastRangeJoin(intervalKeys1, pointKeys, true :: false :: Nil, BuildRight, left, right), + BroadcastRangeJoin( + intervalKeys1, + pointKeys, + true :: false :: Nil, + BuildRight, + left, + right), Seq( (0, 2, 1), (1, 5, 1), @@ -52,7 +58,13 @@ class RangeJoinSuite extends SparkPlanTest { // low1 <= point && point < high1 checkAnswer2(intervals1, points, (left: SparkPlan, right: SparkPlan) => - BroadcastRangeJoin(intervalKeys1, pointKeys, false :: false :: Nil, BuildRight, left, right), + BroadcastRangeJoin( + intervalKeys1, + pointKeys, + false :: false :: Nil, + BuildRight, + left, + right), Seq( (0, 2, 1), (1, 5, 3) @@ -60,7 +72,13 @@ class RangeJoinSuite extends SparkPlanTest { // low <= point && point <= high1 checkAnswer2(points, intervals1, (left: SparkPlan, right: SparkPlan) => - BroadcastRangeJoin(pointKeys, intervalKeys1, true :: true :: Nil, BuildRight, left, right), + BroadcastRangeJoin( + pointKeys, + intervalKeys1, + true :: true :: Nil, + BuildRight, + left, + right), Seq( (1, 0, 1), (1, 0, 2), @@ -70,7 +88,13 @@ class RangeJoinSuite extends SparkPlanTest { // low1 < point && point < high1 checkAnswer2(intervals1, points, (left: SparkPlan, right: SparkPlan) => - BroadcastRangeJoin(intervalKeys1, pointKeys, false :: false :: Nil, BuildLeft, left, right), + BroadcastRangeJoin( + intervalKeys1, + pointKeys, + false :: false :: Nil, + BuildLeft, + left, + right), Seq( (0, 2, 1), (1, 5, 3) @@ -80,7 +104,13 @@ class RangeJoinSuite extends SparkPlanTest { test("interval-interval range join") { // low1 <= high2 && low2 < high1 checkAnswer2(intervals1, intervals2, (left: SparkPlan, right: SparkPlan) => - BroadcastRangeJoin(intervalKeys1, intervalKeys2, true :: false :: Nil, BuildRight, left, right), + BroadcastRangeJoin( + intervalKeys1, + intervalKeys2, + true :: false :: Nil, + BuildRight, + left, + right), Seq( (-1, 0, -2, -1), (0, 2, 1, 3), @@ -89,7 +119,13 @@ class RangeJoinSuite extends SparkPlanTest { // low1 < high2 && low2 <= high1 checkAnswer2(intervals1, intervals2, (left: SparkPlan, right: SparkPlan) => - BroadcastRangeJoin(intervalKeys1, intervalKeys2, false :: true :: Nil, BuildLeft, left, right), + BroadcastRangeJoin( + intervalKeys1, + intervalKeys2, + false :: true :: Nil, + BuildLeft, + left, + right), Seq( (0, 1, 1, 3), (0, 2, 1, 3), @@ -99,7 +135,13 @@ class RangeJoinSuite extends SparkPlanTest { // low1 < high2 && low2 < high1 checkAnswer2(intervals1, intervals2, (left: SparkPlan, right: SparkPlan) => - BroadcastRangeJoin(intervalKeys1, intervalKeys2, false :: false :: Nil, BuildRight, left, right), + BroadcastRangeJoin( + intervalKeys1, + intervalKeys2, + false :: false :: Nil, + BuildRight, + left, + right), Seq( (0, 2, 1, 3), (1, 5, 1, 3) @@ -107,7 +149,13 @@ class RangeJoinSuite extends SparkPlanTest { // low1 <= high2 && low2 <= high1 checkAnswer2(intervals1, intervals2, (left: SparkPlan, right: SparkPlan) => - BroadcastRangeJoin(intervalKeys1, intervalKeys2, true :: true :: Nil, BuildLeft, left, right), + BroadcastRangeJoin( + intervalKeys1, + intervalKeys2, + true :: true :: Nil, + BuildLeft, + left, + right), Seq( (-1, 0, -2, -1), (0, 1, 1, 3), From 8204eaed1b9399f17415afc6ce178c845f29746f Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 15 Jul 2015 23:44:29 -0400 Subject: [PATCH 8/8] scalastyle... --- .../org/apache/spark/sql/execution/joins/RangeJoinSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/RangeJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/RangeJoinSuite.scala index b2e79ba80c27..9b3c5544999a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/RangeJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/RangeJoinSuite.scala @@ -32,7 +32,7 @@ class RangeJoinSuite extends SparkPlanTest { val intervals2 = Seq( (-2, -1), (-4, -2), - (1 ,3), + (1, 3), (5, 7) ).toDF("low2", "high2") val intervalKeys2 = Seq("low2", "high2").map(UnresolvedAttribute.apply)