From 0dcc9923816114ae6e413af1ca8f195b150e0b5d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 8 Jul 2015 10:55:09 +0800 Subject: [PATCH 1/7] Add float coercion on SparkR. --- R/pkg/R/deserialize.R | 1 + core/src/main/scala/org/apache/spark/api/r/SerDe.scala | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index d961bbc38368..7d1f6b0819ed 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -23,6 +23,7 @@ # Int -> integer # String -> character # Boolean -> logical +# Float -> double # Double -> double # Long -> double # Array[Byte] -> raw diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 56adc857d4ce..d5b4260bf452 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -179,6 +179,7 @@ private[spark] object SerDe { // Int -> integer // String -> character // Boolean -> logical + // Float -> double // Double -> double // Long -> double // Array[Byte] -> raw @@ -215,6 +216,9 @@ private[spark] object SerDe { case "long" | "java.lang.Long" => writeType(dos, "double") writeDouble(dos, value.asInstanceOf[Long].toDouble) + case "float" | "java.lang.Float" => + writeType(dos, "double") + writeDouble(dos, value.asInstanceOf[Float].toDouble) case "double" | "java.lang.Double" => writeType(dos, "double") writeDouble(dos, value.asInstanceOf[Double]) From 8db3244918c05173f0b28d2c929b322dd245bdc2 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 8 Jul 2015 17:09:01 +0800 Subject: [PATCH 2/7] schema also needs to support float. add test case. --- R/pkg/R/schema.R | 1 + R/pkg/inst/tests/test_sparkSQL.R | 8 ++++++++ .../main/scala/org/apache/spark/sql/api/r/SQLUtils.scala | 1 + 3 files changed, 10 insertions(+) diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index 15e2bdbd55d7..06df43068768 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -123,6 +123,7 @@ structField.character <- function(x, type, nullable = TRUE) { } options <- c("byte", "integer", + "float", "double", "numeric", "character", diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index b0ea38854304..59c71eb68208 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -108,6 +108,14 @@ test_that("create DataFrame from RDD", { expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + localDF <- data.frame(name=c("John", "Smith", "Sarah"), age=c(19, 23, 18), height=c(164.10, 181.4, 173.7)) + schema <- structType(structField("name", "string"), structField("age", "integer"), structField("height", "float")) + df <- createDataFrame(sqlContext, localDF, schema) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) + expect_equal(columns(df), c("name", "age", "height")) + expect_equal(dtypes(df), list(c("name", "string"), c("age", "double"), c("height", "double"))) }) test_that("convert NAs to null type in DataFrames", { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 43b62f0e822f..e640642a9d7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -47,6 +47,7 @@ private[r] object SQLUtils { dataType match { case "byte" => org.apache.spark.sql.types.ByteType case "integer" => org.apache.spark.sql.types.IntegerType + case "float" => org.apache.spark.sql.types.FloatType case "double" => org.apache.spark.sql.types.DoubleType case "numeric" => org.apache.spark.sql.types.DoubleType case "character" => org.apache.spark.sql.types.StringType From 6f9159dac8126cb1b714f9d37ed59aa932d5fad8 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 9 Jul 2015 10:42:53 +0800 Subject: [PATCH 3/7] Add another test case. --- R/pkg/inst/tests/test_sparkSQL.R | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 59c71eb68208..362e28940348 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -116,6 +116,16 @@ test_that("create DataFrame from RDD", { expect_equal(count(df), 3) expect_equal(columns(df), c("name", "age", "height")) expect_equal(dtypes(df), list(c("name", "string"), c("age", "double"), c("height", "double"))) + + hiveCtx <- tryCatch({ + newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) + }, error = function(err) { + skip("Hive is not build with SparkSQL, skipped") + }) + sql(hiveCtx, "CREATE TABLE people (name string, age double, height float) row format delimited fields terminated by ','") + insertInto(df, "people") + expect_equal(sql(hiveCtx, "SELECT age from people"), c(19, 23, 18)) + expect_equal(sql(hiveCtx, "SELECT height from people"), c(164.10, 181.4, 173.7)) }) test_that("convert NAs to null type in DataFrames", { From 30c2a404b4760849910d49419a1d65f5fc8c582f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 10 Jul 2015 01:31:45 +0800 Subject: [PATCH 4/7] Update test case. --- R/pkg/inst/tests/test_sparkSQL.R | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 362e28940348..1afffe5584cc 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -109,23 +109,16 @@ test_that("create DataFrame from RDD", { expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) - localDF <- data.frame(name=c("John", "Smith", "Sarah"), age=c(19, 23, 18), height=c(164.10, 181.4, 173.7)) - schema <- structType(structField("name", "string"), structField("age", "integer"), structField("height", "float")) - df <- createDataFrame(sqlContext, localDF, schema) - expect_is(df, "DataFrame") - expect_equal(count(df), 3) - expect_equal(columns(df), c("name", "age", "height")) - expect_equal(dtypes(df), list(c("name", "string"), c("age", "double"), c("height", "double"))) - + df <- jsonFile(sqlContext, jsonPathNa) hiveCtx <- tryCatch({ newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) }, error = function(err) { skip("Hive is not build with SparkSQL, skipped") }) - sql(hiveCtx, "CREATE TABLE people (name string, age double, height float) row format delimited fields terminated by ','") + sql(hiveCtx, "CREATE TABLE people (name string, age double, height float)") insertInto(df, "people") - expect_equal(sql(hiveCtx, "SELECT age from people"), c(19, 23, 18)) - expect_equal(sql(hiveCtx, "SELECT height from people"), c(164.10, 181.4, 173.7)) + expect_equal(sql(hiveCtx, "SELECT age from people WHERE name = 'Bob'"), c(16)) + expect_equal(sql(hiveCtx, "SELECT height from people WHERE name ='Bob'"), c(176.5)) }) test_that("convert NAs to null type in DataFrames", { From 733015a29defded46c1b844a622e10e5ed1ee571 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 10 Jul 2015 12:19:46 +0800 Subject: [PATCH 5/7] Add test case for DataFrame with float type. --- R/pkg/inst/tests/test_sparkSQL.R | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 1afffe5584cc..c801282a6164 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -119,6 +119,13 @@ test_that("create DataFrame from RDD", { insertInto(df, "people") expect_equal(sql(hiveCtx, "SELECT age from people WHERE name = 'Bob'"), c(16)) expect_equal(sql(hiveCtx, "SELECT height from people WHERE name ='Bob'"), c(176.5)) + + schema <- structType(structField("name", "string"), structField("age", "integer"), + structField("height", "float")) + df2 <- createDataFrame(sqlContext, df.toRDD, schema) + expect_equal(columns(df2), c("name", "age", "height")) + expect_equal(dtypes(df2), list(c("name", "string"), c("age", "int"), c("height", "float"))) + expect_equal(collect(where(df2, df2$name == "Bob")), c("Bob", 16, 176.5)) }) test_that("convert NAs to null type in DataFrames", { From dbf0c1bf6d0234ee5df55ca9040430e9ab3da872 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 14 Jul 2015 13:45:42 +0800 Subject: [PATCH 6/7] Implicitly convert Double to Float based on provided schema. --- R/pkg/inst/tests/test_sparkSQL.R | 8 ++++++++ .../main/scala/org/apache/spark/api/r/SerDe.scala | 13 +++++++++++-- .../scala/org/apache/spark/sql/api/r/SQLUtils.scala | 6 +++--- 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index c801282a6164..76f74f80834a 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -126,6 +126,14 @@ test_that("create DataFrame from RDD", { expect_equal(columns(df2), c("name", "age", "height")) expect_equal(dtypes(df2), list(c("name", "string"), c("age", "int"), c("height", "float"))) expect_equal(collect(where(df2, df2$name == "Bob")), c("Bob", 16, 176.5)) + + localDF <- data.frame(name=c("John", "Smith", "Sarah"), age=c(19, 23, 18), height=c(164.10, 181.4, 173.7)) + df <- createDataFrame(sqlContext, localDF, schema) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) + expect_equal(columns(df), c("name", "age", "height")) + expect_equal(dtypes(df), list(c("name", "string"), c("age", "int"), c("height", "float"))) + expect_equal(collect(where(df, df$name == "John")), c("John", 19, 164.10)) }) test_that("convert NAs to null type in DataFrames", { diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index d5b4260bf452..f5592d038bcc 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -46,9 +46,18 @@ private[spark] object SerDe { dis.readByte().toChar } - def readObject(dis: DataInputStream): Object = { + def readObject(dis: DataInputStream, typeName: String = ""): Object = { val dataType = readObjectType(dis) - readTypedObject(dis, dataType) + val data = readTypedObject(dis, dataType) + doConversion(data, dataType, typeName) + } + + def doConversion(data: Object, dataType: Char, typeName: String): Object = { + dataType match { + case 'd' if typeName == "Float" => + new java.lang.Float(data.asInstanceOf[java.lang.Double]) + case _ => data + } } def readTypedObject( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index e640642a9d7e..04150656dec6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -69,7 +69,7 @@ private[r] object SQLUtils { def createDF(rdd: RDD[Array[Byte]], schema: StructType, sqlContext: SQLContext): DataFrame = { val num = schema.fields.size - val rowRDD = rdd.map(bytesToRow) + val rowRDD = rdd.map(bytesToRow(_, schema)) sqlContext.createDataFrame(rowRDD, schema) } @@ -77,12 +77,12 @@ private[r] object SQLUtils { df.map(r => rowToRBytes(r)) } - private[this] def bytesToRow(bytes: Array[Byte]): Row = { + private[this] def bytesToRow(bytes: Array[Byte], schema: StructType): Row = { val bis = new ByteArrayInputStream(bytes) val dis = new DataInputStream(bis) val num = SerDe.readInt(dis) Row.fromSeq((0 until num).map { i => - SerDe.readObject(dis) + SerDe.readObject(dis, schema.fields(i).dataType.typeName) }.toSeq) } From c86dc0e0ddbe86fb23d46565ac931ddf6983ab9d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 15 Jul 2015 12:46:12 +0800 Subject: [PATCH 7/7] For comments. --- .../main/scala/org/apache/spark/api/r/SerDe.scala | 13 ++----------- .../scala/org/apache/spark/sql/api/r/SQLUtils.scala | 10 +++++++++- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index f5592d038bcc..d5b4260bf452 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -46,18 +46,9 @@ private[spark] object SerDe { dis.readByte().toChar } - def readObject(dis: DataInputStream, typeName: String = ""): Object = { + def readObject(dis: DataInputStream): Object = { val dataType = readObjectType(dis) - val data = readTypedObject(dis, dataType) - doConversion(data, dataType, typeName) - } - - def doConversion(data: Object, dataType: Char, typeName: String): Object = { - dataType match { - case 'd' if typeName == "Float" => - new java.lang.Float(data.asInstanceOf[java.lang.Double]) - case _ => data - } + readTypedObject(dis, dataType) } def readTypedObject( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 04150656dec6..92861ab038f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -77,12 +77,20 @@ private[r] object SQLUtils { df.map(r => rowToRBytes(r)) } + private[this] def doConversion(data: Object, dataType: DataType): Object = { + data match { + case d: java.lang.Double if dataType == FloatType => + new java.lang.Float(d) + case _ => data + } + } + private[this] def bytesToRow(bytes: Array[Byte], schema: StructType): Row = { val bis = new ByteArrayInputStream(bytes) val dis = new DataInputStream(bis) val num = SerDe.readInt(dis) Row.fromSeq((0 until num).map { i => - SerDe.readObject(dis, schema.fields(i).dataType.typeName) + doConversion(SerDe.readObject(dis), schema.fields(i).dataType) }.toSeq) }