From fd45c7aa90670301779b56fe84e6f2d6afb191d1 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 14 Jul 2015 14:45:33 +0800 Subject: [PATCH 1/5] Support mutable state in code gen expressions --- .../expressions/codegen/CodeGenerator.scala | 10 +- .../codegen/GenerateMutableProjection.scala | 20 +++- .../codegen/GenerateOrdering.scala | 19 ++- .../codegen/GeneratePredicate.scala | 20 +++- .../codegen/GenerateProjection.scala | 113 ++++++++++-------- .../sql/catalyst/expressions/random.scala | 19 +++ .../MonotonicallyIncreasingID.scala | 21 +++- .../expressions/SparkPartitionID.scala | 16 ++- 8 files changed, 172 insertions(+), 66 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 9f6329bbda4e..53aecb467d00 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -56,6 +56,14 @@ class CodeGenContext { */ val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]() + /** + * Holding expressions' mutable states like `Rand.rng`, and keep them as member variables + * in generated classes like `SpecificProjection`. + * Each element is a 3-tuple: java type, variable name, variable value. + */ + val mutableStates: mutable.ArrayBuffer[(String, String, Any)] = + mutable.ArrayBuffer.empty[(String, String, Any)] + val stringType: String = classOf[UTF8String].getName val decimalType: String = classOf[Decimal].getName @@ -205,7 +213,7 @@ class CodeGenContext { abstract class GeneratedClass { - def generate(expressions: Array[Expression]): Any + def generate(expressions: Array[Expression], states: Array[Any]): Any } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index addb8023d9c0..022ccec41524 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -46,19 +46,30 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)}; """ }.mkString("\n") + + val mutableStates = ctx.mutableStates.map { + case (jt, name, _) => s"private $jt $name;" + }.mkString("\n ") + + val initStates = ctx.mutableStates.zipWithIndex.map { + case ((jt, name, _), index) => s"$name = (${ctx.boxedType(jt)}) states[$index];" + }.mkString("\n ") + val code = s""" - public Object generate($exprType[] expr) { - return new SpecificProjection(expr); + public Object generate($exprType[] expr, Object[] states) { + return new SpecificProjection(expr, states); } class SpecificProjection extends ${classOf[BaseMutableProjection].getName} { private $exprType[] expressions = null; private $mutableRowType mutableRow = null; + $mutableStates - public SpecificProjection($exprType[] expr) { + public SpecificProjection($exprType[] expr, Object[] states) { expressions = expr; mutableRow = new $genericMutableRowType(${expressions.size}); + $initStates } public ${classOf[BaseMutableProjection].getName} target($mutableRowType row) { @@ -84,7 +95,8 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu val c = compile(code) () => { - c.generate(ctx.references.toArray).asInstanceOf[MutableProjection] + c.generate(ctx.references.toArray, ctx.mutableStates.map(_._3).toArray) + .asInstanceOf[MutableProjection] } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index d05dfc108e63..5b4cfb58c8df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -70,17 +70,27 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR """ }.mkString("\n") + val mutableStates = ctx.mutableStates.map { + case (jt, name, _) => s"private $jt $name;" + }.mkString("\n ") + + val initStates = ctx.mutableStates.zipWithIndex.map { + case ((jt, name, _), index) => s"$name = (${ctx.boxedType(jt)}) states[$index];" + }.mkString("\n ") + val code = s""" - public SpecificOrdering generate($exprType[] expr) { - return new SpecificOrdering(expr); + public SpecificOrdering generate($exprType[] expr, Object[] states) { + return new SpecificOrdering(expr, states); } class SpecificOrdering extends ${classOf[BaseOrdering].getName} { private $exprType[] expressions = null; + $mutableStates - public SpecificOrdering($exprType[] expr) { + public SpecificOrdering($exprType[] expr, Object[] states) { expressions = expr; + $initStates } @Override @@ -93,6 +103,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR logDebug(s"Generated Ordering: $code") - compile(code).generate(ctx.references.toArray).asInstanceOf[BaseOrdering] + compile(code).generate(ctx.references.toArray, ctx.mutableStates.map(_._3).toArray) + .asInstanceOf[BaseOrdering] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index 274a42cb6908..c9698ff0486e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -40,15 +40,26 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool protected def create(predicate: Expression): ((InternalRow) => Boolean) = { val ctx = newCodeGenContext() val eval = predicate.gen(ctx) + + val mutableStates = ctx.mutableStates.map { + case (jt, name, _) => s"private $jt $name;" + }.mkString("\n ") + + val initStates = ctx.mutableStates.zipWithIndex.map { + case ((jt, name, _), index) => s"$name = (${ctx.boxedType(jt)}) states[$index];" + }.mkString("\n ") + val code = s""" - public SpecificPredicate generate($exprType[] expr) { - return new SpecificPredicate(expr); + public SpecificPredicate generate($exprType[] expr, Object[] states) { + return new SpecificPredicate(expr, states); } class SpecificPredicate extends ${classOf[Predicate].getName} { private final $exprType[] expressions; - public SpecificPredicate($exprType[] expr) { + $mutableStates + public SpecificPredicate($exprType[] expr, Object[] states) { expressions = expr; + $initStates } @Override @@ -60,7 +71,8 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool logDebug(s"Generated predicate '$predicate':\n$code") - val p = compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate] + val p = compile(code).generate(ctx.references.toArray, ctx.mutableStates.map(_._3).toArray) + .asInstanceOf[Predicate] (r: InternalRow) => p.eval(r) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 3c7ee9cc1659..2a40cd0a1e11 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -151,85 +151,96 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { s"""if (!nullBits[$i]) arr[$i] = c$i;""" }.mkString("\n ") + val mutableStates = ctx.mutableStates.map { + case (jt, name, _) => s"private $jt $name;" + }.mkString("\n ") + + val initStates = ctx.mutableStates.zipWithIndex.map { + case ((jt, name, _), index) => s"$name = (${ctx.boxedType(jt)}) states[$index];" + }.mkString("\n ") + val code = s""" - public SpecificProjection generate($exprType[] expr) { - return new SpecificProjection(expr); + public SpecificProjection generate($exprType[] expr, Object[] states) { + return new SpecificProjection(expr, states); } class SpecificProjection extends ${classOf[BaseProject].getName} { private $exprType[] expressions = null; + $mutableStates - public SpecificProjection($exprType[] expr) { + public SpecificProjection($exprType[] expr, Object[] states) { expressions = expr; + $initStates } @Override public Object apply(Object r) { - return new SpecificRow(expressions, (InternalRow) r); + return new SpecificRow((InternalRow) r); } - } - final class SpecificRow extends ${classOf[MutableRow].getName} { + final class SpecificRow extends ${classOf[MutableRow].getName} { - $columns + $columns - public SpecificRow($exprType[] expressions, InternalRow i) { - $initColumns - } + public SpecificRow(InternalRow i) { + $initColumns + } - public int length() { return ${expressions.length};} - protected boolean[] nullBits = new boolean[${expressions.length}]; - public void setNullAt(int i) { nullBits[i] = true; } - public boolean isNullAt(int i) { return nullBits[i]; } + public int length() { return ${expressions.length};} + protected boolean[] nullBits = new boolean[${expressions.length}]; + public void setNullAt(int i) { nullBits[i] = true; } + public boolean isNullAt(int i) { return nullBits[i]; } - public Object get(int i) { - if (isNullAt(i)) return null; - switch (i) { - $getCases + public Object get(int i) { + if (isNullAt(i)) return null; + switch (i) { + $getCases + } + return null; } - return null; - } - public void update(int i, Object value) { - if (value == null) { - setNullAt(i); - return; + public void update(int i, Object value) { + if (value == null) { + setNullAt(i); + return; + } + nullBits[i] = false; + switch (i) { + $updateCases + } } - nullBits[i] = false; - switch (i) { - $updateCases + $specificAccessorFunctions + $specificMutatorFunctions + + @Override + public int hashCode() { + int result = 37; + $hashUpdates + return result; } - } - $specificAccessorFunctions - $specificMutatorFunctions - - @Override - public int hashCode() { - int result = 37; - $hashUpdates - return result; - } - @Override - public boolean equals(Object other) { - if (other instanceof SpecificRow) { - SpecificRow row = (SpecificRow) other; - $columnChecks - return true; + @Override + public boolean equals(Object other) { + if (other instanceof SpecificRow) { + SpecificRow row = (SpecificRow) other; + $columnChecks + return true; + } + return super.equals(other); } - return super.equals(other); - } - @Override - public InternalRow copy() { - Object[] arr = new Object[${expressions.length}]; - ${copyColumns} - return new ${classOf[GenericInternalRow].getName}(arr); + @Override + public InternalRow copy() { + Object[] arr = new Object[${expressions.length}]; + ${copyColumns} + return new ${classOf[GenericInternalRow].getName}(arr); + } } } """ logDebug(s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n${code}") - compile(code).generate(ctx.references.toArray).asInstanceOf[Projection] + compile(code).generate(ctx.references.toArray, ctx.mutableStates.map(_._3).toArray) + .asInstanceOf[Projection] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala index 6cdc3000382e..a53b9be7831c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.TaskContext import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.types.{DataType, DoubleType} import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -61,6 +62,15 @@ case class Rand(seed: Long) extends RDG(seed) { case IntegerLiteral(s) => s case _ => throw new AnalysisException("Input argument to rand must be an integer literal.") }) + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val rngTerm = ctx.freshName("rng") + ctx.mutableStates += ((rng.getClass.getCanonicalName, rngTerm, rng)) + ev.isNull = "false" + s""" + final ${ctx.javaType(dataType)} ${ev.primitive} = $rngTerm.nextDouble(); + """ + } } /** Generate a random column with i.i.d. gaussian random distribution. */ @@ -73,4 +83,13 @@ case class Randn(seed: Long) extends RDG(seed) { case IntegerLiteral(s) => s case _ => throw new AnalysisException("Input argument to rand must be an integer literal.") }) + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val rngTerm = ctx.freshName("rng") + ctx.mutableStates += ((rng.getClass.getCanonicalName, rngTerm, rng)) + ev.isNull = "false" + s""" + final ${ctx.javaType(dataType)} ${ev.primitive} = $rngTerm.nextGaussian(); + """ + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala index 437d143e53f3..160f2cf75a8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.expressions import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.LeafExpression +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types.{LongType, DataType} /** @@ -40,6 +41,11 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression { */ @transient private[this] var count: Long = 0L + @transient protected lazy val partitionMask = TaskContext.get() match { + case null => 0L + case _ => TaskContext.get().partitionId().toLong << 33 + } + override def nullable: Boolean = false override def dataType: DataType = LongType @@ -47,6 +53,19 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression { override def eval(input: InternalRow): Long = { val currentCount = count count += 1 - (TaskContext.get().partitionId().toLong << 33) + currentCount + partitionMask + currentCount + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val countTerm = ctx.freshName("count") + val partitionMaskTerm = ctx.freshName("partitionMask") + ctx.mutableStates += (("long", countTerm, count)) + ctx.mutableStates += (("long", partitionMaskTerm, partitionMask)) + + ev.isNull = "false" + s""" + final ${ctx.javaType(dataType)} ${ev.primitive} = $partitionMaskTerm + $countTerm; + $countTerm++; + """ } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala index 822d3d8c9108..7e59a6387a98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.expressions import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.LeafExpression +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types.{IntegerType, DataType} @@ -32,5 +33,18 @@ private[sql] case object SparkPartitionID extends LeafExpression { override def dataType: DataType = IntegerType - override def eval(input: InternalRow): Int = TaskContext.get().partitionId() + @transient private lazy val partitionId = TaskContext.get() match { + case null => 0 + case _ => TaskContext.get().partitionId() + } + + override def eval(input: InternalRow): Int = partitionId + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val idTerm = ctx.freshName("partitionId") + ctx.mutableStates += (("int", idTerm, partitionId)) + ev.isNull = "false" + ev.primitive = idTerm + "" + } } From d43b65d9e8df54f0d53d7456a61c2388121d87e2 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 14 Jul 2015 15:10:22 +0800 Subject: [PATCH 2/5] address comments --- .../scala/org/apache/spark/TaskContext.scala | 8 ++++++- .../expressions/codegen/CodeGenerator.scala | 19 ++++++++++------- .../codegen/GenerateMutableProjection.scala | 21 +++++-------------- .../codegen/GenerateOrdering.scala | 20 +++++------------- .../codegen/GeneratePredicate.scala | 21 +++++-------------- .../codegen/GenerateProjection.scala | 20 +++++------------- .../sql/catalyst/expressions/random.scala | 14 ++++++------- .../MonotonicallyIncreasingID.scala | 10 ++++----- .../expressions/SparkPartitionID.scala | 10 +++------ 9 files changed, 53 insertions(+), 90 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index d09e17dea091..57ecb3508acf 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -32,7 +32,13 @@ object TaskContext { */ def get(): TaskContext = taskContext.get - private val taskContext: ThreadLocal[TaskContext] = new ThreadLocal[TaskContext] + /** + * Returns the partition id of currently active TaskContext. It will return 0 + * if there is no active TaskContext for cases like local execution. + */ + def getPartitionId(): Int = Option(taskContext.get).map(_.partitionId).getOrElse(0) + + private[this] val taskContext: ThreadLocal[TaskContext] = new ThreadLocal[TaskContext] // Note: protected[spark] instead of private[spark] to prevent the following two from // showing up in JavaDoc. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 53aecb467d00..52c69844fea1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -56,13 +56,15 @@ class CodeGenContext { */ val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]() + val mutableStates: mutable.ArrayBuffer[String] = mutable.ArrayBuffer.empty[String] + /** - * Holding expressions' mutable states like `Rand.rng`, and keep them as member variables - * in generated classes like `SpecificProjection`. - * Each element is a 3-tuple: java type, variable name, variable value. + * Register expressions' mutable states like `MonotonicallyIncreasingID.count`, they will be + * kept as member variables in generated classes like `SpecificProjection`. */ - val mutableStates: mutable.ArrayBuffer[(String, String, Any)] = - mutable.ArrayBuffer.empty[(String, String, Any)] + def addMutableState(javaType: String, variableName: String, initialValue: String): Unit = { + mutableStates += s"private $javaType $variableName = $initialValue;" + } val stringType: String = classOf[UTF8String].getName val decimalType: String = classOf[Decimal].getName @@ -211,9 +213,12 @@ class CodeGenContext { def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt)) } - +/** + * A wrapper for generated class, defines a `generate` method so that we can pass extra objects + * into generated class. + */ abstract class GeneratedClass { - def generate(expressions: Array[Expression], states: Array[Any]): Any + def generate(expressions: Array[Expression]): Any } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 022ccec41524..0b55ea047f1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -46,30 +46,20 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)}; """ }.mkString("\n") - - val mutableStates = ctx.mutableStates.map { - case (jt, name, _) => s"private $jt $name;" - }.mkString("\n ") - - val initStates = ctx.mutableStates.zipWithIndex.map { - case ((jt, name, _), index) => s"$name = (${ctx.boxedType(jt)}) states[$index];" - }.mkString("\n ") - val code = s""" - public Object generate($exprType[] expr, Object[] states) { - return new SpecificProjection(expr, states); + public Object generate($exprType[] expr) { + return new SpecificProjection(expr); } class SpecificProjection extends ${classOf[BaseMutableProjection].getName} { private $exprType[] expressions = null; private $mutableRowType mutableRow = null; - $mutableStates + ${ctx.mutableStates.mkString("\n ")} - public SpecificProjection($exprType[] expr, Object[] states) { + public SpecificProjection($exprType[] expr) { expressions = expr; mutableRow = new $genericMutableRowType(${expressions.size}); - $initStates } public ${classOf[BaseMutableProjection].getName} target($mutableRowType row) { @@ -95,8 +85,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu val c = compile(code) () => { - c.generate(ctx.references.toArray, ctx.mutableStates.map(_._3).toArray) - .asInstanceOf[MutableProjection] + c.generate(ctx.references.toArray).asInstanceOf[MutableProjection] } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 5b4cfb58c8df..b75c5eac9b8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -70,27 +70,18 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR """ }.mkString("\n") - val mutableStates = ctx.mutableStates.map { - case (jt, name, _) => s"private $jt $name;" - }.mkString("\n ") - - val initStates = ctx.mutableStates.zipWithIndex.map { - case ((jt, name, _), index) => s"$name = (${ctx.boxedType(jt)}) states[$index];" - }.mkString("\n ") - val code = s""" - public SpecificOrdering generate($exprType[] expr, Object[] states) { - return new SpecificOrdering(expr, states); + public SpecificOrdering generate($exprType[] expr) { + return new SpecificOrdering(expr); } class SpecificOrdering extends ${classOf[BaseOrdering].getName} { private $exprType[] expressions = null; - $mutableStates + ${ctx.mutableStates.mkString("\n ")} - public SpecificOrdering($exprType[] expr, Object[] states) { + public SpecificOrdering($exprType[] expr) { expressions = expr; - $initStates } @Override @@ -103,7 +94,6 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR logDebug(s"Generated Ordering: $code") - compile(code).generate(ctx.references.toArray, ctx.mutableStates.map(_._3).toArray) - .asInstanceOf[BaseOrdering] + compile(code).generate(ctx.references.toArray).asInstanceOf[BaseOrdering] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index c9698ff0486e..471731c33707 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -40,26 +40,16 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool protected def create(predicate: Expression): ((InternalRow) => Boolean) = { val ctx = newCodeGenContext() val eval = predicate.gen(ctx) - - val mutableStates = ctx.mutableStates.map { - case (jt, name, _) => s"private $jt $name;" - }.mkString("\n ") - - val initStates = ctx.mutableStates.zipWithIndex.map { - case ((jt, name, _), index) => s"$name = (${ctx.boxedType(jt)}) states[$index];" - }.mkString("\n ") - val code = s""" - public SpecificPredicate generate($exprType[] expr, Object[] states) { - return new SpecificPredicate(expr, states); + public SpecificPredicate generate($exprType[] expr) { + return new SpecificPredicate(expr); } class SpecificPredicate extends ${classOf[Predicate].getName} { private final $exprType[] expressions; - $mutableStates - public SpecificPredicate($exprType[] expr, Object[] states) { + ${ctx.mutableStates.mkString("\n ")} + public SpecificPredicate($exprType[] expr) { expressions = expr; - $initStates } @Override @@ -71,8 +61,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool logDebug(s"Generated predicate '$predicate':\n$code") - val p = compile(code).generate(ctx.references.toArray, ctx.mutableStates.map(_._3).toArray) - .asInstanceOf[Predicate] + val p = compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate] (r: InternalRow) => p.eval(r) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 2a40cd0a1e11..016ab5dcd0d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -151,26 +151,17 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { s"""if (!nullBits[$i]) arr[$i] = c$i;""" }.mkString("\n ") - val mutableStates = ctx.mutableStates.map { - case (jt, name, _) => s"private $jt $name;" - }.mkString("\n ") - - val initStates = ctx.mutableStates.zipWithIndex.map { - case ((jt, name, _), index) => s"$name = (${ctx.boxedType(jt)}) states[$index];" - }.mkString("\n ") - val code = s""" - public SpecificProjection generate($exprType[] expr, Object[] states) { - return new SpecificProjection(expr, states); + public SpecificProjection generate($exprType[] expr) { + return new SpecificProjection(expr); } class SpecificProjection extends ${classOf[BaseProject].getName} { private $exprType[] expressions = null; - $mutableStates + ${ctx.mutableStates.mkString("\n ")} - public SpecificProjection($exprType[] expr, Object[] states) { + public SpecificProjection($exprType[] expr) { expressions = expr; - $initStates } @Override @@ -240,7 +231,6 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { logDebug(s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n${code}") - compile(code).generate(ctx.references.toArray, ctx.mutableStates.map(_._3).toArray) - .asInstanceOf[Projection] + compile(code).generate(ctx.references.toArray).asInstanceOf[Projection] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala index a53b9be7831c..e10ba5539666 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala @@ -39,11 +39,7 @@ abstract class RDG(seed: Long) extends LeafExpression with Serializable { * Record ID within each partition. By being transient, the Random Number Generator is * reset every time we serialize and deserialize it. */ - @transient protected lazy val partitionId = TaskContext.get() match { - case null => 0 - case _ => TaskContext.get().partitionId() - } - @transient protected lazy val rng = new XORShiftRandom(seed + partitionId) + @transient protected lazy val rng = new XORShiftRandom(seed + TaskContext.getPartitionId) override def deterministic: Boolean = false @@ -65,7 +61,9 @@ case class Rand(seed: Long) extends RDG(seed) { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val rngTerm = ctx.freshName("rng") - ctx.mutableStates += ((rng.getClass.getCanonicalName, rngTerm, rng)) + val className = classOf[XORShiftRandom].getCanonicalName + ctx.addMutableState(className, rngTerm, + s"new $className($seed + org.apache.spark.TaskContext.getPartitionId())") ev.isNull = "false" s""" final ${ctx.javaType(dataType)} ${ev.primitive} = $rngTerm.nextDouble(); @@ -86,7 +84,9 @@ case class Randn(seed: Long) extends RDG(seed) { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val rngTerm = ctx.freshName("rng") - ctx.mutableStates += ((rng.getClass.getCanonicalName, rngTerm, rng)) + val className = classOf[XORShiftRandom].getCanonicalName + ctx.addMutableState(className, rngTerm, + s"new $className($seed + org.apache.spark.TaskContext.getPartitionId())") ev.isNull = "false" s""" final ${ctx.javaType(dataType)} ${ev.primitive} = $rngTerm.nextGaussian(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala index 160f2cf75a8e..69a37750d752 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala @@ -41,10 +41,7 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression { */ @transient private[this] var count: Long = 0L - @transient protected lazy val partitionMask = TaskContext.get() match { - case null => 0L - case _ => TaskContext.get().partitionId().toLong << 33 - } + @transient private lazy val partitionMask = TaskContext.getPartitionId.toLong << 33 override def nullable: Boolean = false @@ -59,8 +56,9 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val countTerm = ctx.freshName("count") val partitionMaskTerm = ctx.freshName("partitionMask") - ctx.mutableStates += (("long", countTerm, count)) - ctx.mutableStates += (("long", partitionMaskTerm, partitionMask)) + ctx.addMutableState(ctx.JAVA_LONG, countTerm, "0L") + ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm, + "((long) org.apache.spark.TaskContext.getPartitionId()) << 33") ev.isNull = "false" s""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala index 7e59a6387a98..5f1b514f2cff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala @@ -33,18 +33,14 @@ private[sql] case object SparkPartitionID extends LeafExpression { override def dataType: DataType = IntegerType - @transient private lazy val partitionId = TaskContext.get() match { - case null => 0 - case _ => TaskContext.get().partitionId() - } + @transient private lazy val partitionId = TaskContext.getPartitionId override def eval(input: InternalRow): Int = partitionId override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val idTerm = ctx.freshName("partitionId") - ctx.mutableStates += (("int", idTerm, partitionId)) + ctx.addMutableState(ctx.JAVA_INT, idTerm, "org.apache.spark.TaskContext.getPartitionId()") ev.isNull = "false" - ev.primitive = idTerm - "" + s"final ${ctx.javaType(dataType)} ${ev.primitive} = $idTerm;" } } From 318f41d09e442f0d23a87d16a0af46c0ccfc79c3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 15 Jul 2015 10:59:16 +0800 Subject: [PATCH 3/5] address more comments --- .../main/scala/org/apache/spark/TaskContext.scala | 9 ++++++++- .../catalyst/expressions/codegen/CodeGenerator.scala | 12 +++++++----- .../codegen/GenerateMutableProjection.scala | 5 ++++- .../expressions/codegen/GenerateOrdering.scala | 6 ++++-- .../expressions/codegen/GeneratePredicate.scala | 5 ++++- .../expressions/codegen/GenerateProjection.scala | 6 +++++- 6 files changed, 32 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 57ecb3508acf..248339148d9b 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -36,7 +36,14 @@ object TaskContext { * Returns the partition id of currently active TaskContext. It will return 0 * if there is no active TaskContext for cases like local execution. */ - def getPartitionId(): Int = Option(taskContext.get).map(_.partitionId).getOrElse(0) + def getPartitionId(): Int = { + val tc = taskContext.get() + if (tc == null) { + 0 + } else { + tc.partitionId() + } + } private[this] val taskContext: ThreadLocal[TaskContext] = new ThreadLocal[TaskContext] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 52c69844fea1..328d635de874 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -56,14 +56,16 @@ class CodeGenContext { */ val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]() - val mutableStates: mutable.ArrayBuffer[String] = mutable.ArrayBuffer.empty[String] - /** - * Register expressions' mutable states like `MonotonicallyIncreasingID.count`, they will be - * kept as member variables in generated classes like `SpecificProjection`. + * Holding expressions' mutable states like `MonotonicallyIncreasingID.count` as a + * 3-tuple: java type, variable name, code to init it. + * They will be kept as member variables in generated classes like `SpecificProjection`. */ + val mutableStates: mutable.ArrayBuffer[(String, String, String)] = + mutable.ArrayBuffer.empty[(String, String, String)] + def addMutableState(javaType: String, variableName: String, initialValue: String): Unit = { - mutableStates += s"private $javaType $variableName = $initialValue;" + mutableStates += ((javaType, variableName, initialValue)) } val stringType: String = classOf[UTF8String].getName diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 0b55ea047f1a..283545fdeb9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -46,6 +46,9 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)}; """ }.mkString("\n") + val mutableStates = ctx.mutableStates.map { case (jt, name, init) => + s"private $jt $name = $init;" + }.mkString("\n ") val code = s""" public Object generate($exprType[] expr) { return new SpecificProjection(expr); @@ -55,7 +58,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu private $exprType[] expressions = null; private $mutableRowType mutableRow = null; - ${ctx.mutableStates.mkString("\n ")} + $mutableStates public SpecificProjection($exprType[] expr) { expressions = expr; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index b75c5eac9b8d..fda7e24f63da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -69,7 +69,9 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR } """ }.mkString("\n") - + val mutableStates = ctx.mutableStates.map { case (jt, name, init) => + s"private $jt $name = $init;" + }.mkString("\n ") val code = s""" public SpecificOrdering generate($exprType[] expr) { return new SpecificOrdering(expr); @@ -78,7 +80,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR class SpecificOrdering extends ${classOf[BaseOrdering].getName} { private $exprType[] expressions = null; - ${ctx.mutableStates.mkString("\n ")} + $mutableStates public SpecificOrdering($exprType[] expr) { expressions = expr; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index 471731c33707..b4bdfbfbc348 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -40,6 +40,9 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool protected def create(predicate: Expression): ((InternalRow) => Boolean) = { val ctx = newCodeGenContext() val eval = predicate.gen(ctx) + val mutableStates = ctx.mutableStates.map { case (jt, name, init) => + s"private $jt $name = $init;" + }.mkString("\n ") val code = s""" public SpecificPredicate generate($exprType[] expr) { return new SpecificPredicate(expr); @@ -47,7 +50,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool class SpecificPredicate extends ${classOf[Predicate].getName} { private final $exprType[] expressions; - ${ctx.mutableStates.mkString("\n ")} + $mutableStates public SpecificPredicate($exprType[] expr) { expressions = expr; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 016ab5dcd0d7..5061709c2e96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -151,6 +151,10 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { s"""if (!nullBits[$i]) arr[$i] = c$i;""" }.mkString("\n ") + val mutableStates = ctx.mutableStates.map { case (jt, name, init) => + s"private $jt $name = $init;" + }.mkString("\n ") + val code = s""" public SpecificProjection generate($exprType[] expr) { return new SpecificProjection(expr); @@ -158,7 +162,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { class SpecificProjection extends ${classOf[BaseProject].getName} { private $exprType[] expressions = null; - ${ctx.mutableStates.mkString("\n ")} + $mutableStates public SpecificProjection($exprType[] expr) { expressions = expr; From 73144d8ae27fb5bcc19835040647f8f2f4b78b3e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 15 Jul 2015 16:35:49 +0800 Subject: [PATCH 4/5] naming improvement --- .../expressions/codegen/GenerateMutableProjection.scala | 4 ++-- .../sql/catalyst/expressions/codegen/GenerateOrdering.scala | 4 ++-- .../sql/catalyst/expressions/codegen/GeneratePredicate.scala | 4 ++-- .../sql/catalyst/expressions/codegen/GenerateProjection.scala | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 283545fdeb9f..71e47d4f9b62 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -46,8 +46,8 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)}; """ }.mkString("\n") - val mutableStates = ctx.mutableStates.map { case (jt, name, init) => - s"private $jt $name = $init;" + val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) => + s"private $javaType $variableName = $initialValue;" }.mkString("\n ") val code = s""" public Object generate($exprType[] expr) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index fda7e24f63da..b2411e77c359 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -69,8 +69,8 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR } """ }.mkString("\n") - val mutableStates = ctx.mutableStates.map { case (jt, name, init) => - s"private $jt $name = $init;" + val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) => + s"private $javaType $variableName = $initialValue;" }.mkString("\n ") val code = s""" public SpecificOrdering generate($exprType[] expr) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index b4bdfbfbc348..9e5a745d512e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -40,8 +40,8 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool protected def create(predicate: Expression): ((InternalRow) => Boolean) = { val ctx = newCodeGenContext() val eval = predicate.gen(ctx) - val mutableStates = ctx.mutableStates.map { case (jt, name, init) => - s"private $jt $name = $init;" + val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) => + s"private $javaType $variableName = $initialValue;" }.mkString("\n ") val code = s""" public SpecificPredicate generate($exprType[] expr) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 5061709c2e96..3e5ca308dc31 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -151,8 +151,8 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { s"""if (!nullBits[$i]) arr[$i] = c$i;""" }.mkString("\n ") - val mutableStates = ctx.mutableStates.map { case (jt, name, init) => - s"private $jt $name = $init;" + val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) => + s"private $javaType $variableName = $initialValue;" }.mkString("\n ") val code = s""" From eb3a221d95839ff57b7579f87cbceaaf9ff19653 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 15 Jul 2015 17:30:01 +0800 Subject: [PATCH 5/5] fix order --- .../codegen/GenerateOrdering.scala | 33 ++++++++++++++----- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index b2411e77c359..856ff9f1f96f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -46,23 +46,38 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR protected def create(ordering: Seq[SortOrder]): Ordering[InternalRow] = { val ctx = newCodeGenContext() - val comparisons = ordering.zipWithIndex.map { case (order, i) => - val evalA = order.child.gen(ctx) - val evalB = order.child.gen(ctx) + val comparisons = ordering.map { order => + val eval = order.child.gen(ctx) val asc = order.direction == Ascending + val isNullA = ctx.freshName("isNullA") + val primitiveA = ctx.freshName("primitiveA") + val isNullB = ctx.freshName("isNullB") + val primitiveB = ctx.freshName("primitiveB") s""" i = a; - ${evalA.code} + boolean $isNullA; + ${ctx.javaType(order.child.dataType)} $primitiveA; + { + ${eval.code} + $isNullA = ${eval.isNull}; + $primitiveA = ${eval.primitive}; + } i = b; - ${evalB.code} - if (${evalA.isNull} && ${evalB.isNull}) { + boolean $isNullB; + ${ctx.javaType(order.child.dataType)} $primitiveB; + { + ${eval.code} + $isNullB = ${eval.isNull}; + $primitiveB = ${eval.primitive}; + } + if ($isNullA && $isNullB) { // Nothing - } else if (${evalA.isNull}) { + } else if ($isNullA) { return ${if (order.direction == Ascending) "-1" else "1"}; - } else if (${evalB.isNull}) { + } else if ($isNullB) { return ${if (order.direction == Ascending) "1" else "-1"}; } else { - int comp = ${ctx.genComp(order.child.dataType, evalA.primitive, evalB.primitive)}; + int comp = ${ctx.genComp(order.child.dataType, primitiveA, primitiveB)}; if (comp != 0) { return ${if (asc) "comp" else "-comp"}; }