Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 51 additions & 11 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli

val startTime = System.currentTimeMillis()

@volatile private var stopped: Boolean = false

private def assertNotStopped(): Unit = {
if (stopped) {
throw new IllegalStateException("Cannot call methods on a stopped SparkContext")
}
}

/**
* Create a SparkContext that loads settings from system properties (for instance, when
* launching with ./bin/spark-submit).
Expand Down Expand Up @@ -526,6 +534,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* the argument to avoid this.
*/
def parallelize[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In 1.2, calling this when SparkContext was stopped would throw a NullPointerException:

scala> sc.parallelize(1 to 100)
java.lang.NullPointerException
    at org.apache.spark.SparkContext.defaultParallelism(SparkContext.scala:1461)
    at org.apache.spark.SparkContext.parallelize$default$2(SparkContext.scala:521)
    at $iwC$$iwC$$iwC$$iwC.<init>(<console>:13)
    at $iwC$$iwC$$iwC.<init>(<console>:18)
    at $iwC$$iwC.<init>(<console>:20)
    at $iwC.<init>(<console>:22)
    at <init>(<console>:24)
    at .<init>(<console>:28)
    at .<clinit>(<console>)
    at .<init>(<console>:7)
    at .<clinit>(<console>)
    at $print(<console>)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:606)
    at org.apache.spark.repl.SparkIMain$ReadEvalPrint.call(SparkIMain.scala:852)
    at org.apache.spark.repl.SparkIMain$Request.loadAndRun(SparkIMain.scala:1125)
    at org.apache.spark.repl.SparkIMain.loadAndRunReq$1(SparkIMain.scala:674)
    at org.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:705)
    at org.apache.spark.repl.SparkIMain.interpret(SparkIMain.scala:669)
    at org.apache.spark.repl.SparkILoop.reallyInterpret$1(SparkILoop.scala:828)
    at org.apache.spark.repl.SparkILoop.interpretStartingWith(SparkILoop.scala:873)
    at org.apache.spark.repl.SparkILoop.command(SparkILoop.scala:785)
    at org.apache.spark.repl.SparkILoop.processLine$1(SparkILoop.scala:628)
    at org.apache.spark.repl.SparkILoop.innerLoop$1(SparkILoop.scala:636)
    at org.apache.spark.repl.SparkILoop.loop(SparkILoop.scala:641)
    at org.apache.spark.repl.SparkILoop$$anonfun$process$1.apply$mcZ$sp(SparkILoop.scala:968)
    at org.apache.spark.repl.SparkILoop$$anonfun$process$1.apply(SparkILoop.scala:916)
    at org.apache.spark.repl.SparkILoop$$anonfun$process$1.apply(SparkILoop.scala:916)
    at scala.tools.nsc.util.ScalaClassLoader$.savingContextLoader(ScalaClassLoader.scala:135)
    at org.apache.spark.repl.SparkILoop.process(SparkILoop.scala:916)
    at org.apache.spark.repl.SparkILoop.process(SparkILoop.scala:1011)
    at org.apache.spark.repl.Main$.main(Main.scala:31)
    at org.apache.spark.repl.Main.main(Main.scala)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:606)
    at org.apache.spark.deploy.SparkSubmit$.launch(SparkSubmit.scala:358)
    at org.apache.spark.deploy.SparkSubmit$.main(SparkSubmit.scala:75)
    at org.apache.spark.deploy.SparkSubmit.main(SparkSubmit.scala)

assertNotStopped()
new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]())
}

Expand All @@ -541,6 +550,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* location preferences (hostnames of Spark nodes) for each object.
* Create a new partition for each collection item. */
def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = {
assertNotStopped()
val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap
new ParallelCollectionRDD[T](this, seq.map(_._1), seq.size, indexToPrefs)
}
Expand All @@ -550,6 +560,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* Hadoop-supported file system URI, and return it as an RDD of Strings.
*/
def textFile(path: String, minPartitions: Int = defaultMinPartitions): RDD[String] = {
assertNotStopped()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for textFile:

scala> sc.textFile("/usr/share/dict/words")
java.lang.NullPointerException
    at org.apache.spark.SparkContext.defaultParallelism(SparkContext.scala:1461)
    at org.apache.spark.SparkContext.defaultMinPartitions(SparkContext.scala:1468)
    at org.apache.spark.SparkContext.textFile$default$2(SparkContext.scala:545)

hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text],
minPartitions).map(pair => pair._2.toString).setName(path)
}
Expand Down Expand Up @@ -583,6 +594,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
*/
def wholeTextFiles(path: String, minPartitions: Int = defaultMinPartitions):
RDD[(String, String)] = {
assertNotStopped()
val job = new NewHadoopJob(hadoopConfiguration)
NewFileInputFormat.addInputPath(job, new Path(path))
val updateConf = job.getConfiguration
Expand Down Expand Up @@ -628,6 +640,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
@Experimental
def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions):
RDD[(String, PortableDataStream)] = {
assertNotStopped()
val job = new NewHadoopJob(hadoopConfiguration)
NewFileInputFormat.addInputPath(job, new Path(path))
val updateConf = job.getConfiguration
Expand All @@ -652,6 +665,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
@Experimental
def binaryRecords(path: String, recordLength: Int, conf: Configuration = hadoopConfiguration)
: RDD[Array[Byte]] = {
assertNotStopped()
conf.setInt(FixedLengthBinaryInputFormat.RECORD_LENGTH_PROPERTY, recordLength)
val br = newAPIHadoopFile[LongWritable, BytesWritable, FixedLengthBinaryInputFormat](path,
classOf[FixedLengthBinaryInputFormat],
Expand Down Expand Up @@ -685,6 +699,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
valueClass: Class[V],
minPartitions: Int = defaultMinPartitions
): RDD[(K, V)] = {
assertNotStopped()
// Add necessary security credentials to the JobConf before broadcasting it.
SparkHadoopUtil.get.addCredentials(conf)
new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minPartitions)
Expand All @@ -704,6 +719,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
valueClass: Class[V],
minPartitions: Int = defaultMinPartitions
): RDD[(K, V)] = {
assertNotStopped()
// A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it.
val confBroadcast = broadcast(new SerializableWritable(hadoopConfiguration))
val setInputPathsFunc = (jobConf: JobConf) => FileInputFormat.setInputPaths(jobConf, path)
Expand Down Expand Up @@ -783,6 +799,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
kClass: Class[K],
vClass: Class[V],
conf: Configuration = hadoopConfiguration): RDD[(K, V)] = {
assertNotStopped()
val job = new NewHadoopJob(conf)
NewFileInputFormat.addInputPath(job, new Path(path))
val updatedConf = job.getConfiguration
Expand All @@ -803,6 +820,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
fClass: Class[F],
kClass: Class[K],
vClass: Class[V]): RDD[(K, V)] = {
assertNotStopped()
new NewHadoopRDD(this, fClass, kClass, vClass, conf)
}

Expand All @@ -818,6 +836,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
valueClass: Class[V],
minPartitions: Int
): RDD[(K, V)] = {
assertNotStopped()
val inputFormatClass = classOf[SequenceFileInputFormat[K, V]]
hadoopFile(path, inputFormatClass, keyClass, valueClass, minPartitions)
}
Expand All @@ -829,9 +848,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* If you plan to directly cache Hadoop writable objects, you should first copy them using
* a `map` function.
* */
def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]
): RDD[(K, V)] =
def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]): RDD[(K, V)] = {
assertNotStopped()
sequenceFile(path, keyClass, valueClass, defaultMinPartitions)
}

/**
* Version of sequenceFile() for types implicitly convertible to Writables through a
Expand Down Expand Up @@ -859,6 +879,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
(implicit km: ClassTag[K], vm: ClassTag[V],
kcf: () => WritableConverter[K], vcf: () => WritableConverter[V])
: RDD[(K, V)] = {
assertNotStopped()
val kc = kcf()
val vc = vcf()
val format = classOf[SequenceFileInputFormat[Writable, Writable]]
Expand All @@ -880,6 +901,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
path: String,
minPartitions: Int = defaultMinPartitions
): RDD[T] = {
assertNotStopped()
sequenceFile(path, classOf[NullWritable], classOf[BytesWritable], minPartitions)
.flatMap(x => Utils.deserialize[Array[T]](x._2.getBytes, Utils.getContextOrSparkClassLoader))
}
Expand Down Expand Up @@ -955,6 +977,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* The variable will be sent to each cluster only once.
*/
def broadcast[T: ClassTag](value: T): Broadcast[T] = {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Broadcast, on the other hand, throws a NPE:

scala> sc.broadcast(0)
java.lang.NullPointerException
    at org.apache.spark.broadcast.TorrentBroadcast.<init>(TorrentBroadcast.scala:79)
    at org.apache.spark.broadcast.TorrentBroadcastFactory.newBroadcast(TorrentBroadcastFactory.scala:34)
    at org.apache.spark.broadcast.TorrentBroadcastFactory.newBroadcast(TorrentBroadcastFactory.scala:29)
    at org.apache.spark.broadcast.BroadcastManager.newBroadcast(BroadcastManager.scala:62)
    at org.apache.spark.SparkContext.broadcast(SparkContext.scala:951)
    at $iwC$$iwC$$iwC$$iwC.<init>(<console>:13)
    at $iwC$$iwC$$iwC.<init>(<console>:18)
    at $iwC$$iwC.<init>(<console>:20)
    at $iwC.<init>(<console>:22)
    at <init>(<console>:24)
    at .<init>(<console>:28)
    at .<clinit>(<console>)
    at .<init>(<console>:7)
    at .<clinit>(<console>)
    at $print(<console>)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:606)

assertNotStopped()
if (classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass)) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, maybe this check should go somewhere else, since I think that it might technically have been safe to create a broadcast variable with an RDD, even though doing anything with it would trigger errors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've changed this in my latest patch; we log a warning here and any errors are caught by the more general "display an error about RDD nesting if the sc field is null" check.

// This is a warning instead of an exception in order to avoid breaking user programs that
// might have created RDD broadcast variables but not used them:
logWarning("Can not directly broadcast RDDs; instead, call collect() and "
+ "broadcast the result (see SPARK-5063)")
}
val bc = env.broadcastManager.newBroadcast[T](value, isLocal)
val callSite = getCallSite
logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm)
Expand Down Expand Up @@ -1047,6 +1076,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* memory available for caching.
*/
def getExecutorMemoryStatus: Map[String, (Long, Long)] = {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This throws an error, so I'll keep it:

scala> sc.getExecutorMemoryStatus
org.apache.spark.SparkException: Error sending message as actor is null [message = GetMemoryStatus]
    at org.apache.spark.util.AkkaUtils$.askWithReply(AkkaUtils.scala:178)
    at org.apache.spark.storage.BlockManagerMaster.askDriverWithReply(BlockManagerMaster.scala:221)
    at org.apache.spark.storage.BlockManagerMaster.getMemoryStatus(BlockManagerMaster.scala:148)
    at org.apache.spark.SparkContext.getExecutorMemoryStatus(SparkContext.scala:1039)
    at $iwC$$iwC$$iwC$$iwC.<init>(<console>:13)
    at $iwC$$iwC$$iwC.<init>(<console>:18)
    at $iwC$$iwC.<init>(<console>:20)
    at $iwC.<init>(<console>:22)
    at <init>(<console>:24)
    at .<init>(<console>:28)
    at .<clinit>(<console>)
    at .<init>(<console>:7)
    at .<clinit>(<console>)
    at $print(<console>)

assertNotStopped()
env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) =>
(blockManagerId.host + ":" + blockManagerId.port, mem)
}
Expand All @@ -1059,6 +1089,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
*/
@DeveloperApi
def getRDDStorageInfo: Array[RDDInfo] = {
assertNotStopped()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here:

scala> sc.getRDDStorageInfo
org.apache.spark.SparkException: Error sending message as actor is null [message = GetStorageStatus]
    at org.apache.spark.util.AkkaUtils$.askWithReply(AkkaUtils.scala:178)
    at org.apache.spark.storage.BlockManagerMaster.askDriverWithReply(BlockManagerMaster.scala:221)
    at org.apache.spark.storage.BlockManagerMaster.getStorageStatus(BlockManagerMaster.scala:152)
    at org.apache.spark.SparkContext.getExecutorStorageStatus(SparkContext.scala:1068)
    at org.apache.spark.SparkContext.getRDDStorageInfo(SparkContext.scala:1052)
    at $iwC$$iwC$$iwC$$iwC.<init>(<console>:13)
    at $iwC$$iwC$$iwC.<init>(<console>:18)
    at $iwC$$iwC.<init>(<console>:20)
    at $iwC.<init>(<console>:22)
    at <init>(<console>:24)
    at .<init>(<console>:28)
    at .<clinit>(<console>)
    at .<init>(<console>:7)
    at .<clinit>(<console>)
    at $print(<console>)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(N

val rddInfos = persistentRdds.values.map(RDDInfo.fromRdd).toArray
StorageUtils.updateRddInfo(rddInfos, getExecutorStorageStatus)
rddInfos.filter(_.isCached)
Expand All @@ -1076,6 +1107,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
*/
@DeveloperApi
def getExecutorStorageStatus: Array[StorageStatus] = {
assertNotStopped()
env.blockManager.master.getStorageStatus
}

Expand All @@ -1085,6 +1117,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
*/
@DeveloperApi
def getAllPools: Seq[Schedulable] = {
assertNotStopped()
// TODO(xiajunluan): We should take nested pools into account
taskScheduler.rootPool.schedulableQueue.toSeq
}
Expand All @@ -1095,13 +1128,15 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
*/
@DeveloperApi
def getPoolForName(pool: String): Option[Schedulable] = {
assertNotStopped()
Option(taskScheduler.rootPool.schedulableNameToSchedulable.get(pool))
}

/**
* Return current scheduling mode
*/
def getSchedulingMode: SchedulingMode.SchedulingMode = {
assertNotStopped()
taskScheduler.schedulingMode
}

Expand Down Expand Up @@ -1207,16 +1242,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
SparkContext.SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized {
postApplicationEnd()
ui.foreach(_.stop())
// Do this only if not stopped already - best case effort.
// prevent NPE if stopped more than once.
val dagSchedulerCopy = dagScheduler
dagScheduler = null
if (dagSchedulerCopy != null) {
if (!stopped) {
stopped = true
env.metricsSystem.report()
metadataCleaner.cancel()
env.actorSystem.stop(heartbeatReceiver)
cleaner.foreach(_.stop())
dagSchedulerCopy.stop()
dagScheduler.stop()
dagScheduler = null
taskScheduler = null
// TODO: Cache.stop()?
env.stop()
Expand Down Expand Up @@ -1290,8 +1323,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
partitions: Seq[Int],
allowLocal: Boolean,
resultHandler: (Int, U) => Unit) {
if (dagScheduler == null) {
throw new SparkException("SparkContext has been shutdown")
if (stopped) {
throw new IllegalStateException("SparkContext has been shutdown")
}
val callSite = getCallSite
val cleanedFunc = clean(func)
Expand Down Expand Up @@ -1378,6 +1411,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
func: (TaskContext, Iterator[T]) => U,
evaluator: ApproximateEvaluator[U, R],
timeout: Long): PartialResult[R] = {
assertNotStopped()
val callSite = getCallSite
logInfo("Starting job: " + callSite.shortForm)
val start = System.nanoTime
Expand All @@ -1400,6 +1434,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
resultHandler: (Int, U) => Unit,
resultFunc: => R): SimpleFutureAction[R] =
{
assertNotStopped()
val cleanF = clean(processPartition)
val callSite = getCallSite
val waiter = dagScheduler.submitJob(
Expand All @@ -1418,11 +1453,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* for more information.
*/
def cancelJobGroup(groupId: String) {
assertNotStopped()
dagScheduler.cancelJobGroup(groupId)
}

/** Cancel all jobs that have been scheduled or are running. */
def cancelAllJobs() {
assertNotStopped()
dagScheduler.cancelAllJobs()
}

Expand Down Expand Up @@ -1469,7 +1506,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
def getCheckpointDir = checkpointDir

/** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */
def defaultParallelism: Int = taskScheduler.defaultParallelism
def defaultParallelism: Int = {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This throws an exception because taskScheduler is null:

scala> sc.defaultParallelism
java.lang.NullPointerException
    at org.apache.spark.SparkContext.defaultParallelism(SparkContext.scala:1461)
    at $iwC$$iwC$$iwC$$iwC.<init>(<console>:13)
    at $iwC$$iwC$$iwC.<init>(<console>:18)
    at $iwC$$iwC.<init>(<console>:20)
    at $iwC.<init>(<console>:22)
    at <init>(<console>:24)
    at .<init>(<console>:28)
    at .<clinit>(<console>)
    at .<init>(<console>:7)
    at .<clinit>(<console>)
    at $print(<console>)

assertNotStopped()
taskScheduler.defaultParallelism
}

/** Default min number of partitions for Hadoop RDDs when not given by user */
@deprecated("use defaultMinPartitions", "1.0.0")
Expand Down
19 changes: 18 additions & 1 deletion core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,27 @@ import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, Bernoulli
* on RDD internals.
*/
abstract class RDD[T: ClassTag](
@transient private var sc: SparkContext,
@transient private var _sc: SparkContext,
@transient private var deps: Seq[Dependency[_]]
) extends Serializable with Logging {

if (classOf[RDD[_]].isAssignableFrom(elementClassTag.runtimeClass)) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, this should perhaps be a warning instead of an exception in order to avoid any possibility of breaking odd corner-case 1.2.1 apps. I'll change this to a warning and leave the sc getter as an exception.

// This is a warning instead of an exception in order to avoid breaking user programs that
// might have defined nested RDDs without running jobs with them.
logWarning("Spark does not support nested RDDs (see SPARK-5063)")
}

private def sc: SparkContext = {
if (_sc == null) {
throw new SparkException(
"RDD transformations and actions can only be invoked by the driver, not inside of other " +
"transformations; for example, rdd1.map(x => rdd2.values.count() * x) is invalid because " +
"the values transformation and count action cannot be performed inside of the rdd1.map " +
"transformation. For more information, see SPARK-5063.")
}
_sc
}

/** Construct an RDD with just a one-to-one dependency on one parent */
def this(@transient oneParent: RDD[_]) =
this(oneParent.context , List(new OneToOneDependency(oneParent)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,15 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
testPackage.runCallSiteTest(sc)
}

test("Broadcast variables cannot be created after SparkContext is stopped (SPARK-5065)") {
sc = new SparkContext("local", "test")
sc.stop()
val thrown = intercept[IllegalStateException] {
sc.broadcast(Seq(1, 2, 3))
}
assert(thrown.getMessage.toLowerCase.contains("stopped"))
}

/**
* Verify the persistence of state associated with an HttpBroadcast in either local mode or
* local-cluster mode (when distributed = true).
Expand Down Expand Up @@ -349,8 +358,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
package object testPackage extends Assertions {

def runCallSiteTest(sc: SparkContext) {
val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2)
val broadcast = sc.broadcast(rdd)
val broadcast = sc.broadcast(Array(1, 2, 3, 4))
broadcast.destroy()
val thrown = intercept[SparkException] { broadcast.value }
assert(thrown.getMessage.contains("BroadcastSuite.scala"))
Expand Down
40 changes: 40 additions & 0 deletions core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -918,4 +918,44 @@ class RDDSuite extends FunSuite with SharedSparkContext {
mutableDependencies += dep
}
}

test("nested RDDs are not supported (SPARK-5063)") {
val rdd: RDD[Int] = sc.parallelize(1 to 100)
val rdd2: RDD[Int] = sc.parallelize(1 to 100)
val thrown = intercept[SparkException] {
val nestedRDD: RDD[RDD[Int]] = rdd.mapPartitions { x => Seq(rdd2.map(x => x)).iterator }
nestedRDD.count()
}
assert(thrown.getMessage.contains("SPARK-5063"))
}

test("actions cannot be performed inside of transformations (SPARK-5063)") {
val rdd: RDD[Int] = sc.parallelize(1 to 100)
val rdd2: RDD[Int] = sc.parallelize(1 to 100)
val thrown = intercept[SparkException] {
rdd.map(x => x * rdd2.count).collect()
}
assert(thrown.getMessage.contains("SPARK-5063"))
}

test("cannot run actions after SparkContext has been stopped (SPARK-5063)") {
val existingRDD = sc.parallelize(1 to 100)
sc.stop()
val thrown = intercept[IllegalStateException] {
existingRDD.count()
}
assert(thrown.getMessage.contains("shutdown"))
}

test("cannot call methods on a stopped SparkContext (SPARK-5063)") {
sc.stop()
def assertFails(block: => Any): Unit = {
val thrown = intercept[IllegalStateException] {
block
}
assert(thrown.getMessage.contains("stopped"))
}
assertFails { sc.parallelize(1 to 100) }
assertFails { sc.textFile("/nonexistent-path") }
}
}
8 changes: 8 additions & 0 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,14 @@ def _ensure_initialized(cls, instance=None, gateway=None):
else:
SparkContext._active_spark_context = instance

def __getnewargs__(self):
# This method is called when attempting to pickle SparkContext, which is always an error:
raise Exception(
"It appears that you are attempting to reference SparkContext from a broadcast "
"variable, action, or transforamtion. SparkContext can only be used on the driver, "
"not in code that it run on workers. For more information, see SPARK-5063."
)

def __enter__(self):
"""
Enable 'with SparkContext(...) as sc: app(sc)' syntax.
Expand Down
Loading