Skip to content

Commit 628c7b5

Browse files
author
Marcelo Vanzin
committed
[SPARKR] Match pyspark features in SparkR communication protocol.
1 parent cc613b5 commit 628c7b5

File tree

10 files changed

+210
-29
lines changed

10 files changed

+210
-29
lines changed

R/pkg/R/client.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
# Creates a SparkR client connection object
2121
# if one doesn't already exist
22-
connectBackend <- function(hostname, port, timeout) {
22+
connectBackend <- function(hostname, port, timeout, authSecret) {
2323
if (exists(".sparkRcon", envir = .sparkREnv)) {
2424
if (isOpen(.sparkREnv[[".sparkRCon"]])) {
2525
cat("SparkRBackend client connection already exists\n")
@@ -29,7 +29,7 @@ connectBackend <- function(hostname, port, timeout) {
2929

3030
con <- socketConnection(host = hostname, port = port, server = FALSE,
3131
blocking = TRUE, open = "wb", timeout = timeout)
32-
32+
doServerAuth(con, authSecret)
3333
assign(".sparkRCon", con, envir = .sparkREnv)
3434
con
3535
}

R/pkg/R/deserialize.R

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,18 @@ readTypedObject <- function(con, type) {
6060
stop(paste("Unsupported type for deserialization", type)))
6161
}
6262

63-
readString <- function(con) {
64-
stringLen <- readInt(con)
65-
raw <- readBin(con, raw(), stringLen, endian = "big")
63+
readStringData <- function(con, len) {
64+
raw <- readBin(con, raw(), len, endian = "big")
6665
string <- rawToChar(raw)
6766
Encoding(string) <- "UTF-8"
6867
string
6968
}
7069

70+
readString <- function(con) {
71+
stringLen <- readInt(con)
72+
readStringData(con, stringLen)
73+
}
74+
7175
readInt <- function(con) {
7276
readBin(con, integer(), n = 1, endian = "big")
7377
}

R/pkg/R/sparkR.R

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,10 @@ sparkR.sparkContext <- function(
158158
" please use the --packages commandline instead", sep = ","))
159159
}
160160
backendPort <- existingPort
161+
authSecret <- Sys.getenv("SPARKR_BACKEND_AUTH_SECRET")
162+
if (nchar(authSecret) == 0) {
163+
stop("Auth secret not provided in environment.")
164+
}
161165
} else {
162166
path <- tempfile(pattern = "backend_port")
163167
submitOps <- getClientModeSparkSubmitOpts(
@@ -186,16 +190,27 @@ sparkR.sparkContext <- function(
186190
monitorPort <- readInt(f)
187191
rLibPath <- readString(f)
188192
connectionTimeout <- readInt(f)
193+
194+
# Don't use readString() so that we can provide a useful
195+
# error message if the R and Java versions are mismatched.
196+
authSecretLen = readInt(f)
197+
if (length(authSecretLen) == 0 || authSecretLen == 0) {
198+
stop("Unexpected EOF in JVM connection data. Mismatched versions?")
199+
}
200+
authSecret <- readStringData(f, authSecretLen)
189201
close(f)
190202
file.remove(path)
191203
if (length(backendPort) == 0 || backendPort == 0 ||
192204
length(monitorPort) == 0 || monitorPort == 0 ||
193-
length(rLibPath) != 1) {
205+
length(rLibPath) != 1 || length(authSecret) == 0) {
194206
stop("JVM failed to launch")
195207
}
196-
assign(".monitorConn",
197-
socketConnection(port = monitorPort, timeout = connectionTimeout),
198-
envir = .sparkREnv)
208+
209+
monitorConn <- socketConnection(port = monitorPort, blocking = TRUE,
210+
timeout = connectionTimeout, open = "wb")
211+
doServerAuth(monitorConn, authSecret)
212+
213+
assign(".monitorConn", monitorConn, envir = .sparkREnv)
199214
assign(".backendLaunched", 1, envir = .sparkREnv)
200215
if (rLibPath != "") {
201216
assign(".libPath", rLibPath, envir = .sparkREnv)
@@ -205,7 +220,7 @@ sparkR.sparkContext <- function(
205220

206221
.sparkREnv$backendPort <- backendPort
207222
tryCatch({
208-
connectBackend("localhost", backendPort, timeout = connectionTimeout)
223+
connectBackend("localhost", backendPort, timeout = connectionTimeout, authSecret = authSecret)
209224
},
210225
error = function(err) {
211226
stop("Failed to connect JVM\n")
@@ -687,3 +702,17 @@ sparkCheckInstall <- function(sparkHome, master, deployMode) {
687702
NULL
688703
}
689704
}
705+
706+
# Utility function for sending auth data over a socket and checking the server's reply.
707+
doServerAuth <- function(con, authSecret) {
708+
if (nchar(authSecret) == 0) {
709+
stop("Auth secret not provided.")
710+
}
711+
writeString(con, authSecret)
712+
flush(con)
713+
reply <- readString(con)
714+
if (reply != "ok") {
715+
close(con)
716+
stop("Unexpected reply from server.")
717+
}
718+
}

R/pkg/inst/worker/daemon.R

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ suppressPackageStartupMessages(library(SparkR))
2828

2929
port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT"))
3030
inputCon <- socketConnection(
31-
port = port, open = "rb", blocking = TRUE, timeout = connectionTimeout)
31+
port = port, open = "wb", blocking = TRUE, timeout = connectionTimeout)
32+
33+
SparkR:::doServerAuth(inputCon, Sys.getenv("SPARKR_WORKER_SECRET"))
3234

3335
# Waits indefinitely for a socket connecion by default.
3436
selectTimeout <- NULL

R/pkg/inst/worker/worker.R

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,12 @@ suppressPackageStartupMessages(library(SparkR))
100100

101101
port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT"))
102102
inputCon <- socketConnection(
103-
port = port, blocking = TRUE, open = "rb", timeout = connectionTimeout)
103+
port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout)
104+
SparkR:::doServerAuth(inputCon, Sys.getenv("SPARKR_WORKER_SECRET"))
105+
104106
outputCon <- socketConnection(
105107
port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout)
108+
SparkR:::doServerAuth(outputCon, Sys.getenv("SPARKR_WORKER_SECRET"))
106109

107110
# read the index of the current partition inside the RDD
108111
partition <- SparkR:::readInt(inputCon)
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.api.r
19+
20+
import java.io.{DataInputStream, DataOutputStream}
21+
import java.net.Socket
22+
23+
import org.apache.spark.SparkConf
24+
import org.apache.spark.security.SocketAuthHelper
25+
26+
private[spark] class RAuthHelper(conf: SparkConf) extends SocketAuthHelper(conf) {
27+
28+
override protected def readUtf8(s: Socket): String = {
29+
SerDe.readString(new DataInputStream(s.getInputStream()))
30+
}
31+
32+
override protected def writeUtf8(str: String, s: Socket): Unit = {
33+
val out = s.getOutputStream()
34+
SerDe.writeString(new DataOutputStream(out), str)
35+
out.flush()
36+
}
37+
38+
}

core/src/main/scala/org/apache/spark/api/r/RBackend.scala

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717

1818
package org.apache.spark.api.r
1919

20-
import java.io.{DataOutputStream, File, FileOutputStream, IOException}
21-
import java.net.{InetAddress, InetSocketAddress, ServerSocket}
20+
import java.io.{DataInputStream, DataOutputStream, File, FileOutputStream, IOException}
21+
import java.net.{InetAddress, InetSocketAddress, ServerSocket, Socket}
2222
import java.util.concurrent.TimeUnit
2323

2424
import io.netty.bootstrap.ServerBootstrap
@@ -32,6 +32,8 @@ import io.netty.handler.timeout.ReadTimeoutHandler
3232

3333
import org.apache.spark.SparkConf
3434
import org.apache.spark.internal.Logging
35+
import org.apache.spark.network.util.JavaUtils
36+
import org.apache.spark.util.Utils
3537

3638
/**
3739
* Netty-based backend server that is used to communicate between R and Java.
@@ -45,14 +47,15 @@ private[spark] class RBackend {
4547
/** Tracks JVM objects returned to R for this RBackend instance. */
4648
private[r] val jvmObjectTracker = new JVMObjectTracker
4749

48-
def init(): Int = {
50+
def init(): (Int, RAuthHelper) = {
4951
val conf = new SparkConf()
5052
val backendConnectionTimeout = conf.getInt(
5153
"spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT)
5254
bossGroup = new NioEventLoopGroup(
5355
conf.getInt("spark.r.numRBackendThreads", SparkRDefaults.DEFAULT_NUM_RBACKEND_THREADS))
5456
val workerGroup = bossGroup
5557
val handler = new RBackendHandler(this)
58+
val authHelper = new RAuthHelper(conf)
5659

5760
bootstrap = new ServerBootstrap()
5861
.group(bossGroup, workerGroup)
@@ -71,13 +74,16 @@ private[spark] class RBackend {
7174
new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4))
7275
.addLast("decoder", new ByteArrayDecoder())
7376
.addLast("readTimeoutHandler", new ReadTimeoutHandler(backendConnectionTimeout))
77+
.addLast(new RBackendAuthHandler(authHelper.secret))
7478
.addLast("handler", handler)
7579
}
7680
})
7781

7882
channelFuture = bootstrap.bind(new InetSocketAddress("localhost", 0))
7983
channelFuture.syncUninterruptibly()
80-
channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort()
84+
85+
val port = channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort()
86+
(port, authHelper)
8187
}
8288

8389
def run(): Unit = {
@@ -116,7 +122,7 @@ private[spark] object RBackend extends Logging {
116122
val sparkRBackend = new RBackend()
117123
try {
118124
// bind to random port
119-
val boundPort = sparkRBackend.init()
125+
val (boundPort, authHelper) = sparkRBackend.init()
120126
val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
121127
val listenPort = serverSocket.getLocalPort()
122128
// Connection timeout is set by socket client. To make it configurable we will pass the
@@ -133,6 +139,7 @@ private[spark] object RBackend extends Logging {
133139
dos.writeInt(listenPort)
134140
SerDe.writeString(dos, RUtils.rPackages.getOrElse(""))
135141
dos.writeInt(backendConnectionTimeout)
142+
SerDe.writeString(dos, authHelper.secret)
136143
dos.close()
137144
f.renameTo(new File(path))
138145

@@ -144,12 +151,35 @@ private[spark] object RBackend extends Logging {
144151
val buf = new Array[Byte](1024)
145152
// shutdown JVM if R does not connect back in 10 seconds
146153
serverSocket.setSoTimeout(10000)
154+
155+
// Wait for the R process to connect back, ignoring any failed auth attempts. Allow
156+
// a max number of connection attempts to avoid looping forever.
147157
try {
148-
val inSocket = serverSocket.accept()
158+
var remainingAttempts = 10
159+
var inSocket: Socket = null
160+
while (inSocket == null) {
161+
inSocket = serverSocket.accept()
162+
try {
163+
authHelper.authClient(inSocket)
164+
} catch {
165+
case e: Exception =>
166+
remainingAttempts -= 1
167+
if (remainingAttempts == 0) {
168+
val msg = "Too many failed authentication attempts."
169+
logError(msg)
170+
throw new IllegalStateException(msg)
171+
}
172+
logInfo("Client connection failed authentication.")
173+
inSocket = null
174+
}
175+
}
176+
149177
serverSocket.close()
178+
150179
// wait for the end of socket, closed if R process die
151180
inSocket.getInputStream().read(buf)
152181
} finally {
182+
serverSocket.close()
153183
sparkRBackend.close()
154184
System.exit(0)
155185
}
@@ -165,4 +195,5 @@ private[spark] object RBackend extends Logging {
165195
}
166196
System.exit(0)
167197
}
198+
168199
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.api.r
19+
20+
import java.io.{ByteArrayOutputStream, DataOutputStream}
21+
import java.nio.charset.StandardCharsets.UTF_8
22+
23+
import io.netty.channel.{Channel, ChannelHandlerContext, SimpleChannelInboundHandler}
24+
25+
import org.apache.spark.internal.Logging
26+
import org.apache.spark.util.Utils
27+
28+
/**
29+
* Authentication handler for connections from the R process.
30+
*/
31+
private class RBackendAuthHandler(secret: String)
32+
extends SimpleChannelInboundHandler[Array[Byte]] with Logging {
33+
34+
override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = {
35+
// The R code adds a null terminator to serialized strings, so ignore it here.
36+
val clientSecret = new String(msg, 0, msg.length - 1, UTF_8)
37+
try {
38+
require(secret == clientSecret, "Auth secret mismatch.")
39+
ctx.pipeline().remove(this)
40+
writeReply("ok", ctx.channel())
41+
} catch {
42+
case e: Exception =>
43+
logInfo("Authentication failure.", e)
44+
writeReply("err", ctx.channel())
45+
ctx.close()
46+
}
47+
}
48+
49+
private def writeReply(reply: String, chan: Channel): Unit = {
50+
val out = new ByteArrayOutputStream()
51+
SerDe.writeString(new DataOutputStream(out), reply)
52+
chan.writeAndFlush(out.toByteArray())
53+
}
54+
55+
}

0 commit comments

Comments
 (0)