diff --git a/bin/spark-class b/bin/spark-class index 0d58d95c1aee..79af42c72c76 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -148,7 +148,7 @@ fi if [[ "$1" =~ org.apache.spark.tools.* ]]; then if test -z "$SPARK_TOOLS_JAR"; then echo "Failed to find Spark Tools Jar in $FWDIR/tools/target/scala-$SPARK_SCALA_VERSION/" 1>&2 - echo "You need to build Spark before running $1." 1>&2 + echo "You need to run \"build/sbt tools/package\" before running $1." 1>&2 exit 1 fi CLASSPATH="$CLASSPATH:$SPARK_TOOLS_JAR" diff --git a/core/src/main/java/org/apache/spark/JavaSparkListener.java b/core/src/main/java/org/apache/spark/JavaSparkListener.java new file mode 100644 index 000000000000..646496f31350 --- /dev/null +++ b/core/src/main/java/org/apache/spark/JavaSparkListener.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark; + +import org.apache.spark.scheduler.SparkListener; +import org.apache.spark.scheduler.SparkListenerApplicationEnd; +import org.apache.spark.scheduler.SparkListenerApplicationStart; +import org.apache.spark.scheduler.SparkListenerBlockManagerAdded; +import org.apache.spark.scheduler.SparkListenerBlockManagerRemoved; +import org.apache.spark.scheduler.SparkListenerEnvironmentUpdate; +import org.apache.spark.scheduler.SparkListenerExecutorAdded; +import org.apache.spark.scheduler.SparkListenerExecutorMetricsUpdate; +import org.apache.spark.scheduler.SparkListenerExecutorRemoved; +import org.apache.spark.scheduler.SparkListenerJobEnd; +import org.apache.spark.scheduler.SparkListenerJobStart; +import org.apache.spark.scheduler.SparkListenerStageCompleted; +import org.apache.spark.scheduler.SparkListenerStageSubmitted; +import org.apache.spark.scheduler.SparkListenerTaskEnd; +import org.apache.spark.scheduler.SparkListenerTaskGettingResult; +import org.apache.spark.scheduler.SparkListenerTaskStart; +import org.apache.spark.scheduler.SparkListenerUnpersistRDD; + +/** + * Java clients should extend this class instead of implementing + * SparkListener directly. This is to prevent java clients + * from breaking when new events are added to the SparkListener + * trait. + * + * This is a concrete class instead of abstract to enforce + * new events get added to both the SparkListener and this adapter + * in lockstep. + */ +public class JavaSparkListener implements SparkListener { + + @Override + public void onStageCompleted(SparkListenerStageCompleted stageCompleted) { } + + @Override + public void onStageSubmitted(SparkListenerStageSubmitted stageSubmitted) { } + + @Override + public void onTaskStart(SparkListenerTaskStart taskStart) { } + + @Override + public void onTaskGettingResult(SparkListenerTaskGettingResult taskGettingResult) { } + + @Override + public void onTaskEnd(SparkListenerTaskEnd taskEnd) { } + + @Override + public void onJobStart(SparkListenerJobStart jobStart) { } + + @Override + public void onJobEnd(SparkListenerJobEnd jobEnd) { } + + @Override + public void onEnvironmentUpdate(SparkListenerEnvironmentUpdate environmentUpdate) { } + + @Override + public void onBlockManagerAdded(SparkListenerBlockManagerAdded blockManagerAdded) { } + + @Override + public void onBlockManagerRemoved(SparkListenerBlockManagerRemoved blockManagerRemoved) { } + + @Override + public void onUnpersistRDD(SparkListenerUnpersistRDD unpersistRDD) { } + + @Override + public void onApplicationStart(SparkListenerApplicationStart applicationStart) { } + + @Override + public void onApplicationEnd(SparkListenerApplicationEnd applicationEnd) { } + + @Override + public void onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate executorMetricsUpdate) { } + + @Override + public void onExecutorAdded(SparkListenerExecutorAdded executorAdded) { } + + @Override + public void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { } +} diff --git a/core/src/main/java/org/apache/spark/TaskContext.java b/core/src/main/java/org/apache/spark/TaskContext.java index 0d6973203eba..095f9fb94fdf 100644 --- a/core/src/main/java/org/apache/spark/TaskContext.java +++ b/core/src/main/java/org/apache/spark/TaskContext.java @@ -62,7 +62,7 @@ static void unset() { */ public abstract boolean isInterrupted(); - /** @deprecated: use isRunningLocally() */ + /** @deprecated use {@link #isRunningLocally()} */ @Deprecated public abstract boolean runningLocally(); @@ -87,19 +87,39 @@ static void unset() { * is for HadoopRDD to register a callback to close the input stream. * Will be called in any situation - success, failure, or cancellation. * - * @deprecated: use addTaskCompletionListener + * @deprecated use {@link #addTaskCompletionListener(scala.Function1)} * * @param f Callback function. */ @Deprecated public abstract void addOnCompleteCallback(final Function0 f); + /** + * The ID of the stage that this task belong to. + */ public abstract int stageId(); + /** + * The ID of the RDD partition that is computed by this task. + */ public abstract int partitionId(); + /** + * How many times this task has been attempted. The first task attempt will be assigned + * attemptNumber = 0, and subsequent attempts will have increasing attempt numbers. + */ + public abstract int attemptNumber(); + + /** @deprecated use {@link #taskAttemptId()}; it was renamed to avoid ambiguity. */ + @Deprecated public abstract long attemptId(); + /** + * An ID that is unique to this task attempt (within the same SparkContext, no two task attempts + * will share the same attempt ID). This is roughly equivalent to Hadoop's TaskAttemptID. + */ + public abstract long taskAttemptId(); + /** ::DeveloperApi:: */ @DeveloperApi public abstract TaskMetrics taskMetrics(); diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index 5751964b792c..f02b035a980b 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -19,6 +19,7 @@ height: 50px; font-size: 15px; margin-bottom: 15px; + min-width: 1200px } .navbar .navbar-inner { @@ -39,12 +40,12 @@ .navbar .nav > li a { height: 30px; - line-height: 30px; + line-height: 2; } .navbar-text { height: 50px; - line-height: 50px; + line-height: 3.3; } table.sortable thead { @@ -170,7 +171,7 @@ span.additional-metric-title { } .version { - line-height: 30px; + line-height: 2.5; vertical-align: bottom; font-size: 12px; padding: 0; diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index 80da62c44edc..a0c0372b7f0e 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -44,7 +44,11 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { blockManager.get(key) match { case Some(blockResult) => // Partition is already materialized, so just return its values - context.taskMetrics.inputMetrics = Some(blockResult.inputMetrics) + val inputMetrics = blockResult.inputMetrics + val existingMetrics = context.taskMetrics + .getInputMetricsForReadMethod(inputMetrics.readMethod) + existingMetrics.addBytesRead(inputMetrics.bytesRead) + new InterruptibleIterator(context, blockResult.data.asInstanceOf[Iterator[T]]) case None => diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala index edc3889c9ae5..677c5e0f89d7 100644 --- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala @@ -24,6 +24,7 @@ import com.google.common.io.Files import org.apache.spark.util.Utils private[spark] class HttpFileServer( + conf: SparkConf, securityManager: SecurityManager, requestedPort: Int = 0) extends Logging { @@ -41,7 +42,7 @@ private[spark] class HttpFileServer( fileDir.mkdir() jarDir.mkdir() logInfo("HTTP File server directory is " + baseDir) - httpServer = new HttpServer(baseDir, securityManager, requestedPort, "HTTP file server") + httpServer = new HttpServer(conf, baseDir, securityManager, requestedPort, "HTTP file server") httpServer.start() serverUri = httpServer.uri logDebug("HTTP file server started at: " + serverUri) diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala index 912558d0cab7..fa22787ce7ea 100644 --- a/core/src/main/scala/org/apache/spark/HttpServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpServer.scala @@ -42,6 +42,7 @@ private[spark] class ServerStateException(message: String) extends Exception(mes * around a Jetty server. */ private[spark] class HttpServer( + conf: SparkConf, resourceBase: File, securityManager: SecurityManager, requestedPort: Int = 0, @@ -57,7 +58,7 @@ private[spark] class HttpServer( } else { logInfo("Starting HTTP Server") val (actualServer, actualPort) = - Utils.startServiceOnPort[Server](requestedPort, doStart, serverName) + Utils.startServiceOnPort[Server](requestedPort, doStart, conf, serverName) server = actualServer port = actualPort } diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index c14764f77398..a0ce107f43b1 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -370,7 +370,9 @@ private[spark] object SparkConf { } /** - * Return whether the given config is a Spark port config. + * Return true if the given config matches either `spark.*.port` or `spark.port.*`. */ - def isSparkPortConf(name: String): Boolean = name.startsWith("spark.") && name.endsWith(".port") + def isSparkPortConf(name: String): Boolean = { + (name.startsWith("spark.") && name.endsWith(".port")) || name.startsWith("spark.port.") + } } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index ff5d796ee276..6a354ed4d148 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -520,10 +520,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** Distribute a local Scala collection to form an RDD. * - * @note Parallelize acts lazily. If `seq` is a mutable collection and is - * altered after the call to parallelize and before the first action on the - * RDD, the resultant RDD will reflect the modified collection. Pass a copy of - * the argument to avoid this. + * @note Parallelize acts lazily. If `seq` is a mutable collection and is altered after the call + * to parallelize and before the first action on the RDD, the resultant RDD will reflect the + * modified collection. Pass a copy of the argument to avoid this. */ def parallelize[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = { new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]()) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 43436a169700..4d418037bd33 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -312,7 +312,7 @@ object SparkEnv extends Logging { val httpFileServer = if (isDriver) { val fileServerPort = conf.getInt("spark.fileserver.port", 0) - val server = new HttpFileServer(securityManager, fileServerPort) + val server = new HttpFileServer(conf, securityManager, fileServerPort) server.initialize() conf.set("spark.fileserver.uri", server.serverUri) server diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index afd2b85d33a7..9bb0c61e441f 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -22,14 +22,19 @@ import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerExce import scala.collection.mutable.ArrayBuffer -private[spark] class TaskContextImpl(val stageId: Int, +private[spark] class TaskContextImpl( + val stageId: Int, val partitionId: Int, - val attemptId: Long, + override val taskAttemptId: Long, + override val attemptNumber: Int, val runningLocally: Boolean = false, val taskMetrics: TaskMetrics = TaskMetrics.empty) extends TaskContext with Logging { + // For backwards-compatibility; this method is now deprecated as of 1.3.0. + override def attemptId: Long = taskAttemptId + // List of callback functions to execute when the task completes. @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener] diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 31f0a462f84d..31d6958c403b 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -153,7 +153,8 @@ private[broadcast] object HttpBroadcast extends Logging { private def createServer(conf: SparkConf) { broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf)) val broadcastPort = conf.getInt("spark.broadcast.port", 0) - server = new HttpServer(broadcastDir, securityManager, broadcastPort, "HTTP broadcast server") + server = + new HttpServer(conf, broadcastDir, securityManager, broadcastPort, "HTTP broadcast server") server.start() serverUri = server.uri logInfo("Broadcast server started at " + serverUri) diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala index 2e1e52906cee..e5873ce724b9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala @@ -23,7 +23,7 @@ import scala.collection.mutable.ListBuffer import org.apache.log4j.Level -import org.apache.spark.util.MemoryParam +import org.apache.spark.util.{IntParam, MemoryParam} /** * Command-line parser for the driver client. @@ -51,8 +51,8 @@ private[spark] class ClientArguments(args: Array[String]) { parse(args.toList) def parse(args: List[String]): Unit = args match { - case ("--cores" | "-c") :: value :: tail => - cores = value.toInt + case ("--cores" | "-c") :: IntParam(value) :: tail => + cores = value parse(tail) case ("--memory" | "-m") :: MemoryParam(value) :: tail => diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 955cbd6dab96..050ba91eb2bc 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -200,6 +200,7 @@ object SparkSubmit { // Yarn cluster only OptionAssigner(args.name, YARN, CLUSTER, clOption = "--name"), OptionAssigner(args.driverMemory, YARN, CLUSTER, clOption = "--driver-memory"), + OptionAssigner(args.driverCores, YARN, CLUSTER, clOption = "--driver-cores"), OptionAssigner(args.queue, YARN, CLUSTER, clOption = "--queue"), OptionAssigner(args.numExecutors, YARN, CLUSTER, clOption = "--num-executors"), OptionAssigner(args.executorMemory, YARN, CLUSTER, clOption = "--executor-memory"), diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index f14ef4d29938..81ec08cb6d50 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -108,6 +108,9 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St .orElse(sparkProperties.get("spark.driver.memory")) .orElse(env.get("SPARK_DRIVER_MEMORY")) .orNull + driverCores = Option(driverCores) + .orElse(sparkProperties.get("spark.driver.cores")) + .orNull executorMemory = Option(executorMemory) .orElse(sparkProperties.get("spark.executor.memory")) .orElse(env.get("SPARK_EXECUTOR_MEMORY")) @@ -149,6 +152,11 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St // Global defaults. These should be keep to minimum to avoid confusing behavior. master = Option(master).getOrElse("local[*]") + // In YARN mode, app name can be set via SPARK_YARN_APP_NAME (see SPARK-5222) + if (master.startsWith("yarn")) { + name = Option(name).orElse(env.get("SPARK_YARN_APP_NAME")).orNull + } + // Set name from main class if not given name = Option(name).orElse(Option(mainClass)).orNull if (name == null && primaryResource != null) { @@ -401,6 +409,8 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St | --total-executor-cores NUM Total cores for all executors. | | YARN-only: + | --driver-cores NUM Number of cores used by the driver, only in cluster mode + | (Default: 1). | --executor-cores NUM Number of cores per executor (Default: 1). | --queue QUEUE_NAME The YARN queue to submit to (Default: "default"). | --num-executors NUM Number of executors to launch (Default: 2). diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index ad7d81747c37..ede0a9dbefb8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -38,8 +38,8 @@ private[spark] class ApplicationInfo( extends Serializable { @transient var state: ApplicationState.Value = _ - @transient var executors: mutable.HashMap[Int, ExecutorInfo] = _ - @transient var removedExecutors: ArrayBuffer[ExecutorInfo] = _ + @transient var executors: mutable.HashMap[Int, ExecutorDesc] = _ + @transient var removedExecutors: ArrayBuffer[ExecutorDesc] = _ @transient var coresGranted: Int = _ @transient var endTime: Long = _ @transient var appSource: ApplicationSource = _ @@ -55,12 +55,12 @@ private[spark] class ApplicationInfo( private def init() { state = ApplicationState.WAITING - executors = new mutable.HashMap[Int, ExecutorInfo] + executors = new mutable.HashMap[Int, ExecutorDesc] coresGranted = 0 endTime = -1L appSource = new ApplicationSource(this) nextExecutorId = 0 - removedExecutors = new ArrayBuffer[ExecutorInfo] + removedExecutors = new ArrayBuffer[ExecutorDesc] } private def newExecutorId(useID: Option[Int] = None): Int = { @@ -75,14 +75,14 @@ private[spark] class ApplicationInfo( } } - def addExecutor(worker: WorkerInfo, cores: Int, useID: Option[Int] = None): ExecutorInfo = { - val exec = new ExecutorInfo(newExecutorId(useID), this, worker, cores, desc.memoryPerSlave) + def addExecutor(worker: WorkerInfo, cores: Int, useID: Option[Int] = None): ExecutorDesc = { + val exec = new ExecutorDesc(newExecutorId(useID), this, worker, cores, desc.memoryPerSlave) executors(exec.id) = exec coresGranted += cores exec } - def removeExecutor(exec: ExecutorInfo) { + def removeExecutor(exec: ExecutorDesc) { if (executors.contains(exec.id)) { removedExecutors += executors(exec.id) executors -= exec.id diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ExecutorDesc.scala similarity index 95% rename from core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala rename to core/src/main/scala/org/apache/spark/deploy/master/ExecutorDesc.scala index d417070c5101..5d620dfcabad 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ExecutorDesc.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.master import org.apache.spark.deploy.{ExecutorDescription, ExecutorState} -private[spark] class ExecutorInfo( +private[spark] class ExecutorDesc( val id: Int, val application: ApplicationInfo, val worker: WorkerInfo, @@ -37,7 +37,7 @@ private[spark] class ExecutorInfo( override def equals(other: Any): Boolean = { other match { - case info: ExecutorInfo => + case info: ExecutorDesc => fullId == info.fullId && worker.id == info.worker.id && cores == info.cores && diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 4b631ec63907..d92d99310a58 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -581,7 +581,7 @@ private[spark] class Master( } } - def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo) { + def launchExecutor(worker: WorkerInfo, exec: ExecutorDesc) { logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) worker.addExecutor(exec) worker.actor ! LaunchExecutor(masterUrl, diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala index 473ddc23ff0f..e94aae93e449 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala @@ -38,7 +38,7 @@ private[spark] class WorkerInfo( Utils.checkHost(host, "Expected hostname") assert (port > 0) - @transient var executors: mutable.HashMap[String, ExecutorInfo] = _ // executorId => info + @transient var executors: mutable.HashMap[String, ExecutorDesc] = _ // executorId => info @transient var drivers: mutable.HashMap[String, DriverInfo] = _ // driverId => info @transient var state: WorkerState.Value = _ @transient var coresUsed: Int = _ @@ -70,13 +70,13 @@ private[spark] class WorkerInfo( host + ":" + port } - def addExecutor(exec: ExecutorInfo) { + def addExecutor(exec: ExecutorDesc) { executors(exec.fullId) = exec coresUsed += exec.cores memoryUsed += exec.memory } - def removeExecutor(exec: ExecutorInfo) { + def removeExecutor(exec: ExecutorDesc) { if (executors.contains(exec.fullId)) { executors -= exec.fullId coresUsed -= exec.cores diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 4588c130ef43..3aae2b95d739 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -27,7 +27,7 @@ import org.json4s.JValue import org.apache.spark.deploy.{ExecutorState, JsonProtocol} import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} -import org.apache.spark.deploy.master.ExecutorInfo +import org.apache.spark.deploy.master.ExecutorDesc import org.apache.spark.ui.{UIUtils, WebUIPage} import org.apache.spark.util.Utils @@ -109,7 +109,7 @@ private[spark] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app UIUtils.basicSparkPage(content, "Application: " + app.desc.name) } - private def executorRow(executor: ExecutorInfo): Seq[Node] = { + private def executorRow(executor: ExecutorDesc): Seq[Node] = { {executor.id} diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index c794a7bc3599..9a4adfbbb3d7 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -71,7 +71,8 @@ private[spark] class CoarseGrainedExecutorBackend( val ser = env.closureSerializer.newInstance() val taskDesc = ser.deserialize[TaskDescription](data.value) logInfo("Got assigned task " + taskDesc.taskId) - executor.launchTask(this, taskDesc.taskId, taskDesc.name, taskDesc.serializedTask) + executor.launchTask(this, taskId = taskDesc.taskId, attemptNumber = taskDesc.attemptNumber, + taskDesc.name, taskDesc.serializedTask) } case KillTask(taskId, _, interruptThread) => diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 0f99cd9f3b08..6660b98eb8ce 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -108,8 +108,13 @@ private[spark] class Executor( startDriverHeartbeater() def launchTask( - context: ExecutorBackend, taskId: Long, taskName: String, serializedTask: ByteBuffer) { - val tr = new TaskRunner(context, taskId, taskName, serializedTask) + context: ExecutorBackend, + taskId: Long, + attemptNumber: Int, + taskName: String, + serializedTask: ByteBuffer) { + val tr = new TaskRunner(context, taskId = taskId, attemptNumber = attemptNumber, taskName, + serializedTask) runningTasks.put(taskId, tr) threadPool.execute(tr) } @@ -134,7 +139,11 @@ private[spark] class Executor( private def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum class TaskRunner( - execBackend: ExecutorBackend, val taskId: Long, taskName: String, serializedTask: ByteBuffer) + execBackend: ExecutorBackend, + val taskId: Long, + val attemptNumber: Int, + taskName: String, + serializedTask: ByteBuffer) extends Runnable { @volatile private var killed = false @@ -180,7 +189,7 @@ private[spark] class Executor( // Run the actual task and measure its runtime. taskStart = System.currentTimeMillis() - val value = task.run(taskId.toInt) + val value = task.run(taskAttemptId = taskId, attemptNumber = attemptNumber) val taskFinish = System.currentTimeMillis() // If the task has been killed, let's fail it. @@ -370,6 +379,7 @@ private[spark] class Executor( if (!taskRunner.attemptedTask.isEmpty) { Option(taskRunner.task).flatMap(_.metrics).foreach { metrics => metrics.updateShuffleReadMetrics + metrics.updateInputMetrics() metrics.jvmGCTime = curGCTime - taskRunner.startGCTime if (isLocal) { // JobProgressListener will hold an reference of it during diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index 2e23ae0a4f83..cfd672e1d8a9 100644 --- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -28,6 +28,7 @@ import org.apache.mesos.Protos.{TaskStatus => MesosTaskStatus, _} import org.apache.spark.{Logging, TaskState, SparkConf, SparkEnv} import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.scheduler.cluster.mesos.{MesosTaskLaunchData} import org.apache.spark.util.{SignalLogger, Utils} private[spark] class MesosExecutorBackend @@ -77,11 +78,13 @@ private[spark] class MesosExecutorBackend override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) { val taskId = taskInfo.getTaskId.getValue.toLong + val taskData = MesosTaskLaunchData.fromByteString(taskInfo.getData) if (executor == null) { logError("Received launchTask but executor was null") } else { SparkHadoopUtil.get.runAsSparkUser { () => - executor.launchTask(this, taskId, taskInfo.getName, taskInfo.getData.asReadOnlyByteBuffer) + executor.launchTask(this, taskId = taskId, attemptNumber = taskData.attemptNumber, + taskInfo.getName, taskData.serializedTask) } } } diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 51b5328cb4c8..7eb10f95e023 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -17,6 +17,11 @@ package org.apache.spark.executor +import java.util.concurrent.atomic.AtomicLong + +import org.apache.spark.executor.DataReadMethod +import org.apache.spark.executor.DataReadMethod.DataReadMethod + import scala.collection.mutable.ArrayBuffer import org.apache.spark.annotation.DeveloperApi @@ -80,7 +85,17 @@ class TaskMetrics extends Serializable { * If this task reads from a HadoopRDD or from persisted data, metrics on how much data was read * are stored here. */ - var inputMetrics: Option[InputMetrics] = None + private var _inputMetrics: Option[InputMetrics] = None + + def inputMetrics = _inputMetrics + + /** + * This should only be used when recreating TaskMetrics, not when updating input metrics in + * executors + */ + private[spark] def setInputMetrics(inputMetrics: Option[InputMetrics]) { + _inputMetrics = inputMetrics + } /** * If this task writes data externally (e.g. to a distributed filesystem), metrics on how much @@ -133,6 +148,30 @@ class TaskMetrics extends Serializable { readMetrics } + /** + * Returns the input metrics object that the task should use. Currently, if + * there exists an input metric with the same readMethod, we return that one + * so the caller can accumulate bytes read. If the readMethod is different + * than previously seen by this task, we return a new InputMetric but don't + * record it. + * + * Once https://issues.apache.org/jira/browse/SPARK-5225 is addressed, + * we can store all the different inputMetrics (one per readMethod). + */ + private[spark] def getInputMetricsForReadMethod(readMethod: DataReadMethod): + InputMetrics =synchronized { + _inputMetrics match { + case None => + val metrics = new InputMetrics(readMethod) + _inputMetrics = Some(metrics) + metrics + case Some(metrics @ InputMetrics(method)) if method == readMethod => + metrics + case Some(InputMetrics(method)) => + new InputMetrics(readMethod) + } + } + /** * Aggregates shuffle read metrics for all registered dependencies into shuffleReadMetrics. */ @@ -146,6 +185,10 @@ class TaskMetrics extends Serializable { } _shuffleReadMetrics = Some(merged) } + + private[spark] def updateInputMetrics() = synchronized { + inputMetrics.foreach(_.updateBytesRead()) + } } private[spark] object TaskMetrics { @@ -179,10 +222,38 @@ object DataWriteMethod extends Enumeration with Serializable { */ @DeveloperApi case class InputMetrics(readMethod: DataReadMethod.Value) { + + private val _bytesRead: AtomicLong = new AtomicLong() + /** * Total bytes read. */ - var bytesRead: Long = 0L + def bytesRead: Long = _bytesRead.get() + @volatile @transient var bytesReadCallback: Option[() => Long] = None + + /** + * Adds additional bytes read for this read method. + */ + def addBytesRead(bytes: Long) = { + _bytesRead.addAndGet(bytes) + } + + /** + * Invoke the bytesReadCallback and mutate bytesRead. + */ + def updateBytesRead() { + bytesReadCallback.foreach { c => + _bytesRead.set(c()) + } + } + + /** + * Register a function that can be called to get up-to-date information on how many bytes the task + * has read from an input source. + */ + def setBytesReadCallback(f: Option[() => Long]) { + bytesReadCallback = f + } } /** diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 3340fca08014..03c4137ca0a8 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -174,7 +174,7 @@ private[nio] class ConnectionManager( serverChannel.socket.bind(new InetSocketAddress(port)) (serverChannel, serverChannel.socket.getLocalPort) } - Utils.startServiceOnPort[ServerSocketChannel](port, startService, name) + Utils.startServiceOnPort[ServerSocketChannel](port, startService, conf, name) serverChannel.register(selector, SelectionKey.OP_ACCEPT) val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort) diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index 7ba1182f0ed2..1c13e2c37284 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -95,7 +95,8 @@ private[spark] object CheckpointRDD extends Logging { val finalOutputName = splitIdToFile(ctx.partitionId) val finalOutputPath = new Path(outputDir, finalOutputName) - val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptId) + val tempOutputPath = + new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptNumber) if (fs.exists(tempOutputPath)) { throw new IOException("Checkpoint failed: temporary path " + @@ -119,7 +120,7 @@ private[spark] object CheckpointRDD extends Logging { logInfo("Deleting tempOutputPath " + tempOutputPath) fs.delete(tempOutputPath, false) throw new IOException("Checkpoint failed: failed to save output of task: " - + ctx.attemptId + " and final output path does not exist") + + ctx.attemptNumber + " and final output path does not exist") } else { // Some other copy of this task must've finished before us and renamed it logInfo("Final output path " + finalOutputPath + " already exists; not overwriting it") diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 0001c2329c83..3b99d3a6cafd 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -213,23 +213,24 @@ class HadoopRDD[K, V]( logInfo("Input split: " + split.inputSplit) val jobConf = getJobConf() - val inputMetrics = new InputMetrics(DataReadMethod.Hadoop) + val inputMetrics = context.taskMetrics + .getInputMetricsForReadMethod(DataReadMethod.Hadoop) + // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes - val bytesReadCallback = if (split.inputSplit.value.isInstanceOf[FileSplit]) { - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback( - split.inputSplit.value.asInstanceOf[FileSplit].getPath, jobConf) - } else { - None - } - if (bytesReadCallback.isDefined) { - context.taskMetrics.inputMetrics = Some(inputMetrics) - } + val bytesReadCallback = inputMetrics.bytesReadCallback.orElse( + split.inputSplit.value match { + case split: FileSplit => + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback(split.getPath, jobConf) + case _ => None + } + ) + inputMetrics.setBytesReadCallback(bytesReadCallback) var reader: RecordReader[K, V] = null val inputFormat = getInputFormat(jobConf) HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime), - context.stageId, theSplit.index, context.attemptId.toInt, jobConf) + context.stageId, theSplit.index, context.attemptNumber, jobConf) reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) // Register an on-task-completion callback to close the input stream. @@ -237,8 +238,6 @@ class HadoopRDD[K, V]( val key: K = reader.createKey() val value: V = reader.createValue() - var recordsSinceMetricsUpdate = 0 - override def getNext() = { try { finished = !reader.next(key, value) @@ -246,16 +245,6 @@ class HadoopRDD[K, V]( case eof: EOFException => finished = true } - - // Update bytes read metric every few records - if (recordsSinceMetricsUpdate == HadoopRDD.RECORDS_BETWEEN_BYTES_READ_METRIC_UPDATES - && bytesReadCallback.isDefined) { - recordsSinceMetricsUpdate = 0 - val bytesReadFn = bytesReadCallback.get - inputMetrics.bytesRead = bytesReadFn() - } else { - recordsSinceMetricsUpdate += 1 - } (key, value) } @@ -263,14 +252,12 @@ class HadoopRDD[K, V]( try { reader.close() if (bytesReadCallback.isDefined) { - val bytesReadFn = bytesReadCallback.get - inputMetrics.bytesRead = bytesReadFn() + inputMetrics.updateBytesRead() } else if (split.inputSplit.value.isInstanceOf[FileSplit]) { // If we can't get the bytes read from the FS stats, fall back to the split size, // which may be inaccurate. try { - inputMetrics.bytesRead = split.inputSplit.value.getLength - context.taskMetrics.inputMetrics = Some(inputMetrics) + inputMetrics.addBytesRead(split.inputSplit.value.getLength) } catch { case e: java.io.IOException => logWarning("Unable to get input size to set InputMetrics for task", e) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index e55d03d391e0..890ec677c269 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -109,18 +109,19 @@ class NewHadoopRDD[K, V]( logInfo("Input split: " + split.serializableHadoopSplit) val conf = confBroadcast.value.value - val inputMetrics = new InputMetrics(DataReadMethod.Hadoop) + val inputMetrics = context.taskMetrics + .getInputMetricsForReadMethod(DataReadMethod.Hadoop) + // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes - val bytesReadCallback = if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit]) { - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback( - split.serializableHadoopSplit.value.asInstanceOf[FileSplit].getPath, conf) - } else { - None - } - if (bytesReadCallback.isDefined) { - context.taskMetrics.inputMetrics = Some(inputMetrics) - } + val bytesReadCallback = inputMetrics.bytesReadCallback.orElse( + split.serializableHadoopSplit.value match { + case split: FileSplit => + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback(split.getPath, conf) + case _ => None + } + ) + inputMetrics.setBytesReadCallback(bytesReadCallback) val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) @@ -153,34 +154,19 @@ class NewHadoopRDD[K, V]( throw new java.util.NoSuchElementException("End of stream") } havePair = false - - // Update bytes read metric every few records - if (recordsSinceMetricsUpdate == HadoopRDD.RECORDS_BETWEEN_BYTES_READ_METRIC_UPDATES - && bytesReadCallback.isDefined) { - recordsSinceMetricsUpdate = 0 - val bytesReadFn = bytesReadCallback.get - inputMetrics.bytesRead = bytesReadFn() - } else { - recordsSinceMetricsUpdate += 1 - } - (reader.getCurrentKey, reader.getCurrentValue) } private def close() { try { reader.close() - - // Update metrics with final amount if (bytesReadCallback.isDefined) { - val bytesReadFn = bytesReadCallback.get - inputMetrics.bytesRead = bytesReadFn() + inputMetrics.updateBytesRead() } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit]) { // If we can't get the bytes read from the FS stats, fall back to the split size, // which may be inaccurate. try { - inputMetrics.bytesRead = split.serializableHadoopSplit.value.getLength - context.taskMetrics.inputMetrics = Some(inputMetrics) + inputMetrics.addBytesRead(split.serializableHadoopSplit.value.getLength) } catch { case e: java.io.IOException => logWarning("Unable to get input size to set InputMetrics for task", e) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 38f8f36a4a4d..e43e5066655b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -978,12 +978,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val writeShard = (context: TaskContext, iter: Iterator[(K,V)]) => { val config = wrappedConf.value - // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it - // around by taking a mod. We expect that no task will be attempted 2 billion times. - val attemptNumber = (context.attemptId % Int.MaxValue).toInt /* "reduce task" */ val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, - attemptNumber) + context.attemptNumber) val hadoopContext = newTaskAttemptContext(config, attemptId) val format = outfmt.newInstance format match { @@ -1062,11 +1059,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val config = wrappedConf.value // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. - val attemptNumber = (context.attemptId % Int.MaxValue).toInt + val taskAttemptId = (context.taskAttemptId % Int.MaxValue).toInt val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config) - writer.setup(context.stageId, context.partitionId, attemptNumber) + writer.setup(context.stageId, context.partitionId, taskAttemptId) writer.open() try { var recordsWritten = 0L diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index 87b22de6ae69..f12d0cffaba3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -111,7 +111,8 @@ private object ParallelCollectionRDD { /** * Slice a collection into numSlices sub-collections. One extra thing we do here is to treat Range * collections specially, encoding the slices as other Ranges to minimize memory cost. This makes - * it efficient to run Spark over RDDs representing large sets of numbers. + * it efficient to run Spark over RDDs representing large sets of numbers. And if the collection + * is an inclusive Range, we use inclusive range for the last slice. */ def slice[T: ClassTag](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = { if (numSlices < 1) { @@ -127,19 +128,15 @@ private object ParallelCollectionRDD { }) } seq match { - case r: Range.Inclusive => { - val sign = if (r.step < 0) { - -1 - } else { - 1 - } - slice(new Range( - r.start, r.end + sign, r.step).asInstanceOf[Seq[T]], numSlices) - } case r: Range => { - positions(r.length, numSlices).map({ - case (start, end) => + positions(r.length, numSlices).zipWithIndex.map({ case ((start, end), index) => + // If the range is inclusive, use inclusive range for the last slice + if (r.isInclusive && index == numSlices - 1) { + new Range.Inclusive(r.start + start * r.step, r.end, r.step) + } + else { new Range(r.start + start * r.step, r.start + end * r.step, r.step) + } }).toSeq.asInstanceOf[Seq[Seq[T]]] } case nr: NumericRange[_] => { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 61d09d73e17c..3bca59e0646d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -635,8 +635,8 @@ class DAGScheduler( try { val rdd = job.finalStage.rdd val split = rdd.partitions(job.partitions(0)) - val taskContext = - new TaskContextImpl(job.finalStage.id, job.partitions(0), 0, true) + val taskContext = new TaskContextImpl(job.finalStage.id, job.partitions(0), taskAttemptId = 0, + attemptNumber = 0, runningLocally = true) TaskContextHelper.setTaskContext(taskContext) try { val result = job.func(taskContext, rdd.iterator(split, taskContext)) @@ -661,7 +661,7 @@ class DAGScheduler( // completion events or stage abort stageIdToStage -= s.id jobIdToStageIds -= job.jobId - listenerBus.post(SparkListenerJobEnd(job.jobId, jobResult)) + listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTime(), jobResult)) } } @@ -710,7 +710,7 @@ class DAGScheduler( stage.latestInfo.stageFailed(stageFailedMessage) listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) } - listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error))) + listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTime(), JobFailed(error))) } } @@ -749,9 +749,11 @@ class DAGScheduler( logInfo("Missing parents: " + getMissingParentStages(finalStage)) val shouldRunLocally = localExecutionEnabled && allowLocal && finalStage.parents.isEmpty && partitions.length == 1 + val jobSubmissionTime = clock.getTime() if (shouldRunLocally) { // Compute very short actions like first() or take() with no parent stages locally. - listenerBus.post(SparkListenerJobStart(job.jobId, Seq.empty, properties)) + listenerBus.post( + SparkListenerJobStart(job.jobId, jobSubmissionTime, Seq.empty, properties)) runLocally(job) } else { jobIdToActiveJob(jobId) = job @@ -759,7 +761,8 @@ class DAGScheduler( finalStage.resultOfJob = Some(job) val stageIds = jobIdToStageIds(jobId).toArray val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) - listenerBus.post(SparkListenerJobStart(job.jobId, stageInfos, properties)) + listenerBus.post( + SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties)) submitStage(finalStage) } } @@ -965,7 +968,8 @@ class DAGScheduler( if (job.numFinished == job.numPartitions) { markStageAsFinished(stage) cleanupStateForJobAndIndependentStages(job) - listenerBus.post(SparkListenerJobEnd(job.jobId, JobSucceeded)) + listenerBus.post( + SparkListenerJobEnd(job.jobId, clock.getTime(), JobSucceeded)) } // taskSucceeded runs some user code that might throw an exception. Make sure @@ -1234,7 +1238,7 @@ class DAGScheduler( if (ableToCancelStages) { job.listener.jobFailed(error) cleanupStateForJobAndIndependentStages(job) - listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error))) + listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTime(), JobFailed(error))) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 27bf4f159907..30075c172bdb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -168,6 +168,10 @@ private[spark] class EventLoggingListener( logEvent(event, flushLogger = true) override def onApplicationEnd(event: SparkListenerApplicationEnd) = logEvent(event, flushLogger = true) + override def onExecutorAdded(event: SparkListenerExecutorAdded) = + logEvent(event, flushLogger = true) + override def onExecutorRemoved(event: SparkListenerExecutorRemoved) = + logEvent(event, flushLogger = true) // No-op because logging every update would be overkill override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate) { } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index b62b0c131269..e5d1eb767e10 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -25,6 +25,7 @@ import scala.collection.mutable import org.apache.spark.{Logging, TaskEndReason} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.{Distribution, Utils} @@ -58,6 +59,7 @@ case class SparkListenerTaskEnd( @DeveloperApi case class SparkListenerJobStart( jobId: Int, + time: Long, stageInfos: Seq[StageInfo], properties: Properties = null) extends SparkListenerEvent { @@ -67,7 +69,11 @@ case class SparkListenerJobStart( } @DeveloperApi -case class SparkListenerJobEnd(jobId: Int, jobResult: JobResult) extends SparkListenerEvent +case class SparkListenerJobEnd( + jobId: Int, + time: Long, + jobResult: JobResult) + extends SparkListenerEvent @DeveloperApi case class SparkListenerEnvironmentUpdate(environmentDetails: Map[String, Seq[(String, String)]]) @@ -84,6 +90,14 @@ case class SparkListenerBlockManagerRemoved(time: Long, blockManagerId: BlockMan @DeveloperApi case class SparkListenerUnpersistRDD(rddId: Int) extends SparkListenerEvent +@DeveloperApi +case class SparkListenerExecutorAdded(executorId: String, executorInfo: ExecutorInfo) + extends SparkListenerEvent + +@DeveloperApi +case class SparkListenerExecutorRemoved(executorId: String) + extends SparkListenerEvent + /** * Periodic updates from executors. * @param execId executor id @@ -109,7 +123,8 @@ private[spark] case object SparkListenerShutdown extends SparkListenerEvent /** * :: DeveloperApi :: * Interface for listening to events from the Spark scheduler. Note that this is an internal - * interface which might change in different Spark releases. + * interface which might change in different Spark releases. Java clients should extend + * {@link JavaSparkListener} */ @DeveloperApi trait SparkListener { @@ -183,6 +198,16 @@ trait SparkListener { * Called when the driver receives task metrics from an executor in a heartbeat. */ def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate) { } + + /** + * Called when the driver registers a new executor. + */ + def onExecutorAdded(executorAdded: SparkListenerExecutorAdded) { } + + /** + * Called when the driver removes an executor. + */ + def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved) { } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index e79ffd7a3587..e700c6af542f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -70,6 +70,10 @@ private[spark] trait SparkListenerBus extends Logging { foreachListener(_.onApplicationEnd(applicationEnd)) case metricsUpdate: SparkListenerExecutorMetricsUpdate => foreachListener(_.onExecutorMetricsUpdate(metricsUpdate)) + case executorAdded: SparkListenerExecutorAdded => + foreachListener(_.onExecutorAdded(executorAdded)) + case executorRemoved: SparkListenerExecutorRemoved => + foreachListener(_.onExecutorRemoved(executorRemoved)) case SparkListenerShutdown => } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index d7dde4fe3843..2367f7e2cf67 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -44,8 +44,16 @@ import org.apache.spark.util.Utils */ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable { - final def run(attemptId: Long): T = { - context = new TaskContextImpl(stageId, partitionId, attemptId, runningLocally = false) + /** + * Called by Executor to run this task. + * + * @param taskAttemptId an identifier for this task attempt that is unique within a SparkContext. + * @param attemptNumber how many times this task has been attempted (0 for the first attempt) + * @return the result of the task + */ + final def run(taskAttemptId: Long, attemptNumber: Int): T = { + context = new TaskContextImpl(stageId = stageId, partitionId = partitionId, + taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, runningLocally = false) TaskContextHelper.setTaskContext(context) context.taskMetrics.hostname = Utils.localHostName() taskThread = Thread.currentThread() diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala index 4c96b9e5fef6..1c7c81c488c3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala @@ -27,6 +27,7 @@ import org.apache.spark.util.SerializableBuffer */ private[spark] class TaskDescription( val taskId: Long, + val attemptNumber: Int, val executorId: String, val name: String, val index: Int, // Index within this task's TaskSet diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 466785091715..5c94c6bbcb37 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -487,7 +487,8 @@ private[spark] class TaskSetManager( taskName, taskId, host, taskLocality, serializedTask.limit)) sched.dagScheduler.taskStarted(task, info) - return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask)) + return Some(new TaskDescription(taskId = taskId, attemptNumber = attemptNum, execId, + taskName, index, serializedTask)) } case _ => } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index fe9914b50bc5..5786d367464f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -28,7 +28,7 @@ import akka.pattern.ask import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import org.apache.spark.{ExecutorAllocationClient, Logging, SparkEnv, SparkException, TaskState} -import org.apache.spark.scheduler.{SchedulerBackend, SlaveLost, TaskDescription, TaskSchedulerImpl, WorkerOffer} +import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Utils} @@ -66,6 +66,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste // Number of executors requested from the cluster manager that have not registered yet private var numPendingExecutors = 0 + private val listenerBus = scheduler.sc.listenerBus + // Executors we have requested the cluster manager to kill that have not died yet private val executorsPendingToRemove = new HashSet[String] @@ -106,6 +108,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste logDebug(s"Decremented number of pending executors ($numPendingExecutors left)") } } + listenerBus.post(SparkListenerExecutorAdded(executorId, data)) makeOffers() } @@ -213,6 +216,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste totalCoreCount.addAndGet(-executorInfo.totalCores) totalRegisteredExecutors.addAndGet(-1) scheduler.executorLost(executorId, SlaveLost(reason)) + listenerBus.post(SparkListenerExecutorRemoved(executorId)) case None => logError(s"Asked to remove non-existent executor $executorId") } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala index b71bd5783d6d..eb52ddfb1eab 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala @@ -31,7 +31,7 @@ import akka.actor.{Address, ActorRef} private[cluster] class ExecutorData( val executorActor: ActorRef, val executorAddress: Address, - val executorHost: String , + override val executorHost: String, var freeCores: Int, - val totalCores: Int -) + override val totalCores: Int +) extends ExecutorInfo(executorHost, totalCores) diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/BooleanType.java b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorInfo.scala similarity index 52% rename from sql/core/src/main/java/org/apache/spark/sql/api/java/BooleanType.java rename to core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorInfo.scala index 5a1f52725631..b4738e64c939 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/BooleanType.java +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorInfo.scala @@ -14,14 +14,32 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +package org.apache.spark.scheduler.cluster -package org.apache.spark.sql.api.java; +import org.apache.spark.annotation.DeveloperApi /** - * The data type representing boolean and Boolean values. - * - * {@code BooleanType} is represented by the singleton object {@link DataType#BooleanType}. + * :: DeveloperApi :: + * Stores information about an executor to pass from the scheduler to SparkListeners. */ -public class BooleanType extends DataType { - protected BooleanType() {} +@DeveloperApi +class ExecutorInfo( + val executorHost: String, + val totalCores: Int +) { + + def canEqual(other: Any): Boolean = other.isInstanceOf[ExecutorInfo] + + override def equals(other: Any): Boolean = other match { + case that: ExecutorInfo => + (that canEqual this) && + executorHost == that.executorHost && + totalCores == that.totalCores + case _ => false + } + + override def hashCode(): Int = { + val state = Seq(executorHost, totalCores) + state.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b) + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 10e6886c16a4..d252fe8595fb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -22,14 +22,16 @@ import java.util.{ArrayList => JArrayList, List => JList} import java.util.Collections import scala.collection.JavaConversions._ -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.collection.mutable.{HashMap, HashSet} import org.apache.mesos.protobuf.ByteString import org.apache.mesos.{Scheduler => MScheduler} import org.apache.mesos._ -import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} +import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, + ExecutorInfo => MesosExecutorInfo, _} import org.apache.spark.{Logging, SparkContext, SparkException, TaskState} +import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.scheduler._ import org.apache.spark.util.Utils @@ -62,6 +64,9 @@ private[spark] class MesosSchedulerBackend( var classLoader: ClassLoader = null + // The listener bus to publish executor added/removed events. + val listenerBus = sc.listenerBus + @volatile var appId: String = _ override def start() { @@ -87,7 +92,7 @@ private[spark] class MesosSchedulerBackend( } } - def createExecutorInfo(execId: String): ExecutorInfo = { + def createExecutorInfo(execId: String): MesosExecutorInfo = { val executorSparkHome = sc.conf.getOption("spark.mesos.executor.home") .orElse(sc.getSparkHome()) // Fall back to driver Spark home for backward compatibility .getOrElse { @@ -141,7 +146,7 @@ private[spark] class MesosSchedulerBackend( Value.Scalar.newBuilder() .setValue(MemoryUtils.calculateTotalMemory(sc)).build()) .build() - ExecutorInfo.newBuilder() + MesosExecutorInfo.newBuilder() .setExecutorId(ExecutorID.newBuilder().setValue(execId).build()) .setCommand(command) .setData(ByteString.copyFrom(createExecArg())) @@ -237,6 +242,7 @@ private[spark] class MesosSchedulerBackend( } val slaveIdToOffer = usableOffers.map(o => o.getSlaveId.getValue -> o).toMap + val slaveIdToWorkerOffer = workerOffers.map(o => o.executorId -> o).toMap val mesosTasks = new HashMap[String, JArrayList[MesosTaskInfo]] @@ -260,6 +266,10 @@ private[spark] class MesosSchedulerBackend( val filters = Filters.newBuilder().setRefuseSeconds(1).build() // TODO: lower timeout? mesosTasks.foreach { case (slaveId, tasks) => + slaveIdToWorkerOffer.get(slaveId).foreach(o => + listenerBus.post(SparkListenerExecutorAdded(slaveId, + new ExecutorInfo(o.host, o.cores))) + ) d.launchTasks(Collections.singleton(slaveIdToOffer(slaveId).getId), tasks, filters) } @@ -296,7 +306,7 @@ private[spark] class MesosSchedulerBackend( .setExecutor(createExecutorInfo(slaveId)) .setName(task.name) .addResources(cpuResource) - .setData(ByteString.copyFrom(task.serializedTask)) + .setData(MesosTaskLaunchData(task.serializedTask, task.attemptNumber).toByteString) .build() } @@ -315,7 +325,7 @@ private[spark] class MesosSchedulerBackend( synchronized { if (status.getState == MesosTaskState.TASK_LOST && taskIdToSlaveId.contains(tid)) { // We lost the executor on this slave, so remember that it's gone - slaveIdsWithExecutors -= taskIdToSlaveId(tid) + removeExecutor(taskIdToSlaveId(tid)) } if (isFinished(status.getState)) { taskIdToSlaveId.remove(tid) @@ -344,12 +354,20 @@ private[spark] class MesosSchedulerBackend( override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {} + /** + * Remove executor associated with slaveId in a thread safe manner. + */ + private def removeExecutor(slaveId: String) = { + synchronized { + listenerBus.post(SparkListenerExecutorRemoved(slaveId)) + slaveIdsWithExecutors -= slaveId + } + } + private def recordSlaveLost(d: SchedulerDriver, slaveId: SlaveID, reason: ExecutorLossReason) { inClassLoader() { logInfo("Mesos slave lost: " + slaveId.getValue) - synchronized { - slaveIdsWithExecutors -= slaveId.getValue - } + removeExecutor(slaveId.getValue) scheduler.executorLost(slaveId.getValue, reason) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala new file mode 100644 index 000000000000..4416ce92ade2 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import java.nio.ByteBuffer + +import org.apache.mesos.protobuf.ByteString + +/** + * Wrapper for serializing the data sent when launching Mesos tasks. + */ +private[spark] case class MesosTaskLaunchData( + serializedTask: ByteBuffer, + attemptNumber: Int) { + + def toByteString: ByteString = { + val dataBuffer = ByteBuffer.allocate(4 + serializedTask.limit) + dataBuffer.putInt(attemptNumber) + dataBuffer.put(serializedTask) + ByteString.copyFrom(dataBuffer) + } +} + +private[spark] object MesosTaskLaunchData { + def fromByteString(byteString: ByteString): MesosTaskLaunchData = { + val byteBuffer = byteString.asReadOnlyByteBuffer() + val attemptNumber = byteBuffer.getInt // updates the position by 4 bytes + val serializedTask = byteBuffer.slice() // subsequence starting at the current position + MesosTaskLaunchData(serializedTask, attemptNumber) + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index b3bd3110ac80..05b6fa54564b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -76,7 +76,8 @@ private[spark] class LocalActor( val offers = Seq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores)) for (task <- scheduler.resourceOffers(offers).flatten) { freeCores -= scheduler.CPUS_PER_TASK - executor.launchTask(executorBackend, task.taskId, task.name, task.serializedTask) + executor.launchTask(executorBackend, taskId = task.taskId, attemptNumber = task.attemptNumber, + task.name, task.serializedTask) } } } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index d2947dcea4f7..d56e23ce4478 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -29,7 +29,7 @@ import org.apache.spark._ import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.broadcast.HttpBroadcast import org.apache.spark.network.nio.{PutBlock, GotBlock, GetBlock} -import org.apache.spark.scheduler.MapStatus +import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus} import org.apache.spark.storage._ import org.apache.spark.util.BoundedPriorityQueue import org.apache.spark.util.collection.CompactBuffer @@ -207,7 +207,8 @@ private[serializer] object KryoSerializer { classOf[PutBlock], classOf[GotBlock], classOf[GetBlock], - classOf[MapStatus], + classOf[CompressedMapStatus], + classOf[HighlyCompressedMapStatus], classOf[CompactBuffer[_]], classOf[BlockManagerId], classOf[Array[Byte]], diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index d7b184f8a10e..8bc5a1cd18b6 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -34,10 +34,9 @@ import org.apache.spark.executor._ import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} -import org.apache.spark.network.netty.{SparkTransportConf, NettyBlockTransferService} +import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.ExternalShuffleClient import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo -import org.apache.spark.network.util.{ConfigProvider, TransportConf} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.shuffle.hash.HashShuffleManager @@ -54,7 +53,7 @@ private[spark] class BlockResult( readMethod: DataReadMethod.Value, bytes: Long) { val inputMetrics = new InputMetrics(readMethod) - inputMetrics.bytesRead = bytes + inputMetrics.addBytesRead(bytes) } /** @@ -120,7 +119,7 @@ private[spark] class BlockManager( private[spark] var shuffleServerId: BlockManagerId = _ // Client to read other executors' shuffle files. This is either an external service, or just the - // standard BlockTranserService to directly connect to other Executors. + // standard BlockTransferService to directly connect to other Executors. private[spark] val shuffleClient = if (externalShuffleServiceEnabled) { val transConf = SparkTransportConf.fromSparkConf(conf, numUsableCores) new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled()) diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 2a27d49d2de0..88fed833f922 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -201,7 +201,7 @@ private[spark] object JettyUtils extends Logging { } } - val (server, boundPort) = Utils.startServiceOnPort[Server](port, connect, serverName) + val (server, boundPort) = Utils.startServiceOnPort[Server](port, connect, conf, serverName) ServerInfo(server, boundPort, collection) } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index ea2d187a0e8e..81212708ba52 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -21,7 +21,6 @@ import scala.xml.{Node, NodeSeq} import javax.servlet.http.HttpServletRequest -import org.apache.spark.JobExecutionStatus import org.apache.spark.ui.{WebUIPage, UIUtils} import org.apache.spark.ui.jobs.UIData.JobUIData @@ -47,17 +46,17 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { val lastStageData = lastStageInfo.flatMap { s => listener.stageIdToData.get((s.stageId, s.attemptId)) } - val isComplete = job.status == JobExecutionStatus.SUCCEEDED + val lastStageName = lastStageInfo.map(_.name).getOrElse("(Unknown Stage Name)") val lastStageDescription = lastStageData.flatMap(_.description).getOrElse("") val duration: Option[Long] = { - job.startTime.map { start => - val end = job.endTime.getOrElse(System.currentTimeMillis()) + job.submissionTime.map { start => + val end = job.completionTime.getOrElse(System.currentTimeMillis()) end - start } } val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown") - val formattedSubmissionTime = job.startTime.map(UIUtils.formatDate).getOrElse("Unknown") + val formattedSubmissionTime = job.submissionTime.map(UIUtils.formatDate).getOrElse("Unknown") val detailUrl = "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(parent.basePath), job.jobId) @@ -68,7 +67,7 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") {
{lastStageDescription}
{lastStageName} - + {formattedSubmissionTime} {formattedDuration} @@ -101,11 +100,15 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { val now = System.currentTimeMillis val activeJobsTable = - jobsTable(activeJobs.sortBy(_.startTime.getOrElse(-1L)).reverse) + jobsTable(activeJobs.sortBy(_.submissionTime.getOrElse(-1L)).reverse) val completedJobsTable = - jobsTable(completedJobs.sortBy(_.endTime.getOrElse(-1L)).reverse) + jobsTable(completedJobs.sortBy(_.completionTime.getOrElse(-1L)).reverse) val failedJobsTable = - jobsTable(failedJobs.sortBy(_.endTime.getOrElse(-1L)).reverse) + jobsTable(failedJobs.sortBy(_.completionTime.getOrElse(-1L)).reverse) + + val shouldShowActiveJobs = activeJobs.nonEmpty + val shouldShowCompletedJobs = completedJobs.nonEmpty + val shouldShowFailedJobs = failedJobs.nonEmpty val summary: NodeSeq =
@@ -121,27 +124,47 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { Scheduling Mode: {listener.schedulingMode.map(_.toString).getOrElse("Unknown")} -
  • - Active Jobs: - {activeJobs.size} -
  • -
  • - Completed Jobs: - {completedJobs.size} -
  • -
  • - Failed Jobs: - {failedJobs.size} -
  • + { + if (shouldShowActiveJobs) { +
  • + Active Jobs: + {activeJobs.size} +
  • + } + } + { + if (shouldShowCompletedJobs) { +
  • + Completed Jobs: + {completedJobs.size} +
  • + } + } + { + if (shouldShowFailedJobs) { +
  • + Failed Jobs: + {failedJobs.size} +
  • + } + }
    - val content = summary ++ -

    Active Jobs ({activeJobs.size})

    ++ activeJobsTable ++ -

    Completed Jobs ({completedJobs.size})

    ++ completedJobsTable ++ -

    Failed Jobs ({failedJobs.size})

    ++ failedJobsTable - - val helpText = """A job is triggered by a action, like "count()" or "saveAsTextFile()".""" + + var content = summary + if (shouldShowActiveJobs) { + content ++=

    Active Jobs ({activeJobs.size})

    ++ + activeJobsTable + } + if (shouldShowCompletedJobs) { + content ++=

    Completed Jobs ({completedJobs.size})

    ++ + completedJobsTable + } + if (shouldShowFailedJobs) { + content ++=

    Failed Jobs ({failedJobs.size})

    ++ + failedJobsTable + } + val helpText = """A job is triggered by an action, like "count()" or "saveAsTextFile()".""" + " Click on a job's title to see information about the stages of tasks associated with" + " the job." diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 72935beb3a34..b0d3bed1300b 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -153,14 +153,13 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { val jobData: JobUIData = new JobUIData( jobId = jobStart.jobId, - startTime = Some(System.currentTimeMillis), - endTime = None, + submissionTime = Option(jobStart.time).filter(_ >= 0), stageIds = jobStart.stageIds, jobGroup = jobGroup, status = JobExecutionStatus.RUNNING) // Compute (a potential underestimate of) the number of tasks that will be run by this job. // This may be an underestimate because the job start event references all of the result - // stages's transitive stage dependencies, but some of these stages might be skipped if their + // stages' transitive stage dependencies, but some of these stages might be skipped if their // output is available from earlier runs. // See https://github.com/apache/spark/pull/3009 for a more extensive discussion. jobData.numTasks = { @@ -186,7 +185,8 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { logWarning(s"Job completed for unknown job ${jobEnd.jobId}") new JobUIData(jobId = jobEnd.jobId) } - jobData.endTime = Some(System.currentTimeMillis()) + jobData.completionTime = Option(jobEnd.time).filter(_ >= 0) + jobEnd.jobResult match { case JobSucceeded => completedJobs += jobData @@ -309,7 +309,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { val info = taskEnd.taskInfo // If stage attempt id is -1, it means the DAGScheduler had no idea which attempt this task - // compeletion event is for. Let's just drop it here. This means we might have some speculation + // completion event is for. Let's just drop it here. This means we might have some speculation // tasks on the web ui that's never marked as complete. if (info != null && taskEnd.stageAttemptId != -1) { val stageData = stageIdToData.getOrElseUpdate((taskEnd.stageId, taskEnd.stageAttemptId), { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index 48fd7caa1a1e..01f7e23212c3 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -40,15 +40,15 @@ private[jobs] object UIData { class JobUIData( var jobId: Int = -1, - var startTime: Option[Long] = None, - var endTime: Option[Long] = None, + var submissionTime: Option[Long] = None, + var completionTime: Option[Long] = None, var stageIds: Seq[Int] = Seq.empty, var jobGroup: Option[String] = None, var status: JobExecutionStatus = JobExecutionStatus.UNKNOWN, /* Tasks */ // `numTasks` is a potential underestimate of the true number of tasks that this job will run. // This may be an underestimate because the job start event references all of the result - // stages's transitive stage dependencies, but some of these stages might be skipped if their + // stages' transitive stage dependencies, but some of these stages might be skipped if their // output is available from earlier runs. // See https://github.com/apache/spark/pull/3009 for a more extensive discussion. var numTasks: Int = 0, diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index db2531dc171f..4c9b1e3c46f0 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -53,7 +53,7 @@ private[spark] object AkkaUtils extends Logging { val startService: Int => (ActorSystem, Int) = { actualPort => doCreateActorSystem(name, host, actualPort, conf, securityManager) } - Utils.startServiceOnPort(port, startService, name) + Utils.startServiceOnPort(port, startService, conf, name) } private def doCreateActorSystem( diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index d94e8252650d..76709a230f83 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -19,6 +19,8 @@ package org.apache.spark.util import java.util.{Properties, UUID} +import org.apache.spark.scheduler.cluster.ExecutorInfo + import scala.collection.JavaConverters._ import scala.collection.Map @@ -30,6 +32,7 @@ import org.apache.spark.executor._ import org.apache.spark.scheduler._ import org.apache.spark.storage._ import org.apache.spark._ +import org.apache.hadoop.hdfs.web.JsonUtil /** * Serializes SparkListener events to/from JSON. This protocol provides strong backwards- @@ -83,7 +86,10 @@ private[spark] object JsonProtocol { applicationStartToJson(applicationStart) case applicationEnd: SparkListenerApplicationEnd => applicationEndToJson(applicationEnd) - + case executorAdded: SparkListenerExecutorAdded => + executorAddedToJson(executorAdded) + case executorRemoved: SparkListenerExecutorRemoved => + executorRemovedToJson(executorRemoved) // These aren't used, but keeps compiler happy case SparkListenerShutdown => JNothing case SparkListenerExecutorMetricsUpdate(_, _) => JNothing @@ -136,6 +142,7 @@ private[spark] object JsonProtocol { val properties = propertiesToJson(jobStart.properties) ("Event" -> Utils.getFormattedClassName(jobStart)) ~ ("Job ID" -> jobStart.jobId) ~ + ("Submission Time" -> jobStart.time) ~ ("Stage Infos" -> jobStart.stageInfos.map(stageInfoToJson)) ~ // Added in Spark 1.2.0 ("Stage IDs" -> jobStart.stageIds) ~ ("Properties" -> properties) @@ -145,6 +152,7 @@ private[spark] object JsonProtocol { val jobResult = jobResultToJson(jobEnd.jobResult) ("Event" -> Utils.getFormattedClassName(jobEnd)) ~ ("Job ID" -> jobEnd.jobId) ~ + ("Completion Time" -> jobEnd.time) ~ ("Job Result" -> jobResult) } @@ -194,6 +202,16 @@ private[spark] object JsonProtocol { ("Timestamp" -> applicationEnd.time) } + def executorAddedToJson(executorAdded: SparkListenerExecutorAdded): JValue = { + ("Event" -> Utils.getFormattedClassName(executorAdded)) ~ + ("Executor ID" -> executorAdded.executorId) ~ + ("Executor Info" -> executorInfoToJson(executorAdded.executorInfo)) + } + + def executorRemovedToJson(executorRemoved: SparkListenerExecutorRemoved): JValue = { + ("Event" -> Utils.getFormattedClassName(executorRemoved)) ~ + ("Executor ID" -> executorRemoved.executorId) + } /** ------------------------------------------------------------------- * * JSON serialization methods for classes SparkListenerEvents depend on | @@ -362,6 +380,10 @@ private[spark] object JsonProtocol { ("Disk Size" -> blockStatus.diskSize) } + def executorInfoToJson(executorInfo: ExecutorInfo): JValue = { + ("Host" -> executorInfo.executorHost) ~ + ("Total Cores" -> executorInfo.totalCores) + } /** ------------------------------ * * Util JSON serialization methods | @@ -416,6 +438,8 @@ private[spark] object JsonProtocol { val unpersistRDD = Utils.getFormattedClassName(SparkListenerUnpersistRDD) val applicationStart = Utils.getFormattedClassName(SparkListenerApplicationStart) val applicationEnd = Utils.getFormattedClassName(SparkListenerApplicationEnd) + val executorAdded = Utils.getFormattedClassName(SparkListenerExecutorAdded) + val executorRemoved = Utils.getFormattedClassName(SparkListenerExecutorRemoved) (json \ "Event").extract[String] match { case `stageSubmitted` => stageSubmittedFromJson(json) @@ -431,6 +455,8 @@ private[spark] object JsonProtocol { case `unpersistRDD` => unpersistRDDFromJson(json) case `applicationStart` => applicationStartFromJson(json) case `applicationEnd` => applicationEndFromJson(json) + case `executorAdded` => executorAddedFromJson(json) + case `executorRemoved` => executorRemovedFromJson(json) } } @@ -469,6 +495,8 @@ private[spark] object JsonProtocol { def jobStartFromJson(json: JValue): SparkListenerJobStart = { val jobId = (json \ "Job ID").extract[Int] + val submissionTime = + Utils.jsonOption(json \ "Submission Time").map(_.extract[Long]).getOrElse(-1L) val stageIds = (json \ "Stage IDs").extract[List[JValue]].map(_.extract[Int]) val properties = propertiesFromJson(json \ "Properties") // The "Stage Infos" field was added in Spark 1.2.0 @@ -476,13 +504,15 @@ private[spark] object JsonProtocol { .map(_.extract[Seq[JValue]].map(stageInfoFromJson)).getOrElse { stageIds.map(id => new StageInfo(id, 0, "unknown", 0, Seq.empty, "unknown")) } - SparkListenerJobStart(jobId, stageInfos, properties) + SparkListenerJobStart(jobId, submissionTime, stageInfos, properties) } def jobEndFromJson(json: JValue): SparkListenerJobEnd = { val jobId = (json \ "Job ID").extract[Int] + val completionTime = + Utils.jsonOption(json \ "Completion Time").map(_.extract[Long]).getOrElse(-1L) val jobResult = jobResultFromJson(json \ "Job Result") - SparkListenerJobEnd(jobId, jobResult) + SparkListenerJobEnd(jobId, completionTime, jobResult) } def environmentUpdateFromJson(json: JValue): SparkListenerEnvironmentUpdate = { @@ -523,6 +553,16 @@ private[spark] object JsonProtocol { SparkListenerApplicationEnd((json \ "Timestamp").extract[Long]) } + def executorAddedFromJson(json: JValue): SparkListenerExecutorAdded = { + val executorId = (json \ "Executor ID").extract[String] + val executorInfo = executorInfoFromJson(json \ "Executor Info") + SparkListenerExecutorAdded(executorId, executorInfo) + } + + def executorRemovedFromJson(json: JValue): SparkListenerExecutorRemoved = { + val executorId = (json \ "Executor ID").extract[String] + SparkListenerExecutorRemoved(executorId) + } /** --------------------------------------------------------------------- * * JSON deserialization methods for classes SparkListenerEvents depend on | @@ -604,8 +644,8 @@ private[spark] object JsonProtocol { Utils.jsonOption(json \ "Shuffle Read Metrics").map(shuffleReadMetricsFromJson)) metrics.shuffleWriteMetrics = Utils.jsonOption(json \ "Shuffle Write Metrics").map(shuffleWriteMetricsFromJson) - metrics.inputMetrics = - Utils.jsonOption(json \ "Input Metrics").map(inputMetricsFromJson) + metrics.setInputMetrics( + Utils.jsonOption(json \ "Input Metrics").map(inputMetricsFromJson)) metrics.outputMetrics = Utils.jsonOption(json \ "Output Metrics").map(outputMetricsFromJson) metrics.updatedBlocks = @@ -638,7 +678,7 @@ private[spark] object JsonProtocol { def inputMetricsFromJson(json: JValue): InputMetrics = { val metrics = new InputMetrics( DataReadMethod.withName((json \ "Data Read Method").extract[String])) - metrics.bytesRead = (json \ "Bytes Read").extract[Long] + metrics.addBytesRead((json \ "Bytes Read").extract[Long]) metrics } @@ -745,6 +785,11 @@ private[spark] object JsonProtocol { BlockStatus(storageLevel, memorySize, diskSize, tachyonSize) } + def executorInfoFromJson(json: JValue): ExecutorInfo = { + val executorHost = (json \ "Host").extract[String] + val totalCores = (json \ "Total Cores").extract[Int] + new ExecutorInfo(executorHost, totalCores) + } /** -------------------------------- * * Util JSON deserialization methods | diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index c4f1898a2db1..2c04e4ddfbcb 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -701,7 +701,7 @@ private[spark] object Utils extends Logging { } } - private var customHostname: Option[String] = None + private var customHostname: Option[String] = sys.env.get("SPARK_LOCAL_HOSTNAME") /** * Allow setting a custom host name because when we run on Mesos we need to use the same @@ -1690,17 +1690,15 @@ private[spark] object Utils extends Logging { } /** - * Default maximum number of retries when binding to a port before giving up. + * Maximum number of retries when binding to a port before giving up. */ - val portMaxRetries: Int = { - if (sys.props.contains("spark.testing")) { + def portMaxRetries(conf: SparkConf): Int = { + val maxRetries = conf.getOption("spark.port.maxRetries").map(_.toInt) + if (conf.contains("spark.testing")) { // Set a higher number of retries for tests... - sys.props.get("spark.port.maxRetries").map(_.toInt).getOrElse(100) + maxRetries.getOrElse(100) } else { - Option(SparkEnv.get) - .flatMap(_.conf.getOption("spark.port.maxRetries")) - .map(_.toInt) - .getOrElse(16) + maxRetries.getOrElse(16) } } @@ -1709,17 +1707,18 @@ private[spark] object Utils extends Logging { * Each subsequent attempt uses 1 + the port used in the previous attempt (unless the port is 0). * * @param startPort The initial port to start the service on. - * @param maxRetries Maximum number of retries to attempt. - * A value of 3 means attempting ports n, n+1, n+2, and n+3, for example. * @param startService Function to start service on a given port. * This is expected to throw java.net.BindException on port collision. + * @param conf A SparkConf used to get the maximum number of retries when binding to a port. + * @param serviceName Name of the service. */ def startServiceOnPort[T]( startPort: Int, startService: Int => (T, Int), - serviceName: String = "", - maxRetries: Int = portMaxRetries): (T, Int) = { + conf: SparkConf, + serviceName: String = ""): (T, Int) = { val serviceString = if (serviceName.isEmpty) "" else s" '$serviceName'" + val maxRetries = portMaxRetries(conf) for (offset <- 0 to maxRetries) { // Do not increment port if startPort is 0, which is treated as a special port val tryPort = if (startPort == 0) { diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 5ce299d05824..07b1e44d04be 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -820,7 +820,7 @@ public void persist() { @Test public void iterator() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); - TaskContext context = new TaskContextImpl(0, 0, 0L, false, new TaskMetrics()); + TaskContext context = new TaskContextImpl(0, 0, 0L, 0, false, new TaskMetrics()); Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue()); } diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index c0735f448d19..d7d9dc7b50f3 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -66,7 +66,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar // in blockManager.put is a losing battle. You have been warned. blockManager = sc.env.blockManager cacheManager = sc.env.cacheManager - val context = new TaskContextImpl(0, 0, 0) + val context = new TaskContextImpl(0, 0, 0, 0) val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) val getValue = blockManager.get(RDDBlockId(rdd.id, split.index)) assert(computeValue.toList === List(1, 2, 3, 4)) @@ -81,7 +81,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } whenExecuting(blockManager) { - val context = new TaskContextImpl(0, 0, 0) + val context = new TaskContextImpl(0, 0, 0, 0) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(5, 6, 7)) } @@ -94,7 +94,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } whenExecuting(blockManager) { - val context = new TaskContextImpl(0, 0, 0, true) + val context = new TaskContextImpl(0, 0, 0, 0, true) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(1, 2, 3, 4)) } @@ -102,7 +102,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar test("verify task metrics updated correctly") { cacheManager = sc.env.cacheManager - val context = new TaskContextImpl(0, 0, 0) + val context = new TaskContextImpl(0, 0, 0, 0) cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY) assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2) } diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index f8bcde12a371..10a39990f80c 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -17,66 +17,185 @@ package org.apache.spark.metrics -import java.io.{FileWriter, PrintWriter, File} +import java.io.{File, FileWriter, PrintWriter} -import org.apache.spark.SharedSparkContext -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.scheduler.{SparkListenerTaskEnd, SparkListener} +import scala.collection.mutable.ArrayBuffer import org.scalatest.FunSuite -import org.scalatest.Matchers import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{Path, FileSystem} +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.io.{LongWritable, Text} +import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} -import scala.collection.mutable.ArrayBuffer +import org.apache.spark.SharedSparkContext +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} +import org.apache.spark.util.Utils + +class InputOutputMetricsSuite extends FunSuite with SharedSparkContext { -class InputOutputMetricsSuite extends FunSuite with SharedSparkContext with Matchers { - test("input metrics when reading text file with single split") { - val file = new File(getClass.getSimpleName + ".txt") - val pw = new PrintWriter(new FileWriter(file)) - pw.println("some stuff") - pw.println("some other stuff") - pw.println("yet more stuff") - pw.println("too much stuff") + @transient var tmpDir: File = _ + @transient var tmpFile: File = _ + @transient var tmpFilePath: String = _ + + override def beforeAll() { + super.beforeAll() + + tmpDir = Utils.createTempDir() + val testTempDir = new File(tmpDir, "test") + testTempDir.mkdir() + + tmpFile = new File(testTempDir, getClass.getSimpleName + ".txt") + val pw = new PrintWriter(new FileWriter(tmpFile)) + for (x <- 1 to 1000000) { + pw.println("s") + } pw.close() - file.deleteOnExit() - val taskBytesRead = new ArrayBuffer[Long]() - sc.addSparkListener(new SparkListener() { - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { - taskBytesRead += taskEnd.taskMetrics.inputMetrics.get.bytesRead - } - }) - sc.textFile("file://" + file.getAbsolutePath, 2).count() + // Path to tmpFile + tmpFilePath = "file://" + tmpFile.getAbsolutePath + } - // Wait for task end events to come in - sc.listenerBus.waitUntilEmpty(500) - assert(taskBytesRead.length == 2) - assert(taskBytesRead.sum >= file.length()) + override def afterAll() { + super.afterAll() + Utils.deleteRecursively(tmpDir) } - test("input metrics when reading text file with multiple splits") { - val file = new File(getClass.getSimpleName + ".txt") - val pw = new PrintWriter(new FileWriter(file)) - for (i <- 0 until 10000) { - pw.println("some stuff") + test("input metrics for old hadoop with coalesce") { + val bytesRead = runAndReturnBytesRead { + sc.textFile(tmpFilePath, 4).count() + } + val bytesRead2 = runAndReturnBytesRead { + sc.textFile(tmpFilePath, 4).coalesce(2).count() + } + assert(bytesRead != 0) + assert(bytesRead == bytesRead2) + assert(bytesRead2 >= tmpFile.length()) + } + + test("input metrics with cache and coalesce") { + // prime the cache manager + val rdd = sc.textFile(tmpFilePath, 4).cache() + rdd.collect() + + val bytesRead = runAndReturnBytesRead { + rdd.count() + } + val bytesRead2 = runAndReturnBytesRead { + rdd.coalesce(4).count() } - pw.close() - file.deleteOnExit() + // for count and coelesce, the same bytes should be read. + assert(bytesRead != 0) + assert(bytesRead2 == bytesRead) + } + + /** + * This checks the situation where we have interleaved reads from + * different sources. Currently, we only accumulate fron the first + * read method we find in the task. This test uses cartesian to create + * the interleaved reads. + * + * Once https://issues.apache.org/jira/browse/SPARK-5225 is fixed + * this test should break. + */ + test("input metrics with mixed read method") { + // prime the cache manager + val numPartitions = 2 + val rdd = sc.parallelize(1 to 100, numPartitions).cache() + rdd.collect() + + val rdd2 = sc.textFile(tmpFilePath, numPartitions) + + val bytesRead = runAndReturnBytesRead { + rdd.count() + } + val bytesRead2 = runAndReturnBytesRead { + rdd2.count() + } + + val cartRead = runAndReturnBytesRead { + rdd.cartesian(rdd2).count() + } + + assert(cartRead != 0) + assert(bytesRead != 0) + // We read from the first rdd of the cartesian once per partition. + assert(cartRead == bytesRead * numPartitions) + } + + test("input metrics for new Hadoop API with coalesce") { + val bytesRead = runAndReturnBytesRead { + sc.newAPIHadoopFile(tmpFilePath, classOf[NewTextInputFormat], classOf[LongWritable], + classOf[Text]).count() + } + val bytesRead2 = runAndReturnBytesRead { + sc.newAPIHadoopFile(tmpFilePath, classOf[NewTextInputFormat], classOf[LongWritable], + classOf[Text]).coalesce(5).count() + } + assert(bytesRead != 0) + assert(bytesRead2 == bytesRead) + assert(bytesRead >= tmpFile.length()) + } + + test("input metrics when reading text file") { + val bytesRead = runAndReturnBytesRead { + sc.textFile(tmpFilePath, 2).count() + } + assert(bytesRead >= tmpFile.length()) + } + + test("input metrics with interleaved reads") { + val numPartitions = 2 + val cartVector = 0 to 9 + val cartFile = new File(tmpDir, getClass.getSimpleName + "_cart.txt") + val cartFilePath = "file://" + cartFile.getAbsolutePath + + // write files to disk so we can read them later. + sc.parallelize(cartVector).saveAsTextFile(cartFilePath) + val aRdd = sc.textFile(cartFilePath, numPartitions) + + val tmpRdd = sc.textFile(tmpFilePath, numPartitions) + + val firstSize= runAndReturnBytesRead { + aRdd.count() + } + val secondSize = runAndReturnBytesRead { + tmpRdd.count() + } + + val cartesianBytes = runAndReturnBytesRead { + aRdd.cartesian(tmpRdd).count() + } + + // Computing the amount of bytes read for a cartesian operation is a little involved. + // Cartesian interleaves reads between two partitions eg. p1 and p2. + // Here are the steps: + // 1) First it creates an iterator for p1 + // 2) Creates an iterator for p2 + // 3) Reads the first element of p1 and then all the elements of p2 + // 4) proceeds to the next element of p1 + // 5) Creates a new iterator for p2 + // 6) rinse and repeat. + // As a result we read from the second partition n times where n is the number of keys in + // p1. Thus the math below for the test. + assert(cartesianBytes != 0) + assert(cartesianBytes == firstSize * numPartitions + (cartVector.length * secondSize)) + } + + private def runAndReturnBytesRead(job : => Unit): Long = { val taskBytesRead = new ArrayBuffer[Long]() sc.addSparkListener(new SparkListener() { override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { taskBytesRead += taskEnd.taskMetrics.inputMetrics.get.bytesRead } }) - sc.textFile("file://" + file.getAbsolutePath, 2).count() - // Wait for task end events to come in + job + sc.listenerBus.waitUntilEmpty(500) - assert(taskBytesRead.length == 2) - assert(taskBytesRead.sum >= file.length()) + taskBytesRead.sum } test("output metrics when writing text file") { diff --git a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala index 1b112f1a41ca..cd193ae4f523 100644 --- a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala @@ -76,6 +76,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { assert(slices(0).mkString(",") === (0 to 32).mkString(",")) assert(slices(1).mkString(",") === (33 to 66).mkString(",")) assert(slices(2).mkString(",") === (67 to 100).mkString(",")) + assert(slices(2).isInstanceOf[Range.Inclusive]) } test("empty data") { @@ -227,4 +228,28 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { assert(slices.map(_.size).reduceLeft(_+_) === 100) assert(slices.forall(_.isInstanceOf[NumericRange[_]])) } + + test("inclusive ranges with Int.MaxValue and Int.MinValue") { + val data1 = 1 to Int.MaxValue + val slices1 = ParallelCollectionRDD.slice(data1, 3) + assert(slices1.size === 3) + assert(slices1.map(_.size).sum === Int.MaxValue) + assert(slices1(2).isInstanceOf[Range.Inclusive]) + val data2 = -2 to Int.MinValue by -1 + val slices2 = ParallelCollectionRDD.slice(data2, 3) + assert(slices2.size == 3) + assert(slices2.map(_.size).sum === Int.MaxValue) + assert(slices2(2).isInstanceOf[Range.Inclusive]) + } + + test("empty ranges with Int.MaxValue and Int.MinValue") { + val data1 = Int.MaxValue until Int.MaxValue + val slices1 = ParallelCollectionRDD.slice(data1, 5) + assert(slices1.size === 5) + for (i <- 0 until 5) assert(slices1(i).size === 0) + val data2 = Int.MaxValue until Int.MaxValue + val slices2 = ParallelCollectionRDD.slice(data2, 5) + assert(slices2.size === 5) + for (i <- 0 until 5) assert(slices2(i).size === 0) + } } diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala index 271a90c6646b..1a9a0e857e54 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala @@ -174,7 +174,7 @@ class PipedRDDSuite extends FunSuite with SharedSparkContext { } val hadoopPart1 = generateFakeHadoopPartition() val pipedRdd = new PipedRDD(nums, "printenv " + varName) - val tContext = new TaskContextImpl(0, 0, 0) + val tContext = new TaskContextImpl(0, 0, 0, 0) val rddIter = pipedRdd.compute(hadoopPart1, tContext) val arr = rddIter.toArray assert(arr(0) == "/some/path") diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 1de7e130039a..437d8693c0b1 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -160,7 +160,7 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter with Loggin */ private def testApplicationEventLogging(compressionCodec: Option[String] = None) { val conf = getLoggingConf(testDirPath, compressionCodec) - val sc = new SparkContext("local", "test", conf) + val sc = new SparkContext("local-cluster[2,2,512]", "test", conf) assert(sc.eventLogger.isDefined) val eventLogger = sc.eventLogger.get val expectedLogDir = testDir.toURI().toString() @@ -184,6 +184,7 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter with Loggin val eventSet = mutable.Set( SparkListenerApplicationStart, SparkListenerBlockManagerAdded, + SparkListenerExecutorAdded, SparkListenerEnvironmentUpdate, SparkListenerJobStart, SparkListenerJobEnd, diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 24f41bf8cccd..0fb1bdd30d97 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -34,6 +34,8 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers /** Length of time to wait while draining listener events. */ val WAIT_TIMEOUT_MILLIS = 10000 + val jobCompletionTime = 1421191296660L + before { sc = new SparkContext("local", "SparkListenerSuite") } @@ -44,7 +46,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers bus.addListener(counter) // Listener bus hasn't started yet, so posting events should not increment counter - (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, JobSucceeded)) } + (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } assert(counter.count === 0) // Starting listener bus should flush all buffered events @@ -54,7 +56,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers // After listener bus has stopped, posting events should not increment counter bus.stop() - (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, JobSucceeded)) } + (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } assert(counter.count === 5) // Listener bus must not be started twice @@ -99,7 +101,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers bus.addListener(blockingListener) bus.start() - bus.post(SparkListenerJobEnd(0, JobSucceeded)) + bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) listenerStarted.acquire() // Listener should be blocked after start @@ -345,7 +347,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers bus.start() // Post events to all listeners, and wait until the queue is drained - (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, JobSucceeded)) } + (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } assert(bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) // The exception should be caught, and the event should be propagated to other listeners diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala new file mode 100644 index 000000000000..623a687c359a --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.scheduler.cluster.ExecutorInfo +import org.apache.spark.{SparkContext, LocalSparkContext} + +import org.scalatest.{FunSuite, BeforeAndAfter, BeforeAndAfterAll} + +import scala.collection.mutable + +/** + * Unit tests for SparkListener that require a local cluster. + */ +class SparkListenerWithClusterSuite extends FunSuite with LocalSparkContext + with BeforeAndAfter with BeforeAndAfterAll { + + /** Length of time to wait while draining listener events. */ + val WAIT_TIMEOUT_MILLIS = 10000 + + before { + sc = new SparkContext("local-cluster[2,1,512]", "SparkListenerSuite") + } + + test("SparkListener sends executor added message") { + val listener = new SaveExecutorInfo + sc.addSparkListener(listener) + + val rdd1 = sc.parallelize(1 to 100, 4) + val rdd2 = rdd1.map(_.toString) + rdd2.setName("Target RDD") + rdd2.count() + + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + assert(listener.addedExecutorInfo.size == 2) + assert(listener.addedExecutorInfo("0").totalCores == 1) + assert(listener.addedExecutorInfo("1").totalCores == 1) + } + + private class SaveExecutorInfo extends SparkListener { + val addedExecutorInfo = mutable.Map[String, ExecutorInfo]() + + override def onExecutorAdded(executor: SparkListenerExecutorAdded) { + addedExecutorInfo(executor.executorId) = executor.executorInfo + } + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 561a5e9cd90c..057e22691602 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -45,13 +45,13 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte val task = new ResultTask[String, String]( 0, sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0) intercept[RuntimeException] { - task.run(0) + task.run(0, 0) } assert(TaskContextSuite.completed === true) } test("all TaskCompletionListeners should be called even if some fail") { - val context = new TaskContextImpl(0, 0, 0) + val context = new TaskContextImpl(0, 0, 0, 0) val listener = mock(classOf[TaskCompletionListener]) context.addTaskCompletionListener(_ => throw new Exception("blah")) context.addTaskCompletionListener(listener) @@ -63,6 +63,33 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte verify(listener, times(1)).onTaskCompletion(any()) } + + test("TaskContext.attemptNumber should return attempt number, not task id (SPARK-4014)") { + sc = new SparkContext("local[1,2]", "test") // use maxRetries = 2 because we test failed tasks + // Check that attemptIds are 0 for all tasks' initial attempts + val attemptIds = sc.parallelize(Seq(1, 2), 2).mapPartitions { iter => + Seq(TaskContext.get().attemptNumber).iterator + }.collect() + assert(attemptIds.toSet === Set(0)) + + // Test a job with failed tasks + val attemptIdsWithFailedTask = sc.parallelize(Seq(1, 2), 2).mapPartitions { iter => + val attemptId = TaskContext.get().attemptNumber + if (iter.next() == 1 && attemptId == 0) { + throw new Exception("First execution of task failed") + } + Seq(attemptId).iterator + }.collect() + assert(attemptIdsWithFailedTask.toSet === Set(0, 1)) + } + + test("TaskContext.attemptId returns taskAttemptId for backwards-compatibility (SPARK-4014)") { + sc = new SparkContext("local", "test") + val attemptIds = sc.parallelize(Seq(1, 2, 3, 4), 4).mapPartitions { iter => + Seq(TaskContext.get().attemptId).iterator + }.collect() + assert(attemptIds.toSet === Set(0, 1, 2, 3)) + } } private object TaskContextSuite { diff --git a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala index e60e70afd321..78a30a40bf19 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala @@ -18,17 +18,20 @@ package org.apache.spark.scheduler.mesos import org.scalatest.FunSuite -import org.apache.spark.{scheduler, SparkConf, SparkContext, LocalSparkContext} -import org.apache.spark.scheduler.{TaskDescription, WorkerOffer, TaskSchedulerImpl} +import org.apache.spark.{SparkConf, SparkContext, LocalSparkContext} +import org.apache.spark.scheduler.{SparkListenerExecutorAdded, LiveListenerBus, + TaskDescription, WorkerOffer, TaskSchedulerImpl} +import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.scheduler.cluster.mesos.{MemoryUtils, MesosSchedulerBackend} import org.apache.mesos.SchedulerDriver -import org.apache.mesos.Protos._ -import org.scalatest.mock.EasyMockSugar +import org.apache.mesos.Protos.{ExecutorInfo => MesosExecutorInfo, _} import org.apache.mesos.Protos.Value.Scalar import org.easymock.{Capture, EasyMock} import java.nio.ByteBuffer import java.util.Collections import java.util +import org.scalatest.mock.EasyMockSugar + import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -52,11 +55,16 @@ class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with Ea val driver = EasyMock.createMock(classOf[SchedulerDriver]) val taskScheduler = EasyMock.createMock(classOf[TaskSchedulerImpl]) + val listenerBus = EasyMock.createMock(classOf[LiveListenerBus]) + listenerBus.post(SparkListenerExecutorAdded("s1", new ExecutorInfo("host1", 2))) + EasyMock.replay(listenerBus) + val sc = EasyMock.createMock(classOf[SparkContext]) EasyMock.expect(sc.executorMemory).andReturn(100).anyTimes() EasyMock.expect(sc.getSparkHome()).andReturn(Option("/path")).anyTimes() EasyMock.expect(sc.executorEnvs).andReturn(new mutable.HashMap).anyTimes() EasyMock.expect(sc.conf).andReturn(new SparkConf).anyTimes() + EasyMock.expect(sc.listenerBus).andReturn(listenerBus) EasyMock.replay(sc) val minMem = MemoryUtils.calculateTotalMemory(sc).toInt @@ -80,7 +88,7 @@ class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with Ea mesosOffers.get(2).getHostname, 2 )) - val taskDesc = new TaskDescription(1L, "s1", "n1", 0, ByteBuffer.wrap(new Array[Byte](0))) + val taskDesc = new TaskDescription(1L, 0, "s1", "n1", 0, ByteBuffer.wrap(new Array[Byte](0))) EasyMock.expect(taskScheduler.resourceOffers(EasyMock.eq(expectedWorkerOffers))).andReturn(Seq(Seq(taskDesc))) EasyMock.expect(taskScheduler.CPUS_PER_TASK).andReturn(2).anyTimes() EasyMock.replay(taskScheduler) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 1eaabb93adbe..37b593b2c5f7 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -89,7 +89,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { ) val iterator = new ShuffleBlockFetcherIterator( - new TaskContextImpl(0, 0, 0), + new TaskContextImpl(0, 0, 0, 0), transfer, blockManager, blocksByAddress, @@ -154,7 +154,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) - val taskContext = new TaskContextImpl(0, 0, 0) + val taskContext = new TaskContextImpl(0, 0, 0, 0) val iterator = new ShuffleBlockFetcherIterator( taskContext, transfer, @@ -217,7 +217,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) - val taskContext = new TaskContextImpl(0, 0, 0) + val taskContext = new TaskContextImpl(0, 0, 0, 0) val iterator = new ShuffleBlockFetcherIterator( taskContext, transfer, diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 787f4c2b5a8b..e85a436cdba1 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -173,7 +173,7 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers { // Simulate fetch failures: val mappedData = data.map { x => val taskContext = TaskContext.get - if (taskContext.attemptId() == 1) { // Cause this stage to fail on its first attempt. + if (taskContext.attemptNumber == 0) { // Cause this stage to fail on its first attempt. val env = SparkEnv.get val bmAddress = env.blockManager.blockManagerId val shuffleId = shuffleHandle.shuffleId diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 12af60caf7d5..c9417ea1ed8f 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -28,6 +28,8 @@ import org.apache.spark.util.Utils class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matchers { + val jobSubmissionTime = 1421191042750L + val jobCompletionTime = 1421191296660L private def createStageStartEvent(stageId: Int) = { val stageInfo = new StageInfo(stageId, 0, stageId.toString, 0, null, "") @@ -46,12 +48,12 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc val stageInfos = stageIds.map { stageId => new StageInfo(stageId, 0, stageId.toString, 0, null, "") } - SparkListenerJobStart(jobId, stageInfos) + SparkListenerJobStart(jobId, jobSubmissionTime, stageInfos) } private def createJobEndEvent(jobId: Int, failed: Boolean = false) = { val result = if (failed) JobFailed(new Exception("dummy failure")) else JobSucceeded - SparkListenerJobEnd(jobId, result) + SparkListenerJobEnd(jobId, jobCompletionTime, result) } private def runJob(listener: SparkListener, jobId: Int, shouldFail: Boolean = false) { @@ -231,8 +233,8 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc taskMetrics.diskBytesSpilled = base + 5 taskMetrics.memoryBytesSpilled = base + 6 val inputMetrics = new InputMetrics(DataReadMethod.Hadoop) - taskMetrics.inputMetrics = Some(inputMetrics) - inputMetrics.bytesRead = base + 7 + taskMetrics.setInputMetrics(Some(inputMetrics)) + inputMetrics.addBytesRead(base + 7) val outputMetrics = new OutputMetrics(DataWriteMethod.Hadoop) taskMetrics.outputMetrics = Some(outputMetrics) outputMetrics.bytesWritten = base + 8 diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 63c2559c5c5f..db400b416291 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.util import java.util.Properties +import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.shuffle.MetadataFetchFailedException import scala.collection.Map @@ -33,6 +34,9 @@ import org.apache.spark.storage._ class JsonProtocolSuite extends FunSuite { + val jobSubmissionTime = 1421191042750L + val jobCompletionTime = 1421191296660L + test("SparkListenerEvent") { val stageSubmitted = SparkListenerStageSubmitted(makeStageInfo(100, 200, 300, 400L, 500L), properties) @@ -53,9 +57,9 @@ class JsonProtocolSuite extends FunSuite { val stageIds = Seq[Int](1, 2, 3, 4) val stageInfos = stageIds.map(x => makeStageInfo(x, x * 200, x * 300, x * 400L, x * 500L)) - SparkListenerJobStart(10, stageInfos, properties) + SparkListenerJobStart(10, jobSubmissionTime, stageInfos, properties) } - val jobEnd = SparkListenerJobEnd(20, JobSucceeded) + val jobEnd = SparkListenerJobEnd(20, jobCompletionTime, JobSucceeded) val environmentUpdate = SparkListenerEnvironmentUpdate(Map[String, Seq[(String, String)]]( "JVM Information" -> Seq(("GC speed", "9999 objects/s"), ("Java home", "Land of coffee")), "Spark Properties" -> Seq(("Job throughput", "80000 jobs/s, regardless of job type")), @@ -69,6 +73,9 @@ class JsonProtocolSuite extends FunSuite { val unpersistRdd = SparkListenerUnpersistRDD(12345) val applicationStart = SparkListenerApplicationStart("The winner of all", None, 42L, "Garfield") val applicationEnd = SparkListenerApplicationEnd(42L) + val executorAdded = SparkListenerExecutorAdded("exec1", + new ExecutorInfo("Hostee.awesome.com", 11)) + val executorRemoved = SparkListenerExecutorRemoved("exec2") testEvent(stageSubmitted, stageSubmittedJsonString) testEvent(stageCompleted, stageCompletedJsonString) @@ -85,6 +92,8 @@ class JsonProtocolSuite extends FunSuite { testEvent(unpersistRdd, unpersistRDDJsonString) testEvent(applicationStart, applicationStartJsonString) testEvent(applicationEnd, applicationEndJsonString) + testEvent(executorAdded, executorAddedJsonString) + testEvent(executorRemoved, executorRemovedJsonString) } test("Dependent Classes") { @@ -94,6 +103,7 @@ class JsonProtocolSuite extends FunSuite { testTaskMetrics(makeTaskMetrics( 33333L, 44444L, 55555L, 66666L, 7, 8, hasHadoopInput = false, hasOutput = false)) testBlockManagerId(BlockManagerId("Hong", "Kong", 500)) + testExecutorInfo(new ExecutorInfo("host", 43)) // StorageLevel testStorageLevel(StorageLevel.NONE) @@ -240,13 +250,31 @@ class JsonProtocolSuite extends FunSuite { val stageInfos = stageIds.map(x => makeStageInfo(x, x * 200, x * 300, x * 400, x * 500)) val dummyStageInfos = stageIds.map(id => new StageInfo(id, 0, "unknown", 0, Seq.empty, "unknown")) - val jobStart = SparkListenerJobStart(10, stageInfos, properties) + val jobStart = SparkListenerJobStart(10, jobSubmissionTime, stageInfos, properties) val oldEvent = JsonProtocol.jobStartToJson(jobStart).removeField({_._1 == "Stage Infos"}) val expectedJobStart = - SparkListenerJobStart(10, dummyStageInfos, properties) + SparkListenerJobStart(10, jobSubmissionTime, dummyStageInfos, properties) assertEquals(expectedJobStart, JsonProtocol.jobStartFromJson(oldEvent)) } + test("SparkListenerJobStart and SparkListenerJobEnd backward compatibility") { + // Prior to Spark 1.3.0, SparkListenerJobStart did not have a "Submission Time" property. + // Also, SparkListenerJobEnd did not have a "Completion Time" property. + val stageIds = Seq[Int](1, 2, 3, 4) + val stageInfos = stageIds.map(x => makeStageInfo(x * 10, x * 20, x * 30, x * 40, x * 50)) + val jobStart = SparkListenerJobStart(11, jobSubmissionTime, stageInfos, properties) + val oldStartEvent = JsonProtocol.jobStartToJson(jobStart) + .removeField({ _._1 == "Submission Time"}) + val expectedJobStart = SparkListenerJobStart(11, -1, stageInfos, properties) + assertEquals(expectedJobStart, JsonProtocol.jobStartFromJson(oldStartEvent)) + + val jobEnd = SparkListenerJobEnd(11, jobCompletionTime, JobSucceeded) + val oldEndEvent = JsonProtocol.jobEndToJson(jobEnd) + .removeField({ _._1 == "Completion Time"}) + val expectedJobEnd = SparkListenerJobEnd(11, -1, JobSucceeded) + assertEquals(expectedJobEnd, JsonProtocol.jobEndFromJson(oldEndEvent)) + } + /** -------------------------- * | Helper test running methods | * --------------------------- */ @@ -303,6 +331,10 @@ class JsonProtocolSuite extends FunSuite { assert(blockId === newBlockId) } + private def testExecutorInfo(info: ExecutorInfo) { + val newInfo = JsonProtocol.executorInfoFromJson(JsonProtocol.executorInfoToJson(info)) + assertEquals(info, newInfo) + } /** -------------------------------- * | Util methods for comparing events | @@ -335,6 +367,11 @@ class JsonProtocolSuite extends FunSuite { assertEquals(e1.jobResult, e2.jobResult) case (e1: SparkListenerEnvironmentUpdate, e2: SparkListenerEnvironmentUpdate) => assertEquals(e1.environmentDetails, e2.environmentDetails) + case (e1: SparkListenerExecutorAdded, e2: SparkListenerExecutorAdded) => + assert(e1.executorId == e1.executorId) + assertEquals(e1.executorInfo, e2.executorInfo) + case (e1: SparkListenerExecutorRemoved, e2: SparkListenerExecutorRemoved) => + assert(e1.executorId == e1.executorId) case (e1, e2) => assert(e1 === e2) case _ => fail("Events don't match in types!") @@ -387,6 +424,11 @@ class JsonProtocolSuite extends FunSuite { assert(info1.accumulables === info2.accumulables) } + private def assertEquals(info1: ExecutorInfo, info2: ExecutorInfo) { + assert(info1.executorHost == info2.executorHost) + assert(info1.totalCores == info2.totalCores) + } + private def assertEquals(metrics1: TaskMetrics, metrics2: TaskMetrics) { assert(metrics1.hostname === metrics2.hostname) assert(metrics1.executorDeserializeTime === metrics2.executorDeserializeTime) @@ -609,8 +651,8 @@ class JsonProtocolSuite extends FunSuite { if (hasHadoopInput) { val inputMetrics = new InputMetrics(DataReadMethod.Hadoop) - inputMetrics.bytesRead = d + e + f - t.inputMetrics = Some(inputMetrics) + inputMetrics.addBytesRead(d + e + f) + t.setInputMetrics(Some(inputMetrics)) } else { val sr = new ShuffleReadMetrics sr.remoteBytesRead = b + d @@ -1054,6 +1096,7 @@ class JsonProtocolSuite extends FunSuite { |{ | "Event": "SparkListenerJobStart", | "Job ID": 10, + | "Submission Time": 1421191042750, | "Stage Infos": [ | { | "Stage ID": 1, @@ -1328,6 +1371,7 @@ class JsonProtocolSuite extends FunSuite { |{ | "Event": "SparkListenerJobEnd", | "Job ID": 20, + | "Completion Time": 1421191296660, | "Job Result": { | "Result": "JobSucceeded" | } @@ -1407,4 +1451,24 @@ class JsonProtocolSuite extends FunSuite { | "Timestamp": 42 |} """ + + private val executorAddedJsonString = + """ + |{ + | "Event": "SparkListenerExecutorAdded", + | "Executor ID": "exec1", + | "Executor Info": { + | "Host": "Hostee.awesome.com", + | "Total Cores": 11 + | } + |} + """ + + private val executorRemovedJsonString = + """ + |{ + | "Event": "SparkListenerExecutorRemoved", + | "Executor ID": "exec2" + |} + """ } diff --git a/docs/configuration.md b/docs/configuration.md index f292bfbb7dcd..efbab4085317 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -102,11 +102,10 @@ of the most common options to set are: - spark.executor.memory - 512m + spark.driver.cores + 1 - Amount of memory to use per executor process, in the same format as JVM memory strings - (e.g. 512m, 2g). + Number of cores to use for the driver process, only in cluster mode. @@ -117,6 +116,14 @@ of the most common options to set are: (e.g. 512m, 2g). + + spark.executor.memory + 512m + + Amount of memory to use per executor process, in the same format as JVM memory strings + (e.g. 512m, 2g). + + spark.driver.maxResultSize 1g @@ -1228,7 +1235,7 @@ Apart from these, the following properties are also available, and may be useful - spark.streaming.receiver.writeAheadLogs.enable + spark.streaming.receiver.writeAheadLog.enable false Enable write ahead logs for receivers. All the input data received through receivers diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 1c2e27341473..be178d7689fd 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -3,13 +3,16 @@ layout: global title: Spark ML Programming Guide --- -Spark ML is Spark's new machine learning package. It is currently an alpha component but is potentially a successor to [MLlib](mllib-guide.html). The `spark.ml` package aims to replace the old APIs with a cleaner, more uniform set of APIs which will help users create full machine learning pipelines. +`spark.ml` is a new package introduced in Spark 1.2, which aims to provide a uniform set of +high-level APIs that help users create and tune practical machine learning pipelines. +It is currently an alpha component, and we would like to hear back from the community about +how it fits real-world use cases and how it could be improved. -MLlib vs. Spark ML: - -* Users can use algorithms from either of the two packages, but APIs may differ. Currently, `spark.ml` offers a subset of the algorithms from `spark.mllib`. Since Spark ML is an alpha component, its API may change in future releases. -* Developers should contribute new algorithms to `spark.mllib` and can optionally contribute to `spark.ml`. See below for more details. -* Spark ML only has Scala and Java APIs, whereas MLlib also has a Python API. +Note that we will keep supporting and adding features to `spark.mllib` along with the +development of `spark.ml`. +Users should be comfortable using `spark.mllib` features and expect more features coming. +Developers should contribute new algorithms to `spark.mllib` and can optionally contribute +to `spark.ml`. **Table of Contents** @@ -686,17 +689,3 @@ Spark ML currently depends on MLlib and has the same dependencies. Please see the [MLlib Dependencies guide](mllib-guide.html#Dependencies) for more info. Spark ML also depends upon Spark SQL, but the relevant parts of Spark SQL do not bring additional dependencies. - -# Developers - -**Development plan** - -If all goes well, `spark.ml` will become the primary ML package at the time of the Spark 1.3 release. Initially, simple wrappers will be used to port algorithms to `spark.ml`, but eventually, code will be moved to `spark.ml` and `spark.mllib` will be deprecated. - -**Advice to developers** - -During the next development cycle, new algorithms should be contributed to `spark.mllib`, but we welcome patches sent to either package. If an algorithm is best expressed using the new API (e.g., feature transformers), we may ask for developers to use the new `spark.ml` API. -Wrappers for old and new algorithms can be contributed to `spark.ml`. - -Users will be able to use algorithms from either of the two packages. The main difficulty will be the differences in APIs between the two packages. - diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index efd7dda31071..39c64d06926b 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -35,16 +35,20 @@ MLlib is under active development. The APIs marked `Experimental`/`DeveloperApi` may change in future releases, and the migration guide below will explain all changes between releases. -# spark.ml: The New ML Package +# spark.ml: high-level APIs for ML pipelines -Spark 1.2 includes a new machine learning package called `spark.ml`, currently an alpha component but potentially a successor to `spark.mllib`. The `spark.ml` package aims to replace the old APIs with a cleaner, more uniform set of APIs which will help users create full machine learning pipelines. +Spark 1.2 includes a new package called `spark.ml`, which aims to provide a uniform set of +high-level APIs that help users create and tune practical machine learning pipelines. +It is currently an alpha component, and we would like to hear back from the community about +how it fits real-world use cases and how it could be improved. -See the **[spark.ml programming guide](ml-guide.html)** for more information on this package. - -Users can use algorithms from either of the two packages, but APIs may differ. Currently, `spark.ml` offers a subset of the algorithms from `spark.mllib`. +Note that we will keep supporting and adding features to `spark.mllib` along with the +development of `spark.ml`. +Users should be comfortable using `spark.mllib` features and expect more features coming. +Developers should contribute new algorithms to `spark.mllib` and can optionally contribute +to `spark.ml`. -Developers should contribute new algorithms to `spark.mllib` and can optionally contribute to `spark.ml`. -See the `spark.ml` programming guide linked above for more details. +See the **[spark.ml programming guide](ml-guide.html)** for more information on this package. # Dependencies diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 5e0d5c15d706..2443fc29b470 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -913,7 +913,7 @@ for details. cogroup(otherDataset, [numTasks]) - When called on datasets of type (K, V) and (K, W), returns a dataset of (K, Iterable<V>, Iterable<W>) tuples. This operation is also called groupWith. + When called on datasets of type (K, V) and (K, W), returns a dataset of (K, (Iterable<V>, Iterable<W>)) tuples. This operation is also called groupWith. cartesian(otherDataset) @@ -1316,7 +1316,35 @@ For accumulator updates performed inside actions only, Spark guarantees t will only be applied once, i.e. restarted tasks will not update the value. In transformations, users should be aware of that each task's update may be applied more than once if tasks or job stages are re-executed. +Accumulators do not change the lazy evaluation model of Spark. If they are being updated within an operation on an RDD, their value is only updated once that RDD is computed as part of an action. Consequently, accumulator updates are not guaranteed to be executed when made within a lazy transformation like `map()`. The below code fragment demonstrates this property: +
    + +
    +{% highlight scala %} +val acc = sc.accumulator(0) +data.map(x => acc += x; f(x)) +// Here, acc is still 0 because no actions have cause the `map` to be computed. +{% endhighlight %} +
    + +
    +{% highlight java %} +Accumulator accum = sc.accumulator(0); +data.map(x -> accum.add(x); f(x);); +// Here, accum is still 0 because no actions have cause the `map` to be computed. +{% endhighlight %} +
    + +
    +{% highlight python %} +accum = sc.accumulator(0) +data.map(lambda x => acc.add(x); f(x)) +# Here, acc is still 0 because no actions have cause the `map` to be computed. +{% endhighlight %} +
    + +
    # Deploying to a Cluster diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 4f273098c5db..68ab127bcf08 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -29,6 +29,23 @@ Most of the configs are the same for Spark on YARN as for other deployment modes In cluster mode, use spark.driver.memory instead. + + spark.driver.cores + 1 + + Number of cores used by the driver in YARN cluster mode. + Since the driver is run in the same JVM as the YARN Application Master in cluster mode, this also controls the cores used by the YARN AM. + In client mode, use spark.yarn.am.cores to control the number of cores used by the YARN AM instead. + + + + spark.yarn.am.cores + 1 + + Number of cores to use for the YARN Application Master in client mode. + In cluster mode, use spark.driver.cores instead. + + spark.yarn.am.waitTime 100000 diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 729045b81a8c..be8c5c2c1522 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1333,9 +1333,9 @@ import org.apache.spark.sql._
    All data types of Spark SQL are located in the package of -`org.apache.spark.sql.api.java`. To access or create a data type, +`org.apache.spark.sql.types`. To access or create a data type, please use factory methods provided in -`org.apache.spark.sql.api.java.DataType`. +`org.apache.spark.sql.types.DataTypes`. @@ -1346,109 +1346,110 @@ please use factory methods provided in @@ -1458,7 +1459,7 @@ please use factory methods provided in
    ByteType byte or Byte - DataType.ByteType + DataTypes.ByteType
    ShortType short or Short - DataType.ShortType + DataTypes.ShortType
    IntegerType int or Integer - DataType.IntegerType + DataTypes.IntegerType
    LongType long or Long - DataType.LongType + DataTypes.LongType
    FloatType float or Float - DataType.FloatType + DataTypes.FloatType
    DoubleType double or Double - DataType.DoubleType + DataTypes.DoubleType
    DecimalType java.math.BigDecimal - DataType.DecimalType + DataTypes.createDecimalType()
    + DataTypes.createDecimalType(precision, scale).
    StringType String - DataType.StringType + DataTypes.StringType
    BinaryType byte[] - DataType.BinaryType + DataTypes.BinaryType
    BooleanType boolean or Boolean - DataType.BooleanType + DataTypes.BooleanType
    TimestampType java.sql.Timestamp - DataType.TimestampType + DataTypes.TimestampType
    DateType java.sql.Date - DataType.DateType + DataTypes.DateType
    ArrayType java.util.List - DataType.createArrayType(elementType)
    + DataTypes.createArrayType(elementType)
    Note: The value of containsNull will be true
    - DataType.createArrayType(elementType, containsNull). + DataTypes.createArrayType(elementType, containsNull).
    MapType java.util.Map - DataType.createMapType(keyType, valueType)
    + DataTypes.createMapType(keyType, valueType)
    Note: The value of valueContainsNull will be true.
    - DataType.createMapType(keyType, valueType, valueContainsNull)
    + DataTypes.createMapType(keyType, valueType, valueContainsNull)
    StructType org.apache.spark.sql.api.java.Row - DataType.createStructType(fields)
    + DataTypes.createStructType(fields)
    Note: fields is a List or an array of StructFields. Also, two fields with the same name are not allowed.
    The value type in Java of the data type of this field (For example, int for a StructField with the data type IntegerType) - DataType.createStructField(name, dataType, nullable) + DataTypes.createStructField(name, dataType, nullable)
    diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 01450efe35e5..e37a2bb37b9a 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -1574,7 +1574,7 @@ To run a Spark Streaming applications, you need to have the following. recovery, thus ensuring zero data loss (discussed in detail in the [Fault-tolerance Semantics](#fault-tolerance-semantics) section). This can be enabled by setting the [configuration parameter](configuration.html#spark-streaming) - `spark.streaming.receiver.writeAheadLogs.enable` to `true`. However, these stronger semantics may + `spark.streaming.receiver.writeAheadLog.enable` to `true`. However, these stronger semantics may come at the cost of the receiving throughput of individual receivers. This can be corrected by running [more receivers in parallel](#level-of-parallelism-in-data-receiving) to increase aggregate throughput. Additionally, it is recommended that the replication of the diff --git a/examples/pom.xml b/examples/pom.xml index 002d4458c4b3..4b92147725f6 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -392,29 +392,6 @@ - - hbase-hadoop2 - - - hbase.profile - hadoop2 - - - - 0.98.7-hadoop2 - - - - hbase-hadoop1 - - - !hbase.profile - - - - 0.98.7-hadoop1 - - diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java index f4b4f8d8c7b2..247d2a5e31a8 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java @@ -33,9 +33,9 @@ import org.apache.spark.ml.tuning.CrossValidator; import org.apache.spark.ml.tuning.CrossValidatorModel; import org.apache.spark.ml.tuning.ParamGridBuilder; -import org.apache.spark.sql.api.java.JavaSQLContext; -import org.apache.spark.sql.api.java.JavaSchemaRDD; -import org.apache.spark.sql.api.java.Row; +import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.Row; /** * A simple example demonstrating model selection using CrossValidator. @@ -55,7 +55,7 @@ public class JavaCrossValidatorExample { public static void main(String[] args) { SparkConf conf = new SparkConf().setAppName("JavaCrossValidatorExample"); JavaSparkContext jsc = new JavaSparkContext(conf); - JavaSQLContext jsql = new JavaSQLContext(jsc); + SQLContext jsql = new SQLContext(jsc); // Prepare training documents, which are labeled. List localTraining = Lists.newArrayList( @@ -71,8 +71,7 @@ public static void main(String[] args) { new LabeledDocument(9L, "a e c l", 0.0), new LabeledDocument(10L, "spark compile", 1.0), new LabeledDocument(11L, "hadoop software", 0.0)); - JavaSchemaRDD training = - jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); + SchemaRDD training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -113,11 +112,11 @@ public static void main(String[] args) { new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), new Document(7L, "apache hadoop")); - JavaSchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), Document.class); + SchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), Document.class); // Make predictions on test documents. cvModel uses the best model found (lrModel). cvModel.transform(test).registerAsTable("prediction"); - JavaSchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); + SchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); for (Row r: predictions.collect()) { System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2) + ", prediction=" + r.get(3)); diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java index e25b271777ed..5b92655e2e83 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java @@ -28,9 +28,9 @@ import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.api.java.JavaSQLContext; -import org.apache.spark.sql.api.java.JavaSchemaRDD; -import org.apache.spark.sql.api.java.Row; +import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.Row; /** * A simple example demonstrating ways to specify parameters for Estimators and Transformers. @@ -44,7 +44,7 @@ public class JavaSimpleParamsExample { public static void main(String[] args) { SparkConf conf = new SparkConf().setAppName("JavaSimpleParamsExample"); JavaSparkContext jsc = new JavaSparkContext(conf); - JavaSQLContext jsql = new JavaSQLContext(jsc); + SQLContext jsql = new SQLContext(jsc); // Prepare training data. // We use LabeledPoint, which is a JavaBean. Spark SQL can convert RDDs of JavaBeans @@ -54,7 +54,7 @@ public static void main(String[] args) { new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))); - JavaSchemaRDD training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class); + SchemaRDD training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class); // Create a LogisticRegression instance. This instance is an Estimator. LogisticRegression lr = new LogisticRegression(); @@ -94,14 +94,14 @@ public static void main(String[] args) { new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))); - JavaSchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class); + SchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class); // Make predictions on test documents using the Transformer.transform() method. // LogisticRegression.transform will only use the 'features' column. // Note that model2.transform() outputs a 'probability' column instead of the usual 'score' // column since we renamed the lr.scoreCol parameter previously. model2.transform(test).registerAsTable("results"); - JavaSchemaRDD results = + SchemaRDD results = jsql.sql("SELECT features, label, probability, prediction FROM results"); for (Row r: results.collect()) { System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java index 54f18014e4b2..74db449fada7 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java @@ -21,6 +21,7 @@ import com.google.common.collect.Lists; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.Pipeline; import org.apache.spark.ml.PipelineModel; @@ -28,10 +29,9 @@ import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.feature.HashingTF; import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.sql.api.java.JavaSQLContext; -import org.apache.spark.sql.api.java.JavaSchemaRDD; -import org.apache.spark.sql.api.java.Row; -import org.apache.spark.SparkConf; +import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.Row; /** * A simple text classification pipeline that recognizes "spark" from input text. It uses the Java @@ -46,7 +46,7 @@ public class JavaSimpleTextClassificationPipeline { public static void main(String[] args) { SparkConf conf = new SparkConf().setAppName("JavaSimpleTextClassificationPipeline"); JavaSparkContext jsc = new JavaSparkContext(conf); - JavaSQLContext jsql = new JavaSQLContext(jsc); + SQLContext jsql = new SQLContext(jsc); // Prepare training documents, which are labeled. List localTraining = Lists.newArrayList( @@ -54,8 +54,7 @@ public static void main(String[] args) { new LabeledDocument(1L, "b d", 0.0), new LabeledDocument(2L, "spark f g h", 1.0), new LabeledDocument(3L, "hadoop mapreduce", 0.0)); - JavaSchemaRDD training = - jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); + SchemaRDD training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -80,11 +79,11 @@ public static void main(String[] args) { new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), new Document(7L, "apache hadoop")); - JavaSchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), Document.class); + SchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), Document.class); // Make predictions on test documents. model.transform(test).registerAsTable("prediction"); - JavaSchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); + SchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); for (Row r: predictions.collect()) { System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2) + ", prediction=" + r.get(3)); diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java index 01c77bd44337..b70804635d5c 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java @@ -26,9 +26,9 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; -import org.apache.spark.sql.api.java.JavaSQLContext; -import org.apache.spark.sql.api.java.JavaSchemaRDD; -import org.apache.spark.sql.api.java.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.Row; public class JavaSparkSQL { public static class Person implements Serializable { @@ -55,7 +55,7 @@ public void setAge(int age) { public static void main(String[] args) throws Exception { SparkConf sparkConf = new SparkConf().setAppName("JavaSparkSQL"); JavaSparkContext ctx = new JavaSparkContext(sparkConf); - JavaSQLContext sqlCtx = new JavaSQLContext(ctx); + SQLContext sqlCtx = new SQLContext(ctx); System.out.println("=== Data source: RDD ==="); // Load a text file and convert each line to a Java Bean. @@ -74,15 +74,15 @@ public Person call(String line) { }); // Apply a schema to an RDD of Java Beans and register it as a table. - JavaSchemaRDD schemaPeople = sqlCtx.applySchema(people, Person.class); + SchemaRDD schemaPeople = sqlCtx.applySchema(people, Person.class); schemaPeople.registerTempTable("people"); // SQL can be run over RDDs that have been registered as tables. - JavaSchemaRDD teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); + SchemaRDD teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); // The results of SQL queries are SchemaRDDs and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. - List teenagerNames = teenagers.map(new Function() { + List teenagerNames = teenagers.toJavaRDD().map(new Function() { @Override public String call(Row row) { return "Name: " + row.getString(0); @@ -99,13 +99,13 @@ public String call(Row row) { // Read in the parquet file created above. // Parquet files are self-describing so the schema is preserved. // The result of loading a parquet file is also a JavaSchemaRDD. - JavaSchemaRDD parquetFile = sqlCtx.parquetFile("people.parquet"); + SchemaRDD parquetFile = sqlCtx.parquetFile("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); - JavaSchemaRDD teenagers2 = + SchemaRDD teenagers2 = sqlCtx.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); - teenagerNames = teenagers2.map(new Function() { + teenagerNames = teenagers2.toJavaRDD().map(new Function() { @Override public String call(Row row) { return "Name: " + row.getString(0); @@ -120,7 +120,7 @@ public String call(Row row) { // The path can be either a single text file or a directory storing text files. String path = "examples/src/main/resources/people.json"; // Create a JavaSchemaRDD from the file(s) pointed by path - JavaSchemaRDD peopleFromJsonFile = sqlCtx.jsonFile(path); + SchemaRDD peopleFromJsonFile = sqlCtx.jsonFile(path); // Because the schema of a JSON dataset is automatically inferred, to write queries, // it is better to take a look at what is the schema. @@ -134,11 +134,11 @@ public String call(Row row) { peopleFromJsonFile.registerTempTable("people"); // SQL statements can be run by using the sql methods provided by sqlCtx. - JavaSchemaRDD teenagers3 = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); + SchemaRDD teenagers3 = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); // The results of SQL queries are JavaSchemaRDDs and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. - teenagerNames = teenagers3.map(new Function() { + teenagerNames = teenagers3.toJavaRDD().map(new Function() { @Override public String call(Row row) { return "Name: " + row.getString(0); } }).collect(); @@ -151,7 +151,7 @@ public String call(Row row) { List jsonData = Arrays.asList( "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); JavaRDD anotherPeopleRDD = ctx.parallelize(jsonData); - JavaSchemaRDD peopleFromJsonRDD = sqlCtx.jsonRDD(anotherPeopleRDD); + SchemaRDD peopleFromJsonRDD = sqlCtx.jsonRDD(anotherPeopleRDD.rdd()); // Take a look at the schema of this new JavaSchemaRDD. peopleFromJsonRDD.printSchema(); @@ -164,8 +164,8 @@ public String call(Row row) { peopleFromJsonRDD.registerTempTable("people2"); - JavaSchemaRDD peopleWithCity = sqlCtx.sql("SELECT name, address.city FROM people2"); - List nameAndCity = peopleWithCity.map(new Function() { + SchemaRDD peopleWithCity = sqlCtx.sql("SELECT name, address.city FROM people2"); + List nameAndCity = peopleWithCity.toJavaRDD().map(new Function() { @Override public String call(Row row) { return "Name: " + row.getString(0) + ", City: " + row.getString(1); diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala index ce6bc066bd70..d8c7ef38ee46 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala @@ -106,5 +106,7 @@ object CrossValidatorExample { .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) => println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction) } + + sc.stop() } } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala index 44d5b084c269..e8a2adff929c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala @@ -97,5 +97,7 @@ object SimpleParamsExample { .foreach { case Row(features: Vector, label: Double, prob: Double, prediction: Double) => println("(" + features + ", " + label + ") -> prob=" + prob + ", prediction=" + prediction) } + + sc.stop() } } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala index 92895a05e479..b9a6ef0229de 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala @@ -85,5 +85,7 @@ object SimpleTextClassificationPipeline { .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) => println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction) } + + sc.stop() } } diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index 13943ed5442b..f333e3891b5f 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -80,7 +80,7 @@ class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with L val socket = new ServerSocket(trialPort) socket.close() (null, trialPort) - })._2 + }, conf)._2 } /** Setup and start the streaming context */ diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala index 30727dfa6443..fe53a29cba0c 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala @@ -32,6 +32,7 @@ import org.scalatest.concurrent.Eventually import org.apache.spark.streaming.{Milliseconds, StreamingContext} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.SparkConf import org.apache.spark.util.Utils class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { @@ -106,7 +107,7 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { val socket = new ServerSocket(trialPort) socket.close() (null, trialPort) - })._2 + }, new SparkConf())._2 } def publishData(data: String): Unit = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala index fdbee743e817..77d230eb4a12 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -18,12 +18,10 @@ package org.apache.spark.ml import scala.annotation.varargs -import scala.collection.JavaConverters._ import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param.{ParamMap, ParamPair, Params} import org.apache.spark.sql.SchemaRDD -import org.apache.spark.sql.api.java.JavaSchemaRDD /** * :: AlphaComponent :: @@ -66,40 +64,4 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { def fit(dataset: SchemaRDD, paramMaps: Array[ParamMap]): Seq[M] = { paramMaps.map(fit(dataset, _)) } - - // Java-friendly versions of fit. - - /** - * Fits a single model to the input data with optional parameters. - * - * @param dataset input dataset - * @param paramPairs optional list of param pairs (overwrite embedded params) - * @return fitted model - */ - @varargs - def fit(dataset: JavaSchemaRDD, paramPairs: ParamPair[_]*): M = { - fit(dataset.schemaRDD, paramPairs: _*) - } - - /** - * Fits a single model to the input data with provided parameter map. - * - * @param dataset input dataset - * @param paramMap parameter map - * @return fitted model - */ - def fit(dataset: JavaSchemaRDD, paramMap: ParamMap): M = { - fit(dataset.schemaRDD, paramMap) - } - - /** - * Fits multiple models to the input data with multiple sets of parameters. - * - * @param dataset input dataset - * @param paramMaps an array of parameter maps - * @return fitted models, matching the input parameter maps - */ - def fit(dataset: JavaSchemaRDD, paramMaps: Array[ParamMap]): java.util.List[M] = { - fit(dataset.schemaRDD, paramMaps).asJava - } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 081a574beea5..ad6fed178fae 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -21,8 +21,9 @@ import scala.collection.mutable.ListBuffer import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.param.{Params, Param, ParamMap} -import org.apache.spark.sql.{SchemaRDD, StructType} +import org.apache.spark.ml.param.{Param, ParamMap} +import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.types.StructType /** * :: AlphaComponent :: diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index 23fbd228d01c..af56f9c43535 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -23,10 +23,9 @@ import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param._ import org.apache.spark.sql.SchemaRDD -import org.apache.spark.sql.api.java.JavaSchemaRDD import org.apache.spark.sql.catalyst.analysis.Star import org.apache.spark.sql.catalyst.expressions.ScalaUdf -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.types._ /** * :: AlphaComponent :: @@ -55,29 +54,6 @@ abstract class Transformer extends PipelineStage with Params { * @return transformed dataset */ def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD - - // Java-friendly versions of transform. - - /** - * Transforms the dataset with optional parameters. - * @param dataset input datset - * @param paramPairs optional list of param pairs, overwrite embedded params - * @return transformed dataset - */ - @varargs - def transform(dataset: JavaSchemaRDD, paramPairs: ParamPair[_]*): JavaSchemaRDD = { - transform(dataset.schemaRDD, paramPairs: _*).toJavaSchemaRDD - } - - /** - * Transforms the dataset with provided parameter map as additional parameters. - * @param dataset input dataset - * @param paramMap additional parameters, overwrite embedded params - * @return transformed dataset - */ - def transform(dataset: JavaSchemaRDD, paramMap: ParamMap): JavaSchemaRDD = { - transform(dataset.schemaRDD, paramMap).toJavaSchemaRDD - } } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 85b8899636ca..8c570812f831 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -26,6 +26,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.Star import org.apache.spark.sql.catalyst.dsl._ +import org.apache.spark.sql.types.{DoubleType, StructField, StructType} import org.apache.spark.storage.StorageLevel /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 0b0504e036ec..12473cb2b571 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -21,7 +21,8 @@ import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics -import org.apache.spark.sql.{DoubleType, Row, SchemaRDD} +import org.apache.spark.sql.{Row, SchemaRDD} +import org.apache.spark.sql.types.DoubleType /** * :: AlphaComponent :: diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index e0bfb1e484a2..0956062643f2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.{IntParam, ParamMap} import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{VectorUDT, Vector} -import org.apache.spark.sql.catalyst.types.DataType +import org.apache.spark.sql.types.DataType /** * :: AlphaComponent :: diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 896a6b83b67b..72825f6e0218 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -25,6 +25,7 @@ import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.Star import org.apache.spark.sql.catalyst.dsl._ +import org.apache.spark.sql.types.{StructField, StructType} /** * Params for [[StandardScaler]] and [[StandardScalerModel]]. diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 9352f40f372d..e622a5cf9e6f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.ParamMap -import org.apache.spark.sql.{DataType, StringType, ArrayType} +import org.apache.spark.sql.types.{DataType, StringType, ArrayType} /** * :: AlphaComponent :: diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 194b9bfd9a9e..08fe99176424 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -24,7 +24,8 @@ import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml._ import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params} import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.sql.{SchemaRDD, StructType} +import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.types.StructType /** * Params for [[CrossValidator]] and [[CrossValidatorModel]]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index bf1faa25ef0e..adbd8266ed6f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -27,9 +27,8 @@ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} import org.apache.spark.SparkException import org.apache.spark.mllib.util.NumericParser -import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Row} -import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.types._ /** * Represents a numeric vector, whose index type is Int and value type is Double. diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java index 42846677ed28..47f1f46c6c26 100644 --- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java @@ -26,10 +26,9 @@ import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.feature.StandardScaler; -import org.apache.spark.sql.api.java.JavaSQLContext; -import org.apache.spark.sql.api.java.JavaSchemaRDD; -import static org.apache.spark.mllib.classification.LogisticRegressionSuite - .generateLogisticInputAsList; +import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.SQLContext; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; /** * Test Pipeline construction and fitting in Java. @@ -37,13 +36,13 @@ public class JavaPipelineSuite { private transient JavaSparkContext jsc; - private transient JavaSQLContext jsql; - private transient JavaSchemaRDD dataset; + private transient SQLContext jsql; + private transient SchemaRDD dataset; @Before public void setUp() { jsc = new JavaSparkContext("local", "JavaPipelineSuite"); - jsql = new JavaSQLContext(jsc); + jsql = new SQLContext(jsc); JavaRDD points = jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2); dataset = jsql.applySchema(points, LabeledPoint.class); @@ -66,7 +65,7 @@ public void pipeline() { .setStages(new PipelineStage[] {scaler, lr}); PipelineModel model = pipeline.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); - predictions.collect(); + SchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + predictions.collectAsList(); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index 76eb7f00329f..2eba83335bb5 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -26,21 +26,20 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.api.java.JavaSQLContext; -import org.apache.spark.sql.api.java.JavaSchemaRDD; -import static org.apache.spark.mllib.classification.LogisticRegressionSuite - .generateLogisticInputAsList; +import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.SQLContext; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; public class JavaLogisticRegressionSuite implements Serializable { private transient JavaSparkContext jsc; - private transient JavaSQLContext jsql; - private transient JavaSchemaRDD dataset; + private transient SQLContext jsql; + private transient SchemaRDD dataset; @Before public void setUp() { jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); - jsql = new JavaSQLContext(jsc); + jsql = new SQLContext(jsc); List points = generateLogisticInputAsList(1.0, 1.0, 100, 42); dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class); } @@ -56,8 +55,8 @@ public void logisticRegression() { LogisticRegression lr = new LogisticRegression(); LogisticRegressionModel model = lr.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); - predictions.collect(); + SchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + predictions.collectAsList(); } @Test @@ -68,8 +67,8 @@ public void logisticRegressionWithSetters() { LogisticRegressionModel model = lr.fit(dataset); model.transform(dataset, model.threshold().w(0.8)) // overwrite threshold .registerTempTable("prediction"); - JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); - predictions.collect(); + SchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + predictions.collectAsList(); } @Test diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java index a266ebd2071a..a9f1c4a2c3ca 100644 --- a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java @@ -30,21 +30,20 @@ import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator; import org.apache.spark.ml.param.ParamMap; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.api.java.JavaSQLContext; -import org.apache.spark.sql.api.java.JavaSchemaRDD; -import static org.apache.spark.mllib.classification.LogisticRegressionSuite - .generateLogisticInputAsList; +import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.SQLContext; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; public class JavaCrossValidatorSuite implements Serializable { private transient JavaSparkContext jsc; - private transient JavaSQLContext jsql; - private transient JavaSchemaRDD dataset; + private transient SQLContext jsql; + private transient SchemaRDD dataset; @Before public void setUp() { jsc = new JavaSparkContext("local", "JavaCrossValidatorSuite"); - jsql = new JavaSQLContext(jsc); + jsql = new SQLContext(jsc); List points = generateLogisticInputAsList(1.0, 1.0, 100, 42); dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class); } diff --git a/pom.xml b/pom.xml index aadcdfd1083c..f4466e56c2a5 100644 --- a/pom.xml +++ b/pom.xml @@ -122,7 +122,7 @@ 1.0.4 2.4.1 ${hadoop.version} - 0.94.6 + 0.98.7-hadoop1 hbase 1.4.0 3.4.5 @@ -1130,6 +1130,7 @@ ${test_classpath} true + false @@ -1465,6 +1466,7 @@ 2.2.0 2.5.0 + 0.98.7-hadoop2 hadoop2 @@ -1475,6 +1477,7 @@ 2.3.0 2.5.0 0.9.0 + 0.98.7-hadoop2 3.1.1 hadoop2 @@ -1486,6 +1489,7 @@ 2.4.0 2.5.0 0.9.0 + 0.98.7-hadoop2 3.1.1 hadoop2 diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 51e8bd4cf641..0ccbfcb0c43f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -60,6 +60,28 @@ object MimaExcludes { ProblemFilters.exclude[IncompatibleResultTypeProblem]( "org.apache.spark.streaming.flume.sink.SparkAvroCallbackHandler." + "removeAndGetProcessor") + ) ++ Seq( + // SPARK-5123 (SparkSQL data type change) - alpha component only + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.ml.feature.HashingTF.outputDataType"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.ml.feature.Tokenizer.outputDataType"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.ml.feature.Tokenizer.validateInputType"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.ml.classification.LogisticRegressionModel.validateAndTransformSchema"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.ml.classification.LogisticRegression.validateAndTransformSchema") + ) ++ Seq( + // SPARK-4014 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.TaskContext.taskAttemptId"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.TaskContext.attemptNumber") + ) ++ Seq( + // SPARK-5166 Spark SQL API stabilization + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Transformer.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Estimator.fit") ) case v if v.startsWith("1.2") => diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 46a54c681840..ded4b5443a90 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -114,17 +114,6 @@ object SparkBuild extends PomBuild { override val userPropertiesMap = System.getProperties.toMap - // Handle case where hadoop.version is set via profile. - // Needed only because we read back this property in sbt - // when we create the assembly jar. - val pom = loadEffectivePom(new File("pom.xml"), - profiles = profiles, - userProps = userPropertiesMap) - if (System.getProperty("hadoop.version") == null) { - System.setProperty("hadoop.version", - pom.getProperties.get("hadoop.version").asInstanceOf[String]) - } - lazy val MavenCompile = config("m2r") extend(Compile) lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy") @@ -254,10 +243,10 @@ object SQL { |import org.apache.spark.sql.catalyst.expressions._ |import org.apache.spark.sql.catalyst.plans.logical._ |import org.apache.spark.sql.catalyst.rules._ - |import org.apache.spark.sql.catalyst.types._ |import org.apache.spark.sql.catalyst.util._ |import org.apache.spark.sql.execution |import org.apache.spark.sql.test.TestSQLContext._ + |import org.apache.spark.sql.types._ |import org.apache.spark.sql.parquet.ParquetTestData""".stripMargin, cleanupCommands in console := "sparkContext.stop()" ) @@ -284,11 +273,11 @@ object Hive { |import org.apache.spark.sql.catalyst.expressions._ |import org.apache.spark.sql.catalyst.plans.logical._ |import org.apache.spark.sql.catalyst.rules._ - |import org.apache.spark.sql.catalyst.types._ |import org.apache.spark.sql.catalyst.util._ |import org.apache.spark.sql.execution |import org.apache.spark.sql.hive._ |import org.apache.spark.sql.hive.test.TestHive._ + |import org.apache.spark.sql.types._ |import org.apache.spark.sql.parquet.ParquetTestData""".stripMargin, cleanupCommands in console := "sparkContext.stop()", // Some of our log4j jars make it impossible to submit jobs from this JVM to Hive Map/Reduce @@ -303,16 +292,15 @@ object Assembly { import sbtassembly.Plugin._ import AssemblyKeys._ + val hadoopVersion = taskKey[String]("The version of hadoop that spark is compiled against.") + lazy val settings = assemblySettings ++ Seq( test in assembly := {}, - jarName in assembly <<= (version, moduleName) map { (v, mName) => - if (mName.contains("network-yarn")) { - // This must match the same name used in maven (see network/yarn/pom.xml) - "spark-" + v + "-yarn-shuffle.jar" - } else { - mName + "-" + v + "-hadoop" + System.getProperty("hadoop.version") + ".jar" - } + hadoopVersion := { + sys.props.get("hadoop.version") + .getOrElse(SbtPomKeys.effectivePom.value.getProperties.get("hadoop.version").asInstanceOf[String]) }, + jarName in assembly := s"${moduleName.value}-${version.value}-hadoop${hadoopVersion.value}.jar", mergeStrategy in assembly := { case PathList("org", "datanucleus", xs @ _*) => MergeStrategy.discard case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard @@ -323,7 +311,6 @@ object Assembly { case _ => MergeStrategy.first } ) - } object Unidoc { diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 593d74bca5ff..64f6a3ca6bf4 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -319,7 +319,7 @@ def f(split, iterator): # Make sure we distribute data evenly if it's smaller than self.batchSize if "__len__" not in dir(c): c = list(c) # Make it a list so we can compute its length - batchSize = max(1, min(len(c) // numSlices, self._batchSize)) + batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024)) serializer = BatchedSerializer(self._unbatched_serializer, batchSize) serializer.dump_stream(c, tempFile) tempFile.close() diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py index 33c49e239990..3c5ee66cd8b6 100644 --- a/python/pyspark/mllib/common.py +++ b/python/pyspark/mllib/common.py @@ -18,7 +18,7 @@ import py4j.protocol from py4j.protocol import Py4JJavaError from py4j.java_gateway import JavaObject -from py4j.java_collections import MapConverter, ListConverter, JavaArray, JavaList +from py4j.java_collections import ListConverter, JavaArray, JavaList from pyspark import RDD, SparkContext from pyspark.serializers import PickleSerializer, AutoBatchedSerializer @@ -70,9 +70,7 @@ def _py2java(sc, obj): obj = _to_java_object_rdd(obj) elif isinstance(obj, SparkContext): obj = obj._jsc - elif isinstance(obj, dict): - obj = MapConverter().convert(obj, sc._gateway._gateway_client) - elif isinstance(obj, (list, tuple)): + elif isinstance(obj, list) and (obj or isinstance(obj[0], JavaObject)): obj = ListConverter().convert(obj, sc._gateway._gateway_client) elif isinstance(obj, JavaObject): pass diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 4f8491f43e45..7f21190ed8c2 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -510,6 +510,23 @@ def __eq__(self, other): and np.array_equal(other.indices, self.indices) and np.array_equal(other.values, self.values)) + def __getitem__(self, index): + inds = self.indices + vals = self.values + if not isinstance(index, int): + raise ValueError( + "Indices must be of type integer, got type %s" % type(index)) + if index < 0: + index += self.size + if index >= self.size or index < 0: + raise ValueError("Index %d out of bounds." % index) + + insert_index = np.searchsorted(inds, index) + row_ind = inds[insert_index] + if row_ind == index: + return vals[insert_index] + return 0. + def __ne__(self, other): return not self.__eq__(other) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 1f48bc1219db..140c22b5fd4e 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -120,6 +120,18 @@ def test_conversion(self): dv = DenseVector(v) self.assertTrue(dv.array.dtype == 'float64') + def test_sparse_vector_indexing(self): + sv = SparseVector(4, {1: 1, 3: 2}) + self.assertEquals(sv[0], 0.) + self.assertEquals(sv[3], 2.) + self.assertEquals(sv[1], 1.) + self.assertEquals(sv[2], 0.) + self.assertEquals(sv[-1], 2) + self.assertEquals(sv[-2], 0) + self.assertEquals(sv[-4], 0) + for ind in [4, -5, 7.8]: + self.assertRaises(ValueError, sv.__getitem__, ind) + class ListTests(PySparkTestCase): diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index bd08c9a6d20d..b8bda835174b 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -181,6 +181,10 @@ def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE): def _batched(self, iterator): if self.batchSize == self.UNLIMITED_BATCH_SIZE: yield list(iterator) + elif hasattr(iterator, "__len__") and hasattr(iterator, "__getslice__"): + n = len(iterator) + for i in xrange(0, n, self.batchSize): + yield iterator[i: i + self.batchSize] else: items = [] count = 0 diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 0e8b398fc6b9..1990323249cf 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -807,14 +807,14 @@ def convert_struct(obj): return if isinstance(obj, tuple): - if hasattr(obj, "fields"): - d = dict(zip(obj.fields, obj)) - if hasattr(obj, "__FIELDS__"): + if hasattr(obj, "_fields"): + d = dict(zip(obj._fields, obj)) + elif hasattr(obj, "__FIELDS__"): d = dict(zip(obj.__FIELDS__, obj)) elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): d = dict(obj) else: - raise ValueError("unexpected tuple: %s" % obj) + raise ValueError("unexpected tuple: %s" % str(obj)) elif isinstance(obj, dict): d = obj @@ -1281,14 +1281,14 @@ def registerFunction(self, name, f, returnType=StringType()): self._sc._gateway._gateway_client) includes = ListConverter().convert(self._sc._python_includes, self._sc._gateway._gateway_client) - self._ssql_ctx.registerPython(name, - bytearray(pickled_command), - env, - includes, - self._sc.pythonExec, - broadcast_vars, - self._sc._javaAccumulator, - returnType.json()) + self._ssql_ctx.udf().registerPython(name, + bytearray(pickled_command), + env, + includes, + self._sc.pythonExec, + broadcast_vars, + self._sc._javaAccumulator, + returnType.json()) def inferSchema(self, rdd, samplingRatio=None): """Infer and apply a schema to an RDD of L{Row}. @@ -1327,6 +1327,16 @@ def inferSchema(self, rdd, samplingRatio=None): >>> srdd = sqlCtx.inferSchema(nestedRdd2) >>> srdd.collect() [Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])] + + >>> from collections import namedtuple + >>> CustomRow = namedtuple('CustomRow', 'field1 field2') + >>> rdd = sc.parallelize( + ... [CustomRow(field1=1, field2="row1"), + ... CustomRow(field1=2, field2="row2"), + ... CustomRow(field1=3, field2="row3")]) + >>> srdd = sqlCtx.inferSchema(rdd) + >>> srdd.collect()[0] + Row(field1=1, field2=u'row1') """ if isinstance(rdd, SchemaRDD): @@ -1448,7 +1458,7 @@ def applySchema(self, rdd, schema): jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) - return SchemaRDD(srdd.toJavaSchemaRDD(), self) + return SchemaRDD(srdd, self) def registerRDDAsTable(self, rdd, tableName): """Registers the given RDD as a temporary table in the catalog. @@ -1477,7 +1487,7 @@ def parquetFile(self, path): >>> sorted(srdd.collect()) == sorted(srdd2.collect()) True """ - jschema_rdd = self._ssql_ctx.parquetFile(path).toJavaSchemaRDD() + jschema_rdd = self._ssql_ctx.parquetFile(path) return SchemaRDD(jschema_rdd, self) def jsonFile(self, path, schema=None, samplingRatio=1.0): @@ -1539,7 +1549,7 @@ def jsonFile(self, path, schema=None, samplingRatio=1.0): else: scala_datatype = self._ssql_ctx.parseDataType(schema.json()) srdd = self._ssql_ctx.jsonFile(path, scala_datatype) - return SchemaRDD(srdd.toJavaSchemaRDD(), self) + return SchemaRDD(srdd, self) def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): """Loads an RDD storing one JSON object per string as a L{SchemaRDD}. @@ -1609,7 +1619,7 @@ def func(iterator): else: scala_datatype = self._ssql_ctx.parseDataType(schema.json()) srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) - return SchemaRDD(srdd.toJavaSchemaRDD(), self) + return SchemaRDD(srdd, self) def sql(self, sqlQuery): """Return a L{SchemaRDD} representing the result of the given query. @@ -1620,7 +1630,7 @@ def sql(self, sqlQuery): >>> srdd2.collect() [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] """ - return SchemaRDD(self._ssql_ctx.sql(sqlQuery).toJavaSchemaRDD(), self) + return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self) def table(self, tableName): """Returns the specified table as a L{SchemaRDD}. @@ -1631,7 +1641,7 @@ def table(self, tableName): >>> sorted(srdd.collect()) == sorted(srdd2.collect()) True """ - return SchemaRDD(self._ssql_ctx.table(tableName).toJavaSchemaRDD(), self) + return SchemaRDD(self._ssql_ctx.table(tableName), self) def cacheTable(self, tableName): """Caches the specified table in-memory.""" @@ -1676,24 +1686,6 @@ def _ssql_ctx(self): def _get_hive_ctx(self): return self._jvm.HiveContext(self._jsc.sc()) - def hiveql(self, hqlQuery): - """ - DEPRECATED: Use sql() - """ - warnings.warn("hiveql() is deprecated as the sql function now parses using HiveQL by" + - "default. The SQL dialect for parsing can be set using 'spark.sql.dialect'", - DeprecationWarning) - return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery).toJavaSchemaRDD(), self) - - def hql(self, hqlQuery): - """ - DEPRECATED: Use sql() - """ - warnings.warn("hql() is deprecated as the sql function now parses using HiveQL by" + - "default. The SQL dialect for parsing can be set using 'spark.sql.dialect'", - DeprecationWarning) - return self.hiveql(hqlQuery) - class LocalHiveContext(HiveContext): @@ -1706,12 +1698,6 @@ def _get_hive_ctx(self): return self._jvm.LocalHiveContext(self._jsc.sc()) -class TestHiveContext(HiveContext): - - def _get_hive_ctx(self): - return self._jvm.TestHiveContext(self._jsc.sc()) - - def _create_row(fields, values): row = Row(*values) row.__FIELDS__ = fields @@ -1836,7 +1822,7 @@ def __init__(self, jschema_rdd, sql_ctx): self.sql_ctx = sql_ctx self._sc = sql_ctx._sc clsName = jschema_rdd.getClass().getName() - assert clsName.endswith("JavaSchemaRDD"), "jschema_rdd must be JavaSchemaRDD" + assert clsName.endswith("SchemaRDD"), "jschema_rdd must be SchemaRDD" self._jschema_rdd = jschema_rdd self._id = None self.is_cached = False @@ -1870,7 +1856,7 @@ def limit(self, num): >>> srdd.limit(0).collect() [] """ - rdd = self._jschema_rdd.baseSchemaRDD().limit(num).toJavaSchemaRDD() + rdd = self._jschema_rdd.baseSchemaRDD().limit(num) return SchemaRDD(rdd, self.sql_ctx) def toJSON(self, use_unicode=False): @@ -2049,18 +2035,18 @@ def isCheckpointed(self): def getCheckpointFile(self): checkpointFile = self._jschema_rdd.getCheckpointFile() - if checkpointFile.isPresent(): + if checkpointFile.isDefined(): return checkpointFile.get() def coalesce(self, numPartitions, shuffle=False): - rdd = self._jschema_rdd.coalesce(numPartitions, shuffle) + rdd = self._jschema_rdd.coalesce(numPartitions, shuffle, None) return SchemaRDD(rdd, self.sql_ctx) def distinct(self, numPartitions=None): if numPartitions is None: rdd = self._jschema_rdd.distinct() else: - rdd = self._jschema_rdd.distinct(numPartitions) + rdd = self._jschema_rdd.distinct(numPartitions, None) return SchemaRDD(rdd, self.sql_ctx) def intersection(self, other): @@ -2071,7 +2057,7 @@ def intersection(self, other): raise ValueError("Can only intersect with another SchemaRDD") def repartition(self, numPartitions): - rdd = self._jschema_rdd.repartition(numPartitions) + rdd = self._jschema_rdd.repartition(numPartitions, None) return SchemaRDD(rdd, self.sql_ctx) def subtract(self, other, numPartitions=None): diff --git a/repl/pom.xml b/repl/pom.xml index 0bc8bccf90a6..ae7c31aef4f5 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -92,13 +92,6 @@ target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes - - org.apache.maven.plugins - maven-deploy-plugin - - true - - org.codehaus.mojo diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala index 05816941b54b..6480e2d24e04 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala @@ -19,14 +19,21 @@ package org.apache.spark.repl import scala.tools.nsc.{Settings, CompilerCommand} import scala.Predef._ +import org.apache.spark.annotation.DeveloperApi /** * Command class enabling Spark-specific command line options (provided by * org.apache.spark.repl.SparkRunnerSettings). + * + * @example new SparkCommandLine(Nil).settings + * + * @param args The list of command line arguments + * @param settings The underlying settings to associate with this set of + * command-line options */ +@DeveloperApi class SparkCommandLine(args: List[String], override val settings: Settings) extends CompilerCommand(args, settings) { - def this(args: List[String], error: String => Unit) { this(args, new SparkRunnerSettings(error)) } diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala index f8432c8af6ed..5fb378112ef9 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala @@ -15,7 +15,7 @@ import scala.tools.nsc.ast.parser.Tokens.EOF import org.apache.spark.Logging -trait SparkExprTyper extends Logging { +private[repl] trait SparkExprTyper extends Logging { val repl: SparkIMain import repl._ diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkHelper.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkHelper.scala index 5340951d9133..955be17a73b8 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkHelper.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkHelper.scala @@ -17,6 +17,23 @@ package scala.tools.nsc +import org.apache.spark.annotation.DeveloperApi + +// NOTE: Forced to be public (and in scala.tools.nsc package) to access the +// settings "explicitParentLoader" method + +/** + * Provides exposure for the explicitParentLoader method on settings instances. + */ +@DeveloperApi object SparkHelper { + /** + * Retrieves the explicit parent loader for the provided settings. + * + * @param settings The settings whose explicit parent loader to retrieve + * + * @return The Optional classloader representing the explicit parent loader + */ + @DeveloperApi def explicitParentLoader(settings: Settings) = settings.explicitParentLoader } diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala index e56b74edba88..72c1a989999b 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -10,6 +10,8 @@ package org.apache.spark.repl import java.net.URL +import org.apache.spark.annotation.DeveloperApi + import scala.reflect.io.AbstractFile import scala.tools.nsc._ import scala.tools.nsc.backend.JavaPlatform @@ -57,20 +59,22 @@ import org.apache.spark.util.Utils * @author Lex Spoon * @version 1.2 */ -class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, - val master: Option[String]) - extends AnyRef - with LoopCommands - with SparkILoopInit - with Logging -{ +@DeveloperApi +class SparkILoop( + private val in0: Option[BufferedReader], + protected val out: JPrintWriter, + val master: Option[String] +) extends AnyRef with LoopCommands with SparkILoopInit with Logging { def this(in0: BufferedReader, out: JPrintWriter, master: String) = this(Some(in0), out, Some(master)) def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out, None) def this() = this(None, new JPrintWriter(Console.out, true), None) - var in: InteractiveReader = _ // the input stream from which commands come - var settings: Settings = _ - var intp: SparkIMain = _ + private var in: InteractiveReader = _ // the input stream from which commands come + + // NOTE: Exposed in package for testing + private[repl] var settings: Settings = _ + + private[repl] var intp: SparkIMain = _ @deprecated("Use `intp` instead.", "2.9.0") def interpreter = intp @deprecated("Use `intp` instead.", "2.9.0") def interpreter_= (i: SparkIMain): Unit = intp = i @@ -123,6 +127,8 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } } + // NOTE: Must be public for visibility + @DeveloperApi var sparkContext: SparkContext = _ override def echoCommandMessage(msg: String) { @@ -130,45 +136,45 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } // def isAsync = !settings.Yreplsync.value - def isAsync = false + private[repl] def isAsync = false // lazy val power = new Power(intp, new StdReplVals(this))(tagOfStdReplVals, classTag[StdReplVals]) - def history = in.history + private def history = in.history /** The context class loader at the time this object was created */ protected val originalClassLoader = Utils.getContextOrSparkClassLoader // classpath entries added via :cp - var addedClasspath: String = "" + private var addedClasspath: String = "" /** A reverse list of commands to replay if the user requests a :replay */ - var replayCommandStack: List[String] = Nil + private var replayCommandStack: List[String] = Nil /** A list of commands to replay if the user requests a :replay */ - def replayCommands = replayCommandStack.reverse + private def replayCommands = replayCommandStack.reverse /** Record a command for replay should the user request a :replay */ - def addReplay(cmd: String) = replayCommandStack ::= cmd + private def addReplay(cmd: String) = replayCommandStack ::= cmd - def savingReplayStack[T](body: => T): T = { + private def savingReplayStack[T](body: => T): T = { val saved = replayCommandStack try body finally replayCommandStack = saved } - def savingReader[T](body: => T): T = { + private def savingReader[T](body: => T): T = { val saved = in try body finally in = saved } - def sparkCleanUp(){ + private def sparkCleanUp(){ echo("Stopping spark context.") intp.beQuietDuring { command("sc.stop()") } } /** Close the interpreter and set the var to null. */ - def closeInterpreter() { + private def closeInterpreter() { if (intp ne null) { sparkCleanUp() intp.close() @@ -179,14 +185,16 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, class SparkILoopInterpreter extends SparkIMain(settings, out) { outer => - override lazy val formatting = new Formatting { + override private[repl] lazy val formatting = new Formatting { def prompt = SparkILoop.this.prompt } override protected def parentClassLoader = SparkHelper.explicitParentLoader(settings).getOrElse(classOf[SparkILoop].getClassLoader) } - /** Create a new interpreter. */ - def createInterpreter() { + /** + * Constructs a new interpreter. + */ + protected def createInterpreter() { require(settings != null) if (addedClasspath != "") settings.classpath.append(addedClasspath) @@ -207,7 +215,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } /** print a friendly help message */ - def helpCommand(line: String): Result = { + private def helpCommand(line: String): Result = { if (line == "") helpSummary() else uniqueCommand(line) match { case Some(lc) => echo("\n" + lc.longHelp) @@ -258,7 +266,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } /** Show the history */ - lazy val historyCommand = new LoopCommand("history", "show the history (optional num is commands to show)") { + private lazy val historyCommand = new LoopCommand("history", "show the history (optional num is commands to show)") { override def usage = "[num]" def defaultLines = 20 @@ -279,21 +287,21 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, // When you know you are most likely breaking into the middle // of a line being typed. This softens the blow. - protected def echoAndRefresh(msg: String) = { + private[repl] def echoAndRefresh(msg: String) = { echo("\n" + msg) in.redrawLine() } - protected def echo(msg: String) = { + private[repl] def echo(msg: String) = { out println msg out.flush() } - protected def echoNoNL(msg: String) = { + private def echoNoNL(msg: String) = { out print msg out.flush() } /** Search the history */ - def searchHistory(_cmdline: String) { + private def searchHistory(_cmdline: String) { val cmdline = _cmdline.toLowerCase val offset = history.index - history.size + 1 @@ -302,14 +310,27 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } private var currentPrompt = Properties.shellPromptString + + /** + * Sets the prompt string used by the REPL. + * + * @param prompt The new prompt string + */ + @DeveloperApi def setPrompt(prompt: String) = currentPrompt = prompt - /** Prompt to print when awaiting input */ + + /** + * Represents the current prompt string used by the REPL. + * + * @return The current prompt string + */ + @DeveloperApi def prompt = currentPrompt import LoopCommand.{ cmd, nullary } /** Standard commands */ - lazy val standardCommands = List( + private lazy val standardCommands = List( cmd("cp", "", "add a jar or directory to the classpath", addClasspath), cmd("help", "[command]", "print this summary or command-specific help", helpCommand), historyCommand, @@ -333,7 +354,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, ) /** Power user commands */ - lazy val powerCommands: List[LoopCommand] = List( + private lazy val powerCommands: List[LoopCommand] = List( // cmd("phase", "", "set the implicit phase for power commands", phaseCommand) ) @@ -459,7 +480,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } } - protected def newJavap() = new JavapClass(addToolsJarToLoader(), new SparkIMain.ReplStrippingWriter(intp)) { + private def newJavap() = new JavapClass(addToolsJarToLoader(), new SparkIMain.ReplStrippingWriter(intp)) { override def tryClass(path: String): Array[Byte] = { val hd :: rest = path split '.' toList; // If there are dots in the name, the first segment is the @@ -581,7 +602,12 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, // } // } - /** Available commands */ + /** + * Provides a list of available commands. + * + * @return The list of commands + */ + @DeveloperApi def commands: List[LoopCommand] = standardCommands /*++ ( if (isReplPower) powerCommands else Nil )*/ @@ -613,7 +639,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, * command() for each line of input, and stops when * command() returns false. */ - def loop() { + private def loop() { def readOneLine() = { out.flush() in readLine prompt @@ -642,7 +668,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } /** interpret all lines from a specified file */ - def interpretAllFrom(file: File) { + private def interpretAllFrom(file: File) { savingReader { savingReplayStack { file applyReader { reader => @@ -655,7 +681,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } /** create a new interpreter and replay the given commands */ - def replay() { + private def replay() { reset() if (replayCommandStack.isEmpty) echo("Nothing to replay.") @@ -665,7 +691,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, echo("") } } - def resetCommand() { + private def resetCommand() { echo("Resetting repl state.") if (replayCommandStack.nonEmpty) { echo("Forgetting this session history:\n") @@ -681,13 +707,13 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, reset() } - def reset() { + private def reset() { intp.reset() // unleashAndSetPhase() } /** fork a shell and run a command */ - lazy val shCommand = new LoopCommand("sh", "run a shell command (result is implicitly => List[String])") { + private lazy val shCommand = new LoopCommand("sh", "run a shell command (result is implicitly => List[String])") { override def usage = "" def apply(line: String): Result = line match { case "" => showUsage() @@ -698,14 +724,14 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } } - def withFile(filename: String)(action: File => Unit) { + private def withFile(filename: String)(action: File => Unit) { val f = File(filename) if (f.exists) action(f) else echo("That file does not exist") } - def loadCommand(arg: String) = { + private def loadCommand(arg: String) = { var shouldReplay: Option[String] = None withFile(arg)(f => { interpretAllFrom(f) @@ -714,7 +740,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, Result(true, shouldReplay) } - def addAllClasspath(args: Seq[String]): Unit = { + private def addAllClasspath(args: Seq[String]): Unit = { var added = false var totalClasspath = "" for (arg <- args) { @@ -729,7 +755,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } } - def addClasspath(arg: String): Unit = { + private def addClasspath(arg: String): Unit = { val f = File(arg).normalize if (f.exists) { addedClasspath = ClassPath.join(addedClasspath, f.path) @@ -741,12 +767,12 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } - def powerCmd(): Result = { + private def powerCmd(): Result = { if (isReplPower) "Already in power mode." else enablePowerMode(false) } - def enablePowerMode(isDuringInit: Boolean) = { + private[repl] def enablePowerMode(isDuringInit: Boolean) = { // replProps.power setValue true // unleashAndSetPhase() // asyncEcho(isDuringInit, power.banner) @@ -759,12 +785,12 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, // } // } - def asyncEcho(async: Boolean, msg: => String) { + private def asyncEcho(async: Boolean, msg: => String) { if (async) asyncMessage(msg) else echo(msg) } - def verbosity() = { + private def verbosity() = { // val old = intp.printResults // intp.printResults = !old // echo("Switched " + (if (old) "off" else "on") + " result printing.") @@ -773,7 +799,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, /** Run one command submitted by the user. Two values are returned: * (1) whether to keep running, (2) the line to record for replay, * if any. */ - def command(line: String): Result = { + private[repl] def command(line: String): Result = { if (line startsWith ":") { val cmd = line.tail takeWhile (x => !x.isWhitespace) uniqueCommand(cmd) match { @@ -789,7 +815,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, Iterator continually in.readLine("") takeWhile (x => x != null && cond(x)) } - def pasteCommand(): Result = { + private def pasteCommand(): Result = { echo("// Entering paste mode (ctrl-D to finish)\n") val code = readWhile(_ => true) mkString "\n" echo("\n// Exiting paste mode, now interpreting.\n") @@ -820,7 +846,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, * read, go ahead and interpret it. Return the full string * to be recorded for replay, if any. */ - def interpretStartingWith(code: String): Option[String] = { + private def interpretStartingWith(code: String): Option[String] = { // signal completion non-completion input has been received in.completion.resetVerbosity() @@ -874,7 +900,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } // runs :load `file` on any files passed via -i - def loadFiles(settings: Settings) = settings match { + private def loadFiles(settings: Settings) = settings match { case settings: SparkRunnerSettings => for (filename <- settings.loadfiles.value) { val cmd = ":load " + filename @@ -889,7 +915,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, * unless settings or properties are such that it should start * with SimpleReader. */ - def chooseReader(settings: Settings): InteractiveReader = { + private def chooseReader(settings: Settings): InteractiveReader = { if (settings.Xnojline.value || Properties.isEmacsShell) SimpleReader() else try new SparkJLineReader( @@ -903,8 +929,8 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } } - val u: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe - val m = u.runtimeMirror(Utils.getSparkClassLoader) + private val u: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe + private val m = u.runtimeMirror(Utils.getSparkClassLoader) private def tagOfStaticClass[T: ClassTag]: u.TypeTag[T] = u.TypeTag[T]( m, @@ -913,7 +939,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, m.staticClass(classTag[T].runtimeClass.getName).toTypeConstructor.asInstanceOf[U # Type] }) - def process(settings: Settings): Boolean = savingContextLoader { + private def process(settings: Settings): Boolean = savingContextLoader { if (getMaster() == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") this.settings = settings @@ -972,6 +998,8 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, true } + // NOTE: Must be public for visibility + @DeveloperApi def createSparkContext(): SparkContext = { val execUri = System.getenv("SPARK_EXECUTOR_URI") val jars = SparkILoop.getAddedJars @@ -979,7 +1007,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, .setMaster(getMaster()) .setAppName("Spark shell") .setJars(jars) - .set("spark.repl.class.uri", intp.classServer.uri) + .set("spark.repl.class.uri", intp.classServerUri) if (execUri != null) { conf.set("spark.executor.uri", execUri) } @@ -1014,7 +1042,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } @deprecated("Use `process` instead", "2.9.0") - def main(settings: Settings): Unit = process(settings) + private def main(settings: Settings): Unit = process(settings) } object SparkILoop { @@ -1033,7 +1061,7 @@ object SparkILoop { // Designed primarily for use by test code: take a String with a // bunch of code, and prints out a transcript of what it would look // like if you'd just typed it into the repl. - def runForTranscript(code: String, settings: Settings): String = { + private[repl] def runForTranscript(code: String, settings: Settings): String = { import java.io.{ BufferedReader, StringReader, OutputStreamWriter } stringFromStream { ostream => @@ -1071,7 +1099,7 @@ object SparkILoop { /** Creates an interpreter loop with default settings and feeds * the given code to it as input. */ - def run(code: String, sets: Settings = new Settings): String = { + private[repl] def run(code: String, sets: Settings = new Settings): String = { import java.io.{ BufferedReader, StringReader, OutputStreamWriter } stringFromStream { ostream => @@ -1087,5 +1115,5 @@ object SparkILoop { } } } - def run(lines: List[String]): String = run(lines map (_ + "\n") mkString) + private[repl] def run(lines: List[String]): String = run(lines map (_ + "\n") mkString) } diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala index da4286c5e487..99bd777c04fd 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala @@ -19,7 +19,7 @@ import org.apache.spark.SPARK_VERSION /** * Machinery for the asynchronous initialization of the repl. */ -trait SparkILoopInit { +private[repl] trait SparkILoopInit { self: SparkILoop => /** Print a welcome message */ diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 646c68e60c2e..35fb62564502 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -39,6 +39,7 @@ import scala.util.control.ControlThrowable import org.apache.spark.{Logging, HttpServer, SecurityManager, SparkConf} import org.apache.spark.util.Utils +import org.apache.spark.annotation.DeveloperApi // /** directory to save .class files to */ // private class ReplVirtualDirectory(out: JPrintWriter) extends VirtualDirectory("((memory))", None) { @@ -84,17 +85,18 @@ import org.apache.spark.util.Utils * @author Moez A. Abdel-Gawad * @author Lex Spoon */ + @DeveloperApi class SparkIMain( initialSettings: Settings, val out: JPrintWriter, propagateExceptions: Boolean = false) extends SparkImports with Logging { imain => - val conf = new SparkConf() + private val conf = new SparkConf() - val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1") + private val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1") /** Local directory to save .class files too */ - lazy val outputDir = { + private lazy val outputDir = { val tmp = System.getProperty("java.io.tmpdir") val rootDir = conf.get("spark.repl.classdir", tmp) Utils.createTempDir(rootDir) @@ -103,13 +105,20 @@ import org.apache.spark.util.Utils echo("Output directory: " + outputDir) } - val virtualDirectory = new PlainFile(outputDir) // "directory" for classfiles + /** + * Returns the path to the output directory containing all generated + * class files that will be served by the REPL class server. + */ + @DeveloperApi + lazy val getClassOutputDirectory = outputDir + + private val virtualDirectory = new PlainFile(outputDir) // "directory" for classfiles /** Jetty server that will serve our classes to worker nodes */ - val classServerPort = conf.getInt("spark.replClassServer.port", 0) - val classServer = new HttpServer(outputDir, new SecurityManager(conf), classServerPort, "HTTP class server") + private val classServerPort = conf.getInt("spark.replClassServer.port", 0) + private val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf), classServerPort, "HTTP class server") private var currentSettings: Settings = initialSettings - var printResults = true // whether to print result lines - var totalSilence = false // whether to print anything + private var printResults = true // whether to print result lines + private var totalSilence = false // whether to print anything private var _initializeComplete = false // compiler is initialized private var _isInitialized: Future[Boolean] = null // set up initialization future private var bindExceptions = true // whether to bind the lastException variable @@ -123,6 +132,14 @@ import org.apache.spark.util.Utils echo("Class server started, URI = " + classServer.uri) } + /** + * URI of the class server used to feed REPL compiled classes. + * + * @return The string representing the class server uri + */ + @DeveloperApi + def classServerUri = classServer.uri + /** We're going to go to some trouble to initialize the compiler asynchronously. * It's critical that nothing call into it until it's been initialized or we will * run into unrecoverable issues, but the perceived repl startup time goes @@ -141,17 +158,18 @@ import org.apache.spark.util.Utils () => { counter += 1 ; counter } } - def compilerClasspath: Seq[URL] = ( + private def compilerClasspath: Seq[URL] = ( if (isInitializeComplete) global.classPath.asURLs else new PathResolver(settings).result.asURLs // the compiler's classpath ) - def settings = currentSettings - def mostRecentLine = prevRequestList match { + // NOTE: Exposed to repl package since accessed indirectly from SparkIMain + private[repl] def settings = currentSettings + private def mostRecentLine = prevRequestList match { case Nil => "" case req :: _ => req.originalLine } // Run the code body with the given boolean settings flipped to true. - def withoutWarnings[T](body: => T): T = beQuietDuring { + private def withoutWarnings[T](body: => T): T = beQuietDuring { val saved = settings.nowarn.value if (!saved) settings.nowarn.value = true @@ -164,16 +182,28 @@ import org.apache.spark.util.Utils def this(settings: Settings) = this(settings, new NewLinePrintWriter(new ConsoleWriter, true)) def this() = this(new Settings()) - lazy val repllog: Logger = new Logger { + private lazy val repllog: Logger = new Logger { val out: JPrintWriter = imain.out val isInfo: Boolean = BooleanProp keyExists "scala.repl.info" val isDebug: Boolean = BooleanProp keyExists "scala.repl.debug" val isTrace: Boolean = BooleanProp keyExists "scala.repl.trace" } - lazy val formatting: Formatting = new Formatting { + private[repl] lazy val formatting: Formatting = new Formatting { val prompt = Properties.shellPromptString } - lazy val reporter: ConsoleReporter = new SparkIMain.ReplReporter(this) + + // NOTE: Exposed to repl package since used by SparkExprTyper and SparkILoop + private[repl] lazy val reporter: ConsoleReporter = new SparkIMain.ReplReporter(this) + + /** + * Determines if errors were reported (typically during compilation). + * + * @note This is not for runtime errors + * + * @return True if had errors, otherwise false + */ + @DeveloperApi + def isReportingErrors = reporter.hasErrors import formatting._ import reporter.{ printMessage, withoutTruncating } @@ -193,7 +223,8 @@ import org.apache.spark.util.Utils private def tquoted(s: String) = "\"\"\"" + s + "\"\"\"" // argument is a thunk to execute after init is done - def initialize(postInitSignal: => Unit) { + // NOTE: Exposed to repl package since used by SparkILoop + private[repl] def initialize(postInitSignal: => Unit) { synchronized { if (_isInitialized == null) { _isInitialized = io.spawn { @@ -203,15 +234,27 @@ import org.apache.spark.util.Utils } } } + + /** + * Initializes the underlying compiler/interpreter in a blocking fashion. + * + * @note Must be executed before using SparkIMain! + */ + @DeveloperApi def initializeSynchronous(): Unit = { if (!isInitializeComplete) { _initialize() assert(global != null, global) } } - def isInitializeComplete = _initializeComplete + private def isInitializeComplete = _initializeComplete /** the public, go through the future compiler */ + + /** + * The underlying compiler used to generate ASTs and execute code. + */ + @DeveloperApi lazy val global: Global = { if (isInitializeComplete) _compiler else { @@ -226,13 +269,13 @@ import org.apache.spark.util.Utils } } @deprecated("Use `global` for access to the compiler instance.", "2.9.0") - lazy val compiler: global.type = global + private lazy val compiler: global.type = global import global._ import definitions.{ScalaPackage, JavaLangPackage, termMember, typeMember} import rootMirror.{RootClass, getClassIfDefined, getModuleIfDefined, getRequiredModule, getRequiredClass} - implicit class ReplTypeOps(tp: Type) { + private implicit class ReplTypeOps(tp: Type) { def orElse(other: => Type): Type = if (tp ne NoType) tp else other def andAlso(fn: Type => Type): Type = if (tp eq NoType) tp else fn(tp) } @@ -240,7 +283,8 @@ import org.apache.spark.util.Utils // TODO: If we try to make naming a lazy val, we run into big time // scalac unhappiness with what look like cycles. It has not been easy to // reduce, but name resolution clearly takes different paths. - object naming extends { + // NOTE: Exposed to repl package since used by SparkExprTyper + private[repl] object naming extends { val global: imain.global.type = imain.global } with Naming { // make sure we don't overwrite their unwisely named res3 etc. @@ -254,22 +298,43 @@ import org.apache.spark.util.Utils } import naming._ - object deconstruct extends { + // NOTE: Exposed to repl package since used by SparkILoop + private[repl] object deconstruct extends { val global: imain.global.type = imain.global } with StructuredTypeStrings - lazy val memberHandlers = new { + // NOTE: Exposed to repl package since used by SparkImports + private[repl] lazy val memberHandlers = new { val intp: imain.type = imain } with SparkMemberHandlers import memberHandlers._ - /** Temporarily be quiet */ + /** + * Suppresses overwriting print results during the operation. + * + * @param body The block to execute + * @tparam T The return type of the block + * + * @return The result from executing the block + */ + @DeveloperApi def beQuietDuring[T](body: => T): T = { val saved = printResults printResults = false try body finally printResults = saved } + + /** + * Completely masks all output during the operation (minus JVM standard + * out and error). + * + * @param operation The block to execute + * @tparam T The return type of the block + * + * @return The result from executing the block + */ + @DeveloperApi def beSilentDuring[T](operation: => T): T = { val saved = totalSilence totalSilence = true @@ -277,10 +342,10 @@ import org.apache.spark.util.Utils finally totalSilence = saved } - def quietRun[T](code: String) = beQuietDuring(interpret(code)) + // NOTE: Exposed to repl package since used by SparkILoop + private[repl] def quietRun[T](code: String) = beQuietDuring(interpret(code)) - - private def logAndDiscard[T](label: String, alt: => T): PartialFunction[Throwable, T] = { + private def logAndDiscard[T](label: String, alt: => T): PartialFunction[Throwable, T] = { case t: ControlThrowable => throw t case t: Throwable => logDebug(label + ": " + unwrap(t)) @@ -298,14 +363,44 @@ import org.apache.spark.util.Utils finally bindExceptions = true } + /** + * Contains the code (in string form) representing a wrapper around all + * code executed by this instance. + * + * @return The wrapper code as a string + */ + @DeveloperApi def executionWrapper = _executionWrapper + + /** + * Sets the code to use as a wrapper around all code executed by this + * instance. + * + * @param code The wrapper code as a string + */ + @DeveloperApi def setExecutionWrapper(code: String) = _executionWrapper = code + + /** + * Clears the code used as a wrapper around all code executed by + * this instance. + */ + @DeveloperApi def clearExecutionWrapper() = _executionWrapper = "" /** interpreter settings */ - lazy val isettings = new SparkISettings(this) + private lazy val isettings = new SparkISettings(this) - /** Instantiate a compiler. Overridable. */ + /** + * Instantiates a new compiler used by SparkIMain. Overridable to provide + * own instance of a compiler. + * + * @param settings The settings to provide the compiler + * @param reporter The reporter to use for compiler output + * + * @return The compiler as a Global + */ + @DeveloperApi protected def newCompiler(settings: Settings, reporter: Reporter): ReplGlobal = { settings.outputDirs setSingleOutput virtualDirectory settings.exposeEmptyPackage.value = true @@ -320,13 +415,14 @@ import org.apache.spark.util.Utils * @note Currently only supports jars, not directories * @param urls The list of items to add to the compile and runtime classpaths */ + @DeveloperApi def addUrlsToClassPath(urls: URL*): Unit = { new Run // Needed to force initialization of "something" to correctly load Scala classes from jars urls.foreach(_runtimeClassLoader.addNewUrl) // Add jars/classes to runtime for execution updateCompilerClassPath(urls: _*) // Add jars/classes to compile time for compiling } - protected def updateCompilerClassPath(urls: URL*): Unit = { + private def updateCompilerClassPath(urls: URL*): Unit = { require(!global.forMSIL) // Only support JavaPlatform val platform = global.platform.asInstanceOf[JavaPlatform] @@ -342,7 +438,7 @@ import org.apache.spark.util.Utils global.invalidateClassPathEntries(urls.map(_.getPath): _*) } - protected def mergeUrlsIntoClassPath(platform: JavaPlatform, urls: URL*): MergedClassPath[AbstractFile] = { + private def mergeUrlsIntoClassPath(platform: JavaPlatform, urls: URL*): MergedClassPath[AbstractFile] = { // Collect our new jars/directories and add them to the existing set of classpaths val allClassPaths = ( platform.classPath.asInstanceOf[MergedClassPath[AbstractFile]].entries ++ @@ -365,7 +461,13 @@ import org.apache.spark.util.Utils new MergedClassPath(allClassPaths, platform.classPath.context) } - /** Parent classloader. Overridable. */ + /** + * Represents the parent classloader used by this instance. Can be + * overridden to provide alternative classloader. + * + * @return The classloader used as the parent loader of this instance + */ + @DeveloperApi protected def parentClassLoader: ClassLoader = SparkHelper.explicitParentLoader(settings).getOrElse( this.getClass.getClassLoader() ) @@ -382,16 +484,18 @@ import org.apache.spark.util.Utils shadow the old ones, and old code objects refer to the old definitions. */ - def resetClassLoader() = { + private def resetClassLoader() = { logDebug("Setting new classloader: was " + _classLoader) _classLoader = null ensureClassLoader() } - final def ensureClassLoader() { + private final def ensureClassLoader() { if (_classLoader == null) _classLoader = makeClassLoader() } - def classLoader: AbstractFileClassLoader = { + + // NOTE: Exposed to repl package since used by SparkILoop + private[repl] def classLoader: AbstractFileClassLoader = { ensureClassLoader() _classLoader } @@ -418,27 +522,58 @@ import org.apache.spark.util.Utils _runtimeClassLoader }) - def getInterpreterClassLoader() = classLoader + private def getInterpreterClassLoader() = classLoader // Set the current Java "context" class loader to this interpreter's class loader - def setContextClassLoader() = classLoader.setAsContext() + // NOTE: Exposed to repl package since used by SparkILoopInit + private[repl] def setContextClassLoader() = classLoader.setAsContext() - /** Given a simple repl-defined name, returns the real name of - * the class representing it, e.g. for "Bippy" it may return - * {{{ - * $line19.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$Bippy - * }}} + /** + * Returns the real name of a class based on its repl-defined name. + * + * ==Example== + * Given a simple repl-defined name, returns the real name of + * the class representing it, e.g. for "Bippy" it may return + * {{{ + * $line19.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$Bippy + * }}} + * + * @param simpleName The repl-defined name whose real name to retrieve + * + * @return Some real name if the simple name exists, else None */ + @DeveloperApi def generatedName(simpleName: String): Option[String] = { if (simpleName endsWith nme.MODULE_SUFFIX_STRING) optFlatName(simpleName.init) map (_ + nme.MODULE_SUFFIX_STRING) else optFlatName(simpleName) } - def flatName(id: String) = optFlatName(id) getOrElse id - def optFlatName(id: String) = requestForIdent(id) map (_ fullFlatName id) + // NOTE: Exposed to repl package since used by SparkILoop + private[repl] def flatName(id: String) = optFlatName(id) getOrElse id + // NOTE: Exposed to repl package since used by SparkILoop + private[repl] def optFlatName(id: String) = requestForIdent(id) map (_ fullFlatName id) + + /** + * Retrieves all simple names contained in the current instance. + * + * @return A list of sorted names + */ + @DeveloperApi def allDefinedNames = definedNameMap.keys.toList.sorted - def pathToType(id: String): String = pathToName(newTypeName(id)) - def pathToTerm(id: String): String = pathToName(newTermName(id)) + + private def pathToType(id: String): String = pathToName(newTypeName(id)) + // NOTE: Exposed to repl package since used by SparkILoop + private[repl] def pathToTerm(id: String): String = pathToName(newTermName(id)) + + /** + * Retrieves the full code path to access the specified simple name + * content. + * + * @param name The simple name of the target whose path to determine + * + * @return The full path used to access the specified target (name) + */ + @DeveloperApi def pathToName(name: Name): String = { if (definedNameMap contains name) definedNameMap(name) fullPath name @@ -457,13 +592,13 @@ import org.apache.spark.util.Utils } /** Stubs for work in progress. */ - def handleTypeRedefinition(name: TypeName, old: Request, req: Request) = { + private def handleTypeRedefinition(name: TypeName, old: Request, req: Request) = { for (t1 <- old.simpleNameOfType(name) ; t2 <- req.simpleNameOfType(name)) { logDebug("Redefining type '%s'\n %s -> %s".format(name, t1, t2)) } } - def handleTermRedefinition(name: TermName, old: Request, req: Request) = { + private def handleTermRedefinition(name: TermName, old: Request, req: Request) = { for (t1 <- old.compilerTypeOf get name ; t2 <- req.compilerTypeOf get name) { // Printing the types here has a tendency to cause assertion errors, like // assertion failed: fatal: has owner value x, but a class owner is required @@ -473,7 +608,7 @@ import org.apache.spark.util.Utils } } - def recordRequest(req: Request) { + private def recordRequest(req: Request) { if (req == null || referencedNameMap == null) return @@ -504,12 +639,12 @@ import org.apache.spark.util.Utils } } - def replwarn(msg: => String) { + private def replwarn(msg: => String) { if (!settings.nowarnings.value) printMessage(msg) } - def isParseable(line: String): Boolean = { + private def isParseable(line: String): Boolean = { beSilentDuring { try parse(line) match { case Some(xs) => xs.nonEmpty // parses as-is @@ -522,22 +657,32 @@ import org.apache.spark.util.Utils } } - def compileSourcesKeepingRun(sources: SourceFile*) = { + private def compileSourcesKeepingRun(sources: SourceFile*) = { val run = new Run() reporter.reset() run compileSources sources.toList (!reporter.hasErrors, run) } - /** Compile an nsc SourceFile. Returns true if there are - * no compilation errors, or false otherwise. + /** + * Compiles specified source files. + * + * @param sources The sequence of source files to compile + * + * @return True if successful, otherwise false */ + @DeveloperApi def compileSources(sources: SourceFile*): Boolean = compileSourcesKeepingRun(sources: _*)._1 - /** Compile a string. Returns true if there are no - * compilation errors, or false otherwise. + /** + * Compiles a string of code. + * + * @param code The string of code to compile + * + * @return True if successful, otherwise false */ + @DeveloperApi def compileString(code: String): Boolean = compileSources(new BatchSourceFile("