Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Use TrueLiteral/FalseLiteral. Add java type and access property to Ex…
…prValue.
  • Loading branch information
viirya committed Feb 28, 2018
commit f59bb19a3fd04b24ea3077a12283777be0af437d
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LiteralValue}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral}
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -75,7 +75,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
|$javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value);
""".stripMargin)
} else {
ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = LiteralValue("false"))
ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = FalseLiteral)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ abstract class Expression extends TreeNode[Expression] {
}.getOrElse {
val isNull = ctx.freshName("isNull")
val value = ctx.freshName("value")
val eval = doGenCode(ctx, ExprCode("", VariableValue(isNull), VariableValue(value)))
val eval = doGenCode(ctx, ExprCode("",
VariableValue(isNull, ExprType(ctx.JAVA_BOOLEAN, true)),
VariableValue(value, ExprType(ctx, dataType))))
reduceCodeSize(ctx, eval)
if (eval.code.nonEmpty) {
// Add `this` in the comment.
Expand All @@ -121,7 +123,7 @@ abstract class Expression extends TreeNode[Expression] {
val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) {
val globalIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "globalIsNull")
val localIsNull = eval.isNull
eval.isNull = GlobalValue(globalIsNull)
eval.isNull = GlobalValue(globalIsNull, ExprType(ctx.JAVA_BOOLEAN, true))
s"$globalIsNull = $localIsNull;"
} else {
""
Expand All @@ -140,7 +142,7 @@ abstract class Expression extends TreeNode[Expression] {
|}
""".stripMargin)

eval.value = VariableValue(newValue)
eval.value = VariableValue(newValue, ExprType(ctx, dataType))
eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});"
}
}
Expand Down Expand Up @@ -419,7 +421,7 @@ abstract class UnaryExpression extends Expression {
boolean ${ev.isNull} = false;
${childGen.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
$resultCode""", isNull = LiteralValue("false"))
$resultCode""", isNull = FalseLiteral)
}
}
}
Expand Down Expand Up @@ -519,7 +521,7 @@ abstract class BinaryExpression extends Expression {
${leftGen.code}
${rightGen.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
$resultCode""", isNull = LiteralValue("false"))
$resultCode""", isNull = FalseLiteral)
}
}
}
Expand Down Expand Up @@ -663,7 +665,7 @@ abstract class TernaryExpression extends Expression {
${midGen.code}
${rightGen.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
$resultCode""", isNull = LiteralValue("false"))
$resultCode""", isNull = FalseLiteral)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LiteralValue}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral}
import org.apache.spark.sql.types.{DataType, LongType}

/**
Expand Down Expand Up @@ -73,7 +73,7 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis

ev.copy(code = s"""
final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm;
$countTerm++;""", isNull = LiteralValue("false"))
$countTerm++;""", isNull = FalseLiteral)
}

override def prettyName: String = "monotonically_increasing_id"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LiteralValue}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral}
import org.apache.spark.sql.types.{DataType, IntegerType}

/**
Expand Down Expand Up @@ -47,6 +47,6 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic {
ctx.addImmutableStateIfNotExists(ctx.JAVA_INT, idTerm)
ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;")
ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;",
isNull = LiteralValue("false"))
isNull = FalseLiteral)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,8 @@ case class Least(children: Seq[Expression]) extends Expression {

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val evalChildren = children.map(_.genCode(ctx))
ev.isNull = GlobalValue(ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull))
ev.isNull = GlobalValue(ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull),
ExprType(ctx.JAVA_BOOLEAN, true))
val evals = evalChildren.map(eval =>
s"""
|${eval.code}
Expand Down Expand Up @@ -681,7 +682,8 @@ case class Greatest(children: Seq[Expression]) extends Expression {

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val evalChildren = children.map(_.genCode(ctx))
ev.isNull = GlobalValue(ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull))
ev.isNull = GlobalValue(ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull),
ExprType(ctx.JAVA_BOOLEAN, true))
val evals = evalChildren.map(eval =>
s"""
|${eval.code}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.util.{Map => JavaMap}
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.language.{existentials, implicitConversions}
import scala.language.existentials
import scala.util.control.NonFatal

import com.google.common.cache.{CacheBuilder, CacheLoader}
Expand Down Expand Up @@ -58,42 +58,6 @@ import org.apache.spark.util.{ParentClassLoader, Utils}
*/
case class ExprCode(var code: String, var isNull: ExprValue, var value: ExprValue)


// An abstraction that represents the evaluation result of [[ExprCode]].
abstract class ExprValue

object ExprValue {
implicit def exprValueToString(exprValue: ExprValue): String = exprValue.toString
}

// A literal evaluation of [[ExprCode]].
class LiteralValue(val value: String) extends ExprValue {
override def toString: String = value
}

object LiteralValue {
def apply(value: String): LiteralValue = new LiteralValue(value)
def unapply(literal: LiteralValue): Option[String] = Some(literal.value)
}

// A variable evaluation of [[ExprCode]].
case class VariableValue(val variableName: String) extends ExprValue {
override def toString: String = variableName
}

// A statement evaluation of [[ExprCode]].
case class StatementValue(val statement: String) extends ExprValue {
override def toString: String = statement
}

// A global variable evaluation of [[ExprCode]].
case class GlobalValue(val value: String) extends ExprValue {
override def toString: String = value
}

case object TrueLiteral extends LiteralValue("true")
case object FalseLiteral extends LiteralValue("false")

object ExprCode {
def forNonNullValue(value: ExprValue): ExprCode = {
ExprCode(code = "", isNull = FalseLiteral, value = value)
Expand Down Expand Up @@ -359,7 +323,8 @@ class CodegenContext {
case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();"
case _ => s"$value = $initCode;"
}
ExprCode(code, LiteralValue("false"), GlobalValue(value))
ExprCode(code, FalseLiteral,
GlobalValue(value, ExprType(this, dataType)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this can go on one line

}

def declareMutableStates(): String = {
Expand Down Expand Up @@ -1244,7 +1209,8 @@ class CodegenContext {
// at least two nodes) as the cost of doing it is expected to be low.

subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);"
val state = SubExprEliminationState(GlobalValue(isNull), GlobalValue(value))
val state = SubExprEliminationState(GlobalValue(isNull, ExprType(JAVA_BOOLEAN, true)),
GlobalValue(value, ExprType(this, expr.dataType)))
e.foreach(subExprEliminationExprs.put(_, state))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ trait CodegenFallback extends Expression {
$placeHolder
Object $objectTerm = ((Expression) references[$idx]).eval($input);
${ctx.javaType(this.dataType)} ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm;
""", isNull = LiteralValue("false"))
""", isNull = FalseLiteral)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions.codegen

import scala.language.implicitConversions

import org.apache.spark.sql.types.DataType

// An abstraction that represents the evaluation result of [[ExprCode]].
abstract class ExprValue {

val javaType: ExprType

// Whether we can directly access the evaluation value anywhere.
// For example, a variable created outside a method can not be accessed inside the method.
// For such cases, we may need to pass the evaluation as parameter.
val canDirectAccess: Boolean
}

object ExprValue {
implicit def exprValueToString(exprValue: ExprValue): String = exprValue.toString
}

// A literal evaluation of [[ExprCode]].
class LiteralValue(val value: String, val javaType: ExprType) extends ExprValue {
override def toString: String = value
override val canDirectAccess: Boolean = true
}

object LiteralValue {
def apply(value: String, javaType: ExprType): LiteralValue = new LiteralValue(value, javaType)
def unapply(literal: LiteralValue): Option[(String, ExprType)] =
Some((literal.value, literal.javaType))
}

// A variable evaluation of [[ExprCode]].
case class VariableValue(
val variableName: String,
val javaType: ExprType,
val canDirectAccess: Boolean = false) extends ExprValue {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why isn't this fixed like for GlobalValue?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to give it a bit flexibility for something like static variable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a static variable is a GlobalValue, isn't it? Considering that we should be able to access also from methods in other internal classes I don't see any use case where this flexibility is required, honestly...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. I'd let it as fixed.

override def toString: String = variableName
}

// A statement evaluation of [[ExprCode]].
case class StatementValue(
val statement: String,
val javaType: ExprType,
val canDirectAccess: Boolean = false) extends ExprValue {
override def toString: String = statement
}

// A global variable evaluation of [[ExprCode]].
case class GlobalValue(val value: String, val javaType: ExprType) extends ExprValue {
override def toString: String = value
override val canDirectAccess: Boolean = true
}

case object TrueLiteral extends LiteralValue("true", ExprType("boolean", true))
case object FalseLiteral extends LiteralValue("false", ExprType("boolean", true))

// Represents the java type of an evaluation.
case class ExprType(val typeName: String, val isPrimitive: Boolean)
Copy link
Contributor

@mgaido91 mgaido91 Feb 28, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this isPrimitive needed? I think we can get rid of this and use the isPrimitive method when needed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the idea is to include java type information for an evaluation. Then we don't need to consult CodegenContext.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

than can't we move the method from CodegenContext to a static method and use that? Currently this information is never used, and I feel this is hard to maintain (even though I don't expect it to change frequently).

We can add a method like

def isPrimitive: Boolean = CodegenContext.isPrimitive(typeName)


object ExprType {
def apply(ctx: CodegenContext, dataType: DataType): ExprType = ExprType(ctx.javaType(dataType),
ctx.isPrimitiveType(dataType))
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
|${ev.code}
|$isNull = ${ev.isNull};
|$value = ${ev.value};
""".stripMargin, GlobalValue(isNull), value, i)
""".stripMargin, GlobalValue(isNull, ExprType(ctx.JAVA_BOOLEAN, true)), value, i)
} else {
(s"""
|${ev.code}
Expand All @@ -83,7 +83,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP

val updates = validExpr.zip(projectionCodes).map {
case (e, (_, isNull, value, i)) =>
val ev = ExprCode("", isNull, GlobalValue(value))
val ev = ExprCode("", isNull, GlobalValue(value, ExprType(ctx, e.dataType)))
ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
val rowClass = classOf[GenericInternalRow].getName

val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) =>
val converter = convertToSafe(ctx, StatementValue(ctx.getValue(tmpInput, dt, i.toString)), dt)
val converter = convertToSafe(ctx, StatementValue(ctx.getValue(tmpInput, dt, i.toString),
ExprType(ctx, dt)), dt)
s"""
if (!$tmpInput.isNullAt($i)) {
${converter.code}
Expand All @@ -74,7 +75,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
|final InternalRow $output = new $rowClass($values);
""".stripMargin

ExprCode(code, LiteralValue("false"), VariableValue(output))
ExprCode(code, FalseLiteral, VariableValue(output, ExprType("InternalRow", false)))
}

private def createCodeForArray(
Expand All @@ -90,7 +91,8 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
val arrayClass = classOf[GenericArrayData].getName

val elementConverter = convertToSafe(
ctx, StatementValue(ctx.getValue(tmpInput, elementType, index)), elementType)
ctx, StatementValue(ctx.getValue(tmpInput, elementType, index), ExprType(ctx, elementType)),
elementType)
val code = s"""
final ArrayData $tmpInput = $input;
final int $numElements = $tmpInput.numElements();
Expand All @@ -104,7 +106,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
final ArrayData $output = new $arrayClass($values);
"""

ExprCode(code, LiteralValue("false"), VariableValue(output))
ExprCode(code, FalseLiteral, VariableValue(output, ExprType("ArrayData", false)))
}

private def createCodeForMap(
Expand All @@ -125,7 +127,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
final MapData $output = new $mapClass(${keyConverter.value}, ${valueConverter.value});
"""

ExprCode(code, LiteralValue("false"), VariableValue(output))
ExprCode(code, FalseLiteral, VariableValue(output, ExprType("MapData", false)))
}

@tailrec
Expand All @@ -137,7 +139,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType)
case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType)
case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType)
case _ => ExprCode("", LiteralValue("false"), input)
case _ => ExprCode("", FalseLiteral, input)
}

protected def create(expressions: Seq[Expression]): Projection = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
// Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
val tmpInput = ctx.freshName("tmpInput")
val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) =>
ExprCode("", StatementValue(s"$tmpInput.isNullAt($i)"),
StatementValue(ctx.getValue(tmpInput, dt, i.toString)))
ExprCode("", StatementValue(s"$tmpInput.isNullAt($i)", ExprType(ctx.JAVA_BOOLEAN, true)),
StatementValue(ctx.getValue(tmpInput, dt, i.toString), ExprType(ctx, dt)))
}

s"""
Expand Down Expand Up @@ -348,7 +348,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
$writeExpressions
$updateRowSize
"""
ExprCode(code, LiteralValue("false"), GlobalValue(result))
ExprCode(code, FalseLiteral, GlobalValue(result, ExprType("UnsafeRow", false)))
}

protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType
boolean ${ev.isNull} = false;
${childGen.code}
${ctx.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 :
(${childGen.value}).numElements();""", isNull = LiteralValue("false"))
(${childGen.value}).numElements();""", isNull = FalseLiteral)
}
}

Expand Down
Loading