Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[SPARKR] Match pyspark features in SparkR communication protocol.
(cherry picked from commit 628c7b5)
Signed-off-by: Marcelo Vanzin <[email protected]>
(cherry picked from commit 16cd9ac)
Signed-off-by: Marcelo Vanzin <[email protected]>
  • Loading branch information
Marcelo Vanzin committed May 10, 2018
commit f96d13dd67b2d39e5fff80a6bc4a1b1fc36745c6
4 changes: 2 additions & 2 deletions R/pkg/R/client.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

# Creates a SparkR client connection object
# if one doesn't already exist
connectBackend <- function(hostname, port, timeout) {
connectBackend <- function(hostname, port, timeout, authSecret) {
if (exists(".sparkRcon", envir = .sparkREnv)) {
if (isOpen(.sparkREnv[[".sparkRCon"]])) {
cat("SparkRBackend client connection already exists\n")
Expand All @@ -29,7 +29,7 @@ connectBackend <- function(hostname, port, timeout) {

con <- socketConnection(host = hostname, port = port, server = FALSE,
blocking = TRUE, open = "wb", timeout = timeout)

doServerAuth(con, authSecret)
assign(".sparkRCon", con, envir = .sparkREnv)
con
}
Expand Down
10 changes: 7 additions & 3 deletions R/pkg/R/deserialize.R
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,18 @@ readTypedObject <- function(con, type) {
stop(paste("Unsupported type for deserialization", type)))
}

readString <- function(con) {
stringLen <- readInt(con)
raw <- readBin(con, raw(), stringLen, endian = "big")
readStringData <- function(con, len) {
raw <- readBin(con, raw(), len, endian = "big")
string <- rawToChar(raw)
Encoding(string) <- "UTF-8"
string
}

readString <- function(con) {
stringLen <- readInt(con)
readStringData(con, stringLen)
}

readInt <- function(con) {
readBin(con, integer(), n = 1, endian = "big")
}
Expand Down
39 changes: 34 additions & 5 deletions R/pkg/R/sparkR.R
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ sparkR.sparkContext <- function(
" please use the --packages commandline instead", sep = ","))
}
backendPort <- existingPort
authSecret <- Sys.getenv("SPARKR_BACKEND_AUTH_SECRET")
if (nchar(authSecret) == 0) {
stop("Auth secret not provided in environment.")
}
} else {
path <- tempfile(pattern = "backend_port")
submitOps <- getClientModeSparkSubmitOpts(
Expand Down Expand Up @@ -189,16 +193,27 @@ sparkR.sparkContext <- function(
monitorPort <- readInt(f)
rLibPath <- readString(f)
connectionTimeout <- readInt(f)

# Don't use readString() so that we can provide a useful
# error message if the R and Java versions are mismatched.
authSecretLen = readInt(f)
if (length(authSecretLen) == 0 || authSecretLen == 0) {
stop("Unexpected EOF in JVM connection data. Mismatched versions?")
}
authSecret <- readStringData(f, authSecretLen)
close(f)
file.remove(path)
if (length(backendPort) == 0 || backendPort == 0 ||
length(monitorPort) == 0 || monitorPort == 0 ||
length(rLibPath) != 1) {
length(rLibPath) != 1 || length(authSecret) == 0) {
stop("JVM failed to launch")
}
assign(".monitorConn",
socketConnection(port = monitorPort, timeout = connectionTimeout),
envir = .sparkREnv)

monitorConn <- socketConnection(port = monitorPort, blocking = TRUE,
timeout = connectionTimeout, open = "wb")
doServerAuth(monitorConn, authSecret)

assign(".monitorConn", monitorConn, envir = .sparkREnv)
assign(".backendLaunched", 1, envir = .sparkREnv)
if (rLibPath != "") {
assign(".libPath", rLibPath, envir = .sparkREnv)
Expand All @@ -208,7 +223,7 @@ sparkR.sparkContext <- function(

.sparkREnv$backendPort <- backendPort
tryCatch({
connectBackend("localhost", backendPort, timeout = connectionTimeout)
connectBackend("localhost", backendPort, timeout = connectionTimeout, authSecret = authSecret)
},
error = function(err) {
stop("Failed to connect JVM\n")
Expand Down Expand Up @@ -632,3 +647,17 @@ sparkCheckInstall <- function(sparkHome, master, deployMode) {
NULL
}
}

# Utility function for sending auth data over a socket and checking the server's reply.
doServerAuth <- function(con, authSecret) {
if (nchar(authSecret) == 0) {
stop("Auth secret not provided.")
}
writeString(con, authSecret)
flush(con)
reply <- readString(con)
if (reply != "ok") {
close(con)
stop("Unexpected reply from server.")
}
}
4 changes: 3 additions & 1 deletion R/pkg/inst/worker/daemon.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ suppressPackageStartupMessages(library(SparkR))

port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT"))
inputCon <- socketConnection(
port = port, open = "rb", blocking = TRUE, timeout = connectionTimeout)
port = port, open = "wb", blocking = TRUE, timeout = connectionTimeout)

SparkR:::doServerAuth(inputCon, Sys.getenv("SPARKR_WORKER_SECRET"))

while (TRUE) {
ready <- socketSelect(list(inputCon))
Expand Down
5 changes: 4 additions & 1 deletion R/pkg/inst/worker/worker.R
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,12 @@ suppressPackageStartupMessages(library(SparkR))

port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT"))
inputCon <- socketConnection(
port = port, blocking = TRUE, open = "rb", timeout = connectionTimeout)
port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout)
SparkR:::doServerAuth(inputCon, Sys.getenv("SPARKR_WORKER_SECRET"))

outputCon <- socketConnection(
port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout)
SparkR:::doServerAuth(outputCon, Sys.getenv("SPARKR_WORKER_SECRET"))

# read the index of the current partition inside the RDD
partition <- SparkR:::readInt(inputCon)
Expand Down
38 changes: 38 additions & 0 deletions core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.api.r

import java.io.{DataInputStream, DataOutputStream}
import java.net.Socket

import org.apache.spark.SparkConf
import org.apache.spark.security.SocketAuthHelper

private[spark] class RAuthHelper(conf: SparkConf) extends SocketAuthHelper(conf) {

override protected def readUtf8(s: Socket): String = {
SerDe.readString(new DataInputStream(s.getInputStream()))
}

override protected def writeUtf8(str: String, s: Socket): Unit = {
val out = s.getOutputStream()
SerDe.writeString(new DataOutputStream(out), str)
out.flush()
}

}
43 changes: 37 additions & 6 deletions core/src/main/scala/org/apache/spark/api/r/RBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

package org.apache.spark.api.r

import java.io.{DataOutputStream, File, FileOutputStream, IOException}
import java.net.{InetAddress, InetSocketAddress, ServerSocket}
import java.io.{DataInputStream, DataOutputStream, File, FileOutputStream, IOException}
import java.net.{InetAddress, InetSocketAddress, ServerSocket, Socket}
import java.util.concurrent.TimeUnit

import io.netty.bootstrap.ServerBootstrap
Expand All @@ -32,6 +32,8 @@ import io.netty.handler.timeout.ReadTimeoutHandler

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.util.Utils

/**
* Netty-based backend server that is used to communicate between R and Java.
Expand All @@ -45,14 +47,15 @@ private[spark] class RBackend {
/** Tracks JVM objects returned to R for this RBackend instance. */
private[r] val jvmObjectTracker = new JVMObjectTracker

def init(): Int = {
def init(): (Int, RAuthHelper) = {
val conf = new SparkConf()
val backendConnectionTimeout = conf.getInt(
"spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT)
bossGroup = new NioEventLoopGroup(
conf.getInt("spark.r.numRBackendThreads", SparkRDefaults.DEFAULT_NUM_RBACKEND_THREADS))
val workerGroup = bossGroup
val handler = new RBackendHandler(this)
val authHelper = new RAuthHelper(conf)

bootstrap = new ServerBootstrap()
.group(bossGroup, workerGroup)
Expand All @@ -71,13 +74,16 @@ private[spark] class RBackend {
new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4))
.addLast("decoder", new ByteArrayDecoder())
.addLast("readTimeoutHandler", new ReadTimeoutHandler(backendConnectionTimeout))
.addLast(new RBackendAuthHandler(authHelper.secret))
.addLast("handler", handler)
}
})

channelFuture = bootstrap.bind(new InetSocketAddress("localhost", 0))
channelFuture.syncUninterruptibly()
channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort()

val port = channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort()
(port, authHelper)
}

def run(): Unit = {
Expand Down Expand Up @@ -116,7 +122,7 @@ private[spark] object RBackend extends Logging {
val sparkRBackend = new RBackend()
try {
// bind to random port
val boundPort = sparkRBackend.init()
val (boundPort, authHelper) = sparkRBackend.init()
val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
val listenPort = serverSocket.getLocalPort()
// Connection timeout is set by socket client. To make it configurable we will pass the
Expand All @@ -133,6 +139,7 @@ private[spark] object RBackend extends Logging {
dos.writeInt(listenPort)
SerDe.writeString(dos, RUtils.rPackages.getOrElse(""))
dos.writeInt(backendConnectionTimeout)
SerDe.writeString(dos, authHelper.secret)
dos.close()
f.renameTo(new File(path))

Expand All @@ -144,12 +151,35 @@ private[spark] object RBackend extends Logging {
val buf = new Array[Byte](1024)
// shutdown JVM if R does not connect back in 10 seconds
serverSocket.setSoTimeout(10000)

// Wait for the R process to connect back, ignoring any failed auth attempts. Allow
// a max number of connection attempts to avoid looping forever.
try {
val inSocket = serverSocket.accept()
var remainingAttempts = 10
var inSocket: Socket = null
while (inSocket == null) {
inSocket = serverSocket.accept()
try {
authHelper.authClient(inSocket)
} catch {
case e: Exception =>
remainingAttempts -= 1
if (remainingAttempts == 0) {
val msg = "Too many failed authentication attempts."
logError(msg)
throw new IllegalStateException(msg)
}
logInfo("Client connection failed authentication.")
inSocket = null
}
}

serverSocket.close()

// wait for the end of socket, closed if R process die
inSocket.getInputStream().read(buf)
} finally {
serverSocket.close()
sparkRBackend.close()
System.exit(0)
}
Expand All @@ -165,4 +195,5 @@ private[spark] object RBackend extends Logging {
}
System.exit(0)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.api.r

import java.io.{ByteArrayOutputStream, DataOutputStream}
import java.nio.charset.StandardCharsets.UTF_8

import io.netty.channel.{Channel, ChannelHandlerContext, SimpleChannelInboundHandler}

import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils

/**
* Authentication handler for connections from the R process.
*/
private class RBackendAuthHandler(secret: String)
extends SimpleChannelInboundHandler[Array[Byte]] with Logging {

override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = {
// The R code adds a null terminator to serialized strings, so ignore it here.
val clientSecret = new String(msg, 0, msg.length - 1, UTF_8)
try {
require(secret == clientSecret, "Auth secret mismatch.")
ctx.pipeline().remove(this)
writeReply("ok", ctx.channel())
} catch {
case e: Exception =>
logInfo("Authentication failure.", e)
writeReply("err", ctx.channel())
ctx.close()
}
}

private def writeReply(reply: String, chan: Channel): Unit = {
val out = new ByteArrayOutputStream()
SerDe.writeString(new DataOutputStream(out), reply)
chan.writeAndFlush(out.toByteArray())
}

}
Loading