Skip to content
Closed
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
1d6b718
continuous shuffle read RDD
jose-torres May 15, 2018
b5d1008
docs
jose-torres May 17, 2018
af40769
Merge remote-tracking branch 'apache/master' into readerRddMaster
jose-torres May 17, 2018
46456dc
fix ctor
jose-torres May 17, 2018
2ea8a6f
multiple partition test
jose-torres May 17, 2018
955ac79
unset task context after test
jose-torres May 17, 2018
8cefb72
conf from RDD
jose-torres May 18, 2018
f91bfe7
endpoint name
jose-torres May 18, 2018
2590292
testing bool
jose-torres May 18, 2018
859e6e4
tests
jose-torres May 18, 2018
b23b7bb
take instead of poll
jose-torres May 18, 2018
97f7e8f
add interface
jose-torres May 18, 2018
de21b1c
clarify comment
jose-torres May 18, 2018
7dcf51a
multiple
jose-torres May 18, 2018
ad0b5aa
writer with 1 reader partition
jose-torres May 25, 2018
c9adee5
docs and iface
jose-torres May 25, 2018
63d38d8
Merge remote-tracking branch 'apache/master' into writerTask
jose-torres May 25, 2018
331f437
increment epoch
jose-torres May 25, 2018
f3ce675
undo oop
jose-torres May 25, 2018
e0108d7
make rdd loop
jose-torres May 25, 2018
f400651
remote write RDD
jose-torres May 25, 2018
1aaad8d
rename classes
jose-torres May 25, 2018
59890d4
combine suites
jose-torres May 25, 2018
af1508c
fully rm old suite
jose-torres May 25, 2018
65837ac
reorder tests
jose-torres May 29, 2018
a68fae2
return future
jose-torres May 31, 2018
98d55e4
finish getting rid of old name
jose-torres May 31, 2018
e6b9118
synchronous
jose-torres May 31, 2018
629455b
finish rename
jose-torres May 31, 2018
cb6d42b
add timeouts
jose-torres Jun 13, 2018
59d6ff7
unalign
jose-torres Jun 13, 2018
f90388c
add note
jose-torres Jun 13, 2018
4bbdeae
parallel
jose-torres Jun 13, 2018
e57531d
fix compile
jose-torres Jun 13, 2018
cff37c4
fix compile
jose-torres Jun 13, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ case class ContinuousShuffleReadPartition(
// Initialized only on the executor, and only once even as we call compute() multiple times.
lazy val (reader: ContinuousShuffleReader, endpoint) = {
val env = SparkEnv.get.rpcEnv
val receiver = new UnsafeRowReceiver(queueSize, numShuffleWriters, epochIntervalMs, env)
val receiver = new RPCContinuousShuffleReader(
queueSize, numShuffleWriters, epochIntervalMs, env)
val endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID()}", receiver)

TaskContext.get().addTaskCompletionListener { ctx =>
env.stop(endpoint)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.streaming.continuous.shuffle

import org.apache.spark.sql.catalyst.expressions.UnsafeRow

/**
* Trait for writing to a continuous processing shuffle.
*/
trait ContinuousShuffleWriter {
def write(epoch: Iterator[UnsafeRow]): Unit
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think its the right interface. The ContinuousShuffleWriter interface should be for writing the shuffled rows. The implementation should not be responsible for actually deciding partitions (i.e. outputPartitioner.getPartition(row)), as you dont want to re-implement the partitioning in every implementation. So I think the interface should be def write(row: UnsafeRow, partitionId: Int)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better encapsulation to re-implement the partitioning in every ContinuousShuffleWriter implementation than to re-implement it in every ContinuousShuffleWriter user. (Note that the non-continuous ShuffleWriter has precedent for this: it uses the same interface, and all implementations of ShuffleWriter do re-implement partitioning.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. That's fair.

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ package org.apache.spark.sql.execution.streaming.continuous.shuffle
import java.util.concurrent._
import java.util.concurrent.atomic.AtomicBoolean

import scala.collection.mutable

import org.apache.spark.internal.Logging
import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
Expand All @@ -48,7 +46,7 @@ private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends UnsafeRow
* TODO: Support multiple source tasks. We need to output a single epoch marker once all
* source tasks have sent one.
*/
private[shuffle] class UnsafeRowReceiver(
private[shuffle] class RPCContinuousShuffleReader(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Caught what I think are the rest.

queueSize: Int,
numShuffleWriters: Int,
epochIntervalMs: Long,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.streaming.continuous.shuffle

import org.apache.spark.Partitioner
import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.sql.catalyst.expressions.UnsafeRow

/**
* A [[ContinuousShuffleWriter]] sending data to [[RPCContinuousShuffleReader]] instances.
*
* @param writerId The partition ID of this writer.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we don't use vertical alignment as they will introduce unnecessary changes in future.

* @param outputPartitioner The partitioner on the reader side of the shuffle.
* @param endpoints The [[RPCContinuousShuffleReader]] endpoints to write to. Indexed by
* partition ID within outputPartitioner.
*/
class RPCContinuousShuffleWriter(
writerId: Int,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: rename to partitionId?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I worry that partitionId is ambiguous with the partition to which the shuffle data is being written.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok makes sense.

outputPartitioner: Partitioner,
endpoints: Array[RpcEndpointRef]) extends ContinuousShuffleWriter {

if (outputPartitioner.numPartitions != 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason to disable it ? this should work rt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe so, but there's no way to test whether it will work until we implement the scheduling support for distributing the addresses of each of the multiple readers.

throw new IllegalArgumentException("multiple readers not yet supported")
}

if (outputPartitioner.numPartitions != endpoints.length) {
throw new IllegalArgumentException(s"partitioner size ${outputPartitioner.numPartitions} did " +
s"not match endpoint count ${endpoints.length}")
}

def write(epoch: Iterator[UnsafeRow]): Unit = {
while (epoch.hasNext) {
val row = epoch.next()
endpoints(outputPartitioner.getPartition(row)).ask[Unit](ReceiverRow(writerId, row))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about the case where the send fails? the result seem to be ignored here..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @zsxwing

It's my understanding that the RPC framework guarantees messages will be sent in the order that they're ask()ed, and that it's therefore not possible for a single row to fail to be sent while the ones before and after it succeed. If this is the case, then we don't need to handle it here - the query will just start failing to make progress.

If it's not the case, we'll need a more clever solution. Maybe have the epoch marker message contain a count for the number of rows that are supposed to be in the epoch?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A reliable channel (first case) seems like a requirement for correctness. In that case I think the query can just be restarted from the last successful epoch as soon as a failure is detected.

Copy link
Contributor Author

@jose-torres jose-torres May 31, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline with @zsxwing. It's actually not valid to be sending these async at all - the framework will retry e.g. connection failures on the next row, so we can end up committing an epoch before we detect that a row within it has failed to send. We need to just make these synchronous.

This will incur a slight round-trip latency penalty for now, but as mentioned earlier the TCP-based shuffle is what we actually plan to be production quality. I'm hoping to begin work on it after I finish one more PR on top of this. So I think the latency should be fine for now.

}

endpoints.foreach(_.ask[Unit](ReceiverEpochMarker(writerId)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,36 @@

package org.apache.spark.sql.execution.streaming.continuous.shuffle

import org.apache.spark.{TaskContext, TaskContextImpl}
import scala.collection.mutable

import org.apache.spark.{HashPartitioner, Partition, TaskContext, TaskContextImpl}
import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection}
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.execution.streaming.continuous.shuffle._
import org.apache.spark.sql.streaming.StreamTest
import org.apache.spark.sql.types.{DataType, IntegerType, StringType}
import org.apache.spark.unsafe.types.UTF8String

class ContinuousShuffleReadSuite extends StreamTest {
class ContinuousShuffleSuite extends StreamTest {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline. Merged these tests into the earlier test suite. Name the combined one appropriately.

// In this unit test, we emulate that we're in the task thread where
// ContinuousShuffleReadRDD.compute() will be evaluated. This requires a task context
// thread local to be set.
var ctx: TaskContextImpl = _

override def beforeEach(): Unit = {
super.beforeEach()
ctx = TaskContext.empty()
TaskContext.setTaskContext(ctx)
}

private def unsafeRow(value: Int) = {
override def afterEach(): Unit = {
ctx.markTaskCompleted(None)
TaskContext.unset()
ctx = null
super.afterEach()
}

private implicit def unsafeRow(value: Int) = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious: is there a reason to rearrange functions, this and below twos? Looks like they're same except changing this function to implicit.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And where it leverages the implicit attribute of this method? I'm not sure it is really needed, but I'm review on Github page so I might be missing here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

writer.write(Iterator(1, 2, 3)) and such leverages the implicit.

UnsafeProjection.create(Array(IntegerType : DataType))(
new GenericInternalRow(Array(value: Any)))
}
Expand All @@ -40,22 +60,129 @@ class ContinuousShuffleReadSuite extends StreamTest {
messages.foreach(endpoint.askSync[Unit](_))
}

// In this unit test, we emulate that we're in the task thread where
// ContinuousShuffleReadRDD.compute() will be evaluated. This requires a task context
// thread local to be set.
var ctx: TaskContextImpl = _
private def readRDDEndpoint(rdd: ContinuousShuffleReadRDD) = {
rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
}

override def beforeEach(): Unit = {
super.beforeEach()
ctx = TaskContext.empty()
TaskContext.setTaskContext(ctx)
private def readEpoch(rdd: ContinuousShuffleReadRDD) = {
rdd.compute(rdd.partitions(0), ctx).toSeq.map(_.getInt(0))
}

override def afterEach(): Unit = {
ctx.markTaskCompleted(None)
TaskContext.unset()
ctx = null
super.afterEach()
test("one epoch") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: i generally put the simplest test first (likely to be the reader tests since they dont depend on writer) and the more complex, e2e-ish tests later (writers since they needs readers).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reordered.

val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
val writer = new RPCContinuousShuffleWriter(
0, new HashPartitioner(1), Array(readRDDEndpoint(reader)))

writer.write(Iterator(1, 2, 3))

assert(readEpoch(reader) == Seq(1, 2, 3))
}

test("multiple epochs") {
val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
val writer = new RPCContinuousShuffleWriter(
0, new HashPartitioner(1), Array(readRDDEndpoint(reader)))

writer.write(Iterator(1, 2, 3))
writer.write(Iterator(4, 5, 6))

assert(readEpoch(reader) == Seq(1, 2, 3))
assert(readEpoch(reader) == Seq(4, 5, 6))
}

test("empty epochs") {
val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
val writer = new RPCContinuousShuffleWriter(
0, new HashPartitioner(1), Array(readRDDEndpoint(reader)))

writer.write(Iterator())
writer.write(Iterator(1, 2))
writer.write(Iterator())
writer.write(Iterator())
writer.write(Iterator(3, 4))
writer.write(Iterator())

assert(readEpoch(reader) == Seq())
assert(readEpoch(reader) == Seq(1, 2))
assert(readEpoch(reader) == Seq())
assert(readEpoch(reader) == Seq())
assert(readEpoch(reader) == Seq(3, 4))
assert(readEpoch(reader) == Seq())
}

test("blocks waiting for writer") {
val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
val writer = new RPCContinuousShuffleWriter(
0, new HashPartitioner(1), Array(readRDDEndpoint(reader)))

val readerEpoch = reader.compute(reader.partitions(0), ctx)

val readRowThread = new Thread {
override def run(): Unit = {
assert(readerEpoch.toSeq.map(_.getInt(0)) == Seq(1))
}
}
readRowThread.start()

eventually(timeout(streamingTimeout)) {
assert(readRowThread.getState == Thread.State.TIMED_WAITING)
}

// Once we write the epoch the thread should stop waiting and succeed.
writer.write(Iterator(1))
readRowThread.join()
}

test("multiple writer partitions") {
val numWriterPartitions = 3

val reader = new ContinuousShuffleReadRDD(
sparkContext, numPartitions = 1, numShuffleWriters = numWriterPartitions)
val writers = (0 until 3).map { idx =>
new RPCContinuousShuffleWriter(idx, new HashPartitioner(1), Array(readRDDEndpoint(reader)))
}

writers(0).write(Iterator(1, 4, 7))
writers(1).write(Iterator(2, 5))
writers(2).write(Iterator(3, 6))

writers(0).write(Iterator(4, 7, 10))
writers(1).write(Iterator(5, 8))
writers(2).write(Iterator(6, 9))

// Since there are multiple asynchronous writers, the original row sequencing is not guaranteed.
// The epochs should be deterministically preserved, however.
assert(readEpoch(reader).toSet == Seq(1, 2, 3, 4, 5, 6, 7).toSet)
assert(readEpoch(reader).toSet == Seq(4, 5, 6, 7, 8, 9, 10).toSet)
}

test("reader epoch only ends when all writer partitions write it") {
val numWriterPartitions = 3

val reader = new ContinuousShuffleReadRDD(
sparkContext, numPartitions = 1, numShuffleWriters = numWriterPartitions)
val writers = (0 until 3).map { idx =>
new RPCContinuousShuffleWriter(idx, new HashPartitioner(1), Array(readRDDEndpoint(reader)))
}

writers(1).write(Iterator())
writers(2).write(Iterator())

val readerEpoch = reader.compute(reader.partitions(0), ctx)

val readEpochMarkerThread = new Thread {
override def run(): Unit = {
assert(!readerEpoch.hasNext)
}
}

readEpochMarkerThread.start()
eventually(timeout(streamingTimeout)) {
assert(readEpochMarkerThread.getState == Thread.State.TIMED_WAITING)
}

writers(0).write(Iterator())
readEpochMarkerThread.join()
}

test("receiver stopped with row last") {
Expand All @@ -70,7 +197,7 @@ class ContinuousShuffleReadSuite extends StreamTest {
ctx.markTaskCompleted(None)
val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader
eventually(timeout(streamingTimeout)) {
assert(receiver.asInstanceOf[UnsafeRowReceiver].stopped.get())
assert(receiver.asInstanceOf[RPCContinuousShuffleReader].stopped.get())
}
}

Expand All @@ -86,11 +213,11 @@ class ContinuousShuffleReadSuite extends StreamTest {
ctx.markTaskCompleted(None)
val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader
eventually(timeout(streamingTimeout)) {
assert(receiver.asInstanceOf[UnsafeRowReceiver].stopped.get())
assert(receiver.asInstanceOf[RPCContinuousShuffleReader].stopped.get())
}
}

test("one epoch") {
test("reader - one epoch") {
val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
send(
Expand All @@ -105,7 +232,7 @@ class ContinuousShuffleReadSuite extends StreamTest {
assert(iter.toSeq.map(_.getInt(0)) == Seq(111, 222, 333))
}

test("multiple epochs") {
test("reader - multiple epochs") {
val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
send(
Expand All @@ -124,7 +251,7 @@ class ContinuousShuffleReadSuite extends StreamTest {
assert(secondEpoch.toSeq.map(_.getInt(0)) == Seq(222, 333))
}

test("empty epochs") {
test("reader - empty epochs") {
val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint

Expand All @@ -148,7 +275,7 @@ class ContinuousShuffleReadSuite extends StreamTest {
assert(rdd.compute(rdd.partitions(0), ctx).isEmpty)
}

test("multiple partitions") {
test("reader - multiple partitions") {
val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 5)
// Send all data before processing to ensure there's no crossover.
for (p <- rdd.partitions) {
Expand All @@ -169,7 +296,7 @@ class ContinuousShuffleReadSuite extends StreamTest {
}
}

test("blocks waiting for new rows") {
test("reader - blocks waiting for new rows") {
val rdd = new ContinuousShuffleReadRDD(
sparkContext, numPartitions = 1, epochIntervalMs = Long.MaxValue)
val epoch = rdd.compute(rdd.partitions(0), ctx)
Expand All @@ -195,7 +322,7 @@ class ContinuousShuffleReadSuite extends StreamTest {
}
}

test("multiple writers") {
test("reader - multiple writers") {
val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3)
val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
send(
Expand All @@ -213,7 +340,7 @@ class ContinuousShuffleReadSuite extends StreamTest {
Set("writer0-row0", "writer1-row0", "writer2-row0"))
}

test("epoch only ends when all writers send markers") {
test("reader - epoch only ends when all writers send markers") {
val rdd = new ContinuousShuffleReadRDD(
sparkContext, numPartitions = 1, numShuffleWriters = 3, epochIntervalMs = Long.MaxValue)
val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
Expand All @@ -233,6 +360,7 @@ class ContinuousShuffleReadSuite extends StreamTest {

// After checking the right rows, block until we get an epoch marker indicating there's no next.
// (Also fail the assertion if for some reason we get a row.)

val readEpochMarkerThread = new Thread {
override def run(): Unit = {
assert(!epoch.hasNext)
Expand All @@ -254,7 +382,7 @@ class ContinuousShuffleReadSuite extends StreamTest {
readEpochMarkerThread.join()
}

test("writer epochs non aligned") {
test("reader - writer epochs non aligned") {
val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3)
val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
// We send multiple epochs for 0, then multiple for 1, then multiple for 2. The receiver should
Expand Down