Skip to content

Commit 602c8c9

Browse files
committed
Strictly check that a field name can be used as a valid identifier.
1 parent 2d3692e commit 602c8c9

File tree

3 files changed

+20
-14
lines changed

3 files changed

+20
-14
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst
1919

20+
import javax.lang.model.SourceVersion
21+
2022
import org.apache.commons.lang3.reflect.ConstructorUtils
2123

2224
import org.apache.spark.internal.Logging
@@ -539,9 +541,10 @@ object ScalaReflection extends ScalaReflection {
539541

540542
val params = getConstructorParameters(t)
541543
val fields = params.map { case (fieldName, fieldType) =>
542-
if (javaKeywords.contains(fieldName)) {
543-
throw new UnsupportedOperationException(s"`$fieldName` is a reserved keyword and " +
544-
"cannot be used as field name\n" + walkedTypePath)
544+
if (SourceVersion.isKeyword(fieldName) ||
545+
!SourceVersion.isIdentifier(encodeFieldNameToIdentifier(fieldName))) {
546+
throw new UnsupportedOperationException(s"`$fieldName` is not a valid identifier of " +
547+
"Java and cannot be used as field name\n" + walkedTypePath)
545548
}
546549

547550
// SPARK-26730 inputObject won't be null with If's guard below. And KnownNotNul
@@ -784,13 +787,6 @@ object ScalaReflection extends ScalaReflection {
784787
}
785788
}
786789

787-
private val javaKeywords = Set("abstract", "assert", "boolean", "break", "byte", "case", "catch",
788-
"char", "class", "const", "continue", "default", "do", "double", "else", "extends", "false",
789-
"final", "finally", "float", "for", "goto", "if", "implements", "import", "instanceof", "int",
790-
"interface", "long", "native", "new", "null", "package", "private", "protected", "public",
791-
"return", "short", "static", "strictfp", "super", "switch", "synchronized", "this", "throw",
792-
"throws", "transient", "true", "try", "void", "volatile", "while")
793-
794790
val typeJavaMapping = Map[DataType, Class[_]](
795791
BooleanType -> classOf[Boolean],
796792
ByteType -> classOf[Byte],
@@ -849,6 +845,10 @@ object ScalaReflection extends ScalaReflection {
849845
Seq.empty
850846
}
851847
}
848+
849+
def encodeFieldNameToIdentifier(fieldName: String): String = {
850+
TermName(fieldName).encodedName.toString
851+
}
852852
}
853853

854854
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ import org.apache.spark.{SparkConf, SparkEnv}
2828
import org.apache.spark.serializer._
2929
import org.apache.spark.sql.Row
3030
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection}
31-
import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName
3231
import org.apache.spark.sql.catalyst.encoders.RowEncoder
3332
import org.apache.spark.sql.catalyst.expressions._
3433
import org.apache.spark.sql.catalyst.expressions.codegen._
@@ -311,7 +310,7 @@ case class Invoke(
311310
override def nullable: Boolean = targetObject.nullable || needNullCheck || returnNullable
312311
override def children: Seq[Expression] = targetObject +: arguments
313312

314-
private lazy val encodedFunctionName = TermName(functionName).encodedName.toString
313+
private lazy val encodedFunctionName = ScalaReflection.encodeFieldNameToIdentifier(functionName)
315314

316315
@transient lazy val method = targetObject.dataType match {
317316
case ObjectType(cls) =>

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,12 +1222,18 @@ class DatasetSuite extends QueryTest
12221222
assert(result == Set(ClassData("a", 1) -> null, ClassData("b", 2) -> ClassData("x", 2)))
12231223
}
12241224

1225-
test("better error message when use java reserved keyword as field name") {
1225+
test("better error message when use invalid java identifier as field name") {
12261226
val e = intercept[UnsupportedOperationException] {
12271227
Seq(InvalidInJava(1)).toDS()
12281228
}
12291229
assert(e.getMessage.contains(
1230-
"`abstract` is a reserved keyword and cannot be used as field name"))
1230+
"`abstract` is not a valid identifier of Java and cannot be used as field name"))
1231+
1232+
val e2 = intercept[UnsupportedOperationException] {
1233+
Seq(InvalidInJava2(1)).toDS()
1234+
}
1235+
assert(e2.getMessage.contains(
1236+
"`0` is not a valid identifier of Java and cannot be used as field name"))
12311237
}
12321238

12331239
test("Dataset should support flat input object to be null") {
@@ -1965,6 +1971,7 @@ case class NestedStruct(f: ClassData)
19651971
case class DeepNestedStruct(f: NestedStruct)
19661972

19671973
case class InvalidInJava(`abstract`: Int)
1974+
case class InvalidInJava2(`0`: Int)
19681975

19691976
/**
19701977
* A class used to test serialization using encoders. This class throws exceptions when using

0 commit comments

Comments
 (0)