Skip to content

Commit e008ad1

Browse files
mgaido91gatorsmile
authored andcommitted
[SPARK-24782][SQL] Simplify conf retrieval in SQL expressions
## What changes were proposed in this pull request? The PR simplifies the retrieval of config in `size`, as we can access them from tasks too thanks to SPARK-24250. ## How was this patch tested? existing UTs Author: Marco Gaido <[email protected]> Closes #21736 from mgaido91/SPARK-24605_followup.
1 parent ff7f6ef commit e008ad1

File tree

6 files changed

+57
-67
lines changed

6 files changed

+57
-67
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -89,15 +89,9 @@ trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression
8989
> SELECT _FUNC_(NULL);
9090
-1
9191
""")
92-
case class Size(
93-
child: Expression,
94-
legacySizeOfNull: Boolean)
95-
extends UnaryExpression with ExpectsInputTypes {
92+
case class Size(child: Expression) extends UnaryExpression with ExpectsInputTypes {
9693

97-
def this(child: Expression) =
98-
this(
99-
child,
100-
legacySizeOfNull = SQLConf.get.getConf(SQLConf.LEGACY_SIZE_OF_NULL))
94+
val legacySizeOfNull = SQLConf.get.legacySizeOfNull
10195

10296
override def dataType: DataType = IntegerType
10397
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType))

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -514,10 +514,11 @@ case class JsonToStructs(
514514
schema: DataType,
515515
options: Map[String, String],
516516
child: Expression,
517-
timeZoneId: Option[String],
518-
forceNullableSchema: Boolean)
517+
timeZoneId: Option[String] = None)
519518
extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes {
520519

520+
val forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)
521+
521522
// The JSON input data might be missing certain fields. We force the nullability
522523
// of the user-provided schema to avoid data corruptions. In particular, the parquet-mr encoder
523524
// can generate incorrect files if values are missing in columns declared as non-nullable.
@@ -531,8 +532,7 @@ case class JsonToStructs(
531532
schema = JsonExprUtils.evalSchemaExpr(schema),
532533
options = options,
533534
child = child,
534-
timeZoneId = None,
535-
forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA))
535+
timeZoneId = None)
536536

537537
def this(child: Expression, schema: Expression) = this(child, schema, Map.empty[String, String])
538538

@@ -541,13 +541,7 @@ case class JsonToStructs(
541541
schema = JsonExprUtils.evalSchemaExpr(schema),
542542
options = JsonExprUtils.convertToMapData(options),
543543
child = child,
544-
timeZoneId = None,
545-
forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA))
546-
547-
// Used in `org.apache.spark.sql.functions`
548-
def this(schema: DataType, options: Map[String, String], child: Expression) =
549-
this(schema, options, child, timeZoneId = None,
550-
forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA))
544+
timeZoneId = None)
551545

552546
override def checkInputDataTypes(): TypeCheckResult = nullableSchema match {
553547
case _: StructType | ArrayType(_: StructType, _) | _: MapType =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
2727

2828
/**
2929
* The active config object within the current scope.
30-
* Note that if you want to refer config values during execution, you have to capture them
31-
* in Driver and use the captured values in Executors.
3230
* See [[SQLConf.get]] for more information.
3331
*/
3432
def conf: SQLConf = SQLConf.get

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,43 +24,48 @@ import org.apache.spark.SparkFunSuite
2424
import org.apache.spark.sql.Row
2525
import org.apache.spark.sql.catalyst.InternalRow
2626
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils
27+
import org.apache.spark.sql.internal.SQLConf
2728
import org.apache.spark.sql.types._
2829
import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
2930
import org.apache.spark.unsafe.types.CalendarInterval
3031

3132
class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
3233

33-
def testSize(legacySizeOfNull: Boolean, sizeOfNull: Any): Unit = {
34+
def testSize(sizeOfNull: Any): Unit = {
3435
val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
3536
val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
3637
val a2 = Literal.create(Seq(1, 2), ArrayType(IntegerType))
3738

38-
checkEvaluation(Size(a0, legacySizeOfNull), 3)
39-
checkEvaluation(Size(a1, legacySizeOfNull), 0)
40-
checkEvaluation(Size(a2, legacySizeOfNull), 2)
39+
checkEvaluation(Size(a0), 3)
40+
checkEvaluation(Size(a1), 0)
41+
checkEvaluation(Size(a2), 2)
4142

4243
val m0 = Literal.create(Map("a" -> "a", "b" -> "b"), MapType(StringType, StringType))
4344
val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType))
4445
val m2 = Literal.create(Map("a" -> "a"), MapType(StringType, StringType))
4546

46-
checkEvaluation(Size(m0, legacySizeOfNull), 2)
47-
checkEvaluation(Size(m1, legacySizeOfNull), 0)
48-
checkEvaluation(Size(m2, legacySizeOfNull), 1)
47+
checkEvaluation(Size(m0), 2)
48+
checkEvaluation(Size(m1), 0)
49+
checkEvaluation(Size(m2), 1)
4950

5051
checkEvaluation(
51-
Size(Literal.create(null, MapType(StringType, StringType)), legacySizeOfNull),
52+
Size(Literal.create(null, MapType(StringType, StringType))),
5253
expected = sizeOfNull)
5354
checkEvaluation(
54-
Size(Literal.create(null, ArrayType(StringType)), legacySizeOfNull),
55+
Size(Literal.create(null, ArrayType(StringType))),
5556
expected = sizeOfNull)
5657
}
5758

5859
test("Array and Map Size - legacy") {
59-
testSize(legacySizeOfNull = true, sizeOfNull = -1)
60+
withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "true") {
61+
testSize(sizeOfNull = -1)
62+
}
6063
}
6164

6265
test("Array and Map Size") {
63-
testSize(legacySizeOfNull = false, sizeOfNull = null)
66+
withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "false") {
67+
testSize(sizeOfNull = null)
68+
}
6469
}
6570

6671
test("MapKeys/MapValues") {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala

Lines changed: 32 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
392392
val jsonData = """{"a": 1}"""
393393
val schema = StructType(StructField("a", IntegerType) :: Nil)
394394
checkEvaluation(
395-
JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId, true),
395+
JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId),
396396
InternalRow(1)
397397
)
398398
}
@@ -401,13 +401,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
401401
val jsonData = """{"a" 1}"""
402402
val schema = StructType(StructField("a", IntegerType) :: Nil)
403403
checkEvaluation(
404-
JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId, true),
404+
JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId),
405405
null
406406
)
407407

408408
// Other modes should still return `null`.
409409
checkEvaluation(
410-
JsonToStructs(schema, Map("mode" -> PermissiveMode.name), Literal(jsonData), gmtId, true),
410+
JsonToStructs(schema, Map("mode" -> PermissiveMode.name), Literal(jsonData), gmtId),
411411
null
412412
)
413413
}
@@ -416,70 +416,70 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
416416
val input = """[{"a": 1}, {"a": 2}]"""
417417
val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
418418
val output = InternalRow(1) :: InternalRow(2) :: Nil
419-
checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output)
419+
checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output)
420420
}
421421

422422
test("from_json - input=object, schema=array, output=array of single row") {
423423
val input = """{"a": 1}"""
424424
val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
425425
val output = InternalRow(1) :: Nil
426-
checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output)
426+
checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output)
427427
}
428428

429429
test("from_json - input=empty array, schema=array, output=empty array") {
430430
val input = "[ ]"
431431
val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
432432
val output = Nil
433-
checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output)
433+
checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output)
434434
}
435435

436436
test("from_json - input=empty object, schema=array, output=array of single row with null") {
437437
val input = "{ }"
438438
val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
439439
val output = InternalRow(null) :: Nil
440-
checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output)
440+
checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output)
441441
}
442442

443443
test("from_json - input=array of single object, schema=struct, output=single row") {
444444
val input = """[{"a": 1}]"""
445445
val schema = StructType(StructField("a", IntegerType) :: Nil)
446446
val output = InternalRow(1)
447-
checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output)
447+
checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output)
448448
}
449449

450450
test("from_json - input=array, schema=struct, output=null") {
451451
val input = """[{"a": 1}, {"a": 2}]"""
452452
val schema = StructType(StructField("a", IntegerType) :: Nil)
453453
val output = null
454-
checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output)
454+
checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output)
455455
}
456456

457457
test("from_json - input=empty array, schema=struct, output=null") {
458458
val input = """[]"""
459459
val schema = StructType(StructField("a", IntegerType) :: Nil)
460460
val output = null
461-
checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output)
461+
checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output)
462462
}
463463

464464
test("from_json - input=empty object, schema=struct, output=single row with null") {
465465
val input = """{ }"""
466466
val schema = StructType(StructField("a", IntegerType) :: Nil)
467467
val output = InternalRow(null)
468-
checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output)
468+
checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output)
469469
}
470470

471471
test("from_json null input column") {
472472
val schema = StructType(StructField("a", IntegerType) :: Nil)
473473
checkEvaluation(
474-
JsonToStructs(schema, Map.empty, Literal.create(null, StringType), gmtId, true),
474+
JsonToStructs(schema, Map.empty, Literal.create(null, StringType), gmtId),
475475
null
476476
)
477477
}
478478

479479
test("SPARK-20549: from_json bad UTF-8") {
480480
val schema = StructType(StructField("a", IntegerType) :: Nil)
481481
checkEvaluation(
482-
JsonToStructs(schema, Map.empty, Literal(badJson), gmtId, true),
482+
JsonToStructs(schema, Map.empty, Literal(badJson), gmtId),
483483
null)
484484
}
485485

@@ -491,14 +491,14 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
491491
c.set(2016, 0, 1, 0, 0, 0)
492492
c.set(Calendar.MILLISECOND, 123)
493493
checkEvaluation(
494-
JsonToStructs(schema, Map.empty, Literal(jsonData1), gmtId, true),
494+
JsonToStructs(schema, Map.empty, Literal(jsonData1), gmtId),
495495
InternalRow(c.getTimeInMillis * 1000L)
496496
)
497497
// The result doesn't change because the json string includes timezone string ("Z" here),
498498
// which means the string represents the timestamp string in the timezone regardless of
499499
// the timeZoneId parameter.
500500
checkEvaluation(
501-
JsonToStructs(schema, Map.empty, Literal(jsonData1), Option("PST"), true),
501+
JsonToStructs(schema, Map.empty, Literal(jsonData1), Option("PST")),
502502
InternalRow(c.getTimeInMillis * 1000L)
503503
)
504504

@@ -512,8 +512,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
512512
schema,
513513
Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss"),
514514
Literal(jsonData2),
515-
Option(tz.getID),
516-
true),
515+
Option(tz.getID)),
517516
InternalRow(c.getTimeInMillis * 1000L)
518517
)
519518
checkEvaluation(
@@ -522,8 +521,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
522521
Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss",
523522
DateTimeUtils.TIMEZONE_OPTION -> tz.getID),
524523
Literal(jsonData2),
525-
gmtId,
526-
true),
524+
gmtId),
527525
InternalRow(c.getTimeInMillis * 1000L)
528526
)
529527
}
@@ -532,7 +530,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
532530
test("SPARK-19543: from_json empty input column") {
533531
val schema = StructType(StructField("a", IntegerType) :: Nil)
534532
checkEvaluation(
535-
JsonToStructs(schema, Map.empty, Literal.create(" ", StringType), gmtId, true),
533+
JsonToStructs(schema, Map.empty, Literal.create(" ", StringType), gmtId),
536534
null
537535
)
538536
}
@@ -687,23 +685,24 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
687685

688686
test("from_json missing fields") {
689687
for (forceJsonNullableSchema <- Seq(false, true)) {
690-
val input =
691-
"""{
688+
withSQLConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA.key -> forceJsonNullableSchema.toString) {
689+
val input =
690+
"""{
692691
| "a": 1,
693692
| "c": "foo"
694693
|}
695694
|""".stripMargin
696-
val jsonSchema = new StructType()
697-
.add("a", LongType, nullable = false)
698-
.add("b", StringType, nullable = false)
699-
.add("c", StringType, nullable = false)
700-
val output = InternalRow(1L, null, UTF8String.fromString("foo"))
701-
val expr = JsonToStructs(
702-
jsonSchema, Map.empty, Literal.create(input, StringType), gmtId, forceJsonNullableSchema)
703-
checkEvaluation(expr, output)
704-
val schema = expr.dataType
705-
val schemaToCompare = if (forceJsonNullableSchema) jsonSchema.asNullable else jsonSchema
706-
assert(schemaToCompare == schema)
695+
val jsonSchema = new StructType()
696+
.add("a", LongType, nullable = false)
697+
.add("b", StringType, nullable = false)
698+
.add("c", StringType, nullable = false)
699+
val output = InternalRow(1L, null, UTF8String.fromString("foo"))
700+
val expr = JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId)
701+
checkEvaluation(expr, output)
702+
val schema = expr.dataType
703+
val schemaToCompare = if (forceJsonNullableSchema) jsonSchema.asNullable else jsonSchema
704+
assert(schemaToCompare == schema)
705+
}
707706
}
708707
}
709708

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3304,7 +3304,7 @@ object functions {
33043304
* @since 2.2.0
33053305
*/
33063306
def from_json(e: Column, schema: DataType, options: Map[String, String]): Column = withExpr {
3307-
new JsonToStructs(schema, options, e.expr)
3307+
JsonToStructs(schema, options, e.expr)
33083308
}
33093309

33103310
/**
@@ -3495,7 +3495,7 @@ object functions {
34953495
* @group collection_funcs
34963496
* @since 1.5.0
34973497
*/
3498-
def size(e: Column): Column = withExpr { new Size(e.expr) }
3498+
def size(e: Column): Column = withExpr { Size(e.expr) }
34993499

35003500
/**
35013501
* Sorts the input array for the given column in ascending order,

0 commit comments

Comments
 (0)