@@ -46,6 +46,26 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode
4646 }
4747}
4848
49+ /**
50+ * Take the first `limit` elements and collect them to a single partition.
51+ *
52+ * This operator will be used when a logical `Limit` operation is the final operator in an
53+ * logical plan, which happens when the user is collecting results back to the driver.
54+ */
55+ case class CollectLimitRangeExec (start : Int , end : Int , child : SparkPlan ) extends UnaryExecNode {
56+ override def output : Seq [Attribute ] = child.output
57+ override def outputPartitioning : Partitioning = SinglePartition
58+ override def executeCollect (): Array [InternalRow ] = child.executeTake(end)
59+ private val serializer : Serializer = new UnsafeRowSerializer (child.output.size)
60+ protected override def doExecute (): RDD [InternalRow ] = {
61+ val locallyLimited = child.execute().mapPartitionsInternal(_.take(end))
62+ val shuffled = new ShuffledRowRDD (
63+ ShuffleExchangeExec .prepareShuffleDependency(
64+ locallyLimited, child.output, SinglePartition , serializer))
65+ shuffled.mapPartitionsInternal(_.slice(start, end))
66+ }
67+ }
68+
4969/**
5070 * Helper trait which defines methods that are shared by both
5171 * [[LocalLimitExec ]] and [[GlobalLimitExec ]].
@@ -114,6 +134,43 @@ case class GlobalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec {
114134 override def outputOrdering : Seq [SortOrder ] = child.outputOrdering
115135}
116136
137+ /**
138+ * Take the first `limit` elements of the child's single output partition.
139+ */
140+ case class RangeLimitExec (start : Int , limit : Int , child : SparkPlan ) extends BaseLimitExec {
141+
142+ override def requiredChildDistribution : List [Distribution ] = AllTuples :: Nil
143+
144+ override def outputPartitioning : Partitioning = child.outputPartitioning
145+
146+ override def outputOrdering : Seq [SortOrder ] = child.outputOrdering
147+
148+ override def doConsume (ctx : CodegenContext , input : Seq [ExprCode ], row : ExprCode ): String = {
149+ val stopEarly =
150+ ctx.addMutableState(CodeGenerator .JAVA_BOOLEAN , " stopEarly" ) // init as stopEarly = false
151+
152+ ctx.addNewFunction(" stopEarly" , s """
153+ @Override
154+ protected boolean stopEarly() {
155+ return $stopEarly;
156+ }
157+ """ , inlineToOuterClass = true )
158+ val countTerm = ctx.addMutableState(CodeGenerator .JAVA_INT , " count" ) // init as count = 0
159+ s """
160+ | $countTerm += 1;
161+ | if ( $countTerm > $start && $countTerm <= $limit) {
162+ | ${consume(ctx, input)}
163+ | } if( $countTerm > $limit) {
164+ | $stopEarly = true;
165+ | }
166+ """ .stripMargin
167+ }
168+
169+ protected override def doExecute (): RDD [InternalRow ] = child.execute().mapPartitions { iter =>
170+ iter.slice(start, limit)
171+ }
172+ }
173+
117174/**
118175 * Take the first limit elements as defined by the sortOrder, and do projection if needed.
119176 * This is logically equivalent to having a Limit operator after a [[SortExec ]] operator,
@@ -176,3 +233,68 @@ case class TakeOrderedAndProjectExec(
176233 s " TakeOrderedAndProject(limit= $limit, orderBy= $orderByString, output= $outputString) "
177234 }
178235}
236+ /**
237+ * Take the first limit elements as defined by the sortOrder, and do projection if needed.
238+ * This is logically equivalent to having a Limit operator after a [[SortExec ]] operator,
239+ * or having a [[ProjectExec ]] operator between them.
240+ * This could have been named TopK, but Spark's top operator does the opposite in ordering
241+ * so we name it TakeOrdered to avoid confusion.
242+ */
243+ case class TakeOrderedRangeAndProjectExec (
244+ start : Int ,
245+ end : Int ,
246+ sortOrder : Seq [SortOrder ],
247+ projectList : Seq [NamedExpression ],
248+ child : SparkPlan ) extends UnaryExecNode {
249+
250+ override def output : Seq [Attribute ] = {
251+ projectList.map(_.toAttribute)
252+ }
253+
254+ override def executeCollect (): Array [InternalRow ] = {
255+ val ord = new LazilyGeneratedOrdering (sortOrder, child.output)
256+ val data = child.execute().map(_.copy()).takeOrdered(end)(ord).drop(start)
257+ if (projectList != child.output) {
258+ val proj = UnsafeProjection .create(projectList, child.output)
259+ data.map(r => proj(r).copy())
260+ } else {
261+ data
262+ }
263+ }
264+
265+ private val serializer : Serializer = new UnsafeRowSerializer (child.output.size)
266+
267+ protected override def doExecute (): RDD [InternalRow ] = {
268+ val ord = new LazilyGeneratedOrdering (sortOrder, child.output)
269+ val localTopK : RDD [InternalRow ] = {
270+ child.execute().map(_.copy()).mapPartitions { iter =>
271+ org.apache.spark.util.collection.Utils .takeOrdered(iter, end)(ord)
272+ }
273+ }
274+ val shuffled = new ShuffledRowRDD (
275+ ShuffleExchangeExec .prepareShuffleDependency(
276+ localTopK, child.output, SinglePartition , serializer))
277+ shuffled.mapPartitions { iter =>
278+ val topK = org.apache.spark.util.collection.Utils .takeOrdered(iter.map(_.copy()), end)(ord)
279+ .drop(start)
280+ if (projectList != child.output) {
281+ val proj = UnsafeProjection .create(projectList, child.output)
282+ topK.map(r => proj(r))
283+ } else {
284+ topK
285+ }
286+ }
287+ }
288+
289+ override def outputOrdering : Seq [SortOrder ] = sortOrder
290+
291+ override def outputPartitioning : Partitioning = SinglePartition
292+
293+ override def simpleString : String = {
294+ val orderByString = Utils .truncatedString(sortOrder, " [" , " ," , " ]" )
295+ val outputString = Utils .truncatedString(output, " [" , " ," , " ]" )
296+
297+ s " TakeOrderedRangeAndProject " +
298+ s " (start= $start, end= $end, orderBy= $orderByString, output= $outputString) "
299+ }
300+ }
0 commit comments