Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Support ArrayType in createDataFrame().
  • Loading branch information
Sun Rui committed Aug 27, 2015
commit 9b5bd05fa4fdcdb14d3d004c52274a1a6264ff8c
4 changes: 2 additions & 2 deletions R/pkg/R/SQLContext.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ infer_type <- function(x) {
stopifnot(length(x) > 0)
names <- names(x)
if (is.null(names)) {
list(type = "array", elementType = infer_type(x[[1]]), containsNull = TRUE)
paste0("array<", infer_type(x[[1]]), ">")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify this is to support vectors of the form c(1, 2, 3) etc. ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for list. Next changed one is for vector.

} else {
# StructType
types <- lapply(x, infer_type)
Expand All @@ -59,7 +59,7 @@ infer_type <- function(x) {
do.call(structType, fields)
}
} else if (length(x) > 1) {
list(type = "array", elementType = type, containsNull = TRUE)
paste0("array<", infer_type(x[[1]]), ">")
} else {
type
}
Expand Down
52 changes: 33 additions & 19 deletions R/pkg/R/schema.R
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,35 @@ structField.jobj <- function(x) {
obj
}

checkType <- function(type) {
primtiveTypes <- c("byte",
"integer",
"float",
"double",
"numeric",
"character",
"string",
"binary",
"raw",
"logical",
"boolean",
"timestamp",
"date")
if (type %in% primtiveTypes) {
return()
} else {
m <- regexec("^array<(.*)>$", type)
matchedStrings <- regmatches(type, m)
if (length(matchedStrings[[1]]) >= 2) {
elemType <- matchedStrings[[1]][2]
checkType(elemType)
return()
}
}

stop(paste("Unsupported type for Dataframe:", type))
}

structField.character <- function(x, type, nullable = TRUE) {
if (class(x) != "character") {
stop("Field name must be a string.")
Expand All @@ -124,28 +153,13 @@ structField.character <- function(x, type, nullable = TRUE) {
if (class(nullable) != "logical") {
stop("nullable must be either TRUE or FALSE")
}
options <- c("byte",
"integer",
"float",
"double",
"numeric",
"character",
"string",
"binary",
"raw",
"logical",
"boolean",
"timestamp",
"date")
dataType <- if (type %in% options) {
type
} else {
stop(paste("Unsupported type for Dataframe:", type))
}

checkType(type)

sfObj <- callJStatic("org.apache.spark.sql.api.r.SQLUtils",
"createStructField",
x,
dataType,
type,
nullable)
structField(sfObj)
}
Expand Down
22 changes: 15 additions & 7 deletions R/pkg/inst/tests/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,8 @@ test_that("infer types", {
expect_equal(infer_type(TRUE), "boolean")
expect_equal(infer_type(as.Date("2015-03-11")), "date")
expect_equal(infer_type(as.POSIXlt("2015-03-11 12:13:04.043")), "timestamp")
expect_equal(infer_type(c(1L, 2L)),
list(type = "array", elementType = "integer", containsNull = TRUE))
expect_equal(infer_type(list(1L, 2L)),
list(type = "array", elementType = "integer", containsNull = TRUE))
expect_equal(infer_type(c(1L, 2L)), "array<integer>")
expect_equal(infer_type(list(1L, 2L)), "array<integer>")
testStruct <- infer_type(list(a = 1L, b = "2"))
expect_equal(class(testStruct), "structType")
checkStructField(testStruct$fields()[[1]], "a", "IntegerType", TRUE)
Expand Down Expand Up @@ -244,8 +242,7 @@ test_that("create DataFrame with different data types", {
expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE))
})

# TODO: enable this test after fix serialization for nested object
#test_that("create DataFrame with nested array and struct", {
test_that("create DataFrame with nested array and struct", {
# e <- new.env()
# assign("n", 3L, envir = e)
# l <- list(1:10, list("a", "b"), e, list(a="aa", b=3L))
Expand All @@ -255,7 +252,18 @@ test_that("create DataFrame with different data types", {
# expect_equal(count(df), 1)
# ldf <- collect(df)
# expect_equal(ldf[1,], l[[1]])
#})


# ArrayType only for now
l <- list(as.list(1:10), list("a", "b"))
df <- createDataFrame(sqlContext, list(l), c("a", "b"))
expect_equal(dtypes(df), list(c("a", "array<int>"), c("b", "array<string>")))
expect_equal(count(df), 1)
ldf <- collect(df)
expect_equal(names(ldf), c("a", "b"))
expect_equal(ldf[1, 1][[1]], l[[1]])
expect_equal(ldf[1, 2][[1]], l[[2]])
})

test_that("Collect DataFrame with complex types", {
# only ArrayType now
Expand Down
10 changes: 10 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpres
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Column, DataFrame, GroupedData, Row, SQLContext, SaveMode}

import scala.util.matching.Regex

private[r] object SQLUtils {
def createSQLContext(jsc: JavaSparkContext): SQLContext = {
new SQLContext(jsc)
Expand All @@ -39,6 +41,11 @@ private[r] object SQLUtils {
StructType(fields)
}

// Support using regex in string interpolation
implicit class RegexContext(sc: StringContext) {
def r = new Regex(sc.parts.mkString, sc.parts.tail.map(_ => "x"): _*)
}

def getSQLDataType(dataType: String): DataType = {
dataType match {
case "byte" => org.apache.spark.sql.types.ByteType
Expand All @@ -54,6 +61,9 @@ private[r] object SQLUtils {
case "boolean" => org.apache.spark.sql.types.BooleanType
case "timestamp" => org.apache.spark.sql.types.TimestampType
case "date" => org.apache.spark.sql.types.DateType
case r"\Aarray<(.*)${elemType}>\Z" => {
org.apache.spark.sql.types.ArrayType(getSQLDataType(elemType))
}
case _ => throw new IllegalArgumentException(s"Invaid type $dataType")
}
}
Expand Down