Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[SPARK-10049][SPARKR] Support collecting data of ArraryType in DataFr…
…ame.
  • Loading branch information
Sun Rui committed Aug 26, 2015
commit 1f7ab9522c185cecb9f51ee882a28aa8ea534cc4
26 changes: 13 additions & 13 deletions R/pkg/R/DataFrame.R
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ setMethod("names<-",
signature(x = "DataFrame"),
function(x, value) {
if (!is.null(value)) {
sdf <- callJMethod(x@sdf, "toDF", listToSeq(as.list(value)))
sdf <- callJMethod(x@sdf, "toDF", as.list(value))
dataFrame(sdf)
}
})
Expand Down Expand Up @@ -843,10 +843,10 @@ setMethod("groupBy",
function(x, ...) {
cols <- list(...)
if (length(cols) >= 1 && class(cols[[1]]) == "character") {
sgd <- callJMethod(x@sdf, "groupBy", cols[[1]], listToSeq(cols[-1]))
sgd <- callJMethod(x@sdf, "groupBy", cols[[1]], cols[-1])
} else {
jcol <- lapply(cols, function(c) { c@jc })
sgd <- callJMethod(x@sdf, "groupBy", listToSeq(jcol))
sgd <- callJMethod(x@sdf, "groupBy", jcol)
}
groupedData(sgd)
})
Expand Down Expand Up @@ -1053,7 +1053,7 @@ setMethod("[", signature(x = "DataFrame", i = "Column"),
#' }
setMethod("select", signature(x = "DataFrame", col = "character"),
function(x, col, ...) {
sdf <- callJMethod(x@sdf, "select", col, toSeq(...))
sdf <- callJMethod(x@sdf, "select", col, list(...))
dataFrame(sdf)
})

Expand All @@ -1064,7 +1064,7 @@ setMethod("select", signature(x = "DataFrame", col = "Column"),
jcols <- lapply(list(col, ...), function(c) {
c@jc
})
sdf <- callJMethod(x@sdf, "select", listToSeq(jcols))
sdf <- callJMethod(x@sdf, "select", jcols)
dataFrame(sdf)
})

Expand All @@ -1080,7 +1080,7 @@ setMethod("select",
col(c)@jc
}
})
sdf <- callJMethod(x@sdf, "select", listToSeq(cols))
sdf <- callJMethod(x@sdf, "select", cols)
dataFrame(sdf)
})

Expand All @@ -1107,7 +1107,7 @@ setMethod("selectExpr",
signature(x = "DataFrame", expr = "character"),
function(x, expr, ...) {
exprList <- list(expr, ...)
sdf <- callJMethod(x@sdf, "selectExpr", listToSeq(exprList))
sdf <- callJMethod(x@sdf, "selectExpr", exprList)
dataFrame(sdf)
})

Expand Down Expand Up @@ -1272,12 +1272,12 @@ setMethod("arrange",
signature(x = "DataFrame", col = "characterOrColumn"),
function(x, col, ...) {
if (class(col) == "character") {
sdf <- callJMethod(x@sdf, "sort", col, toSeq(...))
sdf <- callJMethod(x@sdf, "sort", col, list(...))
} else if (class(col) == "Column") {
jcols <- lapply(list(col, ...), function(c) {
c@jc
})
sdf <- callJMethod(x@sdf, "sort", listToSeq(jcols))
sdf <- callJMethod(x@sdf, "sort", jcols)
}
dataFrame(sdf)
})
Expand Down Expand Up @@ -1624,7 +1624,7 @@ setMethod("describe",
signature(x = "DataFrame", col = "character"),
function(x, col, ...) {
colList <- list(col, ...)
sdf <- callJMethod(x@sdf, "describe", listToSeq(colList))
sdf <- callJMethod(x@sdf, "describe", colList)
dataFrame(sdf)
})

Expand All @@ -1634,7 +1634,7 @@ setMethod("describe",
signature(x = "DataFrame"),
function(x) {
colList <- as.list(c(columns(x)))
sdf <- callJMethod(x@sdf, "describe", listToSeq(colList))
sdf <- callJMethod(x@sdf, "describe", colList)
dataFrame(sdf)
})

Expand Down Expand Up @@ -1691,7 +1691,7 @@ setMethod("dropna",

naFunctions <- callJMethod(x@sdf, "na")
sdf <- callJMethod(naFunctions, "drop",
as.integer(minNonNulls), listToSeq(as.list(cols)))
as.integer(minNonNulls), as.list(cols))
dataFrame(sdf)
})

Expand Down Expand Up @@ -1775,7 +1775,7 @@ setMethod("fillna",
sdf <- if (length(cols) == 0) {
callJMethod(naFunctions, "fill", value)
} else {
callJMethod(naFunctions, "fill", value, listToSeq(as.list(cols)))
callJMethod(naFunctions, "fill", value, as.list(cols))
}
dataFrame(sdf)
})
Expand Down
3 changes: 1 addition & 2 deletions R/pkg/R/column.R
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,7 @@ setMethod("cast",
setMethod("%in%",
signature(x = "Column"),
function(x, table) {
table <- listToSeq(as.list(table))
jc <- callJMethod(x@jc, "in", table)
jc <- callJMethod(x@jc, "in", as.list(table))
return(column(jc))
})

Expand Down
12 changes: 6 additions & 6 deletions R/pkg/R/functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -1331,7 +1331,7 @@ setMethod("countDistinct",
x@jc
})
jc <- callJStatic("org.apache.spark.sql.functions", "countDistinct", x@jc,
listToSeq(jcol))
jcol)
column(jc)
})

Expand All @@ -1348,7 +1348,7 @@ setMethod("concat",
signature(x = "Column"),
function(x, ...) {
jcols <- lapply(list(x, ...), function(x) { x@jc })
jc <- callJStatic("org.apache.spark.sql.functions", "concat", listToSeq(jcols))
jc <- callJStatic("org.apache.spark.sql.functions", "concat", jcols)
column(jc)
})

Expand All @@ -1366,7 +1366,7 @@ setMethod("greatest",
function(x, ...) {
stopifnot(length(list(...)) > 0)
jcols <- lapply(list(x, ...), function(x) { x@jc })
jc <- callJStatic("org.apache.spark.sql.functions", "greatest", listToSeq(jcols))
jc <- callJStatic("org.apache.spark.sql.functions", "greatest", jcols)
column(jc)
})

Expand All @@ -1384,7 +1384,7 @@ setMethod("least",
function(x, ...) {
stopifnot(length(list(...)) > 0)
jcols <- lapply(list(x, ...), function(x) { x@jc })
jc <- callJStatic("org.apache.spark.sql.functions", "least", listToSeq(jcols))
jc <- callJStatic("org.apache.spark.sql.functions", "least", jcols)
column(jc)
})

Expand Down Expand Up @@ -1675,7 +1675,7 @@ setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"),
#' @export
setMethod("concat_ws", signature(sep = "character", x = "Column"),
function(sep, x, ...) {
jcols <- listToSeq(lapply(list(x, ...), function(x) { x@jc }))
jcols <- lapply(list(x, ...), function(x) { x@jc })
jc <- callJStatic("org.apache.spark.sql.functions", "concat_ws", sep, jcols)
column(jc)
})
Expand Down Expand Up @@ -1723,7 +1723,7 @@ setMethod("expr", signature(x = "character"),
#' @export
setMethod("format_string", signature(format = "character", x = "Column"),
function(format, x, ...) {
jcols <- listToSeq(lapply(list(x, ...), function(arg) { arg@jc }))
jcols <- lapply(list(x, ...), function(arg) { arg@jc })
jc <- callJStatic("org.apache.spark.sql.functions",
"format_string",
format, jcols)
Expand Down
4 changes: 2 additions & 2 deletions R/pkg/R/group.R
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ setMethod("agg",
}
}
jcols <- lapply(cols, function(c) { c@jc })
sdf <- callJMethod(x@sgd, "agg", jcols[[1]], listToSeq(jcols[-1]))
sdf <- callJMethod(x@sgd, "agg", jcols[[1]], jcols[-1])
} else {
stop("agg can only support Column or character")
}
Expand All @@ -124,7 +124,7 @@ createMethod <- function(name) {
setMethod(name,
signature(x = "GroupedData"),
function(x, ...) {
sdf <- callJMethod(x@sgd, name, toSeq(...))
sdf <- callJMethod(x@sgd, name, list(...))
dataFrame(sdf)
})
}
Expand Down
2 changes: 1 addition & 1 deletion R/pkg/R/schema.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ structType.structField <- function(x, ...) {
})
stObj <- callJStatic("org.apache.spark.sql.api.r.SQLUtils",
"createStructType",
listToSeq(sfObjList))
sfObjList)
structType(stObj)
}

Expand Down
10 changes: 0 additions & 10 deletions R/pkg/R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -360,16 +360,6 @@ numToInt <- function(num) {
as.integer(num)
}

# create a Seq in JVM
toSeq <- function(...) {
callJStatic("org.apache.spark.sql.api.r.SQLUtils", "toSeq", list(...))
}

# create a Seq in JVM from a list
listToSeq <- function(l) {
callJStatic("org.apache.spark.sql.api.r.SQLUtils", "toSeq", l)
}

# Utility function to recursively traverse the Abstract Syntax Tree (AST) of a
# user defined function (UDF), and to examine variables in the UDF to decide
# if their values should be included in the new function environment.
Expand Down
96 changes: 63 additions & 33 deletions core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -125,29 +125,41 @@ private[r] class RBackendHandler(server: RBackend)
val methods = cls.getMethods
val selectedMethods = methods.filter(m => m.getName == methodName)
if (selectedMethods.length > 0) {
val methods = selectedMethods.filter { x =>
matchMethod(numArgs, args, x.getParameterTypes)
}
if (methods.isEmpty) {
val (index, convertedArgs) = matchMethod(
selectedMethods.map(_.getParameterTypes),
args)

if (index.isEmpty) {
logWarning(s"cannot find matching method ${cls}.$methodName. "
+ s"Candidates are:")
selectedMethods.foreach { method =>
logWarning(s"$methodName(${method.getParameterTypes.mkString(",")})")
}
throw new Exception(s"No matched method found for $cls.$methodName")
}
val ret = methods.head.invoke(obj, args : _*)

val ret = selectedMethods(index.get).invoke(obj, convertedArgs : _*)

// Write status bit
writeInt(dos, 0)
writeObject(dos, ret.asInstanceOf[AnyRef])
} else if (methodName == "<init>") {
// methodName should be "<init>" for constructor
val ctor = cls.getConstructors.filter { x =>
matchMethod(numArgs, args, x.getParameterTypes)
}.head
val ctors = cls.getConstructors
val (index, convertedArgs) = matchMethod(
ctors.map(_.getParameterTypes),
args)

if (index.isEmpty) {
logWarning(s"cannot find matching constructor for ${cls}. "
+ s"Candidates are:")
ctors.foreach { ctor =>
logWarning(s"$cls(${ctor.getParameterTypes.mkString(",")})")
}
throw new Exception(s"No matched constructor found for $cls")
}

val obj = ctor.newInstance(args : _*)
val obj = ctors(index.get).newInstance(convertedArgs : _*)

writeInt(dos, 0)
writeObject(dos, obj.asInstanceOf[AnyRef])
Expand All @@ -171,35 +183,53 @@ private[r] class RBackendHandler(server: RBackend)
}.toArray
}

// Checks if the arguments passed in args matches the parameter types.
// NOTE: Currently we do exact match. We may add type conversions later.
// Find a matching method in all methods of the same name of a class
// according to the passed arguments.
def matchMethod(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also document what the returned objects are for this method ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure. will add.

numArgs: Int,
args: Array[java.lang.Object],
parameterTypes: Array[Class[_]]): Boolean = {
if (parameterTypes.length != numArgs) {
return false
}
parameterTypesOfMethods: Array[Array[Class[_]]],
args: Array[Object]): (Option[Int], Array[Object]) = {
val numArgs = args.length

for (index <- 0 to parameterTypesOfMethods.length - 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0 until xxx

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

val parameterTypes = parameterTypesOfMethods(index)

if (parameterTypes.length == numArgs) {
val convertedArgs = new Array[Object](numArgs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason we shouldn't modify args in place and make a copy ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arg is checked and converted one by one when matching a method. It is possible that after some args were converted, an arg fails, then all args have to be reverted to the original ones for matching next method.

It could be possible for now that checking all args for one method, and only after the checking passes, then do conversion, thus we can modify args in place.

I will refine the code.

Array.copy(args, 0, convertedArgs, 0, numArgs)

var argMatched = true
var i = 0
while (i < numArgs && argMatched) {
val parameterType = parameterTypes(i)
var parameterWrapperType = parameterType

// Convert native parameters to Object types as args is Array[Object] here
if (parameterType.isPrimitive) {
parameterWrapperType = parameterType match {
case java.lang.Integer.TYPE => classOf[java.lang.Integer]
case java.lang.Long.TYPE => classOf[java.lang.Integer]
case java.lang.Double.TYPE => classOf[java.lang.Double]
case java.lang.Boolean.TYPE => classOf[java.lang.Boolean]
case _ => parameterType
}
} else if (parameterType == classOf[Seq[Any]] && args(i).getClass.isArray) {
// Convert a Java array to scala Seq
convertedArgs(i) = args(i).asInstanceOf[Array[_]].toSeq
}
if (!parameterWrapperType.isInstance(convertedArgs(i))) {
argMatched = false
}
i = i + 1
}

for (i <- 0 to numArgs - 1) {
val parameterType = parameterTypes(i)
var parameterWrapperType = parameterType

// Convert native parameters to Object types as args is Array[Object] here
if (parameterType.isPrimitive) {
parameterWrapperType = parameterType match {
case java.lang.Integer.TYPE => classOf[java.lang.Integer]
case java.lang.Long.TYPE => classOf[java.lang.Integer]
case java.lang.Double.TYPE => classOf[java.lang.Double]
case java.lang.Boolean.TYPE => classOf[java.lang.Boolean]
case _ => parameterType
if (argMatched) {
// For now, we return the first matching method.
// TODO: find best method in matching methods.
return (Some(index), convertedArgs)
}
}
if (!parameterWrapperType.isInstance(args(i))) {
return false
}
}
true
(None, args)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@ private[r] object SQLUtils {
new JavaSparkContext(sqlCtx.sparkContext)
}

def toSeq[T](arr: Array[T]): Seq[T] = {
arr.toSeq
}

def createStructType(fields : Seq[StructField]): StructType = {
StructType(fields)
}
Expand Down