Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2707,13 +2707,13 @@ class Analyzer(

case p => p transformExpressionsUp {

case udf @ ScalaUDF(_, _, inputs, inputPrimitives, _, _, _, _)
if inputPrimitives.contains(true) =>
case udf @ ScalaUDF(_, _, inputs, _, _, _, _)
if udf.inputPrimitives.contains(true) =>
// Otherwise, add special handling of null for fields that can't accept null.
// The result of operations like this, when passed null, is generally to return null.
assert(inputPrimitives.length == inputs.length)
assert(udf.inputPrimitives.length == inputs.length)

val inputPrimitivesPair = inputPrimitives.zip(inputs)
val inputPrimitivesPair = udf.inputPrimitives.zip(inputs)
val inputNullCheck = inputPrimitivesPair.collect {
case (isPrimitive, input) if isPrimitive && input.nullable =>
IsNull(input)
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType, InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Count, Sum}
Expand Down Expand Up @@ -326,20 +327,21 @@ class AnalysisSuite extends AnalysisTest with Matchers {
}

// non-primitive parameters do not need special null handling
val udf1 = ScalaUDF((s: String) => "x", StringType, string :: Nil, false :: Nil)
val udf1 = ScalaUDF((s: String) => "x", StringType, string :: Nil,
Option(ExpressionEncoder[String]()) :: Nil)
val expected1 = udf1
checkUDF(udf1, expected1)

// only primitive parameter needs special null handling
val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil,
false :: true :: Nil)
Option(ExpressionEncoder[String]()) :: Option(ExpressionEncoder[Double]()) :: Nil)
val expected2 =
If(IsNull(double), nullResult, udf2.copy(children = string :: KnownNotNull(double) :: Nil))
checkUDF(udf2, expected2)

// special null handling should apply to all primitive parameters
val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil,
true :: true :: Nil)
Option(ExpressionEncoder[Short]()) :: Option(ExpressionEncoder[Double]()) :: Nil)
val expected3 = If(
IsNull(short) || IsNull(double),
nullResult,
Expand All @@ -351,7 +353,7 @@ class AnalysisSuite extends AnalysisTest with Matchers {
(s: Short, d: Double) => "x",
StringType,
short :: nonNullableDouble :: Nil,
true :: true :: Nil)
Option(ExpressionEncoder[Short]()) :: Option(ExpressionEncoder[Double]()) :: Nil)
val expected4 = If(
IsNull(short),
nullResult,
Expand All @@ -362,8 +364,12 @@ class AnalysisSuite extends AnalysisTest with Matchers {
test("SPARK-24891 Fix HandleNullInputsForUDF rule") {
val a = testRelation.output(0)
val func = (x: Int, y: Int) => x + y
val udf1 = ScalaUDF(func, IntegerType, a :: a :: Nil, false :: false :: Nil)
val udf2 = ScalaUDF(func, IntegerType, a :: udf1 :: Nil, false :: false :: Nil)
val udf1 = ScalaUDF(func, IntegerType, a :: a :: Nil,
Option(ExpressionEncoder[java.lang.Integer]()) ::
Option(ExpressionEncoder[java.lang.Integer]()) :: Nil)
val udf2 = ScalaUDF(func, IntegerType, a :: udf1 :: Nil,
Option(ExpressionEncoder[java.lang.Integer]()) ::
Option(ExpressionEncoder[java.lang.Integer]()) :: Nil)
val plan = Project(Alias(udf2, "")() :: Nil, testRelation)
comparePlans(plan.analyze, plan.analyze.analyze)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,20 @@ package org.apache.spark.sql.catalyst.expressions
import java.util.Locale

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DecimalType, IntegerType, StringType}

class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {

test("basic") {
val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil, true :: Nil)
val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil,
Option(ExpressionEncoder[Int]()) :: Nil)
checkEvaluation(intUdf, 2)

val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, false :: Nil)
val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil,
Option(ExpressionEncoder[String]()) :: Nil)
checkEvaluation(stringUdf, "ax")
}

Expand All @@ -39,7 +42,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
(s: String) => s.toLowerCase(Locale.ROOT),
StringType,
Literal.create(null, StringType) :: Nil,
false :: Nil)
Option(ExpressionEncoder[String]()) :: Nil)

val e1 = intercept[SparkException](udf.eval())
assert(e1.getMessage.contains("Failed to execute user defined function"))
Expand All @@ -52,7 +55,8 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {

test("SPARK-22695: ScalaUDF should not use global variables") {
val ctx = new CodegenContext
ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, false :: Nil).genCode(ctx)
ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil,
Option(ExpressionEncoder[String]()) :: Nil).genCode(ctx)
assert(ctx.inlinedMutableStates.isEmpty)
}

Expand All @@ -61,7 +65,8 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
val udf = ScalaUDF(
(a: java.math.BigDecimal) => a.multiply(new java.math.BigDecimal(100)),
DecimalType.SYSTEM_DEFAULT,
Literal(BigDecimal("12345678901234567890.123")) :: Nil, false :: Nil)
Literal(BigDecimal("12345678901234567890.123")) :: Nil,
Option(ExpressionEncoder[java.math.BigDecimal]()) :: Nil)
val e1 = intercept[ArithmeticException](udf.eval())
assert(e1.getMessage.contains("cannot be represented as Decimal"))
val e2 = intercept[SparkException] {
Expand All @@ -73,7 +78,8 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
val udf = ScalaUDF(
(a: java.math.BigDecimal) => a.multiply(new java.math.BigDecimal(100)),
DecimalType.SYSTEM_DEFAULT,
Literal(BigDecimal("12345678901234567890.123")) :: Nil, false :: Nil)
Literal(BigDecimal("12345678901234567890.123")) :: Nil,
Option(ExpressionEncoder[java.math.BigDecimal]()) :: Nil)
checkEvaluation(udf, null)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
Expand Down Expand Up @@ -244,7 +245,8 @@ class EliminateSortsSuite extends PlanTest {
}

test("should not remove orderBy in groupBy clause with ScalaUDF as aggs") {
val scalaUdf = ScalaUDF((s: Int) => s, IntegerType, 'a :: Nil, true :: Nil)
val scalaUdf = ScalaUDF((s: Int) => s, IntegerType, 'a :: Nil,
Option(ExpressionEncoder[Int]()) :: Nil)
val projectPlan = testRelation.select('a, 'b)
val orderByPlan = projectPlan.orderBy('a.asc, 'b.desc)
val groupByPlan = orderByPlan.groupBy('a)(scalaUdf)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.{AliasIdentifier, FunctionIdentifier, InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.dsl.expressions.DslString
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.plans.{LeftOuter, NaturalJoin, SQLHelper}
Expand Down Expand Up @@ -594,7 +595,8 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
}

test("toJSON should not throws java.lang.StackOverflowError") {
val udf = ScalaUDF(SelfReferenceUDF(), BooleanType, Seq("col1".attr), false :: Nil)
val udf = ScalaUDF(SelfReferenceUDF(), BooleanType, Seq("col1".attr),
Option(ExpressionEncoder[String]()) :: Nil)
// Should not throw java.lang.StackOverflowError
udf.toJSON
}
Expand Down
Loading