Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion core/src/main/scala/org/apache/spark/TaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,20 @@ object TaskContext {
*/
def get(): TaskContext = taskContext.get

private val taskContext: ThreadLocal[TaskContext] = new ThreadLocal[TaskContext]
/**
* Returns the partition id of currently active TaskContext. It will return 0
* if there is no active TaskContext for cases like local execution.
*/
def getPartitionId(): Int = {
val tc = taskContext.get()
if (tc == null) {
0
} else {
tc.partitionId()
}
}

private[this] val taskContext: ThreadLocal[TaskContext] = new ThreadLocal[TaskContext]

// Note: protected[spark] instead of private[spark] to prevent the following two from
// showing up in JavaDoc.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,18 @@ class CodeGenContext {
*/
val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]()

/**
* Holding expressions' mutable states like `MonotonicallyIncreasingID.count` as a
* 3-tuple: java type, variable name, code to init it.
* They will be kept as member variables in generated classes like `SpecificProjection`.
*/
val mutableStates: mutable.ArrayBuffer[(String, String, String)] =
mutable.ArrayBuffer.empty[(String, String, String)]

def addMutableState(javaType: String, variableName: String, initialValue: String): Unit = {
mutableStates += ((javaType, variableName, initialValue))
}

val stringType: String = classOf[UTF8String].getName
val decimalType: String = classOf[Decimal].getName

Expand Down Expand Up @@ -203,7 +215,10 @@ class CodeGenContext {
def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt))
}


/**
* A wrapper for generated class, defines a `generate` method so that we can pass extra objects
* into generated class.
*/
abstract class GeneratedClass {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while you are at this, can you add scaladoc for GeneratedClass? It is not obvious what it does just by looking at it.

def generate(expressions: Array[Expression]): Any
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)};
"""
}.mkString("\n")
val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) =>
s"private $javaType $variableName = $initialValue;"
}.mkString("\n ")
val code = s"""
public Object generate($exprType[] expr) {
return new SpecificProjection(expr);
Expand All @@ -55,6 +58,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu

private $exprType[] expressions = null;
private $mutableRowType mutableRow = null;
$mutableStates

public SpecificProjection($exprType[] expr) {
expressions = expr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,30 +46,47 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
protected def create(ordering: Seq[SortOrder]): Ordering[InternalRow] = {
val ctx = newCodeGenContext()

val comparisons = ordering.zipWithIndex.map { case (order, i) =>
val evalA = order.child.gen(ctx)
val evalB = order.child.gen(ctx)
val comparisons = ordering.map { order =>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In interpreted mode, we use the same expression to eval 2 rows, which means we only keep one copy of mutable states for that expression. However, in GenerateOrdering, we call order.child.gen(ctx) twice and thus keep 2 copy of mutable states for that expression. This is inconsistent, and may return different compare result, so I fixed it here.

However, should we allow stateful expressions in order by?
cc @davies @rxin

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not figured out a case that need a stateful ordering, could we delay it until we really need it?

val eval = order.child.gen(ctx)
val asc = order.direction == Ascending
val isNullA = ctx.freshName("isNullA")
val primitiveA = ctx.freshName("primitiveA")
val isNullB = ctx.freshName("isNullB")
val primitiveB = ctx.freshName("primitiveB")
s"""
i = a;
${evalA.code}
boolean $isNullA;
${ctx.javaType(order.child.dataType)} $primitiveA;
{
${eval.code}
$isNullA = ${eval.isNull};
$primitiveA = ${eval.primitive};
}
i = b;
${evalB.code}
if (${evalA.isNull} && ${evalB.isNull}) {
boolean $isNullB;
${ctx.javaType(order.child.dataType)} $primitiveB;
{
${eval.code}
$isNullB = ${eval.isNull};
$primitiveB = ${eval.primitive};
}
if ($isNullA && $isNullB) {
// Nothing
} else if (${evalA.isNull}) {
} else if ($isNullA) {
return ${if (order.direction == Ascending) "-1" else "1"};
} else if (${evalB.isNull}) {
} else if ($isNullB) {
return ${if (order.direction == Ascending) "1" else "-1"};
} else {
int comp = ${ctx.genComp(order.child.dataType, evalA.primitive, evalB.primitive)};
int comp = ${ctx.genComp(order.child.dataType, primitiveA, primitiveB)};
if (comp != 0) {
return ${if (asc) "comp" else "-comp"};
}
}
"""
}.mkString("\n")

val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) =>
s"private $javaType $variableName = $initialValue;"
}.mkString("\n ")
val code = s"""
public SpecificOrdering generate($exprType[] expr) {
return new SpecificOrdering(expr);
Expand All @@ -78,6 +95,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
class SpecificOrdering extends ${classOf[BaseOrdering].getName} {

private $exprType[] expressions = null;
$mutableStates

public SpecificOrdering($exprType[] expr) {
expressions = expr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,17 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool
protected def create(predicate: Expression): ((InternalRow) => Boolean) = {
val ctx = newCodeGenContext()
val eval = predicate.gen(ctx)
val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) =>
s"private $javaType $variableName = $initialValue;"
}.mkString("\n ")
val code = s"""
public SpecificPredicate generate($exprType[] expr) {
return new SpecificPredicate(expr);
}

class SpecificPredicate extends ${classOf[Predicate].getName} {
private final $exprType[] expressions;
$mutableStates
public SpecificPredicate($exprType[] expr) {
expressions = expr;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,79 +151,84 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
s"""if (!nullBits[$i]) arr[$i] = c$i;"""
}.mkString("\n ")

val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) =>
s"private $javaType $variableName = $initialValue;"
}.mkString("\n ")

val code = s"""
public SpecificProjection generate($exprType[] expr) {
return new SpecificProjection(expr);
}

class SpecificProjection extends ${classOf[BaseProject].getName} {
private $exprType[] expressions = null;
$mutableStates

public SpecificProjection($exprType[] expr) {
expressions = expr;
}

@Override
public Object apply(Object r) {
return new SpecificRow(expressions, (InternalRow) r);
return new SpecificRow((InternalRow) r);
}
}

final class SpecificRow extends ${classOf[MutableRow].getName} {
final class SpecificRow extends ${classOf[MutableRow].getName} {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made SpecificRow a inner class of SpecificProjection here, so that we can access these mutable states easily.


$columns
$columns

public SpecificRow($exprType[] expressions, InternalRow i) {
$initColumns
}
public SpecificRow(InternalRow i) {
$initColumns
}

public int length() { return ${expressions.length};}
protected boolean[] nullBits = new boolean[${expressions.length}];
public void setNullAt(int i) { nullBits[i] = true; }
public boolean isNullAt(int i) { return nullBits[i]; }
public int length() { return ${expressions.length};}
protected boolean[] nullBits = new boolean[${expressions.length}];
public void setNullAt(int i) { nullBits[i] = true; }
public boolean isNullAt(int i) { return nullBits[i]; }

public Object get(int i) {
if (isNullAt(i)) return null;
switch (i) {
$getCases
public Object get(int i) {
if (isNullAt(i)) return null;
switch (i) {
$getCases
}
return null;
}
return null;
}
public void update(int i, Object value) {
if (value == null) {
setNullAt(i);
return;
public void update(int i, Object value) {
if (value == null) {
setNullAt(i);
return;
}
nullBits[i] = false;
switch (i) {
$updateCases
}
}
nullBits[i] = false;
switch (i) {
$updateCases
$specificAccessorFunctions
$specificMutatorFunctions

@Override
public int hashCode() {
int result = 37;
$hashUpdates
return result;
}
}
$specificAccessorFunctions
$specificMutatorFunctions

@Override
public int hashCode() {
int result = 37;
$hashUpdates
return result;
}

@Override
public boolean equals(Object other) {
if (other instanceof SpecificRow) {
SpecificRow row = (SpecificRow) other;
$columnChecks
return true;
@Override
public boolean equals(Object other) {
if (other instanceof SpecificRow) {
SpecificRow row = (SpecificRow) other;
$columnChecks
return true;
}
return super.equals(other);
}
return super.equals(other);
}

@Override
public InternalRow copy() {
Object[] arr = new Object[${expressions.length}];
${copyColumns}
return new ${classOf[GenericInternalRow].getName}(arr);
@Override
public InternalRow copy() {
Object[] arr = new Object[${expressions.length}];
${copyColumns}
return new ${classOf[GenericInternalRow].getName}(arr);
}
}
}
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.TaskContext
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.types.{DataType, DoubleType}
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
Expand All @@ -38,11 +39,7 @@ abstract class RDG(seed: Long) extends LeafExpression with Serializable {
* Record ID within each partition. By being transient, the Random Number Generator is
* reset every time we serialize and deserialize it.
*/
@transient protected lazy val partitionId = TaskContext.get() match {
case null => 0
case _ => TaskContext.get().partitionId()
}
@transient protected lazy val rng = new XORShiftRandom(seed + partitionId)
@transient protected lazy val rng = new XORShiftRandom(seed + TaskContext.getPartitionId)

override def deterministic: Boolean = false

Expand All @@ -61,6 +58,17 @@ case class Rand(seed: Long) extends RDG(seed) {
case IntegerLiteral(s) => s
case _ => throw new AnalysisException("Input argument to rand must be an integer literal.")
})

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val rngTerm = ctx.freshName("rng")
val className = classOf[XORShiftRandom].getCanonicalName
ctx.addMutableState(className, rngTerm,
s"new $className($seed + org.apache.spark.TaskContext.getPartitionId())")
ev.isNull = "false"
s"""
final ${ctx.javaType(dataType)} ${ev.primitive} = $rngTerm.nextDouble();
"""
}
}

/** Generate a random column with i.i.d. gaussian random distribution. */
Expand All @@ -73,4 +81,15 @@ case class Randn(seed: Long) extends RDG(seed) {
case IntegerLiteral(s) => s
case _ => throw new AnalysisException("Input argument to rand must be an integer literal.")
})

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val rngTerm = ctx.freshName("rng")
val className = classOf[XORShiftRandom].getCanonicalName
ctx.addMutableState(className, rngTerm,
s"new $className($seed + org.apache.spark.TaskContext.getPartitionId())")
ev.isNull = "false"
s"""
final ${ctx.javaType(dataType)} ${ev.primitive} = $rngTerm.nextGaussian();
"""
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.expressions
import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.LeafExpression
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.types.{LongType, DataType}

/**
Expand All @@ -40,13 +41,29 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression {
*/
@transient private[this] var count: Long = 0L

@transient private lazy val partitionMask = TaskContext.getPartitionId.toLong << 33

override def nullable: Boolean = false

override def dataType: DataType = LongType

override def eval(input: InternalRow): Long = {
val currentCount = count
count += 1
(TaskContext.get().partitionId().toLong << 33) + currentCount
partitionMask + currentCount
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val countTerm = ctx.freshName("count")
val partitionMaskTerm = ctx.freshName("partitionMask")
ctx.addMutableState(ctx.JAVA_LONG, countTerm, "0L")
ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm,
"((long) org.apache.spark.TaskContext.getPartitionId()) << 33")

ev.isNull = "false"
s"""
final ${ctx.javaType(dataType)} ${ev.primitive} = $partitionMaskTerm + $countTerm;
$countTerm++;
"""
}
}
Loading