Skip to content

Commit 416e71a

Browse files
Oscar D. Lara Yejasshivaram
authored andcommitted
[SPARK-13327][SPARKR] Added parameter validations for colnames<-
Author: Oscar D. Lara Yejas <[email protected]> Author: Oscar D. Lara Yejas <[email protected]> Closes #11220 from olarayej/SPARK-13312-3.
1 parent 88fa866 commit 416e71a

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

R/pkg/R/DataFrame.R

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,8 +303,28 @@ setMethod("colnames",
303303
#' @rdname columns
304304
#' @name colnames<-
305305
setMethod("colnames<-",
306-
signature(x = "DataFrame", value = "character"),
306+
signature(x = "DataFrame"),
307307
function(x, value) {
308+
309+
# Check parameter integrity
310+
if (class(value) != "character") {
311+
stop("Invalid column names.")
312+
}
313+
314+
if (length(value) != ncol(x)) {
315+
stop(
316+
"Column names must have the same length as the number of columns in the dataset.")
317+
}
318+
319+
if (any(is.na(value))) {
320+
stop("Column names cannot be NA.")
321+
}
322+
323+
# Check if the column names have . in it
324+
if (any(regexec(".", value, fixed=TRUE)[[1]][1] != -1)) {
325+
stop("Colum names cannot contain the '.' symbol.")
326+
}
327+
308328
sdf <- callJMethod(x@sdf, "toDF", as.list(value))
309329
dataFrame(sdf)
310330
})

R/pkg/inst/tests/testthat/test_sparkSQL.R

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,17 @@ test_that("names() colnames() set the column names", {
691691
colnames(df) <- c("col3", "col4")
692692
expect_equal(names(df)[1], "col3")
693693

694+
expect_error(colnames(df) <- c("sepal.length", "sepal_width"),
695+
"Colum names cannot contain the '.' symbol.")
696+
expect_error(colnames(df) <- c(1, 2), "Invalid column names.")
697+
expect_error(colnames(df) <- c("a"),
698+
"Column names must have the same length as the number of columns in the dataset.")
699+
expect_error(colnames(df) <- c("1", NA), "Column names cannot be NA.")
700+
701+
# Note: if this test is broken, remove check for "." character on colnames<- method
702+
irisDF <- suppressWarnings(createDataFrame(sqlContext, iris))
703+
expect_equal(names(irisDF)[1], "Sepal_Length")
704+
694705
# Test base::colnames base::names
695706
m2 <- cbind(1, 1:4)
696707
expect_equal(colnames(m2, do.NULL = FALSE), c("col1", "col2"))

0 commit comments

Comments
 (0)