Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
import org.apache.spark.util.Utils


/**
Expand Down Expand Up @@ -377,6 +378,23 @@ object ScalaReflection extends ScalaReflection {
expressions.Literal.create(null, ObjectType(cls)),
newInstance
)

case t if isSubtype(t, localTypeOf[Enumeration#Value]) =>
// package example
// object Foo extends Enumeration {
// type Foo = Value
// val E1, E2 = Value
// }
// the fullName of tpe is example.Foo.Foo, but we need example.Foo so that
// we can call example.Foo.withName to deserialize string to enumeration.
val parent = t.asInstanceOf[TypeRef].pre.typeSymbol.asClass
val cls = mirror.runtimeClass(parent)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

follow scala reflection.

StaticInvoke(
cls,
ObjectType(getClassFromType(t)),
"withName",
createDeserializerForString(path, false) :: Nil,
returnNullable = false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should also be true?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, if I'm correct, it would be good to include a test case for this case. What happens if you actually do operations on the null value stored as an enum? Is the nullability of the resulting schema correct. Do operations like .isNotNull work correctly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need returnNullable = false since Enumeration.withName will never return null.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you are correct.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if you actually do operations on the null value stored as an enum

We will get an exception, and the behavior is same as Java enum what we supported.

        case other if other.isEnum =>
          createSerializerForString(
            Invoke(inputObject, "name", ObjectType(classOf[String]), returnNullable = false))

}
}

Expand Down Expand Up @@ -562,6 +580,14 @@ object ScalaReflection extends ScalaReflection {
}
createSerializerForObject(inputObject, fields)

case t if isSubtype(t, localTypeOf[Enumeration#Value]) =>
createSerializerForString(
Invoke(
inputObject,
"toString",
ObjectType(classOf[java.lang.String]),
returnNullable = false))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

true?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as withName, toString also return a non-null value.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. I agree.


case _ =>
throw new UnsupportedOperationException(
s"No Encoder found for $tpe\n" + walkedTypePath)
Expand Down Expand Up @@ -739,6 +765,8 @@ object ScalaReflection extends ScalaReflection {
val Schema(dataType, nullable) = schemaFor(fieldType)
StructField(fieldName, dataType, nullable)
}), nullable = true)
case t if isSubtype(t, localTypeOf[Enumeration#Value]) =>
Schema(StringType, nullable = true)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we are returning the schema generated from enum as nullable, but the serializer expression is not nullable (because it does not produce null). Why is that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two side of nullable.

  • For schema, nullable means the column can be null value.
  • For returnNullable which used in Invoke and StaticInvoke means the method promise will never produce a null value.

Some code get from Invoke

def invoke(
      obj: Any,
      method: Method,
      arguments: Seq[Expression],
      input: InternalRow,
      dataType: DataType): Any = {
    val args = arguments.map(e => e.eval(input).asInstanceOf[Object])
    if (needNullCheck && args.exists(_ == null)) {
      // return null if one of arguments is null
      null
    } else {
      val ret = method.invoke(obj, args: _*)
      val boxedClass = ScalaReflection.typeBoxedJavaMapping.get(dataType)
      if (boxedClass.isDefined) {
        boxedClass.get.cast(ret)
      } else {
        ret
      }
    }
  }

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So effectively, in the serialized form, we are allowing nulls to be present in the column mapping to the enum field. But if there is indeed a row with a null in that column and we attempt to deserialize that rows, then it will cause a runtime failure...isnt it?

If this understanding is correct, then serializing will never produce null, and even if there is a null, serializing will fail. Then why keep this column nullable = true? I am not necessarily opposed to it, I am just trying to understand the rationale. cc @marmbrus

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But if there is indeed a row with a null in that column and we attempt to deserialize that rows, then it will cause a runtime failure...isnt it

Actually we will get null back. It's a trick that if input value is null the serializing will return null.

case other =>
throw new UnsupportedOperationException(s"Schema for type $other is not supported")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp}
import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.FooEnum.FooEnum
import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression, If, SpecificInternalRow, UpCast}
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, NewInstance}
Expand Down Expand Up @@ -90,6 +91,13 @@ case class FooWithAnnotation(f1: String @FooAnnotation, f2: Option[String] @FooA

case class SpecialCharAsFieldData(`field.1`: String, `field 2`: String)

object FooEnum extends Enumeration {
type FooEnum = Value
val E1, E2 = Value
}

case class FooClassWithEnum(i: Int, e: FooEnum)

object TestingUDT {
@SQLUserDefinedType(udt = classOf[NestedStructUDT])
class NestedStruct(val a: Integer, val b: Long, val c: Double)
Expand Down Expand Up @@ -437,4 +445,11 @@ class ScalaReflectionSuite extends SparkFunSuite {
StructField("f2", StringType))))
assert(deserializerFor[FooWithAnnotation].dataType == ObjectType(classOf[FooWithAnnotation]))
}

test("SPARK-32585: Support scala enumeration in ScalaReflection") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would consider putting a test case in ExpressionEncoderSuite as well as I think that checks various combinations of evaluation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, not realized this test.

assert(serializerFor[FooClassWithEnum].dataType == StructType(Seq(
StructField("i", IntegerType, false),
StructField("e", StringType, true))))
assert(deserializerFor[FooClassWithEnum].dataType == ObjectType(classOf[FooClassWithEnum]))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import scala.collection.mutable.ArrayBuffer
import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData}
import org.apache.spark.sql.catalyst.{FooClassWithEnum, FooEnum, OptionalData, PrimitiveData}
import org.apache.spark.sql.catalyst.analysis.AnalysisTest
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.AttributeReference
Expand Down Expand Up @@ -389,6 +389,14 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes
assert(e.getMessage.contains("tuple with more than 22 elements are not supported"))
}

encodeDecodeTest((1, FooEnum.E1), "Tuple with Int and scala Enum")
encodeDecodeTest((null, FooEnum.E1, FooEnum.E2), "Tuple with Null and scala Enum")
encodeDecodeTest(Seq(FooEnum.E1, null), "Seq with scala Enum")
encodeDecodeTest(Map("key" -> FooEnum.E1), "Map with String key and scala Enum")
encodeDecodeTest(Map(FooEnum.E1 -> "value"), "Map with scala Enum key and String value")
encodeDecodeTest(FooClassWithEnum(1, FooEnum.E1), "case class with Int and scala Enum")
encodeDecodeTest(FooEnum.E1, "scala Enum")

// Scala / Java big decimals ----------------------------------------------------------

encodeDecodeTest(BigDecimal(("9" * 20) + "." + "9" * 18),
Expand Down
15 changes: 14 additions & 1 deletion sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.scalatest.exceptions.TestFailedException
import org.scalatest.prop.TableDrivenPropertyChecks._

import org.apache.spark.{SparkException, TaskContext}
import org.apache.spark.sql.catalyst.ScroogeLikeExample
import org.apache.spark.sql.catalyst.{FooClassWithEnum, FooEnum, ScroogeLikeExample}
import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder}
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi}
import org.apache.spark.sql.catalyst.util.sideBySide
Expand Down Expand Up @@ -1926,6 +1926,19 @@ class DatasetSuite extends QueryTest
}
}
}

test("SPARK-32585: Support scala enumeration in ScalaReflection") {
checkDataset(
Seq(FooClassWithEnum(1, FooEnum.E1), FooClassWithEnum(2, FooEnum.E2)).toDS(),
Seq(FooClassWithEnum(1, FooEnum.E1), FooClassWithEnum(2, FooEnum.E2)): _*
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related to the above comment, I would add a test case where the enum value is null.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the null test.

)

// test null
checkDataset(
Seq(FooClassWithEnum(1, null), FooClassWithEnum(2, FooEnum.E2)).toDS(),
Seq(FooClassWithEnum(1, null), FooClassWithEnum(2, FooEnum.E2)): _*
)
}
}

object AssertExecutionId {
Expand Down