-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-9020][SQL] Support mutable state in code gen expressions #7392
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -56,6 +56,14 @@ class CodeGenContext { | |
| */ | ||
| val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]() | ||
|
|
||
| /** | ||
| * Holding expressions' mutable states like `Rand.rng`, and keep them as member variables | ||
| * in generated classes like `SpecificProjection`. | ||
| * Each element is a 3-tuple: java type, variable name, variable value. | ||
| */ | ||
| val mutableStates: mutable.ArrayBuffer[(String, String, Any)] = | ||
| mutable.ArrayBuffer.empty[(String, String, Any)] | ||
|
|
||
| val stringType: String = classOf[UTF8String].getName | ||
| val decimalType: String = classOf[Decimal].getName | ||
|
|
||
|
|
@@ -205,7 +213,7 @@ class CodeGenContext { | |
|
|
||
|
|
||
| abstract class GeneratedClass { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. while you are at this, can you add scaladoc for GeneratedClass? It is not obvious what it does just by looking at it. |
||
| def generate(expressions: Array[Expression]): Any | ||
| def generate(expressions: Array[Expression], states: Array[Any]): Any | ||
| } | ||
|
|
||
| /** | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -70,17 +70,27 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR | |
| """ | ||
| }.mkString("\n") | ||
|
|
||
| val mutableStates = ctx.mutableStates.map { | ||
| case (jt, name, _) => s"private $jt $name;" | ||
| }.mkString("\n ") | ||
|
|
||
| val initStates = ctx.mutableStates.zipWithIndex.map { | ||
| case ((jt, name, _), index) => s"$name = (${ctx.boxedType(jt)}) states[$index];" | ||
| }.mkString("\n ") | ||
|
|
||
| val code = s""" | ||
| public SpecificOrdering generate($exprType[] expr) { | ||
| return new SpecificOrdering(expr); | ||
| public SpecificOrdering generate($exprType[] expr, Object[] states) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can't we pass in the code to initialize the variables, rather than using an object array?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It will be hard to define the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And this initialization will happen only once, and the member variables can be primitive type, so boxing is not a problem here, like
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe i'm missing something. what i'm saying is that the generated code can look something like class GenerateProjection456 {
private long nextId123;
public GenerateProjection456() {
nextId123 = 50L;
}
...
}does this not work?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i'm not worried about performance -- i just think it's ugly and unnecessary to pass the state around.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This works for literals, but how about objects? That's also the reason why we pass expressions this way...
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we have any expressions that require objects for now? if not, it's better to start with a simpler solution. if yes, then yes let's do this.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can't we just pass the code to create XORShiftRandom in the constructor?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just realized that... |
||
| return new SpecificOrdering(expr, states); | ||
| } | ||
|
|
||
| class SpecificOrdering extends ${classOf[BaseOrdering].getName} { | ||
|
|
||
| private $exprType[] expressions = null; | ||
| $mutableStates | ||
|
|
||
| public SpecificOrdering($exprType[] expr) { | ||
| public SpecificOrdering($exprType[] expr, Object[] states) { | ||
| expressions = expr; | ||
| $initStates | ||
| } | ||
|
|
||
| @Override | ||
|
|
@@ -93,6 +103,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR | |
|
|
||
| logDebug(s"Generated Ordering: $code") | ||
|
|
||
| compile(code).generate(ctx.references.toArray).asInstanceOf[BaseOrdering] | ||
| compile(code).generate(ctx.references.toArray, ctx.mutableStates.map(_._3).toArray) | ||
| .asInstanceOf[BaseOrdering] | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -151,85 +151,96 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { | |
| s"""if (!nullBits[$i]) arr[$i] = c$i;""" | ||
| }.mkString("\n ") | ||
|
|
||
| val mutableStates = ctx.mutableStates.map { | ||
| case (jt, name, _) => s"private $jt $name;" | ||
| }.mkString("\n ") | ||
|
|
||
| val initStates = ctx.mutableStates.zipWithIndex.map { | ||
| case ((jt, name, _), index) => s"$name = (${ctx.boxedType(jt)}) states[$index];" | ||
| }.mkString("\n ") | ||
|
|
||
| val code = s""" | ||
| public SpecificProjection generate($exprType[] expr) { | ||
| return new SpecificProjection(expr); | ||
| public SpecificProjection generate($exprType[] expr, Object[] states) { | ||
| return new SpecificProjection(expr, states); | ||
| } | ||
|
|
||
| class SpecificProjection extends ${classOf[BaseProject].getName} { | ||
| private $exprType[] expressions = null; | ||
| $mutableStates | ||
|
|
||
| public SpecificProjection($exprType[] expr) { | ||
| public SpecificProjection($exprType[] expr, Object[] states) { | ||
| expressions = expr; | ||
| $initStates | ||
| } | ||
|
|
||
| @Override | ||
| public Object apply(Object r) { | ||
| return new SpecificRow(expressions, (InternalRow) r); | ||
| return new SpecificRow((InternalRow) r); | ||
| } | ||
| } | ||
|
|
||
| final class SpecificRow extends ${classOf[MutableRow].getName} { | ||
| final class SpecificRow extends ${classOf[MutableRow].getName} { | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I made |
||
|
|
||
| $columns | ||
| $columns | ||
|
|
||
| public SpecificRow($exprType[] expressions, InternalRow i) { | ||
| $initColumns | ||
| } | ||
| public SpecificRow(InternalRow i) { | ||
| $initColumns | ||
| } | ||
|
|
||
| public int length() { return ${expressions.length};} | ||
| protected boolean[] nullBits = new boolean[${expressions.length}]; | ||
| public void setNullAt(int i) { nullBits[i] = true; } | ||
| public boolean isNullAt(int i) { return nullBits[i]; } | ||
| public int length() { return ${expressions.length};} | ||
| protected boolean[] nullBits = new boolean[${expressions.length}]; | ||
| public void setNullAt(int i) { nullBits[i] = true; } | ||
| public boolean isNullAt(int i) { return nullBits[i]; } | ||
|
|
||
| public Object get(int i) { | ||
| if (isNullAt(i)) return null; | ||
| switch (i) { | ||
| $getCases | ||
| public Object get(int i) { | ||
| if (isNullAt(i)) return null; | ||
| switch (i) { | ||
| $getCases | ||
| } | ||
| return null; | ||
| } | ||
| return null; | ||
| } | ||
| public void update(int i, Object value) { | ||
| if (value == null) { | ||
| setNullAt(i); | ||
| return; | ||
| public void update(int i, Object value) { | ||
| if (value == null) { | ||
| setNullAt(i); | ||
| return; | ||
| } | ||
| nullBits[i] = false; | ||
| switch (i) { | ||
| $updateCases | ||
| } | ||
| } | ||
| nullBits[i] = false; | ||
| switch (i) { | ||
| $updateCases | ||
| $specificAccessorFunctions | ||
| $specificMutatorFunctions | ||
|
|
||
| @Override | ||
| public int hashCode() { | ||
| int result = 37; | ||
| $hashUpdates | ||
| return result; | ||
| } | ||
| } | ||
| $specificAccessorFunctions | ||
| $specificMutatorFunctions | ||
|
|
||
| @Override | ||
| public int hashCode() { | ||
| int result = 37; | ||
| $hashUpdates | ||
| return result; | ||
| } | ||
|
|
||
| @Override | ||
| public boolean equals(Object other) { | ||
| if (other instanceof SpecificRow) { | ||
| SpecificRow row = (SpecificRow) other; | ||
| $columnChecks | ||
| return true; | ||
| @Override | ||
| public boolean equals(Object other) { | ||
| if (other instanceof SpecificRow) { | ||
| SpecificRow row = (SpecificRow) other; | ||
| $columnChecks | ||
| return true; | ||
| } | ||
| return super.equals(other); | ||
| } | ||
| return super.equals(other); | ||
| } | ||
|
|
||
| @Override | ||
| public InternalRow copy() { | ||
| Object[] arr = new Object[${expressions.length}]; | ||
| ${copyColumns} | ||
| return new ${classOf[GenericInternalRow].getName}(arr); | ||
| @Override | ||
| public InternalRow copy() { | ||
| Object[] arr = new Object[${expressions.length}]; | ||
| ${copyColumns} | ||
| return new ${classOf[GenericInternalRow].getName}(arr); | ||
| } | ||
| } | ||
| } | ||
| """ | ||
|
|
||
| logDebug(s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n${code}") | ||
|
|
||
| compile(code).generate(ctx.references.toArray).asInstanceOf[Projection] | ||
| compile(code).generate(ctx.references.toArray, ctx.mutableStates.map(_._3).toArray) | ||
| .asInstanceOf[Projection] | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.expressions | |
| import org.apache.spark.TaskContext | ||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions.LeafExpression | ||
| import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} | ||
| import org.apache.spark.sql.types.{LongType, DataType} | ||
|
|
||
| /** | ||
|
|
@@ -40,13 +41,31 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression { | |
| */ | ||
| @transient private[this] var count: Long = 0L | ||
|
|
||
| @transient protected lazy val partitionMask = TaskContext.get() match { | ||
| case null => 0L | ||
| case _ => TaskContext.get().partitionId().toLong << 33 | ||
| } | ||
|
|
||
| override def nullable: Boolean = false | ||
|
|
||
| override def dataType: DataType = LongType | ||
|
|
||
| override def eval(input: InternalRow): Long = { | ||
| val currentCount = count | ||
| count += 1 | ||
| (TaskContext.get().partitionId().toLong << 33) + currentCount | ||
| partitionMask + currentCount | ||
| } | ||
|
|
||
| override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { | ||
| val countTerm = ctx.freshName("count") | ||
| val partitionMaskTerm = ctx.freshName("partitionMask") | ||
| ctx.mutableStates += (("long", countTerm, count)) | ||
| ctx.mutableStates += (("long", partitionMaskTerm, partitionMask)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think you should just explicitly initialize partiitonmask, rather than relying on the current value (i.e. generate the code to initialize it with task context) |
||
|
|
||
| ev.isNull = "false" | ||
| s""" | ||
| final ${ctx.javaType(dataType)} ${ev.primitive} = $partitionMaskTerm + $countTerm; | ||
| $countTerm++; | ||
| """ | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's add a function so we don;t call += directly on the array buffer.
this way it also makes it more clear what the semantics are for each element in the tuple3.