Skip to content

Commit 04caba5

Browse files
committed
apache#73 support limit offset
1 parent 4b21f59 commit 04caba5

File tree

4 files changed

+173
-0
lines changed

4 files changed

+173
-0
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,35 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends OrderPr
823823
}
824824
}
825825

826+
object LimitRange {
827+
def apply(startExpr: Expression, endExpr: Expression, child: LogicalPlan): UnaryNode = {
828+
LimitRange0(startExpr, endExpr, LocalLimit(endExpr, child))
829+
}
830+
831+
def unapply(p: LimitRange0): Option[(Expression, Expression, LogicalPlan)] = {
832+
p match {
833+
case LimitRange0(le0, le1, LocalLimit(le2, child)) if le1 == le2 => Some((le0, le1, child))
834+
case _ => None
835+
}
836+
}
837+
}
838+
/**
839+
* A global (coordinated) limit. This operator can emit at most `limitExpr` number in total.
840+
*
841+
* See [[Limit]] for more information.
842+
*/
843+
case class LimitRange0(startExpr: Expression, endExpr: Expression, child: LogicalPlan)
844+
extends OrderPreservingUnaryNode {
845+
override def output: Seq[Attribute] = child.output
846+
override def maxRows: Option[Long] = {
847+
(endExpr, endExpr) match {
848+
case (IntegerLiteral(start), IntegerLiteral(end)) => Some(end - start)
849+
case _ => None
850+
}
851+
}
852+
}
853+
854+
826855
/**
827856
* Aliased subquery.
828857
*

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1809,6 +1809,10 @@ class Dataset[T] private[sql](
18091809
Limit(Literal(n), logicalPlan)
18101810
}
18111811

1812+
def limitRange(start: Int, end: Int): Dataset[T] = withTypedPlan {
1813+
LimitRange(Literal(start), Literal(end), logicalPlan)
1814+
}
1815+
18121816
/**
18131817
* Returns a new Dataset containing union of rows in this Dataset and another Dataset.
18141818
*

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,17 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
7474
case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child)))
7575
if limit < conf.topKSortFallbackThreshold =>
7676
TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil
77+
case LimitRange(IntegerLiteral(start), IntegerLiteral(end),
78+
Project(projectList, Sort(order, true, child))) if end < conf.topKSortFallbackThreshold =>
79+
TakeOrderedRangeAndProjectExec(start, end, order, projectList, planLater(child)) :: Nil
80+
case LimitRange(IntegerLiteral(start), IntegerLiteral(end), Sort(order, true, child))
81+
if end < conf.topKSortFallbackThreshold =>
82+
TakeOrderedRangeAndProjectExec(start, end, order, child.output, planLater(child)) :: Nil
7783
case Limit(IntegerLiteral(limit), child) =>
7884
CollectLimitExec(limit, planLater(child)) :: Nil
85+
case LimitRange(IntegerLiteral(start),
86+
IntegerLiteral(limit), child) =>
87+
CollectLimitRangeExec(start, limit, planLater(child)) :: Nil
7988
case other => planLater(other) :: Nil
8089
}
8190
case Limit(IntegerLiteral(limit), Sort(order, true, child))
@@ -84,6 +93,12 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
8493
case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child)))
8594
if limit < conf.topKSortFallbackThreshold =>
8695
TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil
96+
case LimitRange(IntegerLiteral(start), IntegerLiteral(end),
97+
Project(projectList, Sort(order, true, child))) if end < conf.topKSortFallbackThreshold =>
98+
TakeOrderedRangeAndProjectExec(start, end, order, projectList, planLater(child)) :: Nil
99+
case LimitRange(IntegerLiteral(start), IntegerLiteral(end), Sort(order, true, child))
100+
if end < conf.topKSortFallbackThreshold =>
101+
TakeOrderedRangeAndProjectExec(start, end, order, child.output, planLater(child)) :: Nil
87102
case _ => Nil
88103
}
89104
}
@@ -617,6 +632,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
617632
execution.LocalLimitExec(limit, planLater(child)) :: Nil
618633
case logical.GlobalLimit(IntegerLiteral(limit), child) =>
619634
execution.GlobalLimitExec(limit, planLater(child)) :: Nil
635+
case logical.LimitRange(IntegerLiteral(start),
636+
IntegerLiteral(limit), child) =>
637+
execution.RangeLimitExec(start, limit, planLater(child)) :: Nil
620638
case logical.Union(unionChildren) =>
621639
execution.UnionExec(unionChildren.map(planLater)) :: Nil
622640
case g @ logical.Generate(generator, _, outer, _, _, child) =>

sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,26 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode
4646
}
4747
}
4848

49+
/**
50+
* Take the first `limit` elements and collect them to a single partition.
51+
*
52+
* This operator will be used when a logical `Limit` operation is the final operator in an
53+
* logical plan, which happens when the user is collecting results back to the driver.
54+
*/
55+
case class CollectLimitRangeExec(start: Int, end: Int, child: SparkPlan) extends UnaryExecNode {
56+
override def output: Seq[Attribute] = child.output
57+
override def outputPartitioning: Partitioning = SinglePartition
58+
override def executeCollect(): Array[InternalRow] = child.executeTake(end)
59+
private val serializer: Serializer = new UnsafeRowSerializer(child.output.size)
60+
protected override def doExecute(): RDD[InternalRow] = {
61+
val locallyLimited = child.execute().mapPartitionsInternal(_.take(end))
62+
val shuffled = new ShuffledRowRDD(
63+
ShuffleExchangeExec.prepareShuffleDependency(
64+
locallyLimited, child.output, SinglePartition, serializer))
65+
shuffled.mapPartitionsInternal(_.slice(start, end))
66+
}
67+
}
68+
4969
/**
5070
* Helper trait which defines methods that are shared by both
5171
* [[LocalLimitExec]] and [[GlobalLimitExec]].
@@ -114,6 +134,43 @@ case class GlobalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec {
114134
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
115135
}
116136

137+
/**
138+
* Take the first `limit` elements of the child's single output partition.
139+
*/
140+
case class RangeLimitExec(start: Int, limit: Int, child: SparkPlan) extends BaseLimitExec {
141+
142+
override def requiredChildDistribution: List[Distribution] = AllTuples :: Nil
143+
144+
override def outputPartitioning: Partitioning = child.outputPartitioning
145+
146+
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
147+
148+
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
149+
val stopEarly =
150+
ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "stopEarly") // init as stopEarly = false
151+
152+
ctx.addNewFunction("stopEarly", s"""
153+
@Override
154+
protected boolean stopEarly() {
155+
return $stopEarly;
156+
}
157+
""", inlineToOuterClass = true)
158+
val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "count") // init as count = 0
159+
s"""
160+
| $countTerm += 1;
161+
| if ( $countTerm > $start && $countTerm <= $limit) {
162+
| ${consume(ctx, input)}
163+
| } if($countTerm > $limit) {
164+
| $stopEarly = true;
165+
| }
166+
""".stripMargin
167+
}
168+
169+
protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter =>
170+
iter.slice(start, limit)
171+
}
172+
}
173+
117174
/**
118175
* Take the first limit elements as defined by the sortOrder, and do projection if needed.
119176
* This is logically equivalent to having a Limit operator after a [[SortExec]] operator,
@@ -176,3 +233,68 @@ case class TakeOrderedAndProjectExec(
176233
s"TakeOrderedAndProject(limit=$limit, orderBy=$orderByString, output=$outputString)"
177234
}
178235
}
236+
/**
237+
* Take the first limit elements as defined by the sortOrder, and do projection if needed.
238+
* This is logically equivalent to having a Limit operator after a [[SortExec]] operator,
239+
* or having a [[ProjectExec]] operator between them.
240+
* This could have been named TopK, but Spark's top operator does the opposite in ordering
241+
* so we name it TakeOrdered to avoid confusion.
242+
*/
243+
case class TakeOrderedRangeAndProjectExec(
244+
start: Int,
245+
end: Int,
246+
sortOrder: Seq[SortOrder],
247+
projectList: Seq[NamedExpression],
248+
child: SparkPlan) extends UnaryExecNode {
249+
250+
override def output: Seq[Attribute] = {
251+
projectList.map(_.toAttribute)
252+
}
253+
254+
override def executeCollect(): Array[InternalRow] = {
255+
val ord = new LazilyGeneratedOrdering(sortOrder, child.output)
256+
val data = child.execute().map(_.copy()).takeOrdered(end)(ord).drop(start)
257+
if (projectList != child.output) {
258+
val proj = UnsafeProjection.create(projectList, child.output)
259+
data.map(r => proj(r).copy())
260+
} else {
261+
data
262+
}
263+
}
264+
265+
private val serializer: Serializer = new UnsafeRowSerializer(child.output.size)
266+
267+
protected override def doExecute(): RDD[InternalRow] = {
268+
val ord = new LazilyGeneratedOrdering(sortOrder, child.output)
269+
val localTopK: RDD[InternalRow] = {
270+
child.execute().map(_.copy()).mapPartitions { iter =>
271+
org.apache.spark.util.collection.Utils.takeOrdered(iter, end)(ord)
272+
}
273+
}
274+
val shuffled = new ShuffledRowRDD(
275+
ShuffleExchangeExec.prepareShuffleDependency(
276+
localTopK, child.output, SinglePartition, serializer))
277+
shuffled.mapPartitions { iter =>
278+
val topK = org.apache.spark.util.collection.Utils.takeOrdered(iter.map(_.copy()), end)(ord)
279+
.drop(start)
280+
if (projectList != child.output) {
281+
val proj = UnsafeProjection.create(projectList, child.output)
282+
topK.map(r => proj(r))
283+
} else {
284+
topK
285+
}
286+
}
287+
}
288+
289+
override def outputOrdering: Seq[SortOrder] = sortOrder
290+
291+
override def outputPartitioning: Partitioning = SinglePartition
292+
293+
override def simpleString: String = {
294+
val orderByString = Utils.truncatedString(sortOrder, "[", ",", "]")
295+
val outputString = Utils.truncatedString(output, "[", ",", "]")
296+
297+
s"TakeOrderedRangeAndProject" +
298+
s"(start=$start, end=$end, orderBy=$orderByString, output=$outputString)"
299+
}
300+
}

0 commit comments

Comments
 (0)