Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -799,16 +799,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String

private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
case StringType =>
val wrapper = ctx.freshName("wrapper")
ctx.addMutableState("UTF8String.IntWrapper", wrapper,
s"$wrapper = new UTF8String.IntWrapper();")
val wrapper = ctx.freshName("intWrapper")
(c, evPrim, evNull) =>
s"""
UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper();
if ($c.toByte($wrapper)) {
$evPrim = (byte) $wrapper.value;
} else {
$evNull = true;
}
$wrapper = null;
"""
case BooleanType =>
(c, evPrim, evNull) => s"$evPrim = $c ? (byte) 1 : (byte) 0;"
Expand All @@ -826,16 +826,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
from: DataType,
ctx: CodegenContext): CastFunction = from match {
case StringType =>
val wrapper = ctx.freshName("wrapper")
ctx.addMutableState("UTF8String.IntWrapper", wrapper,
s"$wrapper = new UTF8String.IntWrapper();")
val wrapper = ctx.freshName("intWrapper")
(c, evPrim, evNull) =>
s"""
UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper();
if ($c.toShort($wrapper)) {
$evPrim = (short) $wrapper.value;
} else {
$evNull = true;
}
$wrapper = null;
"""
case BooleanType =>
(c, evPrim, evNull) => s"$evPrim = $c ? (short) 1 : (short) 0;"
Expand All @@ -851,16 +851,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String

private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
case StringType =>
val wrapper = ctx.freshName("wrapper")
ctx.addMutableState("UTF8String.IntWrapper", wrapper,
s"$wrapper = new UTF8String.IntWrapper();")
val wrapper = ctx.freshName("intWrapper")
(c, evPrim, evNull) =>
s"""
UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper();
if ($c.toInt($wrapper)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if we create a new wrapper every time?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would work well, too. We may have some overhead to create a small object IntWrapper and collect it at GC.

Copy link
Contributor

@cloud-fan cloud-fan Nov 23, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't GC work very well in this case? My concern is, if there is no significant performance benefit, reusing global variables may be too hacky.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. If we do not reuse global variables, we can use only a local variable.

$evPrim = $wrapper.value;
} else {
$evNull = true;
}
$wrapper = null;
"""
case BooleanType =>
(c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;"
Expand All @@ -876,17 +876,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String

private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
case StringType =>
val wrapper = ctx.freshName("wrapper")
ctx.addMutableState("UTF8String.LongWrapper", wrapper,
s"$wrapper = new UTF8String.LongWrapper();")
val wrapper = ctx.freshName("longWrapper")

(c, evPrim, evNull) =>
s"""
UTF8String.LongWrapper $wrapper = new UTF8String.LongWrapper();
if ($c.toLong($wrapper)) {
$evPrim = $wrapper.value;
} else {
$evNull = true;
}
$wrapper = null;
"""
case BooleanType =>
(c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0L;"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
val (preprocess, assigns, postprocess, arrayData) =
GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false)
ev.copy(
code = preprocess + ctx.splitExpressions(assigns) + postprocess,
code = preprocess + assigns + postprocess,
value = arrayData,
isNull = "false")
}
Expand All @@ -77,24 +77,22 @@ private [sql] object GenArrayData {
*
* @param ctx a [[CodegenContext]]
* @param elementType data type of underlying array elements
* @param elementsCode a set of [[ExprCode]] for each element of an underlying array
* @param elementsCode concatenated set of [[ExprCode]] for each element of an underlying array
* @param isMapKey if true, throw an exception when the element is null
* @return (code pre-assignments, assignments to each array elements, code post-assignments,
* arrayData name)
* @return (code pre-assignments, concatenated assignments to each array elements,
* code post-assignments, arrayData name)
*/
def genCodeToCreateArrayData(
ctx: CodegenContext,
elementType: DataType,
elementsCode: Seq[ExprCode],
isMapKey: Boolean): (String, Seq[String], String, String) = {
val arrayName = ctx.freshName("array")
isMapKey: Boolean): (String, String, String, String) = {
val arrayDataName = ctx.freshName("arrayData")
val numElements = elementsCode.length

if (!ctx.isPrimitiveType(elementType)) {
val arrayName = ctx.freshName("arrayObject")
val genericArrayClass = classOf[GenericArrayData].getName
ctx.addMutableState("Object[]", arrayName,
s"$arrayName = new Object[$numElements];")

val assignments = elementsCode.zipWithIndex.map { case (eval, i) =>
val isNullAssignment = if (!isMapKey) {
Expand All @@ -110,17 +108,21 @@ private [sql] object GenArrayData {
}
"""
}
val assignmentString = ctx.splitExpressions(
expressions = assignments,
funcName = "apply",
extraArguments = ("Object[]", arrayDataName) :: Nil)

("",
assignments,
(s"Object[] $arrayName = new Object[$numElements];",
assignmentString,
s"final ArrayData $arrayDataName = new $genericArrayClass($arrayName);",
arrayDataName)
} else {
val arrayName = ctx.freshName("array")
val unsafeArraySizeInBytes =
UnsafeArrayData.calculateHeaderPortionInBytes(numElements) +
ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements)
val baseOffset = Platform.BYTE_ARRAY_OFFSET
ctx.addMutableState("UnsafeArrayData", arrayDataName)

val primitiveValueTypeName = ctx.primitiveTypeName(elementType)
val assignments = elementsCode.zipWithIndex.map { case (eval, i) =>
Expand All @@ -137,14 +139,18 @@ private [sql] object GenArrayData {
}
"""
}
val assignmentString = ctx.splitExpressions(
expressions = assignments,
funcName = "apply",
extraArguments = ("UnsafeArrayData", arrayDataName) :: Nil)

(s"""
byte[] $arrayName = new byte[$unsafeArraySizeInBytes];
$arrayDataName = new UnsafeArrayData();
UnsafeArrayData $arrayDataName = new UnsafeArrayData();
Platform.putLong($arrayName, $baseOffset, $numElements);
$arrayDataName.pointTo($arrayName, $baseOffset, $unsafeArraySizeInBytes);
""",
assignments,
assignmentString,
"",
arrayDataName)
}
Expand Down Expand Up @@ -216,10 +222,10 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
s"""
final boolean ${ev.isNull} = false;
$preprocessKeyData
${ctx.splitExpressions(assignKeys)}
$assignKeys
$postprocessKeyData
$preprocessValueData
${ctx.splitExpressions(assignValues)}
$assignValues
$postprocessValueData
final MapData ${ev.value} = new $mapClass($keyArrayData, $valueArrayData);
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio

val termLastReplacement = ctx.freshName("lastReplacement")
val termLastReplacementInUTF8 = ctx.freshName("lastReplacementInUTF8")

val termResult = ctx.freshName("result")
val termResult = ctx.freshName("termResult")

val classNamePattern = classOf[Pattern].getCanonicalName
val classNameStringBuffer = classOf[java.lang.StringBuffer].getCanonicalName
Expand All @@ -334,8 +333,6 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
ctx.addMutableState("String", termLastReplacement, s"${termLastReplacement} = null;")
ctx.addMutableState("UTF8String",
termLastReplacementInUTF8, s"${termLastReplacementInUTF8} = null;")
ctx.addMutableState(classNameStringBuffer,
termResult, s"${termResult} = new $classNameStringBuffer();")

val setEvNotNull = if (nullable) {
s"${ev.isNull} = false;"
Expand All @@ -355,14 +352,15 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
${termLastReplacementInUTF8} = $rep.clone();
${termLastReplacement} = ${termLastReplacementInUTF8}.toString();
}
${termResult}.delete(0, ${termResult}.length());
$classNameStringBuffer ${termResult} = new $classNameStringBuffer();
java.util.regex.Matcher ${matcher} = ${termPattern}.matcher($subject.toString());

while (${matcher}.find()) {
${matcher}.appendReplacement(${termResult}, ${termLastReplacement});
}
${matcher}.appendTail(${termResult});
${ev.value} = UTF8String.fromString(${termResult}.toString());
${termResult} = null;
$setEvNotNull
"""
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.util.{Calendar, Locale, TimeZone}
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.DateTimeUtils.TimeZoneGMT
Expand Down Expand Up @@ -845,4 +846,11 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
val outputOuter = Row.fromSeq((1 to N).map(_ => outputInner))
checkEvaluation(cast(Literal.create(inputOuter, fromOuter), toOuter), outputOuter)
}

test("SPARK-22570: Cast should not create a lot of instance variables") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: instance variables -> global variables. To match with other two added tests.

val ctx = new CodegenContext
cast("1", IntegerType).genCode(ctx).code
cast("2", LongType).genCode(ctx).code
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: no need to call code

assert(ctx.mutableStates.length == 0)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodegenContext, CodeGenerator}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CodeAndComment and CodeGenerator are not used.

import org.apache.spark.sql.types.{IntegerType, StringType}

/**
Expand Down Expand Up @@ -178,6 +179,14 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(nonNullExpr, "num-num", row1)
}

test("SPARK-22570: RegExpReplace should not create a lot of global variables") {
val ctx = new CodegenContext
RegExpReplace(Literal("100"), Literal("(\\d+)"), Literal("num")).genCode(ctx)
// four global variables (lastRegex, pattern, lastReplacement, and lastReplacementInUTF8)
// are always required
assert(ctx.mutableStates.length == 4)
}

test("RegexExtract") {
val row1 = create_row("100-200", "(\\d+)-(\\d+)", 1)
val row2 = create_row("100-200", "(\\d+)-(\\d+)", 2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodegenContext, CodeGenerator}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
Expand Down Expand Up @@ -164,6 +165,12 @@ class ComplexTypesSuite extends PlanTest{
comparePlans(Optimizer execute query, expected)
}

test("SPARK-22570: CreateArray should not create a lot of global variables") {
val ctx = new CodegenContext
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

something like this, we have a codegen context, and do codegen. After that, instead of compiling the code to make sure it doesn't fail, we can just check the size of ctx.mutableStates to confirm that we don't add global variables during codegen

CreateArray(Seq(Literal(1))).genCode(ctx).code
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

assert(ctx.mutableStates.length == 0)
}

test("simplify map ops") {
val rel = relation
.select(
Expand Down