rdd = sparkContext.wholeTextFiles("hdfs://a-hdfs-path")
+ * }}}
*
* then `rdd` contains
* {{{
@@ -210,6 +232,84 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
def wholeTextFiles(path: String): JavaPairRDD[String, String] =
new JavaPairRDD(sc.wholeTextFiles(path))
+ /**
+ * Read a directory of binary files from HDFS, a local file system (available on all nodes),
+ * or any Hadoop-supported file system URI as a byte array. Each file is read as a single
+ * record and returned in a key-value pair, where the key is the path of each file,
+ * the value is the content of each file.
+ *
+ * For example, if you have the following files:
+ * {{{
+ * hdfs://a-hdfs-path/part-00000
+ * hdfs://a-hdfs-path/part-00001
+ * ...
+ * hdfs://a-hdfs-path/part-nnnnn
+ * }}}
+ *
+ * Do
+ * `JavaPairRDD rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`,
+ *
+ * then `rdd` contains
+ * {{{
+ * (a-hdfs-path/part-00000, its content)
+ * (a-hdfs-path/part-00001, its content)
+ * ...
+ * (a-hdfs-path/part-nnnnn, its content)
+ * }}}
+ *
+ * @note Small files are preferred; very large files but may cause bad performance.
+ *
+ * @param minPartitions A suggestion value of the minimal splitting number for input data.
+ */
+ def binaryFiles(path: String, minPartitions: Int): JavaPairRDD[String, PortableDataStream] =
+ new JavaPairRDD(sc.binaryFiles(path, minPartitions))
+
+ /**
+ * :: Experimental ::
+ *
+ * Read a directory of binary files from HDFS, a local file system (available on all nodes),
+ * or any Hadoop-supported file system URI as a byte array. Each file is read as a single
+ * record and returned in a key-value pair, where the key is the path of each file,
+ * the value is the content of each file.
+ *
+ * For example, if you have the following files:
+ * {{{
+ * hdfs://a-hdfs-path/part-00000
+ * hdfs://a-hdfs-path/part-00001
+ * ...
+ * hdfs://a-hdfs-path/part-nnnnn
+ * }}}
+ *
+ * Do
+ * `JavaPairRDD rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`,
+ *
+ * then `rdd` contains
+ * {{{
+ * (a-hdfs-path/part-00000, its content)
+ * (a-hdfs-path/part-00001, its content)
+ * ...
+ * (a-hdfs-path/part-nnnnn, its content)
+ * }}}
+ *
+ * @note Small files are preferred; very large files but may cause bad performance.
+ */
+ @Experimental
+ def binaryFiles(path: String): JavaPairRDD[String, PortableDataStream] =
+ new JavaPairRDD(sc.binaryFiles(path, defaultMinPartitions))
+
+ /**
+ * :: Experimental ::
+ *
+ * Load data from a flat binary file, assuming the length of each record is constant.
+ *
+ * @param path Directory to the input data files
+ * @return An RDD of data with values, represented as byte arrays
+ */
+ @Experimental
+ def binaryRecords(path: String, recordLength: Int): JavaRDD[Array[Byte]] = {
+ new JavaRDD(sc.binaryRecords(path, recordLength))
+ }
+
/** Get an RDD for a Hadoop SequenceFile with given key and value types.
*
* '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
@@ -284,7 +384,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
): JavaPairRDD[K, V] = {
implicit val ctagK: ClassTag[K] = ClassTag(keyClass)
implicit val ctagV: ClassTag[V] = ClassTag(valueClass)
- new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass, minPartitions))
+ val rdd = sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass, minPartitions)
+ new JavaHadoopRDD(rdd.asInstanceOf[HadoopRDD[K, V]])
}
/**
@@ -304,7 +405,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
): JavaPairRDD[K, V] = {
implicit val ctagK: ClassTag[K] = ClassTag(keyClass)
implicit val ctagV: ClassTag[V] = ClassTag(valueClass)
- new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass))
+ val rdd = sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass)
+ new JavaHadoopRDD(rdd.asInstanceOf[HadoopRDD[K, V]])
}
/** Get an RDD for a Hadoop file with an arbitrary InputFormat.
@@ -323,7 +425,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
): JavaPairRDD[K, V] = {
implicit val ctagK: ClassTag[K] = ClassTag(keyClass)
implicit val ctagV: ClassTag[V] = ClassTag(valueClass)
- new JavaPairRDD(sc.hadoopFile(path, inputFormatClass, keyClass, valueClass, minPartitions))
+ val rdd = sc.hadoopFile(path, inputFormatClass, keyClass, valueClass, minPartitions)
+ new JavaHadoopRDD(rdd.asInstanceOf[HadoopRDD[K, V]])
}
/** Get an RDD for a Hadoop file with an arbitrary InputFormat
@@ -341,8 +444,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
): JavaPairRDD[K, V] = {
implicit val ctagK: ClassTag[K] = ClassTag(keyClass)
implicit val ctagV: ClassTag[V] = ClassTag(valueClass)
- new JavaPairRDD(sc.hadoopFile(path,
- inputFormatClass, keyClass, valueClass))
+ val rdd = sc.hadoopFile(path, inputFormatClass, keyClass, valueClass)
+ new JavaHadoopRDD(rdd.asInstanceOf[HadoopRDD[K, V]])
}
/**
@@ -362,7 +465,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
conf: Configuration): JavaPairRDD[K, V] = {
implicit val ctagK: ClassTag[K] = ClassTag(kClass)
implicit val ctagV: ClassTag[V] = ClassTag(vClass)
- new JavaPairRDD(sc.newAPIHadoopFile(path, fClass, kClass, vClass, conf))
+ val rdd = sc.newAPIHadoopFile(path, fClass, kClass, vClass, conf)
+ new JavaNewHadoopRDD(rdd.asInstanceOf[NewHadoopRDD[K, V]])
}
/**
@@ -381,7 +485,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
vClass: Class[V]): JavaPairRDD[K, V] = {
implicit val ctagK: ClassTag[K] = ClassTag(kClass)
implicit val ctagV: ClassTag[V] = ClassTag(vClass)
- new JavaPairRDD(sc.newAPIHadoopRDD(conf, fClass, kClass, vClass))
+ val rdd = sc.newAPIHadoopRDD(conf, fClass, kClass, vClass)
+ new JavaNewHadoopRDD(rdd.asInstanceOf[NewHadoopRDD[K, V]])
}
/** Build the union of two or more RDDs. */
@@ -414,6 +519,16 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
def intAccumulator(initialValue: Int): Accumulator[java.lang.Integer] =
sc.accumulator(initialValue)(IntAccumulatorParam).asInstanceOf[Accumulator[java.lang.Integer]]
+ /**
+ * Create an [[org.apache.spark.Accumulator]] integer variable, which tasks can "add" values
+ * to using the `add` method. Only the master can access the accumulator's `value`.
+ *
+ * This version supports naming the accumulator for display in Spark's web UI.
+ */
+ def intAccumulator(initialValue: Int, name: String): Accumulator[java.lang.Integer] =
+ sc.accumulator(initialValue, name)(IntAccumulatorParam)
+ .asInstanceOf[Accumulator[java.lang.Integer]]
+
/**
* Create an [[org.apache.spark.Accumulator]] double variable, which tasks can "add" values
* to using the `add` method. Only the master can access the accumulator's `value`.
@@ -421,12 +536,31 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
def doubleAccumulator(initialValue: Double): Accumulator[java.lang.Double] =
sc.accumulator(initialValue)(DoubleAccumulatorParam).asInstanceOf[Accumulator[java.lang.Double]]
+ /**
+ * Create an [[org.apache.spark.Accumulator]] double variable, which tasks can "add" values
+ * to using the `add` method. Only the master can access the accumulator's `value`.
+ *
+ * This version supports naming the accumulator for display in Spark's web UI.
+ */
+ def doubleAccumulator(initialValue: Double, name: String): Accumulator[java.lang.Double] =
+ sc.accumulator(initialValue, name)(DoubleAccumulatorParam)
+ .asInstanceOf[Accumulator[java.lang.Double]]
+
/**
* Create an [[org.apache.spark.Accumulator]] integer variable, which tasks can "add" values
* to using the `add` method. Only the master can access the accumulator's `value`.
*/
def accumulator(initialValue: Int): Accumulator[java.lang.Integer] = intAccumulator(initialValue)
+ /**
+ * Create an [[org.apache.spark.Accumulator]] integer variable, which tasks can "add" values
+ * to using the `add` method. Only the master can access the accumulator's `value`.
+ *
+ * This version supports naming the accumulator for display in Spark's web UI.
+ */
+ def accumulator(initialValue: Int, name: String): Accumulator[java.lang.Integer] =
+ intAccumulator(initialValue, name)
+
/**
* Create an [[org.apache.spark.Accumulator]] double variable, which tasks can "add" values
* to using the `add` method. Only the master can access the accumulator's `value`.
@@ -434,6 +568,16 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
def accumulator(initialValue: Double): Accumulator[java.lang.Double] =
doubleAccumulator(initialValue)
+
+ /**
+ * Create an [[org.apache.spark.Accumulator]] double variable, which tasks can "add" values
+ * to using the `add` method. Only the master can access the accumulator's `value`.
+ *
+ * This version supports naming the accumulator for display in Spark's web UI.
+ */
+ def accumulator(initialValue: Double, name: String): Accumulator[java.lang.Double] =
+ doubleAccumulator(initialValue, name)
+
/**
* Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add"
* values to using the `add` method. Only the master can access the accumulator's `value`.
@@ -441,6 +585,16 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
def accumulator[T](initialValue: T, accumulatorParam: AccumulatorParam[T]): Accumulator[T] =
sc.accumulator(initialValue)(accumulatorParam)
+ /**
+ * Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add"
+ * values to using the `add` method. Only the master can access the accumulator's `value`.
+ *
+ * This version supports naming the accumulator for display in Spark's web UI.
+ */
+ def accumulator[T](initialValue: T, name: String, accumulatorParam: AccumulatorParam[T])
+ : Accumulator[T] =
+ sc.accumulator(initialValue, name)(accumulatorParam)
+
/**
* Create an [[org.apache.spark.Accumulable]] shared variable of the given type, to which tasks
* can "add" values with `add`. Only the master can access the accumuable's `value`.
@@ -448,6 +602,16 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
def accumulable[T, R](initialValue: T, param: AccumulableParam[T, R]): Accumulable[T, R] =
sc.accumulable(initialValue)(param)
+ /**
+ * Create an [[org.apache.spark.Accumulable]] shared variable of the given type, to which tasks
+ * can "add" values with `add`. Only the master can access the accumuable's `value`.
+ *
+ * This version supports naming the accumulator for display in Spark's web UI.
+ */
+ def accumulable[T, R](initialValue: T, name: String, param: AccumulableParam[T, R])
+ : Accumulable[T, R] =
+ sc.accumulable(initialValue, name)(param)
+
/**
* Broadcast a read-only variable to the cluster, returning a
* [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
@@ -460,6 +624,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
sc.stop()
}
+ override def close(): Unit = stop()
+
/**
* Get Spark's home location from either a value set through the constructor,
* or the spark.home Java property, or the SPARK_HOME environment variable
@@ -471,7 +637,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
* Add a file to be downloaded with this Spark job on every node.
* The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
* filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs,
- * use `SparkFiles.get(path)` to find its download location.
+ * use `SparkFiles.get(fileName)` to find its download location.
*/
def addFile(path: String) {
sc.addFile(path)
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala
new file mode 100644
index 0000000000000..3300cad9efbab
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java
+
+import org.apache.spark.{SparkStageInfo, SparkJobInfo, SparkContext}
+
+/**
+ * Low-level status reporting APIs for monitoring job and stage progress.
+ *
+ * These APIs intentionally provide very weak consistency semantics; consumers of these APIs should
+ * be prepared to handle empty / missing information. For example, a job's stage ids may be known
+ * but the status API may not have any information about the details of those stages, so
+ * `getStageInfo` could potentially return `null` for a valid stage id.
+ *
+ * To limit memory usage, these APIs only provide information on recent jobs / stages. These APIs
+ * will provide information for the last `spark.ui.retainedStages` stages and
+ * `spark.ui.retainedJobs` jobs.
+ *
+ * NOTE: this class's constructor should be considered private and may be subject to change.
+ */
+class JavaSparkStatusTracker private[spark] (sc: SparkContext) {
+
+ /**
+ * Return a list of all known jobs in a particular job group. If `jobGroup` is `null`, then
+ * returns all known jobs that are not associated with a job group.
+ *
+ * The returned list may contain running, failed, and completed jobs, and may vary across
+ * invocations of this method. This method does not guarantee the order of the elements in
+ * its result.
+ */
+ def getJobIdsForGroup(jobGroup: String): Array[Int] = sc.statusTracker.getJobIdsForGroup(jobGroup)
+
+ /**
+ * Returns an array containing the ids of all active stages.
+ *
+ * This method does not guarantee the order of the elements in its result.
+ */
+ def getActiveStageIds(): Array[Int] = sc.statusTracker.getActiveStageIds()
+
+ /**
+ * Returns an array containing the ids of all active jobs.
+ *
+ * This method does not guarantee the order of the elements in its result.
+ */
+ def getActiveJobIds(): Array[Int] = sc.statusTracker.getActiveJobIds()
+
+ /**
+ * Returns job information, or `null` if the job info could not be found or was garbage collected.
+ */
+ def getJobInfo(jobId: Int): SparkJobInfo = sc.statusTracker.getJobInfo(jobId).orNull
+
+ /**
+ * Returns stage information, or `null` if the stage info could not be found or was
+ * garbage collected.
+ */
+ def getStageInfo(stageId: Int): SparkStageInfo = sc.statusTracker.getStageInfo(stageId).orNull
+}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
index 22810cb1c662d..b52d0a5028e84 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
@@ -19,10 +19,20 @@ package org.apache.spark.api.java
import com.google.common.base.Optional
+import scala.collection.convert.Wrappers.MapWrapper
+
private[spark] object JavaUtils {
def optionToOptional[T](option: Option[T]): Optional[T] =
option match {
case Some(value) => Optional.of(value)
case None => Optional.absent()
}
+
+ // Workaround for SPARK-3926 / SI-8911
+ def mapAsSerializableJavaMap[A, B](underlying: collection.Map[A, B]) =
+ new SerializableMapWrapper(underlying)
+
+ class SerializableMapWrapper[A, B](underlying: collection.Map[A, B])
+ extends MapWrapper(underlying) with java.io.Serializable
+
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala
index adaa1ef6cf9ff..5ba66178e2b78 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala
@@ -17,8 +17,10 @@
package org.apache.spark.api.python
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
-import org.apache.spark.Logging
+import org.apache.spark.util.Utils
+import org.apache.spark.{Logging, SerializableWritable, SparkException}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io._
import scala.util.{Failure, Success, Try}
@@ -31,16 +33,17 @@ import org.apache.spark.annotation.Experimental
* transformation code by overriding the convert method.
*/
@Experimental
-trait Converter[T, U] extends Serializable {
+trait Converter[T, + U] extends Serializable {
def convert(obj: T): U
}
private[python] object Converter extends Logging {
- def getInstance(converterClass: Option[String]): Converter[Any, Any] = {
+ def getInstance(converterClass: Option[String],
+ defaultConverter: Converter[Any, Any]): Converter[Any, Any] = {
converterClass.map { cc =>
Try {
- val c = Class.forName(cc).newInstance().asInstanceOf[Converter[Any, Any]]
+ val c = Utils.classForName(cc).newInstance().asInstanceOf[Converter[Any, Any]]
logInfo(s"Loaded converter: $cc")
c
} match {
@@ -49,7 +52,7 @@ private[python] object Converter extends Logging {
logError(s"Failed to load converter: $cc")
throw err
}
- }.getOrElse { new DefaultConverter }
+ }.getOrElse { defaultConverter }
}
}
@@ -57,7 +60,8 @@ private[python] object Converter extends Logging {
* A converter that handles conversion of common [[org.apache.hadoop.io.Writable]] objects.
* Other objects are passed through without conversion.
*/
-private[python] class DefaultConverter extends Converter[Any, Any] {
+private[python] class WritableToJavaConverter(
+ conf: Broadcast[SerializableWritable[Configuration]]) extends Converter[Any, Any] {
/**
* Converts a [[org.apache.hadoop.io.Writable]] to the underlying primitive, String or
@@ -72,17 +76,29 @@ private[python] class DefaultConverter extends Converter[Any, Any] {
case fw: FloatWritable => fw.get()
case t: Text => t.toString
case bw: BooleanWritable => bw.get()
- case byw: BytesWritable => byw.getBytes
+ case byw: BytesWritable =>
+ val bytes = new Array[Byte](byw.getLength)
+ System.arraycopy(byw.getBytes(), 0, bytes, 0, byw.getLength)
+ bytes
case n: NullWritable => null
- case aw: ArrayWritable => aw.get().map(convertWritable(_))
- case mw: MapWritable => mapAsJavaMap(mw.map { case (k, v) =>
- (convertWritable(k), convertWritable(v))
- }.toMap)
+ case aw: ArrayWritable =>
+ // Due to erasure, all arrays appear as Object[] and they get pickled to Python tuples.
+ // Since we can't determine element types for empty arrays, we will not attempt to
+ // convert to primitive arrays (which get pickled to Python arrays). Users may want
+ // write custom converters for arrays if they know the element types a priori.
+ aw.get().map(convertWritable(_))
+ case mw: MapWritable =>
+ val map = new java.util.HashMap[Any, Any]()
+ mw.foreach { case (k, v) =>
+ map.put(convertWritable(k), convertWritable(v))
+ }
+ map
+ case w: Writable => WritableUtils.clone(w, conf.value.value)
case other => other
}
}
- def convert(obj: Any): Any = {
+ override def convert(obj: Any): Any = {
obj match {
case writable: Writable =>
convertWritable(writable)
@@ -92,6 +108,47 @@ private[python] class DefaultConverter extends Converter[Any, Any] {
}
}
+/**
+ * A converter that converts common types to [[org.apache.hadoop.io.Writable]]. Note that array
+ * types are not supported since the user needs to subclass [[org.apache.hadoop.io.ArrayWritable]]
+ * to set the type properly. See [[org.apache.spark.api.python.DoubleArrayWritable]] and
+ * [[org.apache.spark.api.python.DoubleArrayToWritableConverter]] for an example. They are used in
+ * PySpark RDD `saveAsNewAPIHadoopFile` doctest.
+ */
+private[python] class JavaToWritableConverter extends Converter[Any, Writable] {
+
+ /**
+ * Converts common data types to [[org.apache.hadoop.io.Writable]]. Note that array types are not
+ * supported out-of-the-box.
+ */
+ private def convertToWritable(obj: Any): Writable = {
+ import collection.JavaConversions._
+ obj match {
+ case i: java.lang.Integer => new IntWritable(i)
+ case d: java.lang.Double => new DoubleWritable(d)
+ case l: java.lang.Long => new LongWritable(l)
+ case f: java.lang.Float => new FloatWritable(f)
+ case s: java.lang.String => new Text(s)
+ case b: java.lang.Boolean => new BooleanWritable(b)
+ case aob: Array[Byte] => new BytesWritable(aob)
+ case null => NullWritable.get()
+ case map: java.util.Map[_, _] =>
+ val mapWritable = new MapWritable()
+ map.foreach { case (k, v) =>
+ mapWritable.put(convertToWritable(k), convertToWritable(v))
+ }
+ mapWritable
+ case other => throw new SparkException(
+ s"Data of type ${other.getClass.getName} cannot be used")
+ }
+ }
+
+ override def convert(obj: Any): Writable = obj match {
+ case writable: Writable => writable
+ case other => convertToWritable(other)
+ }
+}
+
/** Utilities for working with Python objects <-> Hadoop-related objects */
private[python] object PythonHadoopUtil {
@@ -118,7 +175,7 @@ private[python] object PythonHadoopUtil {
/**
* Converts an RDD of key-value pairs, where key and/or value could be instances of
- * [[org.apache.hadoop.io.Writable]], into an RDD[(K, V)]
+ * [[org.apache.hadoop.io.Writable]], into an RDD of base types, or vice versa.
*/
def convertRDD[K, V](rdd: RDD[(K, V)],
keyConverter: Converter[Any, Any],
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 462e09466bfa6..45beb8fc8c925 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -19,26 +19,29 @@ package org.apache.spark.api.python
import java.io._
import java.net._
-import java.nio.charset.Charset
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}
+import org.apache.spark.input.PortableDataStream
+
import scala.collection.JavaConversions._
-import scala.reflect.ClassTag
-import scala.util.Try
+import scala.collection.mutable
+import scala.language.existentials
-import net.razorvine.pickle.{Pickler, Unpickler}
+import com.google.common.base.Charsets.UTF_8
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.mapred.{InputFormat, JobConf}
-import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
+import org.apache.hadoop.io.compress.CompressionCodec
+import org.apache.hadoop.mapred.{InputFormat, OutputFormat, JobConf}
+import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, OutputFormat => NewOutputFormat}
import org.apache.spark._
+import org.apache.spark.SparkContext._
import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
-private[spark] class PythonRDD[T: ClassTag](
- parent: RDD[T],
+private[spark] class PythonRDD(
+ @transient parent: RDD[_],
command: Array[Byte],
envVars: JMap[String, String],
pythonIncludes: JList[String],
@@ -49,27 +52,39 @@ private[spark] class PythonRDD[T: ClassTag](
extends RDD[Array[Byte]](parent) {
val bufferSize = conf.getInt("spark.buffer.size", 65536)
+ val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true)
- override def getPartitions = parent.partitions
+ override def getPartitions = firstParent.partitions
- override val partitioner = if (preservePartitoning) parent.partitioner else None
+ override val partitioner = if (preservePartitoning) firstParent.partitioner else None
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
val startTime = System.currentTimeMillis
val env = SparkEnv.get
+ val localdir = env.blockManager.diskBlockManager.localDirs.map(
+ f => f.getPath()).mkString(",")
+ envVars += ("SPARK_LOCAL_DIRS" -> localdir) // it's also used in monitor thread
+ if (reuse_worker) {
+ envVars += ("SPARK_REUSE_WORKER" -> "1")
+ }
val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap)
// Start a thread to feed the process input from our parent's iterator
val writerThread = new WriterThread(env, worker, split, context)
- context.addOnCompleteCallback { () =>
+ var complete_cleanly = false
+ context.addTaskCompletionListener { context =>
writerThread.shutdownOnTaskCompletion()
-
- // Cleanup the worker socket. This will also cause the Python worker to exit.
- try {
- worker.close()
- } catch {
- case e: Exception => logWarning("Failed to close worker socket", e)
+ writerThread.join()
+ if (reuse_worker && complete_cleanly) {
+ env.releasePythonWorker(pythonExec, envVars.toMap, worker)
+ } else {
+ try {
+ worker.close()
+ } catch {
+ case e: Exception =>
+ logWarning("Failed to close worker socket", e)
+ }
}
}
@@ -109,13 +124,17 @@ private[spark] class PythonRDD[T: ClassTag](
val total = finishTime - startTime
logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot,
init, finish))
+ val memoryBytesSpilled = stream.readLong()
+ val diskBytesSpilled = stream.readLong()
+ context.taskMetrics.memoryBytesSpilled += memoryBytesSpilled
+ context.taskMetrics.diskBytesSpilled += diskBytesSpilled
read()
case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
// Signals that an exception has been thrown in python
val exLength = stream.readInt()
val obj = new Array[Byte](exLength)
stream.readFully(obj)
- throw new PythonException(new String(obj, "utf-8"),
+ throw new PythonException(new String(obj, UTF_8),
writerThread.exception.getOrElse(null))
case SpecialLengths.END_OF_DATA_SECTION =>
// We've finished the data section of the output, but we can still
@@ -127,14 +146,21 @@ private[spark] class PythonRDD[T: ClassTag](
stream.readFully(update)
accumulator += Collections.singletonList(update)
}
+ if (stream.readInt() == SpecialLengths.END_OF_STREAM) {
+ complete_cleanly = true
+ }
null
}
} catch {
- case e: Exception if context.interrupted =>
+ case e: Exception if context.isInterrupted =>
logDebug("Exception thrown after task interruption", e)
throw new TaskKilledException
+ case e: Exception if env.isStopped =>
+ logDebug("Exception thrown after context is stopped", e)
+ null // exit silently
+
case e: Exception if writerThread.exception.isDefined =>
logError("Python worker exited unexpectedly (crashed)", e)
logError("This may have been caused by a prior exception:", writerThread.exception.get)
@@ -170,13 +196,12 @@ private[spark] class PythonRDD[T: ClassTag](
/** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */
def shutdownOnTaskCompletion() {
- assert(context.completed)
+ assert(context.isCompleted)
this.interrupt()
}
override def run(): Unit = Utils.logUncaughtExceptions {
try {
- SparkEnv.set(env)
val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
val dataOut = new DataOutputStream(stream)
// Partition index
@@ -189,29 +214,51 @@ private[spark] class PythonRDD[T: ClassTag](
PythonRDD.writeUTF(include, dataOut)
}
// Broadcast variables
- dataOut.writeInt(broadcastVars.length)
+ val oldBids = PythonRDD.getWorkerBroadcasts(worker)
+ val newBids = broadcastVars.map(_.id).toSet
+ // number of different broadcasts
+ val cnt = oldBids.diff(newBids).size + newBids.diff(oldBids).size
+ dataOut.writeInt(cnt)
+ for (bid <- oldBids) {
+ if (!newBids.contains(bid)) {
+ // remove the broadcast from worker
+ dataOut.writeLong(- bid - 1) // bid >= 0
+ oldBids.remove(bid)
+ }
+ }
for (broadcast <- broadcastVars) {
- dataOut.writeLong(broadcast.id)
- dataOut.writeInt(broadcast.value.length)
- dataOut.write(broadcast.value)
+ if (!oldBids.contains(broadcast.id)) {
+ // send new broadcast
+ dataOut.writeLong(broadcast.id)
+ dataOut.writeInt(broadcast.value.length)
+ dataOut.write(broadcast.value)
+ oldBids.add(broadcast.id)
+ }
}
dataOut.flush()
// Serialized command:
dataOut.writeInt(command.length)
dataOut.write(command)
// Data values
- PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
+ PythonRDD.writeIteratorToStream(firstParent.iterator(split, context), dataOut)
+ dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
+ dataOut.writeInt(SpecialLengths.END_OF_STREAM)
dataOut.flush()
} catch {
- case e: Exception if context.completed || context.interrupted =>
+ case e: Exception if context.isCompleted || context.isInterrupted =>
logDebug("Exception thrown after task completion (likely due to cleanup)", e)
+ worker.shutdownOutput()
case e: Exception =>
// We must avoid throwing exceptions here, because the thread uncaught exception handler
// will kill the whole executor (see org.apache.spark.executor.Executor).
_exception = e
+ worker.shutdownOutput()
} finally {
- Try(worker.shutdownOutput()) // kill Python worker process
+ // Release memory used by this thread for shuffles
+ env.shuffleMemoryManager.releaseMemoryForThisThread()
+ // Release memory used by this thread for unrolling blocks
+ env.blockManager.memoryStore.releaseUnrollMemoryForThisThread()
}
}
}
@@ -229,13 +276,13 @@ private[spark] class PythonRDD[T: ClassTag](
override def run() {
// Kill the worker if it is interrupted, checking until task completion.
// TODO: This has a race condition if interruption occurs, as completed may still become true.
- while (!context.interrupted && !context.completed) {
+ while (!context.isInterrupted && !context.isCompleted) {
Thread.sleep(2000)
}
- if (!context.completed) {
+ if (!context.isCompleted) {
try {
logWarning("Incomplete task interrupted: Attempting to kill Python Worker")
- env.destroyPythonWorker(pythonExec, envVars.toMap)
+ env.destroyPythonWorker(pythonExec, envVars.toMap, worker)
} catch {
case e: Exception =>
logError("Exception when trying to kill worker", e)
@@ -267,10 +314,18 @@ private object SpecialLengths {
val END_OF_DATA_SECTION = -1
val PYTHON_EXCEPTION_THROWN = -2
val TIMING_DATA = -3
+ val END_OF_STREAM = -4
}
private[spark] object PythonRDD extends Logging {
- val UTF8 = Charset.forName("UTF-8")
+
+ // remember the broadcasts sent to each worker
+ private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]()
+ private def getWorkerBroadcasts(worker: Socket) = {
+ synchronized {
+ workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]())
+ }
+ }
/**
* Adapter for calling SparkContext#runJob from Python.
@@ -295,18 +350,34 @@ private[spark] object PythonRDD extends Logging {
def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
JavaRDD[Array[Byte]] = {
val file = new DataInputStream(new FileInputStream(filename))
- val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
try {
- while (true) {
- val length = file.readInt()
- val obj = new Array[Byte](length)
- file.readFully(obj)
- objs.append(obj)
+ val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
+ try {
+ while (true) {
+ val length = file.readInt()
+ val obj = new Array[Byte](length)
+ file.readFully(obj)
+ objs.append(obj)
+ }
+ } catch {
+ case eof: EOFException => {}
}
- } catch {
- case eof: EOFException => {}
+ JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
+ } finally {
+ file.close()
+ }
+ }
+
+ def readBroadcastFromFile(sc: JavaSparkContext, filename: String): Broadcast[Array[Byte]] = {
+ val file = new DataInputStream(new FileInputStream(filename))
+ try {
+ val length = file.readInt()
+ val obj = new Array[Byte](length)
+ file.readFully(obj)
+ sc.broadcast(obj)
+ } finally {
+ file.close()
}
- JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
}
def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) {
@@ -326,22 +397,33 @@ private[spark] object PythonRDD extends Logging {
newIter.asInstanceOf[Iterator[String]].foreach { str =>
writeUTF(str, dataOut)
}
- case pair: Tuple2[_, _] =>
- pair._1 match {
- case bytePair: Array[Byte] =>
- newIter.asInstanceOf[Iterator[Tuple2[Array[Byte], Array[Byte]]]].foreach { pair =>
- dataOut.writeInt(pair._1.length)
- dataOut.write(pair._1)
- dataOut.writeInt(pair._2.length)
- dataOut.write(pair._2)
- }
- case stringPair: String =>
- newIter.asInstanceOf[Iterator[Tuple2[String, String]]].foreach { pair =>
- writeUTF(pair._1, dataOut)
- writeUTF(pair._2, dataOut)
- }
- case other =>
- throw new SparkException("Unexpected Tuple2 element type " + pair._1.getClass)
+ case stream: PortableDataStream =>
+ newIter.asInstanceOf[Iterator[PortableDataStream]].foreach { stream =>
+ val bytes = stream.toArray()
+ dataOut.writeInt(bytes.length)
+ dataOut.write(bytes)
+ }
+ case (key: String, stream: PortableDataStream) =>
+ newIter.asInstanceOf[Iterator[(String, PortableDataStream)]].foreach {
+ case (key, stream) =>
+ writeUTF(key, dataOut)
+ val bytes = stream.toArray()
+ dataOut.writeInt(bytes.length)
+ dataOut.write(bytes)
+ }
+ case (key: String, value: String) =>
+ newIter.asInstanceOf[Iterator[(String, String)]].foreach {
+ case (key, value) =>
+ writeUTF(key, dataOut)
+ writeUTF(value, dataOut)
+ }
+ case (key: Array[Byte], value: Array[Byte]) =>
+ newIter.asInstanceOf[Iterator[(Array[Byte], Array[Byte])]].foreach {
+ case (key, value) =>
+ dataOut.writeInt(key.length)
+ dataOut.write(key)
+ dataOut.writeInt(value.length)
+ dataOut.write(value)
}
case other =>
throw new SparkException("Unexpected element type " + first.getClass)
@@ -362,19 +444,17 @@ private[spark] object PythonRDD extends Logging {
valueClassMaybeNull: String,
keyConverterClass: String,
valueConverterClass: String,
- minSplits: Int) = {
+ minSplits: Int,
+ batchSize: Int) = {
val keyClass = Option(keyClassMaybeNull).getOrElse("org.apache.hadoop.io.Text")
val valueClass = Option(valueClassMaybeNull).getOrElse("org.apache.hadoop.io.Text")
- implicit val kcm = ClassTag(Class.forName(keyClass)).asInstanceOf[ClassTag[K]]
- implicit val vcm = ClassTag(Class.forName(valueClass)).asInstanceOf[ClassTag[V]]
- val kc = kcm.runtimeClass.asInstanceOf[Class[K]]
- val vc = vcm.runtimeClass.asInstanceOf[Class[V]]
-
+ val kc = Utils.classForName(keyClass).asInstanceOf[Class[K]]
+ val vc = Utils.classForName(valueClass).asInstanceOf[Class[V]]
val rdd = sc.sc.sequenceFile[K, V](path, kc, vc, minSplits)
- val keyConverter = Converter.getInstance(Option(keyConverterClass))
- val valueConverter = Converter.getInstance(Option(valueConverterClass))
- val converted = PythonHadoopUtil.convertRDD[K, V](rdd, keyConverter, valueConverter)
- JavaRDD.fromRDD(SerDeUtil.rddToPython(converted))
+ val confBroadcasted = sc.sc.broadcast(new SerializableWritable(sc.hadoopConfiguration()))
+ val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
+ new WritableToJavaConverter(confBroadcasted))
+ JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}
/**
@@ -391,17 +471,16 @@ private[spark] object PythonRDD extends Logging {
valueClass: String,
keyConverterClass: String,
valueConverterClass: String,
- confAsMap: java.util.HashMap[String, String]) = {
- val conf = PythonHadoopUtil.mapToConf(confAsMap)
- val baseConf = sc.hadoopConfiguration()
- val mergedConf = PythonHadoopUtil.mergeConfs(baseConf, conf)
+ confAsMap: java.util.HashMap[String, String],
+ batchSize: Int) = {
+ val mergedConf = getMergedConf(confAsMap, sc.hadoopConfiguration())
val rdd =
newAPIHadoopRDDFromClassNames[K, V, F](sc,
Some(path), inputFormatClass, keyClass, valueClass, mergedConf)
- val keyConverter = Converter.getInstance(Option(keyConverterClass))
- val valueConverter = Converter.getInstance(Option(valueConverterClass))
- val converted = PythonHadoopUtil.convertRDD[K, V](rdd, keyConverter, valueConverter)
- JavaRDD.fromRDD(SerDeUtil.rddToPython(converted))
+ val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf))
+ val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
+ new WritableToJavaConverter(confBroadcasted))
+ JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}
/**
@@ -418,15 +497,16 @@ private[spark] object PythonRDD extends Logging {
valueClass: String,
keyConverterClass: String,
valueConverterClass: String,
- confAsMap: java.util.HashMap[String, String]) = {
+ confAsMap: java.util.HashMap[String, String],
+ batchSize: Int) = {
val conf = PythonHadoopUtil.mapToConf(confAsMap)
val rdd =
newAPIHadoopRDDFromClassNames[K, V, F](sc,
None, inputFormatClass, keyClass, valueClass, conf)
- val keyConverter = Converter.getInstance(Option(keyConverterClass))
- val valueConverter = Converter.getInstance(Option(valueConverterClass))
- val converted = PythonHadoopUtil.convertRDD[K, V](rdd, keyConverter, valueConverter)
- JavaRDD.fromRDD(SerDeUtil.rddToPython(converted))
+ val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf))
+ val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
+ new WritableToJavaConverter(confBroadcasted))
+ JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}
private def newAPIHadoopRDDFromClassNames[K, V, F <: NewInputFormat[K, V]](
@@ -436,18 +516,14 @@ private[spark] object PythonRDD extends Logging {
keyClass: String,
valueClass: String,
conf: Configuration) = {
- implicit val kcm = ClassTag(Class.forName(keyClass)).asInstanceOf[ClassTag[K]]
- implicit val vcm = ClassTag(Class.forName(valueClass)).asInstanceOf[ClassTag[V]]
- implicit val fcm = ClassTag(Class.forName(inputFormatClass)).asInstanceOf[ClassTag[F]]
- val kc = kcm.runtimeClass.asInstanceOf[Class[K]]
- val vc = vcm.runtimeClass.asInstanceOf[Class[V]]
- val fc = fcm.runtimeClass.asInstanceOf[Class[F]]
- val rdd = if (path.isDefined) {
+ val kc = Utils.classForName(keyClass).asInstanceOf[Class[K]]
+ val vc = Utils.classForName(valueClass).asInstanceOf[Class[V]]
+ val fc = Utils.classForName(inputFormatClass).asInstanceOf[Class[F]]
+ if (path.isDefined) {
sc.sc.newAPIHadoopFile[K, V, F](path.get, fc, kc, vc, conf)
} else {
sc.sc.newAPIHadoopRDD[K, V, F](conf, fc, kc, vc)
}
- rdd
}
/**
@@ -464,17 +540,16 @@ private[spark] object PythonRDD extends Logging {
valueClass: String,
keyConverterClass: String,
valueConverterClass: String,
- confAsMap: java.util.HashMap[String, String]) = {
- val conf = PythonHadoopUtil.mapToConf(confAsMap)
- val baseConf = sc.hadoopConfiguration()
- val mergedConf = PythonHadoopUtil.mergeConfs(baseConf, conf)
+ confAsMap: java.util.HashMap[String, String],
+ batchSize: Int) = {
+ val mergedConf = getMergedConf(confAsMap, sc.hadoopConfiguration())
val rdd =
hadoopRDDFromClassNames[K, V, F](sc,
Some(path), inputFormatClass, keyClass, valueClass, mergedConf)
- val keyConverter = Converter.getInstance(Option(keyConverterClass))
- val valueConverter = Converter.getInstance(Option(valueConverterClass))
- val converted = PythonHadoopUtil.convertRDD[K, V](rdd, keyConverter, valueConverter)
- JavaRDD.fromRDD(SerDeUtil.rddToPython(converted))
+ val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf))
+ val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
+ new WritableToJavaConverter(confBroadcasted))
+ JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}
/**
@@ -491,15 +566,16 @@ private[spark] object PythonRDD extends Logging {
valueClass: String,
keyConverterClass: String,
valueConverterClass: String,
- confAsMap: java.util.HashMap[String, String]) = {
+ confAsMap: java.util.HashMap[String, String],
+ batchSize: Int) = {
val conf = PythonHadoopUtil.mapToConf(confAsMap)
val rdd =
hadoopRDDFromClassNames[K, V, F](sc,
None, inputFormatClass, keyClass, valueClass, conf)
- val keyConverter = Converter.getInstance(Option(keyConverterClass))
- val valueConverter = Converter.getInstance(Option(valueConverterClass))
- val converted = PythonHadoopUtil.convertRDD[K, V](rdd, keyConverter, valueConverter)
- JavaRDD.fromRDD(SerDeUtil.rddToPython(converted))
+ val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf))
+ val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
+ new WritableToJavaConverter(confBroadcasted))
+ JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}
private def hadoopRDDFromClassNames[K, V, F <: InputFormat[K, V]](
@@ -509,22 +585,18 @@ private[spark] object PythonRDD extends Logging {
keyClass: String,
valueClass: String,
conf: Configuration) = {
- implicit val kcm = ClassTag(Class.forName(keyClass)).asInstanceOf[ClassTag[K]]
- implicit val vcm = ClassTag(Class.forName(valueClass)).asInstanceOf[ClassTag[V]]
- implicit val fcm = ClassTag(Class.forName(inputFormatClass)).asInstanceOf[ClassTag[F]]
- val kc = kcm.runtimeClass.asInstanceOf[Class[K]]
- val vc = vcm.runtimeClass.asInstanceOf[Class[V]]
- val fc = fcm.runtimeClass.asInstanceOf[Class[F]]
- val rdd = if (path.isDefined) {
+ val kc = Utils.classForName(keyClass).asInstanceOf[Class[K]]
+ val vc = Utils.classForName(valueClass).asInstanceOf[Class[V]]
+ val fc = Utils.classForName(inputFormatClass).asInstanceOf[Class[F]]
+ if (path.isDefined) {
sc.sc.hadoopFile(path.get, fc, kc, vc)
} else {
sc.sc.hadoopRDD(new JobConf(conf), fc, kc, vc)
}
- rdd
}
def writeUTF(str: String, dataOut: DataOutputStream) {
- val bytes = str.getBytes(UTF8)
+ val bytes = str.getBytes(UTF_8)
dataOut.writeInt(bytes.length)
dataOut.write(bytes)
}
@@ -540,41 +612,156 @@ private[spark] object PythonRDD extends Logging {
file.close()
}
+ private def getMergedConf(confAsMap: java.util.HashMap[String, String],
+ baseConf: Configuration): Configuration = {
+ val conf = PythonHadoopUtil.mapToConf(confAsMap)
+ PythonHadoopUtil.mergeConfs(baseConf, conf)
+ }
+
+ private def inferKeyValueTypes[K, V](rdd: RDD[(K, V)], keyConverterClass: String = null,
+ valueConverterClass: String = null): (Class[_], Class[_]) = {
+ // Peek at an element to figure out key/value types. Since Writables are not serializable,
+ // we cannot call first() on the converted RDD. Instead, we call first() on the original RDD
+ // and then convert locally.
+ val (key, value) = rdd.first()
+ val (kc, vc) = getKeyValueConverters(keyConverterClass, valueConverterClass,
+ new JavaToWritableConverter)
+ (kc.convert(key).getClass, vc.convert(value).getClass)
+ }
+
+ private def getKeyValueTypes(keyClass: String, valueClass: String):
+ Option[(Class[_], Class[_])] = {
+ for {
+ k <- Option(keyClass)
+ v <- Option(valueClass)
+ } yield (Utils.classForName(k), Utils.classForName(v))
+ }
+
+ private def getKeyValueConverters(keyConverterClass: String, valueConverterClass: String,
+ defaultConverter: Converter[Any, Any]): (Converter[Any, Any], Converter[Any, Any]) = {
+ val keyConverter = Converter.getInstance(Option(keyConverterClass), defaultConverter)
+ val valueConverter = Converter.getInstance(Option(valueConverterClass), defaultConverter)
+ (keyConverter, valueConverter)
+ }
+
/**
- * Convert an RDD of serialized Python dictionaries to Scala Maps
- * TODO: Support more Python types.
+ * Convert an RDD of key-value pairs from internal types to serializable types suitable for
+ * output, or vice versa.
*/
- def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = {
- pyRDD.rdd.mapPartitions { iter =>
- val unpickle = new Unpickler
- // TODO: Figure out why flatMap is necessay for pyspark
- iter.flatMap { row =>
- unpickle.loads(row) match {
- case objs: java.util.ArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap)
- // Incase the partition doesn't have a collection
- case obj: JMap[String @unchecked, _] => Seq(obj.toMap)
- }
- }
- }
+ private def convertRDD[K, V](rdd: RDD[(K, V)],
+ keyConverterClass: String,
+ valueConverterClass: String,
+ defaultConverter: Converter[Any, Any]): RDD[(Any, Any)] = {
+ val (kc, vc) = getKeyValueConverters(keyConverterClass, valueConverterClass,
+ defaultConverter)
+ PythonHadoopUtil.convertRDD(rdd, kc, vc)
}
/**
- * Convert and RDD of Java objects to and RDD of serialized Python objects, that is usable by
- * PySpark.
+ * Output a Python RDD of key-value pairs as a Hadoop SequenceFile using the Writable types
+ * we convert from the RDD's key and value types. Note that keys and values can't be
+ * [[org.apache.hadoop.io.Writable]] types already, since Writables are not Java
+ * `Serializable` and we can't peek at them. The `path` can be on any Hadoop file system.
*/
- def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = {
- jRDD.rdd.mapPartitions { iter =>
- val pickle = new Pickler
- iter.map { row =>
- pickle.dumps(row)
- }
+ def saveAsSequenceFile[K, V, C <: CompressionCodec](
+ pyRDD: JavaRDD[Array[Byte]],
+ batchSerialized: Boolean,
+ path: String,
+ compressionCodecClass: String) = {
+ saveAsHadoopFile(
+ pyRDD, batchSerialized, path, "org.apache.hadoop.mapred.SequenceFileOutputFormat",
+ null, null, null, null, new java.util.HashMap(), compressionCodecClass)
+ }
+
+ /**
+ * Output a Python RDD of key-value pairs to any Hadoop file system, using old Hadoop
+ * `OutputFormat` in mapred package. Keys and values are converted to suitable output
+ * types using either user specified converters or, if not specified,
+ * [[org.apache.spark.api.python.JavaToWritableConverter]]. Post-conversion types
+ * `keyClass` and `valueClass` are automatically inferred if not specified. The passed-in
+ * `confAsMap` is merged with the default Hadoop conf associated with the SparkContext of
+ * this RDD.
+ */
+ def saveAsHadoopFile[K, V, F <: OutputFormat[_, _], C <: CompressionCodec](
+ pyRDD: JavaRDD[Array[Byte]],
+ batchSerialized: Boolean,
+ path: String,
+ outputFormatClass: String,
+ keyClass: String,
+ valueClass: String,
+ keyConverterClass: String,
+ valueConverterClass: String,
+ confAsMap: java.util.HashMap[String, String],
+ compressionCodecClass: String) = {
+ val rdd = SerDeUtil.pythonToPairRDD(pyRDD, batchSerialized)
+ val (kc, vc) = getKeyValueTypes(keyClass, valueClass).getOrElse(
+ inferKeyValueTypes(rdd, keyConverterClass, valueConverterClass))
+ val mergedConf = getMergedConf(confAsMap, pyRDD.context.hadoopConfiguration)
+ val codec = Option(compressionCodecClass).map(Utils.classForName(_).asInstanceOf[Class[C]])
+ val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
+ new JavaToWritableConverter)
+ val fc = Utils.classForName(outputFormatClass).asInstanceOf[Class[F]]
+ converted.saveAsHadoopFile(path, kc, vc, fc, new JobConf(mergedConf), codec=codec)
+ }
+
+ /**
+ * Output a Python RDD of key-value pairs to any Hadoop file system, using new Hadoop
+ * `OutputFormat` in mapreduce package. Keys and values are converted to suitable output
+ * types using either user specified converters or, if not specified,
+ * [[org.apache.spark.api.python.JavaToWritableConverter]]. Post-conversion types
+ * `keyClass` and `valueClass` are automatically inferred if not specified. The passed-in
+ * `confAsMap` is merged with the default Hadoop conf associated with the SparkContext of
+ * this RDD.
+ */
+ def saveAsNewAPIHadoopFile[K, V, F <: NewOutputFormat[_, _]](
+ pyRDD: JavaRDD[Array[Byte]],
+ batchSerialized: Boolean,
+ path: String,
+ outputFormatClass: String,
+ keyClass: String,
+ valueClass: String,
+ keyConverterClass: String,
+ valueConverterClass: String,
+ confAsMap: java.util.HashMap[String, String]) = {
+ val rdd = SerDeUtil.pythonToPairRDD(pyRDD, batchSerialized)
+ val (kc, vc) = getKeyValueTypes(keyClass, valueClass).getOrElse(
+ inferKeyValueTypes(rdd, keyConverterClass, valueConverterClass))
+ val mergedConf = getMergedConf(confAsMap, pyRDD.context.hadoopConfiguration)
+ val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
+ new JavaToWritableConverter)
+ val fc = Utils.classForName(outputFormatClass).asInstanceOf[Class[F]]
+ converted.saveAsNewAPIHadoopFile(path, kc, vc, fc, mergedConf)
+ }
+
+ /**
+ * Output a Python RDD of key-value pairs to any Hadoop file system, using a Hadoop conf
+ * converted from the passed-in `confAsMap`. The conf should set relevant output params (
+ * e.g., output path, output format, etc), in the same way as it would be configured for
+ * a Hadoop MapReduce job. Both old and new Hadoop OutputFormat APIs are supported
+ * (mapred vs. mapreduce). Keys/values are converted for output using either user specified
+ * converters or, by default, [[org.apache.spark.api.python.JavaToWritableConverter]].
+ */
+ def saveAsHadoopDataset[K, V](
+ pyRDD: JavaRDD[Array[Byte]],
+ batchSerialized: Boolean,
+ confAsMap: java.util.HashMap[String, String],
+ keyConverterClass: String,
+ valueConverterClass: String,
+ useNewAPI: Boolean) = {
+ val conf = PythonHadoopUtil.mapToConf(confAsMap)
+ val converted = convertRDD(SerDeUtil.pythonToPairRDD(pyRDD, batchSerialized),
+ keyConverterClass, valueConverterClass, new JavaToWritableConverter)
+ if (useNewAPI) {
+ converted.saveAsNewAPIHadoopDataset(conf)
+ } else {
+ converted.saveAsHadoopDataset(new JobConf(conf))
}
}
}
private
class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] {
- override def call(arr: Array[Byte]) : String = new String(arr, PythonRDD.UTF8)
+ override def call(arr: Array[Byte]) : String = new String(arr, UTF_8)
}
/**
@@ -588,19 +775,30 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort:
val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536)
+ /**
+ * We try to reuse a single Socket to transfer accumulator updates, as they are all added
+ * by the DAGScheduler's single-threaded actor anyway.
+ */
+ @transient var socket: Socket = _
+
+ def openSocket(): Socket = synchronized {
+ if (socket == null || socket.isClosed) {
+ socket = new Socket(serverHost, serverPort)
+ }
+ socket
+ }
+
override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList
override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]])
- : JList[Array[Byte]] = {
+ : JList[Array[Byte]] = synchronized {
if (serverHost == null) {
// This happens on the worker node, where we just want to remember all the updates
val1.addAll(val2)
val1
} else {
// This happens on the master, where we pass the updates to Python through a socket
- val socket = new Socket(serverHost, serverPort)
- // SPARK-2282: Immediately reuse closed sockets because we create one per task.
- socket.setReuseAddress(true)
+ val socket = openSocket()
val in = socket.getInputStream
val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize))
out.writeInt(val2.size)
@@ -614,7 +812,6 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort:
if (byteRead == -1) {
throw new SparkException("EOF reached before Python server acknowledged")
}
- socket.close()
null
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
index 6d3e257c4d5df..be5ebfa9219d3 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
@@ -29,7 +29,7 @@ private[spark] object PythonUtils {
val pythonPath = new ArrayBuffer[String]
for (sparkHome <- sys.env.get("SPARK_HOME")) {
pythonPath += Seq(sparkHome, "python").mkString(File.separator)
- pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.8.1-src.zip").mkString(File.separator)
+ pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.8.2.1-src.zip").mkString(File.separator)
}
pythonPath ++= SparkContext.jarOfObject(this)
pythonPath.mkString(File.pathSeparator)
@@ -40,28 +40,3 @@ private[spark] object PythonUtils {
paths.filter(_ != "").mkString(File.pathSeparator)
}
}
-
-
-/**
- * A utility class to redirect the child process's stdout or stderr.
- */
-private[spark] class RedirectThread(
- in: InputStream,
- out: OutputStream,
- name: String)
- extends Thread(name) {
-
- setDaemon(true)
- override def run() {
- scala.util.control.Exception.ignoring(classOf[IOException]) {
- // FIXME: We copy the stream on the level of bytes to avoid encoding problems.
- val buf = new Array[Byte](1024)
- var len = in.read(buf)
- while (len != -1) {
- out.write(buf, 0, len)
- out.flush()
- len = in.read(buf)
- }
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
index 759cbe2c46c52..e314408c067e9 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
@@ -17,13 +17,14 @@
package org.apache.spark.api.python
-import java.io.{DataInputStream, InputStream, OutputStreamWriter}
+import java.io.{DataOutputStream, DataInputStream, InputStream, OutputStreamWriter}
import java.net.{InetAddress, ServerSocket, Socket, SocketException}
+import scala.collection.mutable
import scala.collection.JavaConversions._
import org.apache.spark._
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{RedirectThread, Utils}
private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String])
extends Logging {
@@ -39,6 +40,12 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
var daemon: Process = null
val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
var daemonPort: Int = 0
+ val daemonWorkers = new mutable.WeakHashMap[Socket, Int]()
+ val idleWorkers = new mutable.Queue[Socket]()
+ var lastActivity = 0L
+ new MonitorThread().start()
+
+ var simpleWorkers = new mutable.WeakHashMap[Socket, Process]()
val pythonPath = PythonUtils.mergePythonPaths(
PythonUtils.sparkPythonPath,
@@ -47,6 +54,11 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
def create(): Socket = {
if (useDaemon) {
+ synchronized {
+ if (idleWorkers.size > 0) {
+ return idleWorkers.dequeue()
+ }
+ }
createThroughDaemon()
} else {
createSimpleWorker()
@@ -58,19 +70,31 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
* to avoid the high cost of forking from Java. This currently only works on UNIX-based systems.
*/
private def createThroughDaemon(): Socket = {
+
+ def createSocket(): Socket = {
+ val socket = new Socket(daemonHost, daemonPort)
+ val pid = new DataInputStream(socket.getInputStream).readInt()
+ if (pid < 0) {
+ throw new IllegalStateException("Python daemon failed to launch worker with code " + pid)
+ }
+ daemonWorkers.put(socket, pid)
+ socket
+ }
+
synchronized {
// Start the daemon if it hasn't been started
startDaemon()
// Attempt to connect, restart and retry once if it fails
try {
- new Socket(daemonHost, daemonPort)
+ createSocket()
} catch {
case exc: SocketException =>
- logWarning("Python daemon unexpectedly quit, attempting to restart")
+ logWarning("Failed to open socket to Python daemon:", exc)
+ logWarning("Assuming that daemon unexpectedly quit, attempting to restart")
stopDaemon()
startDaemon()
- new Socket(daemonHost, daemonPort)
+ createSocket()
}
}
}
@@ -84,10 +108,12 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1)))
// Create and start the worker
- val pb = new ProcessBuilder(Seq(pythonExec, "-u", "-m", "pyspark.worker"))
+ val pb = new ProcessBuilder(Seq(pythonExec, "-m", "pyspark.worker"))
val workerEnv = pb.environment()
workerEnv.putAll(envVars)
workerEnv.put("PYTHONPATH", pythonPath)
+ // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
+ workerEnv.put("PYTHONUNBUFFERED", "YES")
val worker = pb.start()
// Redirect worker stdout and stderr
@@ -101,7 +127,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
// Wait for it to connect to our socket
serverSocket.setSoTimeout(10000)
try {
- return serverSocket.accept()
+ val socket = serverSocket.accept()
+ simpleWorkers.put(socket, worker)
+ return socket
} catch {
case e: Exception =>
throw new SparkException("Python worker did not connect back in time", e)
@@ -123,10 +151,12 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
try {
// Create and start the daemon
- val pb = new ProcessBuilder(Seq(pythonExec, "-u", "-m", "pyspark.daemon"))
+ val pb = new ProcessBuilder(Seq(pythonExec, "-m", "pyspark.daemon"))
val workerEnv = pb.environment()
workerEnv.putAll(envVars)
workerEnv.put("PYTHONPATH", pythonPath)
+ // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
+ workerEnv.put("PYTHONUNBUFFERED", "YES")
daemon = pb.start()
val in = new DataInputStream(daemon.getInputStream)
@@ -181,23 +211,99 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
}
}
+ /**
+ * Monitor all the idle workers, kill them after timeout.
+ */
+ private class MonitorThread extends Thread(s"Idle Worker Monitor for $pythonExec") {
+
+ setDaemon(true)
+
+ override def run() {
+ while (true) {
+ synchronized {
+ if (lastActivity + IDLE_WORKER_TIMEOUT_MS < System.currentTimeMillis()) {
+ cleanupIdleWorkers()
+ lastActivity = System.currentTimeMillis()
+ }
+ }
+ Thread.sleep(10000)
+ }
+ }
+ }
+
+ private def cleanupIdleWorkers() {
+ while (idleWorkers.length > 0) {
+ val worker = idleWorkers.dequeue()
+ try {
+ // the worker will exit after closing the socket
+ worker.close()
+ } catch {
+ case e: Exception =>
+ logWarning("Failed to close worker socket", e)
+ }
+ }
+ }
+
private def stopDaemon() {
synchronized {
- // Request shutdown of existing daemon by sending SIGTERM
- if (daemon != null) {
- daemon.destroy()
+ if (useDaemon) {
+ cleanupIdleWorkers()
+
+ // Request shutdown of existing daemon by sending SIGTERM
+ if (daemon != null) {
+ daemon.destroy()
+ }
+
+ daemon = null
+ daemonPort = 0
+ } else {
+ simpleWorkers.mapValues(_.destroy())
}
-
- daemon = null
- daemonPort = 0
}
}
def stop() {
stopDaemon()
}
+
+ def stopWorker(worker: Socket) {
+ synchronized {
+ if (useDaemon) {
+ if (daemon != null) {
+ daemonWorkers.get(worker).foreach { pid =>
+ // tell daemon to kill worker by pid
+ val output = new DataOutputStream(daemon.getOutputStream)
+ output.writeInt(pid)
+ output.flush()
+ daemon.getOutputStream.flush()
+ }
+ }
+ } else {
+ simpleWorkers.get(worker).foreach(_.destroy())
+ }
+ }
+ worker.close()
+ }
+
+ def releaseWorker(worker: Socket) {
+ if (useDaemon) {
+ synchronized {
+ lastActivity = System.currentTimeMillis()
+ idleWorkers.enqueue(worker)
+ }
+ } else {
+ // Cleanup the worker socket. This will also cause the Python worker to exit.
+ try {
+ worker.close()
+ } catch {
+ case e: Exception =>
+ logWarning("Failed to close worker socket", e)
+ }
+ }
+ }
}
private object PythonWorkerFactory {
val PROCESS_WAIT_TIMEOUT_MS = 10000
+ val IDLE_WORKER_TIMEOUT_MS = 60000 // kill idle workers after 1 minute
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
index 9a012e7254901..a4153aaa926f8 100644
--- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
@@ -17,16 +17,149 @@
package org.apache.spark.api.python
-import scala.util.Try
-import org.apache.spark.rdd.RDD
-import org.apache.spark.Logging
-import scala.util.Success
+import java.nio.ByteOrder
+import java.util.{ArrayList => JArrayList}
+
+import org.apache.spark.api.java.JavaRDD
+
+import scala.collection.JavaConversions._
+import scala.collection.JavaConverters._
+import scala.collection.mutable
import scala.util.Failure
-import net.razorvine.pickle.Pickler
+import scala.util.Try
+
+import net.razorvine.pickle.{Unpickler, Pickler}
+import org.apache.spark.{Logging, SparkException}
+import org.apache.spark.rdd.RDD
/** Utilities for serialization / deserialization between Python and Java, using Pickle. */
-private[python] object SerDeUtil extends Logging {
+private[spark] object SerDeUtil extends Logging {
+ // Unpickle array.array generated by Python 2.6
+ class ArrayConstructor extends net.razorvine.pickle.objects.ArrayConstructor {
+ // /* Description of types */
+ // static struct arraydescr descriptors[] = {
+ // {'c', sizeof(char), c_getitem, c_setitem},
+ // {'b', sizeof(char), b_getitem, b_setitem},
+ // {'B', sizeof(char), BB_getitem, BB_setitem},
+ // #ifdef Py_USING_UNICODE
+ // {'u', sizeof(Py_UNICODE), u_getitem, u_setitem},
+ // #endif
+ // {'h', sizeof(short), h_getitem, h_setitem},
+ // {'H', sizeof(short), HH_getitem, HH_setitem},
+ // {'i', sizeof(int), i_getitem, i_setitem},
+ // {'I', sizeof(int), II_getitem, II_setitem},
+ // {'l', sizeof(long), l_getitem, l_setitem},
+ // {'L', sizeof(long), LL_getitem, LL_setitem},
+ // {'f', sizeof(float), f_getitem, f_setitem},
+ // {'d', sizeof(double), d_getitem, d_setitem},
+ // {'\0', 0, 0, 0} /* Sentinel */
+ // };
+ // TODO: support Py_UNICODE with 2 bytes
+ // FIXME: unpickle array of float is wrong in Pyrolite, so we reverse the
+ // machine code for float/double here to workaround it.
+ // we should fix this after Pyrolite fix them
+ val machineCodes: Map[Char, Int] = if (ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN)) {
+ Map('c' -> 1, 'B' -> 0, 'b' -> 1, 'H' -> 3, 'h' -> 5, 'I' -> 7, 'i' -> 9,
+ 'L' -> 11, 'l' -> 13, 'f' -> 14, 'd' -> 16, 'u' -> 21
+ )
+ } else {
+ Map('c' -> 1, 'B' -> 0, 'b' -> 1, 'H' -> 2, 'h' -> 4, 'I' -> 6, 'i' -> 8,
+ 'L' -> 10, 'l' -> 12, 'f' -> 15, 'd' -> 17, 'u' -> 20
+ )
+ }
+ override def construct(args: Array[Object]): Object = {
+ if (args.length == 1) {
+ construct(args ++ Array(""))
+ } else if (args.length == 2 && args(1).isInstanceOf[String]) {
+ val typecode = args(0).asInstanceOf[String].charAt(0)
+ val data: Array[Byte] = args(1).asInstanceOf[String].getBytes("ISO-8859-1")
+ construct(typecode, machineCodes(typecode), data)
+ } else {
+ super.construct(args)
+ }
+ }
+ }
+
+ private var initialized = false
+ // This should be called before trying to unpickle array.array from Python
+ // In cluster mode, this should be put in closure
+ def initialize() = {
+ synchronized{
+ if (!initialized) {
+ Unpickler.registerConstructor("array", "array", new ArrayConstructor())
+ initialized = true
+ }
+ }
+ }
+ initialize()
+
+
+ /**
+ * Convert an RDD of Java objects to Array (no recursive conversions).
+ * It is only used by pyspark.sql.
+ */
+ def toJavaArray(jrdd: JavaRDD[Any]): JavaRDD[Array[_]] = {
+ jrdd.rdd.map {
+ case objs: JArrayList[_] =>
+ objs.toArray
+ case obj if obj.getClass.isArray =>
+ obj.asInstanceOf[Array[_]].toArray
+ }.toJavaRDD()
+ }
+
+ /**
+ * Choose batch size based on size of objects
+ */
+ private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] {
+ private val pickle = new Pickler()
+ private var batch = 1
+ private val buffer = new mutable.ArrayBuffer[Any]
+
+ override def hasNext: Boolean = iter.hasNext
+
+ override def next(): Array[Byte] = {
+ while (iter.hasNext && buffer.length < batch) {
+ buffer += iter.next()
+ }
+ val bytes = pickle.dumps(buffer.toArray)
+ val size = bytes.length
+ // let 1M < size < 10M
+ if (size < 1024 * 1024) {
+ batch *= 2
+ } else if (size > 1024 * 1024 * 10 && batch > 1) {
+ batch /= 2
+ }
+ buffer.clear()
+ bytes
+ }
+ }
+
+ /**
+ * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
+ * PySpark.
+ */
+ private[spark] def javaToPython(jRDD: JavaRDD[_]): JavaRDD[Array[Byte]] = {
+ jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) }
+ }
+
+ /**
+ * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark.
+ */
+ def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = {
+ pyRDD.rdd.mapPartitions { iter =>
+ initialize()
+ val unpickle = new Unpickler
+ iter.flatMap { row =>
+ val obj = unpickle.loads(row)
+ if (batched) {
+ obj.asInstanceOf[JArrayList[_]].asScala
+ } else {
+ Seq(obj)
+ }
+ }
+ }.toJavaRDD()
+ }
private def checkPickle(t: (Any, Any)): (Boolean, Boolean) = {
val pickle = new Pickler
@@ -65,23 +198,43 @@ private[python] object SerDeUtil extends Logging {
* by PySpark. By default, if serialization fails, toString is called and the string
* representation is serialized
*/
- def rddToPython(rdd: RDD[(Any, Any)]): RDD[Array[Byte]] = {
+ def pairRDDToPython(rdd: RDD[(Any, Any)], batchSize: Int): RDD[Array[Byte]] = {
val (keyFailed, valueFailed) = checkPickle(rdd.first())
+
rdd.mapPartitions { iter =>
- val pickle = new Pickler
- iter.map { case (k, v) =>
- if (keyFailed && valueFailed) {
- pickle.dumps(Array(k.toString, v.toString))
- } else if (keyFailed) {
- pickle.dumps(Array(k.toString, v))
- } else if (!keyFailed && valueFailed) {
- pickle.dumps(Array(k, v.toString))
- } else {
- pickle.dumps(Array(k, v))
- }
+ val cleaned = iter.map { case (k, v) =>
+ val key = if (keyFailed) k.toString else k
+ val value = if (valueFailed) v.toString else v
+ Array[Any](key, value)
+ }
+ if (batchSize == 0) {
+ new AutoBatchedPickler(cleaned)
+ } else {
+ val pickle = new Pickler
+ cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched)))
}
}
}
-}
+ /**
+ * Convert an RDD of serialized Python tuple (K, V) to RDD[(K, V)].
+ */
+ def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batched: Boolean): RDD[(K, V)] = {
+ def isPair(obj: Any): Boolean = {
+ Option(obj.getClass.getComponentType).exists(!_.isPrimitive) &&
+ obj.asInstanceOf[Array[_]].length == 2
+ }
+ val rdd = pythonToJava(pyRDD, batched).rdd
+ rdd.first match {
+ case obj if isPair(obj) =>
+ // we only accept (K, V)
+ case other => throw new SparkException(
+ s"RDD element of type ${other.getClass.getName} cannot be used")
+ }
+ rdd.map { obj =>
+ val arr = obj.asInstanceOf[Array[_]]
+ (arr.head.asInstanceOf[K], arr.last.asInstanceOf[V])
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala
index f0e3fb9aff5a0..c0cbd28a845be 100644
--- a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala
@@ -17,15 +17,17 @@
package org.apache.spark.api.python
-import org.apache.spark.SparkContext
-import org.apache.hadoop.io._
-import scala.Array
import java.io.{DataOutput, DataInput}
+
+import com.google.common.base.Charsets.UTF_8
+
+import org.apache.hadoop.io._
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat
import org.apache.spark.api.java.JavaSparkContext
+import org.apache.spark.{SparkContext, SparkException}
/**
- * A class to test MsgPack serialization on the Scala side, that will be deserialized
+ * A class to test Pyrolite serialization on the Scala side, that will be deserialized
* in Python
* @param str
* @param int
@@ -54,7 +56,13 @@ case class TestWritable(var str: String, var int: Int, var double: Double) exten
}
}
-class TestConverter extends Converter[Any, Any] {
+private[python] class TestInputKeyConverter extends Converter[Any, Any] {
+ override def convert(obj: Any) = {
+ obj.asInstanceOf[IntWritable].get().toChar
+ }
+}
+
+private[python] class TestInputValueConverter extends Converter[Any, Any] {
import collection.JavaConversions._
override def convert(obj: Any) = {
val m = obj.asInstanceOf[MapWritable]
@@ -62,6 +70,38 @@ class TestConverter extends Converter[Any, Any] {
}
}
+private[python] class TestOutputKeyConverter extends Converter[Any, Any] {
+ override def convert(obj: Any) = {
+ new Text(obj.asInstanceOf[Int].toString)
+ }
+}
+
+private[python] class TestOutputValueConverter extends Converter[Any, Any] {
+ import collection.JavaConversions._
+ override def convert(obj: Any) = {
+ new DoubleWritable(obj.asInstanceOf[java.util.Map[Double, _]].keySet().head)
+ }
+}
+
+private[python] class DoubleArrayWritable extends ArrayWritable(classOf[DoubleWritable])
+
+private[python] class DoubleArrayToWritableConverter extends Converter[Any, Writable] {
+ override def convert(obj: Any) = obj match {
+ case arr if arr.getClass.isArray && arr.getClass.getComponentType == classOf[Double] =>
+ val daw = new DoubleArrayWritable
+ daw.set(arr.asInstanceOf[Array[Double]].map(new DoubleWritable(_)))
+ daw
+ case other => throw new SparkException(s"Data of type $other is not supported")
+ }
+}
+
+private[python] class WritableToDoubleArrayConverter extends Converter[Any, Array[Double]] {
+ override def convert(obj: Any): Array[Double] = obj match {
+ case daw : DoubleArrayWritable => daw.get().map(_.asInstanceOf[DoubleWritable].get())
+ case other => throw new SparkException(s"Data of type $other is not supported")
+ }
+}
+
/**
* This object contains method to generate SequenceFile test data and write it to a
* given directory (probably a temp directory)
@@ -97,7 +137,8 @@ object WriteInputFormatTestDataGenerator {
sc.parallelize(intKeys).saveAsSequenceFile(intPath)
sc.parallelize(intKeys.map{ case (k, v) => (k.toDouble, v) }).saveAsSequenceFile(doublePath)
sc.parallelize(intKeys.map{ case (k, v) => (k.toString, v) }).saveAsSequenceFile(textPath)
- sc.parallelize(intKeys.map{ case (k, v) => (k, v.getBytes) }).saveAsSequenceFile(bytesPath)
+ sc.parallelize(intKeys.map{ case (k, v) => (k, v.getBytes(UTF_8)) }
+ ).saveAsSequenceFile(bytesPath)
val bools = Seq((1, true), (2, true), (2, false), (3, true), (2, false), (1, false))
sc.parallelize(bools).saveAsSequenceFile(boolPath)
sc.parallelize(intKeys).map{ case (k, v) =>
@@ -106,19 +147,20 @@ object WriteInputFormatTestDataGenerator {
// Create test data for ArrayWritable
val data = Seq(
- (1, Array(1.0, 2.0, 3.0)),
+ (1, Array()),
(2, Array(3.0, 4.0, 5.0)),
(3, Array(4.0, 5.0, 6.0))
)
sc.parallelize(data, numSlices = 2)
.map{ case (k, v) =>
- (new IntWritable(k), new ArrayWritable(classOf[DoubleWritable], v.map(new DoubleWritable(_))))
- }.saveAsNewAPIHadoopFile[SequenceFileOutputFormat[IntWritable, ArrayWritable]](arrPath)
+ val va = new DoubleArrayWritable
+ va.set(v.map(new DoubleWritable(_)))
+ (new IntWritable(k), va)
+ }.saveAsNewAPIHadoopFile[SequenceFileOutputFormat[IntWritable, DoubleArrayWritable]](arrPath)
// Create test data for MapWritable, with keys DoubleWritable and values Text
val mapData = Seq(
- (1, Map(2.0 -> "aa")),
- (2, Map(3.0 -> "bb")),
+ (1, Map()),
(2, Map(1.0 -> "cc")),
(3, Map(2.0 -> "dd")),
(2, Map(1.0 -> "aa")),
@@ -126,19 +168,19 @@ object WriteInputFormatTestDataGenerator {
)
sc.parallelize(mapData, numSlices = 2).map{ case (i, m) =>
val mw = new MapWritable()
- val k = m.keys.head
- val v = m.values.head
- mw.put(new DoubleWritable(k), new Text(v))
+ m.foreach { case (k, v) =>
+ mw.put(new DoubleWritable(k), new Text(v))
+ }
(new IntWritable(i), mw)
}.saveAsSequenceFile(mapPath)
// Create test data for arbitrary custom writable TestWritable
val testClass = Seq(
- ("1", TestWritable("test1", 123, 54.0)),
- ("2", TestWritable("test2", 456, 8762.3)),
- ("1", TestWritable("test3", 123, 423.1)),
- ("3", TestWritable("test56", 456, 423.5)),
- ("2", TestWritable("test2", 123, 5435.2))
+ ("1", TestWritable("test1", 1, 1.0)),
+ ("2", TestWritable("test2", 2, 2.3)),
+ ("3", TestWritable("test3", 3, 3.1)),
+ ("5", TestWritable("test56", 5, 5.5)),
+ ("4", TestWritable("test4", 4, 4.2))
)
val rdd = sc.parallelize(testClass, numSlices = 2).map{ case (k, v) => (new Text(k), v) }
rdd.saveAsNewAPIHadoopFile(classPath,
diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
index 76956f6a345d1..a5ea478f231d7 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
@@ -20,6 +20,8 @@ package org.apache.spark.broadcast
import java.io.Serializable
import org.apache.spark.SparkException
+import org.apache.spark.Logging
+import org.apache.spark.util.Utils
import scala.reflect.ClassTag
@@ -37,7 +39,7 @@ import scala.reflect.ClassTag
*
* {{{
* scala> val broadcastVar = sc.broadcast(Array(1, 2, 3))
- * broadcastVar: spark.Broadcast[Array[Int]] = spark.Broadcast(b5c40191-a864-4c7d-b9bf-d87e1a4e787c)
+ * broadcastVar: org.apache.spark.broadcast.Broadcast[Array[Int]] = Broadcast(0)
*
* scala> broadcastVar.value
* res0: Array[Int] = Array(1, 2, 3)
@@ -52,7 +54,7 @@ import scala.reflect.ClassTag
* @param id A unique identifier for the broadcast variable.
* @tparam T Type of the data contained in the broadcast variable.
*/
-abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable {
+abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable with Logging {
/**
* Flag signifying whether the broadcast variable is valid
@@ -60,6 +62,8 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable {
*/
@volatile private var _isValid = true
+ private var _destroySite = ""
+
/** Get the broadcasted value. */
def value: T = {
assertValid()
@@ -84,13 +88,26 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable {
doUnpersist(blocking)
}
+
+ /**
+ * Destroy all data and metadata related to this broadcast variable. Use this with caution;
+ * once a broadcast variable has been destroyed, it cannot be used again.
+ * This method blocks until destroy has completed
+ */
+ def destroy() {
+ destroy(blocking = true)
+ }
+
/**
* Destroy all data and metadata related to this broadcast variable. Use this with caution;
* once a broadcast variable has been destroyed, it cannot be used again.
+ * @param blocking Whether to block until destroy has completed
*/
private[spark] def destroy(blocking: Boolean) {
assertValid()
_isValid = false
+ _destroySite = Utils.getCallSite().shortForm
+ logInfo("Destroying %s (from %s)".format(toString, _destroySite))
doDestroy(blocking)
}
@@ -106,25 +123,26 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable {
* Actually get the broadcasted value. Concrete implementations of Broadcast class must
* define their own way to get the value.
*/
- private[spark] def getValue(): T
+ protected def getValue(): T
/**
* Actually unpersist the broadcasted value on the executors. Concrete implementations of
* Broadcast class must define their own logic to unpersist their own data.
*/
- private[spark] def doUnpersist(blocking: Boolean)
+ protected def doUnpersist(blocking: Boolean)
/**
* Actually destroy all data and metadata related to this broadcast variable.
* Implementation of Broadcast class must define their own logic to destroy their own
* state.
*/
- private[spark] def doDestroy(blocking: Boolean)
+ protected def doDestroy(blocking: Boolean)
/** Check if this broadcast is valid. If not valid, exception is thrown. */
- private[spark] def assertValid() {
+ protected def assertValid() {
if (!_isValid) {
- throw new SparkException("Attempted to use %s after it has been destroyed!".format(toString))
+ throw new SparkException(
+ "Attempted to use %s after it was destroyed (%s) ".format(toString, _destroySite))
}
}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
index a8c827030a1ef..6a187b40628a2 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
@@ -32,8 +32,19 @@ import org.apache.spark.annotation.DeveloperApi
*/
@DeveloperApi
trait BroadcastFactory {
+
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit
+
+ /**
+ * Creates a new broadcast variable.
+ *
+ * @param value value to broadcast
+ * @param isLocal whether we are in local mode (single JVM process)
+ * @param id unique id representing this broadcast variable
+ */
def newBroadcast[T: ClassTag](value: T, isLocal: Boolean, id: Long): Broadcast[T]
+
def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit
+
def stop(): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala
index c88be6aba6901..8f8a0b11f9f2e 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala
@@ -39,7 +39,7 @@ private[spark] class BroadcastManager(
synchronized {
if (!initialized) {
val broadcastFactoryClass =
- conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
+ conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
broadcastFactory =
Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
index 4f6cabaff2b99..31f0a462f84d8 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
@@ -40,9 +40,9 @@ private[spark] class HttpBroadcast[T: ClassTag](
@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
- def getValue = value_
+ override protected def getValue() = value_
- val blockId = BroadcastBlockId(id)
+ private val blockId = BroadcastBlockId(id)
/*
* Broadcasted data is also stored in the BlockManager of the driver. The BlockManagerMaster
@@ -60,25 +60,25 @@ private[spark] class HttpBroadcast[T: ClassTag](
/**
* Remove all persisted state associated with this HTTP broadcast on the executors.
*/
- def doUnpersist(blocking: Boolean) {
+ override protected def doUnpersist(blocking: Boolean) {
HttpBroadcast.unpersist(id, removeFromDriver = false, blocking)
}
/**
* Remove all persisted state associated with this HTTP broadcast on the executors and driver.
*/
- def doDestroy(blocking: Boolean) {
+ override protected def doDestroy(blocking: Boolean) {
HttpBroadcast.unpersist(id, removeFromDriver = true, blocking)
}
/** Used by the JVM when serializing this object. */
- private def writeObject(out: ObjectOutputStream) {
+ private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
assertValid()
out.defaultWriteObject()
}
/** Used by the JVM when deserializing this object. */
- private def readObject(in: ObjectInputStream) {
+ private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
in.defaultReadObject()
HttpBroadcast.synchronized {
SparkEnv.get.blockManager.getSingle(blockId) match {
@@ -102,7 +102,7 @@ private[spark] class HttpBroadcast[T: ClassTag](
}
}
-private[spark] object HttpBroadcast extends Logging {
+private[broadcast] object HttpBroadcast extends Logging {
private var initialized = false
private var broadcastDir: File = null
private var compress: Boolean = false
@@ -152,7 +152,8 @@ private[spark] object HttpBroadcast extends Logging {
private def createServer(conf: SparkConf) {
broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf))
- server = new HttpServer(broadcastDir, securityManager)
+ val broadcastPort = conf.getInt("spark.broadcast.port", 0)
+ server = new HttpServer(broadcastDir, securityManager, broadcastPort, "HTTP broadcast server")
server.start()
serverUri = server.uri
logInfo("Broadcast server started at " + serverUri)
@@ -160,23 +161,28 @@ private[spark] object HttpBroadcast extends Logging {
def getFile(id: Long) = new File(broadcastDir, BroadcastBlockId(id).name)
- def write(id: Long, value: Any) {
+ private def write(id: Long, value: Any) {
val file = getFile(id)
- val out: OutputStream = {
- if (compress) {
- compressionCodec.compressedOutputStream(new FileOutputStream(file))
- } else {
- new BufferedOutputStream(new FileOutputStream(file), bufferSize)
+ val fileOutputStream = new FileOutputStream(file)
+ try {
+ val out: OutputStream = {
+ if (compress) {
+ compressionCodec.compressedOutputStream(fileOutputStream)
+ } else {
+ new BufferedOutputStream(fileOutputStream, bufferSize)
+ }
}
+ val ser = SparkEnv.get.serializer.newInstance()
+ val serOut = ser.serializeStream(out)
+ serOut.writeObject(value)
+ serOut.close()
+ files += file
+ } finally {
+ fileOutputStream.close()
}
- val ser = SparkEnv.get.serializer.newInstance()
- val serOut = ser.serializeStream(out)
- serOut.writeObject(value)
- serOut.close()
- files += file
}
- def read[T: ClassTag](id: Long): T = {
+ private def read[T: ClassTag](id: Long): T = {
logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id)
val url = serverUri + "/" + BroadcastBlockId(id).name
@@ -185,10 +191,12 @@ private[spark] object HttpBroadcast extends Logging {
logDebug("broadcast security enabled")
val newuri = Utils.constructURIForAuthentication(new URI(url), securityManager)
uc = newuri.toURL.openConnection()
+ uc.setConnectTimeout(httpReadTimeout)
uc.setAllowUserInteraction(false)
} else {
logDebug("broadcast not using security")
uc = new URL(url).openConnection()
+ uc.setConnectTimeout(httpReadTimeout)
}
val in = {
diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala
index d5a031e2bbb59..c7ef02d572a19 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala
@@ -27,21 +27,21 @@ import org.apache.spark.{SecurityManager, SparkConf}
* [[org.apache.spark.broadcast.HttpBroadcast]] for more details about this mechanism.
*/
class HttpBroadcastFactory extends BroadcastFactory {
- def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
+ override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
HttpBroadcast.initialize(isDriver, conf, securityMgr)
}
- def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) =
+ override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) =
new HttpBroadcast[T](value_, isLocal, id)
- def stop() { HttpBroadcast.stop() }
+ override def stop() { HttpBroadcast.stop() }
/**
* Remove all persisted state associated with the HTTP broadcast with the given ID.
* @param removeFromDriver Whether to remove state from the driver
* @param blocking Whether to block until unbroadcasted
*/
- def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {
+ override def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {
HttpBroadcast.unpersist(id, removeFromDriver, blocking)
}
}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
index 734de37ba115d..94142d33369c7 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -17,60 +17,133 @@
package org.apache.spark.broadcast
-import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream}
+import java.io._
+import java.nio.ByteBuffer
+import scala.collection.JavaConversions.asJavaEnumeration
import scala.reflect.ClassTag
-import scala.math
import scala.util.Random
import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException}
+import org.apache.spark.io.CompressionCodec
+import org.apache.spark.serializer.Serializer
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{ByteBufferInputStream, Utils}
+import org.apache.spark.util.io.ByteArrayChunkOutputStream
/**
- * A [[org.apache.spark.broadcast.Broadcast]] implementation that uses a BitTorrent-like
- * protocol to do a distributed transfer of the broadcasted data to the executors.
- * The mechanism is as follows. The driver divides the serializes the broadcasted data,
- * divides it into smaller chunks, and stores them in the BlockManager of the driver.
- * These chunks are reported to the BlockManagerMaster so that all the executors can
- * learn the location of those chunks. The first time the broadcast variable (sent as
- * part of task) is deserialized at a executor, all the chunks are fetched using
- * the BlockManager. When all the chunks are fetched (initially from the driver's
- * BlockManager), they are combined and deserialized to recreate the broadcasted data.
- * However, the chunks are also stored in the BlockManager and reported to the
- * BlockManagerMaster. As more executors fetch the chunks, BlockManagerMaster learns
- * multiple locations for each chunk. Hence, subsequent fetches of each chunk will be
- * made to other executors who already have those chunks, resulting in a distributed
- * fetching. This prevents the driver from being the bottleneck in sending out multiple
- * copies of the broadcast data (one per executor) as done by the
- * [[org.apache.spark.broadcast.HttpBroadcast]].
+ * A BitTorrent-like implementation of [[org.apache.spark.broadcast.Broadcast]].
+ *
+ * The mechanism is as follows:
+ *
+ * The driver divides the serialized object into small chunks and
+ * stores those chunks in the BlockManager of the driver.
+ *
+ * On each executor, the executor first attempts to fetch the object from its BlockManager. If
+ * it does not exist, it then uses remote fetches to fetch the small chunks from the driver and/or
+ * other executors if available. Once it gets the chunks, it puts the chunks in its own
+ * BlockManager, ready for other executors to fetch from.
+ *
+ * This prevents the driver from being the bottleneck in sending out multiple copies of the
+ * broadcast data (one per executor) as done by the [[org.apache.spark.broadcast.HttpBroadcast]].
+ *
+ * When initialized, TorrentBroadcast objects read SparkEnv.get.conf.
+ *
+ * @param obj object to broadcast
+ * @param id A unique identifier for the broadcast variable.
*/
-private[spark] class TorrentBroadcast[T: ClassTag](
- @transient var value_ : T, isLocal: Boolean, id: Long)
+private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
- def getValue = value_
+ /**
+ * Value of the broadcast object on executors. This is reconstructed by [[readBroadcastBlock]],
+ * which builds this value by reading blocks from the driver and/or other executors.
+ *
+ * On the driver, if the value is required, it is read lazily from the block manager.
+ */
+ @transient private lazy val _value: T = readBroadcastBlock()
+
+ /** The compression codec to use, or None if compression is disabled */
+ @transient private var compressionCodec: Option[CompressionCodec] = _
+ /** Size of each block. Default value is 4MB. This value is only read by the broadcaster. */
+ @transient private var blockSize: Int = _
+
+ private def setConf(conf: SparkConf) {
+ compressionCodec = if (conf.getBoolean("spark.broadcast.compress", true)) {
+ Some(CompressionCodec.createCodec(conf))
+ } else {
+ None
+ }
+ blockSize = conf.getInt("spark.broadcast.blockSize", 4096) * 1024
+ }
+ setConf(SparkEnv.get.conf)
+
+ private val broadcastId = BroadcastBlockId(id)
+
+ /** Total number of blocks this broadcast variable contains. */
+ private val numBlocks: Int = writeBlocks(obj)
- val broadcastId = BroadcastBlockId(id)
+ override protected def getValue() = {
+ _value
+ }
- TorrentBroadcast.synchronized {
- SparkEnv.get.blockManager.putSingle(
- broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
+ /**
+ * Divide the object into multiple blocks and put those blocks in the block manager.
+ * @param value the object to divide
+ * @return number of blocks this broadcast variable is divided into
+ */
+ private def writeBlocks(value: T): Int = {
+ // Store a copy of the broadcast variable in the driver so that tasks run on the driver
+ // do not create a duplicate copy of the broadcast variable's value.
+ SparkEnv.get.blockManager.putSingle(broadcastId, value, StorageLevel.MEMORY_AND_DISK,
+ tellMaster = false)
+ val blocks =
+ TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec)
+ blocks.zipWithIndex.foreach { case (block, i) =>
+ SparkEnv.get.blockManager.putBytes(
+ BroadcastBlockId(id, "piece" + i),
+ block,
+ StorageLevel.MEMORY_AND_DISK_SER,
+ tellMaster = true)
+ }
+ blocks.length
}
- @transient var arrayOfBlocks: Array[TorrentBlock] = null
- @transient var totalBlocks = -1
- @transient var totalBytes = -1
- @transient var hasBlocks = 0
+ /** Fetch torrent blocks from the driver and/or other executors. */
+ private def readBlocks(): Array[ByteBuffer] = {
+ // Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported
+ // to the driver, so other executors can pull these chunks from this executor as well.
+ val blocks = new Array[ByteBuffer](numBlocks)
+ val bm = SparkEnv.get.blockManager
- if (!isLocal) {
- sendBroadcast()
+ for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {
+ val pieceId = BroadcastBlockId(id, "piece" + pid)
+ logDebug(s"Reading piece $pieceId of $broadcastId")
+ // First try getLocalBytes because there is a chance that previous attempts to fetch the
+ // broadcast blocks have already fetched some of the blocks. In that case, some blocks
+ // would be available locally (on this executor).
+ def getLocal: Option[ByteBuffer] = bm.getLocalBytes(pieceId)
+ def getRemote: Option[ByteBuffer] = bm.getRemoteBytes(pieceId).map { block =>
+ // If we found the block from remote executors/driver's BlockManager, put the block
+ // in this executor's BlockManager.
+ SparkEnv.get.blockManager.putBytes(
+ pieceId,
+ block,
+ StorageLevel.MEMORY_AND_DISK_SER,
+ tellMaster = true)
+ block
+ }
+ val block: ByteBuffer = getLocal.orElse(getRemote).getOrElse(
+ throw new SparkException(s"Failed to get $pieceId of $broadcastId"))
+ blocks(pid) = block
+ }
+ blocks
}
/**
* Remove all persisted state associated with this Torrent broadcast on the executors.
*/
- def doUnpersist(blocking: Boolean) {
+ override protected def doUnpersist(blocking: Boolean) {
TorrentBroadcast.unpersist(id, removeFromDriver = false, blocking)
}
@@ -78,215 +151,79 @@ private[spark] class TorrentBroadcast[T: ClassTag](
* Remove all persisted state associated with this Torrent broadcast on the executors
* and driver.
*/
- def doDestroy(blocking: Boolean) {
+ override protected def doDestroy(blocking: Boolean) {
TorrentBroadcast.unpersist(id, removeFromDriver = true, blocking)
}
- def sendBroadcast() {
- val tInfo = TorrentBroadcast.blockifyObject(value_)
- totalBlocks = tInfo.totalBlocks
- totalBytes = tInfo.totalBytes
- hasBlocks = tInfo.totalBlocks
-
- // Store meta-info
- val metaId = BroadcastBlockId(id, "meta")
- val metaInfo = TorrentInfo(null, totalBlocks, totalBytes)
- TorrentBroadcast.synchronized {
- SparkEnv.get.blockManager.putSingle(
- metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, tellMaster = true)
- }
-
- // Store individual pieces
- for (i <- 0 until totalBlocks) {
- val pieceId = BroadcastBlockId(id, "piece" + i)
- TorrentBroadcast.synchronized {
- SparkEnv.get.blockManager.putSingle(
- pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, tellMaster = true)
- }
- }
- }
-
/** Used by the JVM when serializing this object. */
- private def writeObject(out: ObjectOutputStream) {
+ private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
assertValid()
out.defaultWriteObject()
}
- /** Used by the JVM when deserializing this object. */
- private def readObject(in: ObjectInputStream) {
- in.defaultReadObject()
+ private def readBroadcastBlock(): T = Utils.tryOrIOException {
TorrentBroadcast.synchronized {
- SparkEnv.get.blockManager.getSingle(broadcastId) match {
+ setConf(SparkEnv.get.conf)
+ SparkEnv.get.blockManager.getLocal(broadcastId).map(_.data.next()) match {
case Some(x) =>
- value_ = x.asInstanceOf[T]
+ x.asInstanceOf[T]
case None =>
- val start = System.nanoTime
logInfo("Started reading broadcast variable " + id)
-
- // Initialize @transient variables that will receive garbage values from the master.
- resetWorkerVariables()
-
- if (receiveBroadcast()) {
- value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
-
- /* Store the merged copy in cache so that the next worker doesn't need to rebuild it.
- * This creates a trade-off between memory usage and latency. Storing copy doubles
- * the memory footprint; not storing doubles deserialization cost. Also,
- * this does not need to be reported to BlockManagerMaster since other executors
- * does not need to access this block (they only need to fetch the chunks,
- * which are reported).
- */
- SparkEnv.get.blockManager.putSingle(
- broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
-
- // Remove arrayOfBlocks from memory once value_ is on local cache
- resetWorkerVariables()
- } else {
- logError("Reading broadcast variable " + id + " failed")
- }
-
- val time = (System.nanoTime - start) / 1e9
- logInfo("Reading broadcast variable " + id + " took " + time + " s")
+ val startTimeMs = System.currentTimeMillis()
+ val blocks = readBlocks()
+ logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs))
+
+ val obj = TorrentBroadcast.unBlockifyObject[T](
+ blocks, SparkEnv.get.serializer, compressionCodec)
+ // Store the merged copy in BlockManager so other tasks on this executor don't
+ // need to re-fetch it.
+ SparkEnv.get.blockManager.putSingle(
+ broadcastId, obj, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
+ obj
}
}
}
- private def resetWorkerVariables() {
- arrayOfBlocks = null
- totalBytes = -1
- totalBlocks = -1
- hasBlocks = 0
- }
-
- def receiveBroadcast(): Boolean = {
- // Receive meta-info about the size of broadcast data,
- // the number of chunks it is divided into, etc.
- val metaId = BroadcastBlockId(id, "meta")
- var attemptId = 10
- while (attemptId > 0 && totalBlocks == -1) {
- TorrentBroadcast.synchronized {
- SparkEnv.get.blockManager.getSingle(metaId) match {
- case Some(x) =>
- val tInfo = x.asInstanceOf[TorrentInfo]
- totalBlocks = tInfo.totalBlocks
- totalBytes = tInfo.totalBytes
- arrayOfBlocks = new Array[TorrentBlock](totalBlocks)
- hasBlocks = 0
-
- case None =>
- Thread.sleep(500)
- }
- }
- attemptId -= 1
- }
- if (totalBlocks == -1) {
- return false
- }
-
- /*
- * Fetch actual chunks of data. Note that all these chunks are stored in
- * the BlockManager and reported to the master, so that other executors
- * can find out and pull the chunks from this executor.
- */
- val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList)
- for (pid <- recvOrder) {
- val pieceId = BroadcastBlockId(id, "piece" + pid)
- TorrentBroadcast.synchronized {
- SparkEnv.get.blockManager.getSingle(pieceId) match {
- case Some(x) =>
- arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock]
- hasBlocks += 1
- SparkEnv.get.blockManager.putSingle(
- pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, tellMaster = true)
-
- case None =>
- throw new SparkException("Failed to get " + pieceId + " of " + broadcastId)
- }
- }
- }
-
- hasBlocks == totalBlocks
- }
-
}
-private[spark] object TorrentBroadcast extends Logging {
- private lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024
- private var initialized = false
- private var conf: SparkConf = null
- def initialize(_isDriver: Boolean, conf: SparkConf) {
- TorrentBroadcast.conf = conf // TODO: we might have to fix it in tests
- synchronized {
- if (!initialized) {
- initialized = true
- }
- }
- }
+private object TorrentBroadcast extends Logging {
- def stop() {
- initialized = false
+ def blockifyObject[T: ClassTag](
+ obj: T,
+ blockSize: Int,
+ serializer: Serializer,
+ compressionCodec: Option[CompressionCodec]): Array[ByteBuffer] = {
+ val bos = new ByteArrayChunkOutputStream(blockSize)
+ val out: OutputStream = compressionCodec.map(c => c.compressedOutputStream(bos)).getOrElse(bos)
+ val ser = serializer.newInstance()
+ val serOut = ser.serializeStream(out)
+ serOut.writeObject[T](obj).close()
+ bos.toArrays.map(ByteBuffer.wrap)
}
- def blockifyObject[T](obj: T): TorrentInfo = {
- val byteArray = Utils.serialize[T](obj)
- val bais = new ByteArrayInputStream(byteArray)
-
- var blockNum = byteArray.length / BLOCK_SIZE
- if (byteArray.length % BLOCK_SIZE != 0) {
- blockNum += 1
- }
-
- val blocks = new Array[TorrentBlock](blockNum)
- var blockId = 0
-
- for (i <- 0 until (byteArray.length, BLOCK_SIZE)) {
- val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i)
- val tempByteArray = new Array[Byte](thisBlockSize)
- bais.read(tempByteArray, 0, thisBlockSize)
-
- blocks(blockId) = new TorrentBlock(blockId, tempByteArray)
- blockId += 1
- }
- bais.close()
-
- val info = TorrentInfo(blocks, blockNum, byteArray.length)
- info.hasBlocks = blockNum
- info
- }
-
- def unBlockifyObject[T](
- arrayOfBlocks: Array[TorrentBlock],
- totalBytes: Int,
- totalBlocks: Int): T = {
- val retByteArray = new Array[Byte](totalBytes)
- for (i <- 0 until totalBlocks) {
- System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
- i * BLOCK_SIZE, arrayOfBlocks(i).byteArray.length)
- }
- Utils.deserialize[T](retByteArray, Thread.currentThread.getContextClassLoader)
+ def unBlockifyObject[T: ClassTag](
+ blocks: Array[ByteBuffer],
+ serializer: Serializer,
+ compressionCodec: Option[CompressionCodec]): T = {
+ require(blocks.nonEmpty, "Cannot unblockify an empty array of blocks")
+ val is = new SequenceInputStream(
+ asJavaEnumeration(blocks.iterator.map(block => new ByteBufferInputStream(block))))
+ val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is)
+ val ser = serializer.newInstance()
+ val serIn = ser.deserializeStream(in)
+ val obj = serIn.readObject[T]()
+ serIn.close()
+ obj
}
/**
* Remove all persisted blocks associated with this torrent broadcast on the executors.
* If removeFromDriver is true, also remove these persisted blocks on the driver.
*/
- def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = synchronized {
+ def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = {
+ logDebug(s"Unpersisting TorrentBroadcast $id")
SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking)
}
}
-
-private[spark] case class TorrentBlock(
- blockID: Int,
- byteArray: Array[Byte])
- extends Serializable
-
-private[spark] case class TorrentInfo(
- @transient arrayOfBlocks: Array[TorrentBlock],
- totalBlocks: Int,
- totalBytes: Int)
- extends Serializable {
-
- @transient var hasBlocks = 0
-}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala
index 1de8396a0e17f..fb024c12094f2 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala
@@ -28,21 +28,20 @@ import org.apache.spark.{SecurityManager, SparkConf}
*/
class TorrentBroadcastFactory extends BroadcastFactory {
- def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
- TorrentBroadcast.initialize(isDriver, conf)
- }
+ override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { }
- def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) =
- new TorrentBroadcast[T](value_, isLocal, id)
+ override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) = {
+ new TorrentBroadcast[T](value_, id)
+ }
- def stop() { TorrentBroadcast.stop() }
+ override def stop() { }
/**
* Remove all persisted state associated with the torrent broadcast with the given ID.
* @param removeFromDriver Whether to remove state from the driver.
* @param blocking Whether to block until unbroadcasted
*/
- def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {
+ override def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {
TorrentBroadcast.unpersist(id, removeFromDriver, blocking)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala
index 86305d2ea8a09..65a1a8fd7e929 100644
--- a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala
@@ -22,7 +22,6 @@ private[spark] class ApplicationDescription(
val maxCores: Option[Int],
val memoryPerSlave: Int,
val command: Command,
- val sparkHome: Option[String],
var appUiUrl: String,
val eventLogDir: Option[String] = None)
extends Serializable {
diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala
index c371dc3a51c73..f2687ce6b42b4 100644
--- a/core/src/main/scala/org/apache/spark/deploy/Client.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala
@@ -17,8 +17,6 @@
package org.apache.spark.deploy
-import scala.collection.JavaConversions._
-import scala.collection.mutable.Map
import scala.concurrent._
import akka.actor._
@@ -29,12 +27,14 @@ import org.apache.log4j.{Level, Logger}
import org.apache.spark.{Logging, SecurityManager, SparkConf}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.{DriverState, Master}
-import org.apache.spark.util.{AkkaUtils, Utils}
+import org.apache.spark.util.{ActorLogReceive, AkkaUtils, Utils}
/**
* Proxy that relays messages to the driver.
*/
-private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends Actor with Logging {
+private class ClientActor(driverArgs: ClientArguments, conf: SparkConf)
+ extends Actor with ActorLogReceive with Logging {
+
var masterActor: ActorSelection = _
val timeout = AkkaUtils.askTimeout(conf)
@@ -50,9 +50,6 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends
// TODO: We could add an env variable here and intercept it in `sc.addJar` that would
// truncate filesystem paths similar to what YARN does. For now, we just require
// people call `addJar` assuming the jar is in the same directory.
- val env = Map[String, String]()
- System.getenv().foreach{case (k, v) => env(k) = v}
-
val mainClass = "org.apache.spark.deploy.worker.DriverWrapper"
val classPathConf = "spark.driver.extraClassPath"
@@ -65,10 +62,13 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends
cp.split(java.io.File.pathSeparator)
}
- val javaOptionsConf = "spark.driver.extraJavaOptions"
- val javaOpts = sys.props.get(javaOptionsConf)
+ val extraJavaOptsConf = "spark.driver.extraJavaOptions"
+ val extraJavaOpts = sys.props.get(extraJavaOptsConf)
+ .map(Utils.splitCommandString).getOrElse(Seq.empty)
+ val sparkJavaOpts = Utils.sparkJavaOpts(conf)
+ val javaOpts = sparkJavaOpts ++ extraJavaOpts
val command = new Command(mainClass, Seq("{{WORKER_URL}}", driverArgs.mainClass) ++
- driverArgs.driverOptions, env, classPathEntries, libraryPathEntries, javaOpts)
+ driverArgs.driverOptions, sys.env, classPathEntries, libraryPathEntries, javaOpts)
val driverDescription = new DriverDescription(
driverArgs.jarUrl,
@@ -109,13 +109,14 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends
// Exception, if present
statusResponse.exception.map { e =>
println(s"Exception from cluster was: $e")
+ e.printStackTrace()
System.exit(-1)
}
System.exit(0)
}
}
- override def receive = {
+ override def receiveWithLogging = {
case SubmitDriverResponse(success, driverId, message) =>
println(message)
@@ -129,7 +130,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends
println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.")
System.exit(-1)
- case AssociationErrorEvent(cause, _, remoteAddress, _) =>
+ case AssociationErrorEvent(cause, _, remoteAddress, _, _) =>
println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.")
println(s"Cause was: $cause")
System.exit(-1)
@@ -141,8 +142,10 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends
*/
object Client {
def main(args: Array[String]) {
- println("WARNING: This client is deprecated and will be removed in a future version of Spark.")
- println("Use ./bin/spark-submit with \"--master spark://host:port\"")
+ if (!sys.props.contains("SPARK_SUBMIT")) {
+ println("WARNING: This client is deprecated and will be removed in a future version of Spark")
+ println("Use ./bin/spark-submit with \"--master spark://host:port\"")
+ }
val conf = new SparkConf()
val driverArgs = new ClientArguments(args)
@@ -154,8 +157,6 @@ object Client {
conf.set("akka.loglevel", driverArgs.logLevel.toString.replace("WARN", "WARNING"))
Logger.getRootLogger.setLevel(driverArgs.logLevel)
- // TODO: See if we can initialize akka so return messages are sent back using the same TCP
- // flow. Else, this (sadly) requires the DriverClient be routable from the Master.
val (actorSystem, _) = AkkaUtils.createActorSystem(
"driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf))
diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
index 39150deab863c..4e802e02c4149 100644
--- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
@@ -17,6 +17,8 @@
package org.apache.spark.deploy
+import java.net.{URI, URISyntaxException}
+
import scala.collection.mutable.ListBuffer
import org.apache.log4j.Level
@@ -114,5 +116,12 @@ private[spark] class ClientArguments(args: Array[String]) {
}
object ClientArguments {
- def isValidJarUrl(s: String): Boolean = s.matches("(.+):(.+)jar")
+ def isValidJarUrl(s: String): Boolean = {
+ try {
+ val uri = new URI(s)
+ uri.getScheme != null && uri.getAuthority != null && s.endsWith("jar")
+ } catch {
+ case _: URISyntaxException => false
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/Command.scala b/core/src/main/scala/org/apache/spark/deploy/Command.scala
index 32f3ba385084f..a2b263544c6a2 100644
--- a/core/src/main/scala/org/apache/spark/deploy/Command.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/Command.scala
@@ -25,5 +25,5 @@ private[spark] case class Command(
environment: Map[String, String],
classPathEntries: Seq[String],
libraryPathEntries: Seq[String],
- extraJavaOptions: Option[String] = None) {
+ javaOpts: Seq[String]) {
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
index a7368f9f3dfbe..b9dd8557ee904 100644
--- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
@@ -71,6 +71,8 @@ private[deploy] object DeployMessages {
case class RegisterWorkerFailed(message: String) extends DeployMessage
+ case class ReconnectWorker(masterUrl: String) extends DeployMessage
+
case class KillExecutor(masterUrl: String, appId: String, execId: Int) extends DeployMessage
case class LaunchExecutor(
diff --git a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala
index c4f5e294a393e..696f32a6f5730 100644
--- a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala
@@ -56,7 +56,6 @@ private[spark] object JsonProtocol {
("cores" -> obj.maxCores) ~
("memoryperslave" -> obj.memoryPerSlave) ~
("user" -> obj.user) ~
- ("sparkhome" -> obj.sparkHome) ~
("command" -> obj.command.toString)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
index 0d6751f3fa6d2..039c8719e2867 100644
--- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
@@ -22,8 +22,8 @@ import java.net.URI
import scala.collection.mutable.ArrayBuffer
import scala.collection.JavaConversions._
-import org.apache.spark.api.python.{PythonUtils, RedirectThread}
-import org.apache.spark.util.Utils
+import org.apache.spark.api.python.PythonUtils
+import org.apache.spark.util.{RedirectThread, Utils}
/**
* A main class used by spark-submit to launch Python applications. It executes python as a
@@ -34,7 +34,8 @@ object PythonRunner {
val pythonFile = args(0)
val pyFiles = args(1)
val otherArgs = args.slice(2, args.length)
- val pythonExec = sys.env.get("PYSPARK_PYTHON").getOrElse("python") // TODO: get this from conf
+ val pythonExec =
+ sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", sys.env.getOrElse("PYSPARK_PYTHON", "python"))
// Format python file paths before adding them to the PYTHONPATH
val formattedPythonFile = formatPath(pythonFile)
@@ -54,9 +55,11 @@ object PythonRunner {
val pythonPath = PythonUtils.mergePythonPaths(pathElements: _*)
// Launch Python process
- val builder = new ProcessBuilder(Seq(pythonExec, "-u", formattedPythonFile) ++ otherArgs)
+ val builder = new ProcessBuilder(Seq(pythonExec, formattedPythonFile) ++ otherArgs)
val env = builder.environment()
env.put("PYTHONPATH", pythonPath)
+ // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
+ env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string
env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort)
builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize
val process = builder.start()
@@ -84,8 +87,8 @@ object PythonRunner {
// Strip the URI scheme from the path
formattedPath =
new URI(formattedPath).getScheme match {
- case Utils.windowsDrive(d) if windows => formattedPath
case null => formattedPath
+ case Utils.windowsDrive(d) if windows => formattedPath
case _ => new URI(formattedPath).getPath
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index 148115d3ed351..60ee115e393ce 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -17,22 +17,29 @@
package org.apache.spark.deploy
+import java.lang.reflect.Method
import java.security.PrivilegedExceptionAction
import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.fs.FileSystem.Statistics
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.security.Credentials
import org.apache.hadoop.security.UserGroupInformation
-import org.apache.spark.{Logging, SparkContext, SparkException}
+import org.apache.spark.{Logging, SparkContext, SparkConf, SparkException}
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.util.Utils
import scala.collection.JavaConversions._
/**
+ * :: DeveloperApi ::
* Contains util methods to interact with Hadoop from Spark.
*/
+@DeveloperApi
class SparkHadoopUtil extends Logging {
- val conf: Configuration = newConfiguration()
+ val conf: Configuration = newConfiguration(new SparkConf())
UserGroupInformation.setConfiguration(conf)
/**
@@ -64,11 +71,39 @@ class SparkHadoopUtil extends Logging {
}
}
+ @Deprecated
+ def newConfiguration(): Configuration = newConfiguration(null)
+
/**
* Return an appropriate (subclass) of Configuration. Creating config can initializes some Hadoop
* subsystems.
*/
- def newConfiguration(): Configuration = new Configuration()
+ def newConfiguration(conf: SparkConf): Configuration = {
+ val hadoopConf = new Configuration()
+
+ // Note: this null check is around more than just access to the "conf" object to maintain
+ // the behavior of the old implementation of this code, for backwards compatibility.
+ if (conf != null) {
+ // Explicitly check for S3 environment variables
+ if (System.getenv("AWS_ACCESS_KEY_ID") != null &&
+ System.getenv("AWS_SECRET_ACCESS_KEY") != null) {
+ hadoopConf.set("fs.s3.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID"))
+ hadoopConf.set("fs.s3n.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID"))
+ hadoopConf.set("fs.s3.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY"))
+ hadoopConf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY"))
+ }
+ // Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar"
+ conf.getAll.foreach { case (key, value) =>
+ if (key.startsWith("spark.hadoop.")) {
+ hadoopConf.set(key.substring("spark.hadoop.".length), value)
+ }
+ }
+ val bufferSize = conf.get("spark.buffer.size", "65536")
+ hadoopConf.set("io.file.buffer.size", bufferSize)
+ }
+
+ hadoopConf
+ }
/**
* Add any user credentials to the job conf which are necessary for running on a secure Hadoop
@@ -86,10 +121,68 @@ class SparkHadoopUtil extends Logging {
def getSecretKeyFromUserCredentials(key: String): Array[Byte] = { null }
- def loginUserFromKeytab(principalName: String, keytabFilename: String) {
+ def loginUserFromKeytab(principalName: String, keytabFilename: String) {
UserGroupInformation.loginUserFromKeytab(principalName, keytabFilename)
}
+ /**
+ * Returns a function that can be called to find Hadoop FileSystem bytes read. If
+ * getFSBytesReadOnThreadCallback is called from thread r at time t, the returned callback will
+ * return the bytes read on r since t. Reflection is required because thread-level FileSystem
+ * statistics are only available as of Hadoop 2.5 (see HADOOP-10688).
+ * Returns None if the required method can't be found.
+ */
+ private[spark] def getFSBytesReadOnThreadCallback(path: Path, conf: Configuration)
+ : Option[() => Long] = {
+ try {
+ val threadStats = getFileSystemThreadStatistics(path, conf)
+ val getBytesReadMethod = getFileSystemThreadStatisticsMethod("getBytesRead")
+ val f = () => threadStats.map(getBytesReadMethod.invoke(_).asInstanceOf[Long]).sum
+ val baselineBytesRead = f()
+ Some(() => f() - baselineBytesRead)
+ } catch {
+ case e: NoSuchMethodException => {
+ logDebug("Couldn't find method for retrieving thread-level FileSystem input data", e)
+ None
+ }
+ }
+ }
+
+ /**
+ * Returns a function that can be called to find Hadoop FileSystem bytes written. If
+ * getFSBytesWrittenOnThreadCallback is called from thread r at time t, the returned callback will
+ * return the bytes written on r since t. Reflection is required because thread-level FileSystem
+ * statistics are only available as of Hadoop 2.5 (see HADOOP-10688).
+ * Returns None if the required method can't be found.
+ */
+ private[spark] def getFSBytesWrittenOnThreadCallback(path: Path, conf: Configuration)
+ : Option[() => Long] = {
+ try {
+ val threadStats = getFileSystemThreadStatistics(path, conf)
+ val getBytesWrittenMethod = getFileSystemThreadStatisticsMethod("getBytesWritten")
+ val f = () => threadStats.map(getBytesWrittenMethod.invoke(_).asInstanceOf[Long]).sum
+ val baselineBytesWritten = f()
+ Some(() => f() - baselineBytesWritten)
+ } catch {
+ case e: NoSuchMethodException => {
+ logDebug("Couldn't find method for retrieving thread-level FileSystem output data", e)
+ None
+ }
+ }
+ }
+
+ private def getFileSystemThreadStatistics(path: Path, conf: Configuration): Seq[AnyRef] = {
+ val qualifiedPath = path.getFileSystem(conf).makeQualified(path)
+ val scheme = qualifiedPath.toUri().getScheme()
+ val stats = FileSystem.getAllStatistics().filter(_.getScheme().equals(scheme))
+ stats.map(Utils.invoke(classOf[Statistics], _, "getThreadStatistics"))
+ }
+
+ private def getFileSystemThreadStatisticsMethod(methodName: String): Method = {
+ val statisticsDataClass =
+ Class.forName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData")
+ statisticsDataClass.getDeclaredMethod(methodName)
+ }
}
object SparkHadoopUtil {
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index b050dccb6d57f..8a62519bd2315 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -18,7 +18,7 @@
package org.apache.spark.deploy
import java.io.{File, PrintStream}
-import java.lang.reflect.InvocationTargetException
+import java.lang.reflect.{Modifier, InvocationTargetException}
import java.net.URL
import scala.collection.mutable.{ArrayBuffer, HashMap, Map}
@@ -27,116 +27,126 @@ import org.apache.spark.executor.ExecutorURLClassLoader
import org.apache.spark.util.Utils
/**
- * Scala code behind the spark-submit script. The script handles setting up the classpath with
- * relevant Spark dependencies and provides a layer over the different cluster managers and deploy
- * modes that Spark supports.
+ * Main gateway of launching a Spark application.
+ *
+ * This program handles setting up the classpath with relevant Spark dependencies and provides
+ * a layer over the different cluster managers and deploy modes that Spark supports.
*/
object SparkSubmit {
+
+ // Cluster managers
private val YARN = 1
private val STANDALONE = 2
private val MESOS = 4
private val LOCAL = 8
private val ALL_CLUSTER_MGRS = YARN | STANDALONE | MESOS | LOCAL
- private var clusterManager: Int = LOCAL
+ // Deploy modes
+ private val CLIENT = 1
+ private val CLUSTER = 2
+ private val ALL_DEPLOY_MODES = CLIENT | CLUSTER
- /**
- * Special primary resource names that represent shells rather than application jars.
- */
+ // A special jar name that indicates the class being run is inside of Spark itself, and therefore
+ // no user jar is needed.
+ private val SPARK_INTERNAL = "spark-internal"
+
+ // Special primary resource names that represent shells rather than application jars.
private val SPARK_SHELL = "spark-shell"
private val PYSPARK_SHELL = "pyspark-shell"
- def main(args: Array[String]) {
- val appArgs = new SparkSubmitArguments(args)
- if (appArgs.verbose) {
- printStream.println(appArgs)
- }
- val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs)
- launch(childArgs, classpath, sysProps, mainClass, appArgs.verbose)
- }
+ private val CLASS_NOT_FOUND_EXIT_STATUS = 101
// Exposed for testing
- private[spark] var printStream: PrintStream = System.err
private[spark] var exitFn: () => Unit = () => System.exit(-1)
-
+ private[spark] var printStream: PrintStream = System.err
+ private[spark] def printWarning(str: String) = printStream.println("Warning: " + str)
private[spark] def printErrorAndExit(str: String) = {
printStream.println("Error: " + str)
printStream.println("Run with --help for usage help or --verbose for debug output")
exitFn()
}
- private[spark] def printWarning(str: String) = printStream.println("Warning: " + str)
+
+ def main(args: Array[String]) {
+ val appArgs = new SparkSubmitArguments(args)
+ if (appArgs.verbose) {
+ printStream.println(appArgs)
+ }
+ val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs)
+ launch(childArgs, classpath, sysProps, mainClass, appArgs.verbose)
+ }
/**
- * @return a tuple containing the arguments for the child, a list of classpath
- * entries for the child, a list of system properties, a list of env vars
- * and the main class for the child
+ * @return a tuple containing
+ * (1) the arguments for the child process,
+ * (2) a list of classpath entries for the child,
+ * (3) a list of system properties and env vars, and
+ * (4) the main class for the child
*/
private[spark] def createLaunchEnv(args: SparkSubmitArguments)
: (ArrayBuffer[String], ArrayBuffer[String], Map[String, String], String) = {
- if (args.master.startsWith("local")) {
- clusterManager = LOCAL
- } else if (args.master.startsWith("yarn")) {
- clusterManager = YARN
- } else if (args.master.startsWith("spark")) {
- clusterManager = STANDALONE
- } else if (args.master.startsWith("mesos")) {
- clusterManager = MESOS
- } else {
- printErrorAndExit("Master must start with yarn, mesos, spark, or local")
- }
-
- // Because "yarn-cluster" and "yarn-client" encapsulate both the master
- // and deploy mode, we have some logic to infer the master and deploy mode
- // from each other if only one is specified, or exit early if they are at odds.
- if (args.deployMode == null &&
- (args.master == "yarn-standalone" || args.master == "yarn-cluster")) {
- args.deployMode = "cluster"
- }
- if (args.deployMode == "cluster" && args.master == "yarn-client") {
- printErrorAndExit("Deploy mode \"cluster\" and master \"yarn-client\" are not compatible")
- }
- if (args.deployMode == "client" &&
- (args.master == "yarn-standalone" || args.master == "yarn-cluster")) {
- printErrorAndExit("Deploy mode \"client\" and master \"" + args.master
- + "\" are not compatible")
- }
- if (args.deployMode == "cluster" && args.master.startsWith("yarn")) {
- args.master = "yarn-cluster"
- }
- if (args.deployMode != "cluster" && args.master.startsWith("yarn")) {
- args.master = "yarn-client"
- }
- val deployOnCluster = Option(args.deployMode).getOrElse("client") == "cluster"
-
- val childClasspath = new ArrayBuffer[String]()
+ // Values to return
val childArgs = new ArrayBuffer[String]()
+ val childClasspath = new ArrayBuffer[String]()
val sysProps = new HashMap[String, String]()
var childMainClass = ""
- val isPython = args.isPython
- val isYarnCluster = clusterManager == YARN && deployOnCluster
+ // Set the cluster manager
+ val clusterManager: Int = args.master match {
+ case m if m.startsWith("yarn") => YARN
+ case m if m.startsWith("spark") => STANDALONE
+ case m if m.startsWith("mesos") => MESOS
+ case m if m.startsWith("local") => LOCAL
+ case _ => printErrorAndExit("Master must start with yarn, spark, mesos, or local"); -1
+ }
- // For mesos, only client mode is supported
- if (clusterManager == MESOS && deployOnCluster) {
- printErrorAndExit("Cluster deploy mode is currently not supported for Mesos clusters.")
+ // Set the deploy mode; default is client mode
+ var deployMode: Int = args.deployMode match {
+ case "client" | null => CLIENT
+ case "cluster" => CLUSTER
+ case _ => printErrorAndExit("Deploy mode must be either client or cluster"); -1
}
- // For standalone, only client mode is supported
- if (clusterManager == STANDALONE && deployOnCluster) {
- printErrorAndExit("Cluster deploy mode is currently not supported for standalone clusters.")
+ // Because "yarn-cluster" and "yarn-client" encapsulate both the master
+ // and deploy mode, we have some logic to infer the master and deploy mode
+ // from each other if only one is specified, or exit early if they are at odds.
+ if (clusterManager == YARN) {
+ if (args.master == "yarn-standalone") {
+ printWarning("\"yarn-standalone\" is deprecated. Use \"yarn-cluster\" instead.")
+ args.master = "yarn-cluster"
+ }
+ (args.master, args.deployMode) match {
+ case ("yarn-cluster", null) =>
+ deployMode = CLUSTER
+ case ("yarn-cluster", "client") =>
+ printErrorAndExit("Client deploy mode is not compatible with master \"yarn-cluster\"")
+ case ("yarn-client", "cluster") =>
+ printErrorAndExit("Cluster deploy mode is not compatible with master \"yarn-client\"")
+ case (_, mode) =>
+ args.master = "yarn-" + Option(mode).getOrElse("client")
+ }
+
+ // Make sure YARN is included in our build if we're trying to use it
+ if (!Utils.classIsLoadable("org.apache.spark.deploy.yarn.Client") && !Utils.isTesting) {
+ printErrorAndExit(
+ "Could not load YARN classes. " +
+ "This copy of Spark may not have been compiled with YARN support.")
+ }
}
- // For shells, only client mode is applicable
- if (isShell(args.primaryResource) && deployOnCluster) {
- printErrorAndExit("Cluster deploy mode is not applicable to Spark shells.")
+ // The following modes are not supported or applicable
+ (clusterManager, deployMode) match {
+ case (MESOS, CLUSTER) =>
+ printErrorAndExit("Cluster deploy mode is currently not supported for Mesos clusters.")
+ case (_, CLUSTER) if args.isPython =>
+ printErrorAndExit("Cluster deploy mode is currently not supported for python applications.")
+ case (_, CLUSTER) if isShell(args.primaryResource) =>
+ printErrorAndExit("Cluster deploy mode is not applicable to Spark shells.")
+ case _ =>
}
// If we're running a python app, set the main class to our specific python runner
- if (isPython) {
- if (deployOnCluster) {
- printErrorAndExit("Cluster deploy mode is currently not supported for python.")
- }
+ if (args.isPython) {
if (args.primaryResource == PYSPARK_SHELL) {
args.mainClass = "py4j.GatewayServer"
args.childArgs = ArrayBuffer("--die-on-broken-pipe", "0")
@@ -148,27 +158,8 @@ object SparkSubmit {
args.files = mergeFileLists(args.files, args.primaryResource)
}
args.files = mergeFileLists(args.files, args.pyFiles)
- // Format python file paths properly before adding them to the PYTHONPATH
- sysProps("spark.submit.pyFiles") = PythonRunner.formatPaths(args.pyFiles).mkString(",")
- }
-
- // If we're deploying into YARN, use yarn.Client as a wrapper around the user class
- if (!deployOnCluster) {
- childMainClass = args.mainClass
- if (isUserJar(args.primaryResource)) {
- childClasspath += args.primaryResource
- }
- } else if (clusterManager == YARN) {
- childMainClass = "org.apache.spark.deploy.yarn.Client"
- childArgs += ("--jar", args.primaryResource)
- childArgs += ("--class", args.mainClass)
- }
-
- // Make sure YARN is included in our build if we're trying to use it
- if (clusterManager == YARN) {
- if (!Utils.classIsLoadable("org.apache.spark.deploy.yarn.Client") && !Utils.isTesting) {
- printErrorAndExit("Could not load YARN classes. " +
- "This copy of Spark may not have been compiled with YARN support.")
+ if (args.pyFiles != null) {
+ sysProps("spark.submit.pyFiles") = args.pyFiles
}
}
@@ -178,94 +169,137 @@ object SparkSubmit {
// A list of rules to map each argument to system properties or command-line options in
// each deploy mode; we iterate through these below
val options = List[OptionAssigner](
- OptionAssigner(args.master, ALL_CLUSTER_MGRS, false, sysProp = "spark.master"),
- OptionAssigner(args.name, ALL_CLUSTER_MGRS, false, sysProp = "spark.app.name"),
- OptionAssigner(args.name, YARN, true, clOption = "--name", sysProp = "spark.app.name"),
- OptionAssigner(args.driverExtraClassPath, STANDALONE | YARN, true,
+
+ // All cluster managers
+ OptionAssigner(args.master, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.master"),
+ OptionAssigner(args.name, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.app.name"),
+ OptionAssigner(args.jars, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars"),
+ OptionAssigner(args.driverMemory, ALL_CLUSTER_MGRS, CLIENT,
+ sysProp = "spark.driver.memory"),
+ OptionAssigner(args.driverExtraClassPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES,
sysProp = "spark.driver.extraClassPath"),
- OptionAssigner(args.driverExtraJavaOptions, STANDALONE | YARN, true,
+ OptionAssigner(args.driverExtraJavaOptions, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES,
sysProp = "spark.driver.extraJavaOptions"),
- OptionAssigner(args.driverExtraLibraryPath, STANDALONE | YARN, true,
+ OptionAssigner(args.driverExtraLibraryPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES,
sysProp = "spark.driver.extraLibraryPath"),
- OptionAssigner(args.driverMemory, YARN, true, clOption = "--driver-memory"),
- OptionAssigner(args.driverMemory, STANDALONE, true, clOption = "--memory"),
- OptionAssigner(args.driverCores, STANDALONE, true, clOption = "--cores"),
- OptionAssigner(args.queue, YARN, true, clOption = "--queue"),
- OptionAssigner(args.queue, YARN, false, sysProp = "spark.yarn.queue"),
- OptionAssigner(args.numExecutors, YARN, true, clOption = "--num-executors"),
- OptionAssigner(args.numExecutors, YARN, false, sysProp = "spark.executor.instances"),
- OptionAssigner(args.executorMemory, YARN, true, clOption = "--executor-memory"),
- OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN, false,
+
+ // Standalone cluster only
+ OptionAssigner(args.jars, STANDALONE, CLUSTER, sysProp = "spark.jars"),
+ OptionAssigner(args.driverMemory, STANDALONE, CLUSTER, clOption = "--memory"),
+ OptionAssigner(args.driverCores, STANDALONE, CLUSTER, clOption = "--cores"),
+
+ // Yarn client only
+ OptionAssigner(args.queue, YARN, CLIENT, sysProp = "spark.yarn.queue"),
+ OptionAssigner(args.numExecutors, YARN, CLIENT, sysProp = "spark.executor.instances"),
+ OptionAssigner(args.executorCores, YARN, CLIENT, sysProp = "spark.executor.cores"),
+ OptionAssigner(args.files, YARN, CLIENT, sysProp = "spark.yarn.dist.files"),
+ OptionAssigner(args.archives, YARN, CLIENT, sysProp = "spark.yarn.dist.archives"),
+
+ // Yarn cluster only
+ OptionAssigner(args.name, YARN, CLUSTER, clOption = "--name"),
+ OptionAssigner(args.driverMemory, YARN, CLUSTER, clOption = "--driver-memory"),
+ OptionAssigner(args.queue, YARN, CLUSTER, clOption = "--queue"),
+ OptionAssigner(args.numExecutors, YARN, CLUSTER, clOption = "--num-executors"),
+ OptionAssigner(args.executorMemory, YARN, CLUSTER, clOption = "--executor-memory"),
+ OptionAssigner(args.executorCores, YARN, CLUSTER, clOption = "--executor-cores"),
+ OptionAssigner(args.files, YARN, CLUSTER, clOption = "--files"),
+ OptionAssigner(args.archives, YARN, CLUSTER, clOption = "--archives"),
+ OptionAssigner(args.jars, YARN, CLUSTER, clOption = "--addJars"),
+
+ // Other options
+ OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN, ALL_DEPLOY_MODES,
sysProp = "spark.executor.memory"),
- OptionAssigner(args.executorCores, YARN, true, clOption = "--executor-cores"),
- OptionAssigner(args.executorCores, YARN, false, sysProp = "spark.executor.cores"),
- OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS, false,
+ OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS, ALL_DEPLOY_MODES,
sysProp = "spark.cores.max"),
- OptionAssigner(args.files, YARN, false, sysProp = "spark.yarn.dist.files"),
- OptionAssigner(args.files, YARN, true, clOption = "--files"),
- OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, false, sysProp = "spark.files"),
- OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, true, sysProp = "spark.files"),
- OptionAssigner(args.archives, YARN, false, sysProp = "spark.yarn.dist.archives"),
- OptionAssigner(args.archives, YARN, true, clOption = "--archives"),
- OptionAssigner(args.jars, YARN, true, clOption = "--addJars"),
- OptionAssigner(args.jars, ALL_CLUSTER_MGRS, false, sysProp = "spark.jars")
+ OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, ALL_DEPLOY_MODES,
+ sysProp = "spark.files")
)
- // For client mode make any added jars immediately visible on the classpath
- if (args.jars != null && !deployOnCluster) {
- for (jar <- args.jars.split(",")) {
- childClasspath += jar
+ // In client mode, launch the application main class directly
+ // In addition, add the main application jar and any added jars (if any) to the classpath
+ if (deployMode == CLIENT) {
+ childMainClass = args.mainClass
+ if (isUserJar(args.primaryResource)) {
+ childClasspath += args.primaryResource
}
+ if (args.jars != null) { childClasspath ++= args.jars.split(",") }
+ if (args.childArgs != null) { childArgs ++= args.childArgs }
}
+
// Map all arguments to command-line options or system properties for our chosen mode
for (opt <- options) {
- if (opt.value != null && deployOnCluster == opt.deployOnCluster &&
+ if (opt.value != null &&
+ (deployMode & opt.deployMode) != 0 &&
(clusterManager & opt.clusterManager) != 0) {
- if (opt.clOption != null) {
- childArgs += (opt.clOption, opt.value)
- }
- if (opt.sysProp != null) {
- sysProps.put(opt.sysProp, opt.value)
- }
+ if (opt.clOption != null) { childArgs += (opt.clOption, opt.value) }
+ if (opt.sysProp != null) { sysProps.put(opt.sysProp, opt.value) }
}
}
// Add the application jar automatically so the user doesn't have to call sc.addJar
// For YARN cluster mode, the jar is already distributed on each node as "app.jar"
// For python files, the primary resource is already distributed as a regular file
- if (!isYarnCluster && !isPython) {
- var jars = sysProps.get("spark.jars").map(x => x.split(",").toSeq).getOrElse(Seq())
+ val isYarnCluster = clusterManager == YARN && deployMode == CLUSTER
+ if (!isYarnCluster && !args.isPython) {
+ var jars = sysProps.get("spark.jars").map(x => x.split(",").toSeq).getOrElse(Seq.empty)
if (isUserJar(args.primaryResource)) {
jars = jars ++ Seq(args.primaryResource)
}
sysProps.put("spark.jars", jars.mkString(","))
}
- // Standalone cluster specific configurations
- if (deployOnCluster && clusterManager == STANDALONE) {
+ // In standalone-cluster mode, use Client as a wrapper around the user class
+ if (clusterManager == STANDALONE && deployMode == CLUSTER) {
+ childMainClass = "org.apache.spark.deploy.Client"
if (args.supervise) {
childArgs += "--supervise"
}
- childMainClass = "org.apache.spark.deploy.Client"
childArgs += "launch"
childArgs += (args.master, args.primaryResource, args.mainClass)
+ if (args.childArgs != null) {
+ childArgs ++= args.childArgs
+ }
}
- // Arguments to be passed to user program
- if (args.childArgs != null) {
- if (!deployOnCluster || clusterManager == STANDALONE) {
- childArgs ++= args.childArgs
- } else if (clusterManager == YARN) {
- for (arg <- args.childArgs) {
- childArgs += ("--arg", arg)
- }
+ // In yarn-cluster mode, use yarn.Client as a wrapper around the user class
+ if (isYarnCluster) {
+ childMainClass = "org.apache.spark.deploy.yarn.Client"
+ if (args.primaryResource != SPARK_INTERNAL) {
+ childArgs += ("--jar", args.primaryResource)
+ }
+ childArgs += ("--class", args.mainClass)
+ if (args.childArgs != null) {
+ args.childArgs.foreach { arg => childArgs += ("--arg", arg) }
+ }
+ }
+
+ // Load any properties specified through --conf and the default properties file
+ for ((k, v) <- args.sparkProperties) {
+ sysProps.getOrElseUpdate(k, v)
+ }
+
+ // Resolve paths in certain spark properties
+ val pathConfigs = Seq(
+ "spark.jars",
+ "spark.files",
+ "spark.yarn.jar",
+ "spark.yarn.dist.files",
+ "spark.yarn.dist.archives")
+ pathConfigs.foreach { config =>
+ // Replace old URIs with resolved URIs, if they exist
+ sysProps.get(config).foreach { oldValue =>
+ sysProps(config) = Utils.resolveURIs(oldValue)
}
}
- // Read from default spark properties, if any
- for ((k, v) <- args.getDefaultSparkProperties) {
- if (!sysProps.contains(k)) sysProps(k) = v
+ // Resolve and format python file paths properly before adding them to the PYTHONPATH.
+ // The resolving part is redundant in the case of --py-files, but necessary if the user
+ // explicitly sets `spark.submit.pyFiles` in his/her default properties file.
+ sysProps.get("spark.submit.pyFiles").foreach { pyFiles =>
+ val resolvedPyFiles = Utils.resolveURIs(pyFiles)
+ val formattedPyFiles = PythonRunner.formatPaths(resolvedPyFiles).mkString(",")
+ sysProps("spark.submit.pyFiles") = formattedPyFiles
}
(childArgs, childClasspath, sysProps, childMainClass)
@@ -297,8 +331,24 @@ object SparkSubmit {
System.setProperty(key, value)
}
- val mainClass = Class.forName(childMainClass, true, loader)
+ var mainClass: Class[_] = null
+
+ try {
+ mainClass = Class.forName(childMainClass, true, loader)
+ } catch {
+ case e: ClassNotFoundException =>
+ e.printStackTrace(printStream)
+ if (childMainClass.contains("thriftserver")) {
+ println(s"Failed to load main class $childMainClass.")
+ println("You need to build Spark with -Phive and -Phive-thriftserver.")
+ }
+ System.exit(CLASS_NOT_FOUND_EXIT_STATUS)
+ }
+
val mainMethod = mainClass.getMethod("main", new Array[String](0).getClass)
+ if (!Modifier.isStatic(mainMethod.getModifiers)) {
+ throw new IllegalStateException("The main method in the given main class must be static")
+ }
try {
mainMethod.invoke(null, childArgs.toArray)
} catch {
@@ -328,7 +378,7 @@ object SparkSubmit {
* Return whether the given primary resource represents a user jar.
*/
private def isUserJar(primaryResource: String): Boolean = {
- !isShell(primaryResource) && !isPython(primaryResource)
+ !isShell(primaryResource) && !isPython(primaryResource) && !isInternal(primaryResource)
}
/**
@@ -345,6 +395,10 @@ object SparkSubmit {
primaryResource.endsWith(".py") || primaryResource == PYSPARK_SHELL
}
+ private[spark] def isInternal(primaryResource: String): Boolean = {
+ primaryResource == SPARK_INTERNAL
+ }
+
/**
* Merge a sequence of comma-separated file lists, some of which may be null to indicate
* no files, into a single comma-separated string.
@@ -364,6 +418,6 @@ object SparkSubmit {
private[spark] case class OptionAssigner(
value: String,
clusterManager: Int,
- deployOnCluster: Boolean,
+ deployMode: Int,
clOption: String = null,
sysProp: String = null)
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
index 57655aa4c32b1..f0e9ee67f6a67 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
@@ -17,20 +17,17 @@
package org.apache.spark.deploy
-import java.io.{File, FileInputStream, IOException}
-import java.util.Properties
import java.util.jar.JarFile
-import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, HashMap}
-import org.apache.spark.SparkException
import org.apache.spark.util.Utils
/**
* Parses and encapsulates arguments from the spark-submit script.
+ * The env argument is used for testing.
*/
-private[spark] class SparkSubmitArguments(args: Seq[String]) {
+private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, String] = sys.env) {
var master: String = null
var deployMode: String = null
var executorMemory: String = null
@@ -55,19 +52,15 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) {
var verbose: Boolean = false
var isPython: Boolean = false
var pyFiles: String = null
+ val sparkProperties: HashMap[String, String] = new HashMap[String, String]()
- parseOpts(args.toList)
- loadDefaults()
- checkRequiredArguments()
-
- /** Return default present in the currently defined defaults file. */
- def getDefaultSparkProperties = {
+ /** Default properties present in the currently defined defaults file. */
+ lazy val defaultSparkProperties: HashMap[String, String] = {
val defaultProperties = new HashMap[String, String]()
if (verbose) SparkSubmit.printStream.println(s"Using properties file: $propertiesFile")
Option(propertiesFile).foreach { filename =>
- val file = new File(filename)
- SparkSubmitArguments.getPropertiesFromFile(file).foreach { case (k, v) =>
- if (k.startsWith("spark")) {
+ Utils.getPropertiesFromFile(filename).foreach { case (k, v) =>
+ if (k.startsWith("spark.")) {
defaultProperties(k) = v
if (verbose) SparkSubmit.printStream.println(s"Adding default property: $k=$v")
} else {
@@ -78,37 +71,55 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) {
defaultProperties
}
- /** Fill in any undefined values based on the current properties file or built-in defaults. */
- private def loadDefaults(): Unit = {
+ // Set parameters from command line arguments
+ parseOpts(args.toList)
+ // Populate `sparkProperties` map from properties file
+ mergeDefaultSparkProperties()
+ // Use `sparkProperties` map along with env vars to fill in any missing parameters
+ loadEnvironmentArguments()
+
+ checkRequiredArguments()
+ /**
+ * Merge values from the default properties file with those specified through --conf.
+ * When this is called, `sparkProperties` is already filled with configs from the latter.
+ */
+ private def mergeDefaultSparkProperties(): Unit = {
// Use common defaults file, if not specified by user
- if (propertiesFile == null) {
- sys.env.get("SPARK_HOME").foreach { sparkHome =>
- val sep = File.separator
- val defaultPath = s"${sparkHome}${sep}conf${sep}spark-defaults.conf"
- val file = new File(defaultPath)
- if (file.exists()) {
- propertiesFile = file.getAbsolutePath
- }
+ propertiesFile = Option(propertiesFile).getOrElse(Utils.getDefaultPropertiesFile(env))
+ // Honor --conf before the defaults file
+ defaultSparkProperties.foreach { case (k, v) =>
+ if (!sparkProperties.contains(k)) {
+ sparkProperties(k) = v
}
}
+ }
- val defaultProperties = getDefaultSparkProperties
- // Use properties file as fallback for values which have a direct analog to
- // arguments in this script.
- master = Option(master).getOrElse(defaultProperties.get("spark.master").orNull)
+ /**
+ * Load arguments from environment variables, Spark properties etc.
+ */
+ private def loadEnvironmentArguments(): Unit = {
+ master = Option(master)
+ .orElse(sparkProperties.get("spark.master"))
+ .orElse(env.get("MASTER"))
+ .orNull
+ driverMemory = Option(driverMemory)
+ .orElse(sparkProperties.get("spark.driver.memory"))
+ .orElse(env.get("SPARK_DRIVER_MEMORY"))
+ .orNull
executorMemory = Option(executorMemory)
- .getOrElse(defaultProperties.get("spark.executor.memory").orNull)
+ .orElse(sparkProperties.get("spark.executor.memory"))
+ .orElse(env.get("SPARK_EXECUTOR_MEMORY"))
+ .orNull
executorCores = Option(executorCores)
- .getOrElse(defaultProperties.get("spark.executor.cores").orNull)
+ .orElse(sparkProperties.get("spark.executor.cores"))
+ .orNull
totalExecutorCores = Option(totalExecutorCores)
- .getOrElse(defaultProperties.get("spark.cores.max").orNull)
- name = Option(name).getOrElse(defaultProperties.get("spark.app.name").orNull)
- jars = Option(jars).getOrElse(defaultProperties.get("spark.jars").orNull)
-
- // This supports env vars in older versions of Spark
- master = Option(master).getOrElse(System.getenv("MASTER"))
- deployMode = Option(deployMode).getOrElse(System.getenv("DEPLOY_MODE"))
+ .orElse(sparkProperties.get("spark.cores.max"))
+ .orNull
+ name = Option(name).orElse(sparkProperties.get("spark.app.name")).orNull
+ jars = Option(jars).orElse(sparkProperties.get("spark.jars")).orNull
+ deployMode = Option(deployMode).orElse(env.get("DEPLOY_MODE")).orNull
// Try to set main class from JAR if no --class argument is given
if (mainClass == null && !isPython && primaryResource != null) {
@@ -134,7 +145,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) {
}
/** Ensure that required fields exists. Call this only once all defaults are loaded. */
- private def checkRequiredArguments() = {
+ private def checkRequiredArguments(): Unit = {
if (args.length == 0) {
printUsageAndExit(-1)
}
@@ -161,7 +172,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) {
}
if (master.startsWith("yarn")) {
- val hasHadoopEnv = sys.env.contains("HADOOP_CONF_DIR") || sys.env.contains("YARN_CONF_DIR")
+ val hasHadoopEnv = env.contains("HADOOP_CONF_DIR") || env.contains("YARN_CONF_DIR")
if (!hasHadoopEnv && !Utils.isTesting) {
throw new Exception(s"When running with master '$master' " +
"either HADOOP_CONF_DIR or YARN_CONF_DIR must be set in the environment.")
@@ -169,7 +180,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) {
}
}
- override def toString = {
+ override def toString = {
s"""Parsed arguments:
| master $master
| deployMode $deployMode
@@ -195,17 +206,23 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) {
| jars $jars
| verbose $verbose
|
- |Default properties from $propertiesFile:
- |${getDefaultSparkProperties.mkString(" ", "\n ", "\n")}
+ |Spark properties used, including those specified through
+ | --conf and those from the properties file $propertiesFile:
+ |${sparkProperties.mkString(" ", "\n ", "\n")}
""".stripMargin
}
/** Fill in values by parsing user options. */
private def parseOpts(opts: Seq[String]): Unit = {
+ val EQ_SEPARATED_OPT="""(--[^=]+)=(.+)""".r
+
// Delineates parsing of Spark options from parsing of user options.
- var inSparkOpts = true
parse(opts)
+ /**
+ * NOTE: If you add or remove spark-submit options,
+ * modify NOT ONLY this file but also utils.sh
+ */
def parse(opts: Seq[String]): Unit = opts match {
case ("--name") :: value :: tail =>
name = value
@@ -290,6 +307,13 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) {
jars = Utils.resolveURIs(value)
parse(tail)
+ case ("--conf" | "-c") :: value :: tail =>
+ value.split("=", 2).toSeq match {
+ case Seq(k, v) => sparkProperties(k) = v
+ case _ => SparkSubmit.printErrorAndExit(s"Spark config without '=': $value")
+ }
+ parse(tail)
+
case ("--help" | "-h") :: tail =>
printUsageAndExit(0)
@@ -297,39 +321,27 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) {
verbose = true
parse(tail)
+ case EQ_SEPARATED_OPT(opt, value) :: tail =>
+ parse(opt :: value :: tail)
+
+ case value :: tail if value.startsWith("-") =>
+ SparkSubmit.printErrorAndExit(s"Unrecognized option '$value'.")
+
case value :: tail =>
- if (inSparkOpts) {
- value match {
- // convert --foo=bar to --foo bar
- case v if v.startsWith("--") && v.contains("=") && v.split("=").size == 2 =>
- val parts = v.split("=")
- parse(Seq(parts(0), parts(1)) ++ tail)
- case v if v.startsWith("-") =>
- val errMessage = s"Unrecognized option '$value'."
- SparkSubmit.printErrorAndExit(errMessage)
- case v =>
- primaryResource =
- if (!SparkSubmit.isShell(v)) {
- Utils.resolveURI(v).toString
- } else {
- v
- }
- inSparkOpts = false
- isPython = SparkSubmit.isPython(v)
- parse(tail)
+ primaryResource =
+ if (!SparkSubmit.isShell(value) && !SparkSubmit.isInternal(value)) {
+ Utils.resolveURI(value).toString
+ } else {
+ value
}
- } else {
- if (!value.isEmpty) {
- childArgs += value
- }
- parse(tail)
- }
+ isPython = SparkSubmit.isPython(value)
+ childArgs ++= tail
case Nil =>
}
}
- private def printUsageAndExit(exitCode: Int, unknownParam: Any = null) {
+ private def printUsageAndExit(exitCode: Int, unknownParam: Any = null): Unit = {
val outStream = SparkSubmit.printStream
if (unknownParam != null) {
outStream.println("Unknown/unsupported param " + unknownParam)
@@ -349,6 +361,8 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) {
| on the PYTHONPATH for Python apps.
| --files FILES Comma-separated list of files to be placed in the working
| directory of each executor.
+ |
+ | --conf PROP=VALUE Arbitrary Spark configuration property.
| --properties-file FILE Path to a file from which to load extra properties. If not
| specified, this will look for conf/spark-defaults.conf.
|
@@ -381,23 +395,3 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) {
SparkSubmit.exitFn()
}
}
-
-object SparkSubmitArguments {
- /** Load properties present in the given file. */
- def getPropertiesFromFile(file: File): Seq[(String, String)] = {
- require(file.exists(), s"Properties file $file does not exist")
- require(file.isFile(), s"Properties file $file is not a normal file")
- val inputStream = new FileInputStream(file)
- try {
- val properties = new Properties()
- properties.load(inputStream)
- properties.stringPropertyNames().toSeq.map(k => (k, properties(k).trim))
- } catch {
- case e: IOException =>
- val message = s"Failed when loading Spark properties file $file"
- throw new SparkException(message, e)
- } finally {
- inputStream.close()
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala
new file mode 100644
index 0000000000000..aa3743ca7df63
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy
+
+import java.io.File
+
+import scala.collection.JavaConversions._
+
+import org.apache.spark.util.{RedirectThread, Utils}
+
+/**
+ * Launch an application through Spark submit in client mode with the appropriate classpath,
+ * library paths, java options and memory. These properties of the JVM must be set before the
+ * driver JVM is launched. The sole purpose of this class is to avoid handling the complexity
+ * of parsing the properties file for such relevant configs in Bash.
+ *
+ * Usage: org.apache.spark.deploy.SparkSubmitDriverBootstrapper
+ */
+private[spark] object SparkSubmitDriverBootstrapper {
+
+ // Note: This class depends on the behavior of `bin/spark-class` and `bin/spark-submit`.
+ // Any changes made there must be reflected in this file.
+
+ def main(args: Array[String]): Unit = {
+
+ // This should be called only from `bin/spark-class`
+ if (!sys.env.contains("SPARK_CLASS")) {
+ System.err.println("SparkSubmitDriverBootstrapper must be called from `bin/spark-class`!")
+ System.exit(1)
+ }
+
+ val submitArgs = args
+ val runner = sys.env("RUNNER")
+ val classpath = sys.env("CLASSPATH")
+ val javaOpts = sys.env("JAVA_OPTS")
+ val defaultDriverMemory = sys.env("OUR_JAVA_MEM")
+
+ // Spark submit specific environment variables
+ val deployMode = sys.env("SPARK_SUBMIT_DEPLOY_MODE")
+ val propertiesFile = sys.env("SPARK_SUBMIT_PROPERTIES_FILE")
+ val bootstrapDriver = sys.env("SPARK_SUBMIT_BOOTSTRAP_DRIVER")
+ val submitDriverMemory = sys.env.get("SPARK_SUBMIT_DRIVER_MEMORY")
+ val submitLibraryPath = sys.env.get("SPARK_SUBMIT_LIBRARY_PATH")
+ val submitClasspath = sys.env.get("SPARK_SUBMIT_CLASSPATH")
+ val submitJavaOpts = sys.env.get("SPARK_SUBMIT_OPTS")
+
+ assume(runner != null, "RUNNER must be set")
+ assume(classpath != null, "CLASSPATH must be set")
+ assume(javaOpts != null, "JAVA_OPTS must be set")
+ assume(defaultDriverMemory != null, "OUR_JAVA_MEM must be set")
+ assume(deployMode == "client", "SPARK_SUBMIT_DEPLOY_MODE must be \"client\"!")
+ assume(propertiesFile != null, "SPARK_SUBMIT_PROPERTIES_FILE must be set")
+ assume(bootstrapDriver != null, "SPARK_SUBMIT_BOOTSTRAP_DRIVER must be set")
+
+ // Parse the properties file for the equivalent spark.driver.* configs
+ val properties = Utils.getPropertiesFromFile(propertiesFile)
+ val confDriverMemory = properties.get("spark.driver.memory")
+ val confLibraryPath = properties.get("spark.driver.extraLibraryPath")
+ val confClasspath = properties.get("spark.driver.extraClassPath")
+ val confJavaOpts = properties.get("spark.driver.extraJavaOptions")
+
+ // Favor Spark submit arguments over the equivalent configs in the properties file.
+ // Note that we do not actually use the Spark submit values for library path, classpath,
+ // and Java opts here, because we have already captured them in Bash.
+
+ val newDriverMemory = submitDriverMemory
+ .orElse(confDriverMemory)
+ .getOrElse(defaultDriverMemory)
+
+ val newClasspath =
+ if (submitClasspath.isDefined) {
+ classpath
+ } else {
+ classpath + confClasspath.map(sys.props("path.separator") + _).getOrElse("")
+ }
+
+ val newJavaOpts =
+ if (submitJavaOpts.isDefined) {
+ // SPARK_SUBMIT_OPTS is already captured in JAVA_OPTS
+ javaOpts
+ } else {
+ javaOpts + confJavaOpts.map(" " + _).getOrElse("")
+ }
+
+ val filteredJavaOpts = Utils.splitCommandString(newJavaOpts)
+ .filterNot(_.startsWith("-Xms"))
+ .filterNot(_.startsWith("-Xmx"))
+
+ // Build up command
+ val command: Seq[String] =
+ Seq(runner) ++
+ Seq("-cp", newClasspath) ++
+ filteredJavaOpts ++
+ Seq(s"-Xms$newDriverMemory", s"-Xmx$newDriverMemory") ++
+ Seq("org.apache.spark.deploy.SparkSubmit") ++
+ submitArgs
+
+ // Print the launch command. This follows closely the format used in `bin/spark-class`.
+ if (sys.env.contains("SPARK_PRINT_LAUNCH_COMMAND")) {
+ System.err.print("Spark Command: ")
+ System.err.println(command.mkString(" "))
+ System.err.println("========================================\n")
+ }
+
+ // Start the driver JVM
+ val filteredCommand = command.filter(_.nonEmpty)
+ val builder = new ProcessBuilder(filteredCommand)
+ val env = builder.environment()
+
+ if (submitLibraryPath.isEmpty && confLibraryPath.nonEmpty) {
+ val libraryPaths = confLibraryPath ++ sys.env.get(Utils.libraryPathEnvName)
+ env.put(Utils.libraryPathEnvName, libraryPaths.mkString(sys.props("path.separator")))
+ }
+
+ val process = builder.start()
+
+ // If we kill an app while it's running, its sub-process should be killed too.
+ Runtime.getRuntime().addShutdownHook(new Thread() {
+ override def run() = {
+ if (process != null) {
+ process.destroy()
+ sys.exit(process.waitFor())
+ }
+ }
+ })
+
+ // Redirect stdout and stderr from the child JVM
+ val stdoutThread = new RedirectThread(process.getInputStream, System.out, "redirect stdout")
+ val stderrThread = new RedirectThread(process.getErrorStream, System.err, "redirect stderr")
+ stdoutThread.start()
+ stderrThread.start()
+
+ // Redirect stdin to child JVM only if we're not running Windows. This is because the
+ // subprocess there already reads directly from our stdin, so we should avoid spawning a
+ // thread that contends with the subprocess in reading from System.in.
+ val isWindows = Utils.isWindows
+ val isSubprocess = sys.env.contains("IS_SUBPROCESS")
+ if (!isWindows) {
+ val stdinThread = new RedirectThread(System.in, process.getOutputStream, "redirect stdin")
+ stdinThread.start()
+ // Spark submit (JVM) may run as a subprocess, and so this JVM should terminate on
+ // broken pipe, signaling that the parent process has exited. This is the case if the
+ // application is launched directly from python, as in the PySpark shell. In Windows,
+ // the termination logic is handled in java_gateway.py
+ if (isSubprocess) {
+ stdinThread.join()
+ process.destroy()
+ }
+ }
+ val returnCode = process.waitFor()
+ sys.exit(returnCode)
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala
index d38e9e79204c2..98a93d1fcb2a3 100644
--- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala
@@ -30,7 +30,7 @@ import org.apache.spark.{Logging, SparkConf, SparkException}
import org.apache.spark.deploy.{ApplicationDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.Master
-import org.apache.spark.util.{Utils, AkkaUtils}
+import org.apache.spark.util.{ActorLogReceive, Utils, AkkaUtils}
/**
* Interface allowing applications to speak with a Spark deploy cluster. Takes a master URL,
@@ -56,7 +56,7 @@ private[spark] class AppClient(
var registered = false
var activeMasterUrl: String = null
- class ClientActor extends Actor with Logging {
+ class ClientActor extends Actor with ActorLogReceive with Logging {
var master: ActorSelection = null
var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times
var alreadyDead = false // To avoid calling listener.dead() multiple times
@@ -119,7 +119,7 @@ private[spark] class AppClient(
.contains(remoteUrl.hostPort)
}
- override def receive = {
+ override def receiveWithLogging = {
case RegisteredApplication(appId_, masterUrl) =>
appId = appId_
registered = true
@@ -154,7 +154,7 @@ private[spark] class AppClient(
logWarning(s"Connection to $address failed; waiting for master to reconnect...")
markDisconnected()
- case AssociationErrorEvent(cause, _, address, _) if isPossibleMaster(address) =>
+ case AssociationErrorEvent(cause, _, address, _, _) if isPossibleMaster(address) =>
logWarning(s"Could not connect to $address: $cause")
case StopAppClient =>
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
index e15a87bd38fda..88a0862b96afe 100644
--- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
@@ -46,11 +46,10 @@ private[spark] object TestClient {
def main(args: Array[String]) {
val url = args(0)
val conf = new SparkConf
- val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0,
+ val (actorSystem, _) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0,
conf = conf, securityManager = new SecurityManager(conf))
- val desc = new ApplicationDescription(
- "TestClient", Some(1), 512, Command("spark.deploy.client.TestExecutor", Seq(), Map(), Seq(),
- Seq()), Some("dummy-spark-home"), "ignored")
+ val desc = new ApplicationDescription("TestClient", Some(1), 512,
+ Command("spark.deploy.client.TestExecutor", Seq(), Map(), Seq(), Seq(), Seq()), "ignored")
val listener = new TestListener
val client = new AppClient(actorSystem, Array(url), desc, listener, new SparkConf)
client.start()
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala
index a0e8bd403a41d..fbe39b27649f6 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala
@@ -34,15 +34,15 @@ private[spark] abstract class ApplicationHistoryProvider {
*
* @return List of all know applications.
*/
- def getListing(): Seq[ApplicationHistoryInfo]
+ def getListing(): Iterable[ApplicationHistoryInfo]
/**
* Returns the Spark UI for a specific application.
*
* @param appId The application ID.
- * @return The application's UI, or null if application is not found.
+ * @return The application's UI, or None if application is not found.
*/
- def getAppUI(appId: String): SparkUI
+ def getAppUI(appId: String): Option[SparkUI]
/**
* Called when the server is shutting down.
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
index a8c9ac072449f..2d1609b973607 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
@@ -24,6 +24,7 @@ import scala.collection.mutable
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.spark.{Logging, SecurityManager, SparkConf}
+import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.scheduler._
import org.apache.spark.ui.SparkUI
import org.apache.spark.util.Utils
@@ -31,22 +32,32 @@ import org.apache.spark.util.Utils
private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHistoryProvider
with Logging {
+ private val NOT_STARTED = ""
+
// Interval between each check for event log updates
private val UPDATE_INTERVAL_MS = conf.getInt("spark.history.fs.updateInterval",
conf.getInt("spark.history.updateInterval", 10)) * 1000
private val logDir = conf.get("spark.history.fs.logDirectory", null)
- if (logDir == null) {
- throw new IllegalArgumentException("Logging directory must be specified.")
- }
+ private val resolvedLogDir = Option(logDir)
+ .map { d => Utils.resolveURI(d) }
+ .getOrElse { throw new IllegalArgumentException("Logging directory must be specified.") }
- private val fs = Utils.getHadoopFileSystem(logDir)
+ private val fs = Utils.getHadoopFileSystem(resolvedLogDir,
+ SparkHadoopUtil.get.newConfiguration(conf))
// A timestamp of when the disk was last accessed to check for log updates
private var lastLogCheckTimeMs = -1L
- // List of applications, in order from newest to oldest.
- @volatile private var appList: Seq[ApplicationHistoryInfo] = Nil
+ // The modification time of the newest log detected during the last scan. This is used
+ // to ignore logs that are older during subsequent scans, to avoid processing data that
+ // is already known.
+ private var lastModifiedTime = -1L
+
+ // Mapping of application IDs to their metadata, in descending end time order. Apps are inserted
+ // into the map in order, so the LinkedHashMap maintains the correct ordering.
+ @volatile private var applications: mutable.LinkedHashMap[String, FsApplicationHistoryInfo]
+ = new mutable.LinkedHashMap()
/**
* A background thread that periodically checks for event log updates on disk.
@@ -76,14 +87,14 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
private def initialize() {
// Validate the log directory.
- val path = new Path(logDir)
+ val path = new Path(resolvedLogDir)
if (!fs.exists(path)) {
throw new IllegalArgumentException(
- "Logging directory specified does not exist: %s".format(logDir))
+ "Logging directory specified does not exist: %s".format(resolvedLogDir))
}
if (!fs.getFileStatus(path).isDir) {
throw new IllegalArgumentException(
- "Logging directory specified is not a directory: %s".format(logDir))
+ "Logging directory specified is not a directory: %s".format(resolvedLogDir))
}
checkForLogs()
@@ -91,19 +102,40 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
logCheckingThread.start()
}
- override def getListing() = appList
+ override def getListing() = applications.values
- override def getAppUI(appId: String): SparkUI = {
+ override def getAppUI(appId: String): Option[SparkUI] = {
try {
- val appLogDir = fs.getFileStatus(new Path(logDir, appId))
- loadAppInfo(appLogDir, true)._2
+ applications.get(appId).map { info =>
+ val (replayBus, appListener) = createReplayBus(fs.getFileStatus(
+ new Path(logDir, info.logDir)))
+ val ui = {
+ val conf = this.conf.clone()
+ val appSecManager = new SecurityManager(conf)
+ SparkUI.createHistoryUI(conf, replayBus, appSecManager, appId,
+ s"${HistoryServer.UI_PATH_PREFIX}/$appId")
+ // Do not call ui.bind() to avoid creating a new server for each application
+ }
+
+ replayBus.replay()
+
+ ui.setAppName(s"${appListener.appName.getOrElse(NOT_STARTED)} ($appId)")
+
+ val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false)
+ ui.getSecurityManager.setAcls(uiAclsEnabled)
+ // make sure to set admin acls before view acls so they are properly picked up
+ ui.getSecurityManager.setAdminAcls(appListener.adminAcls.getOrElse(""))
+ ui.getSecurityManager.setViewAcls(appListener.sparkUser.getOrElse(NOT_STARTED),
+ appListener.viewAcls.getOrElse(""))
+ ui
+ }
} catch {
- case e: FileNotFoundException => null
+ case e: FileNotFoundException => None
}
}
override def getConfig(): Map[String, String] =
- Map(("Event Log Location" -> logDir))
+ Map("Event Log Location" -> resolvedLogDir.toString)
/**
* Builds the application list based on the current contents of the log directory.
@@ -114,82 +146,81 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
lastLogCheckTimeMs = getMonotonicTimeMs()
logDebug("Checking for logs. Time is now %d.".format(lastLogCheckTimeMs))
try {
- val logStatus = fs.listStatus(new Path(logDir))
+ val logStatus = fs.listStatus(new Path(resolvedLogDir))
val logDirs = if (logStatus != null) logStatus.filter(_.isDir).toSeq else Seq[FileStatus]()
- val logInfos = logDirs.filter {
- dir => fs.isFile(new Path(dir.getPath(), EventLoggingListener.APPLICATION_COMPLETE))
- }
-
- val currentApps = Map[String, ApplicationHistoryInfo](
- appList.map(app => (app.id -> app)):_*)
- // For any application that either (i) is not listed or (ii) has changed since the last time
- // the listing was created (defined by the log dir's modification time), load the app's info.
- // Otherwise just reuse what's already in memory.
- val newApps = new mutable.ArrayBuffer[ApplicationHistoryInfo](logInfos.size)
- for (dir <- logInfos) {
- val curr = currentApps.getOrElse(dir.getPath().getName(), null)
- if (curr == null || curr.lastUpdated < getModificationTime(dir)) {
+ // Load all new logs from the log directory. Only directories that have a modification time
+ // later than the last known log directory will be loaded.
+ var newLastModifiedTime = lastModifiedTime
+ val logInfos = logDirs
+ .filter { dir =>
+ if (fs.isFile(new Path(dir.getPath(), EventLoggingListener.APPLICATION_COMPLETE))) {
+ val modTime = getModificationTime(dir)
+ newLastModifiedTime = math.max(newLastModifiedTime, modTime)
+ modTime > lastModifiedTime
+ } else {
+ false
+ }
+ }
+ .flatMap { dir =>
try {
- newApps += loadAppInfo(dir, false)._1
+ val (replayBus, appListener) = createReplayBus(dir)
+ replayBus.replay()
+ Some(new FsApplicationHistoryInfo(
+ dir.getPath().getName(),
+ appListener.appId.getOrElse(dir.getPath().getName()),
+ appListener.appName.getOrElse(NOT_STARTED),
+ appListener.startTime.getOrElse(-1L),
+ appListener.endTime.getOrElse(-1L),
+ getModificationTime(dir),
+ appListener.sparkUser.getOrElse(NOT_STARTED)))
} catch {
- case e: Exception => logError(s"Failed to load app info from directory $dir.")
+ case e: Exception =>
+ logInfo(s"Failed to load application log data from $dir.", e)
+ None
}
- } else {
- newApps += curr
}
- }
+ .sortBy { info => -info.endTime }
+
+ lastModifiedTime = newLastModifiedTime
+
+ // When there are new logs, merge the new list with the existing one, maintaining
+ // the expected ordering (descending end time). Maintaining the order is important
+ // to avoid having to sort the list every time there is a request for the log list.
+ if (!logInfos.isEmpty) {
+ val newApps = new mutable.LinkedHashMap[String, FsApplicationHistoryInfo]()
+ def addIfAbsent(info: FsApplicationHistoryInfo) = {
+ if (!newApps.contains(info.id)) {
+ newApps += (info.id -> info)
+ }
+ }
+
+ val newIterator = logInfos.iterator.buffered
+ val oldIterator = applications.values.iterator.buffered
+ while (newIterator.hasNext && oldIterator.hasNext) {
+ if (newIterator.head.endTime > oldIterator.head.endTime) {
+ addIfAbsent(newIterator.next)
+ } else {
+ addIfAbsent(oldIterator.next)
+ }
+ }
+ newIterator.foreach(addIfAbsent)
+ oldIterator.foreach(addIfAbsent)
- appList = newApps.sortBy { info => -info.endTime }
+ applications = newApps
+ }
} catch {
case t: Throwable => logError("Exception in checking for event log updates", t)
}
}
- /**
- * Parse the application's logs to find out the information we need to build the
- * listing page.
- *
- * When creating the listing of available apps, there is no need to load the whole UI for the
- * application. The UI is requested by the HistoryServer (by calling getAppInfo()) when the user
- * clicks on a specific application.
- *
- * @param logDir Directory with application's log files.
- * @param renderUI Whether to create the SparkUI for the application.
- * @return A 2-tuple `(app info, ui)`. `ui` will be null if `renderUI` is false.
- */
- private def loadAppInfo(logDir: FileStatus, renderUI: Boolean) = {
- val elogInfo = EventLoggingListener.parseLoggingInfo(logDir.getPath(), fs)
- val path = logDir.getPath
- val appId = path.getName
+ private def createReplayBus(logDir: FileStatus): (ReplayListenerBus, ApplicationEventListener) = {
+ val path = logDir.getPath()
+ val elogInfo = EventLoggingListener.parseLoggingInfo(path, fs)
val replayBus = new ReplayListenerBus(elogInfo.logPaths, fs, elogInfo.compressionCodec)
val appListener = new ApplicationEventListener
replayBus.addListener(appListener)
-
- val ui: SparkUI = if (renderUI) {
- val conf = this.conf.clone()
- val appSecManager = new SecurityManager(conf)
- new SparkUI(conf, appSecManager, replayBus, appId, "/history/" + appId)
- // Do not call ui.bind() to avoid creating a new server for each application
- } else {
- null
- }
-
- replayBus.replay()
- val appInfo = ApplicationHistoryInfo(
- appId,
- appListener.appName,
- appListener.startTime,
- appListener.endTime,
- getModificationTime(logDir),
- appListener.sparkUser)
-
- if (ui != null) {
- val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false)
- ui.getSecurityManager.setUIAcls(uiAclsEnabled)
- ui.getSecurityManager.setViewAcls(appListener.sparkUser, appListener.viewAcls)
- }
- (appInfo, ui)
+ (replayBus, appListener)
}
/** Return when this directory was last modified. */
@@ -212,3 +243,13 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
private def getMonotonicTimeMs() = System.nanoTime() / (1000 * 1000)
}
+
+private class FsApplicationHistoryInfo(
+ val logDir: String,
+ id: String,
+ name: String,
+ startTime: Long,
+ endTime: Long,
+ lastUpdated: Long,
+ sparkUser: String)
+ extends ApplicationHistoryInfo(id, name, startTime, endTime, lastUpdated, sparkUser)
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
index a958c837c2ff6..0e249e51a77d8 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
@@ -45,7 +45,7 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") {
- { providerConfig.map(e => - {e._1}: {e._2}
) }
+ {providerConfig.map { case (k, v) => - {k}: {v}
}}
{
if (allApps.size > 0) {
@@ -67,6 +67,7 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") {
}
private val appHeader = Seq(
+ "App ID",
"App Name",
"Started",
"Completed",
@@ -75,18 +76,19 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") {
"Last Updated")
private def appRow(info: ApplicationHistoryInfo): Seq[Node] = {
- val uiAddress = "/history/" + info.id
+ val uiAddress = HistoryServer.UI_PATH_PREFIX + s"/${info.id}"
val startTime = UIUtils.formatDate(info.startTime)
val endTime = UIUtils.formatDate(info.endTime)
val duration = UIUtils.formatDuration(info.endTime - info.startTime)
val lastUpdated = UIUtils.formatDate(info.lastUpdated)
- | {info.name} |
- {startTime} |
- {endTime} |
- {duration} |
+ {info.id} |
+ {info.name} |
+ {startTime} |
+ {endTime} |
+ {duration} |
{info.sparkUser} |
- {lastUpdated} |
+ {lastUpdated} |
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
index 56b38ddfc9313..ce00c0ffd21e0 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
@@ -25,9 +25,9 @@ import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder}
import org.apache.spark.{Logging, SecurityManager, SparkConf}
import org.apache.spark.deploy.SparkHadoopUtil
-import org.apache.spark.ui.{WebUI, SparkUI, UIUtils}
+import org.apache.spark.ui.{SparkUI, UIUtils, WebUI}
import org.apache.spark.ui.JettyUtils._
-import org.apache.spark.util.{SignalLogger, Utils}
+import org.apache.spark.util.SignalLogger
/**
* A web server that renders SparkUIs of completed applications.
@@ -52,10 +52,7 @@ class HistoryServer(
private val appLoader = new CacheLoader[String, SparkUI] {
override def load(key: String): SparkUI = {
- val ui = provider.getAppUI(key)
- if (ui == null) {
- throw new NoSuchElementException()
- }
+ val ui = provider.getAppUI(key).getOrElse(throw new NoSuchElementException())
attachSparkUI(ui)
ui
}
@@ -114,7 +111,7 @@ class HistoryServer(
attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static"))
val contextHandler = new ServletContextHandler
- contextHandler.setContextPath("/history")
+ contextHandler.setContextPath(HistoryServer.UI_PATH_PREFIX)
contextHandler.addServlet(new ServletHolder(loaderServlet), "/*")
attachHandler(contextHandler)
}
@@ -172,10 +169,12 @@ class HistoryServer(
object HistoryServer extends Logging {
private val conf = new SparkConf
+ val UI_PATH_PREFIX = "/history"
+
def main(argStrings: Array[String]) {
SignalLogger.register(log)
initSecurity()
- val args = new HistoryServerArguments(conf, argStrings)
+ new HistoryServerArguments(conf, argStrings)
val securityManager = new SecurityManager(conf)
val providerName = conf.getOption("spark.history.provider")
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala
index be9361b754fc3..5bce32a04d16d 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala
@@ -25,6 +25,7 @@ import org.apache.spark.util.Utils
*/
private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String]) {
private var logDir: String = null
+ private var propertiesFile: String = null
parse(args.toList)
@@ -32,25 +33,35 @@ private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String]
args match {
case ("--dir" | "-d") :: value :: tail =>
logDir = value
+ conf.set("spark.history.fs.logDirectory", value)
+ System.setProperty("spark.history.fs.logDirectory", value)
parse(tail)
case ("--help" | "-h") :: tail =>
printUsageAndExit(0)
+ case ("--properties-file") :: value :: tail =>
+ propertiesFile = value
+ parse(tail)
+
case Nil =>
case _ =>
printUsageAndExit(1)
}
- if (logDir != null) {
- conf.set("spark.history.fs.logDirectory", logDir)
- }
}
+ // This mutates the SparkConf, so all accesses to it must be made after this line
+ Utils.loadDefaultSparkProperties(conf, propertiesFile)
+
private def printUsageAndExit(exitCode: Int) {
System.err.println(
"""
- |Usage: HistoryServer
+ |Usage: HistoryServer [options]
+ |
+ |Options:
+ | --properties-file FILE Path to a custom Spark properties file.
+ | Default is conf/spark-defaults.conf.
|
|Configuration options can be set by setting the corresponding JVM system property.
|History Server options are always available; additional options depend on the provider.
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
index 72d0589689e71..ad7d81747c377 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
@@ -24,7 +24,9 @@ import scala.collection.mutable.ArrayBuffer
import akka.actor.ActorRef
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.deploy.ApplicationDescription
+import org.apache.spark.util.Utils
private[spark] class ApplicationInfo(
val startTime: Long,
@@ -46,6 +48,11 @@ private[spark] class ApplicationInfo(
init()
+ private def readObject(in: java.io.ObjectInputStream): Unit = Utils.tryOrIOException {
+ in.defaultReadObject()
+ init()
+ }
+
private def init() {
state = ApplicationState.WAITING
executors = new mutable.HashMap[Int, ExecutorInfo]
@@ -91,11 +98,13 @@ private[spark] class ApplicationInfo(
def retryCount = _retryCount
- def incrementRetryCount = {
+ def incrementRetryCount() = {
_retryCount += 1
_retryCount
}
+ def resetRetryCount() = _retryCount = 0
+
def markFinished(endState: ApplicationState.Value) {
state = endState
endTime = System.currentTimeMillis()
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala
index c87b66f047dc8..38db02cd2421b 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala
@@ -22,8 +22,8 @@ import com.codahale.metrics.{Gauge, MetricRegistry}
import org.apache.spark.metrics.source.Source
class ApplicationSource(val application: ApplicationInfo) extends Source {
- val metricRegistry = new MetricRegistry()
- val sourceName = "%s.%s.%s".format("application", application.desc.name,
+ override val metricRegistry = new MetricRegistry()
+ override val sourceName = "%s.%s.%s".format("application", application.desc.name,
System.currentTimeMillis())
metricRegistry.register(MetricRegistry.name("status"), new Gauge[String] {
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala
index 33377931d6993..9d3d7938c6ccb 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala
@@ -19,7 +19,9 @@ package org.apache.spark.deploy.master
import java.util.Date
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.deploy.DriverDescription
+import org.apache.spark.util.Utils
private[spark] class DriverInfo(
val startTime: Long,
@@ -33,4 +35,17 @@ private[spark] class DriverInfo(
@transient var exception: Option[Exception] = None
/* Most recent worker assigned to this driver */
@transient var worker: Option[WorkerInfo] = None
+
+ init()
+
+ private def readObject(in: java.io.ObjectInputStream): Unit = Utils.tryOrIOException {
+ in.defaultReadObject()
+ init()
+ }
+
+ private def init(): Unit = {
+ state = DriverState.SUBMITTED
+ worker = None
+ exception = None
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
index aa85aa060d9c1..6ff2aa5244847 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
@@ -18,10 +18,12 @@
package org.apache.spark.deploy.master
import java.io._
-
-import akka.serialization.Serialization
+import java.nio.ByteBuffer
import org.apache.spark.Logging
+import org.apache.spark.serializer.Serializer
+
+import scala.reflect.ClassTag
/**
* Stores data in a single on-disk directory with one file per application and worker.
@@ -32,69 +34,47 @@ import org.apache.spark.Logging
*/
private[spark] class FileSystemPersistenceEngine(
val dir: String,
- val serialization: Serialization)
+ val serialization: Serializer)
extends PersistenceEngine with Logging {
+ val serializer = serialization.newInstance()
new File(dir).mkdir()
- override def addApplication(app: ApplicationInfo) {
- val appFile = new File(dir + File.separator + "app_" + app.id)
- serializeIntoFile(appFile, app)
- }
-
- override def removeApplication(app: ApplicationInfo) {
- new File(dir + File.separator + "app_" + app.id).delete()
- }
-
- override def addDriver(driver: DriverInfo) {
- val driverFile = new File(dir + File.separator + "driver_" + driver.id)
- serializeIntoFile(driverFile, driver)
- }
-
- override def removeDriver(driver: DriverInfo) {
- new File(dir + File.separator + "driver_" + driver.id).delete()
- }
-
- override def addWorker(worker: WorkerInfo) {
- val workerFile = new File(dir + File.separator + "worker_" + worker.id)
- serializeIntoFile(workerFile, worker)
+ override def persist(name: String, obj: Object): Unit = {
+ serializeIntoFile(new File(dir + File.separator + name), obj)
}
- override def removeWorker(worker: WorkerInfo) {
- new File(dir + File.separator + "worker_" + worker.id).delete()
+ override def unpersist(name: String): Unit = {
+ new File(dir + File.separator + name).delete()
}
- override def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = {
- val sortedFiles = new File(dir).listFiles().sortBy(_.getName)
- val appFiles = sortedFiles.filter(_.getName.startsWith("app_"))
- val apps = appFiles.map(deserializeFromFile[ApplicationInfo])
- val driverFiles = sortedFiles.filter(_.getName.startsWith("driver_"))
- val drivers = driverFiles.map(deserializeFromFile[DriverInfo])
- val workerFiles = sortedFiles.filter(_.getName.startsWith("worker_"))
- val workers = workerFiles.map(deserializeFromFile[WorkerInfo])
- (apps, drivers, workers)
+ override def read[T: ClassTag](prefix: String) = {
+ val files = new File(dir).listFiles().filter(_.getName.startsWith(prefix))
+ files.map(deserializeFromFile[T])
}
private def serializeIntoFile(file: File, value: AnyRef) {
val created = file.createNewFile()
if (!created) { throw new IllegalStateException("Could not create file: " + file) }
- val serializer = serialization.findSerializerFor(value)
- val serialized = serializer.toBinary(value)
+ val out = serializer.serializeStream(new FileOutputStream(file))
+ try {
+ out.writeObject(value)
+ } finally {
+ out.close()
+ }
- val out = new FileOutputStream(file)
- out.write(serialized)
- out.close()
}
- def deserializeFromFile[T](file: File)(implicit m: Manifest[T]): T = {
+ def deserializeFromFile[T](file: File): T = {
val fileData = new Array[Byte](file.length().asInstanceOf[Int])
val dis = new DataInputStream(new FileInputStream(file))
- dis.readFully(fileData)
- dis.close()
+ try {
+ dis.readFully(fileData)
+ } finally {
+ dis.close()
+ }
- val clazz = m.runtimeClass.asInstanceOf[Class[T]]
- val serializer = serialization.serializerFor(clazz)
- serializer.fromBinary(fileData).asInstanceOf[T]
+ serializer.deserializeStream(dis).readObject()
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
index 4433a2ec29be6..cf77c86d760cf 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
@@ -17,30 +17,27 @@
package org.apache.spark.deploy.master
-import akka.actor.{Actor, ActorRef}
-
-import org.apache.spark.deploy.master.MasterMessages.ElectedLeader
+import org.apache.spark.annotation.DeveloperApi
/**
- * A LeaderElectionAgent keeps track of whether the current Master is the leader, meaning it
- * is the only Master serving requests.
- * In addition to the API provided, the LeaderElectionAgent will use of the following messages
- * to inform the Master of leader changes:
- * [[org.apache.spark.deploy.master.MasterMessages.ElectedLeader ElectedLeader]]
- * [[org.apache.spark.deploy.master.MasterMessages.RevokedLeadership RevokedLeadership]]
+ * :: DeveloperApi ::
+ *
+ * A LeaderElectionAgent tracks current master and is a common interface for all election Agents.
*/
-private[spark] trait LeaderElectionAgent extends Actor {
- // TODO: LeaderElectionAgent does not necessary to be an Actor anymore, need refactoring.
- val masterActor: ActorRef
+@DeveloperApi
+trait LeaderElectionAgent {
+ val masterActor: LeaderElectable
+ def stop() {} // to avoid noops in implementations.
}
-/** Single-node implementation of LeaderElectionAgent -- we're initially and always the leader. */
-private[spark] class MonarchyLeaderAgent(val masterActor: ActorRef) extends LeaderElectionAgent {
- override def preStart() {
- masterActor ! ElectedLeader
- }
+@DeveloperApi
+trait LeaderElectable {
+ def electedLeader()
+ def revokedLeadership()
+}
- override def receive = {
- case _ =>
- }
+/** Single-node implementation of LeaderElectionAgent -- we're initially and always the leader. */
+private[spark] class MonarchyLeaderAgent(val masterActor: LeaderElectable)
+ extends LeaderElectionAgent {
+ masterActor.electedLeader()
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index a304102a49086..021454e25804c 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -17,6 +17,7 @@
package org.apache.spark.deploy.master
+import java.net.URLEncoder
import java.text.SimpleDateFormat
import java.util.Date
@@ -30,25 +31,26 @@ import akka.actor._
import akka.pattern.ask
import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
import akka.serialization.SerializationExtension
-import org.apache.hadoop.fs.FileSystem
import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
-import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState}
+import org.apache.spark.deploy.{ApplicationDescription, DriverDescription,
+ ExecutorState, SparkHadoopUtil}
import org.apache.spark.deploy.DeployMessages._
+import org.apache.spark.deploy.history.HistoryServer
import org.apache.spark.deploy.master.DriverState.DriverState
import org.apache.spark.deploy.master.MasterMessages._
import org.apache.spark.deploy.master.ui.MasterWebUI
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus}
import org.apache.spark.ui.SparkUI
-import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils}
+import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils}
private[spark] class Master(
host: String,
port: Int,
webUiPort: Int,
val securityMgr: SecurityManager)
- extends Actor with Logging {
+ extends Actor with ActorLogReceive with Logging with LeaderElectable {
import context.dispatcher // to use Akka's scheduler.schedule()
@@ -57,8 +59,8 @@ private[spark] class Master(
def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs
val WORKER_TIMEOUT = conf.getLong("spark.worker.timeout", 60) * 1000
val RETAINED_APPLICATIONS = conf.getInt("spark.deploy.retainedApplications", 200)
+ val RETAINED_DRIVERS = conf.getInt("spark.deploy.retainedDrivers", 200)
val REAPER_ITERATIONS = conf.getInt("spark.dead.worker.persistence", 15)
- val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "")
val RECOVERY_MODE = conf.get("spark.deploy.recoveryMode", "NONE")
val workers = new HashSet[WorkerInfo]
@@ -72,9 +74,7 @@ private[spark] class Master(
val waitingApps = new ArrayBuffer[ApplicationInfo]
val completedApps = new ArrayBuffer[ApplicationInfo]
var nextAppNumber = 0
-
val appIdToUI = new HashMap[String, SparkUI]
- val fileSystemsUsed = new HashSet[FileSystem]
val drivers = new HashSet[DriverInfo]
val completedDrivers = new ArrayBuffer[DriverInfo]
@@ -102,7 +102,7 @@ private[spark] class Master(
var persistenceEngine: PersistenceEngine = _
- var leaderElectionAgent: ActorRef = _
+ var leaderElectionAgent: LeaderElectionAgent = _
private var recoveryCompletionTask: Cancellable = _
@@ -129,23 +129,24 @@ private[spark] class Master(
masterMetricsSystem.start()
applicationMetricsSystem.start()
- persistenceEngine = RECOVERY_MODE match {
+ val (persistenceEngine_, leaderElectionAgent_) = RECOVERY_MODE match {
case "ZOOKEEPER" =>
logInfo("Persisting recovery state to ZooKeeper")
- new ZooKeeperPersistenceEngine(SerializationExtension(context.system), conf)
+ val zkFactory = new ZooKeeperRecoveryModeFactory(conf)
+ (zkFactory.createPersistenceEngine(), zkFactory.createLeaderElectionAgent(this))
case "FILESYSTEM" =>
- logInfo("Persisting recovery state to directory: " + RECOVERY_DIR)
- new FileSystemPersistenceEngine(RECOVERY_DIR, SerializationExtension(context.system))
+ val fsFactory = new FileSystemRecoveryModeFactory(conf)
+ (fsFactory.createPersistenceEngine(), fsFactory.createLeaderElectionAgent(this))
+ case "CUSTOM" =>
+ val clazz = Class.forName(conf.get("spark.deploy.recoveryMode.factory"))
+ val factory = clazz.getConstructor(conf.getClass)
+ .newInstance(conf).asInstanceOf[StandaloneRecoveryModeFactory]
+ (factory.createPersistenceEngine(), factory.createLeaderElectionAgent(this))
case _ =>
- new BlackHolePersistenceEngine()
+ (new BlackHolePersistenceEngine(), new MonarchyLeaderAgent(this))
}
-
- leaderElectionAgent = RECOVERY_MODE match {
- case "ZOOKEEPER" =>
- context.actorOf(Props(classOf[ZooKeeperLeaderElectionAgent], self, masterUrl, conf))
- case _ =>
- context.actorOf(Props(classOf[MonarchyLeaderAgent], self))
- }
+ persistenceEngine = persistenceEngine_
+ leaderElectionAgent = leaderElectionAgent_
}
override def preRestart(reason: Throwable, message: Option[Any]) {
@@ -154,19 +155,28 @@ private[spark] class Master(
}
override def postStop() {
+ masterMetricsSystem.report()
+ applicationMetricsSystem.report()
// prevent the CompleteRecovery message sending to restarted master
if (recoveryCompletionTask != null) {
recoveryCompletionTask.cancel()
}
webUi.stop()
- fileSystemsUsed.foreach(_.close())
masterMetricsSystem.stop()
applicationMetricsSystem.stop()
persistenceEngine.close()
- context.stop(leaderElectionAgent)
+ leaderElectionAgent.stop()
+ }
+
+ override def electedLeader() {
+ self ! ElectedLeader
+ }
+
+ override def revokedLeadership() {
+ self ! RevokedLeadership
}
- override def receive = {
+ override def receiveWithLogging = {
case ElectedLeader => {
val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData()
state = if (storedApps.isEmpty && storedDrivers.isEmpty && storedWorkers.isEmpty) {
@@ -294,28 +304,34 @@ private[spark] class Master(
val execOption = idToApp.get(appId).flatMap(app => app.executors.get(execId))
execOption match {
case Some(exec) => {
+ val appInfo = idToApp(appId)
exec.state = state
+ if (state == ExecutorState.RUNNING) { appInfo.resetRetryCount() }
exec.application.driver ! ExecutorUpdated(execId, state, message, exitStatus)
if (ExecutorState.isFinished(state)) {
- val appInfo = idToApp(appId)
// Remove this executor from the worker and app
- logInfo("Removing executor " + exec.fullId + " because it is " + state)
+ logInfo(s"Removing executor ${exec.fullId} because it is $state")
appInfo.removeExecutor(exec)
exec.worker.removeExecutor(exec)
- val normalExit = exitStatus.exists(_ == 0)
+ val normalExit = exitStatus == Some(0)
// Only retry certain number of times so we don't go into an infinite loop.
- if (!normalExit && appInfo.incrementRetryCount < ApplicationState.MAX_NUM_RETRY) {
- schedule()
- } else if (!normalExit) {
- logError("Application %s with ID %s failed %d times, removing it".format(
- appInfo.desc.name, appInfo.id, appInfo.retryCount))
- removeApplication(appInfo, ApplicationState.FAILED)
+ if (!normalExit) {
+ if (appInfo.incrementRetryCount() < ApplicationState.MAX_NUM_RETRY) {
+ schedule()
+ } else {
+ val execs = appInfo.executors.values
+ if (!execs.exists(_.state == ExecutorState.RUNNING)) {
+ logError(s"Application ${appInfo.desc.name} with ID ${appInfo.id} failed " +
+ s"${appInfo.retryCount} times; removing it")
+ removeApplication(appInfo, ApplicationState.FAILED)
+ }
+ }
}
}
}
case None =>
- logWarning("Got status update for unknown executor " + appId + "/" + execId)
+ logWarning(s"Got status update for unknown executor $appId/$execId")
}
}
@@ -333,7 +349,14 @@ private[spark] class Master(
case Some(workerInfo) =>
workerInfo.lastHeartbeat = System.currentTimeMillis()
case None =>
- logWarning("Got heartbeat from unregistered worker " + workerId)
+ if (workers.map(_.id).contains(workerId)) {
+ logWarning(s"Got heartbeat from unregistered worker $workerId." +
+ " Asking it to re-register.")
+ sender ! ReconnectWorker(masterUrl)
+ } else {
+ logWarning(s"Got heartbeat from unregistered worker $workerId." +
+ " This worker was never registered, so ignoring the heartbeat.")
+ }
}
}
@@ -479,13 +502,26 @@ private[spark] class Master(
if (state != RecoveryState.ALIVE) { return }
// First schedule drivers, they take strict precedence over applications
- val shuffledWorkers = Random.shuffle(workers) // Randomization helps balance drivers
- for (worker <- shuffledWorkers if worker.state == WorkerState.ALIVE) {
- for (driver <- List(waitingDrivers: _*)) { // iterate over a copy of waitingDrivers
+ // Randomization helps balance drivers
+ val shuffledAliveWorkers = Random.shuffle(workers.toSeq.filter(_.state == WorkerState.ALIVE))
+ val numWorkersAlive = shuffledAliveWorkers.size
+ var curPos = 0
+
+ for (driver <- waitingDrivers.toList) { // iterate over a copy of waitingDrivers
+ // We assign workers to each waiting driver in a round-robin fashion. For each driver, we
+ // start from the last worker that was assigned a driver, and continue onwards until we have
+ // explored all alive workers.
+ var launched = false
+ var numWorkersVisited = 0
+ while (numWorkersVisited < numWorkersAlive && !launched) {
+ val worker = shuffledAliveWorkers(curPos)
+ numWorkersVisited += 1
if (worker.memoryFree >= driver.desc.mem && worker.coresFree >= driver.desc.cores) {
launchDriver(worker, driver)
waitingDrivers -= driver
+ launched = true
}
+ curPos = (curPos + 1) % numWorkersAlive
}
}
@@ -644,10 +680,7 @@ private[spark] class Master(
waitingApps -= app
// If application events are logged, use them to rebuild the UI
- if (!rebuildSparkUI(app)) {
- // Avoid broken links if the UI is not reconstructed
- app.desc.appUiUrl = ""
- }
+ rebuildSparkUI(app)
for (exec <- app.executors.values) {
exec.worker.removeExecutor(exec)
@@ -669,29 +702,52 @@ private[spark] class Master(
*/
def rebuildSparkUI(app: ApplicationInfo): Boolean = {
val appName = app.desc.name
- val eventLogDir = app.desc.eventLogDir.getOrElse { return false }
- val fileSystem = Utils.getHadoopFileSystem(eventLogDir)
- val eventLogInfo = EventLoggingListener.parseLoggingInfo(eventLogDir, fileSystem)
+ val notFoundBasePath = HistoryServer.UI_PATH_PREFIX + "/not-found"
+ val eventLogDir = app.desc.eventLogDir.getOrElse {
+ // Event logging is not enabled for this application
+ app.desc.appUiUrl = notFoundBasePath
+ return false
+ }
+
+ val appEventLogDir = EventLoggingListener.getLogDirPath(eventLogDir, app.id)
+ val fileSystem = Utils.getHadoopFileSystem(appEventLogDir,
+ SparkHadoopUtil.get.newConfiguration(conf))
+ val eventLogInfo = EventLoggingListener.parseLoggingInfo(appEventLogDir, fileSystem)
val eventLogPaths = eventLogInfo.logPaths
val compressionCodec = eventLogInfo.compressionCodec
- if (!eventLogPaths.isEmpty) {
- try {
- val replayBus = new ReplayListenerBus(eventLogPaths, fileSystem, compressionCodec)
- val ui = new SparkUI(
- new SparkConf, replayBus, appName + " (completed)", "/history/" + app.id)
- replayBus.replay()
- app.desc.appUiUrl = ui.basePath
- appIdToUI(app.id) = ui
- webUi.attachSparkUI(ui)
- return true
- } catch {
- case e: Exception =>
- logError("Exception in replaying log for application %s (%s)".format(appName, app.id), e)
- }
- } else {
- logWarning("Application %s (%s) has no valid logs: %s".format(appName, app.id, eventLogDir))
+
+ if (eventLogPaths.isEmpty) {
+ // Event logging is enabled for this application, but no event logs are found
+ val title = s"Application history not found (${app.id})"
+ var msg = s"No event logs found for application $appName in $appEventLogDir."
+ logWarning(msg)
+ msg += " Did you specify the correct logging directory?"
+ msg = URLEncoder.encode(msg, "UTF-8")
+ app.desc.appUiUrl = notFoundBasePath + s"?msg=$msg&title=$title"
+ return false
+ }
+
+ try {
+ val replayBus = new ReplayListenerBus(eventLogPaths, fileSystem, compressionCodec)
+ val ui = SparkUI.createHistoryUI(new SparkConf, replayBus, new SecurityManager(conf),
+ appName + " (completed)", HistoryServer.UI_PATH_PREFIX + s"/${app.id}")
+ replayBus.replay()
+ appIdToUI(app.id) = ui
+ webUi.attachSparkUI(ui)
+ // Application UI is successfully rebuilt, so link the Master UI to it
+ app.desc.appUiUrl = ui.getBasePath
+ true
+ } catch {
+ case e: Exception =>
+ // Relay exception message to application UI page
+ val title = s"Application history load error (${app.id})"
+ val exception = URLEncoder.encode(Utils.exceptionString(e), "UTF-8")
+ var msg = s"Exception in replaying log for application $appName!"
+ logError(msg, e)
+ msg = URLEncoder.encode(msg, "UTF-8")
+ app.desc.appUiUrl = notFoundBasePath + s"?msg=$msg&exception=$exception&title=$title"
+ false
}
- false
}
/** Generate a new app ID given a app's submission date */
@@ -744,11 +800,16 @@ private[spark] class Master(
case Some(driver) =>
logInfo(s"Removing driver: $driverId")
drivers -= driver
+ if (completedDrivers.size >= RETAINED_DRIVERS) {
+ val toRemove = math.max(RETAINED_DRIVERS / 10, 1)
+ completedDrivers.trimStart(toRemove)
+ }
completedDrivers += driver
persistenceEngine.removeDriver(driver)
driver.state = finalState
driver.exception = exception
driver.worker.foreach(w => w.removeDriver(driver))
+ schedule()
case None =>
logWarning(s"Asked to remove unknown driver: $driverId")
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala
index a87781fb93850..e34bee7854292 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala
@@ -27,6 +27,7 @@ private[spark] class MasterArguments(args: Array[String], conf: SparkConf) {
var host = Utils.localHostName()
var port = 7077
var webUiPort = 8080
+ var propertiesFile: String = null
// Check for settings in environment variables
if (System.getenv("SPARK_MASTER_HOST") != null) {
@@ -38,12 +39,16 @@ private[spark] class MasterArguments(args: Array[String], conf: SparkConf) {
if (System.getenv("SPARK_MASTER_WEBUI_PORT") != null) {
webUiPort = System.getenv("SPARK_MASTER_WEBUI_PORT").toInt
}
- if (conf.contains("master.ui.port")) {
- webUiPort = conf.get("master.ui.port").toInt
- }
parse(args.toList)
+ // This mutates the SparkConf, so all accesses to it must be made after this line
+ propertiesFile = Utils.loadDefaultSparkProperties(conf, propertiesFile)
+
+ if (conf.contains("spark.master.ui.port")) {
+ webUiPort = conf.get("spark.master.ui.port").toInt
+ }
+
def parse(args: List[String]): Unit = args match {
case ("--ip" | "-i") :: value :: tail =>
Utils.checkHost(value, "ip no longer supported, please use hostname " + value)
@@ -63,7 +68,11 @@ private[spark] class MasterArguments(args: Array[String], conf: SparkConf) {
webUiPort = value
parse(tail)
- case ("--help" | "-h") :: tail =>
+ case ("--properties-file") :: value :: tail =>
+ propertiesFile = value
+ parse(tail)
+
+ case ("--help") :: tail =>
printUsageAndExit(0)
case Nil => {}
@@ -83,7 +92,9 @@ private[spark] class MasterArguments(args: Array[String], conf: SparkConf) {
" -i HOST, --ip HOST Hostname to listen on (deprecated, please use --host or -h) \n" +
" -h HOST, --host HOST Hostname to listen on\n" +
" -p PORT, --port PORT Port to listen on (default: 7077)\n" +
- " --webui-port PORT Port for web UI (default: 8080)")
+ " --webui-port PORT Port for web UI (default: 8080)\n" +
+ " --properties-file FILE Path to a custom Spark properties file.\n" +
+ " Default is conf/spark-defaults.conf.")
System.exit(exitCode)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala
index 36c1b87b7f684..9c3f79f1244b7 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala
@@ -22,8 +22,8 @@ import com.codahale.metrics.{Gauge, MetricRegistry}
import org.apache.spark.metrics.source.Source
private[spark] class MasterSource(val master: Master) extends Source {
- val metricRegistry = new MetricRegistry()
- val sourceName = "master"
+ override val metricRegistry = new MetricRegistry()
+ override val sourceName = "master"
// Gauge for worker numbers in cluster
metricRegistry.register(MetricRegistry.name("workers"), new Gauge[Int] {
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
index e3640ea4f7e64..2e0e1e7036ac8 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
@@ -17,6 +17,10 @@
package org.apache.spark.deploy.master
+import org.apache.spark.annotation.DeveloperApi
+
+import scala.reflect.ClassTag
+
/**
* Allows Master to persist any state that is necessary in order to recover from a failure.
* The following semantics are required:
@@ -25,36 +29,70 @@ package org.apache.spark.deploy.master
* Given these two requirements, we will have all apps and workers persisted, but
* we might not have yet deleted apps or workers that finished (so their liveness must be verified
* during recovery).
+ *
+ * The implementation of this trait defines how name-object pairs are stored or retrieved.
*/
-private[spark] trait PersistenceEngine {
- def addApplication(app: ApplicationInfo)
+@DeveloperApi
+trait PersistenceEngine {
- def removeApplication(app: ApplicationInfo)
+ /**
+ * Defines how the object is serialized and persisted. Implementation will
+ * depend on the store used.
+ */
+ def persist(name: String, obj: Object)
- def addWorker(worker: WorkerInfo)
+ /**
+ * Defines how the object referred by its name is removed from the store.
+ */
+ def unpersist(name: String)
- def removeWorker(worker: WorkerInfo)
+ /**
+ * Gives all objects, matching a prefix. This defines how objects are
+ * read/deserialized back.
+ */
+ def read[T: ClassTag](prefix: String): Seq[T]
- def addDriver(driver: DriverInfo)
+ final def addApplication(app: ApplicationInfo): Unit = {
+ persist("app_" + app.id, app)
+ }
- def removeDriver(driver: DriverInfo)
+ final def removeApplication(app: ApplicationInfo): Unit = {
+ unpersist("app_" + app.id)
+ }
+
+ final def addWorker(worker: WorkerInfo): Unit = {
+ persist("worker_" + worker.id, worker)
+ }
+
+ final def removeWorker(worker: WorkerInfo): Unit = {
+ unpersist("worker_" + worker.id)
+ }
+
+ final def addDriver(driver: DriverInfo): Unit = {
+ persist("driver_" + driver.id, driver)
+ }
+
+ final def removeDriver(driver: DriverInfo): Unit = {
+ unpersist("driver_" + driver.id)
+ }
/**
* Returns the persisted data sorted by their respective ids (which implies that they're
* sorted by time of creation).
*/
- def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo])
+ final def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = {
+ (read[ApplicationInfo]("app_"), read[DriverInfo]("driver_"), read[WorkerInfo]("worker_"))
+ }
def close() {}
}
private[spark] class BlackHolePersistenceEngine extends PersistenceEngine {
- override def addApplication(app: ApplicationInfo) {}
- override def removeApplication(app: ApplicationInfo) {}
- override def addWorker(worker: WorkerInfo) {}
- override def removeWorker(worker: WorkerInfo) {}
- override def addDriver(driver: DriverInfo) {}
- override def removeDriver(driver: DriverInfo) {}
-
- override def readPersistedData() = (Nil, Nil, Nil)
+
+ override def persist(name: String, obj: Object): Unit = {}
+
+ override def unpersist(name: String): Unit = {}
+
+ override def read[T: ClassTag](name: String): Seq[T] = Nil
+
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala
new file mode 100644
index 0000000000000..d9d36c1ed5f9f
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.serializer.JavaSerializer
+
+/**
+ * ::DeveloperApi::
+ *
+ * Implementation of this class can be plugged in as recovery mode alternative for Spark's
+ * Standalone mode.
+ *
+ */
+@DeveloperApi
+abstract class StandaloneRecoveryModeFactory(conf: SparkConf) {
+
+ /**
+ * PersistenceEngine defines how the persistent data(Information about worker, driver etc..)
+ * is handled for recovery.
+ *
+ */
+ def createPersistenceEngine(): PersistenceEngine
+
+ /**
+ * Create an instance of LeaderAgent that decides who gets elected as master.
+ */
+ def createLeaderElectionAgent(master: LeaderElectable): LeaderElectionAgent
+}
+
+/**
+ * LeaderAgent in this case is a no-op. Since leader is forever leader as the actual
+ * recovery is made by restoring from filesystem.
+ */
+private[spark] class FileSystemRecoveryModeFactory(conf: SparkConf)
+ extends StandaloneRecoveryModeFactory(conf) with Logging {
+ val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "")
+
+ def createPersistenceEngine() = {
+ logInfo("Persisting recovery state to directory: " + RECOVERY_DIR)
+ new FileSystemPersistenceEngine(RECOVERY_DIR, new JavaSerializer(conf))
+ }
+
+ def createLeaderElectionAgent(master: LeaderElectable) = new MonarchyLeaderAgent(master)
+}
+
+private[spark] class ZooKeeperRecoveryModeFactory(conf: SparkConf)
+ extends StandaloneRecoveryModeFactory(conf) {
+ def createPersistenceEngine() = new ZooKeeperPersistenceEngine(new JavaSerializer(conf), conf)
+
+ def createLeaderElectionAgent(master: LeaderElectable) =
+ new ZooKeeperLeaderElectionAgent(master, conf)
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
index c5fa9cf7d7c2d..473ddc23ff0f3 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
@@ -21,6 +21,7 @@ import scala.collection.mutable
import akka.actor.ActorRef
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.Utils
private[spark] class WorkerInfo(
@@ -50,7 +51,7 @@ private[spark] class WorkerInfo(
def coresFree: Int = cores - coresUsed
def memoryFree: Int = memory - memoryUsed
- private def readObject(in: java.io.ObjectInputStream) : Unit = {
+ private def readObject(in: java.io.ObjectInputStream): Unit = Utils.tryOrIOException {
in.defaultReadObject()
init()
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
index 285f9b014e291..8eaa0ad948519 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
@@ -24,9 +24,8 @@ import org.apache.spark.deploy.master.MasterMessages._
import org.apache.curator.framework.CuratorFramework
import org.apache.curator.framework.recipes.leader.{LeaderLatchListener, LeaderLatch}
-private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: ActorRef,
- masterUrl: String, conf: SparkConf)
- extends LeaderElectionAgent with LeaderLatchListener with Logging {
+private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: LeaderElectable,
+ conf: SparkConf) extends LeaderLatchListener with LeaderElectionAgent with Logging {
val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/leader_election"
@@ -34,30 +33,21 @@ private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: ActorRef,
private var leaderLatch: LeaderLatch = _
private var status = LeadershipStatus.NOT_LEADER
- override def preStart() {
+ start()
+ def start() {
logInfo("Starting ZooKeeper LeaderElection agent")
zk = SparkCuratorUtil.newClient(conf)
leaderLatch = new LeaderLatch(zk, WORKING_DIR)
leaderLatch.addListener(this)
-
leaderLatch.start()
}
- override def preRestart(reason: scala.Throwable, message: scala.Option[scala.Any]) {
- logError("LeaderElectionAgent failed...", reason)
- super.preRestart(reason, message)
- }
-
- override def postStop() {
+ override def stop() {
leaderLatch.close()
zk.close()
}
- override def receive = {
- case _ =>
- }
-
override def isLeader() {
synchronized {
// could have lost leadership by now.
@@ -85,10 +75,10 @@ private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: ActorRef,
def updateLeadershipStatus(isLeader: Boolean) {
if (isLeader && status == LeadershipStatus.NOT_LEADER) {
status = LeadershipStatus.LEADER
- masterActor ! ElectedLeader
+ masterActor.electedLeader()
} else if (!isLeader && status == LeadershipStatus.LEADER) {
status = LeadershipStatus.NOT_LEADER
- masterActor ! RevokedLeadership
+ masterActor.revokedLeadership()
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
index 834dfedee52ce..96c2139eb02f0 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
@@ -19,72 +19,54 @@ package org.apache.spark.deploy.master
import scala.collection.JavaConversions._
-import akka.serialization.Serialization
import org.apache.curator.framework.CuratorFramework
import org.apache.zookeeper.CreateMode
import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.serializer.Serializer
+import java.nio.ByteBuffer
-class ZooKeeperPersistenceEngine(serialization: Serialization, conf: SparkConf)
+import scala.reflect.ClassTag
+
+
+private[spark] class ZooKeeperPersistenceEngine(val serialization: Serializer, conf: SparkConf)
extends PersistenceEngine
with Logging
{
val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/master_status"
val zk: CuratorFramework = SparkCuratorUtil.newClient(conf)
- SparkCuratorUtil.mkdir(zk, WORKING_DIR)
-
- override def addApplication(app: ApplicationInfo) {
- serializeIntoFile(WORKING_DIR + "/app_" + app.id, app)
- }
+ val serializer = serialization.newInstance()
- override def removeApplication(app: ApplicationInfo) {
- zk.delete().forPath(WORKING_DIR + "/app_" + app.id)
- }
+ SparkCuratorUtil.mkdir(zk, WORKING_DIR)
- override def addDriver(driver: DriverInfo) {
- serializeIntoFile(WORKING_DIR + "/driver_" + driver.id, driver)
- }
- override def removeDriver(driver: DriverInfo) {
- zk.delete().forPath(WORKING_DIR + "/driver_" + driver.id)
+ override def persist(name: String, obj: Object): Unit = {
+ serializeIntoFile(WORKING_DIR + "/" + name, obj)
}
- override def addWorker(worker: WorkerInfo) {
- serializeIntoFile(WORKING_DIR + "/worker_" + worker.id, worker)
+ override def unpersist(name: String): Unit = {
+ zk.delete().forPath(WORKING_DIR + "/" + name)
}
- override def removeWorker(worker: WorkerInfo) {
- zk.delete().forPath(WORKING_DIR + "/worker_" + worker.id)
+ override def read[T: ClassTag](prefix: String) = {
+ val file = zk.getChildren.forPath(WORKING_DIR).filter(_.startsWith(prefix))
+ file.map(deserializeFromFile[T]).flatten
}
override def close() {
zk.close()
}
- override def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = {
- val sortedFiles = zk.getChildren().forPath(WORKING_DIR).toList.sorted
- val appFiles = sortedFiles.filter(_.startsWith("app_"))
- val apps = appFiles.map(deserializeFromFile[ApplicationInfo]).flatten
- val driverFiles = sortedFiles.filter(_.startsWith("driver_"))
- val drivers = driverFiles.map(deserializeFromFile[DriverInfo]).flatten
- val workerFiles = sortedFiles.filter(_.startsWith("worker_"))
- val workers = workerFiles.map(deserializeFromFile[WorkerInfo]).flatten
- (apps, drivers, workers)
- }
-
private def serializeIntoFile(path: String, value: AnyRef) {
- val serializer = serialization.findSerializerFor(value)
- val serialized = serializer.toBinary(value)
- zk.create().withMode(CreateMode.PERSISTENT).forPath(path, serialized)
+ val serialized = serializer.serialize(value)
+ zk.create().withMode(CreateMode.PERSISTENT).forPath(path, serialized.array())
}
- def deserializeFromFile[T](filename: String)(implicit m: Manifest[T]): Option[T] = {
+ def deserializeFromFile[T](filename: String): Option[T] = {
val fileData = zk.getData().forPath(WORKING_DIR + "/" + filename)
- val clazz = m.runtimeClass.asInstanceOf[Class[T]]
- val serializer = serialization.serializerFor(clazz)
try {
- Some(serializer.fromBinary(fileData).asInstanceOf[T])
+ Some(serializer.deserialize(ByteBuffer.wrap(fileData)))
} catch {
case e: Exception => {
logWarning("Exception while reading persisted file, deleting", e)
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
index 34fa1429c86de..4588c130ef439 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
@@ -28,7 +28,7 @@ import org.json4s.JValue
import org.apache.spark.deploy.{ExecutorState, JsonProtocol}
import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState}
import org.apache.spark.deploy.master.ExecutorInfo
-import org.apache.spark.ui.{WebUIPage, UIUtils}
+import org.apache.spark.ui.{UIUtils, WebUIPage}
import org.apache.spark.util.Utils
private[spark] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") {
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/HistoryNotFoundPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/HistoryNotFoundPage.scala
new file mode 100644
index 0000000000000..d8daff3e7fb9c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/HistoryNotFoundPage.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master.ui
+
+import java.net.URLDecoder
+import javax.servlet.http.HttpServletRequest
+
+import scala.xml.Node
+
+import org.apache.spark.ui.{UIUtils, WebUIPage}
+
+private[spark] class HistoryNotFoundPage(parent: MasterWebUI)
+ extends WebUIPage("history/not-found") {
+
+ /**
+ * Render a page that conveys failure in loading application history.
+ *
+ * This accepts 3 HTTP parameters:
+ * msg = message to display to the user
+ * title = title of the page
+ * exception = detailed description of the exception in loading application history (if any)
+ *
+ * Parameters "msg" and "exception" are assumed to be UTF-8 encoded.
+ */
+ def render(request: HttpServletRequest): Seq[Node] = {
+ val titleParam = request.getParameter("title")
+ val msgParam = request.getParameter("msg")
+ val exceptionParam = request.getParameter("exception")
+
+ // If no parameters are specified, assume the user did not enable event logging
+ val defaultTitle = "Event logging is not enabled"
+ val defaultContent =
+
+
+ No event logs were found for this application! To
+
enable event logging,
+ set
spark.eventLog.enabled to true and
+
spark.eventLog.dir to the directory to which your
+ event logs are written.
+
+
+
+ val title = Option(titleParam).getOrElse(defaultTitle)
+ val content = Option(msgParam)
+ .map { msg => URLDecoder.decode(msg, "UTF-8") }
+ .map { msg =>
+
++
+ Option(exceptionParam)
+ .map { e => URLDecoder.decode(e, "UTF-8") }
+ .map { e =>
{e} }
+ .getOrElse(Seq.empty)
+ }.getOrElse(defaultContent)
+
+ UIUtils.basicSparkPage(content, title)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
index a18b39fc95d64..d86ec1e03e45c 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
@@ -21,14 +21,14 @@ import org.apache.spark.Logging
import org.apache.spark.deploy.master.Master
import org.apache.spark.ui.{SparkUI, WebUI}
import org.apache.spark.ui.JettyUtils._
-import org.apache.spark.util.{AkkaUtils, Utils}
+import org.apache.spark.util.AkkaUtils
/**
* Web UI server for the standalone master.
*/
private[spark]
class MasterWebUI(val master: Master, requestedPort: Int)
- extends WebUI(master.securityMgr, requestedPort, master.conf) with Logging {
+ extends WebUI(master.securityMgr, requestedPort, master.conf, name = "MasterUI") with Logging {
val masterActorRef = master.self
val timeout = AkkaUtils.askTimeout(master.conf)
@@ -38,6 +38,7 @@ class MasterWebUI(val master: Master, requestedPort: Int)
/** Initialize all components of the server. */
def initialize() {
attachPage(new ApplicationPage(this))
+ attachPage(new HistoryNotFoundPage(this))
attachPage(new MasterPage(this))
attachHandler(createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static"))
master.masterMetricsSystem.getServletHandlers.foreach(attachHandler)
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala
index 4af5bc3afad6c..28e9662db5da9 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala
@@ -20,6 +20,8 @@ package org.apache.spark.deploy.worker
import java.io.{File, FileOutputStream, InputStream, IOException}
import java.lang.System._
+import scala.collection.Map
+
import org.apache.spark.Logging
import org.apache.spark.deploy.Command
import org.apache.spark.util.Utils
@@ -29,8 +31,30 @@ import org.apache.spark.util.Utils
*/
private[spark]
object CommandUtils extends Logging {
- def buildCommandSeq(command: Command, memory: Int, sparkHome: String): Seq[String] = {
- val runner = getEnv("JAVA_HOME", command).map(_ + "/bin/java").getOrElse("java")
+
+ /**
+ * Build a ProcessBuilder based on the given parameters.
+ * The `env` argument is exposed for testing.
+ */
+ def buildProcessBuilder(
+ command: Command,
+ memory: Int,
+ sparkHome: String,
+ substituteArguments: String => String,
+ classPaths: Seq[String] = Seq[String](),
+ env: Map[String, String] = sys.env): ProcessBuilder = {
+ val localCommand = buildLocalCommand(command, substituteArguments, classPaths, env)
+ val commandSeq = buildCommandSeq(localCommand, memory, sparkHome)
+ val builder = new ProcessBuilder(commandSeq: _*)
+ val environment = builder.environment()
+ for ((key, value) <- localCommand.environment) {
+ environment.put(key, value)
+ }
+ builder
+ }
+
+ private def buildCommandSeq(command: Command, memory: Int, sparkHome: String): Seq[String] = {
+ val runner = sys.env.get("JAVA_HOME").map(_ + "/bin/java").getOrElse("java")
// SPARK-698: do not call the run.cmd script, as process.destroy()
// fails to kill a process tree on Windows
@@ -38,16 +62,42 @@ object CommandUtils extends Logging {
command.arguments
}
- private def getEnv(key: String, command: Command): Option[String] =
- command.environment.get(key).orElse(Option(System.getenv(key)))
+ /**
+ * Build a command based on the given one, taking into account the local environment
+ * of where this command is expected to run, substitute any placeholders, and append
+ * any extra class paths.
+ */
+ private def buildLocalCommand(
+ command: Command,
+ substituteArguments: String => String,
+ classPath: Seq[String] = Seq[String](),
+ env: Map[String, String]): Command = {
+ val libraryPathName = Utils.libraryPathEnvName
+ val libraryPathEntries = command.libraryPathEntries
+ val cmdLibraryPath = command.environment.get(libraryPathName)
+
+ val newEnvironment = if (libraryPathEntries.nonEmpty && libraryPathName.nonEmpty) {
+ val libraryPaths = libraryPathEntries ++ cmdLibraryPath ++ env.get(libraryPathName)
+ command.environment + ((libraryPathName, libraryPaths.mkString(File.pathSeparator)))
+ } else {
+ command.environment
+ }
+
+ Command(
+ command.mainClass,
+ command.arguments.map(substituteArguments),
+ newEnvironment,
+ command.classPathEntries ++ classPath,
+ Seq[String](), // library path already captured in environment variable
+ command.javaOpts)
+ }
/**
* Attention: this must always be aligned with the environment variables in the run scripts and
* the way the JAVA_OPTS are assembled there.
*/
- def buildJavaOpts(command: Command, memory: Int, sparkHome: String): Seq[String] = {
+ private def buildJavaOpts(command: Command, memory: Int, sparkHome: String): Seq[String] = {
val memoryOpts = Seq(s"-Xms${memory}M", s"-Xmx${memory}M")
- val extraOpts = command.extraJavaOptions.map(Utils.splitCommandString).getOrElse(Seq())
// Exists for backwards compatibility with older Spark versions
val workerLocalOpts = Option(getenv("SPARK_JAVA_OPTS")).map(Utils.splitCommandString)
@@ -57,25 +107,17 @@ object CommandUtils extends Logging {
logWarning("Set SPARK_LOCAL_DIRS for node-specific storage locations.")
}
- val libraryOpts =
- if (command.libraryPathEntries.size > 0) {
- val joined = command.libraryPathEntries.mkString(File.pathSeparator)
- Seq(s"-Djava.library.path=$joined")
- } else {
- Seq()
- }
-
- val permGenOpt = Seq("-XX:MaxPermSize=128m")
-
// Figure out our classpath with the external compute-classpath script
val ext = if (System.getProperty("os.name").startsWith("Windows")) ".cmd" else ".sh"
val classPath = Utils.executeAndGetOutput(
Seq(sparkHome + "/bin/compute-classpath" + ext),
- extraEnvironment=command.environment)
+ extraEnvironment = command.environment)
val userClassPath = command.classPathEntries ++ Seq(classPath)
+ val javaVersion = System.getProperty("java.version")
+ val permGenOpt = if (!javaVersion.startsWith("1.8")) Some("-XX:MaxPermSize=128m") else None
Seq("-cp", userClassPath.filterNot(_.isEmpty).mkString(File.pathSeparator)) ++
- permGenOpt ++ libraryOpts ++ extraOpts ++ workerLocalOpts ++ memoryOpts
+ permGenOpt ++ workerLocalOpts ++ command.javaOpts ++ memoryOpts
}
/** Spawn a thread that will redirect a given stream to a file */
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
index 662d37871e7a6..28cab36c7b9e2 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
@@ -23,21 +23,23 @@ import scala.collection.JavaConversions._
import scala.collection.Map
import akka.actor.ActorRef
-import com.google.common.base.Charsets
+import com.google.common.base.Charsets.UTF_8
import com.google.common.io.Files
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileUtil, Path}
-import org.apache.spark.Logging
-import org.apache.spark.deploy.{Command, DriverDescription}
+import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.deploy.{Command, DriverDescription, SparkHadoopUtil}
import org.apache.spark.deploy.DeployMessages.DriverStateChanged
import org.apache.spark.deploy.master.DriverState
import org.apache.spark.deploy.master.DriverState.DriverState
/**
* Manages the execution of one driver, including automatically restarting the driver on failure.
+ * This is currently only used in standalone cluster deploy mode.
*/
private[spark] class DriverRunner(
+ val conf: SparkConf,
val driverId: String,
val workDir: File,
val sparkHome: File,
@@ -74,17 +76,9 @@ private[spark] class DriverRunner(
// Make sure user application jar is on the classpath
// TODO: If we add ability to submit multiple jars they should also be added here
- val classPath = driverDesc.command.classPathEntries ++ Seq(s"$localJarFilename")
- val newCommand = Command(
- driverDesc.command.mainClass,
- driverDesc.command.arguments.map(substituteVariables),
- driverDesc.command.environment,
- classPath,
- driverDesc.command.libraryPathEntries,
- driverDesc.command.extraJavaOptions)
- val command = CommandUtils.buildCommandSeq(newCommand, driverDesc.mem,
- sparkHome.getAbsolutePath)
- launchDriver(command, driverDesc.command.environment, driverDir, driverDesc.supervise)
+ val builder = CommandUtils.buildProcessBuilder(driverDesc.command, driverDesc.mem,
+ sparkHome.getAbsolutePath, substituteVariables, Seq(localJarFilename))
+ launchDriver(builder, driverDir, driverDesc.supervise)
}
catch {
case e: Exception => finalException = Some(e)
@@ -143,8 +137,8 @@ private[spark] class DriverRunner(
val jarPath = new Path(driverDesc.jarUrl)
- val emptyConf = new Configuration()
- val jarFileSystem = jarPath.getFileSystem(emptyConf)
+ val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
+ val jarFileSystem = jarPath.getFileSystem(hadoopConf)
val destPath = new File(driverDir.getAbsolutePath, jarPath.getName)
val jarFileName = jarPath.getName
@@ -153,7 +147,7 @@ private[spark] class DriverRunner(
if (!localJarFile.exists()) { // May already exist if running multiple workers on one node
logInfo(s"Copying user jar $jarPath to $destPath")
- FileUtil.copy(jarFileSystem, jarPath, destPath, false, emptyConf)
+ FileUtil.copy(jarFileSystem, jarPath, destPath, false, hadoopConf)
}
if (!localJarFile.exists()) { // Verify copy succeeded
@@ -163,11 +157,8 @@ private[spark] class DriverRunner(
localJarFilename
}
- private def launchDriver(command: Seq[String], envVars: Map[String, String], baseDir: File,
- supervise: Boolean) {
- val builder = new ProcessBuilder(command: _*).directory(baseDir)
- envVars.map{ case(k,v) => builder.environment().put(k, v) }
-
+ private def launchDriver(builder: ProcessBuilder, baseDir: File, supervise: Boolean) {
+ builder.directory(baseDir)
def initialize(process: Process) = {
// Redirect stdout and stderr to files
val stdout = new File(baseDir, "stdout")
@@ -175,8 +166,8 @@ private[spark] class DriverRunner(
val stderr = new File(baseDir, "stderr")
val header = "Launch Command: %s\n%s\n\n".format(
- command.mkString("\"", "\" \"", "\""), "=" * 40)
- Files.append(header, stderr, Charsets.UTF_8)
+ builder.command.mkString("\"", "\" \"", "\""), "=" * 40)
+ Files.append(header, stderr, UTF_8)
CommandUtils.redirectStream(process.getErrorStream, stderr)
}
runCommandWithRetry(ProcessBuilderLike(builder), initialize, supervise)
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
index 467317dd9b44c..8ba6a01bbcb97 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
@@ -19,8 +19,10 @@ package org.apache.spark.deploy.worker
import java.io._
+import scala.collection.JavaConversions._
+
import akka.actor.ActorRef
-import com.google.common.base.Charsets
+import com.google.common.base.Charsets.UTF_8
import com.google.common.io.Files
import org.apache.spark.{SparkConf, Logging}
@@ -30,6 +32,7 @@ import org.apache.spark.util.logging.FileAppender
/**
* Manages the execution of one executor process.
+ * This is currently only used in standalone mode.
*/
private[spark] class ExecutorRunner(
val appId: String,
@@ -41,7 +44,7 @@ private[spark] class ExecutorRunner(
val workerId: String,
val host: String,
val sparkHome: File,
- val workDir: File,
+ val executorDir: File,
val workerUrl: String,
val conf: SparkConf,
var state: ExecutorState.Value)
@@ -72,7 +75,7 @@ private[spark] class ExecutorRunner(
}
/**
- * kill executor process, wait for exit and notify worker to update resource status
+ * Kill executor process, wait for exit and notify worker to update resource status.
*
* @param message the exception message which caused the executor's death
*/
@@ -110,39 +113,25 @@ private[spark] class ExecutorRunner(
case "{{EXECUTOR_ID}}" => execId.toString
case "{{HOSTNAME}}" => host
case "{{CORES}}" => cores.toString
+ case "{{APP_ID}}" => appId
case other => other
}
- def getCommandSeq = {
- val command = Command(appDesc.command.mainClass,
- appDesc.command.arguments.map(substituteVariables) ++ Seq(appId), appDesc.command.environment,
- appDesc.command.classPathEntries, appDesc.command.libraryPathEntries,
- appDesc.command.extraJavaOptions)
- CommandUtils.buildCommandSeq(command, memory, sparkHome.getAbsolutePath)
- }
-
/**
* Download and run the executor described in our ApplicationDescription
*/
def fetchAndRunExecutor() {
try {
- // Create the executor's working directory
- val executorDir = new File(workDir, appId + "/" + execId)
- if (!executorDir.mkdirs()) {
- throw new IOException("Failed to create directory " + executorDir)
- }
-
// Launch the process
- val command = getCommandSeq
+ val builder = CommandUtils.buildProcessBuilder(appDesc.command, memory,
+ sparkHome.getAbsolutePath, substituteVariables)
+ val command = builder.command()
logInfo("Launch command: " + command.mkString("\"", "\" \"", "\""))
- val builder = new ProcessBuilder(command: _*).directory(executorDir)
- val env = builder.environment()
- for ((key, value) <- appDesc.command.environment) {
- env.put(key, value)
- }
+
+ builder.directory(executorDir)
// In case we are running this from within the Spark Shell, avoid creating a "scala"
// parent process for the executor command
- env.put("SPARK_LAUNCH_WITH_SCALA", "0")
+ builder.environment.put("SPARK_LAUNCH_WITH_SCALA", "0")
process = builder.start()
val header = "Spark Executor Command: %s\n%s\n\n".format(
command.mkString("\"", "\" \"", "\""), "=" * 40)
@@ -152,9 +141,11 @@ private[spark] class ExecutorRunner(
stdoutAppender = FileAppender(process.getInputStream, stdout, conf)
val stderr = new File(executorDir, "stderr")
- Files.write(header, stderr, Charsets.UTF_8)
+ Files.write(header, stderr, UTF_8)
stderrAppender = FileAppender(process.getErrorStream, stderr, conf)
+ state = ExecutorState.RUNNING
+ worker ! ExecutorStateChanged(appId, execId, state, None, None)
// Wait for it to exit; executor may exit with code 0 (when driver instructs it to shutdown)
// or with nonzero exit code
val exitCode = process.waitFor()
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala
new file mode 100644
index 0000000000000..b9798963bab0a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.worker
+
+import org.apache.spark.{Logging, SparkConf, SecurityManager}
+import org.apache.spark.network.TransportContext
+import org.apache.spark.network.netty.SparkTransportConf
+import org.apache.spark.network.sasl.SaslRpcHandler
+import org.apache.spark.network.server.TransportServer
+import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler
+
+/**
+ * Provides a server from which Executors can read shuffle files (rather than reading directly from
+ * each other), to provide uninterrupted access to the files in the face of executors being turned
+ * off or killed.
+ *
+ * Optionally requires SASL authentication in order to read. See [[SecurityManager]].
+ */
+private[worker]
+class StandaloneWorkerShuffleService(sparkConf: SparkConf, securityManager: SecurityManager)
+ extends Logging {
+
+ private val enabled = sparkConf.getBoolean("spark.shuffle.service.enabled", false)
+ private val port = sparkConf.getInt("spark.shuffle.service.port", 7337)
+ private val useSasl: Boolean = securityManager.isAuthenticationEnabled()
+
+ private val transportConf = SparkTransportConf.fromSparkConf(sparkConf, numUsableCores = 0)
+ private val blockHandler = new ExternalShuffleBlockHandler(transportConf)
+ private val transportContext: TransportContext = {
+ val handler = if (useSasl) new SaslRpcHandler(blockHandler, securityManager) else blockHandler
+ new TransportContext(transportConf, handler)
+ }
+
+ private var server: TransportServer = _
+
+ /** Starts the external shuffle service if the user has configured us to. */
+ def startIfEnabled() {
+ if (enabled) {
+ require(server == null, "Shuffle server already started")
+ logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl")
+ server = transportContext.createServer(port)
+ }
+ }
+
+ def stop() {
+ if (enabled && server != null) {
+ server.close()
+ server = null
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index ce425443051b0..ca262de832e25 100755
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -18,12 +18,16 @@
package org.apache.spark.deploy.worker
import java.io.File
+import java.io.IOException
import java.text.SimpleDateFormat
-import java.util.Date
+import java.util.{UUID, Date}
+import java.util.concurrent.TimeUnit
+import scala.collection.JavaConversions._
import scala.collection.mutable.HashMap
import scala.concurrent.duration._
import scala.language.postfixOps
+import scala.util.Random
import akka.actor._
import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
@@ -34,7 +38,7 @@ import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.{DriverState, Master}
import org.apache.spark.deploy.worker.ui.WorkerWebUI
import org.apache.spark.metrics.MetricsSystem
-import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils}
+import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils}
/**
* @param masterUrls Each url should look like spark://host:port.
@@ -51,7 +55,7 @@ private[spark] class Worker(
workDirPath: String = null,
val conf: SparkConf,
val securityMgr: SecurityManager)
- extends Actor with Logging {
+ extends Actor with ActorLogReceive with Logging {
import context.dispatcher
Utils.checkHost(host, "Expected hostname")
@@ -62,8 +66,22 @@ private[spark] class Worker(
// Send a heartbeat every (heartbeat timeout) / 4 milliseconds
val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4
- val REGISTRATION_TIMEOUT = 20.seconds
- val REGISTRATION_RETRIES = 3
+ // Model retries to connect to the master, after Hadoop's model.
+ // The first six attempts to reconnect are in shorter intervals (between 5 and 15 seconds)
+ // Afterwards, the next 10 attempts are between 30 and 90 seconds.
+ // A bit of randomness is introduced so that not all of the workers attempt to reconnect at
+ // the same time.
+ val INITIAL_REGISTRATION_RETRIES = 6
+ val TOTAL_REGISTRATION_RETRIES = INITIAL_REGISTRATION_RETRIES + 10
+ val FUZZ_MULTIPLIER_INTERVAL_LOWER_BOUND = 0.500
+ val REGISTRATION_RETRY_FUZZ_MULTIPLIER = {
+ val randomNumberGenerator = new Random(UUID.randomUUID.getMostSignificantBits)
+ randomNumberGenerator.nextDouble + FUZZ_MULTIPLIER_INTERVAL_LOWER_BOUND
+ }
+ val INITIAL_REGISTRATION_RETRY_INTERVAL = (math.round(10 *
+ REGISTRATION_RETRY_FUZZ_MULTIPLIER)).seconds
+ val PROLONGED_REGISTRATION_RETRY_INTERVAL = (math.round(60
+ * REGISTRATION_RETRY_FUZZ_MULTIPLIER)).seconds
val CLEANUP_ENABLED = conf.getBoolean("spark.worker.cleanup.enabled", false)
// How often worker will clean up old app folders
@@ -71,8 +89,7 @@ private[spark] class Worker(
// TTL for app folders/data; after TTL expires it will be cleaned up
val APP_DATA_RETENTION_SECS = conf.getLong("spark.worker.cleanup.appDataTtl", 7 * 24 * 3600)
-
- val masterLock: Object = new Object()
+ val testing: Boolean = sys.props.contains("spark.testing")
var master: ActorSelection = null
var masterAddress: Address = null
var activeMasterUrl: String = ""
@@ -81,13 +98,22 @@ private[spark] class Worker(
@volatile var registered = false
@volatile var connected = false
val workerId = generateWorkerId()
- val sparkHome = new File(Option(System.getenv("SPARK_HOME")).getOrElse("."))
+ val sparkHome =
+ if (testing) {
+ assert(sys.props.contains("spark.test.home"), "spark.test.home is not set!")
+ new File(sys.props("spark.test.home"))
+ } else {
+ new File(sys.env.get("SPARK_HOME").getOrElse("."))
+ }
var workDir: File = null
val executors = new HashMap[String, ExecutorRunner]
val finishedExecutors = new HashMap[String, ExecutorRunner]
val drivers = new HashMap[String, DriverRunner]
val finishedDrivers = new HashMap[String, DriverRunner]
+ // The shuffle service is not actually started unless configured.
+ val shuffleService = new StandaloneWorkerShuffleService(conf, securityMgr)
+
val publicAddress = {
val envVar = System.getenv("SPARK_PUBLIC_DNS")
if (envVar != null) envVar else host
@@ -96,6 +122,7 @@ private[spark] class Worker(
var coresUsed = 0
var memoryUsed = 0
+ var connectionAttemptCount = 0
val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, securityMgr)
val workerSource = new WorkerSource(this)
@@ -130,7 +157,8 @@ private[spark] class Worker(
logInfo("Spark home: " + sparkHome)
createWorkDir()
context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
- webUi = new WorkerWebUI(this, workDir, Some(webUiPort))
+ shuffleService.startIfEnabled()
+ webUi = new WorkerWebUI(this, workDir, webUiPort)
webUi.bind()
registerWithMaster()
@@ -139,21 +167,19 @@ private[spark] class Worker(
}
def changeMaster(url: String, uiUrl: String) {
- masterLock.synchronized {
- activeMasterUrl = url
- activeMasterWebUiUrl = uiUrl
- master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl))
- masterAddress = activeMasterUrl match {
- case Master.sparkUrlRegex(_host, _port) =>
- Address("akka.tcp", Master.systemName, _host, _port.toInt)
- case x =>
- throw new SparkException("Invalid spark URL: " + x)
- }
- connected = true
+ activeMasterUrl = url
+ activeMasterWebUiUrl = uiUrl
+ master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl))
+ masterAddress = activeMasterUrl match {
+ case Master.sparkUrlRegex(_host, _port) =>
+ Address("akka.tcp", Master.systemName, _host, _port.toInt)
+ case x =>
+ throw new SparkException("Invalid spark URL: " + x)
}
+ connected = true
}
- def tryRegisterAllMasters() {
+ private def tryRegisterAllMasters() {
for (masterUrl <- masterUrls) {
logInfo("Connecting to master " + masterUrl + "...")
val actor = context.actorSelection(Master.toAkkaUrl(masterUrl))
@@ -161,49 +187,82 @@ private[spark] class Worker(
}
}
- def registerWithMaster() {
- tryRegisterAllMasters()
- var retries = 0
- registrationRetryTimer = Some {
- context.system.scheduler.schedule(REGISTRATION_TIMEOUT, REGISTRATION_TIMEOUT) {
- Utils.tryOrExit {
- retries += 1
- if (registered) {
- registrationRetryTimer.foreach(_.cancel())
- } else if (retries >= REGISTRATION_RETRIES) {
- logError("All masters are unresponsive! Giving up.")
- System.exit(1)
- } else {
- tryRegisterAllMasters()
+ private def retryConnectToMaster() {
+ Utils.tryOrExit {
+ connectionAttemptCount += 1
+ if (registered) {
+ registrationRetryTimer.foreach(_.cancel())
+ registrationRetryTimer = None
+ } else if (connectionAttemptCount <= TOTAL_REGISTRATION_RETRIES) {
+ logInfo(s"Retrying connection to master (attempt # $connectionAttemptCount)")
+ tryRegisterAllMasters()
+ if (connectionAttemptCount == INITIAL_REGISTRATION_RETRIES) {
+ registrationRetryTimer.foreach(_.cancel())
+ registrationRetryTimer = Some {
+ context.system.scheduler.schedule(PROLONGED_REGISTRATION_RETRY_INTERVAL,
+ PROLONGED_REGISTRATION_RETRY_INTERVAL)(retryConnectToMaster)
}
}
+ } else {
+ logError("All masters are unresponsive! Giving up.")
+ System.exit(1)
}
}
}
- override def receive = {
+ def registerWithMaster() {
+ // DisassociatedEvent may be triggered multiple times, so don't attempt registration
+ // if there are outstanding registration attempts scheduled.
+ registrationRetryTimer match {
+ case None =>
+ registered = false
+ tryRegisterAllMasters()
+ connectionAttemptCount = 0
+ registrationRetryTimer = Some {
+ context.system.scheduler.schedule(INITIAL_REGISTRATION_RETRY_INTERVAL,
+ INITIAL_REGISTRATION_RETRY_INTERVAL)(retryConnectToMaster)
+ }
+ case Some(_) =>
+ logInfo("Not spawning another attempt to register with the master, since there is an" +
+ " attempt scheduled already.")
+ }
+ }
+
+ override def receiveWithLogging = {
case RegisteredWorker(masterUrl, masterWebUiUrl) =>
logInfo("Successfully registered with master " + masterUrl)
registered = true
changeMaster(masterUrl, masterWebUiUrl)
context.system.scheduler.schedule(0 millis, HEARTBEAT_MILLIS millis, self, SendHeartbeat)
if (CLEANUP_ENABLED) {
+ logInfo(s"Worker cleanup enabled; old application directories will be deleted in: $workDir")
context.system.scheduler.schedule(CLEANUP_INTERVAL_MILLIS millis,
CLEANUP_INTERVAL_MILLIS millis, self, WorkDirCleanup)
}
case SendHeartbeat =>
- masterLock.synchronized {
- if (connected) { master ! Heartbeat(workerId) }
- }
+ if (connected) { master ! Heartbeat(workerId) }
case WorkDirCleanup =>
// Spin up a separate thread (in a future) to do the dir cleanup; don't tie up worker actor
val cleanupFuture = concurrent.future {
- logInfo("Cleaning up oldest application directories in " + workDir + " ...")
- Utils.findOldFiles(workDir, APP_DATA_RETENTION_SECS)
- .foreach(Utils.deleteRecursively)
+ val appDirs = workDir.listFiles()
+ if (appDirs == null) {
+ throw new IOException("ERROR: Failed to list files in " + appDirs)
+ }
+ appDirs.filter { dir =>
+ // the directory is used by an application - check that the application is not running
+ // when cleaning up
+ val appIdFromDir = dir.getName
+ val isAppStillRunning = executors.values.map(_.appId).contains(appIdFromDir)
+ dir.isDirectory && !isAppStillRunning &&
+ !Utils.doesDirectoryContainAnyNewFiles(dir, APP_DATA_RETENTION_SECS)
+ }.foreach { dir =>
+ logInfo(s"Removing directory: ${dir.getPath}")
+ Utils.deleteRecursively(dir)
+ }
}
+
cleanupFuture onFailure {
case e: Throwable =>
logError("App dir cleanup failed: " + e.getMessage, e)
@@ -226,45 +285,49 @@ private[spark] class Worker(
System.exit(1)
}
+ case ReconnectWorker(masterUrl) =>
+ logInfo(s"Master with url $masterUrl requested this worker to reconnect.")
+ registerWithMaster()
+
case LaunchExecutor(masterUrl, appId, execId, appDesc, cores_, memory_) =>
if (masterUrl != activeMasterUrl) {
logWarning("Invalid Master (" + masterUrl + ") attempted to launch executor.")
} else {
try {
logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name))
+
+ // Create the executor's working directory
+ val executorDir = new File(workDir, appId + "/" + execId)
+ if (!executorDir.mkdirs()) {
+ throw new IOException("Failed to create directory " + executorDir)
+ }
+
val manager = new ExecutorRunner(appId, execId, appDesc, cores_, memory_,
- self, workerId, host,
- appDesc.sparkHome.map(userSparkHome => new File(userSparkHome)).getOrElse(sparkHome),
- workDir, akkaUrl, conf, ExecutorState.RUNNING)
+ self, workerId, host, sparkHome, executorDir, akkaUrl, conf, ExecutorState.LOADING)
executors(appId + "/" + execId) = manager
manager.start()
coresUsed += cores_
memoryUsed += memory_
- masterLock.synchronized {
- master ! ExecutorStateChanged(appId, execId, manager.state, None, None)
- }
+ master ! ExecutorStateChanged(appId, execId, manager.state, None, None)
} catch {
case e: Exception => {
- logError("Failed to launch executor %s/%d for %s".format(appId, execId, appDesc.name))
+ logError(s"Failed to launch executor $appId/$execId for ${appDesc.name}.", e)
if (executors.contains(appId + "/" + execId)) {
executors(appId + "/" + execId).kill()
executors -= appId + "/" + execId
}
- masterLock.synchronized {
- master ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, None, None)
- }
+ master ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED,
+ Some(e.toString), None)
}
}
}
case ExecutorStateChanged(appId, execId, state, message, exitStatus) =>
- masterLock.synchronized {
- master ! ExecutorStateChanged(appId, execId, state, message, exitStatus)
- }
+ master ! ExecutorStateChanged(appId, execId, state, message, exitStatus)
val fullId = appId + "/" + execId
if (ExecutorState.isFinished(state)) {
executors.get(fullId) match {
- case Some(executor) =>
+ case Some(executor) =>
logInfo("Executor " + fullId + " finished with state " + state +
message.map(" message " + _).getOrElse("") +
exitStatus.map(" exitStatus " + _).getOrElse(""))
@@ -295,7 +358,7 @@ private[spark] class Worker(
case LaunchDriver(driverId, driverDesc) => {
logInfo(s"Asked to launch driver $driverId")
- val driver = new DriverRunner(driverId, workDir, sparkHome, driverDesc, self, akkaUrl)
+ val driver = new DriverRunner(conf, driverId, workDir, sparkHome, driverDesc, self, akkaUrl)
drivers(driverId) = driver
driver.start()
@@ -326,9 +389,7 @@ private[spark] class Worker(
case _ =>
logDebug(s"Driver $driverId changed state to $state")
}
- masterLock.synchronized {
- master ! DriverStateChanged(driverId, state, exception)
- }
+ master ! DriverStateChanged(driverId, state, exception)
val driver = drivers.remove(driverId).get
finishedDrivers(driverId) = driver
memoryUsed -= driver.driverDesc.mem
@@ -347,9 +408,10 @@ private[spark] class Worker(
}
}
- def masterDisconnected() {
+ private def masterDisconnected() {
logError("Connection to master failed! Waiting for master to reconnect...")
connected = false
+ registerWithMaster()
}
def generateWorkerId(): String = {
@@ -357,9 +419,11 @@ private[spark] class Worker(
}
override def postStop() {
+ metricsSystem.report()
registrationRetryTimer.foreach(_.cancel())
executors.values.foreach(_.kill())
drivers.values.foreach(_.kill())
+ shuffleService.stop()
webUi.stop()
metricsSystem.stop()
}
@@ -368,7 +432,8 @@ private[spark] class Worker(
private[spark] object Worker extends Logging {
def main(argStrings: Array[String]) {
SignalLogger.register(log)
- val args = new WorkerArguments(argStrings)
+ val conf = new SparkConf
+ val args = new WorkerArguments(argStrings, conf)
val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores,
args.memory, args.masters, args.workDir)
actorSystem.awaitTermination()
@@ -381,7 +446,8 @@ private[spark] object Worker extends Logging {
cores: Int,
memory: Int,
masterUrls: Array[String],
- workDir: String, workerNumber: Option[Int] = None): (ActorSystem, Int) = {
+ workDir: String,
+ workerNumber: Option[Int] = None): (ActorSystem, Int) = {
// The LocalSparkCluster runs multiple local sparkWorkerX actor systems
val conf = new SparkConf
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
index dc5158102054e..019cd70f2a229 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
@@ -20,11 +20,12 @@ package org.apache.spark.deploy.worker
import java.lang.management.ManagementFactory
import org.apache.spark.util.{IntParam, MemoryParam, Utils}
+import org.apache.spark.SparkConf
/**
* Command-line parser for the worker.
*/
-private[spark] class WorkerArguments(args: Array[String]) {
+private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) {
var host = Utils.localHostName()
var port = 0
var webUiPort = 8081
@@ -32,6 +33,7 @@ private[spark] class WorkerArguments(args: Array[String]) {
var memory = inferDefaultMemory()
var masters: Array[String] = null
var workDir: String = null
+ var propertiesFile: String = null
// Check for settings in environment variables
if (System.getenv("SPARK_WORKER_PORT") != null) {
@@ -40,8 +42,8 @@ private[spark] class WorkerArguments(args: Array[String]) {
if (System.getenv("SPARK_WORKER_CORES") != null) {
cores = System.getenv("SPARK_WORKER_CORES").toInt
}
- if (System.getenv("SPARK_WORKER_MEMORY") != null) {
- memory = Utils.memoryStringToMb(System.getenv("SPARK_WORKER_MEMORY"))
+ if (conf.getenv("SPARK_WORKER_MEMORY") != null) {
+ memory = Utils.memoryStringToMb(conf.getenv("SPARK_WORKER_MEMORY"))
}
if (System.getenv("SPARK_WORKER_WEBUI_PORT") != null) {
webUiPort = System.getenv("SPARK_WORKER_WEBUI_PORT").toInt
@@ -52,6 +54,15 @@ private[spark] class WorkerArguments(args: Array[String]) {
parse(args.toList)
+ // This mutates the SparkConf, so all accesses to it must be made after this line
+ propertiesFile = Utils.loadDefaultSparkProperties(conf, propertiesFile)
+
+ if (conf.contains("spark.worker.ui.port")) {
+ webUiPort = conf.get("spark.worker.ui.port").toInt
+ }
+
+ checkWorkerMemory()
+
def parse(args: List[String]): Unit = args match {
case ("--ip" | "-i") :: value :: tail =>
Utils.checkHost(value, "ip no longer supported, please use hostname " + value)
@@ -83,7 +94,11 @@ private[spark] class WorkerArguments(args: Array[String]) {
webUiPort = value
parse(tail)
- case ("--help" | "-h") :: tail =>
+ case ("--properties-file") :: value :: tail =>
+ propertiesFile = value
+ parse(tail)
+
+ case ("--help") :: tail =>
printUsageAndExit(0)
case value :: tail =>
@@ -118,7 +133,9 @@ private[spark] class WorkerArguments(args: Array[String]) {
" -i HOST, --ip IP Hostname to listen on (deprecated, please use --host or -h)\n" +
" -h HOST, --host HOST Hostname to listen on\n" +
" -p PORT, --port PORT Port to listen on (default: random)\n" +
- " --webui-port PORT Port for web UI (default: 8081)")
+ " --webui-port PORT Port for web UI (default: 8081)\n" +
+ " --properties-file FILE Path to a custom Spark properties file.\n" +
+ " Default is conf/spark-defaults.conf.")
System.exit(exitCode)
}
@@ -149,4 +166,11 @@ private[spark] class WorkerArguments(args: Array[String]) {
// Leave out 1 GB for the operating system, but don't return a negative memory size
math.max(totalMb - 1024, 512)
}
+
+ def checkWorkerMemory(): Unit = {
+ if (memory <= 0) {
+ val message = "Memory can't be 0, missing a M or G on the end of the memory specification?"
+ throw new IllegalStateException(message)
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala
index b7ddd8c816cbc..df1e01b23b932 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala
@@ -22,8 +22,8 @@ import com.codahale.metrics.{Gauge, MetricRegistry}
import org.apache.spark.metrics.source.Source
private[spark] class WorkerSource(val worker: Worker) extends Source {
- val sourceName = "worker"
- val metricRegistry = new MetricRegistry()
+ override val sourceName = "worker"
+ override val metricRegistry = new MetricRegistry()
metricRegistry.register(MetricRegistry.name("executors"), new Gauge[Int] {
override def getValue: Int = worker.executors.size
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
index 530c147000904..63a8ac817b618 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
@@ -22,13 +22,15 @@ import akka.remote.{AssociatedEvent, AssociationErrorEvent, AssociationEvent, Di
import org.apache.spark.Logging
import org.apache.spark.deploy.DeployMessages.SendHeartbeat
+import org.apache.spark.util.ActorLogReceive
/**
* Actor which connects to a worker process and terminates the JVM if the connection is severed.
* Provides fate sharing between a worker and its associated child processes.
*/
-private[spark] class WorkerWatcher(workerUrl: String) extends Actor
- with Logging {
+private[spark] class WorkerWatcher(workerUrl: String)
+ extends Actor with ActorLogReceive with Logging {
+
override def preStart() {
context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
@@ -48,11 +50,11 @@ private[spark] class WorkerWatcher(workerUrl: String) extends Actor
def exitNonZero() = if (isTesting) isShutDown = true else System.exit(-1)
- override def receive = {
+ override def receiveWithLogging = {
case AssociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) =>
logInfo(s"Successfully connected to $workerUrl")
- case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound)
+ case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _)
if isWorker(remoteAddress) =>
// These logs may not be seen if the worker (and associated pipe) has died
logError(s"Could not initialize connection to worker $workerUrl. Exiting.")
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala
index b389cb546de6c..ecb358c399819 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala
@@ -17,7 +17,6 @@
package org.apache.spark.deploy.worker.ui
-import java.io.File
import javax.servlet.http.HttpServletRequest
import scala.xml.Node
@@ -25,7 +24,7 @@ import scala.xml.Node
import org.apache.spark.ui.{WebUIPage, UIUtils}
import org.apache.spark.util.Utils
import org.apache.spark.Logging
-import org.apache.spark.util.logging.{FileAppender, RollingFileAppender}
+import org.apache.spark.util.logging.RollingFileAppender
private[spark] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with Logging {
private val worker = parent.worker
@@ -64,11 +63,11 @@ private[spark] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") w
val offset = Option(request.getParameter("offset")).map(_.toLong)
val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes)
- val (logDir, params) = (appId, executorId, driverId) match {
+ val (logDir, params, pageName) = (appId, executorId, driverId) match {
case (Some(a), Some(e), None) =>
- (s"${workDir.getPath}/$a/$e/", s"appId=$a&executorId=$e")
+ (s"${workDir.getPath}/$a/$e/", s"appId=$a&executorId=$e", s"$a/$e")
case (None, None, Some(d)) =>
- (s"${workDir.getPath}/$d/", s"driverId=$d")
+ (s"${workDir.getPath}/$d/", s"driverId=$d", d)
case _ =>
throw new Exception("Request must specify either application or driver identifiers")
}
@@ -120,7 +119,7 @@ private[spark] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") w