Skip to content

Commit 999ec13

Browse files
kiszkcloud-fan
authored andcommitted
[SPARK-22570][SQL] Avoid to create a lot of global variables by using a local variable with allocation of an object in generated code
## What changes were proposed in this pull request? This PR reduces # of global variables in generated code by replacing a global variable with a local variable with an allocation of an object every time. When a lot of global variables were generated, the generated code may meet 64K constant pool limit. This PR reduces # of generated global variables in the following three operations: * `Cast` with String to primitive byte/short/int/long * `RegExpReplace` * `CreateArray` I intentionally leave [this part](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala#L595-L603). This is because this variable keeps a class that is dynamically generated. In other word, it is not possible to reuse one class. ## How was this patch tested? Added test cases Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com> Closes #19797 from kiszk/SPARK-22570.
1 parent 932bd09 commit 999ec13

File tree

6 files changed

+61
-33
lines changed

6 files changed

+61
-33
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -799,16 +799,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
799799

800800
private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
801801
case StringType =>
802-
val wrapper = ctx.freshName("wrapper")
803-
ctx.addMutableState("UTF8String.IntWrapper", wrapper,
804-
s"$wrapper = new UTF8String.IntWrapper();")
802+
val wrapper = ctx.freshName("intWrapper")
805803
(c, evPrim, evNull) =>
806804
s"""
805+
UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper();
807806
if ($c.toByte($wrapper)) {
808807
$evPrim = (byte) $wrapper.value;
809808
} else {
810809
$evNull = true;
811810
}
811+
$wrapper = null;
812812
"""
813813
case BooleanType =>
814814
(c, evPrim, evNull) => s"$evPrim = $c ? (byte) 1 : (byte) 0;"
@@ -826,16 +826,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
826826
from: DataType,
827827
ctx: CodegenContext): CastFunction = from match {
828828
case StringType =>
829-
val wrapper = ctx.freshName("wrapper")
830-
ctx.addMutableState("UTF8String.IntWrapper", wrapper,
831-
s"$wrapper = new UTF8String.IntWrapper();")
829+
val wrapper = ctx.freshName("intWrapper")
832830
(c, evPrim, evNull) =>
833831
s"""
832+
UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper();
834833
if ($c.toShort($wrapper)) {
835834
$evPrim = (short) $wrapper.value;
836835
} else {
837836
$evNull = true;
838837
}
838+
$wrapper = null;
839839
"""
840840
case BooleanType =>
841841
(c, evPrim, evNull) => s"$evPrim = $c ? (short) 1 : (short) 0;"
@@ -851,16 +851,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
851851

852852
private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
853853
case StringType =>
854-
val wrapper = ctx.freshName("wrapper")
855-
ctx.addMutableState("UTF8String.IntWrapper", wrapper,
856-
s"$wrapper = new UTF8String.IntWrapper();")
854+
val wrapper = ctx.freshName("intWrapper")
857855
(c, evPrim, evNull) =>
858856
s"""
857+
UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper();
859858
if ($c.toInt($wrapper)) {
860859
$evPrim = $wrapper.value;
861860
} else {
862861
$evNull = true;
863862
}
863+
$wrapper = null;
864864
"""
865865
case BooleanType =>
866866
(c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;"
@@ -876,17 +876,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
876876

877877
private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
878878
case StringType =>
879-
val wrapper = ctx.freshName("wrapper")
880-
ctx.addMutableState("UTF8String.LongWrapper", wrapper,
881-
s"$wrapper = new UTF8String.LongWrapper();")
879+
val wrapper = ctx.freshName("longWrapper")
882880

883881
(c, evPrim, evNull) =>
884882
s"""
883+
UTF8String.LongWrapper $wrapper = new UTF8String.LongWrapper();
885884
if ($c.toLong($wrapper)) {
886885
$evPrim = $wrapper.value;
887886
} else {
888887
$evNull = true;
889888
}
889+
$wrapper = null;
890890
"""
891891
case BooleanType =>
892892
(c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0L;"

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
6363
val (preprocess, assigns, postprocess, arrayData) =
6464
GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false)
6565
ev.copy(
66-
code = preprocess + ctx.splitExpressions(assigns) + postprocess,
66+
code = preprocess + assigns + postprocess,
6767
value = arrayData,
6868
isNull = "false")
6969
}
@@ -77,24 +77,22 @@ private [sql] object GenArrayData {
7777
*
7878
* @param ctx a [[CodegenContext]]
7979
* @param elementType data type of underlying array elements
80-
* @param elementsCode a set of [[ExprCode]] for each element of an underlying array
80+
* @param elementsCode concatenated set of [[ExprCode]] for each element of an underlying array
8181
* @param isMapKey if true, throw an exception when the element is null
82-
* @return (code pre-assignments, assignments to each array elements, code post-assignments,
83-
* arrayData name)
82+
* @return (code pre-assignments, concatenated assignments to each array elements,
83+
* code post-assignments, arrayData name)
8484
*/
8585
def genCodeToCreateArrayData(
8686
ctx: CodegenContext,
8787
elementType: DataType,
8888
elementsCode: Seq[ExprCode],
89-
isMapKey: Boolean): (String, Seq[String], String, String) = {
90-
val arrayName = ctx.freshName("array")
89+
isMapKey: Boolean): (String, String, String, String) = {
9190
val arrayDataName = ctx.freshName("arrayData")
9291
val numElements = elementsCode.length
9392

9493
if (!ctx.isPrimitiveType(elementType)) {
94+
val arrayName = ctx.freshName("arrayObject")
9595
val genericArrayClass = classOf[GenericArrayData].getName
96-
ctx.addMutableState("Object[]", arrayName,
97-
s"$arrayName = new Object[$numElements];")
9896

9997
val assignments = elementsCode.zipWithIndex.map { case (eval, i) =>
10098
val isNullAssignment = if (!isMapKey) {
@@ -110,17 +108,21 @@ private [sql] object GenArrayData {
110108
}
111109
"""
112110
}
111+
val assignmentString = ctx.splitExpressions(
112+
expressions = assignments,
113+
funcName = "apply",
114+
extraArguments = ("Object[]", arrayDataName) :: Nil)
113115

114-
("",
115-
assignments,
116+
(s"Object[] $arrayName = new Object[$numElements];",
117+
assignmentString,
116118
s"final ArrayData $arrayDataName = new $genericArrayClass($arrayName);",
117119
arrayDataName)
118120
} else {
121+
val arrayName = ctx.freshName("array")
119122
val unsafeArraySizeInBytes =
120123
UnsafeArrayData.calculateHeaderPortionInBytes(numElements) +
121124
ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements)
122125
val baseOffset = Platform.BYTE_ARRAY_OFFSET
123-
ctx.addMutableState("UnsafeArrayData", arrayDataName)
124126

125127
val primitiveValueTypeName = ctx.primitiveTypeName(elementType)
126128
val assignments = elementsCode.zipWithIndex.map { case (eval, i) =>
@@ -137,14 +139,18 @@ private [sql] object GenArrayData {
137139
}
138140
"""
139141
}
142+
val assignmentString = ctx.splitExpressions(
143+
expressions = assignments,
144+
funcName = "apply",
145+
extraArguments = ("UnsafeArrayData", arrayDataName) :: Nil)
140146

141147
(s"""
142148
byte[] $arrayName = new byte[$unsafeArraySizeInBytes];
143-
$arrayDataName = new UnsafeArrayData();
149+
UnsafeArrayData $arrayDataName = new UnsafeArrayData();
144150
Platform.putLong($arrayName, $baseOffset, $numElements);
145151
$arrayDataName.pointTo($arrayName, $baseOffset, $unsafeArraySizeInBytes);
146152
""",
147-
assignments,
153+
assignmentString,
148154
"",
149155
arrayDataName)
150156
}
@@ -216,10 +222,10 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
216222
s"""
217223
final boolean ${ev.isNull} = false;
218224
$preprocessKeyData
219-
${ctx.splitExpressions(assignKeys)}
225+
$assignKeys
220226
$postprocessKeyData
221227
$preprocessValueData
222-
${ctx.splitExpressions(assignValues)}
228+
$assignValues
223229
$postprocessValueData
224230
final MapData ${ev.value} = new $mapClass($keyArrayData, $valueArrayData);
225231
"""

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -321,8 +321,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
321321

322322
val termLastReplacement = ctx.freshName("lastReplacement")
323323
val termLastReplacementInUTF8 = ctx.freshName("lastReplacementInUTF8")
324-
325-
val termResult = ctx.freshName("result")
324+
val termResult = ctx.freshName("termResult")
326325

327326
val classNamePattern = classOf[Pattern].getCanonicalName
328327
val classNameStringBuffer = classOf[java.lang.StringBuffer].getCanonicalName
@@ -334,8 +333,6 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
334333
ctx.addMutableState("String", termLastReplacement, s"${termLastReplacement} = null;")
335334
ctx.addMutableState("UTF8String",
336335
termLastReplacementInUTF8, s"${termLastReplacementInUTF8} = null;")
337-
ctx.addMutableState(classNameStringBuffer,
338-
termResult, s"${termResult} = new $classNameStringBuffer();")
339336

340337
val setEvNotNull = if (nullable) {
341338
s"${ev.isNull} = false;"
@@ -355,14 +352,15 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
355352
${termLastReplacementInUTF8} = $rep.clone();
356353
${termLastReplacement} = ${termLastReplacementInUTF8}.toString();
357354
}
358-
${termResult}.delete(0, ${termResult}.length());
355+
$classNameStringBuffer ${termResult} = new $classNameStringBuffer();
359356
java.util.regex.Matcher ${matcher} = ${termPattern}.matcher($subject.toString());
360357

361358
while (${matcher}.find()) {
362359
${matcher}.appendReplacement(${termResult}, ${termLastReplacement});
363360
}
364361
${matcher}.appendTail(${termResult});
365362
${ev.value} = UTF8String.fromString(${termResult}.toString());
363+
${termResult} = null;
366364
$setEvNotNull
367365
"""
368366
})

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import java.util.{Calendar, Locale, TimeZone}
2323
import org.apache.spark.SparkFunSuite
2424
import org.apache.spark.sql.Row
2525
import org.apache.spark.sql.catalyst.InternalRow
26+
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
2627
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
2728
import org.apache.spark.sql.catalyst.util.DateTimeUtils
2829
import org.apache.spark.sql.catalyst.util.DateTimeUtils.TimeZoneGMT
@@ -845,4 +846,11 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
845846
val outputOuter = Row.fromSeq((1 to N).map(_ => outputInner))
846847
checkEvaluation(cast(Literal.create(inputOuter, fromOuter), toOuter), outputOuter)
847848
}
849+
850+
test("SPARK-22570: Cast should not create a lot of global variables") {
851+
val ctx = new CodegenContext
852+
cast("1", IntegerType).genCode(ctx)
853+
cast("2", LongType).genCode(ctx)
854+
assert(ctx.mutableStates.length == 0)
855+
}
848856
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst.expressions
2020
import org.apache.spark.SparkFunSuite
2121
import org.apache.spark.sql.AnalysisException
2222
import org.apache.spark.sql.catalyst.dsl.expressions._
23-
import org.apache.spark.sql.types.{IntegerType, StringType}
23+
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
24+
import org.apache.spark.sql.types.StringType
2425

2526
/**
2627
* Unit tests for regular expression (regexp) related SQL expressions.
@@ -178,6 +179,14 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
178179
checkEvaluation(nonNullExpr, "num-num", row1)
179180
}
180181

182+
test("SPARK-22570: RegExpReplace should not create a lot of global variables") {
183+
val ctx = new CodegenContext
184+
RegExpReplace(Literal("100"), Literal("(\\d+)"), Literal("num")).genCode(ctx)
185+
// four global variables (lastRegex, pattern, lastReplacement, and lastReplacementInUTF8)
186+
// are always required
187+
assert(ctx.mutableStates.length == 4)
188+
}
189+
181190
test("RegexExtract") {
182191
val row1 = create_row("100-200", "(\\d+)-(\\d+)", 1)
183192
val row2 = create_row("100-200", "(\\d+)-(\\d+)", 2)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
2020
import org.apache.spark.sql.catalyst.dsl.expressions._
2121
import org.apache.spark.sql.catalyst.dsl.plans._
2222
import org.apache.spark.sql.catalyst.expressions._
23+
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
2324
import org.apache.spark.sql.catalyst.plans.PlanTest
2425
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range}
2526
import org.apache.spark.sql.catalyst.rules.RuleExecutor
@@ -164,6 +165,12 @@ class ComplexTypesSuite extends PlanTest{
164165
comparePlans(Optimizer execute query, expected)
165166
}
166167

168+
test("SPARK-22570: CreateArray should not create a lot of global variables") {
169+
val ctx = new CodegenContext
170+
CreateArray(Seq(Literal(1))).genCode(ctx)
171+
assert(ctx.mutableStates.length == 0)
172+
}
173+
167174
test("simplify map ops") {
168175
val rel = relation
169176
.select(

0 commit comments

Comments
 (0)