Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix tests
  • Loading branch information
Davies Liu committed Feb 28, 2016
commit bc2c66b88070f8a4f743ba4cc18c0a8a59b9cd7b
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.sql.execution

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
Expand All @@ -29,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.toCommentSafeString
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight, SortMergeJoin}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin}
import org.apache.spark.sql.execution.metric.LongSQLMetricValue

/**
Expand Down Expand Up @@ -169,10 +167,6 @@ case class InputAdapter(child: SparkPlan) extends UnaryNode with CodegenSupport
override def outputPartitioning: Partitioning = child.outputPartitioning
override def outputOrdering: Seq[SortOrder] = child.outputOrdering

override def doPrepare(): Unit = {
child.prepare()
}

override def doExecute(): RDD[InternalRow] = {
child.execute()
}
Expand All @@ -181,8 +175,6 @@ case class InputAdapter(child: SparkPlan) extends UnaryNode with CodegenSupport
child.doExecuteBroadcast()
}

override def supportCodegen: Boolean = false

override def upstreams(): Seq[RDD[InternalRow]] = {
child.execute() :: Nil
}
Expand Down Expand Up @@ -245,21 +237,15 @@ case class InputAdapter(child: SparkPlan) extends UnaryNode with CodegenSupport
* doCodeGen() will create a CodeGenContext, which will hold a list of variables for input,
* used to generated code for BoundReference.
*/
case class WholeStageCodegen(child: CodegenSupport) extends UnaryNode with CodegenSupport {

override def supportCodegen: Boolean = false
case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSupport {

override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = child.outputPartitioning
override def outputOrdering: Seq[SortOrder] = child.outputOrdering

override def doPrepare(): Unit = {
child.prepare()
}

override def doExecute(): RDD[InternalRow] = {
val ctx = new CodegenContext
val code = child.produce(ctx, this)
val code = child.asInstanceOf[CodegenSupport].produce(ctx, this)
val references = ctx.references.toArray
val source = s"""
public Object generate(Object[] references) {
Expand Down Expand Up @@ -295,7 +281,7 @@ case class WholeStageCodegen(child: CodegenSupport) extends UnaryNode with Codeg
// println(s"${CodeFormatter.format(cleanedSource)}")
CodeGenerator.compile(cleanedSource)

val rdds = child.upstreams()
val rdds = child.asInstanceOf[CodegenSupport].upstreams()
assert(rdds.size <= 2, "Up to two upstream RDDs can be supported")
if (rdds.length == 1) {
rdds.head.mapPartitions { iter =>
Expand Down Expand Up @@ -424,7 +410,7 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
*/
private def insertWholeStageCodegen(plan: SparkPlan): SparkPlan = plan match {
case plan: CodegenSupport if supportCodegen(plan) =>
WholeStageCodegen(insertInputAdapter(plan).asInstanceOf[CodegenSupport])
WholeStageCodegen(insertInputAdapter(plan))
case other =>
other.withNewChildren(other.children.map(insertWholeStageCodegen))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
import org.apache.spark.sql.internal.SQLConf

Expand Down Expand Up @@ -68,7 +69,7 @@ package object debug {
}
}

private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode {
private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode with CodegenSupport {
def output: Seq[Attribute] = child.output

implicit object SetAccumulatorParam extends AccumulatorParam[HashSet[String]] {
Expand All @@ -86,10 +87,11 @@ package object debug {
/**
* A collection of metrics for each column of output.
* @param elementTypes the actual runtime types for the output. Useful when there are bugs
* causing the wrong data to be projected.
* causing the wrong data to be projected.
*/
case class ColumnMetrics(
elementTypes: Accumulator[HashSet[String]] = sparkContext.accumulator(HashSet.empty))
elementTypes: Accumulator[HashSet[String]] = sparkContext.accumulator(HashSet.empty))

val tupleCount: Accumulator[Int] = sparkContext.accumulator[Int](0)

val numColumns: Int = child.output.size
Expand All @@ -98,7 +100,7 @@ package object debug {
def dumpStats(): Unit = {
logDebug(s"== ${child.simpleString} ==")
logDebug(s"Tuples output: ${tupleCount.value}")
child.output.zip(columnStats).foreach { case(attr, metric) =>
child.output.zip(columnStats).foreach { case (attr, metric) =>
val actualDataTypes = metric.elementTypes.value.mkString("{", ",", "}")
logDebug(s" ${attr.name} ${attr.dataType}: $actualDataTypes")
}
Expand All @@ -108,6 +110,7 @@ package object debug {
child.execute().mapPartitions { iter =>
new Iterator[InternalRow] {
def hasNext: Boolean = iter.hasNext

def next(): InternalRow = {
val currentRow = iter.next()
tupleCount += 1
Expand All @@ -124,5 +127,17 @@ package object debug {
}
}
}

override def upstreams(): Seq[RDD[InternalRow]] = {
child.asInstanceOf[CodegenSupport].upstreams()
}

override def doProduce(ctx: CodegenContext): String = {
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}

override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
consume(ctx, input)
}
}
}