Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1410,8 +1410,45 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
override def children: Seq[Expression] = beanInstance +: setters.values.toSeq
override def dataType: DataType = beanInstance.dataType

override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
private lazy val resolvedSetters = {
assert(beanInstance.dataType.isInstanceOf[ObjectType])

val ObjectType(beanClass) = beanInstance.dataType
setters.map {
case (name, expr) =>
// Looking for known type mapping.
// But also looking for general `Object`-type parameter for generic methods.
val paramTypes = ScalaReflection.expressionJavaClasses(Seq(expr)) ++ Seq(classOf[Object])
val methods = paramTypes.flatMap { fieldClass =>
try {
Some(beanClass.getDeclaredMethod(name, fieldClass))
} catch {
case e: NoSuchMethodException => None
}
}
if (methods.isEmpty) {
throw new NoSuchMethodException(s"""A method named "$name" is not declared """ +
"in any enclosing class nor any supertype")
}
methods.head -> expr
}
}

override def eval(input: InternalRow): Any = {
val instance = beanInstance.eval(input)
if (instance != null) {
val bean = instance.asInstanceOf[Object]
resolvedSetters.foreach {
case (setter, expr) =>
val paramVal = expr.eval(input)
// We don't call setter if input value is null.
if (paramVal != null) {
setter.invoke(bean, paramVal.asInstanceOf[AnyRef])
}
}
}
instance
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val instanceGen = beanInstance.genCode(ctx)
Expand All @@ -1424,7 +1461,9 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
val fieldGen = fieldValue.genCode(ctx)
s"""
|${fieldGen.code}
|$javaBeanInstance.$setterMethod(${fieldGen.value});
|if (!${fieldGen.isNull}) {
| $javaBeanInstance.$setterMethod(${fieldGen.value});
|}
""".stripMargin
}
val initializeCode = ctx.splitExpressionsWithCurrentInputs(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {

protected def checkEvaluation(
expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = {
val expr = prepareEvaluation(expression)
// Make it as method to obtain fresh expression everytime.
def expr = prepareEvaluation(expression)
val catalystValue = CatalystTypeConverters.convertToCatalyst(expected)
checkEvaluationWithoutCodegen(expr, catalystValue, inputRow)
checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow)
Expand Down Expand Up @@ -111,12 +112,14 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
val errMsg = intercept[T] {
eval
}.getMessage
if (errMsg != expectedErrMsg) {
if (!errMsg.contains(expectedErrMsg)) {
fail(s"Expected error message is `$expectedErrMsg`, but `$errMsg` found")
}
}
}
val expr = prepareEvaluation(expression)

// Make it as method to obtain fresh expression everytime.
def expr = prepareEvaluation(expression)
checkException(evaluateWithoutCodegen(expr, inputRow), "non-codegen mode")
checkException(evaluateWithGeneratedMutableProjection(expr, inputRow), "codegen mode")
if (GenerateUnsafeProjection.canSupport(expr.dataType)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,46 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Invoke(funcSubObj, "binOp", DoubleType, inputSum), 0.75, InternalRow.apply(1, 0.25))
}

test("SPARK-23593: InitializeJavaBean should support interpreted execution") {
val list = new java.util.LinkedList[Int]()
list.add(1)

val initializeBean = InitializeJavaBean(Literal.fromObject(new java.util.LinkedList[Int]),
Map("add" -> Literal(1)))
checkEvaluation(initializeBean, list, InternalRow.fromSeq(Seq()))

val initializeWithNonexistingMethod = InitializeJavaBean(
Literal.fromObject(new java.util.LinkedList[Int]),
Map("nonexisting" -> Literal(1)))
checkExceptionInExpression[Exception](initializeWithNonexistingMethod,
InternalRow.fromSeq(Seq()),
"""A method named "nonexisting" is not declared in any enclosing class """ +
"nor any supertype")

val initializeWithWrongParamType = InitializeJavaBean(
Literal.fromObject(new TestBean),
Map("setX" -> Literal("1")))
intercept[Exception] {
evaluateWithoutCodegen(initializeWithWrongParamType, InternalRow.fromSeq(Seq()))
}.getMessage.contains(
"""A method named "setX" is not declared in any enclosing class """ +
"nor any supertype")
}

test("InitializeJavaBean doesn't call setters if input in null") {
val initializeBean = InitializeJavaBean(
Literal.fromObject(new TestBean),
Map("setNonPrimitive" -> Literal(null)))
evaluateWithoutCodegen(initializeBean, InternalRow.fromSeq(Seq()))
evaluateWithGeneratedMutableProjection(initializeBean, InternalRow.fromSeq(Seq()))

val initializeBean2 = InitializeJavaBean(
Literal.fromObject(new TestBean),
Map("setNonPrimitive" -> Literal("string")))
evaluateWithoutCodegen(initializeBean2, InternalRow.fromSeq(Seq()))
evaluateWithGeneratedMutableProjection(initializeBean2, InternalRow.fromSeq(Seq()))
}

test("SPARK-23585: UnwrapOption should support interpreted execution") {
val cls = classOf[Option[Int]]
val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
Expand Down Expand Up @@ -278,3 +318,11 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
}

class TestBean extends Serializable {
private var x: Int = 0

def setX(i: Int): Unit = x = i
def setNonPrimitive(i: AnyRef): Unit =
assert(i != null, "this setter should not be called with null.")
}