Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add wrappers for codegen output.
  • Loading branch information
viirya committed Dec 21, 2017
commit 5ace8b83b7c90cd5a6a451812ac9c1087aaa1c29
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LiteralValue}
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -75,7 +75,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
|$javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value);
""".stripMargin)
} else {
ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = "false")
ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = LiteralValue("false"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: shall we introduce a TrueLiteral and FalseLiteral?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be useful to the isNull field.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like a good idea.

}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ abstract class Expression extends TreeNode[Expression] {
}.getOrElse {
val isNull = ctx.freshName("isNull")
val value = ctx.freshName("value")
val eval = doGenCode(ctx, ExprCode("", isNull, value))
val eval = doGenCode(ctx, ExprCode("", VariableValue(isNull), VariableValue(value)))
reduceCodeSize(ctx, eval)
if (eval.code.nonEmpty) {
// Add `this` in the comment.
Expand All @@ -118,10 +118,10 @@ abstract class Expression extends TreeNode[Expression] {
private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = {
// TODO: support whole stage codegen too
if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) {
val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") {
val setIsNull = if ("false" != eval.isNull && "true" != eval.isNull) {
val globalIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "globalIsNull")
val localIsNull = eval.isNull
eval.isNull = globalIsNull
eval.isNull = GlobalValue(globalIsNull)
s"$globalIsNull = $localIsNull;"
} else {
""
Expand All @@ -140,7 +140,7 @@ abstract class Expression extends TreeNode[Expression] {
|}
""".stripMargin)

eval.value = newValue
eval.value = VariableValue(newValue)
eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});"
}
}
Expand Down Expand Up @@ -419,7 +419,7 @@ abstract class UnaryExpression extends Expression {
boolean ${ev.isNull} = false;
${childGen.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
$resultCode""", isNull = "false")
$resultCode""", isNull = LiteralValue("false"))
}
}
}
Expand Down Expand Up @@ -519,7 +519,7 @@ abstract class BinaryExpression extends Expression {
${leftGen.code}
${rightGen.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
$resultCode""", isNull = "false")
$resultCode""", isNull = LiteralValue("false"))
}
}
}
Expand Down Expand Up @@ -663,7 +663,7 @@ abstract class TernaryExpression extends Expression {
${midGen.code}
${rightGen.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
$resultCode""", isNull = "false")
$resultCode""", isNull = LiteralValue("false"))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LiteralValue}
import org.apache.spark.sql.types.{DataType, LongType}

/**
Expand Down Expand Up @@ -72,7 +72,7 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis

ev.copy(code = s"""
final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm;
$countTerm++;""", isNull = "false")
$countTerm++;""", isNull = LiteralValue("false"))
}

override def prettyName: String = "monotonically_increasing_id"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LiteralValue}
import org.apache.spark.sql.types.{DataType, IntegerType}

/**
Expand All @@ -45,6 +45,7 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val idTerm = ctx.addMutableState(ctx.JAVA_INT, "partitionId")
ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;")
ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false")
ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;",
isNull = LiteralValue("false"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.util.{Map => JavaMap}
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.language.existentials
import scala.language.{existentials, implicitConversions}
import scala.util.control.NonFatal

import com.google.common.cache.{CacheBuilder, CacheLoader}
Expand Down Expand Up @@ -56,7 +56,36 @@ import org.apache.spark.util.{ParentClassLoader, Utils}
* @param value A term for a (possibly primitive) value of the result of the evaluation. Not
* valid if `isNull` is set to `true`.
*/
case class ExprCode(var code: String, var isNull: String, var value: String)
case class ExprCode(var code: String, var isNull: ExprValue, var value: ExprValue)


// An abstraction that represents the evaluation result of [[ExprCode]].
abstract class ExprValue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should classify ExprValue by our needs, not by java definitions. Thinking about the needs, we wanna know: 1) if this value is accessible anywhere and we don't need to carry it via method parameters. 2) if this value needs to be carried with parameters, do we need to generate a parameter name or use this value directly?

So basically we can combine LiteralValue and GlobalValue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO I prefer this approach because in the future we might need to distinguish these two cases, thus I think is a good thing to let them be distinct.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now LiteralValue and GlobalValue can be seen as the same effectively, as they are all accessible anywhere and we don't need to carry it via method parameters.

I don't have strong preference here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kiszk WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In summary, I have no strong preference.

In the future, we will want to distinguish Literal and Global for some optimizations. This is already one of optimizations for Literal.

If this PR just focuses on classifying types between arguments and non-arguments, it is fine to combine Literal and Global. Then, another PR will separate one type into Literal and Global.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If no strong preference for combining them, I'd keep it as two concepts for now, if we foresee the need to distinguish them.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK let's keep it.


object ExprValue {
implicit def exprValueToString(exprValue: ExprValue): String = exprValue.toString
}

// A literal evaluation of [[ExprCode]].
case class LiteralValue(val value: String) extends ExprValue {
override def toString: String = value
}

// A variable evaluation of [[ExprCode]].
case class VariableValue(val variableName: String) extends ExprValue {
override def toString: String = variableName
}

// A statement evaluation of [[ExprCode]].
case class StatementValue(val statement: String) extends ExprValue {
override def toString: String = statement
}

// A global variable evaluation of [[ExprCode]].
case class GlobalValue(val value: String) extends ExprValue {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for compacted global variables, we may get something like arr[1] while arr is a global variable. Is arr[1] a statement or global variable?

Copy link
Member Author

@viirya viirya Dec 21, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is considered as global variable now, as it can be accessed globally and don't/can't/shouldn't be a parameter. Actually we don't want to take global variables as parameters.

override def toString: String = value
}


/**
* State used for subexpression elimination.
Expand All @@ -66,7 +95,7 @@ case class ExprCode(var code: String, var isNull: String, var value: String)
* @param value A term for a value of a common sub-expression. Not valid if `isNull`
* is set to `true`.
*/
case class SubExprEliminationState(isNull: String, value: String)
case class SubExprEliminationState(isNull: ExprValue, value: ExprValue)

/**
* Codes and common subexpressions mapping used for subexpression elimination.
Expand Down Expand Up @@ -264,7 +293,7 @@ class CodegenContext {
case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();"
case _ => s"$value = $initCode;"
}
ExprCode(code, "false", value)
ExprCode(code, LiteralValue("false"), GlobalValue(value))
}

def declareMutableStates(): String = {
Expand Down Expand Up @@ -1144,7 +1173,7 @@ class CodegenContext {
// at least two nodes) as the cost of doing it is expected to be low.

subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);"
val state = SubExprEliminationState(isNull, value)
val state = SubExprEliminationState(GlobalValue(isNull), GlobalValue(value))
e.foreach(subExprEliminationExprs.put(_, state))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ trait CodegenFallback extends Expression {
$placeHolder
Object $objectTerm = ((Expression) references[$idx]).eval($input);
${ctx.javaType(this.dataType)} ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm;
""", isNull = "false")
""", isNull = LiteralValue("false"))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
val exprVals = ctx.generateExpressions(validExpr, useSubexprElimination)

// 4-tuples: (code for projection, isNull variable name, value variable name, column index)
val projectionCodes: Seq[(String, String, String, Int)] = exprVals.zip(index).map {
val projectionCodes: Seq[(String, ExprValue, String, Int)] = exprVals.zip(index).map {
case (ev, i) =>
val e = expressions(i)
val value = ctx.addMutableState(ctx.javaType(e.dataType), "value")
Expand All @@ -69,7 +69,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
|${ev.code}
|$isNull = ${ev.isNull};
|$value = ${ev.value};
""".stripMargin, isNull, value, i)
""".stripMargin, GlobalValue(isNull), value, i)
} else {
(s"""
|${ev.code}
Expand All @@ -83,7 +83,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP

val updates = validExpr.zip(projectionCodes).map {
case (e, (_, isNull, value, i)) =>
val ev = ExprCode("", isNull, value)
val ev = ExprCode("", isNull, GlobalValue(value))
ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
val rowClass = classOf[GenericInternalRow].getName

val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) =>
val converter = convertToSafe(ctx, ctx.getValue(tmpInput, dt, i.toString), dt)
val converter = convertToSafe(ctx, StatementValue(ctx.getValue(tmpInput, dt, i.toString)), dt)
s"""
if (!$tmpInput.isNullAt($i)) {
${converter.code}
Expand All @@ -74,7 +74,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
|final InternalRow $output = new $rowClass($values);
""".stripMargin

ExprCode(code, "false", output)
ExprCode(code, LiteralValue("false"), VariableValue(output))
}

private def createCodeForArray(
Expand All @@ -90,7 +90,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
val arrayClass = classOf[GenericArrayData].getName

val elementConverter = convertToSafe(
ctx, ctx.getValue(tmpInput, elementType, index), elementType)
ctx, StatementValue(ctx.getValue(tmpInput, elementType, index)), elementType)
val code = s"""
final ArrayData $tmpInput = $input;
final int $numElements = $tmpInput.numElements();
Expand All @@ -104,7 +104,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
final ArrayData $output = new $arrayClass($values);
"""

ExprCode(code, "false", output)
ExprCode(code, LiteralValue("false"), VariableValue(output))
}

private def createCodeForMap(
Expand All @@ -125,19 +125,19 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
final MapData $output = new $mapClass(${keyConverter.value}, ${valueConverter.value});
"""

ExprCode(code, "false", output)
ExprCode(code, LiteralValue("false"), VariableValue(output))
}

@tailrec
private def convertToSafe(
ctx: CodegenContext,
input: String,
input: ExprValue,
dataType: DataType): ExprCode = dataType match {
case s: StructType => createCodeForStruct(ctx, input, s)
case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType)
case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType)
case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType)
case _ => ExprCode("", "false", input)
case _ => ExprCode("", LiteralValue("false"), input)
}

protected def create(expressions: Seq[Expression]): Projection = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
// Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
val tmpInput = ctx.freshName("tmpInput")
val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) =>
ExprCode("", s"$tmpInput.isNullAt($i)", ctx.getValue(tmpInput, dt, i.toString))
ExprCode("", StatementValue(s"$tmpInput.isNullAt($i)"),
StatementValue(ctx.getValue(tmpInput, dt, i.toString)))
}

s"""
Expand Down Expand Up @@ -347,7 +348,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
$writeExpressions
$updateRowSize
"""
ExprCode(code, "false", result)
ExprCode(code, LiteralValue("false"), GlobalValue(result))
}

protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import java.util.Comparator

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -55,7 +55,7 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType
boolean ${ev.isNull} = false;
${childGen.code}
${ctx.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 :
(${childGen.value}).numElements();""", isNull = "false")
(${childGen.value}).numElements();""", isNull = LiteralValue("false"))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false)
ev.copy(
code = preprocess + assigns + postprocess,
value = arrayData,
isNull = "false")
value = VariableValue(arrayData),
isNull = LiteralValue("false"))
}

override def prettyName: String = "array"
Expand Down Expand Up @@ -378,7 +378,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc
|$valuesCode
|final InternalRow ${ev.value} = new $rowClass($values);
|$values = null;
""".stripMargin, isNull = "false")
""".stripMargin, isNull = LiteralValue("false"))
}

override def prettyName: String = "named_struct"
Expand All @@ -394,7 +394,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc
case class CreateNamedStructUnsafe(children: Seq[Expression]) extends CreateNamedStructLike {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val eval = GenerateUnsafeProjection.createCode(ctx, valExprs)
ExprCode(code = eval.code, isNull = "false", value = eval.value)
ExprCode(code = eval.code, isNull = LiteralValue("false"), value = eval.value)
}

override def prettyName: String = "named_struct_unsafe"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import java.util.{Calendar, TimeZone}
import scala.util.control.NonFatal

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
Expand Down Expand Up @@ -673,7 +673,7 @@ abstract class UnixTime
case StringType if right.foldable =>
val df = classOf[DateFormat].getName
if (formatter == null) {
ExprCode("", "true", ctx.defaultValue(dataType))
ExprCode("", LiteralValue("true"), LiteralValue(ctx.defaultValue(dataType)))
} else {
val formatterName = ctx.addReferenceObj("formatter", formatter, df)
val eval1 = left.genCode(ctx)
Expand Down Expand Up @@ -808,7 +808,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[
val df = classOf[DateFormat].getName
if (format.foldable) {
if (formatter == null) {
ExprCode("", "true", "(UTF8String) null")
ExprCode("", LiteralValue("true"), LiteralValue("(UTF8String) null"))
} else {
val formatterName = ctx.addReferenceObj("formatter", formatter, df)
val t = left.genCode(ctx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import scala.collection.mutable
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -218,7 +218,7 @@ case class Stack(children: Seq[Expression]) extends Generator {
s"$wrapperClass<InternalRow>",
ev.value,
v => s"$v = $wrapperClass$$.MODULE$$.make($rowData);", useFreshName = false)
ev.copy(code = code, isNull = "false")
ev.copy(code = code, isNull = LiteralValue("false"))
}
}

Expand Down
Loading