Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
improve tests
  • Loading branch information
Ngone51 committed May 27, 2020
commit bb26320160b2401df54b5a2a33768284a9ec2bf0
29 changes: 29 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -617,4 +617,33 @@ class UDFSuite extends QueryTest with SharedSparkSession {
val df = Seq(("data", (TestData(50, "2"), 2))).toDF("col1", "col2")
checkAnswer(df.select(myUdf(Column("col2"))), Row(100) :: Nil)
}

test("case class as generic type of Option") {
val f = (o: Option[TestData]) => o.map(t => t.key * t.value.toInt)
val myUdf = udf(f)
val df = Seq(("data", Some(TestData(50, "2")))).toDF("col1", "col2")
checkAnswer(df.select(myUdf(Column("col2"))), Row(100) :: Nil)
}

test("more input fields than expect for case class") {
val f = (t: TestData2) => t.a * t.b
val myUdf = udf(f)
val df = Seq(("data", TestData4(50, 2, 2))).toDF("col1", "col2")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's avoid creating too many TestData variants.

here we can create a dataframe directly: spark.range(1).select(lit(50).as("a"), ...)

checkAnswer(df.select(myUdf(Column("col2"))), Row(100) :: Nil)
}

test("less input fields than expect for case class") {
val f = (t: TestData4) => t.a * t.b * t.c
val myUdf = udf(f)
val df = Seq(("data", TestData2(50, 2))).toDF("col1", "col2")
val error = intercept[AnalysisException] (df.select(myUdf(Column("col2"))))
assert(error.getMessage.contains("cannot resolve '`c`' given input columns: [a, b]"))
}

test("wrong order of input fields for case class") {
val f = (t: TestData) => t.key * t.value.toInt
val myUdf = udf(f)
val df = Seq(("data", TestData5("2", 50))).toDF("col1", "col2")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

checkAnswer(df.select(myUdf(Column("col2"))), Row(100) :: Nil)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,8 @@ private[sql] object SQLTestData {
case class TestData(key: Int, value: String)
case class TestData2(a: Int, b: Int)
case class TestData3(a: Int, b: Option[Int])
case class TestData4(a: Int, b: Int, c: Int)
case class TestData5(value: String, key: Int)
case class LargeAndSmallInts(a: Int, b: Int)
case class DecimalData(a: BigDecimal, b: BigDecimal)
case class BinaryData(a: Array[Byte], b: Int)
Expand Down