diff --git a/assembly/pom.xml b/assembly/pom.xml index 12940adc54221..56ae693543345 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -204,16 +204,16 @@ - - hbase - - - org.apache.spark - spark-hbase_${scala.binary.version} - ${project.version} - - - + + hbase + + + org.apache.spark + spark-hbase_${scala.binary.version} + ${project.version} + + + spark-ganglia-lgpl diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh index b7943aacacd06..2ac345ba5cd04 100755 --- a/bin/compute-classpath.sh +++ b/bin/compute-classpath.sh @@ -114,10 +114,6 @@ fi datanucleus_jars="$(find "$datanucleus_dir" 2>/dev/null | grep "datanucleus-.*\\.jar")" datanucleus_jars="$(echo "$datanucleus_jars" | tr "\n" : | sed s/:$//g)" -hive_files=$("$JAR_CMD" -tf "$ASSEMBLY_JAR" org/apache/hadoop/hive/ql/exec 2>/dev/null) - -hive_files=$("$JAR_CMD" -tf "$ASSEMBLY_JAR" org/apache/hadoop/hive/ql/exec 2>/dev/null) - if [ -n "$datanucleus_jars" ]; then hive_files=$("$JAR_CMD" -tf "$ASSEMBLY_JAR" org/apache/hadoop/hive/ql/exec 2>/dev/null) if [ -n "$hive_files" ]; then @@ -126,7 +122,6 @@ if [ -n "$datanucleus_jars" ]; then fi fi - # Add test classes if we're running from SBT or Maven with SPARK_TESTING set to 1 if [[ $SPARK_TESTING == 1 ]]; then CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SCALA_VERSION/test-classes" @@ -137,8 +132,8 @@ if [[ $SPARK_TESTING == 1 ]]; then CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SCALA_VERSION/test-classes" CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SCALA_VERSION/test-classes" CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SCALA_VERSION/test-classes" CLASSPATH="$CLASSPATH:$FWDIR/sql/hbase/target/scala-$SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SCALA_VERSION/test-classes" fi # Add hadoop conf dir if given -- otherwise FileSystem.*, etc fail ! @@ -150,5 +145,5 @@ fi if [ -n "$YARN_CONF_DIR" ]; then CLASSPATH="$CLASSPATH:$YARN_CONF_DIR" fi -echo "$CLASSPATH" +echo "$CLASSPATH" diff --git a/bin/pyspark b/bin/pyspark index 6655725ef8e8e..96f30a260a09e 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -50,22 +50,47 @@ fi . "$FWDIR"/bin/load-spark-env.sh -# Figure out which Python executable to use +# In Spark <= 1.1, setting IPYTHON=1 would cause the driver to be launched using the `ipython` +# executable, while the worker would still be launched using PYSPARK_PYTHON. +# +# In Spark 1.2, we removed the documentation of the IPYTHON and IPYTHON_OPTS variables and added +# PYSPARK_DRIVER_PYTHON and PYSPARK_DRIVER_PYTHON_OPTS to allow IPython to be used for the driver. +# Now, users can simply set PYSPARK_DRIVER_PYTHON=ipython to use IPython and set +# PYSPARK_DRIVER_PYTHON_OPTS to pass options when starting the Python driver +# (e.g. PYSPARK_DRIVER_PYTHON_OPTS='notebook'). This supports full customization of the IPython +# and executor Python executables. +# +# For backwards-compatibility, we retain the old IPYTHON and IPYTHON_OPTS variables. + +# Determine the Python executable to use if PYSPARK_PYTHON or PYSPARK_DRIVER_PYTHON isn't set: +if hash python2.7 2>/dev/null; then + # Attempt to use Python 2.7, if installed: + DEFAULT_PYTHON="python2.7" +else + DEFAULT_PYTHON="python" +fi + +# Determine the Python executable to use for the driver: +if [[ -n "$IPYTHON_OPTS" || "$IPYTHON" == "1" ]]; then + # If IPython options are specified, assume user wants to run IPython + # (for backwards-compatibility) + PYSPARK_DRIVER_PYTHON_OPTS="$PYSPARK_DRIVER_PYTHON_OPTS $IPYTHON_OPTS" + PYSPARK_DRIVER_PYTHON="ipython" +elif [[ -z "$PYSPARK_DRIVER_PYTHON" ]]; then + PYSPARK_DRIVER_PYTHON="${PYSPARK_PYTHON:-"$DEFAULT_PYTHON"}" +fi + +# Determine the Python executable to use for the executors: if [[ -z "$PYSPARK_PYTHON" ]]; then - if [[ "$IPYTHON" = "1" || -n "$IPYTHON_OPTS" ]]; then - # for backward compatibility - PYSPARK_PYTHON="ipython" + if [[ $PYSPARK_DRIVER_PYTHON == *ipython* && $DEFAULT_PYTHON != "python2.7" ]]; then + echo "IPython requires Python 2.7+; please install python2.7 or set PYSPARK_PYTHON" 1>&2 + exit 1 else - PYSPARK_PYTHON="python" + PYSPARK_PYTHON="$DEFAULT_PYTHON" fi fi export PYSPARK_PYTHON -if [[ -z "$PYSPARK_PYTHON_OPTS" && -n "$IPYTHON_OPTS" ]]; then - # for backward compatibility - PYSPARK_PYTHON_OPTS="$IPYTHON_OPTS" -fi - # Add the PySpark classes to the Python path: export PYTHONPATH="$SPARK_HOME/python/:$PYTHONPATH" export PYTHONPATH="$SPARK_HOME/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH" @@ -93,9 +118,9 @@ if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR unset HADOOP_CONF_DIR if [[ -n "$PYSPARK_DOC_TEST" ]]; then - exec "$PYSPARK_PYTHON" -m doctest $1 + exec "$PYSPARK_DRIVER_PYTHON" -m doctest $1 else - exec "$PYSPARK_PYTHON" $1 + exec "$PYSPARK_DRIVER_PYTHON" $1 fi exit fi @@ -111,5 +136,5 @@ if [[ "$1" =~ \.py$ ]]; then else # PySpark shell requires special handling downstream export PYSPARK_SHELL=1 - exec "$PYSPARK_PYTHON" $PYSPARK_PYTHON_OPTS + exec "$PYSPARK_DRIVER_PYTHON" $PYSPARK_DRIVER_PYTHON_OPTS fi diff --git a/bin/spark-class b/bin/spark-class index e8201c18d52de..91d858bc063d0 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -105,7 +105,7 @@ else exit 1 fi fi -JAVA_VERSION=$("$RUNNER" -version 2>&1 | sed 's/.* version "\(.*\)\.\(.*\)\..*"/\1\2/; 1q') +JAVA_VERSION=$("$RUNNER" -version 2>&1 | grep 'version' | sed 's/.* version "\(.*\)\.\(.*\)\..*"/\1\2/; 1q') # Set JAVA_OPTS to be able to load native libraries and to set heap size if [ "$JAVA_VERSION" -ge 18 ]; then diff --git a/bin/spark-shell.cmd b/bin/spark-shell.cmd index 2ee60b4e2a2b3..8f90ba5a0b3b8 100755 --- a/bin/spark-shell.cmd +++ b/bin/spark-shell.cmd @@ -17,6 +17,7 @@ rem See the License for the specific language governing permissions and rem limitations under the License. rem -set SPARK_HOME=%~dp0.. +rem This is the entry point for running Spark shell. To avoid polluting the +rem environment, it just launches a new cmd to do the real work. -cmd /V /E /C %SPARK_HOME%\bin\spark-submit.cmd --class org.apache.spark.repl.Main %* spark-shell +cmd /V /E /C %~dp0spark-shell2.cmd %* diff --git a/bin/spark-shell2.cmd b/bin/spark-shell2.cmd new file mode 100644 index 0000000000000..2ee60b4e2a2b3 --- /dev/null +++ b/bin/spark-shell2.cmd @@ -0,0 +1,22 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +set SPARK_HOME=%~dp0.. + +cmd /V /E /C %SPARK_HOME%\bin\spark-submit.cmd --class org.apache.spark.repl.Main %* spark-shell diff --git a/bin/spark-submit.cmd b/bin/spark-submit.cmd index cf6046d1547ad..8f3b84c7b971d 100644 --- a/bin/spark-submit.cmd +++ b/bin/spark-submit.cmd @@ -17,52 +17,7 @@ rem See the License for the specific language governing permissions and rem limitations under the License. rem -rem NOTE: Any changes in this file must be reflected in SparkSubmitDriverBootstrapper.scala! +rem This is the entry point for running Spark submit. To avoid polluting the +rem environment, it just launches a new cmd to do the real work. -set SPARK_HOME=%~dp0.. -set ORIG_ARGS=%* - -rem Reset the values of all variables used -set SPARK_SUBMIT_DEPLOY_MODE=client -set SPARK_SUBMIT_PROPERTIES_FILE=%SPARK_HOME%\conf\spark-defaults.conf -set SPARK_SUBMIT_DRIVER_MEMORY= -set SPARK_SUBMIT_LIBRARY_PATH= -set SPARK_SUBMIT_CLASSPATH= -set SPARK_SUBMIT_OPTS= -set SPARK_SUBMIT_BOOTSTRAP_DRIVER= - -:loop -if [%1] == [] goto continue - if [%1] == [--deploy-mode] ( - set SPARK_SUBMIT_DEPLOY_MODE=%2 - ) else if [%1] == [--properties-file] ( - set SPARK_SUBMIT_PROPERTIES_FILE=%2 - ) else if [%1] == [--driver-memory] ( - set SPARK_SUBMIT_DRIVER_MEMORY=%2 - ) else if [%1] == [--driver-library-path] ( - set SPARK_SUBMIT_LIBRARY_PATH=%2 - ) else if [%1] == [--driver-class-path] ( - set SPARK_SUBMIT_CLASSPATH=%2 - ) else if [%1] == [--driver-java-options] ( - set SPARK_SUBMIT_OPTS=%2 - ) - shift -goto loop -:continue - -rem For client mode, the driver will be launched in the same JVM that launches -rem SparkSubmit, so we may need to read the properties file for any extra class -rem paths, library paths, java options and memory early on. Otherwise, it will -rem be too late by the time the driver JVM has started. - -if [%SPARK_SUBMIT_DEPLOY_MODE%] == [client] ( - if exist %SPARK_SUBMIT_PROPERTIES_FILE% ( - rem Parse the properties file only if the special configs exist - for /f %%i in ('findstr /r /c:"^[\t ]*spark.driver.memory" /c:"^[\t ]*spark.driver.extra" ^ - %SPARK_SUBMIT_PROPERTIES_FILE%') do ( - set SPARK_SUBMIT_BOOTSTRAP_DRIVER=1 - ) - ) -) - -cmd /V /E /C %SPARK_HOME%\bin\spark-class.cmd org.apache.spark.deploy.SparkSubmit %ORIG_ARGS% +cmd /V /E /C %~dp0spark-submit2.cmd %* diff --git a/bin/spark-submit2.cmd b/bin/spark-submit2.cmd new file mode 100644 index 0000000000000..cf6046d1547ad --- /dev/null +++ b/bin/spark-submit2.cmd @@ -0,0 +1,68 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem NOTE: Any changes in this file must be reflected in SparkSubmitDriverBootstrapper.scala! + +set SPARK_HOME=%~dp0.. +set ORIG_ARGS=%* + +rem Reset the values of all variables used +set SPARK_SUBMIT_DEPLOY_MODE=client +set SPARK_SUBMIT_PROPERTIES_FILE=%SPARK_HOME%\conf\spark-defaults.conf +set SPARK_SUBMIT_DRIVER_MEMORY= +set SPARK_SUBMIT_LIBRARY_PATH= +set SPARK_SUBMIT_CLASSPATH= +set SPARK_SUBMIT_OPTS= +set SPARK_SUBMIT_BOOTSTRAP_DRIVER= + +:loop +if [%1] == [] goto continue + if [%1] == [--deploy-mode] ( + set SPARK_SUBMIT_DEPLOY_MODE=%2 + ) else if [%1] == [--properties-file] ( + set SPARK_SUBMIT_PROPERTIES_FILE=%2 + ) else if [%1] == [--driver-memory] ( + set SPARK_SUBMIT_DRIVER_MEMORY=%2 + ) else if [%1] == [--driver-library-path] ( + set SPARK_SUBMIT_LIBRARY_PATH=%2 + ) else if [%1] == [--driver-class-path] ( + set SPARK_SUBMIT_CLASSPATH=%2 + ) else if [%1] == [--driver-java-options] ( + set SPARK_SUBMIT_OPTS=%2 + ) + shift +goto loop +:continue + +rem For client mode, the driver will be launched in the same JVM that launches +rem SparkSubmit, so we may need to read the properties file for any extra class +rem paths, library paths, java options and memory early on. Otherwise, it will +rem be too late by the time the driver JVM has started. + +if [%SPARK_SUBMIT_DEPLOY_MODE%] == [client] ( + if exist %SPARK_SUBMIT_PROPERTIES_FILE% ( + rem Parse the properties file only if the special configs exist + for /f %%i in ('findstr /r /c:"^[\t ]*spark.driver.memory" /c:"^[\t ]*spark.driver.extra" ^ + %SPARK_SUBMIT_PROPERTIES_FILE%') do ( + set SPARK_SUBMIT_BOOTSTRAP_DRIVER=1 + ) + ) +) + +cmd /V /E /C %SPARK_HOME%\bin\spark-class.cmd org.apache.spark.deploy.SparkSubmit %ORIG_ARGS% diff --git a/core/src/main/java/org/apache/spark/TaskContext.java b/core/src/main/java/org/apache/spark/TaskContext.java index 4e6d708af0ea7..2d998d4c7a5d9 100644 --- a/core/src/main/java/org/apache/spark/TaskContext.java +++ b/core/src/main/java/org/apache/spark/TaskContext.java @@ -18,131 +18,55 @@ package org.apache.spark; import java.io.Serializable; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; import scala.Function0; import scala.Function1; import scala.Unit; -import scala.collection.JavaConversions; import org.apache.spark.annotation.DeveloperApi; import org.apache.spark.executor.TaskMetrics; import org.apache.spark.util.TaskCompletionListener; -import org.apache.spark.util.TaskCompletionListenerException; /** -* :: DeveloperApi :: -* Contextual information about a task which can be read or mutated during execution. -*/ -@DeveloperApi -public class TaskContext implements Serializable { - - private int stageId; - private int partitionId; - private long attemptId; - private boolean runningLocally; - private TaskMetrics taskMetrics; - - /** - * :: DeveloperApi :: - * Contextual information about a task which can be read or mutated during execution. - * - * @param stageId stage id - * @param partitionId index of the partition - * @param attemptId the number of attempts to execute this task - * @param runningLocally whether the task is running locally in the driver JVM - * @param taskMetrics performance metrics of the task - */ - @DeveloperApi - public TaskContext(int stageId, int partitionId, long attemptId, boolean runningLocally, - TaskMetrics taskMetrics) { - this.attemptId = attemptId; - this.partitionId = partitionId; - this.runningLocally = runningLocally; - this.stageId = stageId; - this.taskMetrics = taskMetrics; - } - - /** - * :: DeveloperApi :: - * Contextual information about a task which can be read or mutated during execution. - * - * @param stageId stage id - * @param partitionId index of the partition - * @param attemptId the number of attempts to execute this task - * @param runningLocally whether the task is running locally in the driver JVM - */ - @DeveloperApi - public TaskContext(int stageId, int partitionId, long attemptId, boolean runningLocally) { - this.attemptId = attemptId; - this.partitionId = partitionId; - this.runningLocally = runningLocally; - this.stageId = stageId; - this.taskMetrics = TaskMetrics.empty(); - } - + * Contextual information about a task which can be read or mutated during + * execution. To access the TaskContext for a running task use + * TaskContext.get(). + */ +public abstract class TaskContext implements Serializable { /** - * :: DeveloperApi :: - * Contextual information about a task which can be read or mutated during execution. - * - * @param stageId stage id - * @param partitionId index of the partition - * @param attemptId the number of attempts to execute this task + * Return the currently active TaskContext. This can be called inside of + * user functions to access contextual information about running tasks. */ - @DeveloperApi - public TaskContext(int stageId, int partitionId, long attemptId) { - this.attemptId = attemptId; - this.partitionId = partitionId; - this.runningLocally = false; - this.stageId = stageId; - this.taskMetrics = TaskMetrics.empty(); + public static TaskContext get() { + return taskContext.get(); } private static ThreadLocal taskContext = new ThreadLocal(); - /** - * :: Internal API :: - * This is spark internal API, not intended to be called from user programs. - */ - public static void setTaskContext(TaskContext tc) { + static void setTaskContext(TaskContext tc) { taskContext.set(tc); } - public static TaskContext get() { - return taskContext.get(); - } - - /** :: Internal API :: */ - public static void unset() { + static void unset() { taskContext.remove(); } - // List of callback functions to execute when the task completes. - private transient List onCompleteCallbacks = - new ArrayList(); - - // Whether the corresponding task has been killed. - private volatile boolean interrupted = false; - - // Whether the task has completed. - private volatile boolean completed = false; - /** - * Checks whether the task has completed. + * Whether the task has completed. */ - public boolean isCompleted() { - return completed; - } + public abstract boolean isCompleted(); /** - * Checks whether the task has been killed. + * Whether the task has been killed. */ - public boolean isInterrupted() { - return interrupted; - } + public abstract boolean isInterrupted(); + + /** @deprecated: use isRunningLocally() */ + @Deprecated + public abstract boolean runningLocally(); + + public abstract boolean isRunningLocally(); /** * Add a (Java friendly) listener to be executed on task completion. @@ -150,10 +74,7 @@ public boolean isInterrupted() { *

* An example use is for HadoopRDD to register a callback to close the input stream. */ - public TaskContext addTaskCompletionListener(TaskCompletionListener listener) { - onCompleteCallbacks.add(listener); - return this; - } + public abstract TaskContext addTaskCompletionListener(TaskCompletionListener listener); /** * Add a listener in the form of a Scala closure to be executed on task completion. @@ -161,109 +82,27 @@ public TaskContext addTaskCompletionListener(TaskCompletionListener listener) { *

* An example use is for HadoopRDD to register a callback to close the input stream. */ - public TaskContext addTaskCompletionListener(final Function1 f) { - onCompleteCallbacks.add(new TaskCompletionListener() { - @Override - public void onTaskCompletion(TaskContext context) { - f.apply(context); - } - }); - return this; - } + public abstract TaskContext addTaskCompletionListener(final Function1 f); /** * Add a callback function to be executed on task completion. An example use * is for HadoopRDD to register a callback to close the input stream. * Will be called in any situation - success, failure, or cancellation. * - * Deprecated: use addTaskCompletionListener - * + * @deprecated: use addTaskCompletionListener + * * @param f Callback function. */ @Deprecated - public void addOnCompleteCallback(final Function0 f) { - onCompleteCallbacks.add(new TaskCompletionListener() { - @Override - public void onTaskCompletion(TaskContext context) { - f.apply(); - } - }); - } - - /** - * ::Internal API:: - * Marks the task as completed and triggers the listeners. - */ - public void markTaskCompleted() throws TaskCompletionListenerException { - completed = true; - List errorMsgs = new ArrayList(2); - // Process complete callbacks in the reverse order of registration - List revlist = - new ArrayList(onCompleteCallbacks); - Collections.reverse(revlist); - for (TaskCompletionListener tcl: revlist) { - try { - tcl.onTaskCompletion(this); - } catch (Throwable e) { - errorMsgs.add(e.getMessage()); - } - } - - if (!errorMsgs.isEmpty()) { - throw new TaskCompletionListenerException(JavaConversions.asScalaBuffer(errorMsgs)); - } - } - - /** - * ::Internal API:: - * Marks the task for interruption, i.e. cancellation. - */ - public void markInterrupted() { - interrupted = true; - } - - @Deprecated - /** Deprecated: use getStageId() */ - public int stageId() { - return stageId; - } - - @Deprecated - /** Deprecated: use getPartitionId() */ - public int partitionId() { - return partitionId; - } - - @Deprecated - /** Deprecated: use getAttemptId() */ - public long attemptId() { - return attemptId; - } - - @Deprecated - /** Deprecated: use isRunningLocally() */ - public boolean runningLocally() { - return runningLocally; - } - - public boolean isRunningLocally() { - return runningLocally; - } + public abstract void addOnCompleteCallback(final Function0 f); - public int getStageId() { - return stageId; - } + public abstract int stageId(); - public int getPartitionId() { - return partitionId; - } + public abstract int partitionId(); - public long getAttemptId() { - return attemptId; - } + public abstract long attemptId(); - /** ::Internal API:: */ - public TaskMetrics taskMetrics() { - return taskMetrics; - } + /** ::DeveloperApi:: */ + @DeveloperApi + public abstract TaskMetrics taskMetrics(); } diff --git a/core/src/main/java/org/apache/spark/api/java/JavaFutureAction.java b/core/src/main/java/org/apache/spark/api/java/JavaFutureAction.java new file mode 100644 index 0000000000000..0ad189633e427 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/JavaFutureAction.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java; + + +import java.util.List; +import java.util.concurrent.Future; + +public interface JavaFutureAction extends Future { + + /** + * Returns the job IDs run by the underlying async operation. + * + * This returns the current snapshot of the job list. Certain operations may run multiple + * jobs, so multiple calls to this method may return different lists. + */ + List jobIds(); +} diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index e8f761eaa5799..d5c8f9d76c476 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -17,20 +17,21 @@ package org.apache.spark -import scala.concurrent._ -import scala.concurrent.duration.Duration -import scala.util.Try +import java.util.Collections +import java.util.concurrent.TimeUnit -import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaFutureAction import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{JobFailed, JobSucceeded, JobWaiter} +import scala.concurrent._ +import scala.concurrent.duration.Duration +import scala.util.{Failure, Try} + /** - * :: Experimental :: * A future for the result of an action to support cancellation. This is an extension of the * Scala Future interface to support cancellation. */ -@Experimental trait FutureAction[T] extends Future[T] { // Note that we redefine methods of the Future trait here explicitly so we can specify a different // documentation (with reference to the word "action"). @@ -69,6 +70,11 @@ trait FutureAction[T] extends Future[T] { */ override def isCompleted: Boolean + /** + * Returns whether the action has been cancelled. + */ + def isCancelled: Boolean + /** * The value of this Future. * @@ -96,15 +102,16 @@ trait FutureAction[T] extends Future[T] { /** - * :: Experimental :: * A [[FutureAction]] holding the result of an action that triggers a single job. Examples include * count, collect, reduce. */ -@Experimental class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T) extends FutureAction[T] { + @volatile private var _cancelled: Boolean = false + override def cancel() { + _cancelled = true jobWaiter.cancel() } @@ -143,6 +150,8 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: } override def isCompleted: Boolean = jobWaiter.jobFinished + + override def isCancelled: Boolean = _cancelled override def value: Option[Try[T]] = { if (jobWaiter.jobFinished) { @@ -164,12 +173,10 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: /** - * :: Experimental :: * A [[FutureAction]] for actions that could trigger multiple Spark jobs. Examples include take, * takeSample. Cancellation works by setting the cancelled flag to true and interrupting the * action thread if it is being blocked by a job. */ -@Experimental class ComplexFutureAction[T] extends FutureAction[T] { // Pointer to the thread that is executing the action. It is set when the action is run. @@ -222,7 +229,7 @@ class ComplexFutureAction[T] extends FutureAction[T] { // If the action hasn't been cancelled yet, submit the job. The check and the submitJob // command need to be in an atomic block. val job = this.synchronized { - if (!cancelled) { + if (!isCancelled) { rdd.context.submitJob(rdd, processPartition, partitions, resultHandler, resultFunc) } else { throw new SparkException("Action has been cancelled") @@ -243,10 +250,7 @@ class ComplexFutureAction[T] extends FutureAction[T] { } } - /** - * Returns whether the promise has been cancelled. - */ - def cancelled: Boolean = _cancelled + override def isCancelled: Boolean = _cancelled @throws(classOf[InterruptedException]) @throws(classOf[scala.concurrent.TimeoutException]) @@ -271,3 +275,55 @@ class ComplexFutureAction[T] extends FutureAction[T] { def jobIds = jobs } + +private[spark] +class JavaFutureActionWrapper[S, T](futureAction: FutureAction[S], converter: S => T) + extends JavaFutureAction[T] { + + import scala.collection.JavaConverters._ + + override def isCancelled: Boolean = futureAction.isCancelled + + override def isDone: Boolean = { + // According to java.util.Future's Javadoc, this returns True if the task was completed, + // whether that completion was due to successful execution, an exception, or a cancellation. + futureAction.isCancelled || futureAction.isCompleted + } + + override def jobIds(): java.util.List[java.lang.Integer] = { + Collections.unmodifiableList(futureAction.jobIds.map(Integer.valueOf).asJava) + } + + private def getImpl(timeout: Duration): T = { + // This will throw TimeoutException on timeout: + Await.ready(futureAction, timeout) + futureAction.value.get match { + case scala.util.Success(value) => converter(value) + case Failure(exception) => + if (isCancelled) { + throw new CancellationException("Job cancelled").initCause(exception) + } else { + // java.util.Future.get() wraps exceptions in ExecutionException + throw new ExecutionException("Exception thrown by job", exception) + } + } + } + + override def get(): T = getImpl(Duration.Inf) + + override def get(timeout: Long, unit: TimeUnit): T = + getImpl(Duration.fromNanos(unit.toNanos(timeout))) + + override def cancel(mayInterruptIfRunning: Boolean): Boolean = synchronized { + if (isDone) { + // According to java.util.Future's Javadoc, this should return false if the task is completed. + false + } else { + // We're limited in terms of the semantics we can provide here; our cancellation is + // asynchronous and doesn't provide a mechanism to not cancel if the job is running. + futureAction.cancel() + true + } + } + +} diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 396cdd1247e07..ac7935b8c231e 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -21,6 +21,7 @@ import scala.language.implicitConversions import java.io._ import java.net.URI +import java.util.Arrays import java.util.concurrent.atomic.AtomicInteger import java.util.{Properties, UUID} import java.util.UUID.randomUUID @@ -237,6 +238,9 @@ class SparkContext(config: SparkConf) extends Logging { // For tests, do not enable the UI None } + + // Bind the UI before starting the task scheduler to communicate + // the bound port to the cluster manager properly ui.foreach(_.bind()) /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ @@ -814,6 +818,8 @@ class SparkContext(config: SparkConf) extends Logging { */ def broadcast[T: ClassTag](value: T): Broadcast[T] = { val bc = env.broadcastManager.newBroadcast[T](value, isLocal) + val callSite = getCallSite + logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm) cleaner.foreach(_.registerBroadcastForCleanup(bc)) bc } @@ -1429,7 +1435,10 @@ object SparkContext extends Logging { simpleWritableConverter[Boolean, BooleanWritable](_.get) implicit def bytesWritableConverter(): WritableConverter[Array[Byte]] = { - simpleWritableConverter[Array[Byte], BytesWritable](_.getBytes) + simpleWritableConverter[Array[Byte], BytesWritable](bw => + // getBytes method returns array which is longer then data to be returned + Arrays.copyOfRange(bw.getBytes, 0, bw.getLength) + ) } implicit def stringWritableConverter(): WritableConverter[String] = diff --git a/core/src/main/scala/org/apache/spark/TaskContextHelper.scala b/core/src/main/scala/org/apache/spark/TaskContextHelper.scala new file mode 100644 index 0000000000000..4636c4600a01a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/TaskContextHelper.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +/** + * This class exists to restrict the visibility of TaskContext setters. + */ +private [spark] object TaskContextHelper { + + def setTaskContext(tc: TaskContext): Unit = TaskContext.setTaskContext(tc) + + def unset(): Unit = TaskContext.unset() + +} diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala new file mode 100644 index 0000000000000..afd2b85d33a77 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} + +import scala.collection.mutable.ArrayBuffer + +private[spark] class TaskContextImpl(val stageId: Int, + val partitionId: Int, + val attemptId: Long, + val runningLocally: Boolean = false, + val taskMetrics: TaskMetrics = TaskMetrics.empty) + extends TaskContext + with Logging { + + // List of callback functions to execute when the task completes. + @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener] + + // Whether the corresponding task has been killed. + @volatile private var interrupted: Boolean = false + + // Whether the task has completed. + @volatile private var completed: Boolean = false + + override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = { + onCompleteCallbacks += listener + this + } + + override def addTaskCompletionListener(f: TaskContext => Unit): this.type = { + onCompleteCallbacks += new TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = f(context) + } + this + } + + @deprecated("use addTaskCompletionListener", "1.1.0") + override def addOnCompleteCallback(f: () => Unit) { + onCompleteCallbacks += new TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = f() + } + } + + /** Marks the task as completed and triggers the listeners. */ + private[spark] def markTaskCompleted(): Unit = { + completed = true + val errorMsgs = new ArrayBuffer[String](2) + // Process complete callbacks in the reverse order of registration + onCompleteCallbacks.reverse.foreach { listener => + try { + listener.onTaskCompletion(this) + } catch { + case e: Throwable => + errorMsgs += e.getMessage + logError("Error in TaskCompletionListener", e) + } + } + if (errorMsgs.nonEmpty) { + throw new TaskCompletionListenerException(errorMsgs) + } + } + + /** Marks the task for interruption, i.e. cancellation. */ + private[spark] def markInterrupted(): Unit = { + interrupted = true + } + + override def isCompleted: Boolean = completed + + override def isRunningLocally: Boolean = runningLocally + + override def isInterrupted: Boolean = interrupted +} + diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 8ca731038e528..e72826dc25f41 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -26,6 +26,8 @@ import scala.collection.JavaConversions._ import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider} import com.google.common.io.Files +import org.apache.spark.util.Utils + /** * Utilities for tests. Included in main codebase since it's used by multiple * projects. @@ -42,8 +44,7 @@ private[spark] object TestUtils { * in order to avoid interference between tests. */ def createJarWithClasses(classNames: Seq[String], value: String = ""): URL = { - val tempDir = Files.createTempDir() - tempDir.deleteOnExit() + val tempDir = Utils.createTempDir() val files = for (name <- classNames) yield createCompiledClass(name, tempDir, value) val jarFile = new File(tempDir, "testJar-%s.jar".format(System.currentTimeMillis())) createJar(files, jarFile) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 0846225e4f992..c38b96528d037 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -35,6 +35,7 @@ import org.apache.spark.Partitioner._ import org.apache.spark.SparkContext.rddToPairRDDFunctions import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaSparkContext.fakeClassTag +import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, PairFunction} import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.rdd.{OrderedRDDFunctions, RDD} @@ -265,10 +266,10 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * before sending results to a reducer, similarly to a "combiner" in MapReduce. */ def reduceByKeyLocally(func: JFunction2[V, V, V]): java.util.Map[K, V] = - mapAsJavaMap(rdd.reduceByKeyLocally(func)) + mapAsSerializableJavaMap(rdd.reduceByKeyLocally(func)) /** Count the number of elements for each key, and return the result to the master as a Map. */ - def countByKey(): java.util.Map[K, Long] = mapAsJavaMap(rdd.countByKey()) + def countByKey(): java.util.Map[K, Long] = mapAsSerializableJavaMap(rdd.countByKey()) /** * :: Experimental :: @@ -277,7 +278,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) */ @Experimental def countByKeyApprox(timeout: Long): PartialResult[java.util.Map[K, BoundedDouble]] = - rdd.countByKeyApprox(timeout).map(mapAsJavaMap) + rdd.countByKeyApprox(timeout).map(mapAsSerializableJavaMap) /** * :: Experimental :: @@ -287,7 +288,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) @Experimental def countByKeyApprox(timeout: Long, confidence: Double = 0.95) : PartialResult[java.util.Map[K, BoundedDouble]] = - rdd.countByKeyApprox(timeout, confidence).map(mapAsJavaMap) + rdd.countByKeyApprox(timeout, confidence).map(mapAsSerializableJavaMap) /** * Aggregate the values of each key, using given combine functions and a neutral "zero value". @@ -614,7 +615,8 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) /** * Return the key-value pairs in this RDD to the master as a Map. */ - def collectAsMap(): java.util.Map[K, V] = mapAsJavaMap(rdd.collectAsMap()) + def collectAsMap(): java.util.Map[K, V] = mapAsSerializableJavaMap(rdd.collectAsMap()) + /** * Pass each value in the key-value pair RDD through a map function without changing the keys; diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 545bc0e9e99ed..efb8978f7ce12 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -21,15 +21,18 @@ import java.util.{Comparator, List => JList, Iterator => JIterator} import java.lang.{Iterable => JIterable, Long => JLong} import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import com.google.common.base.Optional import org.apache.hadoop.io.compress.CompressionCodec -import org.apache.spark.{FutureAction, Partition, SparkContext, TaskContext} +import org.apache.spark._ +import org.apache.spark.SparkContext._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaPairRDD._ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag +import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, _} import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.rdd.RDD @@ -293,8 +296,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Applies a function f to all elements of this RDD. */ def foreach(f: VoidFunction[T]) { - val cleanF = rdd.context.clean((x: T) => f.call(x)) - rdd.foreach(cleanF) + rdd.foreach(x => f.call(x)) } /** @@ -390,7 +392,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * combine step happens locally on the master, equivalent to running a single reduce task. */ def countByValue(): java.util.Map[T, java.lang.Long] = - mapAsJavaMap(rdd.countByValue().map((x => (x._1, new java.lang.Long(x._2))))) + mapAsSerializableJavaMap(rdd.countByValue().map((x => (x._1, new java.lang.Long(x._2))))) /** * (Experimental) Approximate version of countByValue(). @@ -399,13 +401,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { timeout: Long, confidence: Double ): PartialResult[java.util.Map[T, BoundedDouble]] = - rdd.countByValueApprox(timeout, confidence).map(mapAsJavaMap) + rdd.countByValueApprox(timeout, confidence).map(mapAsSerializableJavaMap) /** * (Experimental) Approximate version of countByValue(). */ def countByValueApprox(timeout: Long): PartialResult[java.util.Map[T, BoundedDouble]] = - rdd.countByValueApprox(timeout).map(mapAsJavaMap) + rdd.countByValueApprox(timeout).map(mapAsSerializableJavaMap) /** * Take the first num elements of the RDD. This currently scans the partitions *one by one*, so @@ -575,16 +577,44 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def name(): String = rdd.name /** - * :: Experimental :: - * The asynchronous version of the foreach action. - * - * @param f the function to apply to all the elements of the RDD - * @return a FutureAction for the action + * The asynchronous version of `count`, which returns a + * future for counting the number of elements in this RDD. */ - @Experimental - def foreachAsync(f: VoidFunction[T]): FutureAction[Unit] = { - import org.apache.spark.SparkContext._ - rdd.foreachAsync(x => f.call(x)) + def countAsync(): JavaFutureAction[JLong] = { + new JavaFutureActionWrapper[Long, JLong](rdd.countAsync(), JLong.valueOf) + } + + /** + * The asynchronous version of `collect`, which returns a future for + * retrieving an array containing all of the elements in this RDD. + */ + def collectAsync(): JavaFutureAction[JList[T]] = { + new JavaFutureActionWrapper(rdd.collectAsync(), (x: Seq[T]) => x.asJava) + } + + /** + * The asynchronous version of the `take` action, which returns a + * future for retrieving the first `num` elements of this RDD. + */ + def takeAsync(num: Int): JavaFutureAction[JList[T]] = { + new JavaFutureActionWrapper(rdd.takeAsync(num), (x: Seq[T]) => x.asJava) } + /** + * The asynchronous version of the `foreach` action, which + * applies a function f to all the elements of this RDD. + */ + def foreachAsync(f: VoidFunction[T]): JavaFutureAction[Void] = { + new JavaFutureActionWrapper[Unit, Void](rdd.foreachAsync(x => f.call(x)), + { x => null.asInstanceOf[Void] }) + } + + /** + * The asynchronous version of the `foreachPartition` action, which + * applies a function f to each partition of this RDD. + */ + def foreachPartitionAsync(f: VoidFunction[java.util.Iterator[T]]): JavaFutureAction[Void] = { + new JavaFutureActionWrapper[Unit, Void](rdd.foreachPartitionAsync(x => f.call(x)), + { x => null.asInstanceOf[Void] }) + } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala index 22810cb1c662d..b52d0a5028e84 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala @@ -19,10 +19,20 @@ package org.apache.spark.api.java import com.google.common.base.Optional +import scala.collection.convert.Wrappers.MapWrapper + private[spark] object JavaUtils { def optionToOptional[T](option: Option[T]): Optional[T] = option match { case Some(value) => Optional.of(value) case None => Optional.absent() } + + // Workaround for SPARK-3926 / SI-8911 + def mapAsSerializableJavaMap[A, B](underlying: collection.Map[A, B]) = + new SerializableMapWrapper(underlying) + + class SerializableMapWrapper[A, B](underlying: collection.Map[A, B]) + extends MapWrapper(underlying) with java.io.Serializable + } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index c74f86548ef85..29ca751519abd 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -23,10 +23,9 @@ import java.nio.charset.Charset import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.existentials -import scala.reflect.ClassTag -import scala.util.{Try, Success, Failure} import net.razorvine.pickle.{Pickler, Unpickler} @@ -42,7 +41,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils private[spark] class PythonRDD( - parent: RDD[_], + @transient parent: RDD[_], command: Array[Byte], envVars: JMap[String, String], pythonIncludes: JList[String], @@ -55,9 +54,9 @@ private[spark] class PythonRDD( val bufferSize = conf.getInt("spark.buffer.size", 65536) val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true) - override def getPartitions = parent.partitions + override def getPartitions = firstParent.partitions - override val partitioner = if (preservePartitoning) parent.partitioner else None + override val partitioner = if (preservePartitoning) firstParent.partitioner else None override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { val startTime = System.currentTimeMillis @@ -234,7 +233,7 @@ private[spark] class PythonRDD( dataOut.writeInt(command.length) dataOut.write(command) // Data values - PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut) + PythonRDD.writeIteratorToStream(firstParent.iterator(split, context), dataOut) dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) dataOut.flush() } catch { @@ -748,6 +747,7 @@ private[spark] object PythonRDD extends Logging { def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = { pyRDD.rdd.mapPartitions { iter => val unpickle = new Unpickler + SerDeUtil.initialize() iter.flatMap { row => unpickle.loads(row) match { // in case of objects are pickled in batch mode @@ -787,7 +787,7 @@ private[spark] object PythonRDD extends Logging { }.toJavaRDD() } - private class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] { + private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] { private val pickle = new Pickler() private var batch = 1 private val buffer = new mutable.ArrayBuffer[Any] @@ -824,11 +824,12 @@ private[spark] object PythonRDD extends Logging { */ def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = { pyRDD.rdd.mapPartitions { iter => + SerDeUtil.initialize() val unpickle = new Unpickler iter.flatMap { row => val obj = unpickle.loads(row) if (batched) { - obj.asInstanceOf[JArrayList[_]] + obj.asInstanceOf[JArrayList[_]].asScala } else { Seq(obj) } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 71bdf0fe1b917..e314408c067e9 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -108,10 +108,12 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) // Create and start the worker - val pb = new ProcessBuilder(Seq(pythonExec, "-u", "-m", "pyspark.worker")) + val pb = new ProcessBuilder(Seq(pythonExec, "-m", "pyspark.worker")) val workerEnv = pb.environment() workerEnv.putAll(envVars) workerEnv.put("PYTHONPATH", pythonPath) + // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: + workerEnv.put("PYTHONUNBUFFERED", "YES") val worker = pb.start() // Redirect worker stdout and stderr @@ -149,10 +151,12 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String try { // Create and start the daemon - val pb = new ProcessBuilder(Seq(pythonExec, "-u", "-m", "pyspark.daemon")) + val pb = new ProcessBuilder(Seq(pythonExec, "-m", "pyspark.daemon")) val workerEnv = pb.environment() workerEnv.putAll(envVars) workerEnv.put("PYTHONPATH", pythonPath) + // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: + workerEnv.put("PYTHONUNBUFFERED", "YES") daemon = pb.start() val in = new DataInputStream(daemon.getInputStream) diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala index 7903457b17e13..ebdc3533e0992 100644 --- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -29,7 +29,7 @@ import org.apache.spark.{Logging, SparkException} import org.apache.spark.rdd.RDD /** Utilities for serialization / deserialization between Python and Java, using Pickle. */ -private[python] object SerDeUtil extends Logging { +private[spark] object SerDeUtil extends Logging { // Unpickle array.array generated by Python 2.6 class ArrayConstructor extends net.razorvine.pickle.objects.ArrayConstructor { // /* Description of types */ @@ -76,9 +76,18 @@ private[python] object SerDeUtil extends Logging { } } + private var initialized = false + // This should be called before trying to unpickle array.array from Python + // In cluster mode, this should be put in closure def initialize() = { - Unpickler.registerConstructor("array", "array", new ArrayConstructor()) + synchronized{ + if (!initialized) { + Unpickler.registerConstructor("array", "array", new ArrayConstructor()) + initialized = true + } + } } + initialize() private def checkPickle(t: (Any, Any)): (Boolean, Boolean) = { val pickle = new Pickler @@ -143,6 +152,7 @@ private[python] object SerDeUtil extends Logging { obj.asInstanceOf[Array[_]].length == 2 } pyRDD.mapPartitions { iter => + initialize() val unpickle = new Unpickler val unpickled = if (batchSerialized) { diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 065ddda50e65e..f2687ce6b42b4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -130,7 +130,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") System.exit(-1) - case AssociationErrorEvent(cause, _, remoteAddress, _) => + case AssociationErrorEvent(cause, _, remoteAddress, _, _) => println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") println(s"Cause was: $cause") System.exit(-1) diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index a7368f9f3dfbe..b9dd8557ee904 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -71,6 +71,8 @@ private[deploy] object DeployMessages { case class RegisterWorkerFailed(message: String) extends DeployMessage + case class ReconnectWorker(masterUrl: String) extends DeployMessage + case class KillExecutor(masterUrl: String, appId: String, execId: Int) extends DeployMessage case class LaunchExecutor( diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index 79b4d7ea41a33..af94b05ce3847 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -34,7 +34,8 @@ object PythonRunner { val pythonFile = args(0) val pyFiles = args(1) val otherArgs = args.slice(2, args.length) - val pythonExec = sys.env.get("PYSPARK_PYTHON").getOrElse("python") // TODO: get this from conf + val pythonExec = + sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", sys.env.getOrElse("PYSPARK_PYTHON", "python")) // Format python file paths before adding them to the PYTHONPATH val formattedPythonFile = formatPath(pythonFile) @@ -57,6 +58,7 @@ object PythonRunner { val builder = new ProcessBuilder(Seq(pythonExec, formattedPythonFile) ++ otherArgs) val env = builder.environment() env.put("PYTHONPATH", pythonPath) + // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort) builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 57b251ff47714..72a452e0aefb5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -17,14 +17,11 @@ package org.apache.spark.deploy -import java.io.{File, FileInputStream, IOException} -import java.util.Properties import java.util.jar.JarFile import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} -import org.apache.spark.SparkException import org.apache.spark.util.Utils /** @@ -63,9 +60,8 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St val defaultProperties = new HashMap[String, String]() if (verbose) SparkSubmit.printStream.println(s"Using properties file: $propertiesFile") Option(propertiesFile).foreach { filename => - val file = new File(filename) - SparkSubmitArguments.getPropertiesFromFile(file).foreach { case (k, v) => - if (k.startsWith("spark")) { + Utils.getPropertiesFromFile(filename).foreach { case (k, v) => + if (k.startsWith("spark.")) { defaultProperties(k) = v if (verbose) SparkSubmit.printStream.println(s"Adding default property: $k=$v") } else { @@ -90,19 +86,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St */ private def mergeSparkProperties(): Unit = { // Use common defaults file, if not specified by user - if (propertiesFile == null) { - val sep = File.separator - val sparkHomeConfig = env.get("SPARK_HOME").map(sparkHome => s"${sparkHome}${sep}conf") - val confDir = env.get("SPARK_CONF_DIR").orElse(sparkHomeConfig) - - confDir.foreach { sparkConfDir => - val defaultPath = s"${sparkConfDir}${sep}spark-defaults.conf" - val file = new File(defaultPath) - if (file.exists()) { - propertiesFile = file.getAbsolutePath - } - } - } + propertiesFile = Option(propertiesFile).getOrElse(Utils.getDefaultPropertiesFile(env)) val properties = HashMap[String, String]() properties.putAll(defaultSparkProperties) @@ -397,23 +381,3 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St SparkSubmit.exitFn() } } - -object SparkSubmitArguments { - /** Load properties present in the given file. */ - def getPropertiesFromFile(file: File): Seq[(String, String)] = { - require(file.exists(), s"Properties file $file does not exist") - require(file.isFile(), s"Properties file $file is not a normal file") - val inputStream = new FileInputStream(file) - try { - val properties = new Properties() - properties.load(inputStream) - properties.stringPropertyNames().toSeq.map(k => (k, properties(k).trim)) - } catch { - case e: IOException => - val message = s"Failed when loading Spark properties file $file" - throw new SparkException(message, e) - } finally { - inputStream.close() - } - } -} diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala index a64170a47bc1c..0125330589da5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala @@ -68,7 +68,7 @@ private[spark] object SparkSubmitDriverBootstrapper { assume(bootstrapDriver != null, "SPARK_SUBMIT_BOOTSTRAP_DRIVER must be set") // Parse the properties file for the equivalent spark.driver.* configs - val properties = SparkSubmitArguments.getPropertiesFromFile(new File(propertiesFile)).toMap + val properties = Utils.getPropertiesFromFile(propertiesFile) val confDriverMemory = properties.get("spark.driver.memory") val confLibraryPath = properties.get("spark.driver.extraLibraryPath") val confClasspath = properties.get("spark.driver.extraClassPath") diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 32790053a6be8..98a93d1fcb2a3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -154,7 +154,7 @@ private[spark] class AppClient( logWarning(s"Connection to $address failed; waiting for master to reconnect...") markDisconnected() - case AssociationErrorEvent(cause, _, address, _) if isPossibleMaster(address) => + case AssociationErrorEvent(cause, _, address, _, _) if isPossibleMaster(address) => logWarning(s"Could not connect to $address: $cause") case StopAppClient => diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala index 25fc76c23e0fb..5bce32a04d16d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala @@ -18,12 +18,14 @@ package org.apache.spark.deploy.history import org.apache.spark.SparkConf +import org.apache.spark.util.Utils /** * Command-line parser for the master. */ private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String]) { private var logDir: String = null + private var propertiesFile: String = null parse(args.toList) @@ -32,11 +34,16 @@ private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String] case ("--dir" | "-d") :: value :: tail => logDir = value conf.set("spark.history.fs.logDirectory", value) + System.setProperty("spark.history.fs.logDirectory", value) parse(tail) case ("--help" | "-h") :: tail => printUsageAndExit(0) + case ("--properties-file") :: value :: tail => + propertiesFile = value + parse(tail) + case Nil => case _ => @@ -44,10 +51,17 @@ private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String] } } + // This mutates the SparkConf, so all accesses to it must be made after this line + Utils.loadDefaultSparkProperties(conf, propertiesFile) + private def printUsageAndExit(exitCode: Int) { System.err.println( """ - |Usage: HistoryServer + |Usage: HistoryServer [options] + | + |Options: + | --properties-file FILE Path to a custom Spark properties file. + | Default is conf/spark-defaults.conf. | |Configuration options can be set by setting the corresponding JVM system property. |History Server options are always available; additional options depend on the provider. diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index f98b531316a3d..3b6bb9fe128a4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -341,7 +341,14 @@ private[spark] class Master( case Some(workerInfo) => workerInfo.lastHeartbeat = System.currentTimeMillis() case None => - logWarning("Got heartbeat from unregistered worker " + workerId) + if (workers.map(_.id).contains(workerId)) { + logWarning(s"Got heartbeat from unregistered worker $workerId." + + " Asking it to re-register.") + sender ! ReconnectWorker(masterUrl) + } else { + logWarning(s"Got heartbeat from unregistered worker $workerId." + + " This worker was never registered, so ignoring the heartbeat.") + } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala index 4b0dbbe543d3f..e34bee7854292 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala @@ -27,6 +27,7 @@ private[spark] class MasterArguments(args: Array[String], conf: SparkConf) { var host = Utils.localHostName() var port = 7077 var webUiPort = 8080 + var propertiesFile: String = null // Check for settings in environment variables if (System.getenv("SPARK_MASTER_HOST") != null) { @@ -38,12 +39,16 @@ private[spark] class MasterArguments(args: Array[String], conf: SparkConf) { if (System.getenv("SPARK_MASTER_WEBUI_PORT") != null) { webUiPort = System.getenv("SPARK_MASTER_WEBUI_PORT").toInt } + + parse(args.toList) + + // This mutates the SparkConf, so all accesses to it must be made after this line + propertiesFile = Utils.loadDefaultSparkProperties(conf, propertiesFile) + if (conf.contains("spark.master.ui.port")) { webUiPort = conf.get("spark.master.ui.port").toInt } - parse(args.toList) - def parse(args: List[String]): Unit = args match { case ("--ip" | "-i") :: value :: tail => Utils.checkHost(value, "ip no longer supported, please use hostname " + value) @@ -63,7 +68,11 @@ private[spark] class MasterArguments(args: Array[String], conf: SparkConf) { webUiPort = value parse(tail) - case ("--help" | "-h") :: tail => + case ("--properties-file") :: value :: tail => + propertiesFile = value + parse(tail) + + case ("--help") :: tail => printUsageAndExit(0) case Nil => {} @@ -83,7 +92,9 @@ private[spark] class MasterArguments(args: Array[String], conf: SparkConf) { " -i HOST, --ip HOST Hostname to listen on (deprecated, please use --host or -h) \n" + " -h HOST, --host HOST Hostname to listen on\n" + " -p PORT, --port PORT Port to listen on (default: 7077)\n" + - " --webui-port PORT Port for web UI (default: 8080)") + " --webui-port PORT Port for web UI (default: 8080)\n" + + " --properties-file FILE Path to a custom Spark properties file.\n" + + " Default is conf/spark-defaults.conf.") System.exit(exitCode) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 71650cd773bcf..71d7385b08eb9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -111,13 +111,14 @@ private[spark] class ExecutorRunner( case "{{EXECUTOR_ID}}" => execId.toString case "{{HOSTNAME}}" => host case "{{CORES}}" => cores.toString + case "{{APP_ID}}" => appId case other => other } def getCommandSeq = { val command = Command( appDesc.command.mainClass, - appDesc.command.arguments.map(substituteVariables) ++ Seq(appId), + appDesc.command.arguments.map(substituteVariables), appDesc.command.environment, appDesc.command.classPathEntries, appDesc.command.libraryPathEntries, diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 9b52cb06fb6fa..c4a8ec2e5e7b0 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -20,12 +20,14 @@ package org.apache.spark.deploy.worker import java.io.File import java.io.IOException import java.text.SimpleDateFormat -import java.util.Date +import java.util.{UUID, Date} +import java.util.concurrent.TimeUnit import scala.collection.JavaConversions._ import scala.collection.mutable.HashMap import scala.concurrent.duration._ import scala.language.postfixOps +import scala.util.Random import akka.actor._ import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} @@ -64,8 +66,22 @@ private[spark] class Worker( // Send a heartbeat every (heartbeat timeout) / 4 milliseconds val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4 - val REGISTRATION_TIMEOUT = 20.seconds - val REGISTRATION_RETRIES = 3 + // Model retries to connect to the master, after Hadoop's model. + // The first six attempts to reconnect are in shorter intervals (between 5 and 15 seconds) + // Afterwards, the next 10 attempts are between 30 and 90 seconds. + // A bit of randomness is introduced so that not all of the workers attempt to reconnect at + // the same time. + val INITIAL_REGISTRATION_RETRIES = 6 + val TOTAL_REGISTRATION_RETRIES = INITIAL_REGISTRATION_RETRIES + 10 + val FUZZ_MULTIPLIER_INTERVAL_LOWER_BOUND = 0.500 + val REGISTRATION_RETRY_FUZZ_MULTIPLIER = { + val randomNumberGenerator = new Random(UUID.randomUUID.getMostSignificantBits) + randomNumberGenerator.nextDouble + FUZZ_MULTIPLIER_INTERVAL_LOWER_BOUND + } + val INITIAL_REGISTRATION_RETRY_INTERVAL = (math.round(10 * + REGISTRATION_RETRY_FUZZ_MULTIPLIER)).seconds + val PROLONGED_REGISTRATION_RETRY_INTERVAL = (math.round(60 + * REGISTRATION_RETRY_FUZZ_MULTIPLIER)).seconds val CLEANUP_ENABLED = conf.getBoolean("spark.worker.cleanup.enabled", false) // How often worker will clean up old app folders @@ -103,6 +119,7 @@ private[spark] class Worker( var coresUsed = 0 var memoryUsed = 0 + var connectionAttemptCount = 0 val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, securityMgr) val workerSource = new WorkerSource(this) @@ -158,7 +175,7 @@ private[spark] class Worker( connected = true } - def tryRegisterAllMasters() { + private def tryRegisterAllMasters() { for (masterUrl <- masterUrls) { logInfo("Connecting to master " + masterUrl + "...") val actor = context.actorSelection(Master.toAkkaUrl(masterUrl)) @@ -166,26 +183,47 @@ private[spark] class Worker( } } - def registerWithMaster() { - tryRegisterAllMasters() - var retries = 0 - registrationRetryTimer = Some { - context.system.scheduler.schedule(REGISTRATION_TIMEOUT, REGISTRATION_TIMEOUT) { - Utils.tryOrExit { - retries += 1 - if (registered) { - registrationRetryTimer.foreach(_.cancel()) - } else if (retries >= REGISTRATION_RETRIES) { - logError("All masters are unresponsive! Giving up.") - System.exit(1) - } else { - tryRegisterAllMasters() + private def retryConnectToMaster() { + Utils.tryOrExit { + connectionAttemptCount += 1 + logInfo(s"Attempting to connect to master (attempt # $connectionAttemptCount") + if (registered) { + registrationRetryTimer.foreach(_.cancel()) + registrationRetryTimer = None + } else if (connectionAttemptCount <= TOTAL_REGISTRATION_RETRIES) { + tryRegisterAllMasters() + if (connectionAttemptCount == INITIAL_REGISTRATION_RETRIES) { + registrationRetryTimer.foreach(_.cancel()) + registrationRetryTimer = Some { + context.system.scheduler.schedule(PROLONGED_REGISTRATION_RETRY_INTERVAL, + PROLONGED_REGISTRATION_RETRY_INTERVAL)(retryConnectToMaster) } } + } else { + logError("All masters are unresponsive! Giving up.") + System.exit(1) } } } + def registerWithMaster() { + // DisassociatedEvent may be triggered multiple times, so don't attempt registration + // if there are outstanding registration attempts scheduled. + registrationRetryTimer match { + case None => + registered = false + tryRegisterAllMasters() + connectionAttemptCount = 0 + registrationRetryTimer = Some { + context.system.scheduler.schedule(INITIAL_REGISTRATION_RETRY_INTERVAL, + INITIAL_REGISTRATION_RETRY_INTERVAL)(retryConnectToMaster) + } + case Some(_) => + logInfo("Not spawning another attempt to register with the master, since there is an" + + " attempt scheduled already.") + } + } + override def receiveWithLogging = { case RegisteredWorker(masterUrl, masterWebUiUrl) => logInfo("Successfully registered with master " + masterUrl) @@ -243,6 +281,10 @@ private[spark] class Worker( System.exit(1) } + case ReconnectWorker(masterUrl) => + logInfo(s"Master with url $masterUrl requested this worker to reconnect.") + registerWithMaster() + case LaunchExecutor(masterUrl, appId, execId, appDesc, cores_, memory_) => if (masterUrl != activeMasterUrl) { logWarning("Invalid Master (" + masterUrl + ") attempted to launch executor.") @@ -362,9 +404,10 @@ private[spark] class Worker( } } - def masterDisconnected() { + private def masterDisconnected() { logError("Connection to master failed! Waiting for master to reconnect...") connected = false + registerWithMaster() } def generateWorkerId(): String = { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index 1e295aaa48c30..019cd70f2a229 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -33,6 +33,7 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) { var memory = inferDefaultMemory() var masters: Array[String] = null var workDir: String = null + var propertiesFile: String = null // Check for settings in environment variables if (System.getenv("SPARK_WORKER_PORT") != null) { @@ -41,21 +42,27 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) { if (System.getenv("SPARK_WORKER_CORES") != null) { cores = System.getenv("SPARK_WORKER_CORES").toInt } - if (System.getenv("SPARK_WORKER_MEMORY") != null) { - memory = Utils.memoryStringToMb(System.getenv("SPARK_WORKER_MEMORY")) + if (conf.getenv("SPARK_WORKER_MEMORY") != null) { + memory = Utils.memoryStringToMb(conf.getenv("SPARK_WORKER_MEMORY")) } if (System.getenv("SPARK_WORKER_WEBUI_PORT") != null) { webUiPort = System.getenv("SPARK_WORKER_WEBUI_PORT").toInt } - if (conf.contains("spark.worker.ui.port")) { - webUiPort = conf.get("spark.worker.ui.port").toInt - } if (System.getenv("SPARK_WORKER_DIR") != null) { workDir = System.getenv("SPARK_WORKER_DIR") } parse(args.toList) + // This mutates the SparkConf, so all accesses to it must be made after this line + propertiesFile = Utils.loadDefaultSparkProperties(conf, propertiesFile) + + if (conf.contains("spark.worker.ui.port")) { + webUiPort = conf.get("spark.worker.ui.port").toInt + } + + checkWorkerMemory() + def parse(args: List[String]): Unit = args match { case ("--ip" | "-i") :: value :: tail => Utils.checkHost(value, "ip no longer supported, please use hostname " + value) @@ -87,7 +94,11 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) { webUiPort = value parse(tail) - case ("--help" | "-h") :: tail => + case ("--properties-file") :: value :: tail => + propertiesFile = value + parse(tail) + + case ("--help") :: tail => printUsageAndExit(0) case value :: tail => @@ -122,7 +133,9 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) { " -i HOST, --ip IP Hostname to listen on (deprecated, please use --host or -h)\n" + " -h HOST, --host HOST Hostname to listen on\n" + " -p PORT, --port PORT Port to listen on (default: random)\n" + - " --webui-port PORT Port for web UI (default: 8081)") + " --webui-port PORT Port for web UI (default: 8081)\n" + + " --properties-file FILE Path to a custom Spark properties file.\n" + + " Default is conf/spark-defaults.conf.") System.exit(exitCode) } @@ -153,4 +166,11 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) { // Leave out 1 GB for the operating system, but don't return a negative memory size math.max(totalMb - 1024, 512) } + + def checkWorkerMemory(): Unit = { + if (memory <= 0) { + val message = "Memory can't be 0, missing a M or G on the end of the memory specification?" + throw new IllegalStateException(message) + } + } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 6d0d0bbe5ecec..63a8ac817b618 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -54,7 +54,7 @@ private[spark] class WorkerWatcher(workerUrl: String) case AssociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => logInfo(s"Successfully connected to $workerUrl") - case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound) + case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) if isWorker(remoteAddress) => // These logs may not be seen if the worker (and associated pipe) has died logError(s"Could not initialize connection to worker $workerUrl. Exiting.") diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 06061edfc0844..c40a3e16675ad 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -152,6 +152,9 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { "Usage: CoarseGrainedExecutorBackend " + " [] ") System.exit(1) + + // NB: These arguments are provided by SparkDeploySchedulerBackend (for standalone mode) + // and CoarseMesosSchedulerBackend (for mesos mode). case 5 => run(args(0), args(1), args(2), args(3).toInt, args(4), None) case x if x > 5 => diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala index a4409181ec907..4c9ca97a2a6b7 100644 --- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala @@ -66,13 +66,27 @@ sealed abstract class ManagedBuffer { final class FileSegmentManagedBuffer(val file: File, val offset: Long, val length: Long) extends ManagedBuffer { + /** + * Memory mapping is expensive and can destabilize the JVM (SPARK-1145, SPARK-3889). + * Avoid unless there's a good reason not to. + */ + private val MIN_MEMORY_MAP_BYTES = 2 * 1024 * 1024; + override def size: Long = length override def nioByteBuffer(): ByteBuffer = { var channel: FileChannel = null try { channel = new RandomAccessFile(file, "r").getChannel - channel.map(MapMode.READ_ONLY, offset, length) + // Just copy the buffer if it's sufficiently small, as memory mapping has a high overhead. + if (length < MIN_MEMORY_MAP_BYTES) { + val buf = ByteBuffer.allocate(length.toInt) + channel.read(buf, offset) + buf.flip() + buf + } else { + channel.map(MapMode.READ_ONLY, offset, length) + } } catch { case e: IOException => Try(channel.size).toOption match { diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala index f368209980f93..4f6f5e235811d 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala @@ -20,11 +20,14 @@ package org.apache.spark.network.nio import java.net._ import java.nio._ import java.nio.channels._ +import java.util.concurrent.ConcurrentLinkedQueue import java.util.LinkedList import org.apache.spark._ +import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.util.control.NonFatal private[nio] abstract class Connection(val channel: SocketChannel, val selector: Selector, @@ -51,7 +54,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, @volatile private var closed = false var onCloseCallback: Connection => Unit = null - var onExceptionCallback: (Connection, Exception) => Unit = null + val onExceptionCallbacks = new ConcurrentLinkedQueue[(Connection, Throwable) => Unit] var onKeyInterestChangeCallback: (Connection, Int) => Unit = null val remoteAddress = getRemoteAddress() @@ -130,20 +133,24 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, onCloseCallback = callback } - def onException(callback: (Connection, Exception) => Unit) { - onExceptionCallback = callback + def onException(callback: (Connection, Throwable) => Unit) { + onExceptionCallbacks.add(callback) } def onKeyInterestChange(callback: (Connection, Int) => Unit) { onKeyInterestChangeCallback = callback } - def callOnExceptionCallback(e: Exception) { - if (onExceptionCallback != null) { - onExceptionCallback(this, e) - } else { - logError("Error in connection to " + getRemoteConnectionManagerId() + - " and OnExceptionCallback not registered", e) + def callOnExceptionCallbacks(e: Throwable) { + onExceptionCallbacks foreach { + callback => + try { + callback(this, e) + } catch { + case NonFatal(e) => { + logWarning("Ignored error in onExceptionCallback", e) + } + } } } @@ -323,7 +330,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, } catch { case e: Exception => { logError("Error connecting to " + address, e) - callOnExceptionCallback(e) + callOnExceptionCallbacks(e) } } } @@ -348,7 +355,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, } catch { case e: Exception => { logWarning("Error finishing connection to " + address, e) - callOnExceptionCallback(e) + callOnExceptionCallbacks(e) } } true @@ -393,7 +400,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, } catch { case e: Exception => { logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e) - callOnExceptionCallback(e) + callOnExceptionCallbacks(e) close() return false } @@ -420,7 +427,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, case e: Exception => logError("Exception while reading SendingConnection to " + getRemoteConnectionManagerId(), e) - callOnExceptionCallback(e) + callOnExceptionCallbacks(e) close() } @@ -577,7 +584,7 @@ private[spark] class ReceivingConnection( } catch { case e: Exception => { logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e) - callOnExceptionCallback(e) + callOnExceptionCallbacks(e) close() return false } diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 01cd27a907eea..bda4bf50932c3 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -34,6 +34,8 @@ import scala.language.postfixOps import org.apache.spark._ import org.apache.spark.util.Utils +import scala.util.Try +import scala.util.control.NonFatal private[nio] class ConnectionManager( port: Int, @@ -51,14 +53,23 @@ private[nio] class ConnectionManager( class MessageStatus( val message: Message, val connectionManagerId: ConnectionManagerId, - completionHandler: MessageStatus => Unit) { + completionHandler: Try[Message] => Unit) { - /** This is non-None if message has been ack'd */ - var ackMessage: Option[Message] = None + def success(ackMessage: Message) { + if (ackMessage == null) { + failure(new NullPointerException) + } + else { + completionHandler(scala.util.Success(ackMessage)) + } + } + + def failWithoutAck() { + completionHandler(scala.util.Failure(new IOException("Failed without being ACK'd"))) + } - def markDone(ackMessage: Option[Message]) { - this.ackMessage = ackMessage - completionHandler(this) + def failure(e: Throwable) { + completionHandler(scala.util.Failure(e)) } } @@ -72,14 +83,32 @@ private[nio] class ConnectionManager( conf.getInt("spark.core.connection.handler.threads.max", 60), conf.getInt("spark.core.connection.handler.threads.keepalive", 60), TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable](), - Utils.namedThreadFactory("handle-message-executor")) + Utils.namedThreadFactory("handle-message-executor")) { + + override def afterExecute(r: Runnable, t: Throwable): Unit = { + super.afterExecute(r, t) + if (t != null && NonFatal(t)) { + logError("Error in handleMessageExecutor is not handled properly", t) + } + } + + } private val handleReadWriteExecutor = new ThreadPoolExecutor( conf.getInt("spark.core.connection.io.threads.min", 4), conf.getInt("spark.core.connection.io.threads.max", 32), conf.getInt("spark.core.connection.io.threads.keepalive", 60), TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable](), - Utils.namedThreadFactory("handle-read-write-executor")) + Utils.namedThreadFactory("handle-read-write-executor")) { + + override def afterExecute(r: Runnable, t: Throwable): Unit = { + super.afterExecute(r, t) + if (t != null && NonFatal(t)) { + logError("Error in handleReadWriteExecutor is not handled properly", t) + } + } + + } // Use a different, yet smaller, thread pool - infrequently used with very short lived tasks : // which should be executed asap @@ -88,7 +117,16 @@ private[nio] class ConnectionManager( conf.getInt("spark.core.connection.connect.threads.max", 8), conf.getInt("spark.core.connection.connect.threads.keepalive", 60), TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable](), - Utils.namedThreadFactory("handle-connect-executor")) + Utils.namedThreadFactory("handle-connect-executor")) { + + override def afterExecute(r: Runnable, t: Throwable): Unit = { + super.afterExecute(r, t) + if (t != null && NonFatal(t)) { + logError("Error in handleConnectExecutor is not handled properly", t) + } + } + + } private val serverChannel = ServerSocketChannel.open() // used to track the SendingConnections waiting to do SASL negotiation @@ -153,17 +191,24 @@ private[nio] class ConnectionManager( } handleReadWriteExecutor.execute(new Runnable { override def run() { - var register: Boolean = false try { - register = conn.write() - } finally { - writeRunnableStarted.synchronized { - writeRunnableStarted -= key - val needReregister = register || conn.resetForceReregister() - if (needReregister && conn.changeInterestForWrite()) { - conn.registerInterest() + var register: Boolean = false + try { + register = conn.write() + } finally { + writeRunnableStarted.synchronized { + writeRunnableStarted -= key + val needReregister = register || conn.resetForceReregister() + if (needReregister && conn.changeInterestForWrite()) { + conn.registerInterest() + } } } + } catch { + case NonFatal(e) => { + logError("Error when writing to " + conn.getRemoteConnectionManagerId(), e) + conn.callOnExceptionCallbacks(e) + } } } } ) @@ -187,16 +232,23 @@ private[nio] class ConnectionManager( } handleReadWriteExecutor.execute(new Runnable { override def run() { - var register: Boolean = false try { - register = conn.read() - } finally { - readRunnableStarted.synchronized { - readRunnableStarted -= key - if (register && conn.changeInterestForRead()) { - conn.registerInterest() + var register: Boolean = false + try { + register = conn.read() + } finally { + readRunnableStarted.synchronized { + readRunnableStarted -= key + if (register && conn.changeInterestForRead()) { + conn.registerInterest() + } } } + } catch { + case NonFatal(e) => { + logError("Error when reading from " + conn.getRemoteConnectionManagerId(), e) + conn.callOnExceptionCallbacks(e) + } } } } ) @@ -213,19 +265,25 @@ private[nio] class ConnectionManager( handleConnectExecutor.execute(new Runnable { override def run() { + try { + var tries: Int = 10 + while (tries >= 0) { + if (conn.finishConnect(false)) return + // Sleep ? + Thread.sleep(1) + tries -= 1 + } - var tries: Int = 10 - while (tries >= 0) { - if (conn.finishConnect(false)) return - // Sleep ? - Thread.sleep(1) - tries -= 1 + // fallback to previous behavior : we should not really come here since this method was + // triggered since channel became connectable : but at times, the first finishConnect need + // not succeed : hence the loop to retry a few 'times'. + conn.finishConnect(true) + } catch { + case NonFatal(e) => { + logError("Error when finishConnect for " + conn.getRemoteConnectionManagerId(), e) + conn.callOnExceptionCallbacks(e) + } } - - // fallback to previous behavior : we should not really come here since this method was - // triggered since channel became connectable : but at times, the first finishConnect need - // not succeed : hence the loop to retry a few 'times'. - conn.finishConnect(true) } } ) } @@ -246,16 +304,16 @@ private[nio] class ConnectionManager( handleConnectExecutor.execute(new Runnable { override def run() { try { - conn.callOnExceptionCallback(e) + conn.callOnExceptionCallbacks(e) } catch { // ignore exceptions - case e: Exception => logDebug("Ignoring exception", e) + case NonFatal(e) => logDebug("Ignoring exception", e) } try { conn.close() } catch { // ignore exceptions - case e: Exception => logDebug("Ignoring exception", e) + case NonFatal(e) => logDebug("Ignoring exception", e) } } }) @@ -448,7 +506,7 @@ private[nio] class ConnectionManager( messageStatuses.values.filter(_.connectionManagerId == sendingConnectionManagerId) .foreach(status => { logInfo("Notifying " + status) - status.markDone(None) + status.failWithoutAck() }) messageStatuses.retain((i, status) => { @@ -477,7 +535,7 @@ private[nio] class ConnectionManager( for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) { logInfo("Notifying " + s) - s.markDone(None) + s.failWithoutAck() } messageStatuses.retain((i, status) => { @@ -492,7 +550,7 @@ private[nio] class ConnectionManager( } } - def handleConnectionError(connection: Connection, e: Exception) { + def handleConnectionError(connection: Connection, e: Throwable) { logInfo("Handling connection error on connection to " + connection.getRemoteConnectionManagerId()) removeConnection(connection) @@ -510,9 +568,17 @@ private[nio] class ConnectionManager( val runnable = new Runnable() { val creationTime = System.currentTimeMillis def run() { - logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms") - handleMessage(connectionManagerId, message, connection) - logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms") + try { + logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms") + handleMessage(connectionManagerId, message, connection) + logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms") + } catch { + case NonFatal(e) => { + logError("Error when handling messages from " + + connection.getRemoteConnectionManagerId(), e) + connection.callOnExceptionCallbacks(e) + } + } } } handleMessageExecutor.execute(runnable) @@ -651,7 +717,7 @@ private[nio] class ConnectionManager( messageStatuses.get(bufferMessage.ackId) match { case Some(status) => { messageStatuses -= bufferMessage.ackId - status.markDone(Some(message)) + status.success(message) } case None => { /** @@ -691,9 +757,7 @@ private[nio] class ConnectionManager( } catch { case e: Exception => { logError(s"Exception was thrown while processing message", e) - val m = Message.createBufferMessage(bufferMessage.id) - m.hasError = true - ackMessage = Some(m) + ackMessage = Some(Message.createErrorMessage(e, bufferMessage.id)) } } finally { sendMessage(connectionManagerId, ackMessage.getOrElse { @@ -770,6 +834,12 @@ private[nio] class ConnectionManager( val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId, newConnectionId, securityManager) + newConnection.onException { + case (conn, e) => { + logError("Exception while sending message.", e) + reportSendingMessageFailure(message.id, e) + } + } logTrace("creating new sending connection: " + newConnectionId) registerRequests.enqueue(newConnection) @@ -782,13 +852,36 @@ private[nio] class ConnectionManager( "connectionid: " + connection.connectionId) if (authEnabled) { - checkSendAuthFirst(connectionManagerId, connection) + try { + checkSendAuthFirst(connectionManagerId, connection) + } catch { + case NonFatal(e) => { + reportSendingMessageFailure(message.id, e) + } + } } logDebug("Sending [" + message + "] to [" + connectionManagerId + "]") connection.send(message) wakeupSelector() } + private def reportSendingMessageFailure(messageId: Int, e: Throwable): Unit = { + // need to tell sender it failed + messageStatuses.synchronized { + val s = messageStatuses.get(messageId) + s match { + case Some(msgStatus) => { + messageStatuses -= messageId + logInfo("Notifying " + msgStatus.connectionManagerId) + msgStatus.failure(e) + } + case None => { + logError("no messageStatus for failed message id: " + messageId) + } + } + } + } + private def wakeupSelector() { selector.wakeup() } @@ -807,9 +900,11 @@ private[nio] class ConnectionManager( override def run(): Unit = { messageStatuses.synchronized { messageStatuses.remove(message.id).foreach ( s => { - promise.failure( - new IOException("sendMessageReliably failed because ack " + - s"was not received within $ackTimeout sec")) + val e = new IOException("sendMessageReliably failed because ack " + + s"was not received within $ackTimeout sec") + if (!promise.tryFailure(e)) { + logWarning("Ignore error because promise is completed", e) + } }) } } @@ -817,15 +912,27 @@ private[nio] class ConnectionManager( val status = new MessageStatus(message, connectionManagerId, s => { timeoutTask.cancel() - s.ackMessage match { - case None => // Indicates a failure where we either never sent or never got ACK'd - promise.failure(new IOException("sendMessageReliably failed without being ACK'd")) - case Some(ackMessage) => + s match { + case scala.util.Failure(e) => + // Indicates a failure where we either never sent or never got ACK'd + if (!promise.tryFailure(e)) { + logWarning("Ignore error because promise is completed", e) + } + case scala.util.Success(ackMessage) => if (ackMessage.hasError) { - promise.failure( - new IOException("sendMessageReliably failed with ACK that signalled a remote error")) + val errorMsgByteBuf = ackMessage.asInstanceOf[BufferMessage].buffers.head + val errorMsgBytes = new Array[Byte](errorMsgByteBuf.limit()) + errorMsgByteBuf.get(errorMsgBytes) + val errorMsg = new String(errorMsgBytes, "utf-8") + val e = new IOException( + s"sendMessageReliably failed with ACK that signalled a remote error: $errorMsg") + if (!promise.tryFailure(e)) { + logWarning("Ignore error because promise is completed", e) + } } else { - promise.success(ackMessage) + if (!promise.trySuccess(ackMessage)) { + logWarning("Drop ackMessage because promise is completed") + } } } }) diff --git a/core/src/main/scala/org/apache/spark/network/nio/Message.scala b/core/src/main/scala/org/apache/spark/network/nio/Message.scala index 0b874c2891255..3ad04591da658 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Message.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Message.scala @@ -22,6 +22,7 @@ import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer +import org.apache.spark.util.Utils private[nio] abstract class Message(val typ: Long, val id: Int) { var senderAddress: InetSocketAddress = null @@ -84,6 +85,19 @@ private[nio] object Message { createBufferMessage(new Array[ByteBuffer](0), ackId) } + /** + * Create a "negative acknowledgment" to notify a sender that an error occurred + * while processing its message. The exception's stacktrace will be formatted + * as a string, serialized into a byte array, and sent as the message payload. + */ + def createErrorMessage(exception: Exception, ackId: Int): BufferMessage = { + val exceptionString = Utils.exceptionString(exception) + val serializedExceptionString = ByteBuffer.wrap(exceptionString.getBytes("utf-8")) + val errorMessage = createBufferMessage(serializedExceptionString, ackId) + errorMessage.hasError = true + errorMessage + } + def create(header: MessageChunkHeader): Message = { val newMessage: Message = header.typ match { case BUFFER_MESSAGE => new BufferMessage(header.id, diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala index b389b9a2022c6..5add4fc433fb3 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala @@ -151,17 +151,14 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa } catch { case e: Exception => { logError("Exception handling buffer message", e) - val errorMessage = Message.createBufferMessage(msg.id) - errorMessage.hasError = true - Some(errorMessage) + Some(Message.createErrorMessage(e, msg.id)) } } case otherMessage: Any => - logError("Unknown type message received: " + otherMessage) - val errorMessage = Message.createBufferMessage(msg.id) - errorMessage.hasError = true - Some(errorMessage) + val errorMsg = s"Received unknown message type: ${otherMessage.getClass.getName}" + logError(errorMsg) + Some(Message.createErrorMessage(new UnsupportedOperationException(errorMsg), msg.id)) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index b62f3fbdc4a15..9f9f10b7ebc3a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -24,14 +24,11 @@ import scala.concurrent.ExecutionContext.Implicits.global import scala.reflect.ClassTag import org.apache.spark.{ComplexFutureAction, FutureAction, Logging} -import org.apache.spark.annotation.Experimental /** - * :: Experimental :: * A set of asynchronous RDD actions available through an implicit conversion. * Import `org.apache.spark.SparkContext._` at the top of your program to use these functions. */ -@Experimental class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Logging { /** @@ -78,16 +75,18 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi // greater than totalParts because we actually cap it at totalParts in runJob. var numPartsToTry = 1 if (partsScanned > 0) { - // If we didn't find any rows after the first iteration, just try all partitions next. + // If we didn't find any rows after the previous iteration, quadruple and retry. // Otherwise, interpolate the number of partitions we need to try, but overestimate it - // by 50%. + // by 50%. We also cap the estimation in the end. if (results.size == 0) { - numPartsToTry = totalParts - 1 + numPartsToTry = partsScanned * 4 } else { - numPartsToTry = (1.5 * num * partsScanned / results.size).toInt + // the left side of max is >=1 whenever partsScanned >= 2 + numPartsToTry = Math.max(1, + (1.5 * num * partsScanned / results.size).toInt - partsScanned) + numPartsToTry = Math.min(numPartsToTry, partsScanned * 4) } } - numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions val left = num - results.size val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 6b63eb23e9ee1..775141775e06c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -132,27 +132,47 @@ class HadoopRDD[K, V]( // used to build JobTracker ID private val createTime = new Date() + private val shouldCloneJobConf = sc.conf.get("spark.hadoop.cloneConf", "false").toBoolean + // Returns a JobConf that will be used on slaves to obtain input splits for Hadoop reads. protected def getJobConf(): JobConf = { val conf: Configuration = broadcastedConf.value.value - if (conf.isInstanceOf[JobConf]) { - // A user-broadcasted JobConf was provided to the HadoopRDD, so always use it. - conf.asInstanceOf[JobConf] - } else if (HadoopRDD.containsCachedMetadata(jobConfCacheKey)) { - // getJobConf() has been called previously, so there is already a local cache of the JobConf - // needed by this RDD. - HadoopRDD.getCachedMetadata(jobConfCacheKey).asInstanceOf[JobConf] - } else { - // Create a JobConf that will be cached and used across this RDD's getJobConf() calls in the - // local process. The local cache is accessed through HadoopRDD.putCachedMetadata(). - // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary objects. - // Synchronize to prevent ConcurrentModificationException (Spark-1097, Hadoop-10456). + if (shouldCloneJobConf) { + // Hadoop Configuration objects are not thread-safe, which may lead to various problems if + // one job modifies a configuration while another reads it (SPARK-2546). This problem occurs + // somewhat rarely because most jobs treat the configuration as though it's immutable. One + // solution, implemented here, is to clone the Configuration object. Unfortunately, this + // clone can be very expensive. To avoid unexpected performance regressions for workloads and + // Hadoop versions that do not suffer from these thread-safety issues, this cloning is + // disabled by default. HadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized { + logDebug("Cloning Hadoop Configuration") val newJobConf = new JobConf(conf) - initLocalJobConfFuncOpt.map(f => f(newJobConf)) - HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf) + if (!conf.isInstanceOf[JobConf]) { + initLocalJobConfFuncOpt.map(f => f(newJobConf)) + } newJobConf } + } else { + if (conf.isInstanceOf[JobConf]) { + logDebug("Re-using user-broadcasted JobConf") + conf.asInstanceOf[JobConf] + } else if (HadoopRDD.containsCachedMetadata(jobConfCacheKey)) { + logDebug("Re-using cached JobConf") + HadoopRDD.getCachedMetadata(jobConfCacheKey).asInstanceOf[JobConf] + } else { + // Create a JobConf that will be cached and used across this RDD's getJobConf() calls in the + // local process. The local cache is accessed through HadoopRDD.putCachedMetadata(). + // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary objects. + // Synchronize to prevent ConcurrentModificationException (SPARK-1097, HADOOP-10456). + HadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized { + logDebug("Creating new JobConf and caching it for later re-use") + val newJobConf = new JobConf(conf) + initLocalJobConfFuncOpt.map(f => f(newJobConf)) + HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf) + newJobConf + } + } } } @@ -196,7 +216,7 @@ class HadoopRDD[K, V]( val jobConf = getJobConf() val inputFormat = getInputFormat(jobConf) HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime), - context.getStageId, theSplit.index, context.getAttemptId.toInt, jobConf) + context.stageId, theSplit.index, context.attemptId.toInt, jobConf) reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) // Register an on-task-completion callback to close the input stream. @@ -276,7 +296,10 @@ class HadoopRDD[K, V]( } private[spark] object HadoopRDD extends Logging { - /** Constructing Configuration objects is not threadsafe, use this lock to serialize. */ + /** + * Configuration's constructor is not threadsafe (see SPARK-1097 and HADOOP-10456). + * Therefore, we synchronize on this lock before calling new JobConf() or new Configuration(). + */ val CONFIGURATION_INSTANTIATION_LOCK = new Object() /** diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 0d97506450a7f..ac96de86dd6d4 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -956,9 +956,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val writeShard = (context: TaskContext, iter: Iterator[(K,V)]) => { // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. - val attemptNumber = (context.getAttemptId % Int.MaxValue).toInt + val attemptNumber = (context.attemptId % Int.MaxValue).toInt /* "reduce task" */ - val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.getPartitionId, + val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, attemptNumber) val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) val format = outfmt.newInstance @@ -1027,15 +1027,13 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val writeToFile = (context: TaskContext, iter: Iterator[(K, V)]) => { // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. - val attemptNumber = (context.getAttemptId % Int.MaxValue).toInt + val attemptNumber = (context.attemptId % Int.MaxValue).toInt - writer.setup(context.getStageId, context.getPartitionId, attemptNumber) + writer.setup(context.stageId, context.partitionId, attemptNumber) writer.open() try { - var count = 0 while (iter.hasNext) { val record = iter.next() - count += 1 writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef]) } } finally { diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 2aba40d152e3e..71cabf61d4ee0 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1079,15 +1079,17 @@ abstract class RDD[T: ClassTag]( // greater than totalParts because we actually cap it at totalParts in runJob. var numPartsToTry = 1 if (partsScanned > 0) { - // If we didn't find any rows after the previous iteration, quadruple and retry. Otherwise, + // If we didn't find any rows after the previous iteration, quadruple and retry. Otherwise, // interpolate the number of partitions we need to try, but overestimate it by 50%. + // We also cap the estimation in the end. if (buf.size == 0) { numPartsToTry = partsScanned * 4 } else { - numPartsToTry = (1.5 * num * partsScanned / buf.size).toInt + // the left side of max is >=1 whenever partsScanned >= 2 + numPartsToTry = Math.max((1.5 * num * partsScanned / buf.size).toInt - partsScanned, 1) + numPartsToTry = Math.min(numPartsToTry, partsScanned * 4) } } - numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions val left = num - buf.size val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 788eb1ff4e455..f81fa6d8089fc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -633,14 +633,14 @@ class DAGScheduler( val rdd = job.finalStage.rdd val split = rdd.partitions(job.partitions(0)) val taskContext = - new TaskContext(job.finalStage.id, job.partitions(0), 0, true) - TaskContext.setTaskContext(taskContext) + new TaskContextImpl(job.finalStage.id, job.partitions(0), 0, true) + TaskContextHelper.setTaskContext(taskContext) try { val result = job.func(taskContext, rdd.iterator(split, taskContext)) job.listener.taskSucceeded(0, result) } finally { taskContext.markTaskCompleted() - TaskContext.unset() + TaskContextHelper.unset() } } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index c6e47c84a0cb2..2552d03d18d06 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import scala.collection.mutable.HashMap -import org.apache.spark.TaskContext +import org.apache.spark.{TaskContextHelper, TaskContextImpl, TaskContext} import org.apache.spark.executor.TaskMetrics import org.apache.spark.serializer.SerializerInstance import org.apache.spark.util.ByteBufferInputStream @@ -45,8 +45,8 @@ import org.apache.spark.util.Utils private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable { final def run(attemptId: Long): T = { - context = new TaskContext(stageId, partitionId, attemptId, false) - TaskContext.setTaskContext(context) + context = new TaskContextImpl(stageId, partitionId, attemptId, false) + TaskContextHelper.setTaskContext(context) context.taskMetrics.hostname = Utils.localHostName() taskThread = Thread.currentThread() if (_killed) { @@ -56,7 +56,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex runTask(context) } finally { context.markTaskCompleted() - TaskContext.unset() + TaskContextHelper.unset() } } @@ -70,7 +70,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex var metrics: Option[TaskMetrics] = None // Task context, to be initialized in run(). - @transient protected var context: TaskContext = _ + @transient protected var context: TaskContextImpl = _ // The actual Thread on which the task is running, if any. Initialized in run(). @volatile @transient private var taskThread: Thread = _ diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index ed209d195ec9d..8c7de75600b5f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -51,7 +51,8 @@ private[spark] class SparkDeploySchedulerBackend( conf.get("spark.driver.host"), conf.get("spark.driver.port"), CoarseGrainedSchedulerBackend.ACTOR_NAME) - val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}", "{{WORKER_URL}}") + val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}", "{{APP_ID}}", + "{{WORKER_URL}}") val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions") .map(Utils.splitCommandString).getOrElse(Seq.empty) val classPathEntries = sc.conf.getOption("spark.executor.extraClassPath").toSeq.flatMap { cp => diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 90828578cd88f..d7f88de4b40aa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -150,17 +150,17 @@ private[spark] class CoarseMesosSchedulerBackend( if (uri == null) { val runScript = new File(executorSparkHome, "./bin/spark-class").getCanonicalPath command.setValue( - "\"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d".format( - runScript, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores)) + "\"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d %s".format( + runScript, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores, appId)) } else { // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". val basename = uri.split('/').last.split('.').head command.setValue( ("cd %s*; " + - "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d") + "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d %s") .format(basename, driverUrl, offer.getSlaveId.getValue, - offer.getHostname, numCores)) + offer.getHostname, numCores, appId)) command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) } command.build() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 6a06257ed0c08..088f06e389d83 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -457,16 +457,18 @@ private[spark] class BlockManagerInfo( if (_blocks.containsKey(blockId)) { // The block exists on the slave already. - val originalLevel: StorageLevel = _blocks.get(blockId).storageLevel + val blockStatus: BlockStatus = _blocks.get(blockId) + val originalLevel: StorageLevel = blockStatus.storageLevel + val originalMemSize: Long = blockStatus.memSize if (originalLevel.useMemory) { - _remainingMem += memSize + _remainingMem += originalMemSize } } if (storageLevel.isValid) { /* isValid means it is either stored in-memory, on-disk or on-Tachyon. - * But the memSize here indicates the data size in or dropped from memory, + * The memSize here indicates the data size in or dropped from memory, * tachyonSize here indicates the data size in or dropped from Tachyon, * and the diskSize here indicates the data size in or dropped to disk. * They can be both larger than 0, when a block is dropped from memory to disk. @@ -493,7 +495,6 @@ private[spark] class BlockManagerInfo( val blockStatus: BlockStatus = _blocks.get(blockId) _blocks.remove(blockId) if (blockStatus.storageLevel.useMemory) { - _remainingMem += blockStatus.memSize logInfo("Removed %s on %s in memory (size: %s, free: %s)".format( blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.memSize), Utils.bytesToString(_remainingMem))) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index f0006b42aee4f..32e6b15bb0999 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -21,6 +21,7 @@ import java.text.SimpleDateFormat import java.util.{Locale, Date} import scala.xml.Node + import org.apache.spark.Logging /** Utility functions for generating XML pages with spark content. */ @@ -169,6 +170,7 @@ private[spark] object UIUtils extends Logging { refreshInterval: Option[Int] = None): Seq[Node] = { val appName = activeTab.appName + val shortAppName = if (appName.length < 36) appName else appName.take(32) + "..." val header = activeTab.headerTabs.map { tab =>

  • {tab.name} @@ -187,7 +189,9 @@ private[spark] object UIUtils extends Logging { - +
    @@ -216,8 +220,10 @@ private[spark] object UIUtils extends Logging {

    - + + + {title}

    diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index 2987dc04494a5..f0e43fbf70976 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -71,19 +71,19 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: JobPr {k} {executorIdToAddress.getOrElse(k, "CANNOT FIND ADDRESS")} - {UIUtils.formatDuration(v.taskTime)} + {UIUtils.formatDuration(v.taskTime)} {v.failedTasks + v.succeededTasks} {v.failedTasks} {v.succeededTasks} - + {Utils.bytesToString(v.inputBytes)} - + {Utils.bytesToString(v.shuffleRead)} - + {Utils.bytesToString(v.shuffleWrite)} - + {Utils.bytesToString(v.memoryBytesSpilled)} - + {Utils.bytesToString(v.diskBytesSpilled)} } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala index a82f71ed08475..1e02f1225d344 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala @@ -29,7 +29,7 @@ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("") private val live = parent.live private val sc = parent.sc private val listener = parent.listener - private lazy val isFairScheduler = parent.isFairScheduler + private def isFairScheduler = parent.isFairScheduler def render(request: HttpServletRequest): Seq[Node] = { listener.synchronized { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index db01be596e073..2414e4c65237e 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -103,7 +103,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { val taskHeaders: Seq[String] = Seq( - "Index", "ID", "Attempt", "Status", "Locality Level", "Executor", + "Index", "ID", "Attempt", "Status", "Locality Level", "Executor ID / Host", "Launch Time", "Duration", "GC Time", "Accumulators") ++ {if (hasInput) Seq("Input") else Nil} ++ {if (hasShuffleRead) Seq("Shuffle Read") else Nil} ++ @@ -282,7 +282,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { } {info.status} {info.taskLocality} - {info.host} + {info.executorId} / {info.host} {UIUtils.formatDate(new Date(info.launchTime))} {formatDuration} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 2e67310594784..4ee7f08ab47a2 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -176,9 +176,9 @@ private[ui] class StageTableBase( {makeProgressBar(stageData.numActiveTasks, stageData.completedIndices.size, stageData.numFailedTasks, s.numTasks)} - {inputReadWithUnit} - {shuffleReadWithUnit} - {shuffleWriteWithUnit} + {inputReadWithUnit} + {shuffleReadWithUnit} + {shuffleWriteWithUnit} } /** Render an HTML row that represents a stage */ diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index 716591c9ed449..83489ca0679ee 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -58,9 +58,9 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { {rdd.numCachedPartitions} {"%.0f%%".format(rdd.numCachedPartitions * 100.0 / rdd.numPartitions)} - {Utils.bytesToString(rdd.memSize)} - {Utils.bytesToString(rdd.tachyonSize)} - {Utils.bytesToString(rdd.diskSize)} + {Utils.bytesToString(rdd.memSize)} + {Utils.bytesToString(rdd.tachyonSize)} + {Utils.bytesToString(rdd.diskSize)} // scalastyle:on } diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index e2d32c859bbda..f41c8d0315cb3 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -77,7 +77,7 @@ private[spark] object AkkaUtils extends Logging { val logAkkaConfig = if (conf.getBoolean("spark.akka.logAkkaConfig", false)) "on" else "off" - val akkaHeartBeatPauses = conf.getInt("spark.akka.heartbeat.pauses", 600) + val akkaHeartBeatPauses = conf.getInt("spark.akka.heartbeat.pauses", 6000) val akkaFailureDetector = conf.getDouble("spark.akka.failure-detector.threshold", 300.0) val akkaHeartBeatInterval = conf.getInt("spark.akka.heartbeat.interval", 1000) diff --git a/core/src/main/scala/org/apache/spark/util/FileLogger.scala b/core/src/main/scala/org/apache/spark/util/FileLogger.scala index 6d1fc05a15d2c..fdc73f08261a6 100644 --- a/core/src/main/scala/org/apache/spark/util/FileLogger.scala +++ b/core/src/main/scala/org/apache/spark/util/FileLogger.scala @@ -51,12 +51,27 @@ private[spark] class FileLogger( def this( logDir: String, sparkConf: SparkConf, - compress: Boolean = false, - overwrite: Boolean = true) = { + compress: Boolean, + overwrite: Boolean) = { this(logDir, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf), compress = compress, overwrite = overwrite) } + def this( + logDir: String, + sparkConf: SparkConf, + compress: Boolean) = { + this(logDir, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf), compress = compress, + overwrite = true) + } + + def this( + logDir: String, + sparkConf: SparkConf) = { + this(logDir, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf), compress = false, + overwrite = true) + } + private val dateFormat = new ThreadLocal[SimpleDateFormat]() { override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 3d307b3c16d3e..0aeff6455b3fe 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -168,6 +168,20 @@ private[spark] object Utils extends Logging { private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]() private val shutdownDeleteTachyonPaths = new scala.collection.mutable.HashSet[String]() + // Add a shutdown hook to delete the temp dirs when the JVM exits + Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dirs") { + override def run(): Unit = Utils.logUncaughtExceptions { + logDebug("Shutdown hook called") + shutdownDeletePaths.foreach { dirPath => + try { + Utils.deleteRecursively(new File(dirPath)) + } catch { + case e: Exception => logError(s"Exception while deleting Spark temp dir: $dirPath", e) + } + } + } + }) + // Register the path to be deleted via shutdown hook def registerShutdownDeleteDir(file: File) { val absolutePath = file.getAbsolutePath() @@ -252,34 +266,47 @@ private[spark] object Utils extends Logging { } registerShutdownDeleteDir(dir) - - // Add a shutdown hook to delete the temp dir when the JVM exits - Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dir " + dir) { - override def run() { - // Attempt to delete if some patch which is parent of this is not already registered. - if (! hasRootAsShutdownDeleteDir(dir)) Utils.deleteRecursively(dir) - } - }) dir } - /** Copy all data from an InputStream to an OutputStream */ + /** Copy all data from an InputStream to an OutputStream. NIO way of file stream to file stream + * copying is disabled by default unless explicitly set transferToEnabled as true, + * the parameter transferToEnabled should be configured by spark.file.transferTo = [true|false]. + */ def copyStream(in: InputStream, out: OutputStream, - closeStreams: Boolean = false): Long = + closeStreams: Boolean = false, + transferToEnabled: Boolean = false): Long = { var count = 0L try { - if (in.isInstanceOf[FileInputStream] && out.isInstanceOf[FileOutputStream]) { + if (in.isInstanceOf[FileInputStream] && out.isInstanceOf[FileOutputStream] + && transferToEnabled) { // When both streams are File stream, use transferTo to improve copy performance. val inChannel = in.asInstanceOf[FileInputStream].getChannel() val outChannel = out.asInstanceOf[FileOutputStream].getChannel() + val initialPos = outChannel.position() val size = inChannel.size() // In case transferTo method transferred less data than we have required. while (count < size) { count += inChannel.transferTo(count, size - count, outChannel) } + + // Check the position after transferTo loop to see if it is in the right position and + // give user information if not. + // Position will not be increased to the expected length after calling transferTo in + // kernel version 2.6.32, this issue can be seen in + // https://bugs.openjdk.java.net/browse/JDK-7052359 + // This will lead to stream corruption issue when using sort-based shuffle (SPARK-3948). + val finalPos = outChannel.position() + assert(finalPos == initialPos + size, + s""" + |Current position $finalPos do not equal to expected position ${initialPos + size} + |after transferTo, please check your kernel version to see if it is 2.6.32, + |this is a kernel bug which will lead to unexpected behavior when using transferTo. + |You can set spark.file.transferTo = false to disable this NIO feature. + """.stripMargin) } else { val buf = new Array[Byte](8192) var n = 0 @@ -334,7 +361,7 @@ private[spark] object Utils extends Logging { val targetFile = new File(targetDir, filename) val uri = new URI(url) val fileOverwrite = conf.getBoolean("spark.files.overwrite", defaultValue = false) - uri.getScheme match { + Option(uri.getScheme).getOrElse("file") match { case "http" | "https" | "ftp" => logInfo("Fetching " + url + " to " + tempFile) @@ -368,7 +395,7 @@ private[spark] object Utils extends Logging { } } Files.move(tempFile, targetFile) - case "file" | null => + case "file" => // In the case of a local file, copy the local file to the target directory. // Note the difference between uri vs url. val sourceFile = if (uri.isAbsolute) new File(uri) else new File(url) @@ -666,15 +693,30 @@ private[spark] object Utils extends Logging { */ def deleteRecursively(file: File) { if (file != null) { - if (file.isDirectory() && !isSymlink(file)) { - for (child <- listFilesSafely(file)) { - deleteRecursively(child) + try { + if (file.isDirectory && !isSymlink(file)) { + var savedIOException: IOException = null + for (child <- listFilesSafely(file)) { + try { + deleteRecursively(child) + } catch { + // In case of multiple exceptions, only last one will be thrown + case ioe: IOException => savedIOException = ioe + } + } + if (savedIOException != null) { + throw savedIOException + } + shutdownDeletePaths.synchronized { + shutdownDeletePaths.remove(file.getAbsolutePath) + } } - } - if (!file.delete()) { - // Delete can also fail if the file simply did not exist - if (file.exists()) { - throw new IOException("Failed to delete: " + file.getAbsolutePath) + } finally { + if (!file.delete()) { + // Delete can also fail if the file simply did not exist + if (file.exists()) { + throw new IOException("Failed to delete: " + file.getAbsolutePath) + } } } } @@ -706,14 +748,14 @@ private[spark] object Utils extends Logging { /** * Determines if a directory contains any files newer than cutoff seconds. - * + * * @param dir must be the path to a directory, or IllegalArgumentException is thrown * @param cutoff measured in seconds. Returns true if there are any files or directories in the * given directory whose last modified time is later than this many seconds ago */ def doesDirectoryContainAnyNewFiles(dir: File, cutoff: Long): Boolean = { if (!dir.isDirectory) { - throw new IllegalArgumentException("$dir is not a directory!") + throw new IllegalArgumentException(s"$dir is not a directory!") } val filesAndDirs = dir.listFiles() val cutoffTimeInMillis = System.currentTimeMillis - (cutoff * 1000) @@ -1347,16 +1389,17 @@ private[spark] object Utils extends Logging { if (uri.getPath == null) { throw new IllegalArgumentException(s"Given path is malformed: $uri") } - uri.getScheme match { - case windowsDrive(d) if windows => + + Option(uri.getScheme) match { + case Some(windowsDrive(d)) if windows => new URI("file:/" + uri.toString.stripPrefix("/")) - case null => + case None => // Preserve fragments for HDFS file name substitution (denoted by "#") // For instance, in "abc.py#xyz.py", "xyz.py" is the name observed by the application val fragment = uri.getFragment val part = new File(uri.getPath).toURI new URI(part.getScheme, part.getPath, fragment) - case _ => + case Some(other) => uri } } @@ -1378,15 +1421,64 @@ private[spark] object Utils extends Logging { } else { paths.split(",").filter { p => val formattedPath = if (windows) formatWindowsPath(p) else p - new URI(formattedPath).getScheme match { + val uri = new URI(formattedPath) + Option(uri.getScheme).getOrElse("file") match { case windowsDrive(d) if windows => false - case "local" | "file" | null => false + case "local" | "file" => false case _ => true } } } } + /** + * Load default Spark properties from the given file. If no file is provided, + * use the common defaults file. This mutates state in the given SparkConf and + * in this JVM's system properties if the config specified in the file is not + * already set. Return the path of the properties file used. + */ + def loadDefaultSparkProperties(conf: SparkConf, filePath: String = null): String = { + val path = Option(filePath).getOrElse(getDefaultPropertiesFile()) + Option(path).foreach { confFile => + getPropertiesFromFile(confFile).filter { case (k, v) => + k.startsWith("spark.") + }.foreach { case (k, v) => + conf.setIfMissing(k, v) + sys.props.getOrElseUpdate(k, v) + } + } + path + } + + /** Load properties present in the given file. */ + def getPropertiesFromFile(filename: String): Map[String, String] = { + val file = new File(filename) + require(file.exists(), s"Properties file $file does not exist") + require(file.isFile(), s"Properties file $file is not a normal file") + + val inReader = new InputStreamReader(new FileInputStream(file), "UTF-8") + try { + val properties = new Properties() + properties.load(inReader) + properties.stringPropertyNames().map(k => (k, properties(k).trim)).toMap + } catch { + case e: IOException => + throw new SparkException(s"Failed when loading Spark properties from $filename", e) + } finally { + inReader.close() + } + } + + /** Return the path of the default Spark properties file. */ + def getDefaultPropertiesFile(env: Map[String, String] = sys.env): String = { + env.get("SPARK_CONF_DIR") + .orElse(env.get("SPARK_HOME").map { t => s"$t${File.separator}conf" }) + .map { t => new File(s"$t${File.separator}spark-defaults.conf")} + .filter(_.isFile) + .map(_.getAbsolutePath) + .orNull + } + /** Return a nice string representation of the exception, including the stack trace. */ def exceptionString(e: Exception): String = { if (e == null) "" else exceptionString(getFormattedClassName(e), e.getMessage, e.getStackTrace) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 644fa36818647..d1b06d14acbd2 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -93,6 +93,7 @@ private[spark] class ExternalSorter[K, V, C]( private val conf = SparkEnv.get.conf private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true) private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 + private val transferToEnabled = conf.getBoolean("spark.file.transferTo", true) // Size of object batches when reading/writing from serializers. // @@ -705,10 +706,10 @@ private[spark] class ExternalSorter[K, V, C]( var out: FileOutputStream = null var in: FileInputStream = null try { - out = new FileOutputStream(outputFile) + out = new FileOutputStream(outputFile, true) for (i <- 0 until numPartitions) { in = new FileInputStream(partitionWriters(i).fileSegment().file) - val size = org.apache.spark.util.Utils.copyStream(in, out, false) + val size = org.apache.spark.util.Utils.copyStream(in, out, false, transferToEnabled) in.close() in = null lengths(i) = size diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 4a078435447e5..3190148fb5f43 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -20,6 +20,7 @@ import java.io.*; import java.net.URI; import java.util.*; +import java.util.concurrent.*; import scala.Tuple2; import scala.Tuple3; @@ -29,6 +30,7 @@ import com.google.common.collect.Iterators; import com.google.common.collect.Lists; import com.google.common.collect.Maps; +import com.google.common.base.Throwables; import com.google.common.base.Optional; import com.google.common.base.Charsets; import com.google.common.io.Files; @@ -43,10 +45,7 @@ import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaDoubleRDD; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.*; import org.apache.spark.api.java.function.*; import org.apache.spark.executor.TaskMetrics; import org.apache.spark.partial.BoundedDouble; @@ -776,7 +775,7 @@ public void persist() { @Test public void iterator() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); - TaskContext context = new TaskContext(0, 0, 0L, false, new TaskMetrics()); + TaskContext context = new TaskContextImpl(0, 0, 0L, false, new TaskMetrics()); Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue()); } @@ -1308,6 +1307,92 @@ public void collectUnderlyingScalaRDD() { Assert.assertEquals(data.size(), collected.length); } + private static final class BuggyMapFunction implements Function { + + @Override + public T call(T x) throws Exception { + throw new IllegalStateException("Custom exception!"); + } + } + + @Test + public void collectAsync() throws Exception { + List data = Arrays.asList(1, 2, 3, 4, 5); + JavaRDD rdd = sc.parallelize(data, 1); + JavaFutureAction> future = rdd.collectAsync(); + List result = future.get(); + Assert.assertEquals(data, result); + Assert.assertFalse(future.isCancelled()); + Assert.assertTrue(future.isDone()); + Assert.assertEquals(1, future.jobIds().size()); + } + + @Test + public void foreachAsync() throws Exception { + List data = Arrays.asList(1, 2, 3, 4, 5); + JavaRDD rdd = sc.parallelize(data, 1); + JavaFutureAction future = rdd.foreachAsync( + new VoidFunction() { + @Override + public void call(Integer integer) throws Exception { + // intentionally left blank. + } + } + ); + future.get(); + Assert.assertFalse(future.isCancelled()); + Assert.assertTrue(future.isDone()); + Assert.assertEquals(1, future.jobIds().size()); + } + + @Test + public void countAsync() throws Exception { + List data = Arrays.asList(1, 2, 3, 4, 5); + JavaRDD rdd = sc.parallelize(data, 1); + JavaFutureAction future = rdd.countAsync(); + long count = future.get(); + Assert.assertEquals(data.size(), count); + Assert.assertFalse(future.isCancelled()); + Assert.assertTrue(future.isDone()); + Assert.assertEquals(1, future.jobIds().size()); + } + + @Test + public void testAsyncActionCancellation() throws Exception { + List data = Arrays.asList(1, 2, 3, 4, 5); + JavaRDD rdd = sc.parallelize(data, 1); + JavaFutureAction future = rdd.foreachAsync(new VoidFunction() { + @Override + public void call(Integer integer) throws Exception { + Thread.sleep(10000); // To ensure that the job won't finish before it's cancelled. + } + }); + future.cancel(true); + Assert.assertTrue(future.isCancelled()); + Assert.assertTrue(future.isDone()); + try { + future.get(2000, TimeUnit.MILLISECONDS); + Assert.fail("Expected future.get() for cancelled job to throw CancellationException"); + } catch (CancellationException ignored) { + // pass + } + } + + @Test + public void testAsyncActionErrorWrapping() throws Exception { + List data = Arrays.asList(1, 2, 3, 4, 5); + JavaRDD rdd = sc.parallelize(data, 1); + JavaFutureAction future = rdd.map(new BuggyMapFunction()).countAsync(); + try { + future.get(2, TimeUnit.SECONDS); + Assert.fail("Expected future.get() for failed job to throw ExcecutionException"); + } catch (ExecutionException ee) { + Assert.assertTrue(Throwables.getStackTraceAsString(ee).contains("Custom exception!")); + } + Assert.assertTrue(future.isDone()); + } + + /** * Test for SPARK-3647. This test needs to use the maven-built assembly to trigger the issue, * since that's the only artifact where Guava classes have been relocated. diff --git a/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java index 0944bf8cd5c71..e9ec700e32e15 100644 --- a/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java +++ b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java @@ -30,8 +30,8 @@ public class JavaTaskCompletionListenerImpl implements TaskCompletionListener { public void onTaskCompletion(TaskContext context) { context.isCompleted(); context.isInterrupted(); - context.getStageId(); - context.getPartitionId(); + context.stageId(); + context.partitionId(); context.isRunningLocally(); context.addTaskCompletionListener(this); } diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index d735010d7c9d5..c0735f448d193 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -66,7 +66,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar // in blockManager.put is a losing battle. You have been warned. blockManager = sc.env.blockManager cacheManager = sc.env.cacheManager - val context = new TaskContext(0, 0, 0) + val context = new TaskContextImpl(0, 0, 0) val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) val getValue = blockManager.get(RDDBlockId(rdd.id, split.index)) assert(computeValue.toList === List(1, 2, 3, 4)) @@ -81,7 +81,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } whenExecuting(blockManager) { - val context = new TaskContext(0, 0, 0) + val context = new TaskContextImpl(0, 0, 0) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(5, 6, 7)) } @@ -94,7 +94,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } whenExecuting(blockManager) { - val context = new TaskContext(0, 0, 0, true) + val context = new TaskContextImpl(0, 0, 0, true) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(1, 2, 3, 4)) } @@ -102,7 +102,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar test("verify task metrics updated correctly") { cacheManager = sc.env.cacheManager - val context = new TaskContext(0, 0, 0) + val context = new TaskContextImpl(0, 0, 0) cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY) assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2) } diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index 7e18f45de7b5b..a8867020e457d 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark import java.io._ import java.util.jar.{JarEntry, JarOutputStream} -import com.google.common.io.Files import org.scalatest.FunSuite import org.apache.spark.SparkContext._ @@ -41,8 +40,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { override def beforeAll() { super.beforeAll() - tmpDir = Files.createTempDir() - tmpDir.deleteOnExit() + tmpDir = Utils.createTempDir() val testTempDir = new File(tmpDir, "test") testTempDir.mkdir() diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 4a53d25012ad9..a2b74c4419d46 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -21,7 +21,6 @@ import java.io.{File, FileWriter} import scala.io.Source -import com.google.common.io.Files import org.apache.hadoop.io._ import org.apache.hadoop.io.compress.DefaultCodec import org.apache.hadoop.mapred.{JobConf, FileAlreadyExistsException, FileSplit, TextInputFormat, TextOutputFormat} @@ -39,8 +38,7 @@ class FileSuite extends FunSuite with LocalSparkContext { override def beforeEach() { super.beforeEach() - tempDir = Files.createTempDir() - tempDir.deleteOnExit() + tempDir = Utils.createTempDir() } override def afterEach() { diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 1fef79ad1001f..cbc0bd178d894 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -146,7 +146,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val masterTracker = new MapOutputTrackerMaster(conf) val actorSystem = ActorSystem("test") val actorRef = TestActorRef[MapOutputTrackerMasterActor]( - new MapOutputTrackerMasterActor(masterTracker, newConf))(actorSystem) + Props(new MapOutputTrackerMasterActor(masterTracker, newConf)))(actorSystem) val masterActor = actorRef.underlyingActor // Frame size should be ~123B, and no exception should be thrown @@ -164,7 +164,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val masterTracker = new MapOutputTrackerMaster(conf) val actorSystem = ActorSystem("test") val actorRef = TestActorRef[MapOutputTrackerMasterActor]( - new MapOutputTrackerMasterActor(masterTracker, newConf))(actorSystem) + Props(new MapOutputTrackerMasterActor(masterTracker, newConf)))(actorSystem) val masterActor = actorRef.underlyingActor // Frame size should be ~1.1MB, and MapOutputTrackerMasterActor should throw exception. diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala new file mode 100644 index 0000000000000..31edad1c56c73 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.FunSuite + +import org.apache.hadoop.io.BytesWritable + +class SparkContextSuite extends FunSuite { + //Regression test for SPARK-3121 + test("BytesWritable implicit conversion is correct") { + val bytesWritable = new BytesWritable() + val inputArray = (1 to 10).map(_.toByte).toArray + bytesWritable.set(inputArray, 0, 10) + bytesWritable.set(inputArray, 0, 5) + + val converter = SparkContext.bytesWritableConverter() + val byteArray = converter.convert(bytesWritable) + assert(byteArray.length === 5) + + bytesWritable.set(inputArray, 0, 0) + val byteArray2 = converter.convert(bytesWritable) + assert(byteArray2.length === 0) + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 4cba90e8f2afe..1cdf50d5c08c7 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -26,7 +26,6 @@ import org.apache.spark.deploy.SparkSubmit._ import org.apache.spark.util.Utils import org.scalatest.FunSuite import org.scalatest.Matchers -import com.google.common.io.Files class SparkSubmitSuite extends FunSuite with Matchers { def beforeAll() { @@ -332,7 +331,7 @@ class SparkSubmitSuite extends FunSuite with Matchers { } def forConfDir(defaults: Map[String, String]) (f: String => Unit) = { - val tmpDir = Files.createTempDir() + val tmpDir = Utils.createTempDir() val defaultsConf = new File(tmpDir.getAbsolutePath, "spark-defaults.conf") val writer = new OutputStreamWriter(new FileOutputStream(defaultsConf)) diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala index 39ab53cf0b5b1..5e2592e8d2e8d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala @@ -26,14 +26,12 @@ import org.apache.spark.SparkConf class ExecutorRunnerTest extends FunSuite { test("command includes appId") { - def f(s:String) = new File(s) + val appId = "12345-worker321-9876" val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) val appDesc = new ApplicationDescription("app name", Some(8), 500, - Command("foo", Seq(), Map(), Seq(), Seq(), Seq()), "appUiUrl") - val appId = "12345-worker321-9876" - val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", f(sparkHome), - f("ooga"), "blah", new SparkConf, ExecutorState.RUNNING) - + Command("foo", Seq(appId), Map(), Seq(), Seq(), Seq()), "appUiUrl") + val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", + new File(sparkHome), new File("ooga"), "blah", new SparkConf, ExecutorState.RUNNING) assert(er.getCommandSeq.last === appId) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala new file mode 100644 index 0000000000000..1a28a9a187cd7 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.deploy.worker + +import org.apache.spark.SparkConf +import org.scalatest.FunSuite + + +class WorkerArgumentsTest extends FunSuite { + + test("Memory can't be set to 0 when cmd line args leave off M or G") { + val conf = new SparkConf + val args = Array("-m", "10000", "spark://localhost:0000 ") + intercept[IllegalStateException] { + new WorkerArguments(args, conf) + } + } + + + test("Memory can't be set to 0 when SPARK_WORKER_MEMORY env property leaves off M or G") { + val args = Array("spark://localhost:0000 ") + + class MySparkConf extends SparkConf(false) { + override def getenv(name: String) = { + if (name == "SPARK_WORKER_MEMORY") "50000" + else super.getenv(name) + } + + override def clone: SparkConf = { + new MySparkConf().setAll(settings) + } + } + val conf = new MySparkConf() + intercept[IllegalStateException] { + new WorkerArguments(args, conf) + } + } + + test("Memory correctly set when SPARK_WORKER_MEMORY env property appends G") { + val args = Array("spark://localhost:0000 ") + + class MySparkConf extends SparkConf(false) { + override def getenv(name: String) = { + if (name == "SPARK_WORKER_MEMORY") "5G" + else super.getenv(name) + } + + override def clone: SparkConf = { + new MySparkConf().setAll(settings) + } + } + val conf = new MySparkConf() + val workerArgs = new WorkerArguments(args, conf) + assert(workerArgs.memory === 5120) + } + + test("Memory correctly set from args with M appended to memory value") { + val conf = new SparkConf + val args = Array("-m", "10000M", "spark://localhost:0000 ") + + val workerArgs = new WorkerArguments(args, conf) + assert(workerArgs.memory === 10000) + + } + +} diff --git a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala index d5ebfb3f3fae1..12d1c7b2faba6 100644 --- a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala @@ -23,8 +23,6 @@ import java.io.FileOutputStream import scala.collection.immutable.IndexedSeq -import com.google.common.io.Files - import org.scalatest.BeforeAndAfterAll import org.scalatest.FunSuite @@ -66,9 +64,7 @@ class WholeTextFileRecordReaderSuite extends FunSuite with BeforeAndAfterAll { * 3) Does the contents be the same. */ test("Correctness of WholeTextFileRecordReader.") { - - val dir = Files.createTempDir() - dir.deleteOnExit() + val dir = Utils.createTempDir() println(s"Local disk address is ${dir.toString}.") WholeTextFileRecordReaderSuite.files.foreach { case (filename, contents) => diff --git a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala index 9f49587cdc670..b70734dfe37cf 100644 --- a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala @@ -27,6 +27,7 @@ import scala.language.postfixOps import org.scalatest.FunSuite import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.util.Utils /** * Test the ConnectionManager with various security settings. @@ -236,7 +237,7 @@ class ConnectionManagerSuite extends FunSuite { val manager = new ConnectionManager(0, conf, securityManager) val managerServer = new ConnectionManager(0, conf, securityManager) managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - throw new Exception + throw new Exception("Custom exception text") }) val size = 10 * 1024 * 1024 @@ -246,9 +247,10 @@ class ConnectionManagerSuite extends FunSuite { val future = manager.sendMessageReliably(managerServer.id, bufferMessage) - intercept[IOException] { + val exception = intercept[IOException] { Await.result(future, 1 second) } + assert(Utils.exceptionString(exception).contains("Custom exception text")) manager.stop() managerServer.stop() diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 75b01191901b8..3620e251cc139 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -24,13 +24,14 @@ import org.apache.hadoop.util.Progressable import scala.collection.mutable.{ArrayBuffer, HashSet} import scala.util.Random -import com.google.common.io.Files import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.mapreduce.{JobContext => NewJobContext, OutputCommitter => NewOutputCommitter, OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, TaskAttemptContext => NewTaskAttempContext} import org.apache.spark.{Partitioner, SharedSparkContext} import org.apache.spark.SparkContext._ +import org.apache.spark.util.Utils + import org.scalatest.FunSuite class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { @@ -381,14 +382,16 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { } test("zero-partition RDD") { - val emptyDir = Files.createTempDir() - emptyDir.deleteOnExit() - val file = sc.textFile(emptyDir.getAbsolutePath) - assert(file.partitions.size == 0) - assert(file.collect().toList === Nil) - // Test that a shuffle on the file works, because this used to be a bug - assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil) - emptyDir.delete() + val emptyDir = Utils.createTempDir() + try { + val file = sc.textFile(emptyDir.getAbsolutePath) + assert(file.partitions.isEmpty) + assert(file.collect().toList === Nil) + // Test that a shuffle on the file works, because this used to be a bug + assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil) + } finally { + Utils.deleteRecursively(emptyDir) + } } test("keys and values") { diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala index be972c5e97a7e..271a90c6646bb 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala @@ -174,7 +174,7 @@ class PipedRDDSuite extends FunSuite with SharedSparkContext { } val hadoopPart1 = generateFakeHadoopPartition() val pipedRdd = new PipedRDD(nums, "printenv " + varName) - val tContext = new TaskContext(0, 0, 0) + val tContext = new TaskContextImpl(0, 0, 0) val rddIter = pipedRdd.compute(hadoopPart1, tContext) val arr = rddIter.toArray assert(arr(0) == "/some/path") diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 3efa85431876b..abc300fcffaf9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.scheduler import scala.collection.mutable import scala.io.Source -import com.google.common.io.Files import org.apache.hadoop.fs.{FileStatus, Path} import org.json4s.jackson.JsonMethods._ import org.scalatest.{BeforeAndAfter, FunSuite} @@ -51,8 +50,7 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter { private var logDirPath: Path = _ before { - testDir = Files.createTempDir() - testDir.deleteOnExit() + testDir = Utils.createTempDir() logDirPath = Utils.getFilePath(testDir, "spark-events") } diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index 48114feee6233..e05f373392d4a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.scheduler import java.io.{File, PrintWriter} -import com.google.common.io.Files import org.json4s.jackson.JsonMethods._ import org.scalatest.{BeforeAndAfter, FunSuite} @@ -39,8 +38,7 @@ class ReplayListenerSuite extends FunSuite with BeforeAndAfter { private var testDir: File = _ before { - testDir = Files.createTempDir() - testDir.deleteOnExit() + testDir = Utils.createTempDir() } after { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index faba5508c906c..561a5e9cd90c4 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -51,7 +51,7 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte } test("all TaskCompletionListeners should be called even if some fail") { - val context = new TaskContext(0, 0, 0) + val context = new TaskContextImpl(0, 0, 0) val listener = mock(classOf[TaskCompletionListener]) context.addTaskCompletionListener(_ => throw new Exception("blah")) context.addTaskCompletionListener(listener) diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index e4522e00a622d..bc5c74c126b74 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -19,22 +19,13 @@ package org.apache.spark.storage import java.io.{File, FileWriter} -import org.apache.spark.network.nio.NioBlockTransferService -import org.apache.spark.shuffle.hash.HashShuffleManager - -import scala.collection.mutable import scala.language.reflectiveCalls -import akka.actor.Props -import com.google.common.io.Files import org.mockito.Mockito.{mock, when} import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} import org.apache.spark.SparkConf -import org.apache.spark.scheduler.LiveListenerBus -import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.util.{AkkaUtils, Utils} -import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.util.Utils class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll { private val testConf = new SparkConf(false) @@ -48,10 +39,8 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before override def beforeAll() { super.beforeAll() - rootDir0 = Files.createTempDir() - rootDir0.deleteOnExit() - rootDir1 = Files.createTempDir() - rootDir1.deleteOnExit() + rootDir0 = Utils.createTempDir() + rootDir1 = Utils.createTempDir() rootDirs = rootDir0.getAbsolutePath + "," + rootDir1.getAbsolutePath } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 809bd70929656..a8c049d749015 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import org.apache.spark.TaskContext +import org.apache.spark.{TaskContextImpl, TaskContext} import org.apache.spark.network.{BlockFetchingListener, BlockTransferService} import org.mockito.Mockito._ @@ -62,7 +62,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { ) val iterator = new ShuffleBlockFetcherIterator( - new TaskContext(0, 0, 0), + new TaskContextImpl(0, 0, 0), transfer, blockManager, blocksByAddress, @@ -120,7 +120,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { ) val iterator = new ShuffleBlockFetcherIterator( - new TaskContext(0, 0, 0), + new TaskContextImpl(0, 0, 0), transfer, blockManager, blocksByAddress, @@ -169,7 +169,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { (bmId, Seq((blId1, 1L), (blId2, 1L)))) val iterator = new ShuffleBlockFetcherIterator( - new TaskContext(0, 0, 0), + new TaskContextImpl(0, 0, 0), transfer, blockManager, blocksByAddress, diff --git a/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala b/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala index c3dd156b40514..72466a3aa1130 100644 --- a/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala @@ -21,7 +21,6 @@ import java.io.{File, IOException} import scala.io.Source -import com.google.common.io.Files import org.apache.hadoop.fs.Path import org.scalatest.{BeforeAndAfter, FunSuite} @@ -44,7 +43,7 @@ class FileLoggerSuite extends FunSuite with BeforeAndAfter { private var logDirPathString: String = _ before { - testDir = Files.createTempDir() + testDir = Utils.createTempDir() logDirPath = Utils.getFilePath(testDir, "test-file-logger") logDirPathString = logDirPath.toString } @@ -75,13 +74,13 @@ class FileLoggerSuite extends FunSuite with BeforeAndAfter { test("Logging when directory already exists") { // Create the logging directory multiple times - new FileLogger(logDirPathString, new SparkConf, overwrite = true).start() - new FileLogger(logDirPathString, new SparkConf, overwrite = true).start() - new FileLogger(logDirPathString, new SparkConf, overwrite = true).start() + new FileLogger(logDirPathString, new SparkConf, compress = false, overwrite = true).start() + new FileLogger(logDirPathString, new SparkConf, compress = false, overwrite = true).start() + new FileLogger(logDirPathString, new SparkConf, compress = false, overwrite = true).start() // If overwrite is not enabled, an exception should be thrown intercept[IOException] { - new FileLogger(logDirPathString, new SparkConf, overwrite = false).start() + new FileLogger(logDirPathString, new SparkConf, compress = false, overwrite = false).start() } } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index e63d9d085e385..ea7ef0524d1e1 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -27,6 +27,8 @@ import com.google.common.base.Charsets import com.google.common.io.Files import org.scalatest.FunSuite +import org.apache.spark.SparkConf + class UtilsSuite extends FunSuite { test("bytesToString") { @@ -112,7 +114,7 @@ class UtilsSuite extends FunSuite { } test("reading offset bytes of a file") { - val tmpDir2 = Files.createTempDir() + val tmpDir2 = Utils.createTempDir() tmpDir2.deleteOnExit() val f1Path = tmpDir2 + "/f1" val f1 = new FileOutputStream(f1Path) @@ -141,7 +143,7 @@ class UtilsSuite extends FunSuite { } test("reading offset bytes across multiple files") { - val tmpDir = Files.createTempDir() + val tmpDir = Utils.createTempDir() tmpDir.deleteOnExit() val files = (1 to 3).map(i => new File(tmpDir, i.toString)) Files.write("0123456789", files(0), Charsets.UTF_8) @@ -308,4 +310,45 @@ class UtilsSuite extends FunSuite { } } + test("deleteRecursively") { + val tempDir1 = Utils.createTempDir() + assert(tempDir1.exists()) + Utils.deleteRecursively(tempDir1) + assert(!tempDir1.exists()) + + val tempDir2 = Utils.createTempDir() + val tempFile1 = new File(tempDir2, "foo.txt") + Files.touch(tempFile1) + assert(tempFile1.exists()) + Utils.deleteRecursively(tempFile1) + assert(!tempFile1.exists()) + + val tempDir3 = new File(tempDir2, "subdir") + assert(tempDir3.mkdir()) + val tempFile2 = new File(tempDir3, "bar.txt") + Files.touch(tempFile2) + assert(tempFile2.exists()) + Utils.deleteRecursively(tempDir2) + assert(!tempDir2.exists()) + assert(!tempDir3.exists()) + assert(!tempFile2.exists()) + } + + test("loading properties from file") { + val outFile = File.createTempFile("test-load-spark-properties", "test") + try { + System.setProperty("spark.test.fileNameLoadB", "2") + Files.write("spark.test.fileNameLoadA true\n" + + "spark.test.fileNameLoadB 1\n", outFile, Charsets.UTF_8) + val properties = Utils.getPropertiesFromFile(outFile.getAbsolutePath) + properties + .filter { case (k, v) => k.startsWith("spark.")} + .foreach { case (k, v) => sys.props.getOrElseUpdate(k, v)} + val sparkConf = new SparkConf + assert(sparkConf.getBoolean("spark.test.fileNameLoadA", false) === true) + assert(sparkConf.getInt("spark.test.fileNameLoadB", 1) === 2) + } finally { + outFile.delete() + } + } } diff --git a/dev/run-tests b/dev/run-tests index 4be2baaf48cd1..f47fcf66ff7e7 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -42,7 +42,7 @@ function handle_error () { elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.0" ]; then export SBT_MAVEN_PROFILES_ARGS="-Dhadoop.version=2.0.0-mr1-cdh4.1.1" elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.2" ]; then - export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Dhadoop.version=2.2.0" + export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0" elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.3" ]; then export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0" fi diff --git a/dev/scalastyle b/dev/scalastyle index efb5f291ea3b7..c3b356bcb3c06 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -26,6 +26,8 @@ echo -e "q\n" | sbt/sbt -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 yarn/scalasty >> scalastyle.txt ERRORS=$(cat scalastyle.txt | grep -e "\") +rm scalastyle.txt + if test ! -z "$ERRORS"; then echo -e "Scalastyle checks failed at following occurrences:\n$ERRORS" exit 1 diff --git a/docs/README.md b/docs/README.md index 79708c3df9106..d2d58e435d4c4 100644 --- a/docs/README.md +++ b/docs/README.md @@ -25,8 +25,7 @@ installing via the Ruby Gem dependency manager. Since the exact HTML output varies between versions of Jekyll and its dependencies, we list specific versions here in some cases: - $ sudo gem install jekyll -v 1.4.3 - $ sudo gem uninstall kramdown -v 1.4.1 + $ sudo gem install jekyll $ sudo gem install jekyll-redirect-from Execute `jekyll` from the `docs/` directory. Compiling the site with Jekyll will create a directory @@ -54,19 +53,19 @@ phase, use the following sytax: // supported languages too. {% endhighlight %} -## API Docs (Scaladoc and Epydoc) +## API Docs (Scaladoc and Sphinx) You can build just the Spark scaladoc by running `sbt/sbt doc` from the SPARK_PROJECT_ROOT directory. -Similarly, you can build just the PySpark epydoc by running `epydoc --config epydoc.conf` from the -SPARK_PROJECT_ROOT/pyspark directory. Documentation is only generated for classes that are listed as +Similarly, you can build just the PySpark docs by running `make html` from the +SPARK_PROJECT_ROOT/python/docs directory. Documentation is only generated for classes that are listed as public in `__init__.py`. When you run `jekyll` in the `docs` directory, it will also copy over the scaladoc for the various Spark subprojects into the `docs` directory (and then also into the `_site` directory). We use a jekyll plugin to run `sbt/sbt doc` before building the site so if you haven't run it (recently) it may take some time as it generates all of the scaladoc. The jekyll plugin also generates the -PySpark docs using [epydoc](http://epydoc.sourceforge.net/). +PySpark docs [Sphinx](http://sphinx-doc.org/). NOTE: To skip the step of building and copying over the Scala and Python API docs, run `SKIP_API=1 jekyll`. diff --git a/docs/_config.yml b/docs/_config.yml index 7bc3a78e2d265..f4bf242ac191b 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -8,6 +8,9 @@ gems: kramdown: entity_output: numeric +include: + - _static + # These allow the documentation to be updated with nerw releases # of Spark, Scala, and Mesos. SPARK_VERSION: 1.0.0-SNAPSHOT diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 3b02e090aec28..4566a2fff562b 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -63,19 +63,20 @@ puts "cp -r " + source + "/. " + dest cp_r(source + "/.", dest) - # Build Epydoc for Python - puts "Moving to python directory and building epydoc." - cd("../python") - puts `epydoc --config epydoc.conf` + # Build Sphinx docs for Python - puts "Moving back into docs dir." - cd("../docs") + puts "Moving to python/docs directory and building sphinx." + cd("../python/docs") + puts `make html` + + puts "Moving back into home dir." + cd("../../") puts "Making directory api/python" - mkdir_p "api/python" + mkdir_p "docs/api/python" - puts "cp -r ../python/docs/. api/python" - cp_r("../python/docs/.", "api/python") + puts "cp -r python/docs/_build/html/. docs/api/python" + cp_r("python/docs/_build/html/.", "docs/api/python") cd("..") end diff --git a/docs/configuration.md b/docs/configuration.md index 1c33855365170..96fa1377ec399 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -103,6 +103,14 @@ of the most common options to set are: (e.g. 512m, 2g). + + spark.driver.memory + 512m + + Amount of memory to use for the driver process, i.e. where SparkContext is initialized. + (e.g. 512m, 2g). + + spark.serializer org.apache.spark.serializer.
    JavaSerializer @@ -153,14 +161,6 @@ Apart from these, the following properties are also available, and may be useful #### Runtime Environment - - - - - @@ -357,7 +357,7 @@ Apart from these, the following properties are also available, and may be useful @@ -619,6 +619,15 @@ Apart from these, the following properties are also available, and may be useful output directories. We recommend that users do not disable this except if trying to achieve compatibility with previous versions of Spark. Simply use Hadoop's FileSystem API to delete output directories by hand. + + + + + @@ -717,7 +726,7 @@ Apart from these, the following properties are also available, and may be useful - + @@ -885,7 +894,7 @@ Apart from these, the following properties are also available, and may be useful to wait for before scheduling begins. Specified as a double between 0 and 1. Regardless of whether the minimum ratio of resources has been reached, the maximum amount of time it will wait before scheduling begins is controlled by config - spark.scheduler.maxRegisteredResourcesWaitingTime + spark.scheduler.maxRegisteredResourcesWaitingTime. diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index d10bd63746629..7978e934fb36b 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -69,7 +69,7 @@ println("Within Set Sum of Squared Errors = " + WSSSE) All of MLlib's methods use Java-friendly types, so you can import and call them there the same way you do in Scala. The only caveat is that the methods take Scala RDD objects, while the Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a Scala one by -calling `.rdd()` on your `JavaRDD` object. A standalone application example +calling `.rdd()` on your `JavaRDD` object. A self-contained application example that is equivalent to the provided example in Scala is given below: {% highlight java %} @@ -113,12 +113,6 @@ public class KMeansExample { } } {% endhighlight %} - -In order to run the above standalone application, follow the instructions -provided in the [Standalone -Applications](quick-start.html#standalone-applications) section of the Spark -quick-start guide. Be sure to also include *spark-mllib* to your build file as -a dependency.
    @@ -153,3 +147,9 @@ print("Within Set Sum of Squared Error = " + str(WSSSE))
    + +In order to run the above application, follow the instructions +provided in the [Self-Contained Applications](quick-start.html#self-contained-applications) +section of the Spark +Quick Start guide. Be sure to also include *spark-mllib* to your build file as +a dependency. diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index d5c539db791be..2094963392295 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -110,7 +110,7 @@ val model = ALS.trainImplicit(ratings, rank, numIterations, alpha) All of MLlib's methods use Java-friendly types, so you can import and call them there the same way you do in Scala. The only caveat is that the methods take Scala RDD objects, while the Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a Scala one by -calling `.rdd()` on your `JavaRDD` object. A standalone application example +calling `.rdd()` on your `JavaRDD` object. A self-contained application example that is equivalent to the provided example in Scala is given bellow: {% highlight java %} @@ -184,12 +184,6 @@ public class CollaborativeFiltering { } } {% endhighlight %} - -In order to run the above standalone application, follow the instructions -provided in the [Standalone -Applications](quick-start.html#standalone-applications) section of the Spark -quick-start guide. Be sure to also include *spark-mllib* to your build file as -a dependency.
    @@ -229,6 +223,12 @@ model = ALS.trainImplicit(ratings, rank, numIterations, alpha = 0.01)
    +In order to run the above application, follow the instructions +provided in the [Self-Contained Applications](quick-start.html#self-contained-applications) +section of the Spark +Quick Start guide. Be sure to also include *spark-mllib* to your build file as +a dependency. + ## Tutorial The [training exercises](https://databricks-training.s3.amazonaws.com/index.html) from the Spark Summit 2014 include a hands-on tutorial for diff --git a/docs/mllib-dimensionality-reduction.md b/docs/mllib-dimensionality-reduction.md index 21cb35b4270ca..870fed6cc5024 100644 --- a/docs/mllib-dimensionality-reduction.md +++ b/docs/mllib-dimensionality-reduction.md @@ -121,9 +121,9 @@ public class SVD { The same code applies to `IndexedRowMatrix` if `U` is defined as an `IndexedRowMatrix`. -In order to run the above standalone application, follow the instructions -provided in the [Standalone -Applications](quick-start.html#standalone-applications) section of the Spark +In order to run the above application, follow the instructions +provided in the [Self-Contained +Applications](quick-start.html#self-contained-applications) section of the Spark quick-start guide. Be sure to also include *spark-mllib* to your build file as a dependency. @@ -200,10 +200,11 @@ public class PCA { } {% endhighlight %} -In order to run the above standalone application, follow the instructions -provided in the [Standalone -Applications](quick-start.html#standalone-applications) section of the Spark -quick-start guide. Be sure to also include *spark-mllib* to your build file as -a dependency. + +In order to run the above application, follow the instructions +provided in the [Self-Contained Applications](quick-start.html#self-contained-applications) +section of the Spark +quick-start guide. Be sure to also include *spark-mllib* to your build file as +a dependency. diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index d31bec3e1bd01..bc914a1899801 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -247,7 +247,7 @@ val modelL1 = svmAlg.run(training) All of MLlib's methods use Java-friendly types, so you can import and call them there the same way you do in Scala. The only caveat is that the methods take Scala RDD objects, while the Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a Scala one by -calling `.rdd()` on your `JavaRDD` object. A standalone application example +calling `.rdd()` on your `JavaRDD` object. A self-contained application example that is equivalent to the provided example in Scala is given bellow: {% highlight java %} @@ -323,9 +323,9 @@ svmAlg.optimizer() final SVMModel modelL1 = svmAlg.run(training.rdd()); {% endhighlight %} -In order to run the above standalone application, follow the instructions -provided in the [Standalone -Applications](quick-start.html#standalone-applications) section of the Spark +In order to run the above application, follow the instructions +provided in the [Self-Contained +Applications](quick-start.html#self-contained-applications) section of the Spark quick-start guide. Be sure to also include *spark-mllib* to your build file as a dependency. @@ -482,12 +482,6 @@ public class LinearRegression { } } {% endhighlight %} - -In order to run the above standalone application, follow the instructions -provided in the [Standalone -Applications](quick-start.html#standalone-applications) section of the Spark -quick-start guide. Be sure to also include *spark-mllib* to your build file as -a dependency.
    @@ -519,6 +513,12 @@ print("Mean Squared Error = " + str(MSE))
    +In order to run the above application, follow the instructions +provided in the [Self-Contained Applications](quick-start.html#self-contained-applications) +section of the Spark +quick-start guide. Be sure to also include *spark-mllib* to your build file as +a dependency. + ## Streaming linear regression When data arrive in a streaming fashion, it is useful to fit regression models online, diff --git a/docs/monitoring.md b/docs/monitoring.md index d07ec4a57a2cc..e3f81a76acdbb 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -77,6 +77,13 @@ follows: one implementation, provided by Spark, which looks for application logs stored in the file system. + + + + + diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 8e8cc1dd983f8..18420afb27e3c 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -211,17 +211,17 @@ For a complete list of options, run `pyspark --help`. Behind the scenes, It is also possible to launch the PySpark shell in [IPython](http://ipython.org), the enhanced Python interpreter. PySpark works with IPython 1.0.0 and later. To -use IPython, set the `PYSPARK_PYTHON` variable to `ipython` when running `bin/pyspark`: +use IPython, set the `PYSPARK_DRIVER_PYTHON` variable to `ipython` when running `bin/pyspark`: {% highlight bash %} -$ PYSPARK_PYTHON=ipython ./bin/pyspark +$ PYSPARK_DRIVER_PYTHON=ipython ./bin/pyspark {% endhighlight %} -You can customize the `ipython` command by setting `PYSPARK_PYTHON_OPTS`. For example, to launch +You can customize the `ipython` command by setting `PYSPARK_DRIVER_PYTHON_OPTS`. For example, to launch the [IPython Notebook](http://ipython.org/notebook.html) with PyLab plot support: {% highlight bash %} -$ PYSPARK_PYTHON=ipython PYSPARK_PYTHON_OPTS="notebook --pylab inline" ./bin/pyspark +$ PYSPARK_DRIVER_PYTHON=ipython PYSPARK_DRIVER_PYTHON_OPTS="notebook --pylab inline" ./bin/pyspark {% endhighlight %} diff --git a/docs/quick-start.md b/docs/quick-start.md index 23313d8aa6152..6236de0e1f2c4 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -8,7 +8,7 @@ title: Quick Start This tutorial provides a quick introduction to using Spark. We will first introduce the API through Spark's interactive shell (in Python or Scala), -then show how to write standalone applications in Java, Scala, and Python. +then show how to write applications in Java, Scala, and Python. See the [programming guide](programming-guide.html) for a more complete reference. To follow along with this guide, first download a packaged release of Spark from the @@ -215,8 +215,8 @@ a cluster, as described in the [programming guide](programming-guide.html#initia -# Standalone Applications -Now say we wanted to write a standalone application using the Spark API. We will walk through a +# Self-Contained Applications +Now say we wanted to write a self-contained application using the Spark API. We will walk through a simple application in both Scala (with SBT), Java (with Maven), and Python.
    @@ -387,7 +387,7 @@ Lines with a: 46, Lines with b: 23
    -Now we will show how to write a standalone application using the Python API (PySpark). +Now we will show how to write an application using the Python API (PySpark). As an example, we'll create a simple Spark application, `SimpleApp.py`: diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 5c21e912ea160..8bbba88b31978 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -212,6 +212,67 @@ The complete code can be found in the Spark Streaming example [JavaNetworkWordCount]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java).
    +
    +
    +First, we import StreamingContext, which is the main entry point for all streaming functionality. We create a local StreamingContext with two execution threads, and batch interval of 1 second. + +{% highlight python %} +from pyspark import SparkContext +from pyspark.streaming import StreamingContext + +# Create a local StreamingContext with two working thread and batch interval of 1 second +sc = SparkContext("local[2]", "NetworkWordCount") +ssc = StreamingContext(sc, 1) +{% endhighlight %} + +Using this context, we can create a DStream that represents streaming data from a TCP +source hostname, e.g. `localhost`, and port, e.g. `9999` + +{% highlight python %} +# Create a DStream that will connect to hostname:port, like localhost:9999 +lines = ssc.socketTextStream("localhost", 9999) +{% endhighlight %} + +This `lines` DStream represents the stream of data that will be received from the data +server. Each record in this DStream is a line of text. Next, we want to split the lines by +space into words. + +{% highlight python %} +# Split each line into words +words = lines.flatMap(lambda line: line.split(" ")) +{% endhighlight %} + +`flatMap` is a one-to-many DStream operation that creates a new DStream by +generating multiple new records from each record in the source DStream. In this case, +each line will be split into multiple words and the stream of words is represented as the +`words` DStream. Next, we want to count these words. + +{% highlight python %} +# Count each word in each batch +pairs = words.map(lambda word: (word, 1)) +wordCounts = pairs.reduceByKey(lambda x, y: x + y) + +# Print the first ten elements of each RDD generated in this DStream to the console +wordCounts.pprint() +{% endhighlight %} + +The `words` DStream is further mapped (one-to-one transformation) to a DStream of `(word, +1)` pairs, which is then reduced to get the frequency of words in each batch of data. +Finally, `wordCounts.pprint()` will print a few of the counts generated every second. + +Note that when these lines are executed, Spark Streaming only sets up the computation it +will perform when it is started, and no real processing has started yet. To start the processing +after all the transformations have been setup, we finally call + +{% highlight python %} +ssc.start() # Start the computation +ssc.awaitTermination() # Wait for the computation to terminate +{% endhighlight %} + +The complete code can be found in the Spark Streaming example +[NetworkWordCount]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/network_wordcount.py). +
    +
    @@ -236,6 +297,11 @@ $ ./bin/run-example streaming.NetworkWordCount localhost 9999 $ ./bin/run-example streaming.JavaNetworkWordCount localhost 9999 {% endhighlight %} +
    +{% highlight bash %} +$ ./bin/spark-submit examples/src/main/python/streaming/network_wordcount.py localhost 9999 +{% endhighlight %} +
    @@ -259,8 +325,11 @@ hello world
    Property NameDefaultMeaning
    spark.executor.memory512m - Amount of memory to use per executor process, in the same format as JVM memory strings - (e.g. 512m, 2g). -
    spark.executor.extraJavaOptions (none)spark.ui.port 4040 - Port for your application's dashboard, which shows memory and workload data + Port for your application's dashboard, which shows memory and workload data.
    spark.hadoop.cloneConffalseIf set to true, clones a new Hadoop Configuration object for each task. This + option should be enabled to work around Configuration thread-safety issues (see + SPARK-2546 for more details). + This is disabled by default in order to avoid unexpected performance regressions for jobs that + are not affected by these issues.
    spark.executor.heartbeatInterval 10000
    spark.akka.heartbeat.pauses6006000 This is set to a larger value to disable failure detector that comes inbuilt akka. It can be enabled again, if you plan to use this feature (Not recommended). Acceptable heart beat pause @@ -872,8 +881,8 @@ Apart from these, the following properties are also available, and may be useful spark.scheduler.revive.interval 1000 - The interval length for the scheduler to revive the worker resource offers to run tasks. - (in milliseconds) + The interval length for the scheduler to revive the worker resource offers to run tasks + (in milliseconds).
    spark.history.fs.logDirectory(none) + Directory that contains application event logs to be loaded by the history server +
    spark.history.fs.updateInterval 10 +
    + +
    {% highlight bash %} -# TERMINAL 2: RUNNING NetworkWordCount or JavaNetworkWordCount +# TERMINAL 2: RUNNING NetworkWordCount $ ./bin/run-example streaming.NetworkWordCount localhost 9999 ... @@ -271,6 +340,37 @@ Time: 1357008430000 ms (world,1) ... {% endhighlight %} +
    + +
    +{% highlight bash %} +# TERMINAL 2: RUNNING JavaNetworkWordCount + +$ ./bin/run-example streaming.JavaNetworkWordCount localhost 9999 +... +------------------------------------------- +Time: 1357008430000 ms +------------------------------------------- +(hello,1) +(world,1) +... +{% endhighlight %} +
    +
    +{% highlight bash %} +# TERMINAL 2: RUNNING network_wordcount.py + +$ ./bin/spark-submit examples/src/main/python/streaming/network_wordcount.py localhost 9999 +... +------------------------------------------- +Time: 2014-10-14 15:25:21 +------------------------------------------- +(hello,1) +(world,1) +... +{% endhighlight %} +
    +
    @@ -398,9 +498,34 @@ JavaSparkContext sc = ... //existing JavaSparkContext JavaStreamingContext ssc = new JavaStreamingContext(sc, new Duration(1000)); {% endhighlight %}
    +
    + +A [StreamingContext](api/python/pyspark.streaming.html#pyspark.streaming.StreamingContext) object can be created from a [SparkContext](api/python/pyspark.html#pyspark.SparkContext) object. + +{% highlight python %} +from pyspark import SparkContext +from pyspark.streaming import StreamingContext + +sc = SparkContext(master, appName) +ssc = StreamingContext(sc, 1) +{% endhighlight %} + +The `appName` parameter is a name for your application to show on the cluster UI. +`master` is a [Spark, Mesos or YARN cluster URL](submitting-applications.html#master-urls), +or a special __"local[\*]"__ string to run in local mode. In practice, when running on a cluster, +you will not want to hardcode `master` in the program, +but rather [launch the application with `spark-submit`](submitting-applications.html) and +receive it there. However, for local testing and unit tests, you can pass "local[\*]" to run Spark Streaming +in-process (detects the number of cores in the local system). + +The batch interval must be set based on the latency requirements of your application +and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-size) +section for more details. +
    After a context is defined, you have to do the follow steps. + 1. Define the input sources. 1. Setup the streaming computations. 1. Start the receiving and procesing of data using `streamingContext.start()`. @@ -483,6 +608,9 @@ methods for creating DStreams from files and Akka actors as input sources.
    streamingContext.fileStream(dataDirectory);
    +
    + streamingContext.textFileStream(dataDirectory) +
    Spark Streaming will monitor the directory `dataDirectory` and process any files created in that directory (files written in nested directories not supported). Note that @@ -494,7 +622,7 @@ methods for creating DStreams from files and Akka actors as input sources. For simple text files, there is an easier method `streamingContext.textFileStream(dataDirectory)`. And file streams do not require running a receiver, hence does not require allocating cores. -- **Streams based on Custom Actors:** DStreams can be created with data streams received through Akka actors by using `streamingContext.actorStream(actorProps, actor-name)`. See the [Custom Receiver Guide](#implementing-and-using-a-custom-actor-based-receiver) for more details. +- **Streams based on Custom Actors:** DStreams can be created with data streams received through Akka actors by using `streamingContext.actorStream(actorProps, actor-name)`. See the [Custom Receiver Guide](streaming-custom-receivers.html#implementing-and-using-a-custom-actor-based-receiver) for more details. - **Queue of RDDs as a Stream:** For testing a Spark Streaming application with test data, one can also create a DStream based on a queue of RDDs, using `streamingContext.queueStream(queueOfRDDs)`. Each RDD pushed into the queue will be treated as a batch of data in the DStream, and processed like a stream. @@ -684,13 +812,30 @@ This is applied on a DStream containing words (say, the `pairs` DStream containi JavaPairDStream runningCounts = pairs.updateStateByKey(updateFunction); {% endhighlight %} + +
    + +{% highlight python %} +def updateFunction(newValues, runningCount): + if runningCount is None: + runningCount = 0 + return sum(newValues, runningCount) # add the new values with the previous running count to get the new count +{% endhighlight %} + +This is applied on a DStream containing words (say, the `pairs` DStream containing `(word, +1)` pairs in the [earlier example](#a-quick-example)). + +{% highlight python %} +runningCounts = pairs.updateStateByKey(updateFunction) +{% endhighlight %} +
    The update function will be called for each word, with `newValues` having a sequence of 1's (from the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete Scala code, take a look at the example -[StatefulNetworkWordCount]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala). +[stateful_network_wordcount.py]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/stateful_network_wordcount.py). #### Transform Operation {:.no_toc} @@ -732,6 +877,15 @@ JavaPairDStream cleanedDStream = wordCounts.transform( }); {% endhighlight %} + +
    + +{% highlight python %} +spamInfoRDD = sc.pickleFile(...) # RDD containing spam information + +# join data stream with spam information to do data cleaning +cleanedDStream = wordCounts.transform(lambda rdd: rdd.join(spamInfoRDD).filter(...)) +{% endhighlight %}
    @@ -793,6 +947,14 @@ Function2 reduceFunc = new Function2 windowedWordCounts = pairs.reduceByKeyAndWindow(reduceFunc, new Duration(30000), new Duration(10000)); {% endhighlight %} + +
    + +{% highlight python %} +# Reduce last 30 seconds of data, every 10 seconds +windowedWordCounts = pairs.reduceByKeyAndWindow(lambda x, y: x + y, lambda x, y: x - y, 30, 10) +{% endhighlight %} +
    @@ -860,6 +1022,7 @@ see [DStream](api/scala/index.html#org.apache.spark.streaming.dstream.DStream) and [PairDStreamFunctions](api/scala/index.html#org.apache.spark.streaming.dstream.PairDStreamFunctions). For the Java API, see [JavaDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaDStream.html) and [JavaPairDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaPairDStream.html). +For the Python API, see [DStream](api/python/pyspark.streaming.html#pyspark.streaming.DStream) *** @@ -872,9 +1035,12 @@ Currently, the following output operations are defined: - + + This is useful for development and debugging. +
    + PS: called pprint() in Python) + @@ -915,17 +1081,41 @@ For this purpose, a developer may inadvertantly try creating a connection object the Spark driver, but try to use it in a Spark worker to save records in the RDDs. For example (in Scala), +
    +
    + +{% highlight scala %} dstream.foreachRDD(rdd => { val connection = createNewConnection() // executed at the driver rdd.foreach(record => { connection.send(record) // executed at the worker }) }) +{% endhighlight %} + +
    +
    + +{% highlight python %} +def sendRecord(rdd): + connection = createNewConnection() # executed at the driver + rdd.foreach(lambda record: connection.send(record)) + connection.close() + +dstream.foreachRDD(sendRecord) +{% endhighlight %} + +
    +
    - This is incorrect as this requires the connection object to be serialized and sent from the driver to the worker. Such connection objects are rarely transferrable across machines. This error may manifest as serialization errors (connection object not serializable), initialization errors (connection object needs to be initialized at the workers), etc. The correct solution is to create the connection object at the worker. + This is incorrect as this requires the connection object to be serialized and sent from the driver to the worker. Such connection objects are rarely transferrable across machines. This error may manifest as serialization errors (connection object not serializable), initialization errors (connection object needs to be initialized at the workers), etc. The correct solution is to create the connection object at the worker. - However, this can lead to another common mistake - creating a new connection for every record. For example, +
    +
    + +{% highlight scala %} dstream.foreachRDD(rdd => { rdd.foreach(record => { val connection = createNewConnection() @@ -933,9 +1123,28 @@ For example (in Scala), connection.close() }) }) +{% endhighlight %} - Typically, creating a connection object has time and resource overheads. Therefore, creating and destroying a connection object for each record can incur unnecessarily high overheads and can significantly reduce the overall throughput of the system. A better solution is to use `rdd.foreachPartition` - create a single connection object and send all the records in a RDD partition using that connection. +
    +
    + +{% highlight python %} +def sendRecord(record): + connection = createNewConnection() + connection.send(record) + connection.close() + +dstream.foreachRDD(lambda rdd: rdd.foreach(sendRecord)) +{% endhighlight %} +
    +
    + + Typically, creating a connection object has time and resource overheads. Therefore, creating and destroying a connection object for each record can incur unnecessarily high overheads and can significantly reduce the overall throughput of the system. A better solution is to use `rdd.foreachPartition` - create a single connection object and send all the records in a RDD partition using that connection. + +
    +
    +{% highlight scala %} dstream.foreachRDD(rdd => { rdd.foreachPartition(partitionOfRecords => { val connection = createNewConnection() @@ -943,13 +1152,31 @@ For example (in Scala), connection.close() }) }) +{% endhighlight %} +
    + +
    +{% highlight python %} +def sendPartition(iter): + connection = createNewConnection() + for record in iter: + connection.send(record) + connection.close() + +dstream.foreachRDD(lambda rdd: rdd.foreachPartition(sendPartition)) +{% endhighlight %} +
    +
    - This amortizes the connection creation overheads over many records. + This amortizes the connection creation overheads over many records. - Finally, this can be further optimized by reusing connection objects across multiple RDDs/batches. One can maintain a static pool of connection objects than can be reused as RDDs of multiple batches are pushed to the external system, thus further reducing the overheads. - + +
    +
    +{% highlight scala %} dstream.foreachRDD(rdd => { rdd.foreachPartition(partitionOfRecords => { // ConnectionPool is a static, lazily initialized pool of connections @@ -958,8 +1185,25 @@ For example (in Scala), ConnectionPool.returnConnection(connection) // return to the pool for future reuse }) }) +{% endhighlight %} +
    - Note that the connections in the pool should be lazily created on demand and timed out if not used for a while. This achieves the most efficient sending of data to external systems. +
    +{% highlight python %} +def sendPartition(iter): + # ConnectionPool is a static, lazily initialized pool of connections + connection = ConnectionPool.getConnection() + for record in iter: + connection.send(record) + # return to the pool for future reuse + ConnectionPool.returnConnection(connection) + +dstream.foreachRDD(lambda rdd: rdd.foreachPartition(sendPartition)) +{% endhighlight %} +
    +
    + +Note that the connections in the pool should be lazily created on demand and timed out if not used for a while. This achieves the most efficient sending of data to external systems. ##### Other points to remember: @@ -1376,6 +1620,44 @@ You can also explicitly create a `JavaStreamingContext` from the checkpoint data the computation by using `new JavaStreamingContext(checkpointDirectory)`. +
    + +This behavior is made simple by using `StreamingContext.getOrCreate`. This is used as follows. + +{% highlight python %} +# Function to create and setup a new StreamingContext +def functionToCreateContext(): + sc = SparkContext(...) # new context + ssc = new StreamingContext(...) + lines = ssc.socketTextStream(...) # create DStreams + ... + ssc.checkpoint(checkpointDirectory) # set checkpoint directory + return ssc + +# Get StreamingContext from checkpoint data or create a new one +context = StreamingContext.getOrCreate(checkpointDirectory, functionToCreateContext) + +# Do additional setup on context that needs to be done, +# irrespective of whether it is being started or restarted +context. ... + +# Start the context +context.start() +context.awaitTermination() +{% endhighlight %} + +If the `checkpointDirectory` exists, then the context will be recreated from the checkpoint data. +If the directory does not exist (i.e., running for the first time), +then the function `functionToCreateContext` will be called to create a new +context and set up the DStreams. See the Python example +[recoverable_network_wordcount.py]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python/streaming/recoverable_network_wordcount.py). +This example appends the word counts of network data into a file. + +You can also explicitly create a `StreamingContext` from the checkpoint data and start the + computation by using `StreamingContext.getOrCreate(checkpointDirectory, None)`. + +
    + **Note**: If Spark Streaming and/or the Spark Streaming program is recompiled, @@ -1572,7 +1854,11 @@ package and renamed for better clarity. [TwitterUtils](api/java/index.html?org/apache/spark/streaming/twitter/TwitterUtils.html), [ZeroMQUtils](api/java/index.html?org/apache/spark/streaming/zeromq/ZeroMQUtils.html), and [MQTTUtils](api/java/index.html?org/apache/spark/streaming/mqtt/MQTTUtils.html) + - Python docs + * [StreamingContext](api/python/pyspark.streaming.html#pyspark.streaming.StreamingContext) + * [DStream](api/python/pyspark.streaming.html#pyspark.streaming.DStream) * More examples in [Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/streaming) and [Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming) + and [Python] ({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python/streaming) * [Paper](http://www.eecs.berkeley.edu/Pubs/TechRpts/2012/EECS-2012-259.pdf) and [video](http://youtu.be/g171ndOHgJ0) describing Spark Streaming. diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 941dfb988b9fb..0d6b82b4944f3 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -32,6 +32,7 @@ import tempfile import time import urllib2 +import warnings from optparse import OptionParser from sys import stderr import boto @@ -61,8 +62,8 @@ def parse_args(): "-s", "--slaves", type="int", default=1, help="Number of slaves to launch (default: %default)") parser.add_option( - "-w", "--wait", type="int", default=120, - help="Seconds to wait for nodes to start (default: %default)") + "-w", "--wait", type="int", + help="DEPRECATED (no longer necessary) - Seconds to wait for nodes to start") parser.add_option( "-k", "--key-pair", help="Key pair to use on instances") @@ -195,18 +196,6 @@ def get_or_make_group(conn, name): return conn.create_security_group(name, "Spark EC2 group") -# Wait for a set of launched instances to exit the "pending" state -# (i.e. either to start running or to fail and be terminated) -def wait_for_instances(conn, instances): - while True: - for i in instances: - i.update() - if len([i for i in instances if i.state == 'pending']) > 0: - time.sleep(5) - else: - return - - # Check whether a given EC2 instance object is in a state we consider active, # i.e. not terminating or terminated. We count both stopping and stopped as # active since we can restart stopped clusters. @@ -594,7 +583,7 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): # NOTE: We should clone the repository before running deploy_files to # prevent ec2-variables.sh from being overwritten - ssh(master, opts, "rm -rf spark-ec2 && git clone https://github.com/mesos/spark-ec2.git -b v3") + ssh(master, opts, "rm -rf spark-ec2 && git clone https://github.com/mesos/spark-ec2.git -b v4") print "Deploying files to master..." deploy_files(conn, "deploy.generic", opts, master_nodes, slave_nodes, modules) @@ -619,14 +608,64 @@ def setup_spark_cluster(master, opts): print "Ganglia started at http://%s:5080/ganglia" % master -# Wait for a whole cluster (masters, slaves and ZooKeeper) to start up -def wait_for_cluster(conn, wait_secs, master_nodes, slave_nodes): - print "Waiting for instances to start up..." - time.sleep(5) - wait_for_instances(conn, master_nodes) - wait_for_instances(conn, slave_nodes) - print "Waiting %d more seconds..." % wait_secs - time.sleep(wait_secs) +def is_ssh_available(host, opts): + "Checks if SSH is available on the host." + try: + with open(os.devnull, 'w') as devnull: + ret = subprocess.check_call( + ssh_command(opts) + ['-t', '-t', '-o', 'ConnectTimeout=3', + '%s@%s' % (opts.user, host), stringify_command('true')], + stdout=devnull, + stderr=devnull + ) + return ret == 0 + except subprocess.CalledProcessError as e: + return False + + +def is_cluster_ssh_available(cluster_instances, opts): + for i in cluster_instances: + if not is_ssh_available(host=i.ip_address, opts=opts): + return False + else: + return True + + +def wait_for_cluster_state(cluster_instances, cluster_state, opts): + """ + cluster_instances: a list of boto.ec2.instance.Instance + cluster_state: a string representing the desired state of all the instances in the cluster + value can be 'ssh-ready' or a valid value from boto.ec2.instance.InstanceState such as + 'running', 'terminated', etc. + (would be nice to replace this with a proper enum: http://stackoverflow.com/a/1695250) + """ + sys.stdout.write( + "Waiting for all instances in cluster to enter '{s}' state.".format(s=cluster_state) + ) + sys.stdout.flush() + + num_attempts = 0 + + while True: + time.sleep(3 * num_attempts) + + for i in cluster_instances: + s = i.update() # capture output to suppress print to screen in newer versions of boto + + if cluster_state == 'ssh-ready': + if all(i.state == 'running' for i in cluster_instances) and \ + is_cluster_ssh_available(cluster_instances, opts): + break + else: + if all(i.state == cluster_state for i in cluster_instances): + break + + num_attempts += 1 + + sys.stdout.write(".") + sys.stdout.flush() + + sys.stdout.write("\n") # Get number of local disks available for a given EC2 instance type. @@ -868,6 +907,16 @@ def real_main(): (opts, action, cluster_name) = parse_args() # Input parameter validation + if opts.wait is not None: + # NOTE: DeprecationWarnings are silent in 2.7+ by default. + # To show them, run Python with the -Wdefault switch. + # See: https://docs.python.org/3.5/whatsnew/2.7.html + warnings.warn( + "This option is deprecated and has no effect. " + "spark-ec2 automatically waits as long as necessary for clusters to startup.", + DeprecationWarning + ) + if opts.ebs_vol_num > 8: print >> stderr, "ebs-vol-num cannot be greater than 8" sys.exit(1) @@ -890,7 +939,11 @@ def real_main(): (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) else: (master_nodes, slave_nodes) = launch_cluster(conn, opts, cluster_name) - wait_for_cluster(conn, opts.wait, master_nodes, slave_nodes) + wait_for_cluster_state( + cluster_instances=(master_nodes + slave_nodes), + cluster_state='ssh-ready', + opts=opts + ) setup_cluster(conn, master_nodes, slave_nodes, opts, True) elif action == "destroy": @@ -919,7 +972,11 @@ def real_main(): else: group_names = [opts.security_group_prefix + "-master", opts.security_group_prefix + "-slaves"] - + wait_for_cluster_state( + cluster_instances=(master_nodes + slave_nodes), + cluster_state='terminated', + opts=opts + ) attempt = 1 while attempt <= 3: print "Attempt %d" % attempt @@ -1019,7 +1076,11 @@ def real_main(): for inst in master_nodes: if inst.state not in ["shutting-down", "terminated"]: inst.start() - wait_for_cluster(conn, opts.wait, master_nodes, slave_nodes) + wait_for_cluster_state( + cluster_instances=(master_nodes + slave_nodes), + cluster_state='ssh-ready', + opts=opts + ) setup_cluster(conn, master_nodes, slave_nodes, opts, False) else: diff --git a/examples/pom.xml b/examples/pom.xml index 54e13c57520dd..be6544e515ab5 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -120,114 +120,114 @@ spark-streaming-mqtt_${scala.binary.version}${project.version} - + org.apache.hbase hbase-common ${hbase.version} - - asm - asm - - - org.jboss.netty - netty - - - io.netty - netty - - - commons-logging - commons-logging - - - org.jruby - jruby-complete - + + asm + asm + + + org.jboss.netty + netty + + + io.netty + netty + + + commons-logging + commons-logging + + + org.jruby + jruby-complete + - - + + org.apache.hbase hbase-client ${hbase.version} - - asm - asm - - - org.jboss.netty - netty - - - io.netty - netty - - - commons-logging - commons-logging - - - org.jruby - jruby-complete - + + asm + asm + + + org.jboss.netty + netty + + + io.netty + netty + + + commons-logging + commons-logging + + + org.jruby + jruby-complete + - - + + org.apache.hbase hbase-server ${hbase.version} - - asm - asm - - - org.jboss.netty - netty - - - io.netty - netty - - - commons-logging - commons-logging - - - org.jruby - jruby-complete - + + asm + asm + + + org.jboss.netty + netty + + + io.netty + netty + + + commons-logging + commons-logging + + + org.jruby + jruby-complete + - - + + org.apache.hbase hbase-protocol ${hbase.version} - - asm - asm - - - org.jboss.netty - netty - - - io.netty - netty - - - commons-logging - commons-logging - - - org.jruby - jruby-complete - + + asm + asm + + + org.jboss.netty + netty + + + io.netty + netty + + + commons-logging + commons-logging + + + org.jruby + jruby-complete + - + org.eclipse.jetty jetty-server diff --git a/examples/src/main/python/sql.py b/examples/src/main/python/sql.py index eefa022f1927c..d2c5ca48c6cb8 100644 --- a/examples/src/main/python/sql.py +++ b/examples/src/main/python/sql.py @@ -48,7 +48,7 @@ # A JSON dataset is pointed to by path. # The path can be either a single text file or a directory storing text files. - path = os.environ['SPARK_HOME'] + "examples/src/main/resources/people.json" + path = os.path.join(os.environ['SPARK_HOME'], "examples/src/main/resources/people.json") # Create a SchemaRDD from the file(s) pointed to by path people = sqlContext.jsonFile(path) # root diff --git a/examples/src/main/python/streaming/hdfs_wordcount.py b/examples/src/main/python/streaming/hdfs_wordcount.py new file mode 100644 index 0000000000000..40faff0ccc7db --- /dev/null +++ b/examples/src/main/python/streaming/hdfs_wordcount.py @@ -0,0 +1,49 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in new text files created in the given directory + Usage: hdfs_wordcount.py + is the directory that Spark Streaming will use to find and read new text files. + + To run this on your local machine on directory `localdir`, run this example + $ bin/spark-submit examples/src/main/python/streaming/network_wordcount.py localdir + + Then create a text file in `localdir` and the words in the file will get counted. +""" + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext + +if __name__ == "__main__": + if len(sys.argv) != 2: + print >> sys.stderr, "Usage: hdfs_wordcount.py " + exit(-1) + + sc = SparkContext(appName="PythonStreamingHDFSWordCount") + ssc = StreamingContext(sc, 1) + + lines = ssc.textFileStream(sys.argv[1]) + counts = lines.flatMap(lambda line: line.split(" "))\ + .map(lambda x: (x, 1))\ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/examples/src/main/python/streaming/network_wordcount.py b/examples/src/main/python/streaming/network_wordcount.py new file mode 100644 index 0000000000000..cfa9c1ff5bfbc --- /dev/null +++ b/examples/src/main/python/streaming/network_wordcount.py @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in UTF8 encoded, '\n' delimited text received from the network every second. + Usage: network_wordcount.py + and describe the TCP server that Spark Streaming would connect to receive data. + + To run this on your local machine, you need to first run a Netcat server + `$ nc -lk 9999` + and then run the example + `$ bin/spark-submit examples/src/main/python/streaming/network_wordcount.py localhost 9999` +""" + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext + +if __name__ == "__main__": + if len(sys.argv) != 3: + print >> sys.stderr, "Usage: network_wordcount.py " + exit(-1) + sc = SparkContext(appName="PythonStreamingNetworkWordCount") + ssc = StreamingContext(sc, 1) + + lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2])) + counts = lines.flatMap(lambda line: line.split(" "))\ + .map(lambda word: (word, 1))\ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/examples/src/main/python/streaming/recoverable_network_wordcount.py b/examples/src/main/python/streaming/recoverable_network_wordcount.py new file mode 100644 index 0000000000000..fc6827c82bf9b --- /dev/null +++ b/examples/src/main/python/streaming/recoverable_network_wordcount.py @@ -0,0 +1,80 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in text encoded with UTF8 received from the network every second. + + Usage: recoverable_network_wordcount.py + and describe the TCP server that Spark Streaming would connect to receive + data. directory to HDFS-compatible file system which checkpoint data + file to which the word counts will be appended + + To run this on your local machine, you need to first run a Netcat server + `$ nc -lk 9999` + + and then run the example + `$ bin/spark-submit examples/src/main/python/streaming/recoverable_network_wordcount.py \ + localhost 9999 ~/checkpoint/ ~/out` + + If the directory ~/checkpoint/ does not exist (e.g. running for the first time), it will create + a new StreamingContext (will print "Creating new context" to the console). Otherwise, if + checkpoint data exists in ~/checkpoint/, then it will create StreamingContext from + the checkpoint data. +""" + +import os +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext + + +def createContext(host, port, outputPath): + # If you do not see this printed, that means the StreamingContext has been loaded + # from the new checkpoint + print "Creating new context" + if os.path.exists(outputPath): + os.remove(outputPath) + sc = SparkContext(appName="PythonStreamingRecoverableNetworkWordCount") + ssc = StreamingContext(sc, 1) + + # Create a socket stream on target ip:port and count the + # words in input stream of \n delimited text (eg. generated by 'nc') + lines = ssc.socketTextStream(host, port) + words = lines.flatMap(lambda line: line.split(" ")) + wordCounts = words.map(lambda x: (x, 1)).reduceByKey(lambda x, y: x + y) + + def echo(time, rdd): + counts = "Counts at time %s %s" % (time, rdd.collect()) + print counts + print "Appending to " + os.path.abspath(outputPath) + with open(outputPath, 'a') as f: + f.write(counts + "\n") + + wordCounts.foreachRDD(echo) + return ssc + +if __name__ == "__main__": + if len(sys.argv) != 5: + print >> sys.stderr, "Usage: recoverable_network_wordcount.py "\ + " " + exit(-1) + host, port, checkpoint, output = sys.argv[1:] + ssc = StreamingContext.getOrCreate(checkpoint, + lambda: createContext(host, int(port), output)) + ssc.start() + ssc.awaitTermination() diff --git a/examples/src/main/python/streaming/stateful_network_wordcount.py b/examples/src/main/python/streaming/stateful_network_wordcount.py new file mode 100644 index 0000000000000..18a9a5a452ffb --- /dev/null +++ b/examples/src/main/python/streaming/stateful_network_wordcount.py @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in UTF8 encoded, '\n' delimited text received from the + network every second. + + Usage: stateful_network_wordcount.py + and describe the TCP server that Spark Streaming + would connect to receive data. + + To run this on your local machine, you need to first run a Netcat server + `$ nc -lk 9999` + and then run the example + `$ bin/spark-submit examples/src/main/python/streaming/stateful_network_wordcount.py \ + localhost 9999` +""" + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext + +if __name__ == "__main__": + if len(sys.argv) != 3: + print >> sys.stderr, "Usage: stateful_network_wordcount.py " + exit(-1) + sc = SparkContext(appName="PythonStreamingStatefulNetworkWordCount") + ssc = StreamingContext(sc, 1) + ssc.checkpoint("checkpoint") + + def updateFunc(new_values, last_sum): + return sum(new_values) + (last_sum or 0) + + lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2])) + running_counts = lines.flatMap(lambda line: line.split(" "))\ + .map(lambda word: (word, 1))\ + .updateStateByKey(updateFunc) + + running_counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala index c4317a6aec798..45527d9382fd0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala @@ -46,17 +46,6 @@ object Analytics extends Logging { } val options = mutable.Map(optionsList: _*) - def pickPartitioner(v: String): PartitionStrategy = { - // TODO: Use reflection rather than listing all the partitioning strategies here. - v match { - case "RandomVertexCut" => RandomVertexCut - case "EdgePartition1D" => EdgePartition1D - case "EdgePartition2D" => EdgePartition2D - case "CanonicalRandomVertexCut" => CanonicalRandomVertexCut - case _ => throw new IllegalArgumentException("Invalid PartitionStrategy: " + v) - } - } - val conf = new SparkConf() .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator") @@ -67,7 +56,7 @@ object Analytics extends Logging { sys.exit(1) } val partitionStrategy: Option[PartitionStrategy] = options.remove("partStrategy") - .map(pickPartitioner(_)) + .map(PartitionStrategy.fromString(_)) val edgeStorageLevel = options.remove("edgeStorageLevel") .map(StorageLevel.fromString(_)).getOrElse(StorageLevel.MEMORY_ONLY) val vertexStorageLevel = options.remove("vertexStorageLevel") @@ -107,7 +96,7 @@ object Analytics extends Logging { if (!outFname.isEmpty) { logWarning("Saving pageranks of pages to " + outFname) - pr.map{case (id, r) => id + "\t" + r}.saveAsTextFile(outFname) + pr.map { case (id, r) => id + "\t" + r }.saveAsTextFile(outFname) } sc.stop() @@ -129,7 +118,7 @@ object Analytics extends Logging { val graph = partitionStrategy.foldLeft(unpartitionedGraph)(_.partitionBy(_)) val cc = ConnectedComponents.run(graph) - println("Components: " + cc.vertices.map{ case (vid,data) => data}.distinct()) + println("Components: " + cc.vertices.map { case (vid, data) => data }.distinct()) sc.stop() case "triangles" => @@ -147,7 +136,7 @@ object Analytics extends Logging { minEdgePartitions = numEPart, edgeStorageLevel = edgeStorageLevel, vertexStorageLevel = vertexStorageLevel) - // TriangleCount requires the graph to be partitioned + // TriangleCount requires the graph to be partitioned .partitionBy(partitionStrategy.getOrElse(RandomVertexCut)).cache() val triangles = TriangleCount.run(graph) println("Triangles: " + triangles.vertices.map { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/AbstractParams.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/AbstractParams.scala new file mode 100644 index 0000000000000..ae6057758d6fc --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/AbstractParams.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import scala.reflect.runtime.universe._ + +/** + * Abstract class for parameter case classes. + * This overrides the [[toString]] method to print all case class fields by name and value. + * @tparam T Concrete parameter class. + */ +abstract class AbstractParams[T: TypeTag] { + + private def tag: TypeTag[T] = typeTag[T] + + /** + * Finds all case class fields in concrete class instance, and outputs them in JSON-style format: + * { + * [field name]:\t[field value]\n + * [field name]:\t[field value]\n + * ... + * } + */ + override def toString: String = { + val tpe = tag.tpe + val allAccessors = tpe.declarations.collect { + case m: MethodSymbol if m.isCaseAccessor => m + } + val mirror = runtimeMirror(getClass.getClassLoader) + val instanceMirror = mirror.reflect(this) + allAccessors.map { f => + val paramName = f.name.toString + val fieldMirror = instanceMirror.reflectField(f) + val paramValue = fieldMirror.get + s" $paramName:\t$paramValue" + }.mkString("{\n", ",\n", "\n}") + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala index a6f78d2441db1..1edd2432a0352 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala @@ -55,7 +55,7 @@ object BinaryClassification { stepSize: Double = 1.0, algorithm: Algorithm = LR, regType: RegType = L2, - regParam: Double = 0.1) + regParam: Double = 0.1) extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala index d6b2fe430e5a4..e49129c4e7844 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala @@ -35,6 +35,7 @@ import org.apache.spark.{SparkConf, SparkContext} object Correlations { case class Params(input: String = "data/mllib/sample_linear_regression_data.txt") + extends AbstractParams[Params] def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala new file mode 100644 index 0000000000000..cb1abbd18fd4d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import scopt.OptionParser + +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.distributed.{MatrixEntry, RowMatrix} +import org.apache.spark.{SparkConf, SparkContext} + +/** + * Compute the similar columns of a matrix, using cosine similarity. + * + * The input matrix must be stored in row-oriented dense format, one line per row with its entries + * separated by space. For example, + * {{{ + * 0.5 1.0 + * 2.0 3.0 + * 4.0 5.0 + * }}} + * represents a 3-by-2 matrix, whose first row is (0.5, 1.0). + * + * Example invocation: + * + * bin/run-example mllib.CosineSimilarity \ + * --threshold 0.1 data/mllib/sample_svm_data.txt + */ +object CosineSimilarity { + case class Params(inputFile: String = null, threshold: Double = 0.1) + extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("CosineSimilarity") { + head("CosineSimilarity: an example app.") + opt[Double]("threshold") + .required() + .text(s"threshold similarity: to tradeoff computation vs quality estimate") + .action((x, c) => c.copy(threshold = x)) + arg[String]("") + .required() + .text(s"input file, one row per line, space-separated") + .action((x, c) => c.copy(inputFile = x)) + note( + """ + |For example, the following command runs this app on a dataset: + | + | ./bin/spark-submit --class org.apache.spark.examples.mllib.CosineSimilarity \ + | examplesjar.jar \ + | --threshold 0.1 data/mllib/sample_svm_data.txt + """.stripMargin) + } + + parser.parse(args, defaultParams).map { params => + run(params) + } getOrElse { + System.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName("CosineSimilarity") + val sc = new SparkContext(conf) + + // Load and parse the data file. + val rows = sc.textFile(params.inputFile).map { line => + val values = line.split(' ').map(_.toDouble) + Vectors.dense(values) + }.cache() + val mat = new RowMatrix(rows) + + // Compute similar columns perfectly, with brute force. + val exact = mat.columnSimilarities() + + // Compute similar columns with estimation using DIMSUM + val approx = mat.columnSimilarities(params.threshold) + + val exactEntries = exact.entries.map { case MatrixEntry(i, j, u) => ((i, j), u) } + val approxEntries = approx.entries.map { case MatrixEntry(i, j, v) => ((i, j), v) } + val MAE = exactEntries.leftOuterJoin(approxEntries).values.map { + case (u, Some(v)) => + math.abs(u - v) + case (u, None) => + math.abs(u) + }.mean() + + println(s"Average absolute error in estimate is: $MAE") + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 4adc91d2fbe65..0890e6263e165 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -62,7 +62,7 @@ object DecisionTreeRunner { minInfoGain: Double = 0.0, numTrees: Int = 1, featureSubsetStrategy: String = "auto", - fracTest: Double = 0.2) + fracTest: Double = 0.2) extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() @@ -138,9 +138,11 @@ object DecisionTreeRunner { def run(params: Params) { - val conf = new SparkConf().setAppName("DecisionTreeRunner") + val conf = new SparkConf().setAppName(s"DecisionTreeRunner with $params") val sc = new SparkContext(conf) + println(s"DecisionTreeRunner with parameters:\n$params") + // Load training data and cache it. val origExamples = params.dataFormat match { case "dense" => MLUtils.loadLabeledPoints(sc, params.input).cache() @@ -187,9 +189,10 @@ object DecisionTreeRunner { // Create training, test sets. val splits = if (params.testInput != "") { // Load testInput. + val numFeatures = examples.take(1)(0).features.size val origTestExamples = params.dataFormat match { case "dense" => MLUtils.loadLabeledPoints(sc, params.testInput) - case "libsvm" => MLUtils.loadLibSVMFile(sc, params.testInput) + case "libsvm" => MLUtils.loadLibSVMFile(sc, params.testInput, numFeatures) } params.algo match { case Classification => { @@ -235,7 +238,10 @@ object DecisionTreeRunner { minInstancesPerNode = params.minInstancesPerNode, minInfoGain = params.minInfoGain) if (params.numTrees == 1) { + val startTime = System.nanoTime() val model = DecisionTree.train(training, strategy) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") if (model.numNodes < 20) { println(model.toDebugString) // Print full model. } else { @@ -259,8 +265,11 @@ object DecisionTreeRunner { } else { val randomSeed = Utils.random.nextInt() if (params.algo == Classification) { + val startTime = System.nanoTime() val model = RandomForest.trainClassifier(training, strategy, params.numTrees, params.featureSubsetStrategy, randomSeed) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") if (model.totalNumNodes < 30) { println(model.toDebugString) // Print full model. } else { @@ -275,8 +284,11 @@ object DecisionTreeRunner { println(s"Test accuracy = $testAccuracy") } if (params.algo == Regression) { + val startTime = System.nanoTime() val model = RandomForest.trainRegressor(training, strategy, params.numTrees, params.featureSubsetStrategy, randomSeed) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") if (model.totalNumNodes < 30) { println(model.toDebugString) // Print full model. } else { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala index 89dfa26c2299c..11e35598baf50 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala @@ -44,7 +44,7 @@ object DenseKMeans { input: String = null, k: Int = -1, numIterations: Int = 10, - initializationMode: InitializationMode = Parallel) + initializationMode: InitializationMode = Parallel) extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala index 05b7d66f8dffd..e1f9622350135 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala @@ -47,7 +47,7 @@ object LinearRegression extends App { numIterations: Int = 100, stepSize: Double = 1.0, regType: RegType = L2, - regParam: Double = 0.1) + regParam: Double = 0.1) extends AbstractParams[Params] val defaultParams = Params() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index 98aaedb9d7dc9..fc6678013b932 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -55,7 +55,7 @@ object MovieLensALS { rank: Int = 10, numUserBlocks: Int = -1, numProductBlocks: Int = -1, - implicitPrefs: Boolean = false) + implicitPrefs: Boolean = false) extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala index 4532512c01f84..6e4e2d07f284b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala @@ -36,6 +36,7 @@ import org.apache.spark.{SparkConf, SparkContext} object MultivariateSummarizer { case class Params(input: String = "data/mllib/sample_linear_regression_data.txt") + extends AbstractParams[Params] def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala index f01b8266e3fe3..663c12734af68 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala @@ -33,6 +33,7 @@ import org.apache.spark.SparkContext._ object SampledRDDs { case class Params(input: String = "data/mllib/sample_binary_classification_data.txt") + extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala index 952fa2a5109a4..f1ff4e6911f5e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala @@ -37,7 +37,7 @@ object SparseNaiveBayes { input: String = null, minPartitions: Int = 0, numFeatures: Int = -1, - lambda: Double = 1.0) + lambda: Double = 1.0) extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index e26f213e8afa8..0c52ef8ed96ac 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -28,7 +28,7 @@ object HiveFromSpark { val sparkConf = new SparkConf().setAppName("HiveFromSpark") val sc = new SparkContext(sparkConf) - // A local hive context creates an instance of the Hive Metastore in process, storing the + // A local hive context creates an instance of the Hive Metastore in process, storing // the warehouse data in the current directory. This location can be overridden by // specifying a second parameter to the constructor. val hiveContext = new HiveContext(sc) diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index 33235d150b4a5..13943ed5442b9 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -17,103 +17,141 @@ package org.apache.spark.streaming.flume -import scala.collection.JavaConversions._ -import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} - -import java.net.InetSocketAddress +import java.net.{InetSocketAddress, ServerSocket} import java.nio.ByteBuffer import java.nio.charset.Charset +import scala.collection.JavaConversions._ +import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} +import scala.concurrent.duration._ +import scala.language.postfixOps + import org.apache.avro.ipc.NettyTransceiver import org.apache.avro.ipc.specific.SpecificRequestor +import org.apache.flume.source.avro import org.apache.flume.source.avro.{AvroFlumeEvent, AvroSourceProtocol} +import org.jboss.netty.channel.ChannelPipeline +import org.jboss.netty.channel.socket.SocketChannel +import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory +import org.jboss.netty.handler.codec.compression._ +import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.concurrent.Eventually._ +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{TestOutputStream, StreamingContext, TestSuiteBase} -import org.apache.spark.streaming.util.ManualClock +import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream} +import org.apache.spark.streaming.scheduler.{StreamingListener, StreamingListenerReceiverStarted} import org.apache.spark.util.Utils -import org.jboss.netty.channel.ChannelPipeline -import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory -import org.jboss.netty.channel.socket.SocketChannel -import org.jboss.netty.handler.codec.compression._ +class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with Logging { + val conf = new SparkConf().setMaster("local[4]").setAppName("FlumeStreamSuite") + + var ssc: StreamingContext = null + var transceiver: NettyTransceiver = null -class FlumeStreamSuite extends TestSuiteBase { + after { + if (ssc != null) { + ssc.stop() + } + if (transceiver != null) { + transceiver.close() + } + } test("flume input stream") { - runFlumeStreamTest(false) + testFlumeStream(testCompression = false) } test("flume input compressed stream") { - runFlumeStreamTest(true) + testFlumeStream(testCompression = true) + } + + /** Run test on flume stream */ + private def testFlumeStream(testCompression: Boolean): Unit = { + val input = (1 to 100).map { _.toString } + val testPort = findFreePort() + val outputBuffer = startContext(testPort, testCompression) + writeAndVerify(input, testPort, outputBuffer, testCompression) + } + + /** Find a free port */ + private def findFreePort(): Int = { + Utils.startServiceOnPort(23456, (trialPort: Int) => { + val socket = new ServerSocket(trialPort) + socket.close() + (null, trialPort) + })._2 } - - def runFlumeStreamTest(enableDecompression: Boolean) { - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val (flumeStream, testPort) = - Utils.startServiceOnPort(9997, (trialPort: Int) => { - val dstream = FlumeUtils.createStream( - ssc, "localhost", trialPort, StorageLevel.MEMORY_AND_DISK, enableDecompression) - (dstream, trialPort) - }) + /** Setup and start the streaming context */ + private def startContext( + testPort: Int, testCompression: Boolean): (ArrayBuffer[Seq[SparkFlumeEvent]]) = { + ssc = new StreamingContext(conf, Milliseconds(200)) + val flumeStream = FlumeUtils.createStream( + ssc, "localhost", testPort, StorageLevel.MEMORY_AND_DISK, testCompression) val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] with SynchronizedBuffer[Seq[SparkFlumeEvent]] val outputStream = new TestOutputStream(flumeStream, outputBuffer) outputStream.register() ssc.start() + outputBuffer + } - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - val input = Seq(1, 2, 3, 4, 5) - Thread.sleep(1000) - val transceiver = new NettyTransceiver(new InetSocketAddress("localhost", testPort)) - var client: AvroSourceProtocol = null - - if (enableDecompression) { - client = SpecificRequestor.getClient( - classOf[AvroSourceProtocol], - new NettyTransceiver(new InetSocketAddress("localhost", testPort), - new CompressionChannelFactory(6))) - } else { - client = SpecificRequestor.getClient( - classOf[AvroSourceProtocol], transceiver) - } + /** Send data to the flume receiver and verify whether the data was received */ + private def writeAndVerify( + input: Seq[String], + testPort: Int, + outputBuffer: ArrayBuffer[Seq[SparkFlumeEvent]], + enableCompression: Boolean + ) { + val testAddress = new InetSocketAddress("localhost", testPort) - for (i <- 0 until input.size) { + val inputEvents = input.map { item => val event = new AvroFlumeEvent - event.setBody(ByteBuffer.wrap(input(i).toString.getBytes("utf-8"))) + event.setBody(ByteBuffer.wrap(item.getBytes("UTF-8"))) event.setHeaders(Map[CharSequence, CharSequence]("test" -> "header")) - client.append(event) - Thread.sleep(500) - clock.addToTime(batchDuration.milliseconds) + event } - Thread.sleep(1000) - - val startTime = System.currentTimeMillis() - while (outputBuffer.size < input.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { - logInfo("output.size = " + outputBuffer.size + ", input.size = " + input.size) - Thread.sleep(100) + eventually(timeout(10 seconds), interval(100 milliseconds)) { + // if last attempted transceiver had succeeded, close it + if (transceiver != null) { + transceiver.close() + transceiver = null + } + + // Create transceiver + transceiver = { + if (enableCompression) { + new NettyTransceiver(testAddress, new CompressionChannelFactory(6)) + } else { + new NettyTransceiver(testAddress) + } + } + + // Create Avro client with the transceiver + val client = SpecificRequestor.getClient(classOf[AvroSourceProtocol], transceiver) + client should not be null + + // Send data + val status = client.appendBatch(inputEvents.toList) + status should be (avro.Status.OK) } - Thread.sleep(1000) - val timeTaken = System.currentTimeMillis() - startTime - assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") - logInfo("Stopping context") - ssc.stop() - - val decoder = Charset.forName("UTF-8").newDecoder() - - assert(outputBuffer.size === input.length) - for (i <- 0 until outputBuffer.size) { - assert(outputBuffer(i).size === 1) - val str = decoder.decode(outputBuffer(i).head.event.getBody) - assert(str.toString === input(i).toString) - assert(outputBuffer(i).head.event.getHeaders.get("test") === "header") + + val decoder = Charset.forName("UTF-8").newDecoder() + eventually(timeout(10 seconds), interval(100 milliseconds)) { + val outputEvents = outputBuffer.flatten.map { _.event } + outputEvents.foreach { + event => + event.getHeaders.get("test") should be("header") + } + val output = outputEvents.map(event => decoder.decode(event.getBody()).toString) + output should be (input) } } - class CompressionChannelFactory(compressionLevel: Int) extends NioClientSocketChannelFactory { + /** Class to create socket channel with compression */ + private class CompressionChannelFactory(compressionLevel: Int) extends NioClientSocketChannelFactory { override def newChannel(pipeline: ChannelPipeline): SocketChannel = { val encoder = new ZlibEncoder(compressionLevel) pipeline.addFirst("deflater", encoder) diff --git a/mllib/pom.xml b/mllib/pom.xml index a5eeef88e9d62..696e9396f627c 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -57,7 +57,7 @@ org.scalanlp breeze_${scala.binary.version} - 0.9 + 0.10 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index e9f41758581e3..9a100170b75c6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.api.python import java.io.OutputStream +import java.util.{ArrayList => JArrayList} import scala.collection.JavaConverters._ import scala.language.existentials @@ -27,8 +28,11 @@ import net.razorvine.pickle._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} +import org.apache.spark.api.python.{PythonRDD, SerDeUtil} import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ +import org.apache.spark.mllib.feature.Word2Vec +import org.apache.spark.mllib.feature.Word2VecModel import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.random.{RandomRDDs => RG} @@ -42,9 +46,9 @@ import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} import org.apache.spark.mllib.stat.correlation.CorrelationNames import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils - /** * :: DeveloperApi :: * The Java stubs necessary for the Python mllib bindings. @@ -287,6 +291,59 @@ class PythonMLLibAPI extends Serializable { ALS.trainImplicit(ratingsJRDD.rdd, rank, iterations, lambda, blocks, alpha) } + /** + * Java stub for Python mllib Word2Vec fit(). This stub returns a + * handle to the Java object instead of the content of the Java object. + * Extra care needs to be taken in the Python code to ensure it gets freed on + * exit; see the Py4J documentation. + * @param dataJRDD input JavaRDD + * @param vectorSize size of vector + * @param learningRate initial learning rate + * @param numPartitions number of partitions + * @param numIterations number of iterations + * @param seed initial seed for random generator + * @return A handle to java Word2VecModelWrapper instance at python side + */ + def trainWord2Vec( + dataJRDD: JavaRDD[java.util.ArrayList[String]], + vectorSize: Int, + learningRate: Double, + numPartitions: Int, + numIterations: Int, + seed: Long): Word2VecModelWrapper = { + val data = dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER) + val word2vec = new Word2Vec() + .setVectorSize(vectorSize) + .setLearningRate(learningRate) + .setNumPartitions(numPartitions) + .setNumIterations(numIterations) + .setSeed(seed) + val model = word2vec.fit(data) + data.unpersist() + new Word2VecModelWrapper(model) + } + + private[python] class Word2VecModelWrapper(model: Word2VecModel) { + def transform(word: String): Vector = { + model.transform(word) + } + + def findSynonyms(word: String, num: Int): java.util.List[java.lang.Object] = { + val vec = transform(word) + findSynonyms(vec, num) + } + + def findSynonyms(vector: Vector, num: Int): java.util.List[java.lang.Object] = { + val result = model.findSynonyms(vector, num) + val similarity = Vectors.dense(result.map(_._2)) + val words = result.map(_._1) + val ret = new java.util.LinkedList[java.lang.Object]() + ret.add(words) + ret.add(similarity) + ret + } + } + /** * Java stub for Python mllib DecisionTree.train(). * This stub returns a handle to the Java object instead of the content of the Java object. @@ -584,13 +641,24 @@ private[spark] object SerDe extends Serializable { } } + var initialized = false + // This should be called before trying to serialize any above classes + // In cluster mode, this should be put in the closure def initialize(): Unit = { - new DenseVectorPickler().register() - new DenseMatrixPickler().register() - new SparseVectorPickler().register() - new LabeledPointPickler().register() - new RatingPickler().register() + SerDeUtil.initialize() + synchronized { + if (!initialized) { + new DenseVectorPickler().register() + new DenseMatrixPickler().register() + new SparseVectorPickler().register() + new LabeledPointPickler().register() + new RatingPickler().register() + initialized = true + } + } } + // will not called in Executor automatically + initialize() def dumps(obj: AnyRef): Array[Byte] = { new Pickler().dumps(obj) @@ -604,4 +672,33 @@ private[spark] object SerDe extends Serializable { def asTupleRDD(rdd: RDD[Array[Any]]): RDD[(Int, Int)] = { rdd.map(x => (x(0).asInstanceOf[Int], x(1).asInstanceOf[Int])) } + + /** + * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by + * PySpark. + */ + def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = { + jRDD.rdd.mapPartitions { iter => + initialize() // let it called in executor + new PythonRDD.AutoBatchedPickler(iter) + } + } + + /** + * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark. + */ + def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = { + pyRDD.rdd.mapPartitions { iter => + initialize() // let it called in executor + val unpickle = new Unpickler + iter.flatMap { row => + val obj = unpickle.loads(row) + if (batched) { + obj.asInstanceOf[JArrayList[_]].asScala + } else { + Seq(obj) + } + } + }.toJavaRDD() + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala index 3afb47767281c..4734251127bb4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.feature -import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} +import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, norm => brzNorm} import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -47,7 +47,7 @@ class Normalizer(p: Double) extends VectorTransformer { * @return normalized vector. If the norm of the input is zero, it will return the input vector. */ override def transform(vector: Vector): Vector = { - var norm = vector.toBreeze.norm(p) + var norm = brzNorm(vector.toBreeze, p) if (norm != 0.0) { // For dense vector, we've to allocate new memory for new output vector. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index fc1444705364a..d321994c2a651 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -67,7 +67,7 @@ private case class VocabWord( class Word2Vec extends Serializable with Logging { private var vectorSize = 100 - private var startingAlpha = 0.025 + private var learningRate = 0.025 private var numPartitions = 1 private var numIterations = 1 private var seed = Utils.random.nextLong() @@ -84,7 +84,7 @@ class Word2Vec extends Serializable with Logging { * Sets initial learning rate (default: 0.025). */ def setLearningRate(learningRate: Double): this.type = { - this.startingAlpha = learningRate + this.learningRate = learningRate this } @@ -286,7 +286,7 @@ class Word2Vec extends Serializable with Logging { val syn0Global = Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) val syn1Global = new Array[Float](vocabSize * vectorSize) - var alpha = startingAlpha + var alpha = learningRate for (k <- 1 to numIterations) { val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) @@ -300,8 +300,8 @@ class Word2Vec extends Serializable with Logging { lwc = wordCount // TODO: discount by iteration? alpha = - startingAlpha * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1)) - if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001 + learningRate * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1)) + if (alpha < learningRate * 0.0001) alpha = learningRate * 0.0001 logInfo("wordCount = " + wordCount + ", alpha = " + alpha) } wc += sentence.size @@ -437,7 +437,7 @@ class Word2VecModel private[mllib] ( * Find synonyms of a word * @param word a word * @param num number of synonyms to find - * @return array of (word, similarity) + * @return array of (word, cosineSimilarity) */ def findSynonyms(word: String, num: Int): Array[(String, Double)] = { val vector = transform(word) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 8380058cf9b41..ec2d481dccc22 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -111,7 +111,10 @@ class RowMatrix( */ def computeGramianMatrix(): Matrix = { val n = numCols().toInt - val nt: Int = n * (n + 1) / 2 + checkNumColumns(n) + // Computes n*(n+1)/2, avoiding overflow in the multiplication. + // This succeeds when n <= 65535, which is checked above + val nt: Int = if (n % 2 == 0) ((n / 2) * (n + 1)) else (n * ((n + 1) / 2)) // Compute the upper triangular part of the gram matrix. val GU = rows.treeAggregate(new BDV[Double](new Array[Double](nt)))( @@ -123,6 +126,16 @@ class RowMatrix( RowMatrix.triuToFull(n, GU.data) } + private def checkNumColumns(cols: Int): Unit = { + if (cols > 65535) { + throw new IllegalArgumentException(s"Argument with more than 65535 cols: $cols") + } + if (cols > 10000) { + val mem = cols * cols * 8 + logWarning(s"$cols columns will require at least $mem bytes of memory!") + } + } + /** * Computes singular value decomposition of this matrix. Denote this matrix by A (m x n). This * will compute matrices U, S, V such that A ~= U * S * V', where S contains the leading k @@ -301,12 +314,7 @@ class RowMatrix( */ def computeCovariance(): Matrix = { val n = numCols().toInt - - if (n > 10000) { - val mem = n * n * java.lang.Double.SIZE / java.lang.Byte.SIZE - logWarning(s"The number of columns $n is greater than 10000! " + - s"We need at least $mem bytes of memory.") - } + checkNumColumns(n) val (m, mean) = rows.treeAggregate[(Long, BDV[Double])]((0L, BDV.zeros[Double](n)))( seqOp = (s: (Long, BDV[Double]), v: Vector) => (s._1 + 1L, s._2 += v.toBreeze), diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index b311d10023894..6737a2f4176c2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -19,6 +19,8 @@ package org.apache.spark.mllib.tree import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD @@ -532,6 +534,14 @@ object DecisionTree extends Serializable with Logging { Some(mutableNodeToFeatures.toMap) } + // array of nodes to train indexed by node index in group + val nodes = new Array[Node](numNodes) + nodesForGroup.foreach { case (treeIndex, nodesForTree) => + nodesForTree.foreach { node => + nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node + } + } + // Calculate best splits for all nodes in the group timer.start("chooseSplits") @@ -568,7 +578,7 @@ object DecisionTree extends Serializable with Logging { // find best split for each node val (split: Split, stats: InformationGainStats, predict: Predict) = - binsToBestSplit(aggStats, splits, featuresForNode) + binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex)) (nodeIndex, (split, stats, predict)) }.collectAsMap() @@ -587,17 +597,30 @@ object DecisionTree extends Serializable with Logging { // Extract info for this node. Create children if not leaf. val isLeaf = (stats.gain <= 0) || (Node.indexToLevel(nodeIndex) == metadata.maxDepth) assert(node.id == nodeIndex) - node.predict = predict.predict + node.predict = predict node.isLeaf = isLeaf node.stats = Some(stats) + node.impurity = stats.impurity logDebug("Node = " + node) if (!isLeaf) { node.split = Some(split) - node.leftNode = Some(Node.emptyNode(Node.leftChildIndex(nodeIndex))) - node.rightNode = Some(Node.emptyNode(Node.rightChildIndex(nodeIndex))) - nodeQueue.enqueue((treeIndex, node.leftNode.get)) - nodeQueue.enqueue((treeIndex, node.rightNode.get)) + val childIsLeaf = (Node.indexToLevel(nodeIndex) + 1) == metadata.maxDepth + val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0) + val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0) + node.leftNode = Some(Node(Node.leftChildIndex(nodeIndex), + stats.leftPredict, stats.leftImpurity, leftChildIsLeaf)) + node.rightNode = Some(Node(Node.rightChildIndex(nodeIndex), + stats.rightPredict, stats.rightImpurity, rightChildIsLeaf)) + + // enqueue left child and right child if they are not leaves + if (!leftChildIsLeaf) { + nodeQueue.enqueue((treeIndex, node.leftNode.get)) + } + if (!rightChildIsLeaf) { + nodeQueue.enqueue((treeIndex, node.rightNode.get)) + } + logDebug("leftChildIndex = " + node.leftNode.get.id + ", impurity = " + stats.leftImpurity) logDebug("rightChildIndex = " + node.rightNode.get.id + @@ -617,7 +640,8 @@ object DecisionTree extends Serializable with Logging { private def calculateGainForSplit( leftImpurityCalculator: ImpurityCalculator, rightImpurityCalculator: ImpurityCalculator, - metadata: DecisionTreeMetadata): InformationGainStats = { + metadata: DecisionTreeMetadata, + impurity: Double): InformationGainStats = { val leftCount = leftImpurityCalculator.count val rightCount = rightImpurityCalculator.count @@ -630,11 +654,6 @@ object DecisionTree extends Serializable with Logging { val totalCount = leftCount + rightCount - val parentNodeAgg = leftImpurityCalculator.copy - parentNodeAgg.add(rightImpurityCalculator) - - val impurity = parentNodeAgg.calculate() - val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 val rightImpurity = rightImpurityCalculator.calculate() @@ -649,7 +668,18 @@ object DecisionTree extends Serializable with Logging { return InformationGainStats.invalidInformationGainStats } - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity) + // calculate left and right predict + val leftPredict = calculatePredict(leftImpurityCalculator) + val rightPredict = calculatePredict(rightImpurityCalculator) + + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, + leftPredict, rightPredict) + } + + private def calculatePredict(impurityCalculator: ImpurityCalculator): Predict = { + val predict = impurityCalculator.predict + val prob = impurityCalculator.prob(predict) + new Predict(predict, prob) } /** @@ -657,17 +687,17 @@ object DecisionTree extends Serializable with Logging { * Note that this function is called only once for each node. * @param leftImpurityCalculator left node aggregates for a split * @param rightImpurityCalculator right node aggregates for a split - * @return predict value for current node + * @return predict value and impurity for current node */ - private def calculatePredict( + private def calculatePredictImpurity( leftImpurityCalculator: ImpurityCalculator, - rightImpurityCalculator: ImpurityCalculator): Predict = { + rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = { val parentNodeAgg = leftImpurityCalculator.copy parentNodeAgg.add(rightImpurityCalculator) - val predict = parentNodeAgg.predict - val prob = parentNodeAgg.prob(predict) + val predict = calculatePredict(parentNodeAgg) + val impurity = parentNodeAgg.calculate() - new Predict(predict, prob) + (predict, impurity) } /** @@ -678,10 +708,16 @@ object DecisionTree extends Serializable with Logging { private def binsToBestSplit( binAggregates: DTStatsAggregator, splits: Array[Array[Split]], - featuresForNode: Option[Array[Int]]): (Split, InformationGainStats, Predict) = { + featuresForNode: Option[Array[Int]], + node: Node): (Split, InformationGainStats, Predict) = { - // calculate predict only once - var predict: Option[Predict] = None + // calculate predict and impurity if current node is top node + val level = Node.indexToLevel(node.id) + var predictWithImpurity: Option[(Predict, Double)] = if (level == 0) { + None + } else { + Some((node.predict, node.impurity)) + } // For each (feature, split), calculate the gain, and select the best (feature, split). val (bestSplit, bestSplitStats) = @@ -708,9 +744,10 @@ object DecisionTree extends Serializable with Logging { val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) rightChildStats.subtract(leftChildStats) - predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) + predictWithImpurity = Some(predictWithImpurity.getOrElse( + calculatePredictImpurity(leftChildStats, rightChildStats))) val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata) + rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) (splitIdx, gainStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) @@ -722,9 +759,10 @@ object DecisionTree extends Serializable with Logging { Range(0, numSplits).map { splitIndex => val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) - predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) + predictWithImpurity = Some(predictWithImpurity.getOrElse( + calculatePredictImpurity(leftChildStats, rightChildStats))) val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata) + rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) (splitIndex, gainStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) @@ -794,9 +832,10 @@ object DecisionTree extends Serializable with Logging { val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) rightChildStats.subtract(leftChildStats) - predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) + predictWithImpurity = Some(predictWithImpurity.getOrElse( + calculatePredictImpurity(leftChildStats, rightChildStats))) val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata) + rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) (splitIndex, gainStats) }.maxBy(_._2.gain) val categoriesForSplit = @@ -807,9 +846,7 @@ object DecisionTree extends Serializable with Logging { } }.maxBy(_._2.gain) - assert(predict.isDefined, "must calculate predict for each node") - - (bestSplit, bestSplitStats, predict.get) + (bestSplit, bestSplitStats, predictWithImpurity.get._1) } /** @@ -874,32 +911,39 @@ object DecisionTree extends Serializable with Logging { // Iterate over all features. var featureIndex = 0 while (featureIndex < numFeatures) { - val numSplits = metadata.numSplits(featureIndex) - val numBins = metadata.numBins(featureIndex) if (metadata.isContinuous(featureIndex)) { - val numSamples = sampledInput.length + val featureSamples = sampledInput.map(lp => lp.features(featureIndex)) + val featureSplits = findSplitsForContinuousFeature(featureSamples, + metadata, featureIndex) + + val numSplits = featureSplits.length + val numBins = numSplits + 1 + logDebug(s"featureIndex = $featureIndex, numSplits = $numSplits") splits(featureIndex) = new Array[Split](numSplits) bins(featureIndex) = new Array[Bin](numBins) - val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted - val stride: Double = numSamples.toDouble / metadata.numBins(featureIndex) - logDebug("stride = " + stride) - for (splitIndex <- 0 until numSplits) { - val sampleIndex = splitIndex * stride.toInt - // Set threshold halfway in between 2 samples. - val threshold = (featureSamples(sampleIndex) + featureSamples(sampleIndex + 1)) / 2.0 + + var splitIndex = 0 + while (splitIndex < numSplits) { + val threshold = featureSplits(splitIndex) splits(featureIndex)(splitIndex) = new Split(featureIndex, threshold, Continuous, List()) + splitIndex += 1 } bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous), splits(featureIndex)(0), Continuous, Double.MinValue) - for (splitIndex <- 1 until numSplits) { + + splitIndex = 1 + while (splitIndex < numSplits) { bins(featureIndex)(splitIndex) = new Bin(splits(featureIndex)(splitIndex - 1), splits(featureIndex)(splitIndex), Continuous, Double.MinValue) + splitIndex += 1 } bins(featureIndex)(numSplits) = new Bin(splits(featureIndex)(numSplits - 1), new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue) } else { + val numSplits = metadata.numSplits(featureIndex) + val numBins = metadata.numBins(featureIndex) // Categorical feature val featureArity = metadata.featureArity(featureIndex) if (metadata.isUnordered(featureIndex)) { @@ -976,4 +1020,77 @@ object DecisionTree extends Serializable with Logging { categories } + /** + * Find splits for a continuous feature + * NOTE: Returned number of splits is set based on `featureSamples` and + * could be different from the specified `numSplits`. + * The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly. + * @param featureSamples feature values of each sample + * @param metadata decision tree metadata + * NOTE: `metadata.numbins` will be changed accordingly + * if there are not enough splits to be found + * @param featureIndex feature index to find splits + * @return array of splits + */ + private[tree] def findSplitsForContinuousFeature( + featureSamples: Array[Double], + metadata: DecisionTreeMetadata, + featureIndex: Int): Array[Double] = { + require(metadata.isContinuous(featureIndex), + "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.") + + val splits = { + val numSplits = metadata.numSplits(featureIndex) + + // get count for each distinct value + val valueCountMap = featureSamples.foldLeft(Map.empty[Double, Int]) { (m, x) => + m + ((x, m.getOrElse(x, 0) + 1)) + } + // sort distinct values + val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray + + // if possible splits is not enough or just enough, just return all possible splits + val possibleSplits = valueCounts.length + if (possibleSplits <= numSplits) { + valueCounts.map(_._1) + } else { + // stride between splits + val stride: Double = featureSamples.length.toDouble / (numSplits + 1) + logDebug("stride = " + stride) + + // iterate `valueCount` to find splits + val splits = new ArrayBuffer[Double] + var index = 1 + // currentCount: sum of counts of values that have been visited + var currentCount = valueCounts(0)._2 + // targetCount: target value for `currentCount`. + // If `currentCount` is closest value to `targetCount`, + // then current value is a split threshold. + // After finding a split threshold, `targetCount` is added by stride. + var targetCount = stride + while (index < valueCounts.length) { + val previousCount = currentCount + currentCount += valueCounts(index)._2 + val previousGap = math.abs(previousCount - targetCount) + val currentGap = math.abs(currentCount - targetCount) + // If adding count of current value to currentCount + // makes the gap between currentCount and targetCount smaller, + // previous value is a split threshold. + if (previousGap < currentGap) { + splits.append(valueCounts(index - 1)._1) + targetCount += stride + } + index += 1 + } + + splits.toArray + } + } + + assert(splits.length > 0) + // set number of splits accordingly + metadata.setNumSplits(featureIndex, splits.length) + + splits + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index fa7a26f17c3ca..ebbd8e0257209 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -176,6 +176,8 @@ private class RandomForest ( timer.stop("findBestSplits") } + baggedInput.unpersist() + timer.stop("total") logInfo("Internal timing for DecisionTree:") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala index 55f422dff0d71..ce8825cc03229 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala @@ -64,12 +64,6 @@ private[tree] class DTStatsAggregator( numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins) } - /** - * Indicator for each feature of whether that feature is an unordered feature. - * TODO: Is Array[Boolean] any faster? - */ - def isUnordered(featureIndex: Int): Boolean = metadata.isUnordered(featureIndex) - /** * Total number of elements stored in this aggregator */ @@ -128,21 +122,13 @@ private[tree] class DTStatsAggregator( * Pre-compute feature offset for use with [[featureUpdate]]. * For ordered features only. */ - def getFeatureOffset(featureIndex: Int): Int = { - require(!isUnordered(featureIndex), - s"DTStatsAggregator.getFeatureOffset is for ordered features only, but was called" + - s" for unordered feature $featureIndex.") - featureOffsets(featureIndex) - } + def getFeatureOffset(featureIndex: Int): Int = featureOffsets(featureIndex) /** * Pre-compute feature offset for use with [[featureUpdate]]. * For unordered features only. */ def getLeftRightFeatureOffsets(featureIndex: Int): (Int, Int) = { - require(isUnordered(featureIndex), - s"DTStatsAggregator.getLeftRightFeatureOffsets is for unordered features only," + - s" but was called for ordered feature $featureIndex.") val baseOffset = featureOffsets(featureIndex) (baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index 212dce25236e0..5bc0f2635c6b1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -19,6 +19,7 @@ package org.apache.spark.mllib.tree.impl import scala.collection.mutable +import org.apache.spark.Logging import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ @@ -75,6 +76,17 @@ private[tree] class DecisionTreeMetadata( numBins(featureIndex) - 1 } + + /** + * Set number of splits for a continuous feature. + * For a continuous feature, number of bins is number of splits plus 1. + */ + def setNumSplits(featureIndex: Int, numSplits: Int) { + require(isContinuous(featureIndex), + s"Only number of bin for a continuous feature can be set.") + numBins(featureIndex) = numSplits + 1 + } + /** * Indicates if feature subsampling is being used. */ @@ -82,7 +94,7 @@ private[tree] class DecisionTreeMetadata( } -private[tree] object DecisionTreeMetadata { +private[tree] object DecisionTreeMetadata extends Logging { /** * Construct a [[DecisionTreeMetadata]] instance for this dataset and parameters. @@ -103,6 +115,10 @@ private[tree] object DecisionTreeMetadata { } val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt + if (maxPossibleBins < strategy.maxBins) { + logWarning(s"DecisionTree reducing maxBins from ${strategy.maxBins} to $maxPossibleBins" + + s" (= number of training instances)") + } // We check the number of bins here against maxPossibleBins. // This needs to be checked here instead of in Strategy since maxPossibleBins can be modified diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index a89e71e115806..9a50ecb550c38 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -26,13 +26,17 @@ import org.apache.spark.annotation.DeveloperApi * @param impurity current node impurity * @param leftImpurity left node impurity * @param rightImpurity right node impurity + * @param leftPredict left node predict + * @param rightPredict right node predict */ @DeveloperApi class InformationGainStats( val gain: Double, val impurity: Double, val leftImpurity: Double, - val rightImpurity: Double) extends Serializable { + val rightImpurity: Double, + val leftPredict: Predict, + val rightPredict: Predict) extends Serializable { override def toString = { "gain = %f, impurity = %f, left impurity = %f, right impurity = %f" @@ -58,5 +62,6 @@ private[tree] object InformationGainStats { * denote that current split doesn't satisfies minimum info gain or * minimum number of instances per node. */ - val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0) + val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, + new Predict(0.0, 0.0), new Predict(0.0, 0.0)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 56c3e25d9285f..2179da8dbe03e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -32,7 +32,8 @@ import org.apache.spark.mllib.linalg.Vector * * @param id integer node id, from 1 * @param predict predicted value at the node - * @param isLeaf whether the leaf is a node + * @param impurity current node impurity + * @param isLeaf whether the node is a leaf * @param split split to calculate left and right nodes * @param leftNode left child * @param rightNode right child @@ -41,7 +42,8 @@ import org.apache.spark.mllib.linalg.Vector @DeveloperApi class Node ( val id: Int, - var predict: Double, + var predict: Predict, + var impurity: Double, var isLeaf: Boolean, var split: Option[Split], var leftNode: Option[Node], @@ -49,7 +51,7 @@ class Node ( var stats: Option[InformationGainStats]) extends Serializable with Logging { override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " + - "split = " + split + ", stats = " + stats + "impurity = " + impurity + "split = " + split + ", stats = " + stats /** * build the left node and right nodes if not leaf @@ -62,6 +64,7 @@ class Node ( logDebug("id = " + id + ", split = " + split) logDebug("stats = " + stats) logDebug("predict = " + predict) + logDebug("impurity = " + impurity) if (!isLeaf) { leftNode = Some(nodes(Node.leftChildIndex(id))) rightNode = Some(nodes(Node.rightChildIndex(id))) @@ -77,7 +80,7 @@ class Node ( */ def predict(features: Vector) : Double = { if (isLeaf) { - predict + predict.predict } else{ if (split.get.featureType == Continuous) { if (features(split.get.feature) <= split.get.threshold) { @@ -109,7 +112,7 @@ class Node ( } else { Some(rightNode.get.deepCopy()) } - new Node(id, predict, isLeaf, split, leftNodeCopy, rightNodeCopy, stats) + new Node(id, predict, impurity, isLeaf, split, leftNodeCopy, rightNodeCopy, stats) } /** @@ -154,7 +157,7 @@ class Node ( } val prefix: String = " " * indentFactor if (isLeaf) { - prefix + s"Predict: $predict\n" + prefix + s"Predict: ${predict.predict}\n" } else { prefix + s"If ${splitToString(split.get, left=true)}\n" + leftNode.get.subtreeToString(indentFactor + 1) + @@ -170,7 +173,27 @@ private[tree] object Node { /** * Return a node with the given node id (but nothing else set). */ - def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, 0, false, None, None, None, None) + def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, new Predict(Double.MinValue), -1.0, + false, None, None, None, None) + + /** + * Construct a node with nodeIndex, predict, impurity and isLeaf parameters. + * This is used in `DecisionTree.findBestSplits` to construct child nodes + * after finding the best splits for parent nodes. + * Other fields are set at next level. + * @param nodeIndex integer node id, from 1 + * @param predict predicted value at the node + * @param impurity current node impurity + * @param isLeaf whether the node is a leaf + * @return new node instance + */ + def apply( + nodeIndex: Int, + predict: Predict, + impurity: Double, + isLeaf: Boolean): Node = { + new Node(nodeIndex, predict, impurity, isLeaf, None, None, None, None) + } /** * Return the index of the left child of this node. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala index d8476b5cd7bc7..004838ee5ba0e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala @@ -17,12 +17,15 @@ package org.apache.spark.mllib.tree.model +import org.apache.spark.annotation.DeveloperApi + /** * Predicted value for a node * @param predict predicted value * @param prob probability of the label (classification only) */ -private[tree] class Predict( +@DeveloperApi +class Predict( val predict: Double, val prob: Double = 0.0) extends Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala index 4d66d6d81caa5..6a22e2abe59bd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala @@ -82,9 +82,9 @@ class RandomForestModel(val trees: Array[DecisionTreeModel], val algo: Algo) ext */ override def toString: String = algo match { case Classification => - s"RandomForestModel classifier with $numTrees trees" + s"RandomForestModel classifier with $numTrees trees and $totalNumNodes total nodes" case Regression => - s"RandomForestModel regressor with $numTrees trees" + s"RandomForestModel regressor with $numTrees trees and $totalNumNodes total nodes" case _ => throw new IllegalArgumentException( s"RandomForestModel given unknown algo parameter: $algo.") } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala index fb76dccfdf79e..2bf9d9816ae45 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.mllib.feature import org.scalatest.FunSuite +import breeze.linalg.{norm => brzNorm} + import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} import org.apache.spark.mllib.util.LocalSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -50,10 +52,10 @@ class NormalizerSuite extends FunSuite with LocalSparkContext { assert((data1, data1RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) - assert(data1(0).toBreeze.norm(1) ~== 1.0 absTol 1E-5) - assert(data1(2).toBreeze.norm(1) ~== 1.0 absTol 1E-5) - assert(data1(3).toBreeze.norm(1) ~== 1.0 absTol 1E-5) - assert(data1(4).toBreeze.norm(1) ~== 1.0 absTol 1E-5) + assert(brzNorm(data1(0).toBreeze, 1) ~== 1.0 absTol 1E-5) + assert(brzNorm(data1(2).toBreeze, 1) ~== 1.0 absTol 1E-5) + assert(brzNorm(data1(3).toBreeze, 1) ~== 1.0 absTol 1E-5) + assert(brzNorm(data1(4).toBreeze, 1) ~== 1.0 absTol 1E-5) assert(data1(0) ~== Vectors.sparse(3, Seq((0, -0.465116279), (1, 0.53488372))) absTol 1E-5) assert(data1(1) ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) @@ -77,10 +79,10 @@ class NormalizerSuite extends FunSuite with LocalSparkContext { assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) - assert(data2(0).toBreeze.norm(2) ~== 1.0 absTol 1E-5) - assert(data2(2).toBreeze.norm(2) ~== 1.0 absTol 1E-5) - assert(data2(3).toBreeze.norm(2) ~== 1.0 absTol 1E-5) - assert(data2(4).toBreeze.norm(2) ~== 1.0 absTol 1E-5) + assert(brzNorm(data2(0).toBreeze, 2) ~== 1.0 absTol 1E-5) + assert(brzNorm(data2(2).toBreeze, 2) ~== 1.0 absTol 1E-5) + assert(brzNorm(data2(3).toBreeze, 2) ~== 1.0 absTol 1E-5) + assert(brzNorm(data2(4).toBreeze, 2) ~== 1.0 absTol 1E-5) assert(data2(0) ~== Vectors.sparse(3, Seq((0, -0.65617871), (1, 0.75460552))) absTol 1E-5) assert(data2(1) ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index a48ed71a1c5fc..8fc5e111bbc17 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ -import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.configuration.{QuantileStrategy, Strategy} import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, TreePoint} import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} import org.apache.spark.mllib.tree.model.{InformationGainStats, DecisionTreeModel, Node} @@ -102,6 +102,72 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(List(3.0, 2.0, 0.0).toSeq === l.toSeq) } + test("find splits for a continuous feature") { + // find splits for normal case + { + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + Map(), Set(), + Array(6), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0, 0 + ) + val featureSamples = Array.fill(200000)(math.random) + val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + assert(splits.length === 5) + assert(fakeMetadata.numSplits(0) === 5) + assert(fakeMetadata.numBins(0) === 6) + // check returned splits are distinct + assert(splits.distinct.length === splits.length) + } + + // find splits should not return identical splits + // when there are not enough split candidates, reduce the number of splits in metadata + { + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + Map(), Set(), + Array(5), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0, 0 + ) + val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble) + val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + assert(splits.length === 3) + assert(fakeMetadata.numSplits(0) === 3) + assert(fakeMetadata.numBins(0) === 4) + // check returned splits are distinct + assert(splits.distinct.length === splits.length) + } + + // find splits when most samples close to the minimum + { + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + Map(), Set(), + Array(3), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0, 0 + ) + val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble) + val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + assert(splits.length === 2) + assert(fakeMetadata.numSplits(0) === 2) + assert(fakeMetadata.numBins(0) === 3) + assert(splits(0) === 2.0) + assert(splits(1) === 3.0) + } + + // find splits when most samples close to the maximum + { + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + Map(), Set(), + Array(3), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0, 0 + ) + val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) + val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + assert(splits.length === 1) + assert(fakeMetadata.numSplits(0) === 1) + assert(fakeMetadata.numBins(0) === 2) + assert(splits(0) === 1.0) + } + } + test("Multiclass classification with unordered categorical features:" + " split and bin calculations") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() @@ -253,7 +319,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val stats = rootNode.stats.get assert(stats.gain > 0) - assert(rootNode.predict === 1) + assert(rootNode.predict.predict === 1) assert(stats.impurity > 0.2) } @@ -282,7 +348,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val stats = rootNode.stats.get assert(stats.gain > 0) - assert(rootNode.predict === 0.6) + assert(rootNode.predict.predict === 0.6) assert(stats.impurity > 0.2) } @@ -352,7 +418,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(stats.gain === 0) assert(stats.leftImpurity === 0) assert(stats.rightImpurity === 0) - assert(rootNode.predict === 1) + assert(rootNode.predict.predict === 1) } test("Binary classification stump with fixed label 0 for Entropy") { @@ -377,7 +443,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(stats.gain === 0) assert(stats.leftImpurity === 0) assert(stats.rightImpurity === 0) - assert(rootNode.predict === 0) + assert(rootNode.predict.predict === 0) } test("Binary classification stump with fixed label 1 for Entropy") { @@ -402,7 +468,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(stats.gain === 0) assert(stats.leftImpurity === 0) assert(stats.rightImpurity === 0) - assert(rootNode.predict === 1) + assert(rootNode.predict.predict === 1) } test("Second level node building with vs. without groups") { @@ -471,7 +537,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(stats1.impurity === stats2.impurity) assert(stats1.leftImpurity === stats2.leftImpurity) assert(stats1.rightImpurity === stats2.rightImpurity) - assert(children1(i).predict === children2(i).predict) + assert(children1(i).predict.predict === children2(i).predict.predict) } } @@ -646,7 +712,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val model = DecisionTree.train(rdd, strategy) assert(model.topNode.isLeaf) - assert(model.topNode.predict == 0.0) + assert(model.topNode.predict.predict == 0.0) val predicts = rdd.map(p => model.predict(p.features)).collect() predicts.foreach { predict => assert(predict == 0.0) @@ -693,7 +759,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val model = DecisionTree.train(input, strategy) assert(model.topNode.isLeaf) - assert(model.topNode.predict == 0.0) + assert(model.topNode.predict.predict == 0.0) val predicts = input.map(p => model.predict(p.features)).collect() predicts.foreach { predict => assert(predict == 0.0) @@ -705,6 +771,92 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val gain = rootNode.stats.get assert(gain == InformationGainStats.invalidInformationGainStats) } + + test("Avoid aggregation on the last level") { + val arr = new Array[LabeledPoint](4) + arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)) + arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)) + arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)) + arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)) + val input = sc.parallelize(arr) + + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1, + numClassesForClassification = 2, categoricalFeaturesInfo = Map(0 -> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) + + val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) + val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput) + + val topNode = Node.emptyNode(nodeIndex = 1) + assert(topNode.predict.predict === Double.MinValue) + assert(topNode.impurity === -1.0) + assert(topNode.isLeaf === false) + + val nodesForGroup = Map((0, Array(topNode))) + val treeToNodeToIndexInfo = Map((0, Map( + (topNode.id, new RandomForest.NodeIndexInfo(0, None)) + ))) + val nodeQueue = new mutable.Queue[(Int, Node)]() + DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), + nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) + + // don't enqueue leaf nodes into node queue + assert(nodeQueue.isEmpty) + + // set impurity and predict for topNode + assert(topNode.predict.predict !== Double.MinValue) + assert(topNode.impurity !== -1.0) + + // set impurity and predict for child nodes + assert(topNode.leftNode.get.predict.predict === 0.0) + assert(topNode.rightNode.get.predict.predict === 1.0) + assert(topNode.leftNode.get.impurity === 0.0) + assert(topNode.rightNode.get.impurity === 0.0) + } + + test("Avoid aggregation if impurity is 0.0") { + val arr = new Array[LabeledPoint](4) + arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)) + arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)) + arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)) + arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)) + val input = sc.parallelize(arr) + + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + numClassesForClassification = 2, categoricalFeaturesInfo = Map(0 -> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) + + val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) + val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput) + + val topNode = Node.emptyNode(nodeIndex = 1) + assert(topNode.predict.predict === Double.MinValue) + assert(topNode.impurity === -1.0) + assert(topNode.isLeaf === false) + + val nodesForGroup = Map((0, Array(topNode))) + val treeToNodeToIndexInfo = Map((0, Map( + (topNode.id, new RandomForest.NodeIndexInfo(0, None)) + ))) + val nodeQueue = new mutable.Queue[(Int, Node)]() + DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), + nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) + + // don't enqueue a node into node queue if its impurity is 0.0 + assert(nodeQueue.isEmpty) + + // set impurity and predict for topNode + assert(topNode.predict.predict !== Double.MinValue) + assert(topNode.impurity !== -1.0) + + // set impurity and predict for child nodes + assert(topNode.leftNode.get.predict.predict === 0.0) + assert(topNode.rightNode.get.predict.predict === 1.0) + assert(topNode.leftNode.get.impurity === 0.0) + assert(topNode.rightNode.get.impurity === 0.0) + } } object DecisionTreeSuite { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index 20d372dc1d3ca..6b13765b98f41 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -93,8 +93,9 @@ class RandomForestSuite extends FunSuite with LocalSparkContext { val categoricalFeaturesInfo = Map.empty[Int, Int] val numTrees = 1 - val strategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, - numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo) + val strategy = new Strategy(algo = Regression, impurity = Variance, + maxDepth = 2, maxBins = 10, numClassesForClassification = 2, + categoricalFeaturesInfo = categoricalFeaturesInfo) val rf = RandomForest.trainRegressor(rdd, strategy, numTrees = numTrees, featureSubsetStrategy = "auto", seed = 123) @@ -173,6 +174,22 @@ class RandomForestSuite extends FunSuite with LocalSparkContext { checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt) } + test("alternating categorical and continuous features with multiclass labels to test indexing") { + val arr = new Array[LabeledPoint](4) + arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0, 3.0, 1.0)) + arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0, 1.0, 2.0)) + arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0, 6.0, 3.0)) + arr(3) = new LabeledPoint(2.0, Vectors.dense(0.0, 2.0, 1.0, 3.0, 2.0)) + val categoricalFeaturesInfo = Map(0 -> 3, 2 -> 2, 4 -> 4) + val input = sc.parallelize(arr) + + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + numClassesForClassification = 3, categoricalFeaturesInfo = categoricalFeaturesInfo) + val model = RandomForest.trainClassifier(input, strategy, numTrees = 2, + featureSubsetStrategy = "sqrt", seed = 12345) + RandomForestSuite.validateClassifier(model, arr, 1.0) + } + } object RandomForestSuite { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index 8ef2bb1bf6a78..0dbe766b4d917 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -67,8 +67,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { |0 |0 2:4.0 4:5.0 6:6.0 """.stripMargin - val tempDir = Files.createTempDir() - tempDir.deleteOnExit() + val tempDir = Utils.createTempDir() val file = new File(tempDir.getPath, "part-00000") Files.write(lines, file, Charsets.US_ASCII) val path = tempDir.toURI.toString @@ -100,7 +99,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { LabeledPoint(1.1, Vectors.sparse(3, Seq((0, 1.23), (2, 4.56)))), LabeledPoint(0.0, Vectors.dense(1.01, 2.02, 3.03)) ), 2) - val tempDir = Files.createTempDir() + val tempDir = Utils.createTempDir() val outputDir = new File(tempDir, "output") MLUtils.saveAsLibSVMFile(examples, outputDir.toURI.toString) val lines = outputDir.listFiles() @@ -166,7 +165,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { Vectors.sparse(2, Array(1), Array(-1.0)), Vectors.dense(0.0, 1.0) ), 2) - val tempDir = Files.createTempDir() + val tempDir = Utils.createTempDir() val outputDir = new File(tempDir, "vectors") val path = outputDir.toURI.toString vectors.saveAsTextFile(path) @@ -181,7 +180,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { LabeledPoint(0.0, Vectors.sparse(2, Array(1), Array(-1.0))), LabeledPoint(1.0, Vectors.dense(0.0, 1.0)) ), 2) - val tempDir = Files.createTempDir() + val tempDir = Utils.createTempDir() val outputDir = new File(tempDir, "points") val path = outputDir.toURI.toString points.saveAsTextFile(path) diff --git a/pom.xml b/pom.xml index 34d77e330348e..cf975f2d723bd 100644 --- a/pom.xml +++ b/pom.xml @@ -119,16 +119,16 @@ 0.18.1 shaded-protobuf org.spark-project.akka - 2.2.3-shaded-protobuf + 2.3.4-spark 1.7.5 1.2.17 - 2.3.0 + 1.0.4 2.4.1 ${hadoop.version} 1.4.0 0.98.5-hadoop2 3.4.5 - 0.12.0 + 0.12.0-protobuf-2.5 1.4.3 1.2.3 8.1.14.v20131031 @@ -224,6 +224,18 @@ false + + + spark-staging + Spring Staging Repository + https://oss.sonatype.org/content/repositories/orgspark-project-1085 + + true + + + false + + @@ -1144,7 +1156,7 @@ - + hadoop-2.4 2.4.0 diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index 39f8ba4745737..d919b18e09855 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -32,7 +32,7 @@ object MimaBuild { ProblemFilters.exclude[MissingMethodProblem](fullName), // Sometimes excluded methods have default arguments and // they are translated into public methods/fields($default$) in generated - // bytecode. It is not possible to exhustively list everything. + // bytecode. It is not possible to exhaustively list everything. // But this should be okay. ProblemFilters.exclude[MissingMethodProblem](fullName+"$default$2"), ProblemFilters.exclude[MissingMethodProblem](fullName+"$default$1"), diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index d499302124461..c58666af84f24 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -50,7 +50,22 @@ object MimaExcludes { "org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL2"), // MapStatus should be private[spark] ProblemFilters.exclude[IncompatibleTemplateDefProblem]( - "org.apache.spark.scheduler.MapStatus") + "org.apache.spark.scheduler.MapStatus"), + // TaskContext was promoted to Abstract class + ProblemFilters.exclude[AbstractClassProblem]( + "org.apache.spark.TaskContext") + ) ++ Seq( + // Adding new methods to the JavaRDDLike trait: + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.takeAsync"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.foreachPartitionAsync"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.countAsync"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.foreachAsync"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.collectAsync") ) case v if v.startsWith("1.1") => diff --git a/project/plugins.sbt b/project/plugins.sbt index 8096c61414660..678f5ed1ba610 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -17,7 +17,7 @@ addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.6.0") addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.7.4") -addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.4.0") +addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.5.0") addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6") diff --git a/project/spark-style/src/main/scala/org/apache/spark/scalastyle/SparkSpaceAfterCommentStartChecker.scala b/project/spark-style/src/main/scala/org/apache/spark/scalastyle/SparkSpaceAfterCommentStartChecker.scala deleted file mode 100644 index 80d3faa3fe749..0000000000000 --- a/project/spark-style/src/main/scala/org/apache/spark/scalastyle/SparkSpaceAfterCommentStartChecker.scala +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -package org.apache.spark.scalastyle - -import java.util.regex.Pattern - -import org.scalastyle.{PositionError, ScalariformChecker, ScalastyleError} -import scalariform.lexer.{MultiLineComment, ScalaDocComment, SingleLineComment, Token} -import scalariform.parser.CompilationUnit - -class SparkSpaceAfterCommentStartChecker extends ScalariformChecker { - val errorKey: String = "insert.a.single.space.after.comment.start.and.before.end" - - private def multiLineCommentRegex(comment: Token) = - Pattern.compile( """/\*\S+.*""", Pattern.DOTALL).matcher(comment.text.trim).matches() || - Pattern.compile( """/\*.*\S\*/""", Pattern.DOTALL).matcher(comment.text.trim).matches() - - private def scalaDocPatternRegex(comment: Token) = - Pattern.compile( """/\*\*\S+.*""", Pattern.DOTALL).matcher(comment.text.trim).matches() || - Pattern.compile( """/\*\*.*\S\*/""", Pattern.DOTALL).matcher(comment.text.trim).matches() - - private def singleLineCommentRegex(comment: Token): Boolean = - comment.text.trim.matches( """//\S+.*""") && !comment.text.trim.matches( """///+""") - - override def verify(ast: CompilationUnit): List[ScalastyleError] = { - ast.tokens - .filter(hasComment) - .map { - _.associatedWhitespaceAndComments.comments.map { - case x: SingleLineComment if singleLineCommentRegex(x.token) => Some(x.token.offset) - case x: MultiLineComment if multiLineCommentRegex(x.token) => Some(x.token.offset) - case x: ScalaDocComment if scalaDocPatternRegex(x.token) => Some(x.token.offset) - case _ => None - }.flatten - }.flatten.map(PositionError(_)) - } - - - private def hasComment(x: Token) = - x.associatedWhitespaceAndComments != null && !x.associatedWhitespaceAndComments.comments.isEmpty - -} diff --git a/python/.gitignore b/python/.gitignore index 80b361ffbd51c..52128cf844a79 100644 --- a/python/.gitignore +++ b/python/.gitignore @@ -1,5 +1,5 @@ *.pyc -docs/ +docs/_build/ pyspark.egg-info build/ dist/ diff --git a/python/docs/conf.py b/python/docs/conf.py index c368cf81a003b..e58d97ae6a746 100644 --- a/python/docs/conf.py +++ b/python/docs/conf.py @@ -55,9 +55,9 @@ # built documents. # # The short X.Y version. -version = '1.1' +version = '1.2-SNAPSHOT' # The full version, including alpha/beta/rc tags. -release = '' +release = '1.2-SNAPSHOT' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -102,7 +102,7 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -html_theme = 'default' +html_theme = 'nature' # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the @@ -121,7 +121,7 @@ # The name of an image file (relative to this directory) to place at the top # of the sidebar. -#html_logo = None +html_logo = "../../docs/img/spark-logo-hd.png" # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 @@ -131,7 +131,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +#html_static_path = ['_static'] # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied @@ -154,10 +154,10 @@ #html_additional_pages = {} # If false, no module index is generated. -#html_domain_indices = True +html_domain_indices = False # If false, no index is generated. -#html_use_index = True +html_use_index = False # If true, the index is split into individual pages for each letter. #html_split_index = False diff --git a/python/docs/epytext.py b/python/docs/epytext.py index 61d731bff570d..19fefbfc057a4 100644 --- a/python/docs/epytext.py +++ b/python/docs/epytext.py @@ -5,7 +5,7 @@ (r"L{([\w.()]+)}", r":class:`\1`"), (r"[LC]{(\w+\.\w+)\(\)}", r":func:`\1`"), (r"C{([\w.()]+)}", r":class:`\1`"), - (r"[IBCM]{(.+)}", r"`\1`"), + (r"[IBCM]{([^}]+)}", r"`\1`"), ('pyspark.rdd.RDD', 'RDD'), ) diff --git a/python/docs/index.rst b/python/docs/index.rst index 25b3f9bd93e63..703bef644de28 100644 --- a/python/docs/index.rst +++ b/python/docs/index.rst @@ -3,7 +3,7 @@ You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. -Welcome to PySpark API reference! +Welcome to Spark Python API Docs! =================================== Contents: @@ -13,6 +13,7 @@ Contents: pyspark pyspark.sql + pyspark.streaming pyspark.mllib @@ -24,14 +25,12 @@ Core classes: Main entry point for Spark functionality. :class:`pyspark.RDD` - + A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Indices and tables ================== -* :ref:`genindex` -* :ref:`modindex` * :ref:`search` diff --git a/python/docs/make.bat b/python/docs/make.bat index adad44fd7536a..c011e82b4a35a 100644 --- a/python/docs/make.bat +++ b/python/docs/make.bat @@ -1,242 +1,6 @@ @ECHO OFF -REM Command file for Sphinx documentation +rem This is the entry point for running Sphinx documentation. To avoid polluting the +rem environment, it just launches a new cmd to do the real work. -if "%SPHINXBUILD%" == "" ( - set SPHINXBUILD=sphinx-build -) -set BUILDDIR=_build -set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . -set I18NSPHINXOPTS=%SPHINXOPTS% . -if NOT "%PAPER%" == "" ( - set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% - set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% -) - -if "%1" == "" goto help - -if "%1" == "help" ( - :help - echo.Please use `make ^` where ^ is one of - echo. html to make standalone HTML files - echo. dirhtml to make HTML files named index.html in directories - echo. singlehtml to make a single large HTML file - echo. pickle to make pickle files - echo. json to make JSON files - echo. htmlhelp to make HTML files and a HTML help project - echo. qthelp to make HTML files and a qthelp project - echo. devhelp to make HTML files and a Devhelp project - echo. epub to make an epub - echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter - echo. text to make text files - echo. man to make manual pages - echo. texinfo to make Texinfo files - echo. gettext to make PO message catalogs - echo. changes to make an overview over all changed/added/deprecated items - echo. xml to make Docutils-native XML files - echo. pseudoxml to make pseudoxml-XML files for display purposes - echo. linkcheck to check all external links for integrity - echo. doctest to run all doctests embedded in the documentation if enabled - goto end -) - -if "%1" == "clean" ( - for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i - del /q /s %BUILDDIR%\* - goto end -) - - -%SPHINXBUILD% 2> nul -if errorlevel 9009 ( - echo. - echo.The 'sphinx-build' command was not found. Make sure you have Sphinx - echo.installed, then set the SPHINXBUILD environment variable to point - echo.to the full path of the 'sphinx-build' executable. Alternatively you - echo.may add the Sphinx directory to PATH. - echo. - echo.If you don't have Sphinx installed, grab it from - echo.http://sphinx-doc.org/ - exit /b 1 -) - -if "%1" == "html" ( - %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The HTML pages are in %BUILDDIR%/html. - goto end -) - -if "%1" == "dirhtml" ( - %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. - goto end -) - -if "%1" == "singlehtml" ( - %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. - goto end -) - -if "%1" == "pickle" ( - %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can process the pickle files. - goto end -) - -if "%1" == "json" ( - %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can process the JSON files. - goto end -) - -if "%1" == "htmlhelp" ( - %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can run HTML Help Workshop with the ^ -.hhp project file in %BUILDDIR%/htmlhelp. - goto end -) - -if "%1" == "qthelp" ( - %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can run "qcollectiongenerator" with the ^ -.qhcp project file in %BUILDDIR%/qthelp, like this: - echo.^> qcollectiongenerator %BUILDDIR%\qthelp\pyspark.qhcp - echo.To view the help file: - echo.^> assistant -collectionFile %BUILDDIR%\qthelp\pyspark.ghc - goto end -) - -if "%1" == "devhelp" ( - %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. - goto end -) - -if "%1" == "epub" ( - %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The epub file is in %BUILDDIR%/epub. - goto end -) - -if "%1" == "latex" ( - %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. - goto end -) - -if "%1" == "latexpdf" ( - %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex - cd %BUILDDIR%/latex - make all-pdf - cd %BUILDDIR%/.. - echo. - echo.Build finished; the PDF files are in %BUILDDIR%/latex. - goto end -) - -if "%1" == "latexpdfja" ( - %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex - cd %BUILDDIR%/latex - make all-pdf-ja - cd %BUILDDIR%/.. - echo. - echo.Build finished; the PDF files are in %BUILDDIR%/latex. - goto end -) - -if "%1" == "text" ( - %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The text files are in %BUILDDIR%/text. - goto end -) - -if "%1" == "man" ( - %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The manual pages are in %BUILDDIR%/man. - goto end -) - -if "%1" == "texinfo" ( - %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. - goto end -) - -if "%1" == "gettext" ( - %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The message catalogs are in %BUILDDIR%/locale. - goto end -) - -if "%1" == "changes" ( - %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes - if errorlevel 1 exit /b 1 - echo. - echo.The overview file is in %BUILDDIR%/changes. - goto end -) - -if "%1" == "linkcheck" ( - %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck - if errorlevel 1 exit /b 1 - echo. - echo.Link check complete; look for any errors in the above output ^ -or in %BUILDDIR%/linkcheck/output.txt. - goto end -) - -if "%1" == "doctest" ( - %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest - if errorlevel 1 exit /b 1 - echo. - echo.Testing of doctests in the sources finished, look at the ^ -results in %BUILDDIR%/doctest/output.txt. - goto end -) - -if "%1" == "xml" ( - %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The XML files are in %BUILDDIR%/xml. - goto end -) - -if "%1" == "pseudoxml" ( - %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. - goto end -) - -:end +cmd /V /E /C %~dp0make2.bat %* diff --git a/python/docs/make2.bat b/python/docs/make2.bat new file mode 100644 index 0000000000000..7bcaeafad13d7 --- /dev/null +++ b/python/docs/make2.bat @@ -0,0 +1,243 @@ +@ECHO OFF + +REM Command file for Sphinx documentation + + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set BUILDDIR=_build +set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . +set I18NSPHINXOPTS=%SPHINXOPTS% . +if NOT "%PAPER%" == "" ( + set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% + set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% +) + +if "%1" == "" goto help + +if "%1" == "help" ( + :help + echo.Please use `make ^` where ^ is one of + echo. html to make standalone HTML files + echo. dirhtml to make HTML files named index.html in directories + echo. singlehtml to make a single large HTML file + echo. pickle to make pickle files + echo. json to make JSON files + echo. htmlhelp to make HTML files and a HTML help project + echo. qthelp to make HTML files and a qthelp project + echo. devhelp to make HTML files and a Devhelp project + echo. epub to make an epub + echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter + echo. text to make text files + echo. man to make manual pages + echo. texinfo to make Texinfo files + echo. gettext to make PO message catalogs + echo. changes to make an overview over all changed/added/deprecated items + echo. xml to make Docutils-native XML files + echo. pseudoxml to make pseudoxml-XML files for display purposes + echo. linkcheck to check all external links for integrity + echo. doctest to run all doctests embedded in the documentation if enabled + goto end +) + +if "%1" == "clean" ( + for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i + del /q /s %BUILDDIR%\* + goto end +) + + +%SPHINXBUILD% 2> nul +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "html" ( + %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/html. + goto end +) + +if "%1" == "dirhtml" ( + %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. + goto end +) + +if "%1" == "singlehtml" ( + %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. + goto end +) + +if "%1" == "pickle" ( + %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can process the pickle files. + goto end +) + +if "%1" == "json" ( + %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can process the JSON files. + goto end +) + +if "%1" == "htmlhelp" ( + %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can run HTML Help Workshop with the ^ +.hhp project file in %BUILDDIR%/htmlhelp. + goto end +) + +if "%1" == "qthelp" ( + %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can run "qcollectiongenerator" with the ^ +.qhcp project file in %BUILDDIR%/qthelp, like this: + echo.^> qcollectiongenerator %BUILDDIR%\qthelp\pyspark.qhcp + echo.To view the help file: + echo.^> assistant -collectionFile %BUILDDIR%\qthelp\pyspark.ghc + goto end +) + +if "%1" == "devhelp" ( + %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. + goto end +) + +if "%1" == "epub" ( + %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The epub file is in %BUILDDIR%/epub. + goto end +) + +if "%1" == "latex" ( + %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. + goto end +) + +if "%1" == "latexpdf" ( + %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex + cd %BUILDDIR%/latex + make all-pdf + cd %BUILDDIR%/.. + echo. + echo.Build finished; the PDF files are in %BUILDDIR%/latex. + goto end +) + +if "%1" == "latexpdfja" ( + %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex + cd %BUILDDIR%/latex + make all-pdf-ja + cd %BUILDDIR%/.. + echo. + echo.Build finished; the PDF files are in %BUILDDIR%/latex. + goto end +) + +if "%1" == "text" ( + %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The text files are in %BUILDDIR%/text. + goto end +) + +if "%1" == "man" ( + %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The manual pages are in %BUILDDIR%/man. + goto end +) + +if "%1" == "texinfo" ( + %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. + goto end +) + +if "%1" == "gettext" ( + %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The message catalogs are in %BUILDDIR%/locale. + goto end +) + +if "%1" == "changes" ( + %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes + if errorlevel 1 exit /b 1 + echo. + echo.The overview file is in %BUILDDIR%/changes. + goto end +) + +if "%1" == "linkcheck" ( + %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck + if errorlevel 1 exit /b 1 + echo. + echo.Link check complete; look for any errors in the above output ^ +or in %BUILDDIR%/linkcheck/output.txt. + goto end +) + +if "%1" == "doctest" ( + %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest + if errorlevel 1 exit /b 1 + echo. + echo.Testing of doctests in the sources finished, look at the ^ +results in %BUILDDIR%/doctest/output.txt. + goto end +) + +if "%1" == "xml" ( + %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The XML files are in %BUILDDIR%/xml. + goto end +) + +if "%1" == "pseudoxml" ( + %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. + goto end +) + +:end diff --git a/python/docs/pyspark.mllib.rst b/python/docs/pyspark.mllib.rst index e95d19e97f151..4548b8739ed91 100644 --- a/python/docs/pyspark.mllib.rst +++ b/python/docs/pyspark.mllib.rst @@ -20,6 +20,14 @@ pyspark.mllib.clustering module :undoc-members: :show-inheritance: +pyspark.mllib.feature module +------------------------------- + +.. automodule:: pyspark.mllib.feature + :members: + :undoc-members: + :show-inheritance: + pyspark.mllib.linalg module --------------------------- diff --git a/python/docs/pyspark.rst b/python/docs/pyspark.rst index a68bd62433085..e81be3b6cb796 100644 --- a/python/docs/pyspark.rst +++ b/python/docs/pyspark.rst @@ -7,8 +7,9 @@ Subpackages .. toctree:: :maxdepth: 1 - pyspark.mllib pyspark.sql + pyspark.streaming + pyspark.mllib Contents -------- diff --git a/python/docs/pyspark.streaming.rst b/python/docs/pyspark.streaming.rst new file mode 100644 index 0000000000000..5024d694b668f --- /dev/null +++ b/python/docs/pyspark.streaming.rst @@ -0,0 +1,10 @@ +pyspark.streaming module +================== + +Module contents +--------------- + +.. automodule:: pyspark.streaming + :members: + :undoc-members: + :show-inheritance: diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 1a2e774738fe7..e39e6514d77a1 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -20,33 +20,21 @@ Public classes: - - L{SparkContext} + - :class:`SparkContext`: Main entry point for Spark functionality. - - L{RDD} + - L{RDD} A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. - - L{Broadcast} + - L{Broadcast} A broadcast variable that gets reused across tasks. - - L{Accumulator} + - L{Accumulator} An "add-only" shared variable that tasks can only add values to. - - L{SparkConf} + - L{SparkConf} For configuring Spark. - - L{SparkFiles} + - L{SparkFiles} Access files shipped with jobs. - - L{StorageLevel} + - L{StorageLevel} Finer-grained cache persistence levels. -Spark SQL: - - L{SQLContext} - Main entry point for SQL functionality. - - L{SchemaRDD} - A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In - addition to normal RDD operations, SchemaRDDs also support SQL. - - L{Row} - A Row of data returned by a Spark SQL query. - -Hive: - - L{HiveContext} - Main entry point for accessing data stored in Apache Hive.. """ # The following block allows us to import python's random instead of mllib.random for scripts in diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py index b64875a3f495a..dc7cd0bce56f3 100644 --- a/python/pyspark/conf.py +++ b/python/pyspark/conf.py @@ -83,11 +83,11 @@ def __init__(self, loadDefaults=True, _jvm=None, _jconf=None): """ Create a new Spark configuration. - @param loadDefaults: whether to load values from Java system + :param loadDefaults: whether to load values from Java system properties (True by default) - @param _jvm: internal parameter used to pass a handle to the + :param _jvm: internal parameter used to pass a handle to the Java VM; does not need to be set by users - @param _jconf: Optionally pass in an existing SparkConf handle + :param _jconf: Optionally pass in an existing SparkConf handle to use its parameters """ if _jconf: @@ -139,7 +139,7 @@ def setAll(self, pairs): """ Set multiple parameters, passed as a list of key-value pairs. - @param pairs: list of key-value pairs to set + :param pairs: list of key-value pairs to set """ for (k, v) in pairs: self._jconf.set(k, v) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index a45d79d6424c7..8d27ccb95f82c 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -29,7 +29,7 @@ from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ - PairDeserializer, CompressedSerializer + PairDeserializer, CompressedSerializer, AutoBatchedSerializer from pyspark.storagelevel import StorageLevel from pyspark.rdd import RDD from pyspark.traceback_utils import CallSite, first_spark_call @@ -67,27 +67,28 @@ class SparkContext(object): _default_batch_size_for_serialized_input = 10 def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, - environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None, - gateway=None): + environment=None, batchSize=0, serializer=PickleSerializer(), conf=None, + gateway=None, jsc=None): """ Create a new SparkContext. At least the master and app name should be set, either through the named parameters here or through C{conf}. - @param master: Cluster URL to connect to + :param master: Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). - @param appName: A name for your job, to display on the cluster web UI. - @param sparkHome: Location where Spark is installed on cluster nodes. - @param pyFiles: Collection of .zip or .py files to send to the cluster + :param appName: A name for your job, to display on the cluster web UI. + :param sparkHome: Location where Spark is installed on cluster nodes. + :param pyFiles: Collection of .zip or .py files to send to the cluster and add to PYTHONPATH. These can be paths on the local file system or HDFS, HTTP, HTTPS, or FTP URLs. - @param environment: A dictionary of environment variables to set on + :param environment: A dictionary of environment variables to set on worker nodes. - @param batchSize: The number of Python objects represented as a single - Java object. Set 1 to disable batching or -1 to use an - unlimited batch size. - @param serializer: The serializer for RDDs. - @param conf: A L{SparkConf} object setting Spark properties. - @param gateway: Use an existing gateway and JVM, otherwise a new JVM + :param batchSize: The number of Python objects represented as a single + Java object. Set 1 to disable batching, 0 to automatically choose + the batch size based on object sizes, or -1 to use an unlimited + batch size + :param serializer: The serializer for RDDs. + :param conf: A L{SparkConf} object setting Spark properties. + :param gateway: Use an existing gateway and JVM, otherwise a new JVM will be instantiated. @@ -103,20 +104,22 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, SparkContext._ensure_initialized(self, gateway=gateway) try: self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer, - conf) + conf, jsc) except: # If an error occurs, clean up in order to allow future SparkContext creation: self.stop() raise def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer, - conf): + conf, jsc): self.environment = environment or {} self._conf = conf or SparkConf(_jvm=self._jvm) self._batchSize = batchSize # -1 represents an unlimited batch size self._unbatched_serializer = serializer if batchSize == 1: self.serializer = self._unbatched_serializer + elif batchSize == 0: + self.serializer = AutoBatchedSerializer(self._unbatched_serializer) else: self.serializer = BatchedSerializer(self._unbatched_serializer, batchSize) @@ -151,7 +154,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, self.environment[varName] = v # Create the Java SparkContext through Py4J - self._jsc = self._initialize_context(self._conf._jconf) + self._jsc = jsc or self._initialize_context(self._conf._jconf) # Create a single Accumulator in Java that we'll send all our updates through; # they will be passed back to us through a TCP server @@ -212,8 +215,6 @@ def _ensure_initialized(cls, instance=None, gateway=None): SparkContext._gateway = gateway or launch_gateway() SparkContext._jvm = SparkContext._gateway.jvm SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile - SparkContext._jvm.SerDeUtil.initialize() - SparkContext._jvm.SerDe.initialize() if instance: if (SparkContext._active_spark_context and @@ -417,16 +418,16 @@ def sequenceFile(self, path, keyClass=None, valueClass=None, keyConverter=None, 3. If this fails, the fallback is to call 'toString' on each key and value 4. C{PickleSerializer} is used to deserialize pickled objects on the Python side - @param path: path to sequncefile - @param keyClass: fully qualified classname of key Writable class + :param path: path to sequncefile + :param keyClass: fully qualified classname of key Writable class (e.g. "org.apache.hadoop.io.Text") - @param valueClass: fully qualified classname of value Writable class + :param valueClass: fully qualified classname of value Writable class (e.g. "org.apache.hadoop.io.LongWritable") - @param keyConverter: - @param valueConverter: - @param minSplits: minimum splits in dataset + :param keyConverter: + :param valueConverter: + :param minSplits: minimum splits in dataset (default min(2, sc.defaultParallelism)) - @param batchSize: The number of Python objects represented as a single + :param batchSize: The number of Python objects represented as a single Java object. (default sc._default_batch_size_for_serialized_input) """ minSplits = minSplits or min(self.defaultParallelism, 2) @@ -446,18 +447,18 @@ def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConv A Hadoop configuration can be passed in as a Python dict. This will be converted into a Configuration in Java - @param path: path to Hadoop file - @param inputFormatClass: fully qualified classname of Hadoop InputFormat + :param path: path to Hadoop file + :param inputFormatClass: fully qualified classname of Hadoop InputFormat (e.g. "org.apache.hadoop.mapreduce.lib.input.TextInputFormat") - @param keyClass: fully qualified classname of key Writable class + :param keyClass: fully qualified classname of key Writable class (e.g. "org.apache.hadoop.io.Text") - @param valueClass: fully qualified classname of value Writable class + :param valueClass: fully qualified classname of value Writable class (e.g. "org.apache.hadoop.io.LongWritable") - @param keyConverter: (None by default) - @param valueConverter: (None by default) - @param conf: Hadoop configuration, passed in as a dict + :param keyConverter: (None by default) + :param valueConverter: (None by default) + :param conf: Hadoop configuration, passed in as a dict (None by default) - @param batchSize: The number of Python objects represented as a single + :param batchSize: The number of Python objects represented as a single Java object. (default sc._default_batch_size_for_serialized_input) """ jconf = self._dictToJavaMap(conf) @@ -476,17 +477,17 @@ def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=N This will be converted into a Configuration in Java. The mechanism is the same as for sc.sequenceFile. - @param inputFormatClass: fully qualified classname of Hadoop InputFormat + :param inputFormatClass: fully qualified classname of Hadoop InputFormat (e.g. "org.apache.hadoop.mapreduce.lib.input.TextInputFormat") - @param keyClass: fully qualified classname of key Writable class + :param keyClass: fully qualified classname of key Writable class (e.g. "org.apache.hadoop.io.Text") - @param valueClass: fully qualified classname of value Writable class + :param valueClass: fully qualified classname of value Writable class (e.g. "org.apache.hadoop.io.LongWritable") - @param keyConverter: (None by default) - @param valueConverter: (None by default) - @param conf: Hadoop configuration, passed in as a dict + :param keyConverter: (None by default) + :param valueConverter: (None by default) + :param conf: Hadoop configuration, passed in as a dict (None by default) - @param batchSize: The number of Python objects represented as a single + :param batchSize: The number of Python objects represented as a single Java object. (default sc._default_batch_size_for_serialized_input) """ jconf = self._dictToJavaMap(conf) @@ -507,18 +508,18 @@ def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter= A Hadoop configuration can be passed in as a Python dict. This will be converted into a Configuration in Java. - @param path: path to Hadoop file - @param inputFormatClass: fully qualified classname of Hadoop InputFormat + :param path: path to Hadoop file + :param inputFormatClass: fully qualified classname of Hadoop InputFormat (e.g. "org.apache.hadoop.mapred.TextInputFormat") - @param keyClass: fully qualified classname of key Writable class + :param keyClass: fully qualified classname of key Writable class (e.g. "org.apache.hadoop.io.Text") - @param valueClass: fully qualified classname of value Writable class + :param valueClass: fully qualified classname of value Writable class (e.g. "org.apache.hadoop.io.LongWritable") - @param keyConverter: (None by default) - @param valueConverter: (None by default) - @param conf: Hadoop configuration, passed in as a dict + :param keyConverter: (None by default) + :param valueConverter: (None by default) + :param conf: Hadoop configuration, passed in as a dict (None by default) - @param batchSize: The number of Python objects represented as a single + :param batchSize: The number of Python objects represented as a single Java object. (default sc._default_batch_size_for_serialized_input) """ jconf = self._dictToJavaMap(conf) @@ -537,17 +538,17 @@ def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, This will be converted into a Configuration in Java. The mechanism is the same as for sc.sequenceFile. - @param inputFormatClass: fully qualified classname of Hadoop InputFormat + :param inputFormatClass: fully qualified classname of Hadoop InputFormat (e.g. "org.apache.hadoop.mapred.TextInputFormat") - @param keyClass: fully qualified classname of key Writable class + :param keyClass: fully qualified classname of key Writable class (e.g. "org.apache.hadoop.io.Text") - @param valueClass: fully qualified classname of value Writable class + :param valueClass: fully qualified classname of value Writable class (e.g. "org.apache.hadoop.io.LongWritable") - @param keyConverter: (None by default) - @param valueConverter: (None by default) - @param conf: Hadoop configuration, passed in as a dict + :param keyConverter: (None by default) + :param valueConverter: (None by default) + :param conf: Hadoop configuration, passed in as a dict (None by default) - @param batchSize: The number of Python objects represented as a single + :param batchSize: The number of Python objects represented as a single Java object. (default sc._default_batch_size_for_serialized_input) """ jconf = self._dictToJavaMap(conf) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index a765b1c4f7d87..e295c9d0954d9 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -21,7 +21,7 @@ from numpy import array from pyspark import SparkContext, PickleSerializer -from pyspark.mllib.linalg import SparseVector, _convert_to_vector +from pyspark.mllib.linalg import SparseVector, _convert_to_vector, _to_java_object_rdd from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper @@ -79,15 +79,15 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, """ Train a logistic regression model on the given data. - @param data: The training data. - @param iterations: The number of iterations (default: 100). - @param step: The step parameter used in SGD + :param data: The training data. + :param iterations: The number of iterations (default: 100). + :param step: The step parameter used in SGD (default: 1.0). - @param miniBatchFraction: Fraction of data to be used for each SGD + :param miniBatchFraction: Fraction of data to be used for each SGD iteration. - @param initialWeights: The initial weights (default: None). - @param regParam: The regularizer parameter (default: 1.0). - @param regType: The type of regularizer used for training + :param initialWeights: The initial weights (default: None). + :param regParam: The regularizer parameter (default: 1.0). + :param regType: The type of regularizer used for training our model. :Allowed values: @@ -151,15 +151,15 @@ def train(cls, data, iterations=100, step=1.0, regParam=1.0, """ Train a support vector machine on the given data. - @param data: The training data. - @param iterations: The number of iterations (default: 100). - @param step: The step parameter used in SGD + :param data: The training data. + :param iterations: The number of iterations (default: 100). + :param step: The step parameter used in SGD (default: 1.0). - @param regParam: The regularizer parameter (default: 1.0). - @param miniBatchFraction: Fraction of data to be used for each SGD + :param regParam: The regularizer parameter (default: 1.0). + :param miniBatchFraction: Fraction of data to be used for each SGD iteration. - @param initialWeights: The initial weights (default: None). - @param regType: The type of regularizer used for training + :param initialWeights: The initial weights (default: None). + :param regType: The type of regularizer used for training our model. :Allowed values: @@ -238,13 +238,13 @@ def train(cls, data, lambda_=1.0): classification. By making every vector a 0-1 vector, it can also be used as Bernoulli NB (U{http://tinyurl.com/p7c96j6}). - @param data: RDD of NumPy vectors, one per element, where the first + :param data: RDD of NumPy vectors, one per element, where the first coordinate is the label and the rest is the feature vector (e.g. a count vector). - @param lambda_: The smoothing parameter + :param lambda_: The smoothing parameter """ sc = data.context - jlist = sc._jvm.PythonMLLibAPI().trainNaiveBayes(data._to_java_object_rdd(), lambda_) + jlist = sc._jvm.PythonMLLibAPI().trainNaiveBayes(_to_java_object_rdd(data), lambda_) labels, pi, theta = PickleSerializer().loads(str(sc._jvm.SerDe.dumps(jlist))) return NaiveBayesModel(labels.toArray(), pi.toArray(), numpy.array(theta)) diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 12c56022717a5..5ee7997104d21 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -17,7 +17,7 @@ from pyspark import SparkContext from pyspark.serializers import PickleSerializer, AutoBatchedSerializer -from pyspark.mllib.linalg import SparseVector, _convert_to_vector +from pyspark.mllib.linalg import SparseVector, _convert_to_vector, _to_java_object_rdd __all__ = ['KMeansModel', 'KMeans'] @@ -85,7 +85,7 @@ def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||" # cache serialized data to avoid objects over head in JVM cached = rdd.map(_convert_to_vector)._reserialize(AutoBatchedSerializer(ser)).cache() model = sc._jvm.PythonMLLibAPI().trainKMeansModel( - cached._to_java_object_rdd(), k, maxIterations, runs, initializationMode) + _to_java_object_rdd(cached), k, maxIterations, runs, initializationMode) bytes = sc._jvm.SerDe.dumps(model.clusterCenters()) centers = ser.loads(str(bytes)) return KMeansModel([c.toArray() for c in centers]) diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py new file mode 100644 index 0000000000000..b5a3f22c6907e --- /dev/null +++ b/python/pyspark/mllib/feature.py @@ -0,0 +1,194 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Python package for feature in MLlib. +""" +from pyspark.serializers import PickleSerializer, AutoBatchedSerializer +from pyspark.mllib.linalg import _convert_to_vector, _to_java_object_rdd + +__all__ = ['Word2Vec', 'Word2VecModel'] + + +class Word2VecModel(object): + """ + class for Word2Vec model + """ + def __init__(self, sc, java_model): + """ + :param sc: Spark context + :param java_model: Handle to Java model object + """ + self._sc = sc + self._java_model = java_model + + def __del__(self): + self._sc._gateway.detach(self._java_model) + + def transform(self, word): + """ + :param word: a word + :return: vector representation of word + + Transforms a word to its vector representation + + Note: local use only + """ + # TODO: make transform usable in RDD operations from python side + result = self._java_model.transform(word) + return PickleSerializer().loads(str(self._sc._jvm.SerDe.dumps(result))) + + def findSynonyms(self, x, num): + """ + :param x: a word or a vector representation of word + :param num: number of synonyms to find + :return: array of (word, cosineSimilarity) + + Find synonyms of a word + + Note: local use only + """ + # TODO: make findSynonyms usable in RDD operations from python side + ser = PickleSerializer() + if type(x) == str: + jlist = self._java_model.findSynonyms(x, num) + else: + bytes = bytearray(ser.dumps(_convert_to_vector(x))) + vec = self._sc._jvm.SerDe.loads(bytes) + jlist = self._java_model.findSynonyms(vec, num) + words, similarity = ser.loads(str(self._sc._jvm.SerDe.dumps(jlist))) + return zip(words, similarity) + + +class Word2Vec(object): + """ + Word2Vec creates vector representation of words in a text corpus. + The algorithm first constructs a vocabulary from the corpus + and then learns vector representation of words in the vocabulary. + The vector representation can be used as features in + natural language processing and machine learning algorithms. + + We used skip-gram model in our implementation and hierarchical softmax + method to train the model. The variable names in the implementation + matches the original C implementation. + For original C implementation, see https://code.google.com/p/word2vec/ + For research papers, see + Efficient Estimation of Word Representations in Vector Space + and + Distributed Representations of Words and Phrases and their Compositionality. + + >>> sentence = "a b " * 100 + "a c " * 10 + >>> localDoc = [sentence, sentence] + >>> doc = sc.parallelize(localDoc).map(lambda line: line.split(" ")) + >>> model = Word2Vec().setVectorSize(10).setSeed(42L).fit(doc) + >>> syms = model.findSynonyms("a", 2) + >>> str(syms[0][0]) + 'b' + >>> str(syms[1][0]) + 'c' + >>> len(syms) + 2 + >>> vec = model.transform("a") + >>> len(vec) + 10 + >>> syms = model.findSynonyms(vec, 2) + >>> str(syms[0][0]) + 'b' + >>> str(syms[1][0]) + 'c' + >>> len(syms) + 2 + """ + def __init__(self): + """ + Construct Word2Vec instance + """ + self.vectorSize = 100 + self.learningRate = 0.025 + self.numPartitions = 1 + self.numIterations = 1 + self.seed = 42L + + def setVectorSize(self, vectorSize): + """ + Sets vector size (default: 100). + """ + self.vectorSize = vectorSize + return self + + def setLearningRate(self, learningRate): + """ + Sets initial learning rate (default: 0.025). + """ + self.learningRate = learningRate + return self + + def setNumPartitions(self, numPartitions): + """ + Sets number of partitions (default: 1). Use a small number for accuracy. + """ + self.numPartitions = numPartitions + return self + + def setNumIterations(self, numIterations): + """ + Sets number of iterations (default: 1), which should be smaller than or equal to number of + partitions. + """ + self.numIterations = numIterations + return self + + def setSeed(self, seed): + """ + Sets random seed. + """ + self.seed = seed + return self + + def fit(self, data): + """ + Computes the vector representation of each word in vocabulary. + + :param data: training data. RDD of subtype of Iterable[String] + :return: python Word2VecModel instance + """ + sc = data.context + ser = PickleSerializer() + vectorSize = self.vectorSize + learningRate = self.learningRate + numPartitions = self.numPartitions + numIterations = self.numIterations + seed = self.seed + + model = sc._jvm.PythonMLLibAPI().trainWord2Vec( + _to_java_object_rdd(data), vectorSize, + learningRate, numPartitions, numIterations, seed) + return Word2VecModel(sc, model) + + +def _test(): + import doctest + from pyspark import SparkContext + globs = globals().copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 51014a8ceb785..773d8d393805d 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -29,6 +29,8 @@ import numpy as np +from pyspark.serializers import AutoBatchedSerializer, PickleSerializer + __all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors'] @@ -50,6 +52,17 @@ def fast_pickle_array(ar): _have_scipy = False +# this will call the MLlib version of pythonToJava() +def _to_java_object_rdd(rdd): + """ Return an JavaRDD of Object by unpickling + + It will convert each Python object into Java object by Pyrolite, whenever the + RDD is serialized in batch or not. + """ + rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer())) + return rdd.ctx._jvm.SerDe.pythonToJava(rdd._jrdd, True) + + def _convert_to_vector(l): if isinstance(l, Vector): return l @@ -238,8 +251,8 @@ def __init__(self, size, *args): (index, value) pairs, or two separate arrays of indices and values (sorted by index). - @param size: Size of the vector. - @param args: Non-zero entries, as a dictionary, list of tupes, + :param size: Size of the vector. + :param args: Non-zero entries, as a dictionary, list of tupes, or two sorted lists containing indices and values. >>> print SparseVector(4, {1: 1.0, 3: 5.5}) @@ -458,8 +471,8 @@ def sparse(size, *args): (index, value) pairs, or two separate arrays of indices and values (sorted by index). - @param size: Size of the vector. - @param args: Non-zero entries, as a dictionary, list of tupes, + :param size: Size of the vector. + :param args: Non-zero entries, as a dictionary, list of tupes, or two sorted lists containing indices and values. >>> print Vectors.sparse(4, {1: 1.0, 3: 5.5}) diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py index a787e4dea2c55..73baba4ace5f6 100644 --- a/python/pyspark/mllib/random.py +++ b/python/pyspark/mllib/random.py @@ -32,7 +32,7 @@ def serialize(f): @wraps(f) def func(sc, *a, **kw): jrdd = f(sc, *a, **kw) - return RDD(sc._jvm.PythonRDD.javaToPython(jrdd), sc, + return RDD(sc._jvm.SerDe.javaToPython(jrdd), sc, BatchedSerializer(PickleSerializer(), 1024)) return func diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 59c1c5ff0ced0..17f96b8700bd7 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -18,6 +18,7 @@ from pyspark import SparkContext from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.rdd import RDD +from pyspark.mllib.linalg import _to_java_object_rdd __all__ = ['MatrixFactorizationModel', 'ALS'] @@ -77,9 +78,9 @@ def predictAll(self, user_product): first = tuple(map(int, first)) assert all(type(x) is int for x in first), "user and product in user_product shoul be int" sc = self._context - tuplerdd = sc._jvm.SerDe.asTupleRDD(user_product._to_java_object_rdd().rdd()) + tuplerdd = sc._jvm.SerDe.asTupleRDD(_to_java_object_rdd(user_product).rdd()) jresult = self._java_model.predict(tuplerdd).toJavaRDD() - return RDD(sc._jvm.PythonRDD.javaToPython(jresult), sc, + return RDD(sc._jvm.SerDe.javaToPython(jresult), sc, AutoBatchedSerializer(PickleSerializer())) @@ -97,7 +98,7 @@ def _prepare(cls, ratings): # serialize them by AutoBatchedSerializer before cache to reduce the # objects overhead in JVM cached = ratings._reserialize(AutoBatchedSerializer(PickleSerializer())).cache() - return cached._to_java_object_rdd() + return _to_java_object_rdd(cached) @classmethod def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1): diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 54f34a98337ca..93e17faf5cd51 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -19,8 +19,8 @@ from numpy import array from pyspark import SparkContext -from pyspark.mllib.linalg import SparseVector, _convert_to_vector from pyspark.serializers import PickleSerializer, AutoBatchedSerializer +from pyspark.mllib.linalg import SparseVector, _convert_to_vector, _to_java_object_rdd __all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'RidgeRegressionModel', 'LinearRegressionWithSGD', 'LassoWithSGD', 'RidgeRegressionWithSGD'] @@ -31,8 +31,8 @@ class LabeledPoint(object): """ The features and labels of a data point. - @param label: Label for this data point. - @param features: Vector of features for this point (NumPy array, list, + :param label: Label for this data point. + :param features: Vector of features for this point (NumPy array, list, pyspark.mllib.linalg.SparseVector, or scipy.sparse column matrix) """ @@ -131,7 +131,7 @@ def _regression_train_wrapper(sc, train_func, modelClass, data, initial_weights) # use AutoBatchedSerializer before cache to reduce the memory # overhead in JVM cached = data._reserialize(AutoBatchedSerializer(ser)).cache() - ans = train_func(cached._to_java_object_rdd(), initial_bytes) + ans = train_func(_to_java_object_rdd(cached), initial_bytes) assert len(ans) == 2, "JVM call result had unexpected length" weights = ser.loads(str(ans[0])) return modelClass(weights, ans[1]) @@ -145,15 +145,15 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, """ Train a linear regression model on the given data. - @param data: The training data. - @param iterations: The number of iterations (default: 100). - @param step: The step parameter used in SGD + :param data: The training data. + :param iterations: The number of iterations (default: 100). + :param step: The step parameter used in SGD (default: 1.0). - @param miniBatchFraction: Fraction of data to be used for each SGD + :param miniBatchFraction: Fraction of data to be used for each SGD iteration. - @param initialWeights: The initial weights (default: None). - @param regParam: The regularizer parameter (default: 1.0). - @param regType: The type of regularizer used for training + :param initialWeights: The initial weights (default: None). + :param regParam: The regularizer parameter (default: 1.0). + :param regType: The type of regularizer used for training our model. :Allowed values: diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py index b9de0909a6fb1..a6019dadf781c 100644 --- a/python/pyspark/mllib/stat.py +++ b/python/pyspark/mllib/stat.py @@ -22,6 +22,7 @@ from functools import wraps from pyspark import PickleSerializer +from pyspark.mllib.linalg import _to_java_object_rdd __all__ = ['MultivariateStatisticalSummary', 'Statistics'] @@ -106,7 +107,7 @@ def colStats(rdd): array([ 2., 0., 0., -2.]) """ sc = rdd.ctx - jrdd = rdd._to_java_object_rdd() + jrdd = _to_java_object_rdd(rdd) cStats = sc._jvm.PythonMLLibAPI().colStats(jrdd) return MultivariateStatisticalSummary(sc, cStats) @@ -162,14 +163,14 @@ def corr(x, y=None, method=None): if type(y) == str: raise TypeError("Use 'method=' to specify method name.") - jx = x._to_java_object_rdd() + jx = _to_java_object_rdd(x) if not y: resultMat = sc._jvm.PythonMLLibAPI().corr(jx, method) bytes = sc._jvm.SerDe.dumps(resultMat) ser = PickleSerializer() return ser.loads(str(bytes)).toArray() else: - jy = y._to_java_object_rdd() + jy = _to_java_object_rdd(y) return sc._jvm.PythonMLLibAPI().corr(jx, jy, method) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 5c20e100e144f..463faf7b6f520 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -25,7 +25,11 @@ from numpy import array, array_equal if sys.version_info[:2] <= (2, 6): - import unittest2 as unittest + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) else: import unittest diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 5d7abfb96b7fe..64ee79d83e849 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -19,7 +19,7 @@ from pyspark import SparkContext, RDD from pyspark.serializers import BatchedSerializer, PickleSerializer -from pyspark.mllib.linalg import Vector, _convert_to_vector +from pyspark.mllib.linalg import Vector, _convert_to_vector, _to_java_object_rdd from pyspark.mllib.regression import LabeledPoint __all__ = ['DecisionTreeModel', 'DecisionTree'] @@ -61,8 +61,8 @@ def predict(self, x): return self._sc.parallelize([]) if not isinstance(first[0], Vector): x = x.map(_convert_to_vector) - jPred = self._java_model.predict(x._to_java_object_rdd()).toJavaRDD() - jpyrdd = self._sc._jvm.PythonRDD.javaToPython(jPred) + jPred = self._java_model.predict(_to_java_object_rdd(x)).toJavaRDD() + jpyrdd = self._sc._jvm.SerDe.javaToPython(jPred) return RDD(jpyrdd, self._sc, BatchedSerializer(ser, 1024)) else: @@ -104,7 +104,7 @@ def _train(data, type, numClasses, categoricalFeaturesInfo, first = data.first() assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint" sc = data.context - jrdd = data._to_java_object_rdd() + jrdd = _to_java_object_rdd(data) cfiMap = MapConverter().convert(categoricalFeaturesInfo, sc._gateway._gateway_client) model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel( @@ -153,9 +153,9 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo, DecisionTreeModel classifier of depth 1 with 3 nodes >>> print model.toDebugString(), # it already has newline DecisionTreeModel classifier of depth 1 with 3 nodes - If (feature 0 <= 0.5) + If (feature 0 <= 0.0) Predict: 0.0 - Else (feature 0 > 0.5) + Else (feature 0 > 0.0) Predict: 1.0 >>> model.predict(array([1.0])) > 0 True diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 8233d4e81f1ca..84b39a48619d2 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -19,7 +19,7 @@ import warnings from pyspark.rdd import RDD -from pyspark.serializers import BatchedSerializer, PickleSerializer +from pyspark.serializers import AutoBatchedSerializer, PickleSerializer from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector from pyspark.mllib.regression import LabeledPoint @@ -77,10 +77,10 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None, multiclass=None method parses each line into a LabeledPoint, where the feature indices are converted to zero-based. - @param sc: Spark context - @param path: file or directory path in any Hadoop-supported file + :param sc: Spark context + :param path: file or directory path in any Hadoop-supported file system URI - @param numFeatures: number of features, which will be determined + :param numFeatures: number of features, which will be determined from the input data if a nonpositive value is given. This is useful when the dataset is already split into multiple files and you @@ -88,7 +88,7 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None, multiclass=None features may not present in certain files, which leads to inconsistent feature dimensions. - @param minPartitions: min number of partitions + :param minPartitions: min number of partitions @return: labeled data stored as an RDD of LabeledPoint >>> from tempfile import NamedTemporaryFile @@ -126,8 +126,8 @@ def saveAsLibSVMFile(data, dir): """ Save labeled data in LIBSVM format. - @param data: an RDD of LabeledPoint to be saved - @param dir: directory to save the data + :param data: an RDD of LabeledPoint to be saved + :param dir: directory to save the data >>> from tempfile import NamedTemporaryFile >>> from fileinput import input @@ -149,10 +149,10 @@ def loadLabeledPoints(sc, path, minPartitions=None): """ Load labeled points saved using RDD.saveAsTextFile. - @param sc: Spark context - @param path: file or directory path in any Hadoop-supported file + :param sc: Spark context + :param path: file or directory path in any Hadoop-supported file system URI - @param minPartitions: min number of partitions + :param minPartitions: min number of partitions @return: labeled data stored as an RDD of LabeledPoint >>> from tempfile import NamedTemporaryFile @@ -174,8 +174,8 @@ def loadLabeledPoints(sc, path, minPartitions=None): """ minPartitions = minPartitions or min(sc.defaultParallelism, 2) jrdd = sc._jvm.PythonMLLibAPI().loadLabeledPoints(sc._jsc, path, minPartitions) - jpyrdd = sc._jvm.PythonRDD.javaToPython(jrdd) - return RDD(jpyrdd, sc, BatchedSerializer(PickleSerializer())) + jpyrdd = sc._jvm.SerDe.javaToPython(jrdd) + return RDD(jpyrdd, sc, AutoBatchedSerializer(PickleSerializer())) def _test(): diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index e77669aad76b6..15be4bfec92f9 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -752,7 +752,7 @@ def max(self, key=None): """ Find the maximum item in this RDD. - @param key: A function used to generate key for comparing + :param key: A function used to generate key for comparing >>> rdd = sc.parallelize([1.0, 5.0, 43.0, 10.0]) >>> rdd.max() @@ -768,7 +768,7 @@ def min(self, key=None): """ Find the minimum item in this RDD. - @param key: A function used to generate key for comparing + :param key: A function used to generate key for comparing >>> rdd = sc.parallelize([2.0, 5.0, 43.0, 10.0]) >>> rdd.min() @@ -1070,10 +1070,13 @@ def take(self, num): # If we didn't find any rows after the previous iteration, # quadruple and retry. Otherwise, interpolate the number of # partitions we need to try, but overestimate it by 50%. + # We also cap the estimation in the end. if len(items) == 0: numPartsToTry = partsScanned * 4 else: - numPartsToTry = int(1.5 * num * partsScanned / len(items)) + # the first paramter of max is >=1 whenever partsScanned >= 2 + numPartsToTry = int(1.5 * num * partsScanned / len(items)) - partsScanned + numPartsToTry = min(max(numPartsToTry, 1), partsScanned * 4) left = num - len(items) @@ -1115,9 +1118,9 @@ def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None converted for output using either user specified converters or, by default, L{org.apache.spark.api.python.JavaToWritableConverter}. - @param conf: Hadoop job configuration, passed in as a dict - @param keyConverter: (None by default) - @param valueConverter: (None by default) + :param conf: Hadoop job configuration, passed in as a dict + :param keyConverter: (None by default) + :param valueConverter: (None by default) """ jconf = self.ctx._dictToJavaMap(conf) pickledRDD = self._toPickleSerialization() @@ -1135,16 +1138,16 @@ def saveAsNewAPIHadoopFile(self, path, outputFormatClass, keyClass=None, valueCl C{conf} is applied on top of the base Hadoop conf associated with the SparkContext of this RDD to create a merged Hadoop MapReduce job configuration for saving the data. - @param path: path to Hadoop file - @param outputFormatClass: fully qualified classname of Hadoop OutputFormat + :param path: path to Hadoop file + :param outputFormatClass: fully qualified classname of Hadoop OutputFormat (e.g. "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat") - @param keyClass: fully qualified classname of key Writable class + :param keyClass: fully qualified classname of key Writable class (e.g. "org.apache.hadoop.io.IntWritable", None by default) - @param valueClass: fully qualified classname of value Writable class + :param valueClass: fully qualified classname of value Writable class (e.g. "org.apache.hadoop.io.Text", None by default) - @param keyConverter: (None by default) - @param valueConverter: (None by default) - @param conf: Hadoop job configuration, passed in as a dict (None by default) + :param keyConverter: (None by default) + :param valueConverter: (None by default) + :param conf: Hadoop job configuration, passed in as a dict (None by default) """ jconf = self.ctx._dictToJavaMap(conf) pickledRDD = self._toPickleSerialization() @@ -1161,9 +1164,9 @@ def saveAsHadoopDataset(self, conf, keyConverter=None, valueConverter=None): converted for output using either user specified converters or, by default, L{org.apache.spark.api.python.JavaToWritableConverter}. - @param conf: Hadoop job configuration, passed in as a dict - @param keyConverter: (None by default) - @param valueConverter: (None by default) + :param conf: Hadoop job configuration, passed in as a dict + :param keyConverter: (None by default) + :param valueConverter: (None by default) """ jconf = self.ctx._dictToJavaMap(conf) pickledRDD = self._toPickleSerialization() @@ -1182,17 +1185,17 @@ def saveAsHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=No C{conf} is applied on top of the base Hadoop conf associated with the SparkContext of this RDD to create a merged Hadoop MapReduce job configuration for saving the data. - @param path: path to Hadoop file - @param outputFormatClass: fully qualified classname of Hadoop OutputFormat + :param path: path to Hadoop file + :param outputFormatClass: fully qualified classname of Hadoop OutputFormat (e.g. "org.apache.hadoop.mapred.SequenceFileOutputFormat") - @param keyClass: fully qualified classname of key Writable class + :param keyClass: fully qualified classname of key Writable class (e.g. "org.apache.hadoop.io.IntWritable", None by default) - @param valueClass: fully qualified classname of value Writable class + :param valueClass: fully qualified classname of value Writable class (e.g. "org.apache.hadoop.io.Text", None by default) - @param keyConverter: (None by default) - @param valueConverter: (None by default) - @param conf: (None by default) - @param compressionCodecClass: (None by default) + :param keyConverter: (None by default) + :param valueConverter: (None by default) + :param conf: (None by default) + :param compressionCodecClass: (None by default) """ jconf = self.ctx._dictToJavaMap(conf) pickledRDD = self._toPickleSerialization() @@ -1212,8 +1215,8 @@ def saveAsSequenceFile(self, path, compressionCodecClass=None): 1. Pyrolite is used to convert pickled Python RDD into RDD of Java objects. 2. Keys and values of this Java RDD are converted to Writables and written out. - @param path: path to sequence file - @param compressionCodecClass: (None by default) + :param path: path to sequence file + :param compressionCodecClass: (None by default) """ pickledRDD = self._toPickleSerialization() batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) @@ -2009,7 +2012,7 @@ def countApproxDistinct(self, relativeSD=0.05): of The Art Cardinality Estimation Algorithm", available here. - @param relativeSD Relative accuracy. Smaller values create + :param relativeSD: Relative accuracy. Smaller values create counters that require more space. It must be greater than 0.000017. diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 099fa54cf2bd7..08a0f0d8ffb3e 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -114,6 +114,9 @@ def __ne__(self, other): def __repr__(self): return "<%s object>" % self.__class__.__name__ + def __hash__(self): + return hash(str(self)) + class FramedSerializer(Serializer): @@ -220,7 +223,7 @@ class AutoBatchedSerializer(BatchedSerializer): Choose the size of batch automatically based on the size of object """ - def __init__(self, serializer, bestSize=1 << 20): + def __init__(self, serializer, bestSize=1 << 16): BatchedSerializer.__init__(self, serializer, -1) self.bestSize = bestSize @@ -247,7 +250,7 @@ def __eq__(self, other): other.serializer == self.serializer) def __str__(self): - return "BatchedSerializer<%s>" % str(self.serializer) + return "AutoBatchedSerializer<%s>" % str(self.serializer) class CartesianDeserializer(FramedSerializer): diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index c13c4fe9cc66f..b9e80769aa965 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -15,28 +15,38 @@ # limitations under the License. # +""" +public classes of Spark SQL: + + - L{SQLContext} + Main entry point for SQL functionality. + - L{SchemaRDD} + A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In + addition to normal RDD operations, SchemaRDDs also support SQL. + - L{Row} + A Row of data returned by a Spark SQL query. + - L{HiveContext} + Main entry point for accessing data stored in Apache Hive.. +""" -import sys -import types import itertools -import warnings import decimal import datetime import keyword import warnings +import json from array import array from operator import itemgetter +from itertools import imap + +from py4j.protocol import Py4JError +from py4j.java_collections import ListConverter, MapConverter from pyspark.rdd import RDD from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync -from itertools import chain, ifilter, imap - -from py4j.protocol import Py4JError -from py4j.java_collections import ListConverter, MapConverter - __all__ = [ "StringType", "BinaryType", "BooleanType", "TimestampType", "DecimalType", @@ -62,6 +72,18 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) + @classmethod + def typeName(cls): + return cls.__name__[:-4].lower() + + def jsonValue(self): + return self.typeName() + + def json(self): + return json.dumps(self.jsonValue(), + separators=(',', ':'), + sort_keys=True) + class PrimitiveTypeSingleton(type): @@ -205,6 +227,16 @@ def __repr__(self): return "ArrayType(%s,%s)" % (self.elementType, str(self.containsNull).lower()) + def jsonValue(self): + return {"type": self.typeName(), + "elementType": self.elementType.jsonValue(), + "containsNull": self.containsNull} + + @classmethod + def fromJson(cls, json): + return ArrayType(_parse_datatype_json_value(json["elementType"]), + json["containsNull"]) + class MapType(DataType): @@ -245,6 +277,18 @@ def __repr__(self): return "MapType(%s,%s,%s)" % (self.keyType, self.valueType, str(self.valueContainsNull).lower()) + def jsonValue(self): + return {"type": self.typeName(), + "keyType": self.keyType.jsonValue(), + "valueType": self.valueType.jsonValue(), + "valueContainsNull": self.valueContainsNull} + + @classmethod + def fromJson(cls, json): + return MapType(_parse_datatype_json_value(json["keyType"]), + _parse_datatype_json_value(json["valueType"]), + json["valueContainsNull"]) + class StructField(DataType): @@ -283,6 +327,17 @@ def __repr__(self): return "StructField(%s,%s,%s)" % (self.name, self.dataType, str(self.nullable).lower()) + def jsonValue(self): + return {"name": self.name, + "type": self.dataType.jsonValue(), + "nullable": self.nullable} + + @classmethod + def fromJson(cls, json): + return StructField(json["name"], + _parse_datatype_json_value(json["type"]), + json["nullable"]) + class StructType(DataType): @@ -312,42 +367,30 @@ def __repr__(self): return ("StructType(List(%s))" % ",".join(str(field) for field in self.fields)) + def jsonValue(self): + return {"type": self.typeName(), + "fields": [f.jsonValue() for f in self.fields]} -def _parse_datatype_list(datatype_list_string): - """Parses a list of comma separated data types.""" - index = 0 - datatype_list = [] - start = 0 - depth = 0 - while index < len(datatype_list_string): - if depth == 0 and datatype_list_string[index] == ",": - datatype_string = datatype_list_string[start:index].strip() - datatype_list.append(_parse_datatype_string(datatype_string)) - start = index + 1 - elif datatype_list_string[index] == "(": - depth += 1 - elif datatype_list_string[index] == ")": - depth -= 1 - - index += 1 + @classmethod + def fromJson(cls, json): + return StructType([StructField.fromJson(f) for f in json["fields"]]) - # Handle the last data type - datatype_string = datatype_list_string[start:index].strip() - datatype_list.append(_parse_datatype_string(datatype_string)) - return datatype_list +_all_primitive_types = dict((v.typeName(), v) + for v in globals().itervalues() + if type(v) is PrimitiveTypeSingleton and + v.__base__ == PrimitiveType) -_all_primitive_types = dict((k, v) for k, v in globals().iteritems() - if type(v) is PrimitiveTypeSingleton and v.__base__ == PrimitiveType) +_all_complex_types = dict((v.typeName(), v) + for v in [ArrayType, MapType, StructType]) -def _parse_datatype_string(datatype_string): - """Parses the given data type string. +def _parse_datatype_json_string(json_string): + """Parses the given data type JSON string. >>> def check_datatype(datatype): - ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(str(datatype)) - ... python_datatype = _parse_datatype_string( - ... scala_datatype.toString()) + ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json()) + ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) ... return datatype == python_datatype >>> all(check_datatype(cls()) for cls in _all_primitive_types.values()) True @@ -385,51 +428,14 @@ def _parse_datatype_string(datatype_string): >>> check_datatype(complex_maptype) True """ - index = datatype_string.find("(") - if index == -1: - # It is a primitive type. - index = len(datatype_string) - type_or_field = datatype_string[:index] - rest_part = datatype_string[index + 1:len(datatype_string) - 1].strip() - - if type_or_field in _all_primitive_types: - return _all_primitive_types[type_or_field]() - - elif type_or_field == "ArrayType": - last_comma_index = rest_part.rfind(",") - containsNull = True - if rest_part[last_comma_index + 1:].strip().lower() == "false": - containsNull = False - elementType = _parse_datatype_string( - rest_part[:last_comma_index].strip()) - return ArrayType(elementType, containsNull) - - elif type_or_field == "MapType": - last_comma_index = rest_part.rfind(",") - valueContainsNull = True - if rest_part[last_comma_index + 1:].strip().lower() == "false": - valueContainsNull = False - keyType, valueType = _parse_datatype_list( - rest_part[:last_comma_index].strip()) - return MapType(keyType, valueType, valueContainsNull) - - elif type_or_field == "StructField": - first_comma_index = rest_part.find(",") - name = rest_part[:first_comma_index].strip() - last_comma_index = rest_part.rfind(",") - nullable = True - if rest_part[last_comma_index + 1:].strip().lower() == "false": - nullable = False - dataType = _parse_datatype_string( - rest_part[first_comma_index + 1:last_comma_index].strip()) - return StructField(name, dataType, nullable) - - elif type_or_field == "StructType": - # rest_part should be in the format like - # List(StructField(field1,IntegerType,false)). - field_list_string = rest_part[rest_part.find("(") + 1:-1] - fields = _parse_datatype_list(field_list_string) - return StructType(fields) + return _parse_datatype_json_value(json.loads(json_string)) + + +def _parse_datatype_json_value(json_value): + if type(json_value) is unicode and json_value in _all_primitive_types.keys(): + return _all_primitive_types[json_value]() + else: + return _all_complex_types[json_value["type"]].fromJson(json_value) # Mapping Python types to Spark SQL DateType @@ -899,8 +905,8 @@ class SQLContext(object): def __init__(self, sparkContext, sqlContext=None): """Create a new SQLContext. - @param sparkContext: The SparkContext to wrap. - @param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new + :param sparkContext: The SparkContext to wrap. + :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new SQLContext in the JVM, instead we make all calls to this object. >>> srdd = sqlCtx.inferSchema(rdd) @@ -983,7 +989,7 @@ def registerFunction(self, name, f, returnType=StringType()): self._sc.pythonExec, broadcast_vars, self._sc._javaAccumulator, - str(returnType)) + returnType.json()) def inferSchema(self, rdd): """Infer and apply a schema to an RDD of L{Row}. @@ -1119,7 +1125,7 @@ def applySchema(self, rdd, schema): batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer) jrdd = self._pythonToJava(rdd._jrdd, batched) - srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), str(schema)) + srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) return SchemaRDD(srdd.toJavaSchemaRDD(), self) def registerRDDAsTable(self, rdd, tableName): @@ -1209,7 +1215,7 @@ def jsonFile(self, path, schema=None): if schema is None: srdd = self._ssql_ctx.jsonFile(path) else: - scala_datatype = self._ssql_ctx.parseDataType(str(schema)) + scala_datatype = self._ssql_ctx.parseDataType(schema.json()) srdd = self._ssql_ctx.jsonFile(path, scala_datatype) return SchemaRDD(srdd.toJavaSchemaRDD(), self) @@ -1279,7 +1285,7 @@ def func(iterator): if schema is None: srdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) else: - scala_datatype = self._ssql_ctx.parseDataType(str(schema)) + scala_datatype = self._ssql_ctx.parseDataType(schema.json()) srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) return SchemaRDD(srdd.toJavaSchemaRDD(), self) @@ -1325,8 +1331,8 @@ class HiveContext(SQLContext): def __init__(self, sparkContext, hiveContext=None): """Create a new HiveContext. - @param sparkContext: The SparkContext to wrap. - @param hiveContext: An optional JVM Scala HiveContext. If set, we do not instatiate a new + :param sparkContext: The SparkContext to wrap. + :param hiveContext: An optional JVM Scala HiveContext. If set, we do not instatiate a new HiveContext in the JVM, instead we make all calls to this object. """ SQLContext.__init__(self, sparkContext) @@ -1660,7 +1666,7 @@ def saveAsTable(self, tableName): def schema(self): """Returns the schema of this SchemaRDD (represented by a L{StructType}).""" - return _parse_datatype_string(self._jschema_rdd.baseSchemaRDD().schema().toString()) + return _parse_datatype_json_string(self._jschema_rdd.baseSchemaRDD().schema().json()) def schemaString(self): """Returns the output schema in the tree format.""" diff --git a/python/epydoc.conf b/python/pyspark/streaming/__init__.py similarity index 55% rename from python/epydoc.conf rename to python/pyspark/streaming/__init__.py index 8593e08deda19..d2644a1d4ffab 100644 --- a/python/epydoc.conf +++ b/python/pyspark/streaming/__init__.py @@ -1,5 +1,3 @@ -[epydoc] # Epydoc section marker (required by ConfigParser) - # # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with @@ -17,22 +15,7 @@ # limitations under the License. # -# Information about the project. -name: Spark 1.0.0 Python API Docs -url: http://spark.apache.org - -# The list of modules to document. Modules can be named using -# dotted names, module filenames, or package directory names. -# This option may be repeated. -modules: pyspark - -# Write html output to the directory "apidocs" -output: html -target: docs/ - -private: no +from pyspark.streaming.context import StreamingContext +from pyspark.streaming.dstream import DStream -exclude: pyspark.cloudpickle pyspark.worker pyspark.join - pyspark.java_gateway pyspark.examples pyspark.shell pyspark.tests - pyspark.rddsampler pyspark.daemon - pyspark.mllib.tests pyspark.shuffle +__all__ = ['StreamingContext', 'DStream'] diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py new file mode 100644 index 0000000000000..dc9dc41121935 --- /dev/null +++ b/python/pyspark/streaming/context.py @@ -0,0 +1,325 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import sys + +from py4j.java_collections import ListConverter +from py4j.java_gateway import java_import, JavaObject + +from pyspark import RDD, SparkConf +from pyspark.serializers import UTF8Deserializer, CloudPickleSerializer +from pyspark.context import SparkContext +from pyspark.storagelevel import StorageLevel +from pyspark.streaming.dstream import DStream +from pyspark.streaming.util import TransformFunction, TransformFunctionSerializer + +__all__ = ["StreamingContext"] + + +def _daemonize_callback_server(): + """ + Hack Py4J to daemonize callback server + + The thread of callback server has daemon=False, it will block the driver + from exiting if it's not shutdown. The following code replace `start()` + of CallbackServer with a new version, which set daemon=True for this + thread. + + Also, it will update the port number (0) with real port + """ + # TODO: create a patch for Py4J + import socket + import py4j.java_gateway + logger = py4j.java_gateway.logger + from py4j.java_gateway import Py4JNetworkError + from threading import Thread + + def start(self): + """Starts the CallbackServer. This method should be called by the + client instead of run().""" + self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, + 1) + try: + self.server_socket.bind((self.address, self.port)) + if not self.port: + # update port with real port + self.port = self.server_socket.getsockname()[1] + except Exception as e: + msg = 'An error occurred while trying to start the callback server: %s' % e + logger.exception(msg) + raise Py4JNetworkError(msg) + + # Maybe thread needs to be cleanup up? + self.thread = Thread(target=self.run) + self.thread.daemon = True + self.thread.start() + + py4j.java_gateway.CallbackServer.start = start + + +class StreamingContext(object): + """ + Main entry point for Spark Streaming functionality. A StreamingContext + represents the connection to a Spark cluster, and can be used to create + L{DStream} various input sources. It can be from an existing L{SparkContext}. + After creating and transforming DStreams, the streaming computation can + be started and stopped using `context.start()` and `context.stop()`, + respectively. `context.awaitTransformation()` allows the current thread + to wait for the termination of the context by `stop()` or by an exception. + """ + _transformerSerializer = None + + def __init__(self, sparkContext, batchDuration=None, jssc=None): + """ + Create a new StreamingContext. + + @param sparkContext: L{SparkContext} object. + @param batchDuration: the time interval (in seconds) at which streaming + data will be divided into batches + """ + + self._sc = sparkContext + self._jvm = self._sc._jvm + self._jssc = jssc or self._initialize_context(self._sc, batchDuration) + + def _initialize_context(self, sc, duration): + self._ensure_initialized() + return self._jvm.JavaStreamingContext(sc._jsc, self._jduration(duration)) + + def _jduration(self, seconds): + """ + Create Duration object given number of seconds + """ + return self._jvm.Duration(int(seconds * 1000)) + + @classmethod + def _ensure_initialized(cls): + SparkContext._ensure_initialized() + gw = SparkContext._gateway + + java_import(gw.jvm, "org.apache.spark.streaming.*") + java_import(gw.jvm, "org.apache.spark.streaming.api.java.*") + java_import(gw.jvm, "org.apache.spark.streaming.api.python.*") + + # start callback server + # getattr will fallback to JVM, so we cannot test by hasattr() + if "_callback_server" not in gw.__dict__: + _daemonize_callback_server() + # use random port + gw._start_callback_server(0) + # gateway with real port + gw._python_proxy_port = gw._callback_server.port + # get the GatewayServer object in JVM by ID + jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client) + # update the port of CallbackClient with real port + gw.jvm.PythonDStream.updatePythonGatewayPort(jgws, gw._python_proxy_port) + + # register serializer for TransformFunction + # it happens before creating SparkContext when loading from checkpointing + cls._transformerSerializer = TransformFunctionSerializer( + SparkContext._active_spark_context, CloudPickleSerializer(), gw) + + @classmethod + def getOrCreate(cls, checkpointPath, setupFunc): + """ + Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + recreated from the checkpoint data. If the data does not exist, then the provided setupFunc + will be used to create a JavaStreamingContext. + + @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program + @param setupFunc Function to create a new JavaStreamingContext and setup DStreams + """ + # TODO: support checkpoint in HDFS + if not os.path.exists(checkpointPath) or not os.listdir(checkpointPath): + ssc = setupFunc() + ssc.checkpoint(checkpointPath) + return ssc + + cls._ensure_initialized() + gw = SparkContext._gateway + + try: + jssc = gw.jvm.JavaStreamingContext(checkpointPath) + except Exception: + print >>sys.stderr, "failed to load StreamingContext from checkpoint" + raise + + jsc = jssc.sparkContext() + conf = SparkConf(_jconf=jsc.getConf()) + sc = SparkContext(conf=conf, gateway=gw, jsc=jsc) + # update ctx in serializer + SparkContext._active_spark_context = sc + cls._transformerSerializer.ctx = sc + return StreamingContext(sc, None, jssc) + + @property + def sparkContext(self): + """ + Return SparkContext which is associated with this StreamingContext. + """ + return self._sc + + def start(self): + """ + Start the execution of the streams. + """ + self._jssc.start() + + def awaitTermination(self, timeout=None): + """ + Wait for the execution to stop. + @param timeout: time to wait in seconds + """ + if timeout is None: + self._jssc.awaitTermination() + else: + self._jssc.awaitTermination(int(timeout * 1000)) + + def stop(self, stopSparkContext=True, stopGraceFully=False): + """ + Stop the execution of the streams, with option of ensuring all + received data has been processed. + + @param stopSparkContext: Stop the associated SparkContext or not + @param stopGracefully: Stop gracefully by waiting for the processing + of all received data to be completed + """ + self._jssc.stop(stopSparkContext, stopGraceFully) + if stopSparkContext: + self._sc.stop() + + def remember(self, duration): + """ + Set each DStreams in this context to remember RDDs it generated + in the last given duration. DStreams remember RDDs only for a + limited duration of time and releases them for garbage collection. + This method allows the developer to specify how to long to remember + the RDDs (if the developer wishes to query old data outside the + DStream computation). + + @param duration: Minimum duration (in seconds) that each DStream + should remember its RDDs + """ + self._jssc.remember(self._jduration(duration)) + + def checkpoint(self, directory): + """ + Sets the context to periodically checkpoint the DStream operations for master + fault-tolerance. The graph will be checkpointed every batch interval. + + @param directory: HDFS-compatible directory where the checkpoint data + will be reliably stored + """ + self._jssc.checkpoint(directory) + + def socketTextStream(self, hostname, port, storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2): + """ + Create an input from TCP source hostname:port. Data is received using + a TCP socket and receive byte is interpreted as UTF8 encoded ``\\n`` delimited + lines. + + @param hostname: Hostname to connect to for receiving data + @param port: Port to connect to for receiving data + @param storageLevel: Storage level to use for storing the received objects + """ + jlevel = self._sc._getJavaStorageLevel(storageLevel) + return DStream(self._jssc.socketTextStream(hostname, port, jlevel), self, + UTF8Deserializer()) + + def textFileStream(self, directory): + """ + Create an input stream that monitors a Hadoop-compatible file system + for new files and reads them as text files. Files must be wrriten to the + monitored directory by "moving" them from another location within the same + file system. File names starting with . are ignored. + """ + return DStream(self._jssc.textFileStream(directory), self, UTF8Deserializer()) + + def _check_serializers(self, rdds): + # make sure they have same serializer + if len(set(rdd._jrdd_deserializer for rdd in rdds)) > 1: + for i in range(len(rdds)): + # reset them to sc.serializer + rdds[i] = rdds[i]._reserialize() + + def queueStream(self, rdds, oneAtATime=True, default=None): + """ + Create an input stream from an queue of RDDs or list. In each batch, + it will process either one or all of the RDDs returned by the queue. + + NOTE: changes to the queue after the stream is created will not be recognized. + + @param rdds: Queue of RDDs + @param oneAtATime: pick one rdd each time or pick all of them once. + @param default: The default rdd if no more in rdds + """ + if default and not isinstance(default, RDD): + default = self._sc.parallelize(default) + + if not rdds and default: + rdds = [rdds] + + if rdds and not isinstance(rdds[0], RDD): + rdds = [self._sc.parallelize(input) for input in rdds] + self._check_serializers(rdds) + + jrdds = ListConverter().convert([r._jrdd for r in rdds], + SparkContext._gateway._gateway_client) + queue = self._jvm.PythonDStream.toRDDQueue(jrdds) + if default: + default = default._reserialize(rdds[0]._jrdd_deserializer) + jdstream = self._jssc.queueStream(queue, oneAtATime, default._jrdd) + else: + jdstream = self._jssc.queueStream(queue, oneAtATime) + return DStream(jdstream, self, rdds[0]._jrdd_deserializer) + + def transform(self, dstreams, transformFunc): + """ + Create a new DStream in which each RDD is generated by applying + a function on RDDs of the DStreams. The order of the JavaRDDs in + the transform function parameter will be the same as the order + of corresponding DStreams in the list. + """ + jdstreams = ListConverter().convert([d._jdstream for d in dstreams], + SparkContext._gateway._gateway_client) + # change the final serializer to sc.serializer + func = TransformFunction(self._sc, + lambda t, *rdds: transformFunc(rdds).map(lambda x: x), + *[d._jrdd_deserializer for d in dstreams]) + jfunc = self._jvm.TransformFunction(func) + jdstream = self._jssc.transform(jdstreams, jfunc) + return DStream(jdstream, self, self._sc.serializer) + + def union(self, *dstreams): + """ + Create a unified DStream from multiple DStreams of the same + type and same slide duration. + """ + if not dstreams: + raise ValueError("should have at least one DStream to union") + if len(dstreams) == 1: + return dstreams[0] + if len(set(s._jrdd_deserializer for s in dstreams)) > 1: + raise ValueError("All DStreams should have same serializer") + if len(set(s._slideDuration for s in dstreams)) > 1: + raise ValueError("All DStreams should have same slide duration") + first = dstreams[0] + jrest = ListConverter().convert([d._jdstream for d in dstreams[1:]], + SparkContext._gateway._gateway_client) + return DStream(self._jssc.union(first._jdstream, jrest), self, first._jrdd_deserializer) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py new file mode 100644 index 0000000000000..0826ddc56e844 --- /dev/null +++ b/python/pyspark/streaming/dstream.py @@ -0,0 +1,623 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from itertools import chain, ifilter, imap +import operator +import time +from datetime import datetime + +from py4j.protocol import Py4JJavaError + +from pyspark import RDD +from pyspark.storagelevel import StorageLevel +from pyspark.streaming.util import rddToFileName, TransformFunction +from pyspark.rdd import portable_hash +from pyspark.resultiterable import ResultIterable + +__all__ = ["DStream"] + + +class DStream(object): + """ + A Discretized Stream (DStream), the basic abstraction in Spark Streaming, + is a continuous sequence of RDDs (of the same type) representing a + continuous stream of data (see L{RDD} in the Spark core documentation + for more details on RDDs). + + DStreams can either be created from live data (such as, data from TCP + sockets, Kafka, Flume, etc.) using a L{StreamingContext} or it can be + generated by transforming existing DStreams using operations such as + `map`, `window` and `reduceByKeyAndWindow`. While a Spark Streaming + program is running, each DStream periodically generates a RDD, either + from live data or by transforming the RDD generated by a parent DStream. + + DStreams internally is characterized by a few basic properties: + - A list of other DStreams that the DStream depends on + - A time interval at which the DStream generates an RDD + - A function that is used to generate an RDD after each time interval + """ + def __init__(self, jdstream, ssc, jrdd_deserializer): + self._jdstream = jdstream + self._ssc = ssc + self._sc = ssc._sc + self._jrdd_deserializer = jrdd_deserializer + self.is_cached = False + self.is_checkpointed = False + + def context(self): + """ + Return the StreamingContext associated with this DStream + """ + return self._ssc + + def count(self): + """ + Return a new DStream in which each RDD has a single element + generated by counting each RDD of this DStream. + """ + return self.mapPartitions(lambda i: [sum(1 for _ in i)]).reduce(operator.add) + + def filter(self, f): + """ + Return a new DStream containing only the elements that satisfy predicate. + """ + def func(iterator): + return ifilter(f, iterator) + return self.mapPartitions(func, True) + + def flatMap(self, f, preservesPartitioning=False): + """ + Return a new DStream by applying a function to all elements of + this DStream, and then flattening the results + """ + def func(s, iterator): + return chain.from_iterable(imap(f, iterator)) + return self.mapPartitionsWithIndex(func, preservesPartitioning) + + def map(self, f, preservesPartitioning=False): + """ + Return a new DStream by applying a function to each element of DStream. + """ + def func(iterator): + return imap(f, iterator) + return self.mapPartitions(func, preservesPartitioning) + + def mapPartitions(self, f, preservesPartitioning=False): + """ + Return a new DStream in which each RDD is generated by applying + mapPartitions() to each RDDs of this DStream. + """ + def func(s, iterator): + return f(iterator) + return self.mapPartitionsWithIndex(func, preservesPartitioning) + + def mapPartitionsWithIndex(self, f, preservesPartitioning=False): + """ + Return a new DStream in which each RDD is generated by applying + mapPartitionsWithIndex() to each RDDs of this DStream. + """ + return self.transform(lambda rdd: rdd.mapPartitionsWithIndex(f, preservesPartitioning)) + + def reduce(self, func): + """ + Return a new DStream in which each RDD has a single element + generated by reducing each RDD of this DStream. + """ + return self.map(lambda x: (None, x)).reduceByKey(func, 1).map(lambda x: x[1]) + + def reduceByKey(self, func, numPartitions=None): + """ + Return a new DStream by applying reduceByKey to each RDD. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + return self.combineByKey(lambda x: x, func, func, numPartitions) + + def combineByKey(self, createCombiner, mergeValue, mergeCombiners, + numPartitions=None): + """ + Return a new DStream by applying combineByKey to each RDD. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + + def func(rdd): + return rdd.combineByKey(createCombiner, mergeValue, mergeCombiners, numPartitions) + return self.transform(func) + + def partitionBy(self, numPartitions, partitionFunc=portable_hash): + """ + Return a copy of the DStream in which each RDD are partitioned + using the specified partitioner. + """ + return self.transform(lambda rdd: rdd.partitionBy(numPartitions, partitionFunc)) + + def foreachRDD(self, func): + """ + Apply a function to each RDD in this DStream. + """ + if func.func_code.co_argcount == 1: + old_func = func + func = lambda t, rdd: old_func(rdd) + jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer) + api = self._ssc._jvm.PythonDStream + api.callForeachRDD(self._jdstream, jfunc) + + def pprint(self): + """ + Print the first ten elements of each RDD generated in this DStream. + """ + def takeAndPrint(time, rdd): + taken = rdd.take(11) + print "-------------------------------------------" + print "Time: %s" % time + print "-------------------------------------------" + for record in taken[:10]: + print record + if len(taken) > 10: + print "..." + print + + self.foreachRDD(takeAndPrint) + + def mapValues(self, f): + """ + Return a new DStream by applying a map function to the value of + each key-value pairs in this DStream without changing the key. + """ + map_values_fn = lambda (k, v): (k, f(v)) + return self.map(map_values_fn, preservesPartitioning=True) + + def flatMapValues(self, f): + """ + Return a new DStream by applying a flatmap function to the value + of each key-value pairs in this DStream without changing the key. + """ + flat_map_fn = lambda (k, v): ((k, x) for x in f(v)) + return self.flatMap(flat_map_fn, preservesPartitioning=True) + + def glom(self): + """ + Return a new DStream in which RDD is generated by applying glom() + to RDD of this DStream. + """ + def func(iterator): + yield list(iterator) + return self.mapPartitions(func) + + def cache(self): + """ + Persist the RDDs of this DStream with the default storage level + (C{MEMORY_ONLY_SER}). + """ + self.is_cached = True + self.persist(StorageLevel.MEMORY_ONLY_SER) + return self + + def persist(self, storageLevel): + """ + Persist the RDDs of this DStream with the given storage level + """ + self.is_cached = True + javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) + self._jdstream.persist(javaStorageLevel) + return self + + def checkpoint(self, interval): + """ + Enable periodic checkpointing of RDDs of this DStream + + @param interval: time in seconds, after each period of that, generated + RDD will be checkpointed + """ + self.is_checkpointed = True + self._jdstream.checkpoint(self._ssc._jduration(interval)) + return self + + def groupByKey(self, numPartitions=None): + """ + Return a new DStream by applying groupByKey on each RDD. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + return self.transform(lambda rdd: rdd.groupByKey(numPartitions)) + + def countByValue(self): + """ + Return a new DStream in which each RDD contains the counts of each + distinct value in each RDD of this DStream. + """ + return self.map(lambda x: (x, None)).reduceByKey(lambda x, y: None).count() + + def saveAsTextFiles(self, prefix, suffix=None): + """ + Save each RDD in this DStream as at text file, using string + representation of elements. + """ + def saveAsTextFile(t, rdd): + path = rddToFileName(prefix, suffix, t) + try: + rdd.saveAsTextFile(path) + except Py4JJavaError as e: + # after recovered from checkpointing, the foreachRDD may + # be called twice + if 'FileAlreadyExistsException' not in str(e): + raise + return self.foreachRDD(saveAsTextFile) + + # TODO: uncomment this until we have ssc.pickleFileStream() + # def saveAsPickleFiles(self, prefix, suffix=None): + # """ + # Save each RDD in this DStream as at binary file, the elements are + # serialized by pickle. + # """ + # def saveAsPickleFile(t, rdd): + # path = rddToFileName(prefix, suffix, t) + # try: + # rdd.saveAsPickleFile(path) + # except Py4JJavaError as e: + # # after recovered from checkpointing, the foreachRDD may + # # be called twice + # if 'FileAlreadyExistsException' not in str(e): + # raise + # return self.foreachRDD(saveAsPickleFile) + + def transform(self, func): + """ + Return a new DStream in which each RDD is generated by applying a function + on each RDD of this DStream. + + `func` can have one argument of `rdd`, or have two arguments of + (`time`, `rdd`) + """ + if func.func_code.co_argcount == 1: + oldfunc = func + func = lambda t, rdd: oldfunc(rdd) + assert func.func_code.co_argcount == 2, "func should take one or two arguments" + return TransformedDStream(self, func) + + def transformWith(self, func, other, keepSerializer=False): + """ + Return a new DStream in which each RDD is generated by applying a function + on each RDD of this DStream and 'other' DStream. + + `func` can have two arguments of (`rdd_a`, `rdd_b`) or have three + arguments of (`time`, `rdd_a`, `rdd_b`) + """ + if func.func_code.co_argcount == 2: + oldfunc = func + func = lambda t, a, b: oldfunc(a, b) + assert func.func_code.co_argcount == 3, "func should take two or three arguments" + jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer, other._jrdd_deserializer) + dstream = self._sc._jvm.PythonTransformed2DStream(self._jdstream.dstream(), + other._jdstream.dstream(), jfunc) + jrdd_serializer = self._jrdd_deserializer if keepSerializer else self._sc.serializer + return DStream(dstream.asJavaDStream(), self._ssc, jrdd_serializer) + + def repartition(self, numPartitions): + """ + Return a new DStream with an increased or decreased level of parallelism. + """ + return self.transform(lambda rdd: rdd.repartition(numPartitions)) + + @property + def _slideDuration(self): + """ + Return the slideDuration in seconds of this DStream + """ + return self._jdstream.dstream().slideDuration().milliseconds() / 1000.0 + + def union(self, other): + """ + Return a new DStream by unifying data of another DStream with this DStream. + + @param other: Another DStream having the same interval (i.e., slideDuration) + as this DStream. + """ + if self._slideDuration != other._slideDuration: + raise ValueError("the two DStream should have same slide duration") + return self.transformWith(lambda a, b: a.union(b), other, True) + + def cogroup(self, other, numPartitions=None): + """ + Return a new DStream by applying 'cogroup' between RDDs of this + DStream and `other` DStream. + + Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + return self.transformWith(lambda a, b: a.cogroup(b, numPartitions), other) + + def join(self, other, numPartitions=None): + """ + Return a new DStream by applying 'join' between RDDs of this DStream and + `other` DStream. + + Hash partitioning is used to generate the RDDs with `numPartitions` + partitions. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + return self.transformWith(lambda a, b: a.join(b, numPartitions), other) + + def leftOuterJoin(self, other, numPartitions=None): + """ + Return a new DStream by applying 'left outer join' between RDDs of this DStream and + `other` DStream. + + Hash partitioning is used to generate the RDDs with `numPartitions` + partitions. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + return self.transformWith(lambda a, b: a.leftOuterJoin(b, numPartitions), other) + + def rightOuterJoin(self, other, numPartitions=None): + """ + Return a new DStream by applying 'right outer join' between RDDs of this DStream and + `other` DStream. + + Hash partitioning is used to generate the RDDs with `numPartitions` + partitions. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + return self.transformWith(lambda a, b: a.rightOuterJoin(b, numPartitions), other) + + def fullOuterJoin(self, other, numPartitions=None): + """ + Return a new DStream by applying 'full outer join' between RDDs of this DStream and + `other` DStream. + + Hash partitioning is used to generate the RDDs with `numPartitions` + partitions. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + return self.transformWith(lambda a, b: a.fullOuterJoin(b, numPartitions), other) + + def _jtime(self, timestamp): + """ Convert datetime or unix_timestamp into Time + """ + if isinstance(timestamp, datetime): + timestamp = time.mktime(timestamp.timetuple()) + return self._sc._jvm.Time(long(timestamp * 1000)) + + def slice(self, begin, end): + """ + Return all the RDDs between 'begin' to 'end' (both included) + + `begin`, `end` could be datetime.datetime() or unix_timestamp + """ + jrdds = self._jdstream.slice(self._jtime(begin), self._jtime(end)) + return [RDD(jrdd, self._sc, self._jrdd_deserializer) for jrdd in jrdds] + + def _validate_window_param(self, window, slide): + duration = self._jdstream.dstream().slideDuration().milliseconds() + if int(window * 1000) % duration != 0: + raise ValueError("windowDuration must be multiple of the slide duration (%d ms)" + % duration) + if slide and int(slide * 1000) % duration != 0: + raise ValueError("slideDuration must be multiple of the slide duration (%d ms)" + % duration) + + def window(self, windowDuration, slideDuration=None): + """ + Return a new DStream in which each RDD contains all the elements in seen in a + sliding window of time over this DStream. + + @param windowDuration: width of the window; must be a multiple of this DStream's + batching interval + @param slideDuration: sliding interval of the window (i.e., the interval after which + the new DStream will generate RDDs); must be a multiple of this + DStream's batching interval + """ + self._validate_window_param(windowDuration, slideDuration) + d = self._ssc._jduration(windowDuration) + if slideDuration is None: + return DStream(self._jdstream.window(d), self._ssc, self._jrdd_deserializer) + s = self._ssc._jduration(slideDuration) + return DStream(self._jdstream.window(d, s), self._ssc, self._jrdd_deserializer) + + def reduceByWindow(self, reduceFunc, invReduceFunc, windowDuration, slideDuration): + """ + Return a new DStream in which each RDD has a single element generated by reducing all + elements in a sliding window over this DStream. + + if `invReduceFunc` is not None, the reduction is done incrementally + using the old window's reduced value : + + 1. reduce the new values that entered the window (e.g., adding new counts) + + 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) + This is more efficient than `invReduceFunc` is None. + + @param reduceFunc: associative reduce function + @param invReduceFunc: inverse reduce function of `reduceFunc` + @param windowDuration: width of the window; must be a multiple of this DStream's + batching interval + @param slideDuration: sliding interval of the window (i.e., the interval after which + the new DStream will generate RDDs); must be a multiple of this + DStream's batching interval + """ + keyed = self.map(lambda x: (1, x)) + reduced = keyed.reduceByKeyAndWindow(reduceFunc, invReduceFunc, + windowDuration, slideDuration, 1) + return reduced.map(lambda (k, v): v) + + def countByWindow(self, windowDuration, slideDuration): + """ + Return a new DStream in which each RDD has a single element generated + by counting the number of elements in a window over this DStream. + windowDuration and slideDuration are as defined in the window() operation. + + This is equivalent to window(windowDuration, slideDuration).count(), + but will be more efficient if window is large. + """ + return self.map(lambda x: 1).reduceByWindow(operator.add, operator.sub, + windowDuration, slideDuration) + + def countByValueAndWindow(self, windowDuration, slideDuration, numPartitions=None): + """ + Return a new DStream in which each RDD contains the count of distinct elements in + RDDs in a sliding window over this DStream. + + @param windowDuration: width of the window; must be a multiple of this DStream's + batching interval + @param slideDuration: sliding interval of the window (i.e., the interval after which + the new DStream will generate RDDs); must be a multiple of this + DStream's batching interval + @param numPartitions: number of partitions of each RDD in the new DStream. + """ + keyed = self.map(lambda x: (x, 1)) + counted = keyed.reduceByKeyAndWindow(operator.add, operator.sub, + windowDuration, slideDuration, numPartitions) + return counted.filter(lambda (k, v): v > 0).count() + + def groupByKeyAndWindow(self, windowDuration, slideDuration, numPartitions=None): + """ + Return a new DStream by applying `groupByKey` over a sliding window. + Similar to `DStream.groupByKey()`, but applies it over a sliding window. + + @param windowDuration: width of the window; must be a multiple of this DStream's + batching interval + @param slideDuration: sliding interval of the window (i.e., the interval after which + the new DStream will generate RDDs); must be a multiple of this + DStream's batching interval + @param numPartitions: Number of partitions of each RDD in the new DStream. + """ + ls = self.mapValues(lambda x: [x]) + grouped = ls.reduceByKeyAndWindow(lambda a, b: a.extend(b) or a, lambda a, b: a[len(b):], + windowDuration, slideDuration, numPartitions) + return grouped.mapValues(ResultIterable) + + def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration=None, + numPartitions=None, filterFunc=None): + """ + Return a new DStream by applying incremental `reduceByKey` over a sliding window. + + The reduced value of over a new window is calculated using the old window's reduce value : + 1. reduce the new values that entered the window (e.g., adding new counts) + 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) + + `invFunc` can be None, then it will reduce all the RDDs in window, could be slower + than having `invFunc`. + + @param reduceFunc: associative reduce function + @param invReduceFunc: inverse function of `reduceFunc` + @param windowDuration: width of the window; must be a multiple of this DStream's + batching interval + @param slideDuration: sliding interval of the window (i.e., the interval after which + the new DStream will generate RDDs); must be a multiple of this + DStream's batching interval + @param numPartitions: number of partitions of each RDD in the new DStream. + @param filterFunc: function to filter expired key-value pairs; + only pairs that satisfy the function are retained + set this to null if you do not want to filter + """ + self._validate_window_param(windowDuration, slideDuration) + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + + reduced = self.reduceByKey(func, numPartitions) + + def reduceFunc(t, a, b): + b = b.reduceByKey(func, numPartitions) + r = a.union(b).reduceByKey(func, numPartitions) if a else b + if filterFunc: + r = r.filter(filterFunc) + return r + + def invReduceFunc(t, a, b): + b = b.reduceByKey(func, numPartitions) + joined = a.leftOuterJoin(b, numPartitions) + return joined.mapValues(lambda (v1, v2): invFunc(v1, v2) if v2 is not None else v1) + + jreduceFunc = TransformFunction(self._sc, reduceFunc, reduced._jrdd_deserializer) + if invReduceFunc: + jinvReduceFunc = TransformFunction(self._sc, invReduceFunc, reduced._jrdd_deserializer) + else: + jinvReduceFunc = None + if slideDuration is None: + slideDuration = self._slideDuration + dstream = self._sc._jvm.PythonReducedWindowedDStream(reduced._jdstream.dstream(), + jreduceFunc, jinvReduceFunc, + self._ssc._jduration(windowDuration), + self._ssc._jduration(slideDuration)) + return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer) + + def updateStateByKey(self, updateFunc, numPartitions=None): + """ + Return a new "state" DStream where the state for each key is updated by applying + the given function on the previous state of the key and the new values of the key. + + @param updateFunc: State update function. If this function returns None, then + corresponding state key-value pair will be eliminated. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + + def reduceFunc(t, a, b): + if a is None: + g = b.groupByKey(numPartitions).mapValues(lambda vs: (list(vs), None)) + else: + g = a.cogroup(b, numPartitions) + g = g.mapValues(lambda (va, vb): (list(vb), list(va)[0] if len(va) else None)) + state = g.mapValues(lambda (vs, s): updateFunc(vs, s)) + return state.filter(lambda (k, v): v is not None) + + jreduceFunc = TransformFunction(self._sc, reduceFunc, + self._sc.serializer, self._jrdd_deserializer) + dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc) + return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer) + + +class TransformedDStream(DStream): + """ + TransformedDStream is an DStream generated by an Python function + transforming each RDD of an DStream to another RDDs. + + Multiple continuous transformations of DStream can be combined into + one transformation. + """ + def __init__(self, prev, func): + self._ssc = prev._ssc + self._sc = self._ssc._sc + self._jrdd_deserializer = self._sc.serializer + self.is_cached = False + self.is_checkpointed = False + self._jdstream_val = None + + if (isinstance(prev, TransformedDStream) and + not prev.is_cached and not prev.is_checkpointed): + prev_func = prev.func + self.func = lambda t, rdd: func(t, prev_func(t, rdd)) + self.prev = prev.prev + else: + self.prev = prev + self.func = func + + @property + def _jdstream(self): + if self._jdstream_val is not None: + return self._jdstream_val + + jfunc = TransformFunction(self._sc, self.func, self.prev._jrdd_deserializer) + dstream = self._sc._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), jfunc) + self._jdstream_val = dstream.asJavaDStream() + return self._jdstream_val diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py new file mode 100644 index 0000000000000..a8d876d0fa3b3 --- /dev/null +++ b/python/pyspark/streaming/tests.py @@ -0,0 +1,545 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +from itertools import chain +import time +import operator +import unittest +import tempfile + +from pyspark.context import SparkConf, SparkContext, RDD +from pyspark.streaming.context import StreamingContext + + +class PySparkStreamingTestCase(unittest.TestCase): + + timeout = 10 # seconds + duration = 1 + + def setUp(self): + class_name = self.__class__.__name__ + conf = SparkConf().set("spark.default.parallelism", 1) + self.sc = SparkContext(appName=class_name, conf=conf) + self.sc.setCheckpointDir("/tmp") + # TODO: decrease duration to speed up tests + self.ssc = StreamingContext(self.sc, self.duration) + + def tearDown(self): + self.ssc.stop() + + def wait_for(self, result, n): + start_time = time.time() + while len(result) < n and time.time() - start_time < self.timeout: + time.sleep(0.01) + if len(result) < n: + print "timeout after", self.timeout + + def _take(self, dstream, n): + """ + Return the first `n` elements in the stream (will start and stop). + """ + results = [] + + def take(_, rdd): + if rdd and len(results) < n: + results.extend(rdd.take(n - len(results))) + + dstream.foreachRDD(take) + + self.ssc.start() + self.wait_for(results, n) + return results + + def _collect(self, dstream, n, block=True): + """ + Collect each RDDs into the returned list. + + :return: list, which will have the collected items. + """ + result = [] + + def get_output(_, rdd): + if rdd and len(result) < n: + r = rdd.collect() + if r: + result.append(r) + + dstream.foreachRDD(get_output) + + if not block: + return result + + self.ssc.start() + self.wait_for(result, n) + return result + + def _test_func(self, input, func, expected, sort=False, input2=None): + """ + @param input: dataset for the test. This should be list of lists. + @param func: wrapped function. This function should return PythonDStream object. + @param expected: expected output for this testcase. + """ + if not isinstance(input[0], RDD): + input = [self.sc.parallelize(d, 1) for d in input] + input_stream = self.ssc.queueStream(input) + if input2 and not isinstance(input2[0], RDD): + input2 = [self.sc.parallelize(d, 1) for d in input2] + input_stream2 = self.ssc.queueStream(input2) if input2 is not None else None + + # Apply test function to stream. + if input2: + stream = func(input_stream, input_stream2) + else: + stream = func(input_stream) + + result = self._collect(stream, len(expected)) + if sort: + self._sort_result_based_on_key(result) + self._sort_result_based_on_key(expected) + self.assertEqual(expected, result) + + def _sort_result_based_on_key(self, outputs): + """Sort the list based on first value.""" + for output in outputs: + output.sort(key=lambda x: x[0]) + + +class BasicOperationTests(PySparkStreamingTestCase): + + def test_map(self): + """Basic operation test for DStream.map.""" + input = [range(1, 5), range(5, 9), range(9, 13)] + + def func(dstream): + return dstream.map(str) + expected = map(lambda x: map(str, x), input) + self._test_func(input, func, expected) + + def test_flatMap(self): + """Basic operation test for DStream.faltMap.""" + input = [range(1, 5), range(5, 9), range(9, 13)] + + def func(dstream): + return dstream.flatMap(lambda x: (x, x * 2)) + expected = map(lambda x: list(chain.from_iterable((map(lambda y: [y, y * 2], x)))), + input) + self._test_func(input, func, expected) + + def test_filter(self): + """Basic operation test for DStream.filter.""" + input = [range(1, 5), range(5, 9), range(9, 13)] + + def func(dstream): + return dstream.filter(lambda x: x % 2 == 0) + expected = map(lambda x: filter(lambda y: y % 2 == 0, x), input) + self._test_func(input, func, expected) + + def test_count(self): + """Basic operation test for DStream.count.""" + input = [range(5), range(10), range(20)] + + def func(dstream): + return dstream.count() + expected = map(lambda x: [len(x)], input) + self._test_func(input, func, expected) + + def test_reduce(self): + """Basic operation test for DStream.reduce.""" + input = [range(1, 5), range(5, 9), range(9, 13)] + + def func(dstream): + return dstream.reduce(operator.add) + expected = map(lambda x: [reduce(operator.add, x)], input) + self._test_func(input, func, expected) + + def test_reduceByKey(self): + """Basic operation test for DStream.reduceByKey.""" + input = [[("a", 1), ("a", 1), ("b", 1), ("b", 1)], + [("", 1), ("", 1), ("", 1), ("", 1)], + [(1, 1), (1, 1), (2, 1), (2, 1), (3, 1)]] + + def func(dstream): + return dstream.reduceByKey(operator.add) + expected = [[("a", 2), ("b", 2)], [("", 4)], [(1, 2), (2, 2), (3, 1)]] + self._test_func(input, func, expected, sort=True) + + def test_mapValues(self): + """Basic operation test for DStream.mapValues.""" + input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)], + [("", 4), (1, 1), (2, 2), (3, 3)], + [(1, 1), (2, 1), (3, 1), (4, 1)]] + + def func(dstream): + return dstream.mapValues(lambda x: x + 10) + expected = [[("a", 12), ("b", 12), ("c", 11), ("d", 11)], + [("", 14), (1, 11), (2, 12), (3, 13)], + [(1, 11), (2, 11), (3, 11), (4, 11)]] + self._test_func(input, func, expected, sort=True) + + def test_flatMapValues(self): + """Basic operation test for DStream.flatMapValues.""" + input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)], + [("", 4), (1, 1), (2, 1), (3, 1)], + [(1, 1), (2, 1), (3, 1), (4, 1)]] + + def func(dstream): + return dstream.flatMapValues(lambda x: (x, x + 10)) + expected = [[("a", 2), ("a", 12), ("b", 2), ("b", 12), + ("c", 1), ("c", 11), ("d", 1), ("d", 11)], + [("", 4), ("", 14), (1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11)], + [(1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11), (4, 1), (4, 11)]] + self._test_func(input, func, expected) + + def test_glom(self): + """Basic operation test for DStream.glom.""" + input = [range(1, 5), range(5, 9), range(9, 13)] + rdds = [self.sc.parallelize(r, 2) for r in input] + + def func(dstream): + return dstream.glom() + expected = [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]] + self._test_func(rdds, func, expected) + + def test_mapPartitions(self): + """Basic operation test for DStream.mapPartitions.""" + input = [range(1, 5), range(5, 9), range(9, 13)] + rdds = [self.sc.parallelize(r, 2) for r in input] + + def func(dstream): + def f(iterator): + yield sum(iterator) + return dstream.mapPartitions(f) + expected = [[3, 7], [11, 15], [19, 23]] + self._test_func(rdds, func, expected) + + def test_countByValue(self): + """Basic operation test for DStream.countByValue.""" + input = [range(1, 5) * 2, range(5, 7) + range(5, 9), ["a", "a", "b", ""]] + + def func(dstream): + return dstream.countByValue() + expected = [[4], [4], [3]] + self._test_func(input, func, expected) + + def test_groupByKey(self): + """Basic operation test for DStream.groupByKey.""" + input = [[(1, 1), (2, 1), (3, 1), (4, 1)], + [(1, 1), (1, 1), (1, 1), (2, 1), (2, 1), (3, 1)], + [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1), ("", 1)]] + + def func(dstream): + return dstream.groupByKey().mapValues(list) + + expected = [[(1, [1]), (2, [1]), (3, [1]), (4, [1])], + [(1, [1, 1, 1]), (2, [1, 1]), (3, [1])], + [("a", [1, 1]), ("b", [1]), ("", [1, 1, 1])]] + self._test_func(input, func, expected, sort=True) + + def test_combineByKey(self): + """Basic operation test for DStream.combineByKey.""" + input = [[(1, 1), (2, 1), (3, 1), (4, 1)], + [(1, 1), (1, 1), (1, 1), (2, 1), (2, 1), (3, 1)], + [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1), ("", 1)]] + + def func(dstream): + def add(a, b): + return a + str(b) + return dstream.combineByKey(str, add, add) + expected = [[(1, "1"), (2, "1"), (3, "1"), (4, "1")], + [(1, "111"), (2, "11"), (3, "1")], + [("a", "11"), ("b", "1"), ("", "111")]] + self._test_func(input, func, expected, sort=True) + + def test_repartition(self): + input = [range(1, 5), range(5, 9)] + rdds = [self.sc.parallelize(r, 2) for r in input] + + def func(dstream): + return dstream.repartition(1).glom() + expected = [[[1, 2, 3, 4]], [[5, 6, 7, 8]]] + self._test_func(rdds, func, expected) + + def test_union(self): + input1 = [range(3), range(5), range(6)] + input2 = [range(3, 6), range(5, 6)] + + def func(d1, d2): + return d1.union(d2) + + expected = [range(6), range(6), range(6)] + self._test_func(input1, func, expected, input2=input2) + + def test_cogroup(self): + input = [[(1, 1), (2, 1), (3, 1)], + [(1, 1), (1, 1), (1, 1), (2, 1)], + [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1)]] + input2 = [[(1, 2)], + [(4, 1)], + [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 2)]] + + def func(d1, d2): + return d1.cogroup(d2).mapValues(lambda vs: tuple(map(list, vs))) + + expected = [[(1, ([1], [2])), (2, ([1], [])), (3, ([1], []))], + [(1, ([1, 1, 1], [])), (2, ([1], [])), (4, ([], [1]))], + [("a", ([1, 1], [1, 1])), ("b", ([1], [1])), ("", ([1, 1], [1, 2]))]] + self._test_func(input, func, expected, sort=True, input2=input2) + + def test_join(self): + input = [[('a', 1), ('b', 2)]] + input2 = [[('b', 3), ('c', 4)]] + + def func(a, b): + return a.join(b) + + expected = [[('b', (2, 3))]] + self._test_func(input, func, expected, True, input2) + + def test_left_outer_join(self): + input = [[('a', 1), ('b', 2)]] + input2 = [[('b', 3), ('c', 4)]] + + def func(a, b): + return a.leftOuterJoin(b) + + expected = [[('a', (1, None)), ('b', (2, 3))]] + self._test_func(input, func, expected, True, input2) + + def test_right_outer_join(self): + input = [[('a', 1), ('b', 2)]] + input2 = [[('b', 3), ('c', 4)]] + + def func(a, b): + return a.rightOuterJoin(b) + + expected = [[('b', (2, 3)), ('c', (None, 4))]] + self._test_func(input, func, expected, True, input2) + + def test_full_outer_join(self): + input = [[('a', 1), ('b', 2)]] + input2 = [[('b', 3), ('c', 4)]] + + def func(a, b): + return a.fullOuterJoin(b) + + expected = [[('a', (1, None)), ('b', (2, 3)), ('c', (None, 4))]] + self._test_func(input, func, expected, True, input2) + + def test_update_state_by_key(self): + + def updater(vs, s): + if not s: + s = [] + s.extend(vs) + return s + + input = [[('k', i)] for i in range(5)] + + def func(dstream): + return dstream.updateStateByKey(updater) + + expected = [[0], [0, 1], [0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]] + expected = [[('k', v)] for v in expected] + self._test_func(input, func, expected) + + +class WindowFunctionTests(PySparkStreamingTestCase): + + timeout = 20 + + def test_window(self): + input = [range(1), range(2), range(3), range(4), range(5)] + + def func(dstream): + return dstream.window(3, 1).count() + + expected = [[1], [3], [6], [9], [12], [9], [5]] + self._test_func(input, func, expected) + + def test_count_by_window(self): + input = [range(1), range(2), range(3), range(4), range(5)] + + def func(dstream): + return dstream.countByWindow(3, 1) + + expected = [[1], [3], [6], [9], [12], [9], [5]] + self._test_func(input, func, expected) + + def test_count_by_window_large(self): + input = [range(1), range(2), range(3), range(4), range(5), range(6)] + + def func(dstream): + return dstream.countByWindow(5, 1) + + expected = [[1], [3], [6], [10], [15], [20], [18], [15], [11], [6]] + self._test_func(input, func, expected) + + def test_count_by_value_and_window(self): + input = [range(1), range(2), range(3), range(4), range(5), range(6)] + + def func(dstream): + return dstream.countByValueAndWindow(5, 1) + + expected = [[1], [2], [3], [4], [5], [6], [6], [6], [6], [6]] + self._test_func(input, func, expected) + + def test_group_by_key_and_window(self): + input = [[('a', i)] for i in range(5)] + + def func(dstream): + return dstream.groupByKeyAndWindow(3, 1).mapValues(list) + + expected = [[('a', [0])], [('a', [0, 1])], [('a', [0, 1, 2])], [('a', [1, 2, 3])], + [('a', [2, 3, 4])], [('a', [3, 4])], [('a', [4])]] + self._test_func(input, func, expected) + + def test_reduce_by_invalid_window(self): + input1 = [range(3), range(5), range(1), range(6)] + d1 = self.ssc.queueStream(input1) + self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 0.1, 0.1)) + self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 1, 0.1)) + + +class StreamingContextTests(PySparkStreamingTestCase): + + duration = 0.1 + + def _add_input_stream(self): + inputs = map(lambda x: range(1, x), range(101)) + stream = self.ssc.queueStream(inputs) + self._collect(stream, 1, block=False) + + def test_stop_only_streaming_context(self): + self._add_input_stream() + self.ssc.start() + self.ssc.stop(False) + self.assertEqual(len(self.sc.parallelize(range(5), 5).glom().collect()), 5) + + def test_stop_multiple_times(self): + self._add_input_stream() + self.ssc.start() + self.ssc.stop() + self.ssc.stop() + + def test_queue_stream(self): + input = [range(i + 1) for i in range(3)] + dstream = self.ssc.queueStream(input) + result = self._collect(dstream, 3) + self.assertEqual(input, result) + + def test_text_file_stream(self): + d = tempfile.mkdtemp() + self.ssc = StreamingContext(self.sc, self.duration) + dstream2 = self.ssc.textFileStream(d).map(int) + result = self._collect(dstream2, 2, block=False) + self.ssc.start() + for name in ('a', 'b'): + time.sleep(1) + with open(os.path.join(d, name), "w") as f: + f.writelines(["%d\n" % i for i in range(10)]) + self.wait_for(result, 2) + self.assertEqual([range(10), range(10)], result) + + def test_union(self): + input = [range(i + 1) for i in range(3)] + dstream = self.ssc.queueStream(input) + dstream2 = self.ssc.queueStream(input) + dstream3 = self.ssc.union(dstream, dstream2) + result = self._collect(dstream3, 3) + expected = [i * 2 for i in input] + self.assertEqual(expected, result) + + def test_transform(self): + dstream1 = self.ssc.queueStream([[1]]) + dstream2 = self.ssc.queueStream([[2]]) + dstream3 = self.ssc.queueStream([[3]]) + + def func(rdds): + rdd1, rdd2, rdd3 = rdds + return rdd2.union(rdd3).union(rdd1) + + dstream = self.ssc.transform([dstream1, dstream2, dstream3], func) + + self.assertEqual([2, 3, 1], self._take(dstream, 3)) + + +class CheckpointTests(PySparkStreamingTestCase): + + def setUp(self): + pass + + def test_get_or_create(self): + inputd = tempfile.mkdtemp() + outputd = tempfile.mkdtemp() + "/" + + def updater(vs, s): + return sum(vs, s or 0) + + def setup(): + conf = SparkConf().set("spark.default.parallelism", 1) + sc = SparkContext(conf=conf) + ssc = StreamingContext(sc, 0.5) + dstream = ssc.textFileStream(inputd).map(lambda x: (x, 1)) + wc = dstream.updateStateByKey(updater) + wc.map(lambda x: "%s,%d" % x).saveAsTextFiles(outputd + "test") + wc.checkpoint(.5) + return ssc + + cpd = tempfile.mkdtemp("test_streaming_cps") + self.ssc = ssc = StreamingContext.getOrCreate(cpd, setup) + ssc.start() + + def check_output(n): + while not os.listdir(outputd): + time.sleep(0.1) + time.sleep(1) # make sure mtime is larger than the previous one + with open(os.path.join(inputd, str(n)), 'w') as f: + f.writelines(["%d\n" % i for i in range(10)]) + + while True: + p = os.path.join(outputd, max(os.listdir(outputd))) + if '_SUCCESS' not in os.listdir(p): + # not finished + time.sleep(0.01) + continue + ordd = ssc.sparkContext.textFile(p).map(lambda line: line.split(",")) + d = ordd.values().map(int).collect() + if not d: + time.sleep(0.01) + continue + self.assertEqual(10, len(d)) + s = set(d) + self.assertEqual(1, len(s)) + m = s.pop() + if n > m: + continue + self.assertEqual(n, m) + break + + check_output(1) + check_output(2) + ssc.stop(True, True) + + time.sleep(1) + self.ssc = ssc = StreamingContext.getOrCreate(cpd, setup) + ssc.start() + check_output(3) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py new file mode 100644 index 0000000000000..86ee5aa04f252 --- /dev/null +++ b/python/pyspark/streaming/util.py @@ -0,0 +1,128 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import time +from datetime import datetime +import traceback + +from pyspark import SparkContext, RDD + + +class TransformFunction(object): + """ + This class wraps a function RDD[X] -> RDD[Y] that was passed to + DStream.transform(), allowing it to be called from Java via Py4J's + callback server. + + Java calls this function with a sequence of JavaRDDs and this function + returns a single JavaRDD pointer back to Java. + """ + _emptyRDD = None + + def __init__(self, ctx, func, *deserializers): + self.ctx = ctx + self.func = func + self.deserializers = deserializers + + def call(self, milliseconds, jrdds): + try: + if self.ctx is None: + self.ctx = SparkContext._active_spark_context + if not self.ctx or not self.ctx._jsc: + # stopped + return + + # extend deserializers with the first one + sers = self.deserializers + if len(sers) < len(jrdds): + sers += (sers[0],) * (len(jrdds) - len(sers)) + + rdds = [RDD(jrdd, self.ctx, ser) if jrdd else None + for jrdd, ser in zip(jrdds, sers)] + t = datetime.fromtimestamp(milliseconds / 1000.0) + r = self.func(t, *rdds) + if r: + return r._jrdd + except Exception: + traceback.print_exc() + + def __repr__(self): + return "TransformFunction(%s)" % self.func + + class Java: + implements = ['org.apache.spark.streaming.api.python.PythonTransformFunction'] + + +class TransformFunctionSerializer(object): + """ + This class implements a serializer for PythonTransformFunction Java + objects. + + This is necessary because the Java PythonTransformFunction objects are + actually Py4J references to Python objects and thus are not directly + serializable. When Java needs to serialize a PythonTransformFunction, + it uses this class to invoke Python, which returns the serialized function + as a byte array. + """ + def __init__(self, ctx, serializer, gateway=None): + self.ctx = ctx + self.serializer = serializer + self.gateway = gateway or self.ctx._gateway + self.gateway.jvm.PythonDStream.registerSerializer(self) + + def dumps(self, id): + try: + func = self.gateway.gateway_property.pool[id] + return bytearray(self.serializer.dumps((func.func, func.deserializers))) + except Exception: + traceback.print_exc() + + def loads(self, bytes): + try: + f, deserializers = self.serializer.loads(str(bytes)) + return TransformFunction(self.ctx, f, *deserializers) + except Exception: + traceback.print_exc() + + def __repr__(self): + return "TransformFunctionSerializer(%s)" % self.serializer + + class Java: + implements = ['org.apache.spark.streaming.api.python.PythonTransformFunctionSerializer'] + + +def rddToFileName(prefix, suffix, timestamp): + """ + Return string prefix-time(.suffix) + + >>> rddToFileName("spark", None, 12345678910) + 'spark-12345678910' + >>> rddToFileName("spark", "tmp", 12345678910) + 'spark-12345678910.tmp' + """ + if isinstance(timestamp, datetime): + seconds = time.mktime(timestamp.timetuple()) + timestamp = long(seconds * 1000) + timestamp.microsecond / 1000 + if suffix is None: + return prefix + "-" + str(timestamp) + else: + return prefix + "-" + str(timestamp) + "." + suffix + + +if __name__ == "__main__": + import doctest + doctest.testmod() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 7f05d48ade2b3..f5ccf31abb3fa 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -34,7 +34,11 @@ from platform import python_implementation if sys.version_info[:2] <= (2, 6): - import unittest2 as unittest + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) else: import unittest @@ -679,6 +683,12 @@ def test_udf(self): [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() self.assertEqual(row[0], 5) + def test_udf2(self): + self.sqlCtx.registerFunction("strlen", lambda string: len(string)) + self.sqlCtx.inferSchema(self.sc.parallelize([Row(a="test")])).registerTempTable("test") + [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() + self.assertEqual(u"4", res[0]) + def test_broadcast_in_udf(self): bar = {"a": "aa", "b": "bb", "c": "abc"} foo = self.sc.broadcast(bar) diff --git a/python/run-tests b/python/run-tests index c713861eb77bb..80acd002ab7eb 100755 --- a/python/run-tests +++ b/python/run-tests @@ -25,16 +25,17 @@ FWDIR="$(cd "`dirname "$0"`"; cd ../; pwd)" cd "$FWDIR/python" FAILED=0 +LOG_FILE=unit-tests.log -rm -f unit-tests.log +rm -f $LOG_FILE # Remove the metastore and warehouse directory created by the HiveContext tests in Spark SQL rm -rf metastore warehouse function run_test() { - echo "Running test: $1" + echo "Running test: $1" | tee -a $LOG_FILE - SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 2>&1 | tee -a unit-tests.log + SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 2>&1 | tee -a $LOG_FILE FAILED=$((PIPESTATUS[0]||$FAILED)) @@ -69,6 +70,7 @@ function run_mllib_tests() { echo "Run mllib tests ..." run_test "pyspark/mllib/classification.py" run_test "pyspark/mllib/clustering.py" + run_test "pyspark/mllib/feature.py" run_test "pyspark/mllib/linalg.py" run_test "pyspark/mllib/random.py" run_test "pyspark/mllib/recommendation.py" @@ -79,6 +81,12 @@ function run_mllib_tests() { run_test "pyspark/mllib/tests.py" } +function run_streaming_tests() { + echo "Run streaming tests ..." + run_test "pyspark/streaming/util.py" + run_test "pyspark/streaming/tests.py" +} + echo "Running PySpark tests. Output is in python/unit-tests.log." export PYSPARK_PYTHON="python" @@ -94,6 +102,7 @@ $PYSPARK_PYTHON --version run_core_tests run_sql_tests run_mllib_tests +run_streaming_tests # Try to test with PyPy if [ $(which pypy) ]; then @@ -103,6 +112,7 @@ if [ $(which pypy) ]; then run_core_tests run_sql_tests + run_streaming_tests fi if [[ $FAILED == 0 ]]; then diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 6ddb6accd696b..646c68e60c2e9 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -84,9 +84,11 @@ import org.apache.spark.util.Utils * @author Moez A. Abdel-Gawad * @author Lex Spoon */ - class SparkIMain(initialSettings: Settings, val out: JPrintWriter) - extends SparkImports with Logging { - imain => + class SparkIMain( + initialSettings: Settings, + val out: JPrintWriter, + propagateExceptions: Boolean = false) + extends SparkImports with Logging { imain => val conf = new SparkConf() @@ -816,6 +818,10 @@ import org.apache.spark.util.Utils val resultName = FixedSessionNames.resultName def bindError(t: Throwable) = { + // Immediately throw the exception if we are asked to propagate them + if (propagateExceptions) { + throw unwrap(t) + } if (!bindExceptions) // avoid looping if already binding throw t diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala index 3e2ee7541f40d..6a79e76a34db8 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala @@ -23,8 +23,6 @@ import java.net.{URL, URLClassLoader} import org.scalatest.BeforeAndAfterAll import org.scalatest.FunSuite -import com.google.common.io.Files - import org.apache.spark.{SparkConf, TestUtils} import org.apache.spark.util.Utils @@ -39,10 +37,8 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll { override def beforeAll() { super.beforeAll() - tempDir1 = Files.createTempDir() - tempDir1.deleteOnExit() - tempDir2 = Files.createTempDir() - tempDir2.deleteOnExit() + tempDir1 = Utils.createTempDir() + tempDir2 = Utils.createTempDir() url1 = "file://" + tempDir1 urls2 = List(tempDir2.toURI.toURL).toArray childClassNames.foreach(TestUtils.createCompiledClass(_, tempDir1, "1")) diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala index c8763eb277052..91c9c52c3c98a 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -22,7 +22,6 @@ import java.net.URLClassLoader import scala.collection.mutable.ArrayBuffer -import com.google.common.io.Files import org.scalatest.FunSuite import org.apache.spark.SparkContext import org.apache.commons.lang3.StringEscapeUtils @@ -190,8 +189,7 @@ class ReplSuite extends FunSuite { } test("interacting with files") { - val tempDir = Files.createTempDir() - tempDir.deleteOnExit() + val tempDir = Utils.createTempDir() val out = new FileWriter(tempDir + "/input") out.write("Hello world!\n") out.write("What's up?\n") diff --git a/scalastyle-config.xml b/scalastyle-config.xml index c54f8b72ebf42..0ff521706c71a 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -141,5 +141,5 @@ - + diff --git a/sql/README.md b/sql/README.md index 31f9152344086..c84534da9a3d3 100644 --- a/sql/README.md +++ b/sql/README.md @@ -44,38 +44,37 @@ Type in expressions to have them evaluated. Type :help for more information. scala> val query = sql("SELECT * FROM (SELECT * FROM src) a") -query: org.apache.spark.sql.ExecutedQuery = -SELECT * FROM (SELECT * FROM src) a -=== Query Plan === -Project [key#6:0.0,value#7:0.1] - HiveTableScan [key#6,value#7], (MetastoreRelation default, src, None), None +query: org.apache.spark.sql.SchemaRDD = +== Query Plan == +== Physical Plan == +HiveTableScan [key#10,value#11], (MetastoreRelation default, src, None), None ``` Query results are RDDs and can be operated as such. ``` scala> query.collect() -res8: Array[org.apache.spark.sql.execution.Row] = Array([238,val_238], [86,val_86], [311,val_311]... +res2: Array[org.apache.spark.sql.Row] = Array([238,val_238], [86,val_86], [311,val_311], [27,val_27]... ``` You can also build further queries on top of these RDDs using the query DSL. ``` -scala> query.where('key === 100).toRdd.collect() -res11: Array[org.apache.spark.sql.execution.Row] = Array([100,val_100], [100,val_100]) +scala> query.where('key === 100).collect() +res3: Array[org.apache.spark.sql.Row] = Array([100,val_100], [100,val_100]) ``` -From the console you can even write rules that transform query plans. For example, the above query has redundant project operators that aren't doing anything. This redundancy can be eliminated using the `transform` function that is available on all [`TreeNode`](http://databricks.github.io/catalyst/latest/api/#catalyst.trees.TreeNode) objects. +From the console you can even write rules that transform query plans. For example, the above query has redundant project operators that aren't doing anything. This redundancy can be eliminated using the `transform` function that is available on all [`TreeNode`](https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala) objects. ```scala -scala> query.logicalPlan -res1: catalyst.plans.logical.LogicalPlan = -Project {key#0,value#1} - Project {key#0,value#1} +scala> query.queryExecution.analyzed +res4: org.apache.spark.sql.catalyst.plans.logical.LogicalPlan = +Project [key#10,value#11] + Project [key#10,value#11] MetastoreRelation default, src, None -scala> query.logicalPlan transform { +scala> query.queryExecution.analyzed transform { | case Project(projectList, child) if projectList == child.output => child | } -res2: catalyst.plans.logical.LogicalPlan = -Project {key#0,value#1} +res5: res17: org.apache.spark.sql.catalyst.plans.logical.LogicalPlan = +Project [key#10,value#11] MetastoreRelation default, src, None ``` diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index b3ae8e6779700..3d4296f9d7068 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation @@ -77,8 +77,9 @@ object ScalaReflection { val Schema(valueDataType, valueNullable) = schemaFor(valueType) Schema(MapType(schemaFor(keyType).dataType, valueDataType, valueContainsNull = valueNullable), nullable = true) - case t if t <:< typeOf[String] => Schema(StringType, nullable = true) + case t if t <:< typeOf[String] => Schema(StringType, nullable = true) case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true) + case t if t <:< typeOf[Date] => Schema(DateType, nullable = true) case t if t <:< typeOf[BigDecimal] => Schema(DecimalType, nullable = true) case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala new file mode 100644 index 0000000000000..04467342e6ab5 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import scala.language.implicitConversions +import scala.util.parsing.combinator.lexical.StdLexical +import scala.util.parsing.combinator.syntactical.StandardTokenParsers +import scala.util.parsing.combinator.{PackratParsers, RegexParsers} +import scala.util.parsing.input.CharArrayReader.EofCh + +import org.apache.spark.sql.catalyst.plans.logical._ + +private[sql] abstract class AbstractSparkSQLParser + extends StandardTokenParsers with PackratParsers { + + def apply(input: String): LogicalPlan = phrase(start)(new lexical.Scanner(input)) match { + case Success(plan, _) => plan + case failureOrError => sys.error(failureOrError.toString) + } + + protected case class Keyword(str: String) + + protected def start: Parser[LogicalPlan] + + // Returns the whole input string + protected lazy val wholeInput: Parser[String] = new Parser[String] { + def apply(in: Input): ParseResult[String] = + Success(in.source.toString, in.drop(in.source.length())) + } + + // Returns the rest of the input string that are not parsed yet + protected lazy val restInput: Parser[String] = new Parser[String] { + def apply(in: Input): ParseResult[String] = + Success( + in.source.subSequence(in.offset, in.source.length()).toString, + in.drop(in.source.length())) + } +} + +class SqlLexical(val keywords: Seq[String]) extends StdLexical { + case class FloatLit(chars: String) extends Token { + override def toString = chars + } + + reserved ++= keywords.flatMap(w => allCaseVersions(w)) + + delimiters += ( + "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", + ",", ";", "%", "{", "}", ":", "[", "]", "." + ) + + override lazy val token: Parser[Token] = + ( identChar ~ (identChar | digit).* ^^ + { case first ~ rest => processIdent((first :: rest).mkString) } + | rep1(digit) ~ ('.' ~> digit.*).? ^^ { + case i ~ None => NumericLit(i.mkString) + case i ~ Some(d) => FloatLit(i.mkString + "." + d.mkString) + } + | '\'' ~> chrExcept('\'', '\n', EofCh).* <~ '\'' ^^ + { case chars => StringLit(chars mkString "") } + | '"' ~> chrExcept('"', '\n', EofCh).* <~ '"' ^^ + { case chars => StringLit(chars mkString "") } + | EofCh ^^^ EOF + | '\'' ~> failure("unclosed string literal") + | '"' ~> failure("unclosed string literal") + | delim + | failure("illegal character") + ) + + override def identChar = letter | elem('_') + + override def whitespace: Parser[Any] = + ( whitespaceChar + | '/' ~ '*' ~ comment + | '/' ~ '/' ~ chrExcept(EofCh, '\n').* + | '#' ~ chrExcept(EofCh, '\n').* + | '-' ~ '-' ~ chrExcept(EofCh, '\n').* + | '/' ~ '*' ~ failure("unclosed comment") + ).* + + /** Generate all variations of upper and lower case of a given string */ + def allCaseVersions(s: String, prefix: String = ""): Stream[String] = { + if (s == "") { + Stream(prefix) + } else { + allCaseVersions(s.tail, prefix + s.head.toLower) ++ + allCaseVersions(s.tail, prefix + s.head.toUpper) + } + } +} + +/** + * The top level Spark SQL parser. This parser recognizes syntaxes that are available for all SQL + * dialects supported by Spark SQL, and delegates all the other syntaxes to the `fallback` parser. + * + * @param fallback A function that parses an input string to a logical plan + */ +private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLParser { + + // A parser for the key-value part of the "SET [key = [value ]]" syntax + private object SetCommandParser extends RegexParsers { + private val key: Parser[String] = "(?m)[^=]+".r + + private val value: Parser[String] = "(?m).*$".r + + private val pair: Parser[LogicalPlan] = + (key ~ ("=".r ~> value).?).? ^^ { + case None => SetCommand(None) + case Some(k ~ v) => SetCommand(Some(k.trim -> v.map(_.trim))) + } + + def apply(input: String): LogicalPlan = parseAll(pair, input) match { + case Success(plan, _) => plan + case x => sys.error(x.toString) + } + } + + protected val AS = Keyword("AS") + protected val CACHE = Keyword("CACHE") + protected val LAZY = Keyword("LAZY") + protected val SET = Keyword("SET") + protected val TABLE = Keyword("TABLE") + protected val SOURCE = Keyword("SOURCE") + protected val UNCACHE = Keyword("UNCACHE") + + protected implicit def asParser(k: Keyword): Parser[String] = + lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) + + private val reservedWords: Seq[String] = + this + .getClass + .getMethods + .filter(_.getReturnType == classOf[Keyword]) + .map(_.invoke(this).asInstanceOf[Keyword].str) + + override val lexical = new SqlLexical(reservedWords) + + override protected lazy val start: Parser[LogicalPlan] = + cache | uncache | set | shell | source | others + + private lazy val cache: Parser[LogicalPlan] = + CACHE ~> LAZY.? ~ (TABLE ~> ident) ~ (AS ~> restInput).? ^^ { + case isLazy ~ tableName ~ plan => + CacheTableCommand(tableName, plan.map(fallback), isLazy.isDefined) + } + + private lazy val uncache: Parser[LogicalPlan] = + UNCACHE ~ TABLE ~> ident ^^ { + case tableName => UncacheTableCommand(tableName) + } + + private lazy val set: Parser[LogicalPlan] = + SET ~> restInput ^^ { + case input => SetCommandParser(input) + } + + private lazy val shell: Parser[LogicalPlan] = + "!" ~> restInput ^^ { + case input => ShellCommand(input.trim) + } + + private lazy val source: Parser[LogicalPlan] = + SOURCE ~> restInput ^^ { + case input => SourceCommand(input.trim) + } + + private lazy val others: Parser[LogicalPlan] = + wholeInput ^^ { + case input => fallback(input) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 99f83244735e1..d594c64b2a512 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -20,10 +20,6 @@ package org.apache.spark.sql.catalyst import java.lang.reflect.Method import scala.language.implicitConversions -import scala.util.parsing.combinator.lexical.StdLexical -import scala.util.parsing.combinator.syntactical.StandardTokenParsers -import scala.util.parsing.combinator.PackratParsers -import scala.util.parsing.input.CharArrayReader.EofCh import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ @@ -41,31 +37,7 @@ import org.apache.spark.sql.catalyst.types._ * This is currently included mostly for illustrative purposes. Users wanting more complete support * for a SQL like language should checkout the HiveQL support in the sql/hive sub-project. */ -class SqlParser extends StandardTokenParsers with PackratParsers { - - def apply(input: String): LogicalPlan = { - // Special-case out set commands since the value fields can be - // complex to handle without RegexParsers. Also this approach - // is clearer for the several possible cases of set commands. - if (input.trim.toLowerCase.startsWith("set")) { - input.trim.drop(3).split("=", 2).map(_.trim) match { - case Array("") => // "set" - SetCommand(None, None) - case Array(key) => // "set key" - SetCommand(Some(key), None) - case Array(key, value) => // "set key=value" - SetCommand(Some(key), Some(value)) - } - } else { - phrase(query)(new lexical.Scanner(input)) match { - case Success(r, x) => r - case x => sys.error(x.toString) - } - } - } - - protected case class Keyword(str: String) - +class SqlParser extends AbstractSparkSQLParser { protected implicit def asParser(k: Keyword): Parser[String] = lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) @@ -79,10 +51,13 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val BETWEEN = Keyword("BETWEEN") protected val BY = Keyword("BY") protected val CACHE = Keyword("CACHE") + protected val CASE = Keyword("CASE") protected val CAST = Keyword("CAST") protected val COUNT = Keyword("COUNT") protected val DESC = Keyword("DESC") protected val DISTINCT = Keyword("DISTINCT") + protected val ELSE = Keyword("ELSE") + protected val END = Keyword("END") protected val EXCEPT = Keyword("EXCEPT") protected val FALSE = Keyword("FALSE") protected val FIRST = Keyword("FIRST") @@ -99,7 +74,6 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val IS = Keyword("IS") protected val JOIN = Keyword("JOIN") protected val LAST = Keyword("LAST") - protected val LAZY = Keyword("LAZY") protected val LEFT = Keyword("LEFT") protected val LIKE = Keyword("LIKE") protected val LIMIT = Keyword("LIMIT") @@ -124,22 +98,23 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val SUBSTRING = Keyword("SUBSTRING") protected val SUM = Keyword("SUM") protected val TABLE = Keyword("TABLE") + protected val THEN = Keyword("THEN") protected val TIMESTAMP = Keyword("TIMESTAMP") protected val TRUE = Keyword("TRUE") - protected val UNCACHE = Keyword("UNCACHE") protected val UNION = Keyword("UNION") protected val UPPER = Keyword("UPPER") + protected val WHEN = Keyword("WHEN") protected val WHERE = Keyword("WHERE") // Use reflection to find the reserved words defined in this class. protected val reservedWords = - this.getClass + this + .getClass .getMethods .filter(_.getReturnType == classOf[Keyword]) - .filter(_.toString.contains("org.apache.spark.sql.catalyst.SqlParser.".toCharArray)) - .map{ m : Method => m.invoke(this).asInstanceOf[Keyword].str} - + .map{_.invoke(this).asInstanceOf[Keyword].str} override val lexical = new SqlLexical(reservedWords) + println(reservedWords) protected def assignAliases(exprs: Seq[Expression]): Seq[NamedExpression] = { exprs.zipWithIndex.map { @@ -148,86 +123,68 @@ class SqlParser extends StandardTokenParsers with PackratParsers { } } - protected lazy val query: Parser[LogicalPlan] = ( - select * ( - UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) } | - INTERSECT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Intersect(q1, q2) } | - EXCEPT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Except(q1, q2)} | - UNION ~ opt(DISTINCT) ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) } + protected lazy val start: Parser[LogicalPlan] = + ( select * + ( UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) } + | INTERSECT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Intersect(q1, q2) } + | EXCEPT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Except(q1, q2)} + | UNION ~ DISTINCT.? ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) } ) - | insert | cache | unCache - ) + | insert + ) protected lazy val select: Parser[LogicalPlan] = - SELECT ~> opt(DISTINCT) ~ projections ~ - opt(from) ~ opt(filter) ~ - opt(grouping) ~ - opt(having) ~ - opt(orderBy) ~ - opt(limit) <~ opt(";") ^^ { - case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l => - val base = r.getOrElse(NoRelation) - val withFilter = f.map(f => Filter(f, base)).getOrElse(base) - val withProjection = - g.map {g => - Aggregate(g, assignAliases(p), withFilter) - }.getOrElse(Project(assignAliases(p), withFilter)) - val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection) - val withHaving = h.map(h => Filter(h, withDistinct)).getOrElse(withDistinct) - val withOrder = o.map(o => Sort(o, withHaving)).getOrElse(withHaving) - val withLimit = l.map { l => Limit(l, withOrder) }.getOrElse(withOrder) - withLimit - } + SELECT ~> DISTINCT.? ~ + repsep(projection, ",") ~ + (FROM ~> relations).? ~ + (WHERE ~> expression).? ~ + (GROUP ~ BY ~> rep1sep(expression, ",")).? ~ + (HAVING ~> expression).? ~ + (ORDER ~ BY ~> ordering).? ~ + (LIMIT ~> expression).? ^^ { + case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l => + val base = r.getOrElse(NoRelation) + val withFilter = f.map(f => Filter(f, base)).getOrElse(base) + val withProjection = g + .map(Aggregate(_, assignAliases(p), withFilter)) + .getOrElse(Project(assignAliases(p), withFilter)) + val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection) + val withHaving = h.map(Filter(_, withDistinct)).getOrElse(withDistinct) + val withOrder = o.map(Sort(_, withHaving)).getOrElse(withHaving) + val withLimit = l.map(Limit(_, withOrder)).getOrElse(withOrder) + withLimit + } protected lazy val insert: Parser[LogicalPlan] = - INSERT ~> opt(OVERWRITE) ~ inTo ~ select <~ opt(";") ^^ { - case o ~ r ~ s => - val overwrite: Boolean = o.getOrElse("") == "OVERWRITE" - InsertIntoTable(r, Map[String, Option[String]](), s, overwrite) - } - - protected lazy val cache: Parser[LogicalPlan] = - CACHE ~> opt(LAZY) ~ (TABLE ~> ident) ~ opt(AS ~> select) <~ opt(";") ^^ { - case isLazy ~ tableName ~ plan => - CacheTableCommand(tableName, plan, isLazy.isDefined) - } - - protected lazy val unCache: Parser[LogicalPlan] = - UNCACHE ~ TABLE ~> ident <~ opt(";") ^^ { - case tableName => UncacheTableCommand(tableName) + INSERT ~> OVERWRITE.? ~ (INTO ~> relation) ~ select ^^ { + case o ~ r ~ s => InsertIntoTable(r, Map.empty[String, Option[String]], s, o.isDefined) } - protected lazy val projections: Parser[Seq[Expression]] = repsep(projection, ",") - protected lazy val projection: Parser[Expression] = - expression ~ (opt(AS) ~> opt(ident)) ^^ { - case e ~ None => e - case e ~ Some(a) => Alias(e, a)() + expression ~ (AS.? ~> ident.?) ^^ { + case e ~ a => a.fold(e)(Alias(e, _)()) } - protected lazy val from: Parser[LogicalPlan] = FROM ~> relations - - protected lazy val inTo: Parser[LogicalPlan] = INTO ~> relation - // Based very loosely on the MySQL Grammar. // http://dev.mysql.com/doc/refman/5.0/en/join.html protected lazy val relations: Parser[LogicalPlan] = - relation ~ "," ~ relation ^^ { case r1 ~ _ ~ r2 => Join(r1, r2, Inner, None) } | - relation + ( relation ~ ("," ~> relation) ^^ { case r1 ~ r2 => Join(r1, r2, Inner, None) } + | relation + ) protected lazy val relation: Parser[LogicalPlan] = - joinedRelation | - relationFactor + joinedRelation | relationFactor protected lazy val relationFactor: Parser[LogicalPlan] = - ident ~ (opt(AS) ~> opt(ident)) ^^ { - case tableName ~ alias => UnresolvedRelation(None, tableName, alias) - } | - "(" ~> query ~ ")" ~ opt(AS) ~ ident ^^ { case s ~ _ ~ _ ~ a => Subquery(a, s) } + ( ident ~ (opt(AS) ~> opt(ident)) ^^ { + case tableName ~ alias => UnresolvedRelation(None, tableName, alias) + } + | ("(" ~> start <~ ")") ~ (AS.? ~> ident) ^^ { case s ~ a => Subquery(a, s) } + ) protected lazy val joinedRelation: Parser[LogicalPlan] = - relationFactor ~ opt(joinType) ~ JOIN ~ relationFactor ~ opt(joinConditions) ^^ { - case r1 ~ jt ~ _ ~ r2 ~ cond => + relationFactor ~ joinType.? ~ (JOIN ~> relationFactor) ~ joinConditions.? ^^ { + case r1 ~ jt ~ r2 ~ cond => Join(r1, r2, joinType = jt.getOrElse(Inner), cond) } @@ -235,151 +192,145 @@ class SqlParser extends StandardTokenParsers with PackratParsers { ON ~> expression protected lazy val joinType: Parser[JoinType] = - INNER ^^^ Inner | - LEFT ~ SEMI ^^^ LeftSemi | - LEFT ~ opt(OUTER) ^^^ LeftOuter | - RIGHT ~ opt(OUTER) ^^^ RightOuter | - FULL ~ opt(OUTER) ^^^ FullOuter - - protected lazy val filter: Parser[Expression] = WHERE ~ expression ^^ { case _ ~ e => e } - - protected lazy val orderBy: Parser[Seq[SortOrder]] = - ORDER ~> BY ~> ordering + ( INNER ^^^ Inner + | LEFT ~ SEMI ^^^ LeftSemi + | LEFT ~ OUTER.? ^^^ LeftOuter + | RIGHT ~ OUTER.? ^^^ RightOuter + | FULL ~ OUTER.? ^^^ FullOuter + ) protected lazy val ordering: Parser[Seq[SortOrder]] = - rep1sep(singleOrder, ",") | - rep1sep(expression, ",") ~ opt(direction) ^^ { - case exps ~ None => exps.map(SortOrder(_, Ascending)) - case exps ~ Some(d) => exps.map(SortOrder(_, d)) - } + ( rep1sep(singleOrder, ",") + | rep1sep(expression, ",") ~ direction.? ^^ { + case exps ~ d => exps.map(SortOrder(_, d.getOrElse(Ascending))) + } + ) protected lazy val singleOrder: Parser[SortOrder] = - expression ~ direction ^^ { case e ~ o => SortOrder(e,o) } + expression ~ direction ^^ { case e ~ o => SortOrder(e, o) } protected lazy val direction: Parser[SortDirection] = - ASC ^^^ Ascending | - DESC ^^^ Descending - - protected lazy val grouping: Parser[Seq[Expression]] = - GROUP ~> BY ~> rep1sep(expression, ",") - - protected lazy val having: Parser[Expression] = - HAVING ~> expression - - protected lazy val limit: Parser[Expression] = - LIMIT ~> expression + ( ASC ^^^ Ascending + | DESC ^^^ Descending + ) - protected lazy val expression: Parser[Expression] = orExpression + protected lazy val expression: Parser[Expression] = + orExpression protected lazy val orExpression: Parser[Expression] = - andExpression * (OR ^^^ { (e1: Expression, e2: Expression) => Or(e1,e2) }) + andExpression * (OR ^^^ { (e1: Expression, e2: Expression) => Or(e1, e2) }) protected lazy val andExpression: Parser[Expression] = - comparisonExpression * (AND ^^^ { (e1: Expression, e2: Expression) => And(e1,e2) }) + comparisonExpression * (AND ^^^ { (e1: Expression, e2: Expression) => And(e1, e2) }) protected lazy val comparisonExpression: Parser[Expression] = - termExpression ~ "=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => EqualTo(e1, e2) } | - termExpression ~ "<" ~ termExpression ^^ { case e1 ~ _ ~ e2 => LessThan(e1, e2) } | - termExpression ~ "<=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => LessThanOrEqual(e1, e2) } | - termExpression ~ ">" ~ termExpression ^^ { case e1 ~ _ ~ e2 => GreaterThan(e1, e2) } | - termExpression ~ ">=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => GreaterThanOrEqual(e1, e2) } | - termExpression ~ "!=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(EqualTo(e1, e2)) } | - termExpression ~ "<>" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(EqualTo(e1, e2)) } | - termExpression ~ BETWEEN ~ termExpression ~ AND ~ termExpression ^^ { - case e ~ _ ~ el ~ _ ~ eu => And(GreaterThanOrEqual(e, el), LessThanOrEqual(e, eu)) - } | - termExpression ~ RLIKE ~ termExpression ^^ { case e1 ~ _ ~ e2 => RLike(e1, e2) } | - termExpression ~ REGEXP ~ termExpression ^^ { case e1 ~ _ ~ e2 => RLike(e1, e2) } | - termExpression ~ LIKE ~ termExpression ^^ { case e1 ~ _ ~ e2 => Like(e1, e2) } | - termExpression ~ IN ~ "(" ~ rep1sep(termExpression, ",") <~ ")" ^^ { - case e1 ~ _ ~ _ ~ e2 => In(e1, e2) - } | - termExpression ~ NOT ~ IN ~ "(" ~ rep1sep(termExpression, ",") <~ ")" ^^ { - case e1 ~ _ ~ _ ~ _ ~ e2 => Not(In(e1, e2)) - } | - termExpression <~ IS ~ NULL ^^ { case e => IsNull(e) } | - termExpression <~ IS ~ NOT ~ NULL ^^ { case e => IsNotNull(e) } | - NOT ~> termExpression ^^ {e => Not(e)} | - termExpression + ( termExpression ~ ("=" ~> termExpression) ^^ { case e1 ~ e2 => EqualTo(e1, e2) } + | termExpression ~ ("<" ~> termExpression) ^^ { case e1 ~ e2 => LessThan(e1, e2) } + | termExpression ~ ("<=" ~> termExpression) ^^ { case e1 ~ e2 => LessThanOrEqual(e1, e2) } + | termExpression ~ (">" ~> termExpression) ^^ { case e1 ~ e2 => GreaterThan(e1, e2) } + | termExpression ~ (">=" ~> termExpression) ^^ { case e1 ~ e2 => GreaterThanOrEqual(e1, e2) } + | termExpression ~ ("!=" ~> termExpression) ^^ { case e1 ~ e2 => Not(EqualTo(e1, e2)) } + | termExpression ~ ("<>" ~> termExpression) ^^ { case e1 ~ e2 => Not(EqualTo(e1, e2)) } + | termExpression ~ (BETWEEN ~> termExpression) ~ (AND ~> termExpression) ^^ { + case e ~ el ~ eu => And(GreaterThanOrEqual(e, el), LessThanOrEqual(e, eu)) + } + | termExpression ~ (RLIKE ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) } + | termExpression ~ (REGEXP ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) } + | termExpression ~ (LIKE ~> termExpression) ^^ { case e1 ~ e2 => Like(e1, e2) } + | termExpression ~ (IN ~ "(" ~> rep1sep(termExpression, ",")) <~ ")" ^^ { + case e1 ~ e2 => In(e1, e2) + } + | termExpression ~ (NOT ~ IN ~ "(" ~> rep1sep(termExpression, ",")) <~ ")" ^^ { + case e1 ~ e2 => Not(In(e1, e2)) + } + | termExpression <~ IS ~ NULL ^^ { case e => IsNull(e) } + | termExpression <~ IS ~ NOT ~ NULL ^^ { case e => IsNotNull(e) } + | NOT ~> termExpression ^^ {e => Not(e)} + | termExpression + ) protected lazy val termExpression: Parser[Expression] = - productExpression * ( - "+" ^^^ { (e1: Expression, e2: Expression) => Add(e1,e2) } | - "-" ^^^ { (e1: Expression, e2: Expression) => Subtract(e1,e2) } ) + productExpression * + ( "+" ^^^ { (e1: Expression, e2: Expression) => Add(e1, e2) } + | "-" ^^^ { (e1: Expression, e2: Expression) => Subtract(e1, e2) } + ) protected lazy val productExpression: Parser[Expression] = - baseExpression * ( - "*" ^^^ { (e1: Expression, e2: Expression) => Multiply(e1,e2) } | - "/" ^^^ { (e1: Expression, e2: Expression) => Divide(e1,e2) } | - "%" ^^^ { (e1: Expression, e2: Expression) => Remainder(e1,e2) } - ) + baseExpression * + ( "*" ^^^ { (e1: Expression, e2: Expression) => Multiply(e1, e2) } + | "/" ^^^ { (e1: Expression, e2: Expression) => Divide(e1, e2) } + | "%" ^^^ { (e1: Expression, e2: Expression) => Remainder(e1, e2) } + ) protected lazy val function: Parser[Expression] = - SUM ~> "(" ~> expression <~ ")" ^^ { case exp => Sum(exp) } | - SUM ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => SumDistinct(exp) } | - COUNT ~> "(" ~ "*" <~ ")" ^^ { case _ => Count(Literal(1)) } | - COUNT ~> "(" ~ expression <~ ")" ^^ { case dist ~ exp => Count(exp) } | - COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => CountDistinct(exp :: Nil) } | - APPROXIMATE ~> COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { - case exp => ApproxCountDistinct(exp) - } | - APPROXIMATE ~> "(" ~> floatLit ~ ")" ~ COUNT ~ "(" ~ DISTINCT ~ expression <~ ")" ^^ { - case s ~ _ ~ _ ~ _ ~ _ ~ e => ApproxCountDistinct(e, s.toDouble) - } | - FIRST ~> "(" ~> expression <~ ")" ^^ { case exp => First(exp) } | - LAST ~> "(" ~> expression <~ ")" ^^ { case exp => Last(exp) } | - AVG ~> "(" ~> expression <~ ")" ^^ { case exp => Average(exp) } | - MIN ~> "(" ~> expression <~ ")" ^^ { case exp => Min(exp) } | - MAX ~> "(" ~> expression <~ ")" ^^ { case exp => Max(exp) } | - UPPER ~> "(" ~> expression <~ ")" ^^ { case exp => Upper(exp) } | - LOWER ~> "(" ~> expression <~ ")" ^^ { case exp => Lower(exp) } | - IF ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ { - case c ~ "," ~ t ~ "," ~ f => If(c,t,f) - } | - (SUBSTR | SUBSTRING) ~> "(" ~> expression ~ "," ~ expression <~ ")" ^^ { - case s ~ "," ~ p => Substring(s,p,Literal(Integer.MAX_VALUE)) - } | - (SUBSTR | SUBSTRING) ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ { - case s ~ "," ~ p ~ "," ~ l => Substring(s,p,l) - } | - SQRT ~> "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) } | - ABS ~> "(" ~> expression <~ ")" ^^ { case exp => Abs(exp) } | - ident ~ "(" ~ repsep(expression, ",") <~ ")" ^^ { - case udfName ~ _ ~ exprs => UnresolvedFunction(udfName, exprs) - } + ( SUM ~> "(" ~> expression <~ ")" ^^ { case exp => Sum(exp) } + | SUM ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => SumDistinct(exp) } + | COUNT ~ "(" ~> "*" <~ ")" ^^ { case _ => Count(Literal(1)) } + | COUNT ~ "(" ~> expression <~ ")" ^^ { case exp => Count(exp) } + | COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => CountDistinct(exp :: Nil) } + | APPROXIMATE ~ COUNT ~ "(" ~ DISTINCT ~> expression <~ ")" ^^ + { case exp => ApproxCountDistinct(exp) } + | APPROXIMATE ~> "(" ~> floatLit ~ ")" ~ COUNT ~ "(" ~ DISTINCT ~ expression <~ ")" ^^ + { case s ~ _ ~ _ ~ _ ~ _ ~ e => ApproxCountDistinct(e, s.toDouble) } + | FIRST ~ "(" ~> expression <~ ")" ^^ { case exp => First(exp) } + | LAST ~ "(" ~> expression <~ ")" ^^ { case exp => Last(exp) } + | AVG ~ "(" ~> expression <~ ")" ^^ { case exp => Average(exp) } + | MIN ~ "(" ~> expression <~ ")" ^^ { case exp => Min(exp) } + | MAX ~ "(" ~> expression <~ ")" ^^ { case exp => Max(exp) } + | UPPER ~ "(" ~> expression <~ ")" ^^ { case exp => Upper(exp) } + | LOWER ~ "(" ~> expression <~ ")" ^^ { case exp => Lower(exp) } + | IF ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^ + { case c ~ t ~ f => If(c, t, f) } + | CASE ~> expression.? ~ (WHEN ~> expression ~ (THEN ~> expression)).* ~ + (ELSE ~> expression).? <~ END ^^ { + case casePart ~ altPart ~ elsePart => + val altExprs = altPart.flatMap { case whenExpr ~ thenExpr => + Seq(casePart.fold(whenExpr)(EqualTo(_, whenExpr)), thenExpr) + } + CaseWhen(altExprs ++ elsePart.toList) + } + | (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) <~ ")" ^^ + { case s ~ p => Substring(s, p, Literal(Integer.MAX_VALUE)) } + | (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^ + { case s ~ p ~ l => Substring(s, p, l) } + | SQRT ~ "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) } + | ABS ~ "(" ~> expression <~ ")" ^^ { case exp => Abs(exp) } + | ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^ + { case udfName ~ exprs => UnresolvedFunction(udfName, exprs) } + ) protected lazy val cast: Parser[Expression] = - CAST ~> "(" ~> expression ~ AS ~ dataType <~ ")" ^^ { case exp ~ _ ~ t => Cast(exp, t) } + CAST ~ "(" ~> expression ~ (AS ~> dataType) <~ ")" ^^ { case exp ~ t => Cast(exp, t) } protected lazy val literal: Parser[Literal] = - numericLit ^^ { - case i if i.toLong > Int.MaxValue => Literal(i.toLong) - case i => Literal(i.toInt) - } | - NULL ^^^ Literal(null, NullType) | - floatLit ^^ {case f => Literal(f.toDouble) } | - stringLit ^^ {case s => Literal(s, StringType) } + ( numericLit ^^ { + case i if i.toLong > Int.MaxValue => Literal(i.toLong) + case i => Literal(i.toInt) + } + | NULL ^^^ Literal(null, NullType) + | floatLit ^^ {case f => Literal(f.toDouble) } + | stringLit ^^ {case s => Literal(s, StringType) } + ) protected lazy val floatLit: Parser[String] = elem("decimal", _.isInstanceOf[lexical.FloatLit]) ^^ (_.chars) protected lazy val baseExpression: PackratParser[Expression] = - expression ~ "[" ~ expression <~ "]" ^^ { - case base ~ _ ~ ordinal => GetItem(base, ordinal) - } | - (expression <~ ".") ~ ident ^^ { - case base ~ fieldName => GetField(base, fieldName) - } | - TRUE ^^^ Literal(true, BooleanType) | - FALSE ^^^ Literal(false, BooleanType) | - cast | - "(" ~> expression <~ ")" | - function | - "-" ~> literal ^^ UnaryMinus | - dotExpressionHeader | - ident ^^ UnresolvedAttribute | - "*" ^^^ Star(None) | - literal + ( expression ~ ("[" ~> expression <~ "]") ^^ + { case base ~ ordinal => GetItem(base, ordinal) } + | (expression <~ ".") ~ ident ^^ + { case base ~ fieldName => GetField(base, fieldName) } + | TRUE ^^^ Literal(true, BooleanType) + | FALSE ^^^ Literal(false, BooleanType) + | cast + | "(" ~> expression <~ ")" + | function + | "-" ~> literal ^^ UnaryMinus + | dotExpressionHeader + | ident ^^ UnresolvedAttribute + | "*" ^^^ Star(None) + | literal + ) protected lazy val dotExpressionHeader: Parser[Expression] = (ident <~ ".") ~ ident ~ rep("." ~> ident) ^^ { @@ -389,55 +340,3 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected lazy val dataType: Parser[DataType] = STRING ^^^ StringType | TIMESTAMP ^^^ TimestampType } - -class SqlLexical(val keywords: Seq[String]) extends StdLexical { - case class FloatLit(chars: String) extends Token { - override def toString = chars - } - - reserved ++= keywords.flatMap(w => allCaseVersions(w)) - - delimiters += ( - "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", - ",", ";", "%", "{", "}", ":", "[", "]", "." - ) - - override lazy val token: Parser[Token] = ( - identChar ~ rep( identChar | digit ) ^^ - { case first ~ rest => processIdent(first :: rest mkString "") } - | rep1(digit) ~ opt('.' ~> rep(digit)) ^^ { - case i ~ None => NumericLit(i mkString "") - case i ~ Some(d) => FloatLit(i.mkString("") + "." + d.mkString("")) - } - | '\'' ~ rep( chrExcept('\'', '\n', EofCh) ) ~ '\'' ^^ - { case '\'' ~ chars ~ '\'' => StringLit(chars mkString "") } - | '\"' ~ rep( chrExcept('\"', '\n', EofCh) ) ~ '\"' ^^ - { case '\"' ~ chars ~ '\"' => StringLit(chars mkString "") } - | EofCh ^^^ EOF - | '\'' ~> failure("unclosed string literal") - | '\"' ~> failure("unclosed string literal") - | delim - | failure("illegal character") - ) - - override def identChar = letter | elem('_') - - override def whitespace: Parser[Any] = rep( - whitespaceChar - | '/' ~ '*' ~ comment - | '/' ~ '/' ~ rep( chrExcept(EofCh, '\n') ) - | '#' ~ rep( chrExcept(EofCh, '\n') ) - | '-' ~ '-' ~ rep( chrExcept(EofCh, '\n') ) - | '/' ~ '*' ~ failure("unclosed comment") - ) - - /** Generate all variations of upper and lower case of a given string */ - def allCaseVersions(s: String, prefix: String = ""): Stream[String] = { - if (s == "") { - Stream(prefix) - } else { - allCaseVersions(s.tail, prefix + s.head.toLower) ++ - allCaseVersions(s.tail, prefix + s.head.toUpper) - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f83e2d25f2bca..82553063145b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -63,7 +63,8 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool typeCoercionRules ++ extendedRules : _*), Batch("Check Analysis", Once, - CheckResolution), + CheckResolution, + CheckAggregation), Batch("AnalysisOperators", fixedPoint, EliminateAnalysisOperators) ) @@ -80,6 +81,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool case p if !p.resolved && p.childrenResolved => throw new TreeNodeException(p, "Unresolved plan found") } match { + // As a backstop, use the root node to check that the entire plan tree is resolved. case p if !p.resolved => throw new TreeNodeException(p, "Unresolved plan in tree") case p => p @@ -87,6 +89,32 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool } } + /** + * Checks for non-aggregated attributes with aggregation + */ + object CheckAggregation extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = { + plan.transform { + case aggregatePlan @ Aggregate(groupingExprs, aggregateExprs, child) => + def isValidAggregateExpression(expr: Expression): Boolean = expr match { + case _: AggregateExpression => true + case e: Attribute => groupingExprs.contains(e) + case e if groupingExprs.contains(e) => true + case e if e.references.isEmpty => true + case e => e.children.forall(isValidAggregateExpression) + } + + aggregateExprs.foreach { e => + if (!isValidAggregateExpression(e)) { + throw new TreeNodeException(plan, s"Expression not in GROUP BY: $e") + } + } + + aggregatePlan + } + } + } + /** * Replaces [[UnresolvedRelation]]s with concrete relations from the catalog. */ @@ -212,7 +240,6 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool Filter(evaluatedCondition.toAttribute, aggregate.copy(aggregateExpressions = aggExprsWithHaving))) } - } protected def containsAggregate(condition: Expression): Boolean = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 79e5283e86a37..7c480de107e7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -220,20 +220,39 @@ trait HiveTypeCoercion { case a: BinaryArithmetic if a.right.dataType == StringType => a.makeCopy(Array(a.left, Cast(a.right, DoubleType))) + // we should cast all timestamp/date/string compare into string compare + case p: BinaryPredicate if p.left.dataType == StringType + && p.right.dataType == DateType => + p.makeCopy(Array(p.left, Cast(p.right, StringType))) + case p: BinaryPredicate if p.left.dataType == DateType + && p.right.dataType == StringType => + p.makeCopy(Array(Cast(p.left, StringType), p.right)) case p: BinaryPredicate if p.left.dataType == StringType && p.right.dataType == TimestampType => - p.makeCopy(Array(Cast(p.left, TimestampType), p.right)) + p.makeCopy(Array(p.left, Cast(p.right, StringType))) case p: BinaryPredicate if p.left.dataType == TimestampType && p.right.dataType == StringType => - p.makeCopy(Array(p.left, Cast(p.right, TimestampType))) + p.makeCopy(Array(Cast(p.left, StringType), p.right)) + case p: BinaryPredicate if p.left.dataType == TimestampType + && p.right.dataType == DateType => + p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType))) + case p: BinaryPredicate if p.left.dataType == DateType + && p.right.dataType == TimestampType => + p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType))) case p: BinaryPredicate if p.left.dataType == StringType && p.right.dataType != StringType => p.makeCopy(Array(Cast(p.left, DoubleType), p.right)) case p: BinaryPredicate if p.left.dataType != StringType && p.right.dataType == StringType => p.makeCopy(Array(p.left, Cast(p.right, DoubleType))) - case i @ In(a,b) if a.dataType == TimestampType && b.forall(_.dataType == StringType) => - i.makeCopy(Array(a,b.map(Cast(_,TimestampType)))) + case i @ In(a, b) if a.dataType == DateType && b.forall(_.dataType == StringType) => + i.makeCopy(Array(Cast(a, StringType), b)) + case i @ In(a, b) if a.dataType == TimestampType && b.forall(_.dataType == StringType) => + i.makeCopy(Array(Cast(a, StringType), b)) + case i @ In(a, b) if a.dataType == DateType && b.forall(_.dataType == TimestampType) => + i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType)))) + case i @ In(a, b) if a.dataType == TimestampType && b.forall(_.dataType == DateType) => + i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType)))) case Sum(e) if e.dataType == StringType => Sum(Cast(e, DoubleType)) @@ -283,6 +302,8 @@ trait HiveTypeCoercion { // Skip if the type is boolean type already. Note that this extra cast should be removed // by optimizer.SimplifyCasts. case Cast(e, BooleanType) if e.dataType == BooleanType => e + // DateType should be null if be cast to boolean. + case Cast(e, BooleanType) if e.dataType == DateType => Cast(e, BooleanType) // If the data type is not boolean and is being cast boolean, turn it into a comparison // with the numeric value, i.e. x != 0. This will coerce the type into numeric type. case Cast(e, BooleanType) if e.dataType != BooleanType => Not(EqualTo(e, Literal(0))) @@ -348,8 +369,11 @@ trait HiveTypeCoercion { case e if !e.childrenResolved => e // Decimal and Double remain the same - case d: Divide if d.dataType == DoubleType => d - case d: Divide if d.dataType == DecimalType => d + case d: Divide if d.resolved && d.dataType == DoubleType => d + case d: Divide if d.resolved && d.dataType == DecimalType => d + + case Divide(l, r) if l.dataType == DecimalType => Divide(l, Cast(r, DecimalType)) + case Divide(l, r) if r.dataType == DecimalType => Divide(Cast(l, DecimalType), r) case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 67570a6f73c36..77d84e1687e1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -88,7 +88,7 @@ case class Star( mapFunction: Attribute => Expression = identity[Attribute]) extends Attribute with trees.LeafNode[Expression] { - override def name = throw new UnresolvedException(this, "exprId") + override def name = throw new UnresolvedException(this, "name") override def exprId = throw new UnresolvedException(this, "exprId") override def dataType = throw new UnresolvedException(this, "dataType") override def nullable = throw new UnresolvedException(this, "nullable") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index deb622c39faf5..75b6e37c2a1f9 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import scala.language.implicitConversions @@ -119,6 +119,7 @@ package object dsl { implicit def floatToLiteral(f: Float) = Literal(f) implicit def doubleToLiteral(d: Double) = Literal(d) implicit def stringToLiteral(s: String) = Literal(s) + implicit def dateToLiteral(d: Date) = Literal(d) implicit def decimalToLiteral(d: BigDecimal) = Literal(d) implicit def timestampToLiteral(t: Timestamp) = Literal(t) implicit def binaryToLiteral(a: Array[Byte]) = Literal(a) @@ -174,6 +175,9 @@ package object dsl { /** Creates a new AttributeReference of type string */ def string = AttributeReference(s, StringType, nullable = true)() + /** Creates a new AttributeReference of type date */ + def date = AttributeReference(s, DateType, nullable = true)() + /** Creates a new AttributeReference of type decimal */ def decimal = AttributeReference(s, DecimalType, nullable = true)() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala index c3a08bbdb6bc7..2b4969b7cfec0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -17,19 +17,26 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.analysis.Star + protected class AttributeEquals(val a: Attribute) { override def hashCode() = a.exprId.hashCode() - override def equals(other: Any) = other match { - case otherReference: AttributeEquals => a.exprId == otherReference.a.exprId - case otherAttribute => false + override def equals(other: Any) = (a, other.asInstanceOf[AttributeEquals].a) match { + case (a1: AttributeReference, a2: AttributeReference) => a1.exprId == a2.exprId + case (a1, a2) => a1 == a2 } } object AttributeSet { - /** Constructs a new [[AttributeSet]] given a sequence of [[Attribute Attributes]]. */ - def apply(baseSet: Seq[Attribute]) = { - new AttributeSet(baseSet.map(new AttributeEquals(_)).toSet) - } + def apply(a: Attribute) = + new AttributeSet(Set(new AttributeEquals(a))) + + /** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */ + def apply(baseSet: Seq[Expression]) = + new AttributeSet( + baseSet + .flatMap(_.references) + .map(new AttributeEquals(_)).toSet) } /** @@ -103,4 +110,6 @@ class AttributeSet private (val baseSet: Set[AttributeEquals]) // We must force toSeq to not be strict otherwise we end up with a [[Stream]] that captures all // sorts of things in its closure. override def toSeq: Seq[Attribute] = baseSet.map(_.a).toArray.toSeq + + override def toString = "{" + baseSet.map(_.a).mkString(", ") + "}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index f626d09f037bc..8e5ee12e314bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -17,18 +17,21 @@ package org.apache.spark.sql.catalyst.expressions -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.types._ /** Cast the child expression to the target data type. */ -case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { +case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging { override def foldable = child.foldable override def nullable = (child.dataType, dataType) match { case (StringType, _: NumericType) => true case (StringType, TimestampType) => true + case (StringType, DateType) => true case _ => child.nullable } @@ -42,6 +45,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { // UDFToString private[this] def castToString: Any => Any = child.dataType match { case BinaryType => buildCast[Array[Byte]](_, new String(_, "UTF-8")) + case DateType => buildCast[Date](_, dateToString) case TimestampType => buildCast[Timestamp](_, timestampToString) case _ => buildCast[Any](_, _.toString) } @@ -56,7 +60,10 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { case StringType => buildCast[String](_, _.length() != 0) case TimestampType => - buildCast[Timestamp](_, b => b.getTime() != 0 || b.getNanos() != 0) + buildCast[Timestamp](_, t => t.getTime() != 0 || t.getNanos() != 0) + case DateType => + // Hive would return null when cast from date to boolean + buildCast[Date](_, d => null) case LongType => buildCast[Long](_, _ != 0) case IntegerType => @@ -95,6 +102,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { buildCast[Short](_, s => new Timestamp(s)) case ByteType => buildCast[Byte](_, b => new Timestamp(b)) + case DateType => + buildCast[Date](_, d => new Timestamp(d.getTime)) // TimestampWritable.decimalToTimestamp case DecimalType => buildCast[BigDecimal](_, d => decimalToTimestamp(d)) @@ -130,7 +139,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { // Converts Timestamp to string according to Hive TimestampWritable convention private[this] def timestampToString(ts: Timestamp): String = { val timestampString = ts.toString - val formatted = Cast.threadLocalDateFormat.get.format(ts) + val formatted = Cast.threadLocalTimestampFormat.get.format(ts) if (timestampString.length > 19 && timestampString.substring(19) != ".0") { formatted + timestampString.substring(19) @@ -139,6 +148,39 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { } } + // Converts Timestamp to string according to Hive TimestampWritable convention + private[this] def timestampToDateString(ts: Timestamp): String = { + Cast.threadLocalDateFormat.get.format(ts) + } + + // DateConverter + private[this] def castToDate: Any => Any = child.dataType match { + case StringType => + buildCast[String](_, s => + try Date.valueOf(s) catch { case _: java.lang.IllegalArgumentException => null } + ) + case TimestampType => + // throw valid precision more than seconds, according to Hive. + // Timestamp.nanos is in 0 to 999,999,999, no more than a second. + buildCast[Timestamp](_, t => new Date(Math.floor(t.getTime / 1000.0).toLong * 1000)) + // Hive throws this exception as a Semantic Exception + // It is never possible to compare result when hive return with exception, so we can return null + // NULL is more reasonable here, since the query itself obeys the grammar. + case _ => _ => null + } + + // Date cannot be cast to long, according to hive + private[this] def dateToLong(d: Date) = null + + // Date cannot be cast to double, according to hive + private[this] def dateToDouble(d: Date) = null + + // Converts Date to string according to Hive DateWritable convention + private[this] def dateToString(d: Date): String = { + Cast.threadLocalDateFormat.get.format(d) + } + + // LongConverter private[this] def castToLong: Any => Any = child.dataType match { case StringType => buildCast[String](_, s => try s.toLong catch { @@ -146,6 +188,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { }) case BooleanType => buildCast[Boolean](_, b => if (b) 1L else 0L) + case DateType => + buildCast[Date](_, d => dateToLong(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t)) case DecimalType => @@ -154,6 +198,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b) } + // IntConverter private[this] def castToInt: Any => Any = child.dataType match { case StringType => buildCast[String](_, s => try s.toInt catch { @@ -161,6 +206,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { }) case BooleanType => buildCast[Boolean](_, b => if (b) 1 else 0) + case DateType => + buildCast[Date](_, d => dateToLong(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t).toInt) case DecimalType => @@ -169,6 +216,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b) } + // ShortConverter private[this] def castToShort: Any => Any = child.dataType match { case StringType => buildCast[String](_, s => try s.toShort catch { @@ -176,6 +224,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { }) case BooleanType => buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort) + case DateType => + buildCast[Date](_, d => dateToLong(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t).toShort) case DecimalType => @@ -184,6 +234,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort } + // ByteConverter private[this] def castToByte: Any => Any = child.dataType match { case StringType => buildCast[String](_, s => try s.toByte catch { @@ -191,6 +242,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { }) case BooleanType => buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte) + case DateType => + buildCast[Date](_, d => dateToLong(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t).toByte) case DecimalType => @@ -199,6 +252,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte } + // DecimalConverter private[this] def castToDecimal: Any => Any = child.dataType match { case StringType => buildCast[String](_, s => try BigDecimal(s.toDouble) catch { @@ -206,6 +260,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { }) case BooleanType => buildCast[Boolean](_, b => if (b) BigDecimal(1) else BigDecimal(0)) + case DateType => + buildCast[Date](_, d => dateToDouble(d)) case TimestampType => // Note that we lose precision here. buildCast[Timestamp](_, t => BigDecimal(timestampToDouble(t))) @@ -213,6 +269,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)) } + // DoubleConverter private[this] def castToDouble: Any => Any = child.dataType match { case StringType => buildCast[String](_, s => try s.toDouble catch { @@ -220,6 +277,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { }) case BooleanType => buildCast[Boolean](_, b => if (b) 1d else 0d) + case DateType => + buildCast[Date](_, d => dateToDouble(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToDouble(t)) case DecimalType => @@ -228,6 +287,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b) } + // FloatConverter private[this] def castToFloat: Any => Any = child.dataType match { case StringType => buildCast[String](_, s => try s.toFloat catch { @@ -235,6 +295,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { }) case BooleanType => buildCast[Boolean](_, b => if (b) 1f else 0f) + case DateType => + buildCast[Date](_, d => dateToDouble(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToDouble(t).toFloat) case DecimalType => @@ -245,17 +307,18 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { private[this] lazy val cast: Any => Any = dataType match { case dt if dt == child.dataType => identity[Any] - case StringType => castToString - case BinaryType => castToBinary - case DecimalType => castToDecimal + case StringType => castToString + case BinaryType => castToBinary + case DecimalType => castToDecimal + case DateType => castToDate case TimestampType => castToTimestamp - case BooleanType => castToBoolean - case ByteType => castToByte - case ShortType => castToShort - case IntegerType => castToInt - case FloatType => castToFloat - case LongType => castToLong - case DoubleType => castToDouble + case BooleanType => castToBoolean + case ByteType => castToByte + case ShortType => castToShort + case IntegerType => castToInt + case FloatType => castToFloat + case LongType => castToLong + case DoubleType => castToDouble } override def eval(input: Row): Any = { @@ -267,6 +330,13 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { object Cast { // `SimpleDateFormat` is not thread-safe. private[sql] val threadLocalDateFormat = new ThreadLocal[DateFormat] { + override def initialValue() = { + new SimpleDateFormat("yyyy-MM-dd") + } + } + + // `SimpleDateFormat` is not thread-safe. + private[sql] val threadLocalTimestampFormat = new ThreadLocal[DateFormat] { override def initialValue() = { new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index ef1d12531f109..e7e81a21fdf03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -39,6 +39,8 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { } new GenericRow(outputArray) } + + override def toString = s"Row => [${exprArray.mkString(",")}]" } /** @@ -137,6 +139,9 @@ class JoinedRow extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + override def getAs[T](i: Int): T = + if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) + def copy() = { val totalSize = row1.size + row2.size val copiedValues = new Array[Any](totalSize) @@ -226,6 +231,9 @@ class JoinedRow2 extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + override def getAs[T](i: Int): T = + if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) + def copy() = { val totalSize = row1.size + row2.size val copiedValues = new Array[Any](totalSize) @@ -309,6 +317,9 @@ class JoinedRow3 extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + override def getAs[T](i: Int): T = + if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) + def copy() = { val totalSize = row1.size + row2.size val copiedValues = new Array[Any](totalSize) @@ -392,6 +403,9 @@ class JoinedRow4 extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + override def getAs[T](i: Int): T = + if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) + def copy() = { val totalSize = row1.size + row2.size val copiedValues = new Array[Any](totalSize) @@ -475,6 +489,9 @@ class JoinedRow5 extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + override def getAs[T](i: Int): T = + if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size) + def copy() = { val totalSize = row1.size + row2.size val copiedValues = new Array[Any](totalSize) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index d68a4fabeac77..d00ec39774c35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -64,6 +64,7 @@ trait Row extends Seq[Any] with Serializable { def getShort(i: Int): Short def getByte(i: Int): Byte def getString(i: Int): String + def getAs[T](i: Int): T = apply(i).asInstanceOf[T] override def toString() = s"[${this.mkString(",")}]" @@ -118,6 +119,7 @@ object EmptyRow extends Row { def getShort(i: Int): Short = throw new UnsupportedOperationException def getByte(i: Int): Byte = throw new UnsupportedOperationException def getString(i: Int): String = throw new UnsupportedOperationException + override def getAs[T](i: Int): T = throw new UnsupportedOperationException def copy() = this } @@ -217,19 +219,19 @@ class GenericMutableRow(size: Int) extends GenericRow(size) with MutableRow { /** No-arg constructor for serialization. */ def this() = this(0) - override def setBoolean(ordinal: Int,value: Boolean): Unit = { values(ordinal) = value } - override def setByte(ordinal: Int,value: Byte): Unit = { values(ordinal) = value } - override def setDouble(ordinal: Int,value: Double): Unit = { values(ordinal) = value } - override def setFloat(ordinal: Int,value: Float): Unit = { values(ordinal) = value } - override def setInt(ordinal: Int,value: Int): Unit = { values(ordinal) = value } - override def setLong(ordinal: Int,value: Long): Unit = { values(ordinal) = value } - override def setString(ordinal: Int,value: String): Unit = { values(ordinal) = value } + override def setBoolean(ordinal: Int, value: Boolean): Unit = { values(ordinal) = value } + override def setByte(ordinal: Int, value: Byte): Unit = { values(ordinal) = value } + override def setDouble(ordinal: Int, value: Double): Unit = { values(ordinal) = value } + override def setFloat(ordinal: Int, value: Float): Unit = { values(ordinal) = value } + override def setInt(ordinal: Int, value: Int): Unit = { values(ordinal) = value } + override def setLong(ordinal: Int, value: Long): Unit = { values(ordinal) = value } + override def setString(ordinal: Int, value: String): Unit = { values(ordinal) = value } override def setNullAt(i: Int): Unit = { values(i) = null } - override def setShort(ordinal: Int,value: Short): Unit = { values(ordinal) = value } + override def setShort(ordinal: Int, value: Short): Unit = { values(ordinal) = value } - override def update(ordinal: Int,value: Any): Unit = { values(ordinal) = value } + override def update(ordinal: Int, value: Any): Unit = { values(ordinal) = value } override def copy() = new GenericRow(values.clone()) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala similarity index 97% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 9cbab3d5d0d0d..570379c533e1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -233,9 +233,9 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def iterator: Iterator[Any] = values.map(_.boxed).iterator - def setString(ordinal: Int, value: String) = update(ordinal, value) + override def setString(ordinal: Int, value: String) = update(ordinal, value) - def getString(ordinal: Int) = apply(ordinal).asInstanceOf[String] + override def getString(ordinal: Int) = apply(ordinal).asInstanceOf[String] override def setInt(ordinal: Int, value: Int): Unit = { val currentValue = values(ordinal).asInstanceOf[MutableInt] @@ -306,4 +306,8 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def getByte(i: Int): Byte = { values(i).asInstanceOf[MutableByte].value } + + override def getAs[T](i: Int): T = { + values(i).boxed.asInstanceOf[T] + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala index 1eb55715794a7..1a4ac06c7a79d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala @@ -24,9 +24,7 @@ import org.apache.spark.sql.catalyst.types.DataType /** * The data type representing [[DynamicRow]] values. */ -case object DynamicType extends DataType { - def simpleString: String = "dynamic" -} +case object DynamicType extends DataType /** * Wrap a [[Row]] as a [[DynamicRow]]. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 78a0c55e4bbe5..ba240233cae61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.types._ @@ -33,6 +33,7 @@ object Literal { case b: Boolean => Literal(b, BooleanType) case d: BigDecimal => Literal(d, DecimalType) case t: Timestamp => Literal(t, TimestampType) + case d: Date => Literal(d, DateType) case a: Array[Byte] => Literal(a, BinaryType) case null => Literal(null, NullType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index e5a958d599393..d023db44d8543 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -57,6 +57,8 @@ abstract class NamedExpression extends Expression { abstract class Attribute extends NamedExpression { self: Product => + override def references = AttributeSet(this) + def withNullability(newNullability: Boolean): Attribute def withQualifiers(newQualifiers: Seq[String]): Attribute def withName(newName: String): Attribute @@ -116,8 +118,6 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea (val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil) extends Attribute with trees.LeafNode[Expression] { - override def references = AttributeSet(this :: Nil) - override def equals(other: Any) = other match { case ar: AttributeReference => exprId == ar.exprId && dataType == ar.dataType case _ => false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 329af332d0fa1..1e22b2d03c672 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.immutable.HashSet import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.types.BooleanType - object InterpretedPredicate { def apply(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = apply(BindReferences.bindReference(expression, inputSchema)) @@ -95,6 +95,23 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } } +/** + * Optimized version of In clause, when all filter values of In clause are + * static. + */ +case class InSet(value: Expression, hset: HashSet[Any], child: Seq[Expression]) + extends Predicate { + + def children = child + + def nullable = true // TODO: Figure out correct nullability semantics of IN. + override def toString = s"$value INSET ${hset.mkString("(", ",", ")")}" + + override def eval(input: Row): Any = { + hset.contains(value.eval(input)) + } +} + case class And(left: Expression, right: Expression) extends BinaryPredicate { def symbol = "&&" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a4133feae8166..3693b41404fd6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer +import scala.collection.immutable.HashSet import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.FullOuter @@ -38,7 +39,8 @@ object Optimizer extends RuleExecutor[LogicalPlan] { BooleanSimplification, SimplifyFilters, SimplifyCasts, - SimplifyCaseConversionExpressions) :: + SimplifyCaseConversionExpressions, + OptimizeIn) :: Batch("Filter Pushdown", FixedPoint(100), UnionPushdown, CombineFilters, @@ -273,6 +275,20 @@ object ConstantFolding extends Rule[LogicalPlan] { } } +/** + * Replaces [[In (value, seq[Literal])]] with optimized version[[InSet (value, HashSet[Literal])]] + * which is much faster + */ +object OptimizeIn extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case q: LogicalPlan => q transformExpressionsDown { + case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) => + val hSet = list.map(e => e.eval(null)) + InSet(v, HashSet() ++ hSet, v +: list) + } + } +} + /** * Simplifies boolean expressions where the answer can be determined without evaluating both sides. * Note that this rule can eliminate expressions that might otherwise have been evaluated and thus @@ -299,6 +315,18 @@ object BooleanSimplification extends Rule[LogicalPlan] { case (_, _) => or } + case not @ Not(exp) => + exp match { + case Literal(true, BooleanType) => Literal(false) + case Literal(false, BooleanType) => Literal(true) + case GreaterThan(l, r) => LessThanOrEqual(l, r) + case GreaterThanOrEqual(l, r) => LessThan(l, r) + case LessThan(l, r) => GreaterThanOrEqual(l, r) + case LessThanOrEqual(l, r) => GreaterThan(l, r) + case Not(e) => e + case _ => not + } + // Turn "if (true) a else b" into "a", and if (false) a else b" into "b". case e @ If(Literal(v, _), trueValue, falseValue) => if (v == true) trueValue else falseValue } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index af9e4d86e995a..dcbbb62c0aca4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -31,6 +31,25 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy */ def outputSet: AttributeSet = AttributeSet(output) + /** + * All Attributes that appear in expressions from this operator. Note that this set does not + * include attributes that are implicitly referenced by being passed through to the output tuple. + */ + def references: AttributeSet = AttributeSet(expressions.flatMap(_.references)) + + /** + * The set of all attributes that are input to this operator by its children. + */ + def inputSet: AttributeSet = + AttributeSet(children.flatMap(_.asInstanceOf[QueryPlan[PlanType]].output)) + + /** + * Attributes that are referenced by expressions but not provided by this nodes children. + * Subclasses should override this method if they produce attributes internally as it is used by + * assertions designed to prevent the construction of invalid plans. + */ + def missingInput: AttributeSet = references -- inputSet + /** * Runs [[transform]] with `rule` on all expressions present in this query operator. * Users should not expect a specific directionality. If a specific directionality is needed, @@ -132,4 +151,8 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy /** Prints out the schema in the tree format */ def printSchema(): Unit = println(schemaString) + + protected def statePrefix = if (missingInput.nonEmpty && children.nonEmpty) "!" else "" + + override def simpleString = statePrefix + super.simpleString } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 627ec3e139ea6..882e9c6110089 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -53,12 +53,6 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { sizeInBytes = children.map(_.statistics).map(_.sizeInBytes).product) } - /** - * Returns the set of attributes that this node takes as - * input from its children. - */ - lazy val inputSet: AttributeSet = AttributeSet(children.flatMap(_.output)) - /** * Returns true if this expression and all its children have been resolved to a specific schema * and false if it still contains any unresolved placeholders. Implementations of LogicalPlan @@ -68,6 +62,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { */ lazy val resolved: Boolean = !expressions.exists(!_.resolved) && childrenResolved + override protected def statePrefix = if (!resolved) "'" else super.statePrefix + /** * Returns true if all its children of this query plan have been resolved. */ @@ -144,21 +140,12 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { // struct fields. val options = input.flatMap { option => // If the first part of the desired name matches a qualifier for this possible match, drop it. - val remainingParts = { - if (option==null) { - throw new IllegalStateException( - "Null member of input attributes found when resolving %s from inputs %s" - .format(name, input.mkString("[",",","]"))) - } -// assert(option != null) - assert(option.qualifiers != null) - assert(parts != null) + val remainingParts = if (option.qualifiers.find(resolver(_, parts.head)).nonEmpty && parts.size > 1) { parts.drop(1) } else { parts } - } if (resolver(option.name, remainingParts.head)) { // Preserve the case of the user's attribute reference. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index f8e9930ac270d..14b03c7445c13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -138,11 +138,6 @@ case class Aggregate( child: LogicalPlan) extends UnaryNode { - /** The set of all AttributeReferences required for this aggregation. */ - def references = - AttributeSet( - groupingExpressions.flatMap(_.references) ++ aggregateExpressions.flatMap(_.references)) - override def output = aggregateExpressions.map(_.toAttribute) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala index 9a3848cfc6b62..b8ba2ee428a20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala @@ -39,9 +39,9 @@ case class NativeCommand(cmd: String) extends Command { } /** - * Commands of the form "SET (key) (= value)". + * Commands of the form "SET [key [= value] ]". */ -case class SetCommand(key: Option[String], value: Option[String]) extends Command { +case class SetCommand(kv: Option[(String, Option[String])]) extends Command { override def output = Seq( AttributeReference("", StringType, nullable = false)()) } @@ -81,3 +81,14 @@ case class DescribeCommand( AttributeReference("data_type", StringType, nullable = false)(), AttributeReference("comment", StringType, nullable = false)()) } + +/** + * Returned for the "! shellCommand" command + */ +case class ShellCommand(cmd: String) extends Command + + +/** + * Returned for the "SOURCE file" command + */ +case class SourceCommand(filePath: String) extends Command diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index ac043d4dd8eb9..0cf139ebde417 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -17,73 +17,127 @@ package org.apache.spark.sql.catalyst.types -import java.sql.Timestamp +import java.sql.{Date, Timestamp} -import scala.math.Numeric.{FloatAsIfIntegral, BigDecimalAsIfIntegral, DoubleAsIfIntegral} +import scala.math.Numeric.{BigDecimalAsIfIntegral, DoubleAsIfIntegral, FloatAsIfIntegral} import scala.reflect.ClassTag -import scala.reflect.runtime.universe.{typeTag, TypeTag, runtimeMirror} +import scala.reflect.runtime.universe.{TypeTag, runtimeMirror, typeTag} import scala.util.parsing.combinator.RegexParsers +import org.json4s.JsonAST.JValue +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} import org.apache.spark.util.Utils -/** - * Utility functions for working with DataTypes. - */ -object DataType extends RegexParsers { - protected lazy val primitiveType: Parser[DataType] = - "StringType" ^^^ StringType | - "FloatType" ^^^ FloatType | - "IntegerType" ^^^ IntegerType | - "ByteType" ^^^ ByteType | - "ShortType" ^^^ ShortType | - "DoubleType" ^^^ DoubleType | - "LongType" ^^^ LongType | - "BinaryType" ^^^ BinaryType | - "BooleanType" ^^^ BooleanType | - "DecimalType" ^^^ DecimalType | - "TimestampType" ^^^ TimestampType - - protected lazy val arrayType: Parser[DataType] = - "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ { - case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull) - } - protected lazy val mapType: Parser[DataType] = - "MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ { - case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull) +object DataType { + def fromJson(json: String): DataType = parseDataType(parse(json)) + + private object JSortedObject { + def unapplySeq(value: JValue): Option[List[(String, JValue)]] = value match { + case JObject(seq) => Some(seq.toList.sortBy(_._1)) + case _ => None } + } + + // NOTE: Map fields must be sorted in alphabetical order to keep consistent with the Python side. + private def parseDataType(json: JValue): DataType = json match { + case JString(name) => + PrimitiveType.nameToType(name) + + case JSortedObject( + ("containsNull", JBool(n)), + ("elementType", t: JValue), + ("type", JString("array"))) => + ArrayType(parseDataType(t), n) + + case JSortedObject( + ("keyType", k: JValue), + ("type", JString("map")), + ("valueContainsNull", JBool(n)), + ("valueType", v: JValue)) => + MapType(parseDataType(k), parseDataType(v), n) + + case JSortedObject( + ("fields", JArray(fields)), + ("type", JString("struct"))) => + StructType(fields.map(parseStructField)) + } + + private def parseStructField(json: JValue): StructField = json match { + case JSortedObject( + ("name", JString(name)), + ("nullable", JBool(nullable)), + ("type", dataType: JValue)) => + StructField(name, parseDataType(dataType), nullable) + } + + @deprecated("Use DataType.fromJson instead") + def fromCaseClassString(string: String): DataType = CaseClassStringParser(string) + + private object CaseClassStringParser extends RegexParsers { + protected lazy val primitiveType: Parser[DataType] = + ( "StringType" ^^^ StringType + | "FloatType" ^^^ FloatType + | "IntegerType" ^^^ IntegerType + | "ByteType" ^^^ ByteType + | "ShortType" ^^^ ShortType + | "DoubleType" ^^^ DoubleType + | "LongType" ^^^ LongType + | "BinaryType" ^^^ BinaryType + | "BooleanType" ^^^ BooleanType + | "DecimalType" ^^^ DecimalType + | "TimestampType" ^^^ TimestampType + ) + + protected lazy val arrayType: Parser[DataType] = + "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ { + case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull) + } + + protected lazy val mapType: Parser[DataType] = + "MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ { + case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull) + } - protected lazy val structField: Parser[StructField] = - ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ { - case name ~ tpe ~ nullable => + protected lazy val structField: Parser[StructField] = + ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ { + case name ~ tpe ~ nullable => StructField(name, tpe, nullable = nullable) - } + } - protected lazy val boolVal: Parser[Boolean] = - "true" ^^^ true | - "false" ^^^ false + protected lazy val boolVal: Parser[Boolean] = + ( "true" ^^^ true + | "false" ^^^ false + ) - protected lazy val structType: Parser[DataType] = - "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ { - case fields => new StructType(fields) - } + protected lazy val structType: Parser[DataType] = + "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ { + case fields => new StructType(fields) + } - protected lazy val dataType: Parser[DataType] = - arrayType | - mapType | - structType | - primitiveType + protected lazy val dataType: Parser[DataType] = + ( arrayType + | mapType + | structType + | primitiveType + ) + + /** + * Parses a string representation of a DataType. + * + * TODO: Generate parser as pickler... + */ + def apply(asString: String): DataType = parseAll(dataType, asString) match { + case Success(result, _) => result + case failure: NoSuccess => + throw new IllegalArgumentException(s"Unsupported dataType: $asString, $failure") + } - /** - * Parses a string representation of a DataType. - * - * TODO: Generate parser as pickler... - */ - def apply(asString: String): DataType = parseAll(dataType, asString) match { - case Success(result, _) => result - case failure: NoSuccess => sys.error(s"Unsupported dataType: $asString, $failure") } protected[types] def buildFormattedString( @@ -111,15 +165,19 @@ abstract class DataType { def isPrimitive: Boolean = false - def simpleString: String -} + def typeName: String = this.getClass.getSimpleName.stripSuffix("$").dropRight(4).toLowerCase + + private[sql] def jsonValue: JValue = typeName -case object NullType extends DataType { - def simpleString: String = "null" + def json: String = compact(render(jsonValue)) + + def prettyJson: String = pretty(render(jsonValue)) } +case object NullType extends DataType + object NativeType { - def all = Seq( + val all = Seq( IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) def unapply(dt: DataType): Boolean = all.contains(dt) @@ -139,6 +197,12 @@ trait PrimitiveType extends DataType { override def isPrimitive = true } +object PrimitiveType { + private[sql] val all = Seq(DecimalType, TimestampType, BinaryType) ++ NativeType.all + + private[sql] val nameToType = all.map(t => t.typeName -> t).toMap +} + abstract class NativeType extends DataType { private[sql] type JvmType @transient private[sql] val tag: TypeTag[JvmType] @@ -154,7 +218,6 @@ case object StringType extends NativeType with PrimitiveType { private[sql] type JvmType = String @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val ordering = implicitly[Ordering[JvmType]] - def simpleString: String = "string" } case object BinaryType extends NativeType with PrimitiveType { @@ -166,17 +229,15 @@ case object BinaryType extends NativeType with PrimitiveType { val res = x(i).compareTo(y(i)) if (res != 0) return res } - return x.length - y.length + x.length - y.length } } - def simpleString: String = "binary" } case object BooleanType extends NativeType with PrimitiveType { private[sql] type JvmType = Boolean @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val ordering = implicitly[Ordering[JvmType]] - def simpleString: String = "boolean" } case object TimestampType extends NativeType { @@ -187,8 +248,16 @@ case object TimestampType extends NativeType { private[sql] val ordering = new Ordering[JvmType] { def compare(x: Timestamp, y: Timestamp) = x.compareTo(y) } +} + +case object DateType extends NativeType { + private[sql] type JvmType = Date + + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - def simpleString: String = "timestamp" + private[sql] val ordering = new Ordering[JvmType] { + def compare(x: Date, y: Date) = x.compareTo(y) + } } abstract class NumericType extends NativeType with PrimitiveType { @@ -222,7 +291,6 @@ case object LongType extends IntegralType { private[sql] val numeric = implicitly[Numeric[Long]] private[sql] val integral = implicitly[Integral[Long]] private[sql] val ordering = implicitly[Ordering[JvmType]] - def simpleString: String = "long" } case object IntegerType extends IntegralType { @@ -231,7 +299,6 @@ case object IntegerType extends IntegralType { private[sql] val numeric = implicitly[Numeric[Int]] private[sql] val integral = implicitly[Integral[Int]] private[sql] val ordering = implicitly[Ordering[JvmType]] - def simpleString: String = "integer" } case object ShortType extends IntegralType { @@ -240,7 +307,6 @@ case object ShortType extends IntegralType { private[sql] val numeric = implicitly[Numeric[Short]] private[sql] val integral = implicitly[Integral[Short]] private[sql] val ordering = implicitly[Ordering[JvmType]] - def simpleString: String = "short" } case object ByteType extends IntegralType { @@ -249,7 +315,6 @@ case object ByteType extends IntegralType { private[sql] val numeric = implicitly[Numeric[Byte]] private[sql] val integral = implicitly[Integral[Byte]] private[sql] val ordering = implicitly[Ordering[JvmType]] - def simpleString: String = "byte" } /** Matcher for any expressions that evaluate to [[FractionalType]]s */ @@ -271,7 +336,6 @@ case object DecimalType extends FractionalType { private[sql] val fractional = implicitly[Fractional[BigDecimal]] private[sql] val ordering = implicitly[Ordering[JvmType]] private[sql] val asIntegral = BigDecimalAsIfIntegral - def simpleString: String = "decimal" } case object DoubleType extends FractionalType { @@ -281,7 +345,6 @@ case object DoubleType extends FractionalType { private[sql] val fractional = implicitly[Fractional[Double]] private[sql] val ordering = implicitly[Ordering[JvmType]] private[sql] val asIntegral = DoubleAsIfIntegral - def simpleString: String = "double" } case object FloatType extends FractionalType { @@ -291,7 +354,6 @@ case object FloatType extends FractionalType { private[sql] val fractional = implicitly[Fractional[Float]] private[sql] val ordering = implicitly[Ordering[JvmType]] private[sql] val asIntegral = FloatAsIfIntegral - def simpleString: String = "float" } object ArrayType { @@ -309,11 +371,14 @@ object ArrayType { case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType { private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { builder.append( - s"${prefix}-- element: ${elementType.simpleString} (containsNull = ${containsNull})\n") + s"$prefix-- element: ${elementType.typeName} (containsNull = $containsNull)\n") DataType.buildFormattedString(elementType, s"$prefix |", builder) } - def simpleString: String = "array" + override private[sql] def jsonValue = + ("type" -> typeName) ~ + ("elementType" -> elementType.jsonValue) ~ + ("containsNull" -> containsNull) } /** @@ -325,9 +390,15 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT case class StructField(name: String, dataType: DataType, nullable: Boolean) { private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { - builder.append(s"${prefix}-- ${name}: ${dataType.simpleString} (nullable = ${nullable})\n") + builder.append(s"$prefix-- $name: ${dataType.typeName} (nullable = $nullable)\n") DataType.buildFormattedString(dataType, s"$prefix |", builder) } + + private[sql] def jsonValue: JValue = { + ("name" -> name) ~ + ("type" -> dataType.jsonValue) ~ + ("nullable" -> nullable) + } } object StructType { @@ -348,8 +419,7 @@ case class StructType(fields: Seq[StructField]) extends DataType { * have a name matching the given name, `null` will be returned. */ def apply(name: String): StructField = { - nameToField.get(name).getOrElse( - throw new IllegalArgumentException(s"Field ${name} does not exist.")) + nameToField.getOrElse(name, throw new IllegalArgumentException(s"Field $name does not exist.")) } /** @@ -358,7 +428,7 @@ case class StructType(fields: Seq[StructField]) extends DataType { */ def apply(names: Set[String]): StructType = { val nonExistFields = names -- fieldNamesSet - if (!nonExistFields.isEmpty) { + if (nonExistFields.nonEmpty) { throw new IllegalArgumentException( s"Field ${nonExistFields.mkString(",")} does not exist.") } @@ -384,7 +454,9 @@ case class StructType(fields: Seq[StructField]) extends DataType { fields.foreach(field => field.buildFormattedString(prefix, builder)) } - def simpleString: String = "struct" + override private[sql] def jsonValue = + ("type" -> typeName) ~ + ("fields" -> fields.map(_.jsonValue)) } object MapType { @@ -407,12 +479,16 @@ case class MapType( valueType: DataType, valueContainsNull: Boolean) extends DataType { private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { - builder.append(s"${prefix}-- key: ${keyType.simpleString}\n") - builder.append(s"${prefix}-- value: ${valueType.simpleString} " + - s"(valueContainsNull = ${valueContainsNull})\n") + builder.append(s"$prefix-- key: ${keyType.typeName}\n") + builder.append(s"$prefix-- value: ${valueType.typeName} " + + s"(valueContainsNull = $valueContainsNull)\n") DataType.buildFormattedString(keyType, s"$prefix |", builder) DataType.buildFormattedString(valueType, s"$prefix |", builder) } - def simpleString: String = "map" + override private[sql] def jsonValue: JValue = + ("type" -> typeName) ~ + ("keyType" -> keyType.jsonValue) ~ + ("valueType" -> valueType.jsonValue) ~ + ("valueContainsNull" -> valueContainsNull) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 5809a108ff62e..7b45738c4fc95 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -19,10 +19,11 @@ package org.apache.spark.sql.catalyst.analysis import org.scalatest.{BeforeAndAfter, FunSuite} -import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.types.IntegerType +import org.apache.spark.sql.catalyst.types._ class AnalysisSuite extends FunSuite with BeforeAndAfter { val caseSensitiveCatalog = new SimpleCatalog(true) @@ -33,6 +34,12 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseSensitive = false) val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) + val testRelation2 = LocalRelation( + AttributeReference("a", StringType)(), + AttributeReference("b", StringType)(), + AttributeReference("c", DoubleType)(), + AttributeReference("d", DecimalType)(), + AttributeReference("e", ShortType)()) before { caseSensitiveCatalog.registerTable(None, "TaBlE", testRelation) @@ -74,7 +81,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { val e = intercept[RuntimeException] { caseSensitiveAnalyze(UnresolvedRelation(None, "tAbLe", None)) } - assert(e.getMessage === "Table Not Found: tAbLe") + assert(e.getMessage == "Table Not Found: tAbLe") assert( caseSensitiveAnalyze(UnresolvedRelation(None, "TaBlE", None)) === @@ -106,4 +113,31 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { } assert(e.getMessage().toLowerCase.contains("unresolved plan")) } + + test("divide should be casted into fractional types") { + val testRelation2 = LocalRelation( + AttributeReference("a", StringType)(), + AttributeReference("b", StringType)(), + AttributeReference("c", DoubleType)(), + AttributeReference("d", DecimalType)(), + AttributeReference("e", ShortType)()) + + val expr0 = 'a / 2 + val expr1 = 'a / 'b + val expr2 = 'a / 'c + val expr3 = 'a / 'd + val expr4 = 'e / 'e + val plan = caseInsensitiveAnalyze(Project( + Alias(expr0, s"Analyzer($expr0)")() :: + Alias(expr1, s"Analyzer($expr1)")() :: + Alias(expr2, s"Analyzer($expr2)")() :: + Alias(expr3, s"Analyzer($expr3)")() :: + Alias(expr4, s"Analyzer($expr4)")() :: Nil, testRelation2)) + val pl = plan.asInstanceOf[Project].projectList + assert(pl(0).dataType == DoubleType) + assert(pl(1).dataType == DoubleType) + assert(pl(2).dataType == DoubleType) + assert(pl(3).dataType == DecimalType) + assert(pl(4).dataType == DoubleType) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 63931af4bac3d..6dc5942023f9e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.catalyst.expressions -import java.sql.Timestamp +import java.sql.{Date, Timestamp} + +import scala.collection.immutable.HashSet import org.scalatest.FunSuite import org.scalatest.Matchers._ @@ -25,6 +27,7 @@ import org.scalautils.TripleEqualsSupport.Spread import org.apache.spark.sql.catalyst.types._ + /* Implicit conversions */ import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -145,6 +148,24 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))) && In(Literal(2), Seq(Literal(1), Literal(2))), true) } + test("INSET") { + val hS = HashSet[Any]() + 1 + 2 + val nS = HashSet[Any]() + 1 + 2 + null + val one = Literal(1) + val two = Literal(2) + val three = Literal(3) + val nl = Literal(null) + val s = Seq(one, two) + val nullS = Seq(one, two, null) + checkEvaluation(InSet(one, hS, one +: s), true) + checkEvaluation(InSet(two, hS, two +: s), true) + checkEvaluation(InSet(two, nS, two +: nullS), true) + checkEvaluation(InSet(nl, nS, nl +: nullS), true) + checkEvaluation(InSet(three, hS, three +: s), false) + checkEvaluation(InSet(three, nS, three +: nullS), false) + checkEvaluation(InSet(one, hS, one +: s) && InSet(two, hS, two +: s), true) + } + test("MaxOf") { checkEvaluation(MaxOf(1, 2), 2) checkEvaluation(MaxOf(2, 1), 2) @@ -231,8 +252,11 @@ class ExpressionEvaluationSuite extends FunSuite { test("data type casting") { - val sts = "1970-01-01 00:00:01.1" - val ts = Timestamp.valueOf(sts) + val sd = "1970-01-01" + val d = Date.valueOf(sd) + val sts = sd + " 00:00:02" + val nts = sts + ".1" + val ts = Timestamp.valueOf(nts) checkEvaluation("abdef" cast StringType, "abdef") checkEvaluation("abdef" cast DecimalType, null) @@ -245,8 +269,15 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Cast(Literal(1.toDouble) cast TimestampType, DoubleType), 1.toDouble) checkEvaluation(Cast(Literal(1.toDouble) cast TimestampType, DoubleType), 1.toDouble) - checkEvaluation(Cast(Literal(sts) cast TimestampType, StringType), sts) + checkEvaluation(Cast(Literal(sd) cast DateType, StringType), sd) + checkEvaluation(Cast(Literal(d) cast StringType, DateType), d) + checkEvaluation(Cast(Literal(nts) cast TimestampType, StringType), nts) checkEvaluation(Cast(Literal(ts) cast StringType, TimestampType), ts) + // all convert to string type to check + checkEvaluation( + Cast(Cast(Literal(nts) cast TimestampType, DateType), StringType), sd) + checkEvaluation( + Cast(Cast(Literal(ts) cast DateType, TimestampType), StringType), sts) checkEvaluation(Cast("abdef" cast BinaryType, StringType), "abdef") @@ -295,6 +326,12 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Cast(Literal(null, IntegerType), ShortType), null) } + test("date") { + val d1 = Date.valueOf("1970-01-01") + val d2 = Date.valueOf("1970-01-02") + checkEvaluation(Literal(d1) < Literal(d2), true) + } + test("timestamp") { val ts1 = new Timestamp(12) val ts2 = new Timestamp(123) @@ -302,6 +339,17 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Literal(ts1) < Literal(ts2), true) } + test("date casting") { + val d = Date.valueOf("1970-01-01") + checkEvaluation(Cast(d, ShortType), null) + checkEvaluation(Cast(d, IntegerType), null) + checkEvaluation(Cast(d, LongType), null) + checkEvaluation(Cast(d, FloatType), null) + checkEvaluation(Cast(d, DoubleType), null) + checkEvaluation(Cast(d, StringType), "1970-01-01") + checkEvaluation(Cast(Cast(d, TimestampType), StringType), "1970-01-01 00:00:00") + } + test("timestamp casting") { val millis = 15 * 1000 + 2 val seconds = millis * 1000 + 2 diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala index 245a2e148030c..ef3114fd4dbab 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala @@ -15,9 +15,8 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.optimizer +package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala index 887aabb1d5fb4..275ea2627ebcd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala @@ -15,9 +15,8 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.optimizer +package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala new file mode 100644 index 0000000000000..97a78ec971c39 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.collection.immutable.HashSet +import org.apache.spark.sql.catalyst.analysis.{EliminateAnalysisOperators, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.types._ + +// For implicit conversions +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ + +class OptimizeInSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("AnalysisNodes", Once, + EliminateAnalysisOperators) :: + Batch("ConstantFolding", Once, + ConstantFolding, + BooleanSimplification, + OptimizeIn) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + test("OptimizedIn test: In clause optimized to InSet") { + val originalQuery = + testRelation + .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2)))) + .analyze + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = + testRelation + .where(InSet(UnresolvedAttribute("a"), HashSet[Any]()+1+2, + UnresolvedAttribute("a") +: Seq(Literal(1),Literal(2)))) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("OptimizedIn test: In clause not optimized in case filter has attributes") { + val originalQuery = + testRelation + .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2), UnresolvedAttribute("b")))) + .analyze + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = + testRelation + .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2), UnresolvedAttribute("b")))) + .analyze + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java index 37b4c8ffcba0b..37e88d72b9172 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java @@ -44,6 +44,11 @@ public abstract class DataType { */ public static final BooleanType BooleanType = new BooleanType(); + /** + * Gets the DateType object. + */ + public static final DateType DateType = new DateType(); + /** * Gets the TimestampType object. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DateType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DateType.java new file mode 100644 index 0000000000000..6677793baa365 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DateType.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +/** + * The data type representing java.sql.Date values. + * + * {@code DateType} is represented by the singleton object {@link DataType#DateType}. + */ +public class DateType extends DataType { + protected DateType() {} +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala index 3bf7382ac67a6..5ab2b5316ab10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala @@ -22,7 +22,7 @@ import java.util.concurrent.locks.ReentrantReadWriteLock import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.columnar.InMemoryRelation import org.apache.spark.storage.StorageLevel -import org.apache.spark.storage.StorageLevel.MEMORY_ONLY +import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK /** Holds a cached logical plan and its data */ private case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation) @@ -74,10 +74,14 @@ private[sql] trait CacheManager { cachedData.clear() } - /** Caches the data produced by the logical representation of the given schema rdd. */ + /** + * Caches the data produced by the logical representation of the given schema rdd. Unlike + * `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because recomputing + * the in-memory columnar representation of the underlying table is expensive. + */ private[sql] def cacheQuery( query: SchemaRDD, - storageLevel: StorageLevel = MEMORY_ONLY): Unit = writeLock { + storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock { val planToCache = query.queryExecution.optimizedPlan if (lookupCachedData(planToCache).nonEmpty) { logWarning("Asked to cache already cached data.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index f6f4cf3b80d41..07e6e2eccddf4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -35,6 +35,7 @@ private[spark] object SQLConf { val PARQUET_BINARY_AS_STRING = "spark.sql.parquet.binaryAsString" val PARQUET_CACHE_METADATA = "spark.sql.parquet.cacheMetadata" val PARQUET_COMPRESSION = "spark.sql.parquet.compression.codec" + val COLUMN_NAME_OF_CORRUPT_RECORD = "spark.sql.columnNameOfCorruptRecord" // This is only used for the thriftserver val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool" @@ -131,6 +132,9 @@ private[sql] trait SQLConf { private[spark] def inMemoryPartitionPruning: Boolean = getConf(IN_MEMORY_PARTITION_PRUNING, "false").toBoolean + private[spark] def columnNameOfCorruptRecord: String = + getConf(COLUMN_NAME_OF_CORRUPT_RECORD, "_corrupt_record") + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 1f6ba851891ac..23e7b2d270777 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.hadoop.conf.Configuration +import org.apache.spark.SparkContext import org.apache.spark.annotation.{AlphaComponent, DeveloperApi, Experimental} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.ScalaReflection @@ -31,12 +32,11 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.types.DataType import org.apache.spark.sql.columnar.InMemoryRelation -import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.SparkStrategies +import org.apache.spark.sql.execution.{SparkStrategies, _} import org.apache.spark.sql.json._ import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.{Logging, SparkContext} /** * :: AlphaComponent :: @@ -66,13 +66,17 @@ class SQLContext(@transient val sparkContext: SparkContext) @transient protected[sql] lazy val analyzer: Analyzer = new Analyzer(catalog, functionRegistry, caseSensitive = true) + @transient protected[sql] val optimizer = Optimizer - @transient - protected[sql] def parser = new catalyst.SqlParser - protected[sql] def parseSql(sql: String): LogicalPlan = parser(sql) + @transient + protected[sql] val sqlParser = { + val fallback = new catalyst.SqlParser + new catalyst.SparkSQLParser(fallback(_)) + } + protected[sql] def parseSql(sql: String): LogicalPlan = sqlParser(sql) protected[sql] def executeSql(sql: String): this.QueryExecution = executePlan(parseSql(sql)) protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution { val logical = plan } @@ -196,9 +200,12 @@ class SQLContext(@transient val sparkContext: SparkContext) */ @Experimental def jsonRDD(json: RDD[String], schema: StructType): SchemaRDD = { + val columnNameOfCorruptJsonRecord = columnNameOfCorruptRecord val appliedSchema = - Option(schema).getOrElse(JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, 1.0))) - val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema) + Option(schema).getOrElse( + JsonRDD.nullTypeToStringType( + JsonRDD.inferSchema(json, 1.0, columnNameOfCorruptJsonRecord))) + val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord) applySchema(rowRDD, appliedSchema) } @@ -207,8 +214,11 @@ class SQLContext(@transient val sparkContext: SparkContext) */ @Experimental def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD = { - val appliedSchema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, samplingRatio)) - val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema) + val columnNameOfCorruptJsonRecord = columnNameOfCorruptRecord + val appliedSchema = + JsonRDD.nullTypeToStringType( + JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord)) + val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord) applySchema(rowRDD, appliedSchema) } @@ -410,8 +420,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * It is only used by PySpark. */ private[sql] def parseDataType(dataTypeString: String): DataType = { - val parser = org.apache.spark.sql.catalyst.types.DataType - parser(dataTypeString) + DataType.fromJson(dataTypeString) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 594bf8ffc20e1..948122d42f0e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -360,7 +360,7 @@ class SchemaRDD( join: Boolean = false, outer: Boolean = false, alias: Option[String] = None) = - new SchemaRDD(sqlContext, Generate(generator, join, outer, None, logicalPlan)) + new SchemaRDD(sqlContext, Generate(generator, join, outer, alias, logicalPlan)) /** * Returns this RDD as a SchemaRDD. Intended primarily to force the invocation of the implicit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala index b36d8b7438283..6b585e2fa314d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala @@ -123,7 +123,7 @@ private[sql] trait SchemaRDDLike { * @group schema */ @Experimental - def saveAsTable(tableName: String): RDD[Row] = + def saveAsTable(tableName: String): Unit = sqlContext.executePlan(CreateTableAsSelect(None, tableName, logicalPlan)).toRdd /** Returns the schema as a string in the tree format. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index c006c4330ff66..f8171c3be3207 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -148,8 +148,12 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { * It goes through the entire dataset once to determine the schema. */ def jsonRDD(json: JavaRDD[String]): JavaSchemaRDD = { - val appliedScalaSchema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0)) - val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema) + val columnNameOfCorruptJsonRecord = sqlContext.columnNameOfCorruptRecord + val appliedScalaSchema = + JsonRDD.nullTypeToStringType( + JsonRDD.inferSchema(json.rdd, 1.0, columnNameOfCorruptJsonRecord)) + val scalaRowRDD = + JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema, columnNameOfCorruptJsonRecord) val logicalPlan = LogicalRDD(appliedScalaSchema.toAttributes, scalaRowRDD)(sqlContext) new JavaSchemaRDD(sqlContext, logicalPlan) @@ -162,10 +166,14 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { */ @Experimental def jsonRDD(json: JavaRDD[String], schema: StructType): JavaSchemaRDD = { + val columnNameOfCorruptJsonRecord = sqlContext.columnNameOfCorruptRecord val appliedScalaSchema = Option(asScalaDataType(schema)).getOrElse( - JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0))).asInstanceOf[SStructType] - val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema) + JsonRDD.nullTypeToStringType( + JsonRDD.inferSchema( + json.rdd, 1.0, columnNameOfCorruptJsonRecord))).asInstanceOf[SStructType] + val scalaRowRDD = JsonRDD.jsonStringToRow( + json.rdd, appliedScalaSchema, columnNameOfCorruptJsonRecord) val logicalPlan = LogicalRDD(appliedScalaSchema.toAttributes, scalaRowRDD)(sqlContext) new JavaSchemaRDD(sqlContext, logicalPlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala index e9d04ce7aae4c..df01411f60a05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala @@ -22,6 +22,7 @@ import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper} import scala.collection.JavaConversions import scala.math.BigDecimal +import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap import org.apache.spark.sql.catalyst.expressions.{Row => ScalaRow} /** @@ -114,7 +115,7 @@ object Row { // they are actually accessed. case row: ScalaRow => new Row(row) case map: scala.collection.Map[_, _] => - JavaConversions.mapAsJavaMap( + mapAsSerializableJavaMap( map.map { case (key, value) => (toJavaValue(key), toJavaValue(value)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala index c9faf0852142a..538dd5b734664 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala @@ -92,6 +92,9 @@ private[sql] class FloatColumnAccessor(buffer: ByteBuffer) private[sql] class StringColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, STRING) +private[sql] class DateColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, DATE) + private[sql] class TimestampColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, TIMESTAMP) @@ -118,6 +121,7 @@ private[sql] object ColumnAccessor { case BYTE.typeId => new ByteColumnAccessor(dup) case SHORT.typeId => new ShortColumnAccessor(dup) case STRING.typeId => new StringColumnAccessor(dup) + case DATE.typeId => new DateColumnAccessor(dup) case TIMESTAMP.typeId => new TimestampColumnAccessor(dup) case BINARY.typeId => new BinaryColumnAccessor(dup) case GENERIC.typeId => new GenericColumnAccessor(dup) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index 2e61a981375aa..300cef15bf8a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -107,6 +107,8 @@ private[sql] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColum private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING) +private[sql] class DateColumnBuilder extends NativeColumnBuilder(new DateColumnStats, DATE) + private[sql] class TimestampColumnBuilder extends NativeColumnBuilder(new TimestampColumnStats, TIMESTAMP) @@ -151,6 +153,7 @@ private[sql] object ColumnBuilder { case STRING.typeId => new StringColumnBuilder case BINARY.typeId => new BinaryColumnBuilder case GENERIC.typeId => new GenericColumnBuilder + case DATE.typeId => new DateColumnBuilder case TIMESTAMP.typeId => new TimestampColumnBuilder }).asInstanceOf[ColumnBuilder] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index 203a714e03c97..b34ab255d084a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.columnar -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute, AttributeReference} @@ -190,6 +190,24 @@ private[sql] class StringColumnStats extends ColumnStats { def collectedStatistics = Row(lower, upper, nullCount) } +private[sql] class DateColumnStats extends ColumnStats { + var upper: Date = null + var lower: Date = null + var nullCount = 0 + + override def gatherStats(row: Row, ordinal: Int) { + if (!row.isNullAt(ordinal)) { + val value = row(ordinal).asInstanceOf[Date] + if (upper == null || value.compareTo(upper) > 0) upper = value + if (lower == null || value.compareTo(lower) < 0) lower = value + } else { + nullCount += 1 + } + } + + def collectedStatistics = Row(lower, upper, nullCount) +} + private[sql] class TimestampColumnStats extends ColumnStats { var upper: Timestamp = null var lower: Timestamp = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 198b5756676aa..ab66c85c4f242 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import scala.reflect.runtime.universe.TypeTag @@ -335,7 +335,26 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { } } -private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 8, 12) { +private[sql] object DATE extends NativeColumnType(DateType, 8, 8) { + override def extract(buffer: ByteBuffer) = { + val date = new Date(buffer.getLong()) + date + } + + override def append(v: Date, buffer: ByteBuffer): Unit = { + buffer.putLong(v.getTime) + } + + override def getField(row: Row, ordinal: Int) = { + row(ordinal).asInstanceOf[Date] + } + + override def setField(row: MutableRow, ordinal: Int, value: Date): Unit = { + row(ordinal) = value + } +} + +private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 12) { override def extract(buffer: ByteBuffer) = { val timestamp = new Timestamp(buffer.getLong()) timestamp.setNanos(buffer.getInt()) @@ -376,7 +395,7 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType]( } } -private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](9, 16) { +private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](10, 16) { override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = { row(ordinal) = value } @@ -387,7 +406,7 @@ private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](9, 16) { // Used to process generic objects (all types other than those listed above). Objects should be // serialized first before appending to the column `ByteBuffer`, and is also extracted as serialized // byte array. -private[sql] object GENERIC extends ByteArrayColumnType[DataType](10, 16) { +private[sql] object GENERIC extends ByteArrayColumnType[DataType](11, 16) { override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = { row(ordinal) = SparkSqlSerializer.deserialize[Any](value) } @@ -407,6 +426,7 @@ private[sql] object ColumnType { case ShortType => SHORT case StringType => STRING case BinaryType => BINARY + case DateType => DATE case TimestampType => TIMESTAMP case _ => GENERIC } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 4f79173a26f88..22ab0e2613f21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -38,7 +38,7 @@ private[sql] object InMemoryRelation { new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child)() } -private[sql] case class CachedBatch(buffers: Array[ByteBuffer], stats: Row) +private[sql] case class CachedBatch(buffers: Array[Array[Byte]], stats: Row) private[sql] case class InMemoryRelation( output: Seq[Attribute], @@ -91,7 +91,7 @@ private[sql] case class InMemoryRelation( val stats = Row.fromSeq( columnBuilders.map(_.columnStats.collectedStatistics).foldLeft(Seq.empty[Any])(_ ++ _)) - CachedBatch(columnBuilders.map(_.build()), stats) + CachedBatch(columnBuilders.map(_.build().array()), stats) } def hasNext = rowIterator.hasNext @@ -238,8 +238,9 @@ private[sql] case class InMemoryColumnarTableScan( def cachedBatchesToRows(cacheBatches: Iterator[CachedBatch]) = { val rows = cacheBatches.flatMap { cachedBatch => // Build column accessors - val columnAccessors = - requestedColumnIndices.map(cachedBatch.buffers(_)).map(ColumnAccessor(_)) + val columnAccessors = requestedColumnIndices.map { batch => + ColumnAccessor(ByteBuffer.wrap(cachedBatch.buffers(batch))) + } // Extract rows via column accessors new Iterator[Row] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index c386fd121c5de..38877c28de3a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -39,7 +39,8 @@ case class Generate( child: SparkPlan) extends UnaryNode { - protected def generatorOutput: Seq[Attribute] = { + // This must be a val since the generator output expr ids are not preserved by serialization. + protected val generatorOutput: Seq[Attribute] = { if (join && outer) { generator.output.map(_.withNullability(true)) } else { @@ -62,7 +63,7 @@ case class Generate( newProjection(child.output ++ nullValues, child.output) val joinProjection = - newProjection(child.output ++ generator.output, child.output ++ generator.output) + newProjection(child.output ++ generatorOutput, child.output ++ generatorOutput) val joinedRow = new JoinedRow iter.flatMap {row => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5c16d0c624128..79e4ddb8c4f5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan} import org.apache.spark.sql.parquet._ + private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SQLContext#SparkPlanner => @@ -34,13 +35,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // Find left semi joins where at least some predicates can be evaluated by matching join keys case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) => - val semiJoin = execution.LeftSemiJoinHash( + val semiJoin = joins.LeftSemiJoinHash( leftKeys, rightKeys, planLater(left), planLater(right)) condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil // no predicate can be evaluated by matching hash keys case logical.Join(left, right, LeftSemi, condition) => - execution.LeftSemiJoinBNL( - planLater(left), planLater(right), condition) :: Nil + joins.LeftSemiJoinBNL(planLater(left), planLater(right), condition) :: Nil case _ => Nil } } @@ -50,13 +50,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * evaluated by matching hash keys. * * This strategy applies a simple optimization based on the estimates of the physical sizes of - * the two join sides. When planning a [[execution.BroadcastHashJoin]], if one side has an + * the two join sides. When planning a [[joins.BroadcastHashJoin]], if one side has an * estimated physical size smaller than the user-settable threshold * [[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]], the planner would mark it as the * ''build'' relation and mark the other relation as the ''stream'' side. The build table will be * ''broadcasted'' to all of the executors involved in the join, as a * [[org.apache.spark.broadcast.Broadcast]] object. If both estimates exceed the threshold, they - * will instead be used to decide the build side in a [[execution.ShuffledHashJoin]]. + * will instead be used to decide the build side in a [[joins.ShuffledHashJoin]]. */ object HashJoin extends Strategy with PredicateHelper { @@ -66,8 +66,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { left: LogicalPlan, right: LogicalPlan, condition: Option[Expression], - side: BuildSide) = { - val broadcastHashJoin = execution.BroadcastHashJoin( + side: joins.BuildSide) = { + val broadcastHashJoin = execution.joins.BroadcastHashJoin( leftKeys, rightKeys, side, planLater(left), planLater(right)) condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil } @@ -76,27 +76,26 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) if sqlContext.autoBroadcastJoinThreshold > 0 && right.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold => - makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildRight) + makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) if sqlContext.autoBroadcastJoinThreshold > 0 && left.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold => - makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildLeft) + makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => val buildSide = if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { - BuildRight + joins.BuildRight } else { - BuildLeft + joins.BuildLeft } - val hashJoin = - execution.ShuffledHashJoin( - leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) + val hashJoin = joins.ShuffledHashJoin( + leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => - execution.HashOuterJoin( + joins.HashOuterJoin( leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil case _ => Nil @@ -164,8 +163,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Join(left, right, joinType, condition) => val buildSide = - if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) BuildRight else BuildLeft - execution.BroadcastNestedLoopJoin( + if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { + joins.BuildRight + } else { + joins.BuildLeft + } + joins.BroadcastNestedLoopJoin( planLater(left), planLater(right), buildSide, joinType, condition) :: Nil case _ => Nil } @@ -174,10 +177,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object CartesianProduct extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Join(left, right, _, None) => - execution.CartesianProduct(planLater(left), planLater(right)) :: Nil + execution.joins.CartesianProduct(planLater(left), planLater(right)) :: Nil case logical.Join(left, right, Inner, Some(condition)) => execution.Filter(condition, - execution.CartesianProduct(planLater(left), planLater(right))) :: Nil + execution.joins.CartesianProduct(planLater(left), planLater(right))) :: Nil case _ => Nil } } @@ -274,9 +277,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil case SparkLogicalPlan(alreadyPlanned) => alreadyPlanned :: Nil case logical.LocalRelation(output, data) => + val nPartitions = if (data.isEmpty) 1 else numPartitions PhysicalRDD( output, - RDDConversions.productToRowRdd(sparkContext.parallelize(data, numPartitions))) :: Nil + RDDConversions.productToRowRdd(sparkContext.parallelize(data, nPartitions))) :: Nil case logical.Limit(IntegerLiteral(limit), child) => execution.Limit(limit, planLater(child)) :: Nil case Unions(unionChildren) => @@ -291,7 +295,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.PhysicalRDD(Nil, singleRowRdd) :: Nil case logical.Repartition(expressions, child) => execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil - case e @ EvaluatePython(udf, child) => + case e @ EvaluatePython(udf, child, _) => BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil case _ => Nil @@ -300,8 +304,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case class CommandStrategy(context: SQLContext) extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.SetCommand(key, value) => - Seq(execution.SetCommand(key, value, plan.output)(context)) + case logical.SetCommand(kv) => + Seq(execution.SetCommand(kv, plan.output)(context)) case logical.ExplainCommand(logicalPlan, extended) => Seq(execution.ExplainCommand(logicalPlan, plan.output, extended)(context)) case logical.CacheTableCommand(tableName, optPlan, isLazy) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index d49633c24ad4d..5859eba408ee1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -48,29 +48,28 @@ trait Command { * :: DeveloperApi :: */ @DeveloperApi -case class SetCommand( - key: Option[String], value: Option[String], output: Seq[Attribute])( +case class SetCommand(kv: Option[(String, Option[String])], output: Seq[Attribute])( @transient context: SQLContext) extends LeafNode with Command with Logging { - override protected lazy val sideEffectResult: Seq[Row] = (key, value) match { - // Set value for key k. - case (Some(k), Some(v)) => - if (k == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) { + override protected lazy val sideEffectResult: Seq[Row] = kv match { + // Set value for the key. + case Some((key, Some(value))) => + if (key == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) { logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS} instead.") - context.setConf(SQLConf.SHUFFLE_PARTITIONS, v) - Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=$v")) + context.setConf(SQLConf.SHUFFLE_PARTITIONS, value) + Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=$value")) } else { - context.setConf(k, v) - Seq(Row(s"$k=$v")) + context.setConf(key, value) + Seq(Row(s"$key=$value")) } - // Query the value bound to key k. - case (Some(k), _) => + // Query the value bound to the key. + case Some((key, None)) => // TODO (lian) This is just a workaround to make the Simba ODBC driver work. // Should remove this once we get the ODBC driver updated. - if (k == "-v") { + if (key == "-v") { val hiveJars = Seq( "hive-exec-0.12.0.jar", "hive-service-0.12.0.jar", @@ -84,23 +83,20 @@ case class SetCommand( Row("system:java.class.path=" + hiveJars), Row("system:sun.java.command=shark.SharkServer2")) } else { - if (k == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) { + if (key == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) { logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + s"showing ${SQLConf.SHUFFLE_PARTITIONS} instead.") Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=${context.numShufflePartitions}")) } else { - Seq(Row(s"$k=${context.getConf(k, "")}")) + Seq(Row(s"$key=${context.getConf(key, "")}")) } } // Query all key-value pairs that are set in the SQLConf of the context. - case (None, None) => + case _ => context.getAllConfs.map { case (k, v) => Row(s"$k=$v") }.toSeq - - case _ => - throw new IllegalArgumentException() } override def otherCopyArgs = context :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index a9535a750bcd7..61be5ed2db65c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -24,6 +24,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.SparkContext._ import org.apache.spark.sql.{SchemaRDD, Row} import org.apache.spark.sql.catalyst.trees.TreeNodeRef +import org.apache.spark.sql.catalyst.types._ /** * :: DeveloperApi :: @@ -56,6 +57,23 @@ package object debug { case _ => } } + + def typeCheck(): Unit = { + val plan = query.queryExecution.executedPlan + val visited = new collection.mutable.HashSet[TreeNodeRef]() + val debugPlan = plan transform { + case s: SparkPlan if !visited.contains(new TreeNodeRef(s)) => + visited += new TreeNodeRef(s) + TypeCheck(s) + } + try { + println(s"Results returned: ${debugPlan.execute().count()}") + } catch { + case e: Exception => + def unwrap(e: Throwable): Throwable = if (e.getCause == null) e else unwrap(e.getCause) + println(s"Deepest Error: ${unwrap(e)}") + } + } } private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode { @@ -115,4 +133,71 @@ package object debug { } } } + + /** + * :: DeveloperApi :: + * Helper functions for checking that runtime types match a given schema. + */ + @DeveloperApi + object TypeCheck { + def typeCheck(data: Any, schema: DataType): Unit = (data, schema) match { + case (null, _) => + + case (row: Row, StructType(fields)) => + row.zip(fields.map(_.dataType)).foreach { case(d,t) => typeCheck(d,t) } + case (s: Seq[_], ArrayType(elemType, _)) => + s.foreach(typeCheck(_, elemType)) + case (m: Map[_, _], MapType(keyType, valueType, _)) => + m.keys.foreach(typeCheck(_, keyType)) + m.values.foreach(typeCheck(_, valueType)) + + case (_: Long, LongType) => + case (_: Int, IntegerType) => + case (_: String, StringType) => + case (_: Float, FloatType) => + case (_: Byte, ByteType) => + case (_: Short, ShortType) => + case (_: Boolean, BooleanType) => + case (_: Double, DoubleType) => + + case (d, t) => sys.error(s"Invalid data found: got $d (${d.getClass}) expected $t") + } + } + + /** + * :: DeveloperApi :: + * Augments SchemaRDDs with debug methods. + */ + @DeveloperApi + private[sql] case class TypeCheck(child: SparkPlan) extends SparkPlan { + import TypeCheck._ + + override def nodeName = "" + + /* Only required when defining this class in a REPL. + override def makeCopy(args: Array[Object]): this.type = + TypeCheck(args(0).asInstanceOf[SparkPlan]).asInstanceOf[this.type] + */ + + def output = child.output + + def children = child :: Nil + + def execute() = { + child.execute().map { row => + try typeCheck(row, child.schema) catch { + case e: Exception => + sys.error( + s""" + |ERROR WHEN TYPE CHECKING QUERY + |============================== + |$e + |======== BAD TREE ============ + |$child + """.stripMargin) + } + row + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala deleted file mode 100644 index 2890a563bed48..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ /dev/null @@ -1,624 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import java.util.{HashMap => JavaHashMap} - -import scala.concurrent.ExecutionContext.Implicits.global -import scala.concurrent._ -import scala.concurrent.duration._ - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.util.collection.CompactBuffer - -@DeveloperApi -sealed abstract class BuildSide - -@DeveloperApi -case object BuildLeft extends BuildSide - -@DeveloperApi -case object BuildRight extends BuildSide - -trait HashJoin { - self: SparkPlan => - - val leftKeys: Seq[Expression] - val rightKeys: Seq[Expression] - val buildSide: BuildSide - val left: SparkPlan - val right: SparkPlan - - lazy val (buildPlan, streamedPlan) = buildSide match { - case BuildLeft => (left, right) - case BuildRight => (right, left) - } - - lazy val (buildKeys, streamedKeys) = buildSide match { - case BuildLeft => (leftKeys, rightKeys) - case BuildRight => (rightKeys, leftKeys) - } - - def output = left.output ++ right.output - - @transient lazy val buildSideKeyGenerator = newProjection(buildKeys, buildPlan.output) - @transient lazy val streamSideKeyGenerator = - newMutableProjection(streamedKeys, streamedPlan.output) - - def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = { - // TODO: Use Spark's HashMap implementation. - - val hashTable = new java.util.HashMap[Row, CompactBuffer[Row]]() - var currentRow: Row = null - - // Create a mapping of buildKeys -> rows - while (buildIter.hasNext) { - currentRow = buildIter.next() - val rowKey = buildSideKeyGenerator(currentRow) - if (!rowKey.anyNull) { - val existingMatchList = hashTable.get(rowKey) - val matchList = if (existingMatchList == null) { - val newMatchList = new CompactBuffer[Row]() - hashTable.put(rowKey, newMatchList) - newMatchList - } else { - existingMatchList - } - matchList += currentRow.copy() - } - } - - new Iterator[Row] { - private[this] var currentStreamedRow: Row = _ - private[this] var currentHashMatches: CompactBuffer[Row] = _ - private[this] var currentMatchPosition: Int = -1 - - // Mutable per row objects. - private[this] val joinRow = new JoinedRow2 - - private[this] val joinKeys = streamSideKeyGenerator() - - override final def hasNext: Boolean = - (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) || - (streamIter.hasNext && fetchNext()) - - override final def next() = { - val ret = buildSide match { - case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) - case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) - } - currentMatchPosition += 1 - ret - } - - /** - * Searches the streamed iterator for the next row that has at least one match in hashtable. - * - * @return true if the search is successful, and false if the streamed iterator runs out of - * tuples. - */ - private final def fetchNext(): Boolean = { - currentHashMatches = null - currentMatchPosition = -1 - - while (currentHashMatches == null && streamIter.hasNext) { - currentStreamedRow = streamIter.next() - if (!joinKeys(currentStreamedRow).anyNull) { - currentHashMatches = hashTable.get(joinKeys.currentValue) - } - } - - if (currentHashMatches == null) { - false - } else { - currentMatchPosition = 0 - true - } - } - } - } -} - -/** - * :: DeveloperApi :: - * Performs a hash based outer join for two child relations by shuffling the data using - * the join keys. This operator requires loading the associated partition in both side into memory. - */ -@DeveloperApi -case class HashOuterJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - joinType: JoinType, - condition: Option[Expression], - left: SparkPlan, - right: SparkPlan) extends BinaryNode { - - override def outputPartitioning: Partitioning = joinType match { - case LeftOuter => left.outputPartitioning - case RightOuter => right.outputPartitioning - case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) - case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") - } - - override def requiredChildDistribution = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - - override def output = { - joinType match { - case LeftOuter => - left.output ++ right.output.map(_.withNullability(true)) - case RightOuter => - left.output.map(_.withNullability(true)) ++ right.output - case FullOuter => - left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case x => - throw new Exception(s"HashOuterJoin should not take $x as the JoinType") - } - } - - @transient private[this] lazy val DUMMY_LIST = Seq[Row](null) - @transient private[this] lazy val EMPTY_LIST = Seq.empty[Row] - - // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala - // iterator for performance purpose. - - private[this] def leftOuterIterator( - key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { - val joinedRow = new JoinedRow() - val rightNullRow = new GenericRow(right.output.length) - val boundCondition = - condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) - - leftIter.iterator.flatMap { l => - joinedRow.withLeft(l) - var matched = false - (if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) => - matched = true - joinedRow.copy - } else { - Nil - }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { - // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, - // as we don't know whether we need to append it until finish iterating all of the - // records in right side. - // If we didn't get any proper row, then append a single row with empty right - joinedRow.withRight(rightNullRow).copy - }) - } - } - - private[this] def rightOuterIterator( - key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { - val joinedRow = new JoinedRow() - val leftNullRow = new GenericRow(left.output.length) - val boundCondition = - condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) - - rightIter.iterator.flatMap { r => - joinedRow.withRight(r) - var matched = false - (if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) => - matched = true - joinedRow.copy - } else { - Nil - }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { - // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, - // as we don't know whether we need to append it until finish iterating all of the - // records in left side. - // If we didn't get any proper row, then append a single row with empty left. - joinedRow.withLeft(leftNullRow).copy - }) - } - } - - private[this] def fullOuterIterator( - key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { - val joinedRow = new JoinedRow() - val leftNullRow = new GenericRow(left.output.length) - val rightNullRow = new GenericRow(right.output.length) - val boundCondition = - condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) - - if (!key.anyNull) { - // Store the positions of records in right, if one of its associated row satisfy - // the join condition. - val rightMatchedSet = scala.collection.mutable.Set[Int]() - leftIter.iterator.flatMap[Row] { l => - joinedRow.withLeft(l) - var matched = false - rightIter.zipWithIndex.collect { - // 1. For those matched (satisfy the join condition) records with both sides filled, - // append them directly - - case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> { - matched = true - // if the row satisfy the join condition, add its index into the matched set - rightMatchedSet.add(idx) - joinedRow.copy - } - } ++ DUMMY_LIST.filter(_ => !matched).map( _ => { - // 2. For those unmatched records in left, append additional records with empty right. - - // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, - // as we don't know whether we need to append it until finish iterating all - // of the records in right side. - // If we didn't get any proper row, then append a single row with empty right. - joinedRow.withRight(rightNullRow).copy - }) - } ++ rightIter.zipWithIndex.collect { - // 3. For those unmatched records in right, append additional records with empty left. - - // Re-visiting the records in right, and append additional row with empty left, if its not - // in the matched set. - case (r, idx) if (!rightMatchedSet.contains(idx)) => { - joinedRow(leftNullRow, r).copy - } - } - } else { - leftIter.iterator.map[Row] { l => - joinedRow(l, rightNullRow).copy - } ++ rightIter.iterator.map[Row] { r => - joinedRow(leftNullRow, r).copy - } - } - } - - private[this] def buildHashTable( - iter: Iterator[Row], keyGenerator: Projection): JavaHashMap[Row, CompactBuffer[Row]] = { - val hashTable = new JavaHashMap[Row, CompactBuffer[Row]]() - while (iter.hasNext) { - val currentRow = iter.next() - val rowKey = keyGenerator(currentRow) - - var existingMatchList = hashTable.get(rowKey) - if (existingMatchList == null) { - existingMatchList = new CompactBuffer[Row]() - hashTable.put(rowKey, existingMatchList) - } - - existingMatchList += currentRow.copy() - } - - hashTable - } - - def execute() = { - left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => - // TODO this probably can be replaced by external sort (sort merged join?) - // Build HashMap for current partition in left relation - val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) - // Build HashMap for current partition in right relation - val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) - - import scala.collection.JavaConversions._ - val boundCondition = - condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) - joinType match { - case LeftOuter => leftHashTable.keysIterator.flatMap { key => - leftOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), - rightHashTable.getOrElse(key, EMPTY_LIST)) - } - case RightOuter => rightHashTable.keysIterator.flatMap { key => - rightOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), - rightHashTable.getOrElse(key, EMPTY_LIST)) - } - case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => - fullOuterIterator(key, - leftHashTable.getOrElse(key, EMPTY_LIST), - rightHashTable.getOrElse(key, EMPTY_LIST)) - } - case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") - } - } - } -} - -/** - * :: DeveloperApi :: - * Performs an inner hash join of two child relations by first shuffling the data using the join - * keys. - */ -@DeveloperApi -case class ShuffledHashJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - buildSide: BuildSide, - left: SparkPlan, - right: SparkPlan) extends BinaryNode with HashJoin { - - override def outputPartitioning: Partitioning = left.outputPartitioning - - override def requiredChildDistribution = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - - def execute() = { - buildPlan.execute().zipPartitions(streamedPlan.execute()) { - (buildIter, streamIter) => joinIterators(buildIter, streamIter) - } - } -} - -/** - * :: DeveloperApi :: - * Build the right table's join keys into a HashSet, and iteratively go through the left - * table, to find the if join keys are in the Hash set. - */ -@DeveloperApi -case class LeftSemiJoinHash( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - left: SparkPlan, - right: SparkPlan) extends BinaryNode with HashJoin { - - val buildSide = BuildRight - - override def requiredChildDistribution = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - - override def output = left.output - - def execute() = { - buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => - val hashSet = new java.util.HashSet[Row]() - var currentRow: Row = null - - // Create a Hash set of buildKeys - while (buildIter.hasNext) { - currentRow = buildIter.next() - val rowKey = buildSideKeyGenerator(currentRow) - if (!rowKey.anyNull) { - val keyExists = hashSet.contains(rowKey) - if (!keyExists) { - hashSet.add(rowKey) - } - } - } - - val joinKeys = streamSideKeyGenerator() - streamIter.filter(current => { - !joinKeys(current).anyNull && hashSet.contains(joinKeys.currentValue) - }) - } - } -} - - -/** - * :: DeveloperApi :: - * Performs an inner hash join of two child relations. When the output RDD of this operator is - * being constructed, a Spark job is asynchronously started to calculate the values for the - * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed - * relation is not shuffled. - */ -@DeveloperApi -case class BroadcastHashJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - buildSide: BuildSide, - left: SparkPlan, - right: SparkPlan) extends BinaryNode with HashJoin { - - override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning - - override def requiredChildDistribution = - UnspecifiedDistribution :: UnspecifiedDistribution :: Nil - - @transient - val broadcastFuture = future { - sparkContext.broadcast(buildPlan.executeCollect()) - } - - def execute() = { - val broadcastRelation = Await.result(broadcastFuture, 5.minute) - - streamedPlan.execute().mapPartitions { streamedIter => - joinIterators(broadcastRelation.value.iterator, streamedIter) - } - } -} - -/** - * :: DeveloperApi :: - * Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys - * for hash join. - */ -@DeveloperApi -case class LeftSemiJoinBNL( - streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) - extends BinaryNode { - // TODO: Override requiredChildDistribution. - - override def outputPartitioning: Partitioning = streamed.outputPartitioning - - def output = left.output - - /** The Streamed Relation */ - def left = streamed - /** The Broadcast relation */ - def right = broadcast - - @transient lazy val boundCondition = - InterpretedPredicate( - condition - .map(c => BindReferences.bindReference(c, left.output ++ right.output)) - .getOrElse(Literal(true))) - - def execute() = { - val broadcastedRelation = - sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) - - streamed.execute().mapPartitions { streamedIter => - val joinedRow = new JoinedRow - - streamedIter.filter(streamedRow => { - var i = 0 - var matched = false - - while (i < broadcastedRelation.value.size && !matched) { - val broadcastedRow = broadcastedRelation.value(i) - if (boundCondition(joinedRow(streamedRow, broadcastedRow))) { - matched = true - } - i += 1 - } - matched - }) - } - } -} - -/** - * :: DeveloperApi :: - */ -@DeveloperApi -case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { - def output = left.output ++ right.output - - def execute() = { - val leftResults = left.execute().map(_.copy()) - val rightResults = right.execute().map(_.copy()) - - leftResults.cartesian(rightResults).mapPartitions { iter => - val joinedRow = new JoinedRow - iter.map(r => joinedRow(r._1, r._2)) - } - } -} - -/** - * :: DeveloperApi :: - */ -@DeveloperApi -case class BroadcastNestedLoopJoin( - left: SparkPlan, - right: SparkPlan, - buildSide: BuildSide, - joinType: JoinType, - condition: Option[Expression]) extends BinaryNode { - // TODO: Override requiredChildDistribution. - - /** BuildRight means the right relation <=> the broadcast relation. */ - val (streamed, broadcast) = buildSide match { - case BuildRight => (left, right) - case BuildLeft => (right, left) - } - - override def outputPartitioning: Partitioning = streamed.outputPartitioning - - override def output = { - joinType match { - case LeftOuter => - left.output ++ right.output.map(_.withNullability(true)) - case RightOuter => - left.output.map(_.withNullability(true)) ++ right.output - case FullOuter => - left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case _ => - left.output ++ right.output - } - } - - @transient lazy val boundCondition = - InterpretedPredicate( - condition - .map(c => BindReferences.bindReference(c, left.output ++ right.output)) - .getOrElse(Literal(true))) - - def execute() = { - val broadcastedRelation = - sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) - - /** All rows that either match both-way, or rows from streamed joined with nulls. */ - val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter => - val matchedRows = new CompactBuffer[Row] - // TODO: Use Spark's BitSet. - val includedBroadcastTuples = - new scala.collection.mutable.BitSet(broadcastedRelation.value.size) - val joinedRow = new JoinedRow - val leftNulls = new GenericMutableRow(left.output.size) - val rightNulls = new GenericMutableRow(right.output.size) - - streamedIter.foreach { streamedRow => - var i = 0 - var streamRowMatched = false - - while (i < broadcastedRelation.value.size) { - // TODO: One bitset per partition instead of per row. - val broadcastedRow = broadcastedRelation.value(i) - buildSide match { - case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => - matchedRows += joinedRow(streamedRow, broadcastedRow).copy() - streamRowMatched = true - includedBroadcastTuples += i - case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) => - matchedRows += joinedRow(broadcastedRow, streamedRow).copy() - streamRowMatched = true - includedBroadcastTuples += i - case _ => - } - i += 1 - } - - (streamRowMatched, joinType, buildSide) match { - case (false, LeftOuter | FullOuter, BuildRight) => - matchedRows += joinedRow(streamedRow, rightNulls).copy() - case (false, RightOuter | FullOuter, BuildLeft) => - matchedRows += joinedRow(leftNulls, streamedRow).copy() - case _ => - } - } - Iterator((matchedRows, includedBroadcastTuples)) - } - - val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2) - val allIncludedBroadcastTuples = - if (includedBroadcastTuples.count == 0) { - new scala.collection.mutable.BitSet(broadcastedRelation.value.size) - } else { - includedBroadcastTuples.reduce(_ ++ _) - } - - val leftNulls = new GenericMutableRow(left.output.size) - val rightNulls = new GenericMutableRow(right.output.size) - /** Rows from broadcasted joined with nulls. */ - val broadcastRowsWithNulls: Seq[Row] = { - val buf: CompactBuffer[Row] = new CompactBuffer() - var i = 0 - val rel = broadcastedRelation.value - while (i < rel.length) { - if (!allIncludedBroadcastTuples.contains(i)) { - (joinType, buildSide) match { - case (RightOuter | FullOuter, BuildRight) => buf += new JoinedRow(leftNulls, rel(i)) - case (LeftOuter | FullOuter, BuildLeft) => buf += new JoinedRow(rel(i), rightNulls) - case _ => - } - } - i += 1 - } - buf.toSeq - } - - // TODO: Breaks lineage. - sparkContext.union( - matchesOrStreamedRowsWithNulls.flatMap(_._1), sparkContext.makeRDD(broadcastRowsWithNulls)) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala new file mode 100644 index 0000000000000..8fd35880eedfe --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import scala.concurrent._ +import scala.concurrent.duration._ +import scala.concurrent.ExecutionContext.Implicits.global + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions.{Row, Expression} +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnspecifiedDistribution} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +/** + * :: DeveloperApi :: + * Performs an inner hash join of two child relations. When the output RDD of this operator is + * being constructed, a Spark job is asynchronously started to calculate the values for the + * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed + * relation is not shuffled. + */ +@DeveloperApi +case class BroadcastHashJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + buildSide: BuildSide, + left: SparkPlan, + right: SparkPlan) + extends BinaryNode with HashJoin { + + override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning + + override def requiredChildDistribution = + UnspecifiedDistribution :: UnspecifiedDistribution :: Nil + + @transient + private val broadcastFuture = future { + val input: Array[Row] = buildPlan.executeCollect() + val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.length) + sparkContext.broadcast(hashed) + } + + override def execute() = { + val broadcastRelation = Await.result(broadcastFuture, 5.minute) + + streamedPlan.execute().mapPartitions { streamedIter => + hashJoin(streamedIter, broadcastRelation.value) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala new file mode 100644 index 0000000000000..36aad13778bd2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.util.collection.CompactBuffer + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +case class BroadcastNestedLoopJoin( + left: SparkPlan, + right: SparkPlan, + buildSide: BuildSide, + joinType: JoinType, + condition: Option[Expression]) extends BinaryNode { + // TODO: Override requiredChildDistribution. + + /** BuildRight means the right relation <=> the broadcast relation. */ + private val (streamed, broadcast) = buildSide match { + case BuildRight => (left, right) + case BuildLeft => (right, left) + } + + override def outputPartitioning: Partitioning = streamed.outputPartitioning + + override def output = { + joinType match { + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case _ => + left.output ++ right.output + } + } + + @transient private lazy val boundCondition = + InterpretedPredicate( + condition + .map(c => BindReferences.bindReference(c, left.output ++ right.output)) + .getOrElse(Literal(true))) + + override def execute() = { + val broadcastedRelation = + sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + + /** All rows that either match both-way, or rows from streamed joined with nulls. */ + val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter => + val matchedRows = new CompactBuffer[Row] + // TODO: Use Spark's BitSet. + val includedBroadcastTuples = + new scala.collection.mutable.BitSet(broadcastedRelation.value.size) + val joinedRow = new JoinedRow + val leftNulls = new GenericMutableRow(left.output.size) + val rightNulls = new GenericMutableRow(right.output.size) + + streamedIter.foreach { streamedRow => + var i = 0 + var streamRowMatched = false + + while (i < broadcastedRelation.value.size) { + // TODO: One bitset per partition instead of per row. + val broadcastedRow = broadcastedRelation.value(i) + buildSide match { + case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => + matchedRows += joinedRow(streamedRow, broadcastedRow).copy() + streamRowMatched = true + includedBroadcastTuples += i + case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) => + matchedRows += joinedRow(broadcastedRow, streamedRow).copy() + streamRowMatched = true + includedBroadcastTuples += i + case _ => + } + i += 1 + } + + (streamRowMatched, joinType, buildSide) match { + case (false, LeftOuter | FullOuter, BuildRight) => + matchedRows += joinedRow(streamedRow, rightNulls).copy() + case (false, RightOuter | FullOuter, BuildLeft) => + matchedRows += joinedRow(leftNulls, streamedRow).copy() + case _ => + } + } + Iterator((matchedRows, includedBroadcastTuples)) + } + + val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2) + val allIncludedBroadcastTuples = + if (includedBroadcastTuples.count == 0) { + new scala.collection.mutable.BitSet(broadcastedRelation.value.size) + } else { + includedBroadcastTuples.reduce(_ ++ _) + } + + val leftNulls = new GenericMutableRow(left.output.size) + val rightNulls = new GenericMutableRow(right.output.size) + /** Rows from broadcasted joined with nulls. */ + val broadcastRowsWithNulls: Seq[Row] = { + val buf: CompactBuffer[Row] = new CompactBuffer() + var i = 0 + val rel = broadcastedRelation.value + while (i < rel.length) { + if (!allIncludedBroadcastTuples.contains(i)) { + (joinType, buildSide) match { + case (RightOuter | FullOuter, BuildRight) => buf += new JoinedRow(leftNulls, rel(i)) + case (LeftOuter | FullOuter, BuildLeft) => buf += new JoinedRow(rel(i), rightNulls) + case _ => + } + } + i += 1 + } + buf.toSeq + } + + // TODO: Breaks lineage. + sparkContext.union( + matchesOrStreamedRowsWithNulls.flatMap(_._1), sparkContext.makeRDD(broadcastRowsWithNulls)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala new file mode 100644 index 0000000000000..76c14c02aab34 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions.JoinedRow +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { + override def output = left.output ++ right.output + + override def execute() = { + val leftResults = left.execute().map(_.copy()) + val rightResults = right.execute().map(_.copy()) + + leftResults.cartesian(rightResults).mapPartitions { iter => + val joinedRow = new JoinedRow + iter.map(r => joinedRow(r._1, r._2)) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala new file mode 100644 index 0000000000000..4012d757d5f9a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.util.collection.CompactBuffer + + +trait HashJoin { + self: SparkPlan => + + val leftKeys: Seq[Expression] + val rightKeys: Seq[Expression] + val buildSide: BuildSide + val left: SparkPlan + val right: SparkPlan + + protected lazy val (buildPlan, streamedPlan) = buildSide match { + case BuildLeft => (left, right) + case BuildRight => (right, left) + } + + protected lazy val (buildKeys, streamedKeys) = buildSide match { + case BuildLeft => (leftKeys, rightKeys) + case BuildRight => (rightKeys, leftKeys) + } + + override def output = left.output ++ right.output + + @transient protected lazy val buildSideKeyGenerator: Projection = + newProjection(buildKeys, buildPlan.output) + + @transient protected lazy val streamSideKeyGenerator: () => MutableProjection = + newMutableProjection(streamedKeys, streamedPlan.output) + + protected def hashJoin(streamIter: Iterator[Row], hashedRelation: HashedRelation): Iterator[Row] = + { + new Iterator[Row] { + private[this] var currentStreamedRow: Row = _ + private[this] var currentHashMatches: CompactBuffer[Row] = _ + private[this] var currentMatchPosition: Int = -1 + + // Mutable per row objects. + private[this] val joinRow = new JoinedRow2 + + private[this] val joinKeys = streamSideKeyGenerator() + + override final def hasNext: Boolean = + (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) || + (streamIter.hasNext && fetchNext()) + + override final def next() = { + val ret = buildSide match { + case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) + case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) + } + currentMatchPosition += 1 + ret + } + + /** + * Searches the streamed iterator for the next row that has at least one match in hashtable. + * + * @return true if the search is successful, and false if the streamed iterator runs out of + * tuples. + */ + private final def fetchNext(): Boolean = { + currentHashMatches = null + currentMatchPosition = -1 + + while (currentHashMatches == null && streamIter.hasNext) { + currentStreamedRow = streamIter.next() + if (!joinKeys(currentStreamedRow).anyNull) { + currentHashMatches = hashedRelation.get(joinKeys.currentValue) + } + } + + if (currentHashMatches == null) { + false + } else { + currentMatchPosition = 0 + true + } + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala new file mode 100644 index 0000000000000..b73041d306b36 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import java.util.{HashMap => JavaHashMap} + +import scala.collection.JavaConversions._ + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning, UnknownPartitioning} +import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.util.collection.CompactBuffer + +/** + * :: DeveloperApi :: + * Performs a hash based outer join for two child relations by shuffling the data using + * the join keys. This operator requires loading the associated partition in both side into memory. + */ +@DeveloperApi +case class HashOuterJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode { + + override def outputPartitioning: Partitioning = joinType match { + case LeftOuter => left.outputPartitioning + case RightOuter => right.outputPartitioning + case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) + case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") + } + + override def requiredChildDistribution = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + override def output = { + joinType match { + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case x => + throw new Exception(s"HashOuterJoin should not take $x as the JoinType") + } + } + + @transient private[this] lazy val DUMMY_LIST = Seq[Row](null) + @transient private[this] lazy val EMPTY_LIST = Seq.empty[Row] + + // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala + // iterator for performance purpose. + + private[this] def leftOuterIterator( + key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { + val joinedRow = new JoinedRow() + val rightNullRow = new GenericRow(right.output.length) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + + leftIter.iterator.flatMap { l => + joinedRow.withLeft(l) + var matched = false + (if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) => + matched = true + joinedRow.copy + } else { + Nil + }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { + // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // as we don't know whether we need to append it until finish iterating all of the + // records in right side. + // If we didn't get any proper row, then append a single row with empty right + joinedRow.withRight(rightNullRow).copy + }) + } + } + + private[this] def rightOuterIterator( + key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { + val joinedRow = new JoinedRow() + val leftNullRow = new GenericRow(left.output.length) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + + rightIter.iterator.flatMap { r => + joinedRow.withRight(r) + var matched = false + (if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) => + matched = true + joinedRow.copy + } else { + Nil + }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { + // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // as we don't know whether we need to append it until finish iterating all of the + // records in left side. + // If we didn't get any proper row, then append a single row with empty left. + joinedRow.withLeft(leftNullRow).copy + }) + } + } + + private[this] def fullOuterIterator( + key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { + val joinedRow = new JoinedRow() + val leftNullRow = new GenericRow(left.output.length) + val rightNullRow = new GenericRow(right.output.length) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + + if (!key.anyNull) { + // Store the positions of records in right, if one of its associated row satisfy + // the join condition. + val rightMatchedSet = scala.collection.mutable.Set[Int]() + leftIter.iterator.flatMap[Row] { l => + joinedRow.withLeft(l) + var matched = false + rightIter.zipWithIndex.collect { + // 1. For those matched (satisfy the join condition) records with both sides filled, + // append them directly + + case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> { + matched = true + // if the row satisfy the join condition, add its index into the matched set + rightMatchedSet.add(idx) + joinedRow.copy + } + } ++ DUMMY_LIST.filter(_ => !matched).map( _ => { + // 2. For those unmatched records in left, append additional records with empty right. + + // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // as we don't know whether we need to append it until finish iterating all + // of the records in right side. + // If we didn't get any proper row, then append a single row with empty right. + joinedRow.withRight(rightNullRow).copy + }) + } ++ rightIter.zipWithIndex.collect { + // 3. For those unmatched records in right, append additional records with empty left. + + // Re-visiting the records in right, and append additional row with empty left, if its not + // in the matched set. + case (r, idx) if (!rightMatchedSet.contains(idx)) => { + joinedRow(leftNullRow, r).copy + } + } + } else { + leftIter.iterator.map[Row] { l => + joinedRow(l, rightNullRow).copy + } ++ rightIter.iterator.map[Row] { r => + joinedRow(leftNullRow, r).copy + } + } + } + + private[this] def buildHashTable( + iter: Iterator[Row], keyGenerator: Projection): JavaHashMap[Row, CompactBuffer[Row]] = { + val hashTable = new JavaHashMap[Row, CompactBuffer[Row]]() + while (iter.hasNext) { + val currentRow = iter.next() + val rowKey = keyGenerator(currentRow) + + var existingMatchList = hashTable.get(rowKey) + if (existingMatchList == null) { + existingMatchList = new CompactBuffer[Row]() + hashTable.put(rowKey, existingMatchList) + } + + existingMatchList += currentRow.copy() + } + + hashTable + } + + override def execute() = { + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + // TODO this probably can be replaced by external sort (sort merged join?) + // Build HashMap for current partition in left relation + val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) + // Build HashMap for current partition in right relation + val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) + val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + joinType match { + case LeftOuter => leftHashTable.keysIterator.flatMap { key => + leftOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), + rightHashTable.getOrElse(key, EMPTY_LIST)) + } + case RightOuter => rightHashTable.keysIterator.flatMap { key => + rightOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), + rightHashTable.getOrElse(key, EMPTY_LIST)) + } + case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => + fullOuterIterator(key, + leftHashTable.getOrElse(key, EMPTY_LIST), + rightHashTable.getOrElse(key, EMPTY_LIST)) + } + case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala new file mode 100644 index 0000000000000..38b8993b03f82 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import java.util.{HashMap => JavaHashMap} + +import org.apache.spark.sql.catalyst.expressions.{Projection, Row} +import org.apache.spark.util.collection.CompactBuffer + + +/** + * Interface for a hashed relation by some key. Use [[HashedRelation.apply]] to create a concrete + * object. + */ +private[joins] sealed trait HashedRelation { + def get(key: Row): CompactBuffer[Row] +} + + +/** + * A general [[HashedRelation]] backed by a hash map that maps the key into a sequence of values. + */ +private[joins] final class GeneralHashedRelation(hashTable: JavaHashMap[Row, CompactBuffer[Row]]) + extends HashedRelation with Serializable { + + override def get(key: Row) = hashTable.get(key) +} + + +/** + * A specialized [[HashedRelation]] that maps key into a single value. This implementation + * assumes the key is unique. + */ +private[joins] final class UniqueKeyHashedRelation(hashTable: JavaHashMap[Row, Row]) + extends HashedRelation with Serializable { + + override def get(key: Row) = { + val v = hashTable.get(key) + if (v eq null) null else CompactBuffer(v) + } + + def getValue(key: Row): Row = hashTable.get(key) +} + + +// TODO(rxin): a version of [[HashedRelation]] backed by arrays for consecutive integer keys. + + +private[joins] object HashedRelation { + + def apply( + input: Iterator[Row], + keyGenerator: Projection, + sizeEstimate: Int = 64): HashedRelation = { + + // TODO: Use Spark's HashMap implementation. + val hashTable = new JavaHashMap[Row, CompactBuffer[Row]](sizeEstimate) + var currentRow: Row = null + + // Whether the join key is unique. If the key is unique, we can convert the underlying + // hash map into one specialized for this. + var keyIsUnique = true + + // Create a mapping of buildKeys -> rows + while (input.hasNext) { + currentRow = input.next() + val rowKey = keyGenerator(currentRow) + if (!rowKey.anyNull) { + val existingMatchList = hashTable.get(rowKey) + val matchList = if (existingMatchList == null) { + val newMatchList = new CompactBuffer[Row]() + hashTable.put(rowKey, newMatchList) + newMatchList + } else { + keyIsUnique = false + existingMatchList + } + matchList += currentRow.copy() + } + } + + if (keyIsUnique) { + val uniqHashTable = new JavaHashMap[Row, Row](hashTable.size) + val iter = hashTable.entrySet().iterator() + while (iter.hasNext) { + val entry = iter.next() + uniqHashTable.put(entry.getKey, entry.getValue()(0)) + } + new UniqueKeyHashedRelation(uniqHashTable) + } else { + new GeneralHashedRelation(hashTable) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala new file mode 100644 index 0000000000000..60003d1900d85 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +/** + * :: DeveloperApi :: + * Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys + * for hash join. + */ +@DeveloperApi +case class LeftSemiJoinBNL( + streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) + extends BinaryNode { + // TODO: Override requiredChildDistribution. + + override def outputPartitioning: Partitioning = streamed.outputPartitioning + + override def output = left.output + + /** The Streamed Relation */ + override def left = streamed + /** The Broadcast relation */ + override def right = broadcast + + @transient private lazy val boundCondition = + InterpretedPredicate( + condition + .map(c => BindReferences.bindReference(c, left.output ++ right.output)) + .getOrElse(Literal(true))) + + override def execute() = { + val broadcastedRelation = + sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + + streamed.execute().mapPartitions { streamedIter => + val joinedRow = new JoinedRow + + streamedIter.filter(streamedRow => { + var i = 0 + var matched = false + + while (i < broadcastedRelation.value.size && !matched) { + val broadcastedRow = broadcastedRelation.value(i) + if (boundCondition(joinedRow(streamedRow, broadcastedRow))) { + matched = true + } + i += 1 + } + matched + }) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala new file mode 100644 index 0000000000000..ea7babf3be948 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions.{Expression, Row} +import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +/** + * :: DeveloperApi :: + * Build the right table's join keys into a HashSet, and iteratively go through the left + * table, to find the if join keys are in the Hash set. + */ +@DeveloperApi +case class LeftSemiJoinHash( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode with HashJoin { + + override val buildSide = BuildRight + + override def requiredChildDistribution = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + override def output = left.output + + override def execute() = { + buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => + val hashSet = new java.util.HashSet[Row]() + var currentRow: Row = null + + // Create a Hash set of buildKeys + while (buildIter.hasNext) { + currentRow = buildIter.next() + val rowKey = buildSideKeyGenerator(currentRow) + if (!rowKey.anyNull) { + val keyExists = hashSet.contains(rowKey) + if (!keyExists) { + hashSet.add(rowKey) + } + } + } + + val joinKeys = streamSideKeyGenerator() + streamIter.filter(current => { + !joinKeys(current).anyNull && hashSet.contains(joinKeys.currentValue) + }) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala new file mode 100644 index 0000000000000..418c1c23e5546 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} + +/** + * :: DeveloperApi :: + * Performs an inner hash join of two child relations by first shuffling the data using the join + * keys. + */ +@DeveloperApi +case class ShuffledHashJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + buildSide: BuildSide, + left: SparkPlan, + right: SparkPlan) + extends BinaryNode with HashJoin { + + override def outputPartitioning: Partitioning = left.outputPartitioning + + override def requiredChildDistribution = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + override def execute() = { + buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => + val hashed = HashedRelation(buildIter, buildSideKeyGenerator) + hashJoin(streamIter, hashed) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala new file mode 100644 index 0000000000000..7f2ab1765b28f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.annotation.DeveloperApi + +/** + * :: DeveloperApi :: + * Physical execution operators for join operations. + */ +package object joins { + + @DeveloperApi + sealed abstract class BuildSide + + @DeveloperApi + case object BuildRight extends BuildSide + + @DeveloperApi + case object BuildLeft extends BuildSide + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 0977da3e8577c..be729e5d244b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -105,13 +105,21 @@ private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] { } } +object EvaluatePython { + def apply(udf: PythonUDF, child: LogicalPlan) = + new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)()) +} + /** * :: DeveloperApi :: * Evaluates a [[PythonUDF]], appending the result to the end of the input tuple. */ @DeveloperApi -case class EvaluatePython(udf: PythonUDF, child: LogicalPlan) extends logical.UnaryNode { - val resultAttribute = AttributeReference("pythonUDF", udf.dataType, nullable=true)() +case class EvaluatePython( + udf: PythonUDF, + child: LogicalPlan, + resultAttribute: AttributeReference) + extends logical.UnaryNode { def output = child.output :+ resultAttribute } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 0f27fd13e7379..61ee960aad9d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -20,7 +20,9 @@ package org.apache.spark.sql.json import scala.collection.Map import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper} import scala.math.BigDecimal +import java.sql.Timestamp +import com.fasterxml.jackson.core.JsonProcessingException import com.fasterxml.jackson.databind.ObjectMapper import org.apache.spark.rdd.RDD @@ -34,16 +36,19 @@ private[sql] object JsonRDD extends Logging { private[sql] def jsonStringToRow( json: RDD[String], - schema: StructType): RDD[Row] = { - parseJson(json).map(parsed => asRow(parsed, schema)) + schema: StructType, + columnNameOfCorruptRecords: String): RDD[Row] = { + parseJson(json, columnNameOfCorruptRecords).map(parsed => asRow(parsed, schema)) } private[sql] def inferSchema( json: RDD[String], - samplingRatio: Double = 1.0): StructType = { + samplingRatio: Double = 1.0, + columnNameOfCorruptRecords: String): StructType = { require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0") val schemaData = if (samplingRatio > 0.99) json else json.sample(false, samplingRatio, 1) - val allKeys = parseJson(schemaData).map(allKeysWithValueTypes).reduce(_ ++ _) + val allKeys = + parseJson(schemaData, columnNameOfCorruptRecords).map(allKeysWithValueTypes).reduce(_ ++ _) createSchema(allKeys) } @@ -273,7 +278,9 @@ private[sql] object JsonRDD extends Logging { case atom => atom } - private def parseJson(json: RDD[String]): RDD[Map[String, Any]] = { + private def parseJson( + json: RDD[String], + columnNameOfCorruptRecords: String): RDD[Map[String, Any]] = { // According to [Jackson-72: https://jira.codehaus.org/browse/JACKSON-72], // ObjectMapper will not return BigDecimal when // "DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS" is disabled @@ -288,12 +295,16 @@ private[sql] object JsonRDD extends Logging { // For example: for {"key": 1, "key":2}, we will get "key"->2. val mapper = new ObjectMapper() iter.flatMap { record => - val parsed = mapper.readValue(record, classOf[Object]) match { - case map: java.util.Map[_, _] => scalafy(map).asInstanceOf[Map[String, Any]] :: Nil - case list: java.util.List[_] => scalafy(list).asInstanceOf[Seq[Map[String, Any]]] - } + try { + val parsed = mapper.readValue(record, classOf[Object]) match { + case map: java.util.Map[_, _] => scalafy(map).asInstanceOf[Map[String, Any]] :: Nil + case list: java.util.List[_] => scalafy(list).asInstanceOf[Seq[Map[String, Any]]] + } - parsed + parsed + } catch { + case e: JsonProcessingException => Map(columnNameOfCorruptRecords -> record) :: Nil + } } }) } @@ -361,6 +372,14 @@ private[sql] object JsonRDD extends Logging { } } + private def toTimestamp(value: Any): Timestamp = { + value match { + case value: java.lang.Integer => new Timestamp(value.asInstanceOf[Int].toLong) + case value: java.lang.Long => new Timestamp(value) + case value: java.lang.String => Timestamp.valueOf(value) + } + } + private[json] def enforceCorrectType(value: Any, desiredType: DataType): Any ={ if (value == null) { null @@ -377,6 +396,7 @@ private[sql] object JsonRDD extends Logging { case ArrayType(elementType, _) => value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType)) case struct: StructType => asRow(value.asInstanceOf[Map[String, Any]], struct) + case TimestampType => toTimestamp(value) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index f513eae9c2d13..e98d151286818 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -165,6 +165,16 @@ package object sql { @DeveloperApi val TimestampType = catalyst.types.TimestampType + /** + * :: DeveloperApi :: + * + * The data type representing `java.sql.Date` values. + * + * @group dataType + */ + @DeveloperApi + val DateType = catalyst.types.DateType + /** * :: DeveloperApi :: * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index ffb732347d30a..5c6fa78ae3895 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -289,9 +289,9 @@ case class InsertIntoParquetTable( def writeShard(context: TaskContext, iter: Iterator[Row]): Int = { // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. - val attemptNumber = (context.getAttemptId % Int.MaxValue).toInt + val attemptNumber = (context.attemptId % Int.MaxValue).toInt /* "reduce task" */ - val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.getPartitionId, + val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, attemptNumber) val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) val format = new AppendingParquetOutputFormat(taskIdOffset) @@ -331,13 +331,21 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int) // override to choose output filename so not overwrite existing ones override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val taskId: TaskID = context.getTaskAttemptID.getTaskID + val taskId: TaskID = getTaskAttemptID(context).getTaskID val partition: Int = taskId.getId val filename = s"part-r-${partition + offset}.parquet" val committer: FileOutputCommitter = getOutputCommitter(context).asInstanceOf[FileOutputCommitter] new Path(committer.getWorkPath, filename) } + + // The TaskAttemptContext is a class in hadoop-1 but is an interface in hadoop-2. + // The signatures of the method TaskAttemptContext.getTaskAttemptID for the both versions + // are the same, so the method calls are source-compatible but NOT binary-compatible because + // the opcode of method call for class is INVOKEVIRTUAL and for interface is INVOKEINTERFACE. + private def getTaskAttemptID(context: TaskAttemptContext): TaskAttemptID = { + context.getClass.getMethod("getTaskAttemptID").invoke(context).asInstanceOf[TaskAttemptID] + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index 2941b9793597f..e6389cf77a4c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.parquet import java.io.IOException +import scala.util.Try + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.Job @@ -323,14 +325,14 @@ private[parquet] object ParquetTypesConverter extends Logging { } def convertFromString(string: String): Seq[Attribute] = { - DataType(string) match { + Try(DataType.fromJson(string)).getOrElse(DataType.fromCaseClassString(string)) match { case s: StructType => s.toAttributes case other => sys.error(s"Can convert $string to row") } } def convertToString(schema: Seq[Attribute]): String = { - StructType.fromAttributes(schema).toString + StructType.fromAttributes(schema).json } def writeMetaData(attributes: Seq[Attribute], origPath: Path, conf: Configuration): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala index 77353f4eb0227..e44cb08309523 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala @@ -41,6 +41,7 @@ protected[sql] object DataTypeConversions { case StringType => JDataType.StringType case BinaryType => JDataType.BinaryType case BooleanType => JDataType.BooleanType + case DateType => JDataType.DateType case TimestampType => JDataType.TimestampType case DecimalType => JDataType.DecimalType case DoubleType => JDataType.DoubleType @@ -80,6 +81,8 @@ protected[sql] object DataTypeConversions { BinaryType case booleanType: org.apache.spark.sql.api.java.BooleanType => BooleanType + case dateType: org.apache.spark.sql.api.java.DateType => + DateType case timestampType: org.apache.spark.sql.api.java.TimestampType => TimestampType case decimalType: org.apache.spark.sql.api.java.DecimalType => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 1e624f97004f5..444bc95009c31 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.storage.RDDBlockId +import org.apache.spark.storage.{StorageLevel, RDDBlockId} case class BigData(s: String) @@ -55,10 +55,10 @@ class CachedTableSuite extends QueryTest { test("too big for memory") { val data = "*" * 10000 - sparkContext.parallelize(1 to 1000000, 1).map(_ => BigData(data)).registerTempTable("bigData") - cacheTable("bigData") - assert(table("bigData").count() === 1000000L) - uncacheTable("bigData") + sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).registerTempTable("bigData") + table("bigData").persist(StorageLevel.MEMORY_AND_DISK) + assert(table("bigData").count() === 200000L) + table("bigData").unpersist() } test("calling .cache() should use in-memory columnar caching") { @@ -69,7 +69,7 @@ class CachedTableSuite extends QueryTest { test("calling .unpersist() should drop in-memory columnar cache") { table("testData").cache() table("testData").count() - table("testData").unpersist(true) + table("testData").unpersist(blocking = true) assertCached(table("testData"), 0) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala index 8fb59c5830f6d..100ecb45e9e88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql import org.scalatest.FunSuite +import org.apache.spark.sql.catalyst.types.DataType + class DataTypeSuite extends FunSuite { test("construct an ArrayType") { @@ -55,4 +57,30 @@ class DataTypeSuite extends FunSuite { struct(Set("b", "d", "e", "f")) } } + + def checkDataTypeJsonRepr(dataType: DataType): Unit = { + test(s"JSON - $dataType") { + assert(DataType.fromJson(dataType.json) === dataType) + } + } + + checkDataTypeJsonRepr(BooleanType) + checkDataTypeJsonRepr(ByteType) + checkDataTypeJsonRepr(ShortType) + checkDataTypeJsonRepr(IntegerType) + checkDataTypeJsonRepr(LongType) + checkDataTypeJsonRepr(FloatType) + checkDataTypeJsonRepr(DoubleType) + checkDataTypeJsonRepr(DecimalType) + checkDataTypeJsonRepr(TimestampType) + checkDataTypeJsonRepr(StringType) + checkDataTypeJsonRepr(BinaryType) + checkDataTypeJsonRepr(ArrayType(DoubleType, true)) + checkDataTypeJsonRepr(ArrayType(StringType, false)) + checkDataTypeJsonRepr(MapType(IntegerType, StringType, true)) + checkDataTypeJsonRepr(MapType(IntegerType, ArrayType(DoubleType), false)) + checkDataTypeJsonRepr( + StructType(Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", ArrayType(DoubleType), nullable = false)))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index d001abb7e1fcc..45e58afe9d9a2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -147,6 +147,14 @@ class DslQuerySuite extends QueryTest { (1, 1, 1, 2) :: Nil) } + test("SPARK-3858 generator qualifiers are discarded") { + checkAnswer( + arrayData.as('ad) + .generate(Explode("data" :: Nil, 'data), alias = Some("ex")) + .select("ex.data".attr), + Seq(1, 2, 3, 2, 3, 4).map(Seq(_))) + } + test("average") { checkAnswer( testData2.groupBy()(avg('a)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 6c7697ece8c56..07f4d2946c1b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.{LeftOuter, RightOuter, FullOuter, Inner, LeftSemi} import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 6fb6cb8db0c8f..15f6ba4f72bbd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.{ShuffledHashJoin, BroadcastHashJoin} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.joins.BroadcastHashJoin import org.apache.spark.sql.test._ import org.scalatest.BeforeAndAfterAll import java.util.TimeZone @@ -42,7 +43,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { TimeZone.setDefault(origZone) } - test("SPARK-3176 Added Parser of SQL ABS()") { checkAnswer( sql("SELECT ABS(-1.3)"), @@ -61,7 +61,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { 4) } - test("SPARK-2041 column name equals tablename") { checkAnswer( sql("SELECT tableName FROM tableName"), @@ -680,9 +679,45 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { sql("SELECT CAST(TRUE AS STRING), CAST(FALSE AS STRING) FROM testData LIMIT 1"), ("true", "false") :: Nil) } - + test("SPARK-3371 Renaming a function expression with group by gives error") { registerFunction("len", (s: String) => s.length) checkAnswer( - sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), 1)} + sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), 1) + } + + test("SPARK-3813 CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END") { + checkAnswer( + sql("SELECT CASE key WHEN 1 THEN 1 ELSE 0 END FROM testData WHERE key = 1 group by key"), 1) + } + + test("SPARK-3813 CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END") { + checkAnswer( + sql("SELECT CASE WHEN key = 1 THEN 1 ELSE 2 END FROM testData WHERE key = 1 group by key"), 1) + } + + test("throw errors for non-aggregate attributes with aggregation") { + def checkAggregation(query: String, isInvalidQuery: Boolean = true) { + val logicalPlan = sql(query).queryExecution.logical + + if (isInvalidQuery) { + val e = intercept[TreeNodeException[LogicalPlan]](sql(query).queryExecution.analyzed) + assert( + e.getMessage.startsWith("Expression not in GROUP BY"), + "Non-aggregate attribute(s) not detected\n" + logicalPlan) + } else { + // Should not throw + sql(query).queryExecution.analyzed + } + } + + checkAggregation("SELECT key, COUNT(*) FROM testData") + checkAggregation("SELECT COUNT(key), COUNT(*) FROM testData", false) + + checkAggregation("SELECT value, COUNT(*) FROM testData GROUP BY key") + checkAggregation("SELECT COUNT(value), SUM(key) FROM testData GROUP BY key", false) + + checkAggregation("SELECT key + 2, COUNT(*) FROM testData GROUP BY key + 1") + checkAggregation("SELECT key + 1 + 1, COUNT(*) FROM testData GROUP BY key + 1", false) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index e24c521d24c7a..bfa9ea416266d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import org.scalatest.FunSuite @@ -34,6 +34,7 @@ case class ReflectData( byteField: Byte, booleanField: Boolean, decimalField: BigDecimal, + date: Date, timestampField: Timestamp, seqInt: Seq[Int]) @@ -76,7 +77,7 @@ case class ComplexReflectData( class ScalaReflectionRelationSuite extends FunSuite { test("query case class RDD") { val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, - BigDecimal(1), new Timestamp(12345), Seq(1,2,3)) + BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3)) val rdd = sparkContext.parallelize(data :: Nil) rdd.registerTempTable("reflectData") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 0cdbb3167ce36..6bdf741134e2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -30,6 +30,7 @@ class ColumnStatsSuite extends FunSuite { testColumnStats(classOf[FloatColumnStats], FLOAT, Row(Float.MaxValue, Float.MinValue, 0)) testColumnStats(classOf[DoubleColumnStats], DOUBLE, Row(Double.MaxValue, Double.MinValue, 0)) testColumnStats(classOf[StringColumnStats], STRING, Row(null, null, 0)) + testColumnStats(classOf[DateColumnStats], DATE, Row(null, null, 0)) testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, Row(null, null, 0)) def testColumnStats[T <: NativeType, U <: ColumnStats]( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 4fb1ecf1d532b..3f3f35d50188b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import org.scalatest.FunSuite @@ -33,8 +33,8 @@ class ColumnTypeSuite extends FunSuite with Logging { test("defaultSize") { val checks = Map( - INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4, - BOOLEAN -> 1, STRING -> 8, TIMESTAMP -> 12, BINARY -> 16, GENERIC -> 16) + INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4, BOOLEAN -> 1, + STRING -> 8, DATE -> 8, TIMESTAMP -> 12, BINARY -> 16, GENERIC -> 16) checks.foreach { case (columnType, expectedSize) => assertResult(expectedSize, s"Wrong defaultSize for $columnType") { @@ -64,6 +64,7 @@ class ColumnTypeSuite extends FunSuite with Logging { checkActualSize(FLOAT, Float.MaxValue, 4) checkActualSize(BOOLEAN, true, 1) checkActualSize(STRING, "hello", 4 + "hello".getBytes("utf-8").length) + checkActualSize(DATE, new Date(0L), 8) checkActualSize(TIMESTAMP, new Timestamp(0L), 12) val binary = Array.fill[Byte](4)(0: Byte) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala index 38b04dd959f70..a1f21219eaf2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.columnar import scala.collection.immutable.HashSet import scala.util.Random -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericMutableRow @@ -50,6 +50,7 @@ object ColumnarTestUtils { case STRING => Random.nextString(Random.nextInt(32)) case BOOLEAN => Random.nextBoolean() case BINARY => randomBytes(Random.nextInt(32)) + case DATE => new Date(Random.nextLong()) case TIMESTAMP => val timestamp = new Timestamp(Random.nextLong()) timestamp.setNanos(Random.nextInt(999999999)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index 6c9a9ab6c3418..21906e3fdcc6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -41,7 +41,9 @@ object TestNullableColumnAccessor { class NullableColumnAccessorSuite extends FunSuite { import ColumnarTestUtils._ - Seq(INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC, TIMESTAMP).foreach { + Seq( + INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC, DATE, TIMESTAMP + ).foreach { testNullableColumnAccessor(_) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index f54a21eb4fbb1..cb73f3da81e24 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -37,7 +37,9 @@ object TestNullableColumnBuilder { class NullableColumnBuilderSuite extends FunSuite { import ColumnarTestUtils._ - Seq(INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC, TIMESTAMP).foreach { + Seq( + INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC, DATE, TIMESTAMP + ).foreach { testNullableColumnBuilder(_) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 69e0adbd3ee0d..f53acc8c9f718 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -67,10 +67,11 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be checkBatchPruning("i > 8 AND i <= 21", 9 to 21, 2, 3) checkBatchPruning("i < 2 OR i > 99", Seq(1, 100), 2, 2) checkBatchPruning("i < 2 OR (i > 78 AND i < 92)", Seq(1) ++ (79 to 91), 3, 4) + checkBatchPruning("NOT (i < 88)", 88 to 100, 1, 2) // With unsupported predicate checkBatchPruning("i < 12 AND i IS NOT NULL", 1 to 11, 1, 2) - checkBatchPruning("NOT (i < 88)", 88 to 100, 5, 10) + checkBatchPruning(s"NOT (i in (${(1 to 30).mkString(",")}))", 31 to 100, 5, 10) def checkBatchPruning( filter: String, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index bfbf431a11913..f14ffca0e4d35 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -19,10 +19,11 @@ package org.apache.spark.sql.execution import org.scalatest.FunSuite +import org.apache.spark.sql.{SQLConf, execution} import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.{SQLConf, execution} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext.planner._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala new file mode 100644 index 0000000000000..87c28c334d228 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.debug + +import org.scalatest.FunSuite + +import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.test.TestSQLContext._ + +class DebuggingSuite extends FunSuite { + test("SchemaRDD.debug()") { + testData.debug() + } + + test("SchemaRDD.typeCheck()") { + testData.typeCheck() + } +} \ No newline at end of file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala new file mode 100644 index 0000000000000..2aad01ded1acf --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.scalatest.FunSuite + +import org.apache.spark.sql.catalyst.expressions.{Projection, Row} +import org.apache.spark.util.collection.CompactBuffer + + +class HashedRelationSuite extends FunSuite { + + // Key is simply the record itself + private val keyProjection = new Projection { + override def apply(row: Row): Row = row + } + + test("GeneralHashedRelation") { + val data = Array(Row(0), Row(1), Row(2), Row(2)) + val hashed = HashedRelation(data.iterator, keyProjection) + assert(hashed.isInstanceOf[GeneralHashedRelation]) + + assert(hashed.get(data(0)) == CompactBuffer[Row](data(0))) + assert(hashed.get(data(1)) == CompactBuffer[Row](data(1))) + assert(hashed.get(Row(10)) === null) + + val data2 = CompactBuffer[Row](data(2)) + data2 += data(2) + assert(hashed.get(data(2)) == data2) + } + + test("UniqueKeyHashedRelation") { + val data = Array(Row(0), Row(1), Row(2)) + val hashed = HashedRelation(data.iterator, keyProjection) + assert(hashed.isInstanceOf[UniqueKeyHashedRelation]) + + assert(hashed.get(data(0)) == CompactBuffer[Row](data(0))) + assert(hashed.get(data(1)) == CompactBuffer[Row](data(1))) + assert(hashed.get(data(2)) == CompactBuffer[Row](data(2))) + assert(hashed.get(Row(10)) === null) + + val uniqHashed = hashed.asInstanceOf[UniqueKeyHashedRelation] + assert(uniqHashed.getValue(data(0)) == data(0)) + assert(uniqHashed.getValue(data(1)) == data(1)) + assert(uniqHashed.getValue(data(2)) == data(2)) + assert(uniqHashed.getValue(Row(10)) == null) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 685e788207725..7bb08f1b513ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -21,8 +21,12 @@ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.json.JsonRDD.{enforceCorrectType, compatibleType} import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ +import java.sql.Timestamp + class JsonSuite extends QueryTest { import TestJsonData._ TestJsonData @@ -50,6 +54,12 @@ class JsonSuite extends QueryTest { val doubleNumber: Double = 1.7976931348623157E308d checkTypePromotion(doubleNumber.toDouble, enforceCorrectType(doubleNumber, DoubleType)) checkTypePromotion(BigDecimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType)) + + checkTypePromotion(new Timestamp(intNumber), enforceCorrectType(intNumber, TimestampType)) + checkTypePromotion(new Timestamp(intNumber.toLong), + enforceCorrectType(intNumber.toLong, TimestampType)) + val strDate = "2014-09-30 12:34:56" + checkTypePromotion(Timestamp.valueOf(strDate), enforceCorrectType(strDate, TimestampType)) } test("Get compatible type") { @@ -636,7 +646,65 @@ class JsonSuite extends QueryTest { ("str_a_1", null, null) :: ("str_a_2", null, null) :: (null, "str_b_3", null) :: - ("str_a_4", "str_b_4", "str_c_4") ::Nil + ("str_a_4", "str_b_4", "str_c_4") :: Nil + ) + } + + test("Corrupt records") { + // Test if we can query corrupt records. + val oldColumnNameOfCorruptRecord = TestSQLContext.columnNameOfCorruptRecord + TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") + + val jsonSchemaRDD = jsonRDD(corruptRecords) + jsonSchemaRDD.registerTempTable("jsonTable") + + val schema = StructType( + StructField("_unparsed", StringType, true) :: + StructField("a", StringType, true) :: + StructField("b", StringType, true) :: + StructField("c", StringType, true) :: Nil) + + assert(schema === jsonSchemaRDD.schema) + + // In HiveContext, backticks should be used to access columns starting with a underscore. + checkAnswer( + sql( + """ + |SELECT a, b, c, _unparsed + |FROM jsonTable + """.stripMargin), + (null, null, null, "{") :: + (null, null, null, "") :: + (null, null, null, """{"a":1, b:2}""") :: + (null, null, null, """{"a":{, b:3}""") :: + ("str_a_4", "str_b_4", "str_c_4", null) :: + (null, null, null, "]") :: Nil ) + + checkAnswer( + sql( + """ + |SELECT a, b, c + |FROM jsonTable + |WHERE _unparsed IS NULL + """.stripMargin), + ("str_a_4", "str_b_4", "str_c_4") :: Nil + ) + + checkAnswer( + sql( + """ + |SELECT _unparsed + |FROM jsonTable + |WHERE _unparsed IS NOT NULL + """.stripMargin), + Seq("{") :: + Seq("") :: + Seq("""{"a":1, b:2}""") :: + Seq("""{"a":{, b:3}""") :: + Seq("]") :: Nil + ) + + TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala index fc833b8b54e4c..eaca9f0508a12 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala @@ -143,4 +143,13 @@ object TestJsonData { """[{"a":"str_a_2"}, {"b":"str_b_3"}]""" :: """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: """[]""" :: Nil) + + val corruptRecords = + TestSQLContext.sparkContext.parallelize( + """{""" :: + """""" :: + """{"a":1, b:2}""" :: + """{"a":{, b:3}""" :: + """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: + """]""" :: Nil) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 07adf731405af..25e41ecf28e2e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -789,7 +789,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(result3(0)(1) === "the answer") Utils.deleteRecursively(tmpdir) } - + test("Querying on empty parquet throws exception (SPARK-3536)") { val tmpdir = Utils.createTempDir() Utils.deleteRecursively(tmpdir) @@ -798,4 +798,18 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(result1.size === 0) Utils.deleteRecursively(tmpdir) } + + test("DataType string parser compatibility") { + val schema = StructType(List( + StructField("c1", IntegerType, false), + StructField("c2", BinaryType, false))) + + val fromCaseClassString = ParquetTypesConverter.convertFromString(schema.toString) + val fromJson = ParquetTypesConverter.convertFromString(schema.json) + + (fromCaseClassString, fromJson).zipped.foreach { (a, b) => + assert(a.name == b.name) + assert(a.dataType === b.dataType) + } + } } diff --git a/sql/hbase/pom.xml b/sql/hbase/pom.xml index 23fd46af6c7f8..5f0812a69448b 100644 --- a/sql/hbase/pom.xml +++ b/sql/hbase/pom.xml @@ -64,118 +64,118 @@ test - org.apache.spark - spark-sql_${scala.binary.version} - ${project.version} - + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + org.apache.hbase hbase-common ${hbase.version} - - asm - asm - - - org.jboss.netty - netty - - - io.netty - netty - - - commons-logging - commons-logging - - - org.jruby - jruby-complete - + + asm + asm + + + org.jboss.netty + netty + + + io.netty + netty + + + commons-logging + commons-logging + + + org.jruby + jruby-complete + - - + + org.apache.hbase hbase-client ${hbase.version} - - asm - asm - - - org.jboss.netty - netty - - - io.netty - netty - - - commons-logging - commons-logging - - - org.jruby - jruby-complete - + + asm + asm + + + org.jboss.netty + netty + + + io.netty + netty + + + commons-logging + commons-logging + + + org.jruby + jruby-complete + - - + + org.apache.hbase hbase-server ${hbase.version} - - asm - asm - - - org.jboss.netty - netty - - - io.netty - netty - - - commons-logging - commons-logging - - - org.jruby - jruby-complete - + + asm + asm + + + org.jboss.netty + netty + + + io.netty + netty + + + commons-logging + commons-logging + + + org.jruby + jruby-complete + - - + + org.apache.hbase hbase-protocol ${hbase.version} - - asm - asm - - - org.jboss.netty - netty - - - io.netty - netty - - - commons-logging - commons-logging - - - org.jruby - jruby-complete - + + asm + asm + + + org.jboss.netty + netty + + + io.netty + netty + + + commons-logging + commons-logging + + + org.jruby + jruby-complete + - + org.apache.hbase hbase-testing-util diff --git a/sql/hbase/src/main/scala/org/apache/spark/sql/hbase/old/DataTypeUtils.scala b/sql/hbase/src/main/scala/org/apache/spark/sql/hbase/old/DataTypeUtils.scala index 55b7dd3ac7518..79ad498e54d3f 100644 --- a/sql/hbase/src/main/scala/org/apache/spark/sql/hbase/old/DataTypeUtils.scala +++ b/sql/hbase/src/main/scala/org/apache/spark/sql/hbase/old/DataTypeUtils.scala @@ -20,7 +20,6 @@ import java.io.{DataOutputStream, ByteArrayOutputStream, DataInputStream, ByteAr import java.math.BigDecimal import org.apache.hadoop.hbase.util.Bytes -import org.apache.log4j.Logger import org.apache.spark.sql import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.catalyst.types._ @@ -30,14 +29,13 @@ import org.apache.spark.sql.catalyst.types._ * Created by sboesch on 10/9/14. */ object DataTypeUtils { - val logger = Logger.getLogger(getClass.getName) def cmp(str1: Option[HBaseRawType], str2: Option[HBaseRawType]) = { if (str1.isEmpty && str2.isEmpty) 0 else if (str1.isEmpty) -2 else if (str2.isEmpty) 2 else { - var ix = 0 + val ix = 0 val s1arr = str1.get val s2arr = str2.get var retval: Option[Int] = None diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index 910174a153768..accf61576b804 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -172,7 +172,7 @@ private[thriftserver] class SparkSQLOperationManager(hiveContext: HiveContext) result = hiveContext.sql(statement) logDebug(result.queryExecution.toString()) result.queryExecution.logical match { - case SetCommand(Some(key), Some(value)) if (key == SQLConf.THRIFTSERVER_POOL) => + case SetCommand(Some((SQLConf.THRIFTSERVER_POOL, Some(value)))) => sessionToActivePool(parentSession) = value logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.") case _ => diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 3475c2c9db080..8a72e9d2aef57 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -30,7 +30,7 @@ import java.util.concurrent.atomic.AtomicInteger import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.scalatest.{BeforeAndAfterAll, FunSuite} -import org.apache.spark.Logging +import org.apache.spark.{SparkException, Logging} import org.apache.spark.sql.catalyst.util.getTempFilePath class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { @@ -62,9 +62,14 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { def captureOutput(source: String)(line: String) { buffer += s"$source> $line" - if (line.contains(expectedAnswers(next.get()))) { - if (next.incrementAndGet() == expectedAnswers.size) { - foundAllExpectedAnswers.trySuccess(()) + // If we haven't found all expected answers... + if (next.get() < expectedAnswers.size) { + // If another expected answer is found... + if (line.startsWith(expectedAnswers(next.get()))) { + // If all expected answers have been found... + if (next.incrementAndGet() == expectedAnswers.size) { + foundAllExpectedAnswers.trySuccess(()) + } } } } @@ -73,11 +78,6 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { val process = (Process(command) #< queryStream).run( ProcessLogger(captureOutput("stdout"), captureOutput("stderr"))) - Future { - val exitValue = process.exitValue() - logInfo(s"Spark SQL CLI process exit value: $exitValue") - } - try { Await.result(foundAllExpectedAnswers.future, timeout) } catch { case cause: Throwable => @@ -96,6 +96,7 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { |End CliSuite failure output |=========================== """.stripMargin, cause) + throw cause } finally { warehousePath.delete() metastorePath.delete() @@ -107,7 +108,7 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { val dataFilePath = Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt") - runCliWithin(1.minute)( + runCliWithin(3.minute)( "CREATE TABLE hive_test(key INT, val STRING);" -> "OK", "SHOW TABLES;" @@ -118,7 +119,7 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { -> "Time taken: ", "SELECT COUNT(*) FROM hive_test;" -> "5", - "DROP TABLE hive_test" + "DROP TABLE hive_test;" -> "Time taken: " ) } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala index 38977ff162097..e3b4e45a3d68c 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala @@ -17,17 +17,17 @@ package org.apache.spark.sql.hive.thriftserver -import scala.collection.mutable.ArrayBuffer -import scala.concurrent.ExecutionContext.Implicits.global -import scala.concurrent.duration._ -import scala.concurrent.{Await, Future, Promise} -import scala.sys.process.{Process, ProcessLogger} - import java.io.File import java.net.ServerSocket import java.sql.{DriverManager, Statement} import java.util.concurrent.TimeoutException +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.duration._ +import scala.concurrent.{Await, Promise} +import scala.sys.process.{Process, ProcessLogger} +import scala.util.Try + import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.jdbc.HiveDriver import org.scalatest.FunSuite @@ -41,25 +41,25 @@ import org.apache.spark.sql.catalyst.util.getTempFilePath class HiveThriftServer2Suite extends FunSuite with Logging { Class.forName(classOf[HiveDriver].getCanonicalName) - private val listeningHost = "localhost" - private val listeningPort = { - // Let the system to choose a random available port to avoid collision with other parallel - // builds. - val socket = new ServerSocket(0) - val port = socket.getLocalPort - socket.close() - port - } - - private val warehousePath = getTempFilePath("warehouse") - private val metastorePath = getTempFilePath("metastore") - private val metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true" - - def startThriftServerWithin(timeout: FiniteDuration = 30.seconds)(f: Statement => Unit) { - val serverScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator) + def startThriftServerWithin(timeout: FiniteDuration = 1.minute)(f: Statement => Unit) { + val startScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator) + val stopScript = "../../sbin/stop-thriftserver.sh".split("/").mkString(File.separator) + + val warehousePath = getTempFilePath("warehouse") + val metastorePath = getTempFilePath("metastore") + val metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true" + val listeningHost = "localhost" + val listeningPort = { + // Let the system to choose a random available port to avoid collision with other parallel + // builds. + val socket = new ServerSocket(0) + val port = socket.getLocalPort + socket.close() + port + } val command = - s"""$serverScript + s"""$startScript | --master local | --hiveconf hive.root.logger=INFO,console | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri @@ -68,29 +68,40 @@ class HiveThriftServer2Suite extends FunSuite with Logging { | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$listeningPort """.stripMargin.split("\\s+").toSeq - val serverStarted = Promise[Unit]() + val serverRunning = Promise[Unit]() val buffer = new ArrayBuffer[String]() + val LOGGING_MARK = + s"starting ${HiveThriftServer2.getClass.getCanonicalName.stripSuffix("$")}, logging to " + var logTailingProcess: Process = null + var logFilePath: String = null - def captureOutput(source: String)(line: String) { - buffer += s"$source> $line" + def captureLogOutput(line: String): Unit = { + buffer += line if (line.contains("ThriftBinaryCLIService listening on")) { - serverStarted.success(()) + serverRunning.success(()) } } - val process = Process(command).run( - ProcessLogger(captureOutput("stdout"), captureOutput("stderr"))) - - Future { - val exitValue = process.exitValue() - logInfo(s"Spark SQL Thrift server process exit value: $exitValue") + def captureThriftServerOutput(source: String)(line: String): Unit = { + if (line.startsWith(LOGGING_MARK)) { + logFilePath = line.drop(LOGGING_MARK.length).trim + // Ensure that the log file is created so that the `tail' command won't fail + Try(new File(logFilePath).createNewFile()) + logTailingProcess = Process(s"/usr/bin/env tail -f $logFilePath") + .run(ProcessLogger(captureLogOutput, _ => ())) + } } + // Resets SPARK_TESTING to avoid loading Log4J configurations in testing class paths + Process(command, None, "SPARK_TESTING" -> "0").run(ProcessLogger( + captureThriftServerOutput("stdout"), + captureThriftServerOutput("stderr"))) + val jdbcUri = s"jdbc:hive2://$listeningHost:$listeningPort/" val user = System.getProperty("user.name") try { - Await.result(serverStarted.future, timeout) + Await.result(serverRunning.future, timeout) val connection = DriverManager.getConnection(jdbcUri, user, "") val statement = connection.createStatement() @@ -122,10 +133,15 @@ class HiveThriftServer2Suite extends FunSuite with Logging { |End HiveThriftServer2Suite failure output |========================================= """.stripMargin, cause) + throw cause } finally { warehousePath.delete() metastorePath.delete() - process.destroy() + Process(stopScript).run().exitValue() + // The `spark-daemon.sh' script uses kill, which is not synchronous, have to wait for a while. + Thread.sleep(3.seconds.toMillis) + Option(logTailingProcess).map(_.destroy()) + Option(logFilePath).map(new File(_).delete()) } } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 35e9c9939d4b7..463888551a359 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -343,6 +343,13 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "ct_case_insensitive", "database_location", "database_properties", + "date_2", + "date_3", + "date_4", + "date_comparison", + "date_join1", + "date_serde", + "date_udf", "decimal_1", "decimal_4", "decimal_join", @@ -604,8 +611,10 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "part_inherit_tbl_props", "part_inherit_tbl_props_empty", "part_inherit_tbl_props_with_star", + "partition_date", "partition_schema1", "partition_serde_format", + "partition_type_check", "partition_varchar1", "partition_wise_fileformat4", "partition_wise_fileformat5", @@ -904,6 +913,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "union7", "union8", "union9", + "union_date", "union_lateralview", "union_ppr", "union_remove_11", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala index c5844e92eaaa9..430ffb29989ea 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala @@ -18,118 +18,50 @@ package org.apache.spark.sql.hive import scala.language.implicitConversions -import scala.util.parsing.combinator.syntactical.StandardTokenParsers -import scala.util.parsing.combinator.PackratParsers + import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.SqlLexical +import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, SqlLexical} /** - * A parser that recognizes all HiveQL constructs together with several Spark SQL specific - * extensions like CACHE TABLE and UNCACHE TABLE. + * A parser that recognizes all HiveQL constructs together with Spark SQL specific extensions. */ -private[hive] class ExtendedHiveQlParser extends StandardTokenParsers with PackratParsers { - - def apply(input: String): LogicalPlan = { - // Special-case out set commands since the value fields can be - // complex to handle without RegexParsers. Also this approach - // is clearer for the several possible cases of set commands. - if (input.trim.toLowerCase.startsWith("set")) { - input.trim.drop(3).split("=", 2).map(_.trim) match { - case Array("") => // "set" - SetCommand(None, None) - case Array(key) => // "set key" - SetCommand(Some(key), None) - case Array(key, value) => // "set key=value" - SetCommand(Some(key), Some(value)) - } - } else if (input.trim.startsWith("!")) { - ShellCommand(input.drop(1)) - } else { - phrase(query)(new lexical.Scanner(input)) match { - case Success(r, x) => r - case x => sys.error(x.toString) - } - } - } - - protected case class Keyword(str: String) - - protected val ADD = Keyword("ADD") - protected val AS = Keyword("AS") - protected val CACHE = Keyword("CACHE") - protected val DFS = Keyword("DFS") - protected val FILE = Keyword("FILE") - protected val JAR = Keyword("JAR") - protected val LAZY = Keyword("LAZY") - protected val SET = Keyword("SET") - protected val SOURCE = Keyword("SOURCE") - protected val TABLE = Keyword("TABLE") - protected val UNCACHE = Keyword("UNCACHE") - +private[hive] class ExtendedHiveQlParser extends AbstractSparkSQLParser { protected implicit def asParser(k: Keyword): Parser[String] = lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) - protected def allCaseConverse(k: String): Parser[String] = - lexical.allCaseVersions(k).map(x => x : Parser[String]).reduce(_ | _) + protected val ADD = Keyword("ADD") + protected val DFS = Keyword("DFS") + protected val FILE = Keyword("FILE") + protected val JAR = Keyword("JAR") - protected val reservedWords = - this.getClass + private val reservedWords = + this + .getClass .getMethods .filter(_.getReturnType == classOf[Keyword]) .map(_.invoke(this).asInstanceOf[Keyword].str) override val lexical = new SqlLexical(reservedWords) - protected lazy val query: Parser[LogicalPlan] = - cache | uncache | addJar | addFile | dfs | source | hiveQl + protected lazy val start: Parser[LogicalPlan] = dfs | addJar | addFile | hiveQl protected lazy val hiveQl: Parser[LogicalPlan] = restInput ^^ { - case statement => HiveQl.createPlan(statement.trim()) - } - - // Returns the whole input string - protected lazy val wholeInput: Parser[String] = new Parser[String] { - def apply(in: Input) = - Success(in.source.toString, in.drop(in.source.length())) - } - - // Returns the rest of the input string that are not parsed yet - protected lazy val restInput: Parser[String] = new Parser[String] { - def apply(in: Input) = - Success( - in.source.subSequence(in.offset, in.source.length).toString, - in.drop(in.source.length())) - } - - protected lazy val cache: Parser[LogicalPlan] = - CACHE ~> opt(LAZY) ~ (TABLE ~> ident) ~ opt(AS ~> hiveQl) ^^ { - case isLazy ~ tableName ~ plan => - CacheTableCommand(tableName, plan, isLazy.isDefined) - } - - protected lazy val uncache: Parser[LogicalPlan] = - UNCACHE ~ TABLE ~> ident ^^ { - case tableName => UncacheTableCommand(tableName) + case statement => HiveQl.createPlan(statement.trim) } - protected lazy val addJar: Parser[LogicalPlan] = - ADD ~ JAR ~> restInput ^^ { - case jar => AddJar(jar.trim()) + protected lazy val dfs: Parser[LogicalPlan] = + DFS ~> wholeInput ^^ { + case command => NativeCommand(command.trim) } - protected lazy val addFile: Parser[LogicalPlan] = + private lazy val addFile: Parser[LogicalPlan] = ADD ~ FILE ~> restInput ^^ { - case file => AddFile(file.trim()) + case input => AddFile(input.trim) } - protected lazy val dfs: Parser[LogicalPlan] = - DFS ~> wholeInput ^^ { - case command => NativeCommand(command.trim()) - } - - protected lazy val source: Parser[LogicalPlan] = - SOURCE ~> restInput ^^ { - case file => SourceCommand(file.trim()) + private lazy val addJar: Parser[LogicalPlan] = + ADD ~ JAR ~> restInput ^^ { + case input => AddJar(input.trim) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index fad3b39f81413..8b5a90159e1bb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive import java.io.{BufferedReader, File, InputStreamReader, PrintStream} -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import java.util.{ArrayList => JArrayList} import scala.collection.JavaConversions._ @@ -34,6 +34,7 @@ import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.ql.stats.StatsSetupConst import org.apache.hadoop.hive.serde2.io.TimestampWritable +import org.apache.hadoop.hive.serde2.io.DateWritable import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD @@ -357,7 +358,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { protected val primitiveTypes = Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType, - ShortType, DecimalType, TimestampType, BinaryType) + ShortType, DecimalType, DateType, TimestampType, BinaryType) protected[sql] def toHiveString(a: (Any, DataType)): String = a match { case (struct: Row, StructType(fields)) => @@ -372,6 +373,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) }.toSeq.sorted.mkString("{", ",", "}") case (null, _) => "NULL" + case (d: Date, DateType) => new DateWritable(d).toString case (t: Timestamp, TimestampType) => new TimestampWritable(t).toString case (bin: Array[Byte], BinaryType) => new String(bin, "UTF-8") case (other, tpe) if primitiveTypes contains tpe => other.toString diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index d633c42c6bd67..1977618b4c9f2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -39,6 +39,7 @@ private[hive] trait HiveInspectors { case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType + case c: Class[_] if c == classOf[hiveIo.DateWritable] => DateType case c: Class[_] if c == classOf[hiveIo.TimestampWritable] => TimestampType case c: Class[_] if c == classOf[hadoopIo.Text] => StringType case c: Class[_] if c == classOf[hadoopIo.IntWritable] => IntegerType @@ -49,6 +50,7 @@ private[hive] trait HiveInspectors { // java class case c: Class[_] if c == classOf[java.lang.String] => StringType + case c: Class[_] if c == classOf[java.sql.Date] => DateType case c: Class[_] if c == classOf[java.sql.Timestamp] => TimestampType case c: Class[_] if c == classOf[HiveDecimal] => DecimalType case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType @@ -93,6 +95,7 @@ private[hive] trait HiveInspectors { System.arraycopy(b.getBytes(), 0, bytes, 0, b.getLength) bytes } + case d: hiveIo.DateWritable => d.get case t: hiveIo.TimestampWritable => t.getTimestamp case b: hiveIo.HiveDecimalWritable => BigDecimal(b.getHiveDecimal().bigDecimalValue()) case list: java.util.List[_] => list.map(unwrap) @@ -108,6 +111,7 @@ private[hive] trait HiveInspectors { case str: String => str case p: java.math.BigDecimal => p case p: Array[Byte] => p + case p: java.sql.Date => p case p: java.sql.Timestamp => p } @@ -147,6 +151,7 @@ private[hive] trait HiveInspectors { case l: Byte => l: java.lang.Byte case b: BigDecimal => new HiveDecimal(b.underlying()) case b: Array[Byte] => b + case d: java.sql.Date => d case t: java.sql.Timestamp => t case s: Seq[_] => seqAsJavaList(s.map(wrap)) case m: Map[_,_] => @@ -173,6 +178,7 @@ private[hive] trait HiveInspectors { case ByteType => PrimitiveObjectInspectorFactory.javaByteObjectInspector case NullType => PrimitiveObjectInspectorFactory.javaVoidObjectInspector case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector + case DateType => PrimitiveObjectInspectorFactory.javaDateObjectInspector case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector case DecimalType => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector case StructType(fields) => @@ -211,6 +217,8 @@ private[hive] trait HiveInspectors { case _: JavaBinaryObjectInspector => BinaryType case _: WritableHiveDecimalObjectInspector => DecimalType case _: JavaHiveDecimalObjectInspector => DecimalType + case _: WritableDateObjectInspector => DateType + case _: JavaDateObjectInspector => DateType case _: WritableTimestampObjectInspector => TimestampType case _: JavaTimestampObjectInspector => TimestampType case _: WritableVoidObjectInspector => NullType @@ -238,6 +246,7 @@ private[hive] trait HiveInspectors { case ShortType => shortTypeInfo case StringType => stringTypeInfo case DecimalType => decimalTypeInfo + case DateType => dateTypeInfo case TimestampType => timestampTypeInfo case NullType => voidTypeInfo } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index cc0605b0adb35..75a19656af110 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -19,31 +19,28 @@ package org.apache.spark.sql.hive import scala.util.parsing.combinator.RegexParsers -import org.apache.hadoop.hive.metastore.api.{FieldSchema, StorageDescriptor, SerDeInfo} -import org.apache.hadoop.hive.metastore.api.{Table => TTable, Partition => TPartition} +import org.apache.hadoop.hive.metastore.api.{FieldSchema, SerDeInfo, StorageDescriptor, Partition => TPartition, Table => TTable} import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.hive.ql.stats.StatsSetupConst import org.apache.hadoop.hive.serde2.Deserializer -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.Logging +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.analysis.{EliminateAnalysisOperators, Catalog} +import org.apache.spark.sql.catalyst.analysis.Catalog import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.columnar.InMemoryRelation -import org.apache.spark.sql.hive.execution.HiveTableScan import org.apache.spark.util.Utils /* Implicit conversions */ import scala.collection.JavaConversions._ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with Logging { - import HiveMetastoreTypes._ + import org.apache.spark.sql.hive.HiveMetastoreTypes._ /** Connection to hive metastore. Usages should lock on `this`. */ protected[hive] val client = Hive.get(hive.hiveconf) @@ -137,10 +134,8 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with def castChildOutput(p: InsertIntoTable, table: MetastoreRelation, child: LogicalPlan) = { val childOutputDataTypes = child.output.map(_.dataType) - // Only check attributes, not partitionKeys since they are always strings. - // TODO: Fully support inserting into partitioned tables. val tableOutputDataTypes = - table.attributes.map(_.dataType) ++ table.partitionKeys.map(_.dataType) + (table.attributes ++ table.partitionKeys).take(child.output.length).map(_.dataType) if (childOutputDataTypes == tableOutputDataTypes) { p @@ -191,6 +186,7 @@ object HiveMetastoreTypes extends RegexParsers { "binary" ^^^ BinaryType | "boolean" ^^^ BooleanType | "decimal" ^^^ DecimalType | + "date" ^^^ DateType | "timestamp" ^^^ TimestampType | "varchar\\((\\d+)\\)".r ^^^ StringType @@ -240,6 +236,7 @@ object HiveMetastoreTypes extends RegexParsers { case LongType => "bigint" case BinaryType => "binary" case BooleanType => "boolean" + case DateType => "date" case DecimalType => "decimal" case TimestampType => "timestamp" case NullType => "void" @@ -308,7 +305,7 @@ private[hive] case class MetastoreRelation val partitionKeys = hiveQlTable.getPartitionKeys.map(_.toAttribute) /** Non-partitionKey attributes */ - val attributes = table.getSd.getCols.map(_.toAttribute) + val attributes = hiveQlTable.getCols.map(_.toAttribute) val output = attributes ++ partitionKeys } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 32c9175f181bb..2b599157d15d3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.hive +import java.sql.Date + import org.apache.hadoop.hive.ql.lib.Node import org.apache.hadoop.hive.ql.parse._ import org.apache.hadoop.hive.ql.plan.PlanUtils +import org.apache.spark.sql.catalyst.SparkSQLParser import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ @@ -38,10 +41,6 @@ import scala.collection.JavaConversions._ */ private[hive] case object NativePlaceholder extends Command -private[hive] case class ShellCommand(cmd: String) extends Command - -private[hive] case class SourceCommand(filePath: String) extends Command - private[hive] case class AddFile(filePath: String) extends Command private[hive] case class AddJar(path: String) extends Command @@ -126,9 +125,11 @@ private[hive] object HiveQl { "TOK_CREATETABLE", "TOK_DESCTABLE" ) ++ nativeCommands - - // It parses hive sql query along with with several Spark SQL specific extensions - protected val hiveSqlParser = new ExtendedHiveQlParser + + protected val hqlParser = { + val fallback = new ExtendedHiveQlParser + new SparkSQLParser(fallback(_)) + } /** * A set of implicit transformations that allow Hive ASTNodes to be rewritten by transformations @@ -218,7 +219,7 @@ private[hive] object HiveQl { def getAst(sql: String): ASTNode = ParseUtils.findRootNonNullToken((new ParseDriver).parse(sql)) /** Returns a LogicalPlan for a given HiveQL string. */ - def parseSql(sql: String): LogicalPlan = hiveSqlParser(sql) + def parseSql(sql: String): LogicalPlan = hqlParser(sql) /** Creates LogicalPlan for a given HiveQL string. */ def createPlan(sql: String) = { @@ -318,6 +319,7 @@ private[hive] object HiveQl { case Token("TOK_STRING", Nil) => StringType case Token("TOK_FLOAT", Nil) => FloatType case Token("TOK_DOUBLE", Nil) => DoubleType + case Token("TOK_DATE", Nil) => DateType case Token("TOK_TIMESTAMP", Nil) => TimestampType case Token("TOK_BINARY", Nil) => BinaryType case Token("TOK_LIST", elementType :: Nil) => ArrayType(nodeToDataType(elementType)) @@ -639,7 +641,7 @@ private[hive] object HiveQl { def nodeToRelation(node: Node): LogicalPlan = node match { case Token("TOK_SUBQUERY", query :: Token(alias, Nil) :: Nil) => - Subquery(alias, nodeToPlan(query)) + Subquery(cleanIdentifier(alias), nodeToPlan(query)) case Token(laterViewToken(isOuter), selectClause :: relationClause :: Nil) => val Token("TOK_SELECT", @@ -925,6 +927,8 @@ private[hive] object HiveQl { Cast(nodeToExpr(arg), DecimalType) case Token("TOK_FUNCTION", Token("TOK_TIMESTAMP", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), TimestampType) + case Token("TOK_FUNCTION", Token("TOK_DATE", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), DateType) /* Arithmetic */ case Token("-", child :: Nil) => UnaryMinus(nodeToExpr(child)) @@ -1048,6 +1052,9 @@ private[hive] object HiveQl { case ast: ASTNode if ast.getType == HiveParser.StringLiteral => Literal(BaseSemanticAnalyzer.unescapeSQLString(ast.getText)) + case ast: ASTNode if ast.getType == HiveParser.TOK_DATELITERAL => + Literal(Date.valueOf(ast.getText.substring(1, ast.getText.length - 1))) + case a: ASTNode => throw new NotImplementedError( s"""No parse rules for ASTNode type: ${a.getType}, text: ${a.getText} : diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 508d8239c7628..5c66322f1ed99 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -167,10 +167,10 @@ private[hive] trait HiveStrategies { database.get, tableName, query, - InsertIntoHiveTable(_: MetastoreRelation, - Map(), - query, - true)(hiveContext)) :: Nil + InsertIntoHiveTable(_: MetastoreRelation, + Map(), + query, + overwrite = true)(hiveContext)) :: Nil case _ => Nil } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 84fafcde63d05..0de29d5cffd0e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{Path, PathFilter} +import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.api.hive_metastoreConstants._ import org.apache.hadoop.hive.ql.exec.Utilities import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable} @@ -52,7 +53,8 @@ private[hive] class HadoopTableReader( @transient attributes: Seq[Attribute], @transient relation: MetastoreRelation, - @transient sc: HiveContext) + @transient sc: HiveContext, + @transient hiveExtraConf: HiveConf) extends TableReader { // Choose the minimum number of splits. If mapred.map.tasks is set, then use that unless @@ -63,7 +65,7 @@ class HadoopTableReader( // TODO: set aws s3 credentials. private val _broadcastedHiveConf = - sc.sparkContext.broadcast(new SerializableWritable(sc.hiveconf)) + sc.sparkContext.broadcast(new SerializableWritable(hiveExtraConf)) def broadcastedHiveConf = _broadcastedHiveConf diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index a4354c1379c63..9a9e2eda6bcd4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -31,6 +31,7 @@ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.hive.serde2.avro.AvroSerDe import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.util.Utils import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical.{CacheTableCommand, LogicalPlan, NativeCommand} import org.apache.spark.sql.catalyst.util._ @@ -71,11 +72,14 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { setConf("javax.jdo.option.ConnectionURL", s"jdbc:derby:;databaseName=$metastorePath;create=true") setConf("hive.metastore.warehouse.dir", warehousePath) + Utils.registerShutdownDeleteDir(new File(warehousePath)) + Utils.registerShutdownDeleteDir(new File(metastorePath)) } val testTempDir = File.createTempFile("testTempFiles", "spark.hive.tmp") testTempDir.delete() testTempDir.mkdir() + Utils.registerShutdownDeleteDir(testTempDir) // For some hive test case which contain ${system:test.tmp.dir} System.setProperty("test.tmp.dir", testTempDir.getCanonicalPath) @@ -121,8 +125,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { val hiveFilesTemp = File.createTempFile("catalystHiveFiles", "") hiveFilesTemp.delete() hiveFilesTemp.mkdir() - hiveFilesTemp.deleteOnExit() - + Utils.registerShutdownDeleteDir(hiveFilesTemp) val inRepoTests = if (System.getProperty("user.dir").endsWith("sql" + File.separator + "hive")) { new File("src" + File.separator + "test" + File.separator + "resources" + File.separator) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index 577ca928b43b6..5b83b77d80a22 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -64,8 +64,14 @@ case class HiveTableScan( BindReferences.bindReference(pred, relation.partitionKeys) } + // Create a local copy of hiveconf,so that scan specific modifications should not impact + // other queries @transient - private[this] val hadoopReader = new HadoopTableReader(attributes, relation, context) + private[this] val hiveExtraConf = new HiveConf(context.hiveconf) + + @transient + private[this] val hadoopReader = + new HadoopTableReader(attributes, relation, context, hiveExtraConf) private[this] def castFromString(value: String, dataType: DataType) = { Cast(Literal(value), dataType).eval(null) @@ -80,10 +86,14 @@ case class HiveTableScan( ColumnProjectionUtils.appendReadColumnIDs(hiveConf, neededColumnIDs) ColumnProjectionUtils.appendReadColumnNames(hiveConf, attributes.map(_.name)) + val tableDesc = relation.tableDesc + val deserializer = tableDesc.getDeserializerClass.newInstance + deserializer.initialize(hiveConf, tableDesc.getProperties) + // Specifies types and object inspectors of columns to be scanned. val structOI = ObjectInspectorUtils .getStandardObjectInspector( - relation.tableDesc.getDeserializer.getObjectInspector, + deserializer.getObjectInspector, ObjectInspectorCopyOption.JAVA) .asInstanceOf[StructObjectInspector] @@ -97,7 +107,7 @@ case class HiveTableScan( hiveConf.set(serdeConstants.LIST_COLUMNS, relation.attributes.map(_.name).mkString(",")) } - addColumnMetadataToConf(context.hiveconf) + addColumnMetadataToConf(hiveExtraConf) /** * Prunes partitions not involve the query plan. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index f8b4e898ec41d..f0785d8882636 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -69,33 +69,36 @@ case class InsertIntoHiveTable( * Wraps with Hive types based on object inspector. * TODO: Consolidate all hive OI/data interface code. */ - protected def wrap(a: (Any, ObjectInspector)): Any = a match { - case (s: String, oi: JavaHiveVarcharObjectInspector) => - new HiveVarchar(s, s.size) - - case (bd: BigDecimal, oi: JavaHiveDecimalObjectInspector) => - new HiveDecimal(bd.underlying()) - - case (row: Row, oi: StandardStructObjectInspector) => - val struct = oi.create() - row.zip(oi.getAllStructFieldRefs: Seq[StructField]).foreach { - case (data, field) => - oi.setStructFieldData(struct, field, wrap(data, field.getFieldObjectInspector)) + protected def wrapperFor(oi: ObjectInspector): Any => Any = oi match { + case _: JavaHiveVarcharObjectInspector => + (o: Any) => new HiveVarchar(o.asInstanceOf[String], o.asInstanceOf[String].size) + + case _: JavaHiveDecimalObjectInspector => + (o: Any) => new HiveDecimal(o.asInstanceOf[BigDecimal].underlying()) + + case soi: StandardStructObjectInspector => + val wrappers = soi.getAllStructFieldRefs.map(ref => wrapperFor(ref.getFieldObjectInspector)) + (o: Any) => { + val struct = soi.create() + (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[Row]).zipped.foreach { + (field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data)) + } + struct } - struct - case (s: Seq[_], oi: ListObjectInspector) => - val wrappedSeq = s.map(wrap(_, oi.getListElementObjectInspector)) - seqAsJavaList(wrappedSeq) + case loi: ListObjectInspector => + val wrapper = wrapperFor(loi.getListElementObjectInspector) + (o: Any) => seqAsJavaList(o.asInstanceOf[Seq[_]].map(wrapper)) - case (m: Map[_, _], oi: MapObjectInspector) => - val keyOi = oi.getMapKeyObjectInspector - val valueOi = oi.getMapValueObjectInspector - val wrappedMap = m.map { case (key, value) => wrap(key, keyOi) -> wrap(value, valueOi) } - mapAsJavaMap(wrappedMap) + case moi: MapObjectInspector => + val keyWrapper = wrapperFor(moi.getMapKeyObjectInspector) + val valueWrapper = wrapperFor(moi.getMapValueObjectInspector) + (o: Any) => mapAsJavaMap(o.asInstanceOf[Map[_, _]].map { case (key, value) => + keyWrapper(key) -> valueWrapper(value) + }) - case (obj, _) => - obj + case _ => + identity[Any] } def saveAsHiveFile( @@ -103,7 +106,7 @@ case class InsertIntoHiveTable( valueClass: Class[_], fileSinkConf: FileSinkDesc, conf: SerializableWritable[JobConf], - writerContainer: SparkHiveWriterContainer) { + writerContainer: SparkHiveWriterContainer): Unit = { assert(valueClass != null, "Output value class not set") conf.value.setOutputValueClass(valueClass) @@ -122,7 +125,7 @@ case class InsertIntoHiveTable( writerContainer.commitJob() // Note that this function is executed on executor side - def writeToFile(context: TaskContext, iterator: Iterator[Row]) { + def writeToFile(context: TaskContext, iterator: Iterator[Row]): Unit = { val serializer = newSerializer(fileSinkConf.getTableInfo) val standardOI = ObjectInspectorUtils .getStandardObjectInspector( @@ -131,6 +134,7 @@ case class InsertIntoHiveTable( .asInstanceOf[StructObjectInspector] val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray + val wrappers = fieldOIs.map(wrapperFor) val outputData = new Array[Any](fieldOIs.length) // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it @@ -141,13 +145,13 @@ case class InsertIntoHiveTable( iterator.foreach { row => var i = 0 while (i < fieldOIs.length) { - // TODO (lian) avoid per row dynamic dispatching and pattern matching cost in `wrap` - outputData(i) = wrap(row(i), fieldOIs(i)) + outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row(i)) i += 1 } - val writer = writerContainer.getLocalFileWriter(row) - writer.write(serializer.serialize(outputData, standardOI)) + writerContainer + .getLocalFileWriter(row) + .write(serializer.serialize(outputData, standardOI)) } writerContainer.close() @@ -207,7 +211,7 @@ case class InsertIntoHiveTable( // Report error if any static partition appears after a dynamic partition val isDynamic = partitionColumnNames.map(partitionSpec(_).isEmpty) - isDynamic.init.zip(isDynamic.tail).find(_ == (true, false)).foreach { _ => + if (isDynamic.init.zip(isDynamic.tail).contains((true, false))) { throw new SparkException(ErrorMsg.PARTITION_DYN_STA_ORDER.getMsg) } } diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java new file mode 100644 index 0000000000000..6c4f378bc5471 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution; + +import org.apache.hadoop.hive.ql.exec.UDF; + +public class UDFIntegerToString extends UDF { + public String evaluate(Integer i) { + return i.toString(); + } +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java new file mode 100644 index 0000000000000..d2d39a8c4dc28 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution; + +import org.apache.hadoop.hive.ql.exec.UDF; + +import java.util.List; + +public class UDFListListInt extends UDF { + /** + * + * @param obj + * SQL schema: array> + * Java Type: List> + * @return + */ + public long evaluate(Object obj) { + if (obj == null) { + return 0l; + } + List listList = (List) obj; + long retVal = 0; + for (List aList : listList) { + @SuppressWarnings("unchecked") + List list = (List) aList; + @SuppressWarnings("unchecked") + Integer someInt = (Integer) list.get(1); + try { + retVal += (long) (someInt.intValue()); + } catch (NullPointerException e) { + System.out.println(e); + } + } + return retVal; + } +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java new file mode 100644 index 0000000000000..efd34df293c88 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution; + +import org.apache.hadoop.hive.ql.exec.UDF; + +import java.util.List; +import org.apache.commons.lang.StringUtils; + +public class UDFListString extends UDF { + + public String evaluate(Object a) { + if (a == null) { + return null; + } + @SuppressWarnings("unchecked") + List s = (List) a; + + return StringUtils.join(s, ','); + } + + +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java new file mode 100644 index 0000000000000..a369188d471e8 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution; + +import org.apache.hadoop.hive.ql.exec.UDF; + +public class UDFStringString extends UDF { + public String evaluate(String s1, String s2) { + return s1 + " " + s2; + } +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java new file mode 100644 index 0000000000000..0165591a7ce78 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution; + +import org.apache.hadoop.hive.ql.exec.UDF; + +public class UDFTwoListList extends UDF { + public String evaluate(Object o1, Object o2) { + UDFListListInt udf = new UDFListListInt(); + + return String.format("%s, %s", udf.evaluate(o1), udf.evaluate(o2)); + } +} diff --git a/sql/hive/src/test/resources/golden/date_1-0-23edf29bf7376c70d5ecf12720f4b1eb b/sql/hive/src/test/resources/golden/date_1-0-23edf29bf7376c70d5ecf12720f4b1eb new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_1-1-4ebe3571c13a8b0c03096fbd972b7f1b b/sql/hive/src/test/resources/golden/date_1-1-4ebe3571c13a8b0c03096fbd972b7f1b new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_1-10-d964bec7e5632091ab5cb6f6786dbbf9 b/sql/hive/src/test/resources/golden/date_1-10-d964bec7e5632091ab5cb6f6786dbbf9 new file mode 100644 index 0000000000000..8fb5edae63c6f --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_1-10-d964bec7e5632091ab5cb6f6786dbbf9 @@ -0,0 +1 @@ +2011-01-01 1 diff --git a/sql/hive/src/test/resources/golden/date_1-11-480c5f024a28232b7857be327c992509 b/sql/hive/src/test/resources/golden/date_1-11-480c5f024a28232b7857be327c992509 new file mode 100644 index 0000000000000..5a368ab170261 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_1-11-480c5f024a28232b7857be327c992509 @@ -0,0 +1 @@ +2012-01-01 2011-01-01 2011-01-01 00:00:00 2011-01-01 2011-01-01 diff --git a/sql/hive/src/test/resources/golden/date_1-12-4c0ed7fcb75770d8790575b586bf14f4 b/sql/hive/src/test/resources/golden/date_1-12-4c0ed7fcb75770d8790575b586bf14f4 new file mode 100644 index 0000000000000..edb4b1f84001b --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_1-12-4c0ed7fcb75770d8790575b586bf14f4 @@ -0,0 +1 @@ +NULL NULL NULL NULL NULL NULL NULL diff --git a/sql/hive/src/test/resources/golden/date_1-13-44fc74c1993062c0a9522199ff27fea b/sql/hive/src/test/resources/golden/date_1-13-44fc74c1993062c0a9522199ff27fea new file mode 100644 index 0000000000000..2af0b9ed3a68c --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_1-13-44fc74c1993062c0a9522199ff27fea @@ -0,0 +1 @@ +true true true true true true true true true true diff --git a/sql/hive/src/test/resources/golden/date_1-14-4855a66124b16d1d0d003235995ac06b b/sql/hive/src/test/resources/golden/date_1-14-4855a66124b16d1d0d003235995ac06b new file mode 100644 index 0000000000000..d8dfbf60007bd --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_1-14-4855a66124b16d1d0d003235995ac06b @@ -0,0 +1 @@ +2001-01-28 2001-02-28 2001-03-28 2001-04-28 2001-05-28 2001-06-28 2001-07-28 2001-08-28 2001-09-28 2001-10-28 2001-11-28 2001-12-28 diff --git a/sql/hive/src/test/resources/golden/date_1-15-8bc190dba0f641840b5e1e198a14c55b b/sql/hive/src/test/resources/golden/date_1-15-8bc190dba0f641840b5e1e198a14c55b new file mode 100644 index 0000000000000..4f6a1bc4273e0 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_1-15-8bc190dba0f641840b5e1e198a14c55b @@ -0,0 +1 @@ +true true true true true true true true true true true true diff --git a/sql/hive/src/test/resources/golden/date_1-16-23edf29bf7376c70d5ecf12720f4b1eb b/sql/hive/src/test/resources/golden/date_1-16-23edf29bf7376c70d5ecf12720f4b1eb new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_1-2-abdce0c0d14d3fc7441b7c134b02f99a b/sql/hive/src/test/resources/golden/date_1-2-abdce0c0d14d3fc7441b7c134b02f99a new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_1-3-df16364a220ff96a6ea1cd478cbc1d0b b/sql/hive/src/test/resources/golden/date_1-3-df16364a220ff96a6ea1cd478cbc1d0b new file mode 100644 index 0000000000000..963bc42fdee07 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_1-3-df16364a220ff96a6ea1cd478cbc1d0b @@ -0,0 +1 @@ +2011-01-01 diff --git a/sql/hive/src/test/resources/golden/date_1-4-d964bec7e5632091ab5cb6f6786dbbf9 b/sql/hive/src/test/resources/golden/date_1-4-d964bec7e5632091ab5cb6f6786dbbf9 new file mode 100644 index 0000000000000..8fb5edae63c6f --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_1-4-d964bec7e5632091ab5cb6f6786dbbf9 @@ -0,0 +1 @@ +2011-01-01 1 diff --git a/sql/hive/src/test/resources/golden/date_1-5-5e70fc74158fbfca38134174360de12d b/sql/hive/src/test/resources/golden/date_1-5-5e70fc74158fbfca38134174360de12d new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_1-6-df16364a220ff96a6ea1cd478cbc1d0b b/sql/hive/src/test/resources/golden/date_1-6-df16364a220ff96a6ea1cd478cbc1d0b new file mode 100644 index 0000000000000..963bc42fdee07 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_1-6-df16364a220ff96a6ea1cd478cbc1d0b @@ -0,0 +1 @@ +2011-01-01 diff --git a/sql/hive/src/test/resources/golden/date_1-7-d964bec7e5632091ab5cb6f6786dbbf9 b/sql/hive/src/test/resources/golden/date_1-7-d964bec7e5632091ab5cb6f6786dbbf9 new file mode 100644 index 0000000000000..8fb5edae63c6f --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_1-7-d964bec7e5632091ab5cb6f6786dbbf9 @@ -0,0 +1 @@ +2011-01-01 1 diff --git a/sql/hive/src/test/resources/golden/date_1-8-1d5c58095cd52ea539d869f2ab1ab67d b/sql/hive/src/test/resources/golden/date_1-8-1d5c58095cd52ea539d869f2ab1ab67d new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_1-9-df16364a220ff96a6ea1cd478cbc1d0b b/sql/hive/src/test/resources/golden/date_1-9-df16364a220ff96a6ea1cd478cbc1d0b new file mode 100644 index 0000000000000..963bc42fdee07 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_1-9-df16364a220ff96a6ea1cd478cbc1d0b @@ -0,0 +1 @@ +2011-01-01 diff --git a/sql/hive/src/test/resources/golden/date_2-3-eedb73e0a622c2ab760b524f395dd4ba b/sql/hive/src/test/resources/golden/date_2-3-eedb73e0a622c2ab760b524f395dd4ba new file mode 100644 index 0000000000000..db973ab292d5b --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_2-3-eedb73e0a622c2ab760b524f395dd4ba @@ -0,0 +1,137 @@ +2010-10-20 7291 +2010-10-20 3198 +2010-10-20 3014 +2010-10-20 2630 +2010-10-20 1610 +2010-10-20 1599 +2010-10-20 1531 +2010-10-20 1142 +2010-10-20 1064 +2010-10-20 897 +2010-10-20 361 +2010-10-21 7291 +2010-10-21 3198 +2010-10-21 3014 +2010-10-21 2646 +2010-10-21 2630 +2010-10-21 1610 +2010-10-21 1599 +2010-10-21 1531 +2010-10-21 1142 +2010-10-21 1064 +2010-10-21 897 +2010-10-21 361 +2010-10-22 3198 +2010-10-22 3014 +2010-10-22 2646 +2010-10-22 2630 +2010-10-22 1610 +2010-10-22 1599 +2010-10-22 1531 +2010-10-22 1142 +2010-10-22 1064 +2010-10-22 897 +2010-10-22 361 +2010-10-23 7274 +2010-10-23 5917 +2010-10-23 5904 +2010-10-23 5832 +2010-10-23 3171 +2010-10-23 3085 +2010-10-23 2932 +2010-10-23 1805 +2010-10-23 650 +2010-10-23 426 +2010-10-23 384 +2010-10-23 272 +2010-10-24 7282 +2010-10-24 3198 +2010-10-24 3014 +2010-10-24 2646 +2010-10-24 2630 +2010-10-24 2571 +2010-10-24 2254 +2010-10-24 1610 +2010-10-24 1599 +2010-10-24 1531 +2010-10-24 897 +2010-10-24 361 +2010-10-25 7291 +2010-10-25 3198 +2010-10-25 3014 +2010-10-25 2646 +2010-10-25 2630 +2010-10-25 1610 +2010-10-25 1599 +2010-10-25 1531 +2010-10-25 1142 +2010-10-25 1064 +2010-10-25 897 +2010-10-25 361 +2010-10-26 7291 +2010-10-26 3198 +2010-10-26 3014 +2010-10-26 2662 +2010-10-26 2646 +2010-10-26 2630 +2010-10-26 1610 +2010-10-26 1599 +2010-10-26 1531 +2010-10-26 1142 +2010-10-26 1064 +2010-10-26 897 +2010-10-26 361 +2010-10-27 7291 +2010-10-27 3198 +2010-10-27 3014 +2010-10-27 2630 +2010-10-27 1610 +2010-10-27 1599 +2010-10-27 1531 +2010-10-27 1142 +2010-10-27 1064 +2010-10-27 897 +2010-10-27 361 +2010-10-28 7291 +2010-10-28 3198 +2010-10-28 3014 +2010-10-28 2646 +2010-10-28 2630 +2010-10-28 1610 +2010-10-28 1599 +2010-10-28 1531 +2010-10-28 1142 +2010-10-28 1064 +2010-10-28 897 +2010-10-28 361 +2010-10-29 7291 +2010-10-29 3198 +2010-10-29 3014 +2010-10-29 2646 +2010-10-29 2630 +2010-10-29 1610 +2010-10-29 1599 +2010-10-29 1531 +2010-10-29 1142 +2010-10-29 1064 +2010-10-29 897 +2010-10-29 361 +2010-10-30 5917 +2010-10-30 5904 +2010-10-30 3171 +2010-10-30 3085 +2010-10-30 2932 +2010-10-30 2018 +2010-10-30 1805 +2010-10-30 650 +2010-10-30 426 +2010-10-30 384 +2010-10-30 272 +2010-10-31 7282 +2010-10-31 3198 +2010-10-31 2571 +2010-10-31 1610 +2010-10-31 1599 +2010-10-31 1531 +2010-10-31 897 +2010-10-31 361 diff --git a/sql/hive/src/test/resources/golden/date_2-4-3618dfde8da7c26f03bca72970db9ef7 b/sql/hive/src/test/resources/golden/date_2-4-3618dfde8da7c26f03bca72970db9ef7 new file mode 100644 index 0000000000000..1b0ea7b9eec84 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_2-4-3618dfde8da7c26f03bca72970db9ef7 @@ -0,0 +1,137 @@ +2010-10-31 361 +2010-10-31 897 +2010-10-31 1531 +2010-10-31 1599 +2010-10-31 1610 +2010-10-31 2571 +2010-10-31 3198 +2010-10-31 7282 +2010-10-30 272 +2010-10-30 384 +2010-10-30 426 +2010-10-30 650 +2010-10-30 1805 +2010-10-30 2018 +2010-10-30 2932 +2010-10-30 3085 +2010-10-30 3171 +2010-10-30 5904 +2010-10-30 5917 +2010-10-29 361 +2010-10-29 897 +2010-10-29 1064 +2010-10-29 1142 +2010-10-29 1531 +2010-10-29 1599 +2010-10-29 1610 +2010-10-29 2630 +2010-10-29 2646 +2010-10-29 3014 +2010-10-29 3198 +2010-10-29 7291 +2010-10-28 361 +2010-10-28 897 +2010-10-28 1064 +2010-10-28 1142 +2010-10-28 1531 +2010-10-28 1599 +2010-10-28 1610 +2010-10-28 2630 +2010-10-28 2646 +2010-10-28 3014 +2010-10-28 3198 +2010-10-28 7291 +2010-10-27 361 +2010-10-27 897 +2010-10-27 1064 +2010-10-27 1142 +2010-10-27 1531 +2010-10-27 1599 +2010-10-27 1610 +2010-10-27 2630 +2010-10-27 3014 +2010-10-27 3198 +2010-10-27 7291 +2010-10-26 361 +2010-10-26 897 +2010-10-26 1064 +2010-10-26 1142 +2010-10-26 1531 +2010-10-26 1599 +2010-10-26 1610 +2010-10-26 2630 +2010-10-26 2646 +2010-10-26 2662 +2010-10-26 3014 +2010-10-26 3198 +2010-10-26 7291 +2010-10-25 361 +2010-10-25 897 +2010-10-25 1064 +2010-10-25 1142 +2010-10-25 1531 +2010-10-25 1599 +2010-10-25 1610 +2010-10-25 2630 +2010-10-25 2646 +2010-10-25 3014 +2010-10-25 3198 +2010-10-25 7291 +2010-10-24 361 +2010-10-24 897 +2010-10-24 1531 +2010-10-24 1599 +2010-10-24 1610 +2010-10-24 2254 +2010-10-24 2571 +2010-10-24 2630 +2010-10-24 2646 +2010-10-24 3014 +2010-10-24 3198 +2010-10-24 7282 +2010-10-23 272 +2010-10-23 384 +2010-10-23 426 +2010-10-23 650 +2010-10-23 1805 +2010-10-23 2932 +2010-10-23 3085 +2010-10-23 3171 +2010-10-23 5832 +2010-10-23 5904 +2010-10-23 5917 +2010-10-23 7274 +2010-10-22 361 +2010-10-22 897 +2010-10-22 1064 +2010-10-22 1142 +2010-10-22 1531 +2010-10-22 1599 +2010-10-22 1610 +2010-10-22 2630 +2010-10-22 2646 +2010-10-22 3014 +2010-10-22 3198 +2010-10-21 361 +2010-10-21 897 +2010-10-21 1064 +2010-10-21 1142 +2010-10-21 1531 +2010-10-21 1599 +2010-10-21 1610 +2010-10-21 2630 +2010-10-21 2646 +2010-10-21 3014 +2010-10-21 3198 +2010-10-21 7291 +2010-10-20 361 +2010-10-20 897 +2010-10-20 1064 +2010-10-20 1142 +2010-10-20 1531 +2010-10-20 1599 +2010-10-20 1610 +2010-10-20 2630 +2010-10-20 3014 +2010-10-20 3198 +2010-10-20 7291 diff --git a/sql/hive/src/test/resources/golden/date_2-5-fe9bebfc8994ddd8d7cd0208c1f0af3c b/sql/hive/src/test/resources/golden/date_2-5-fe9bebfc8994ddd8d7cd0208c1f0af3c new file mode 100644 index 0000000000000..0f2a6f7a99237 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_2-5-fe9bebfc8994ddd8d7cd0208c1f0af3c @@ -0,0 +1,12 @@ +2010-10-20 11 +2010-10-21 12 +2010-10-22 11 +2010-10-23 12 +2010-10-24 12 +2010-10-25 12 +2010-10-26 13 +2010-10-27 11 +2010-10-28 12 +2010-10-29 12 +2010-10-30 11 +2010-10-31 8 diff --git a/sql/hive/src/test/resources/golden/date_2-6-f4edce7cb20f325e8b69e787b2ae8882 b/sql/hive/src/test/resources/golden/date_2-6-f4edce7cb20f325e8b69e787b2ae8882 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_3-3-4cf49e71b636df754871a675f9e4e24 b/sql/hive/src/test/resources/golden/date_3-3-4cf49e71b636df754871a675f9e4e24 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_3-4-e009f358964f6d1236cfc03283e2b06f b/sql/hive/src/test/resources/golden/date_3-4-e009f358964f6d1236cfc03283e2b06f new file mode 100644 index 0000000000000..66d2220d06de2 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_3-4-e009f358964f6d1236cfc03283e2b06f @@ -0,0 +1 @@ +1 2011-01-01 diff --git a/sql/hive/src/test/resources/golden/date_3-5-c26de4559926ddb0127d2dc5ea154774 b/sql/hive/src/test/resources/golden/date_3-5-c26de4559926ddb0127d2dc5ea154774 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_4-0-b84f7e931d710dcbe3c5126d998285a8 b/sql/hive/src/test/resources/golden/date_4-0-b84f7e931d710dcbe3c5126d998285a8 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_4-1-6272f5e518f6a20bc96a5870ff315c4f b/sql/hive/src/test/resources/golden/date_4-1-6272f5e518f6a20bc96a5870ff315c4f new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_4-2-4a0e7bde447ef616b98e0f55d2886de0 b/sql/hive/src/test/resources/golden/date_4-2-4a0e7bde447ef616b98e0f55d2886de0 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_4-3-a23faa56b5d3ca9063a21f72b4278b00 b/sql/hive/src/test/resources/golden/date_4-3-a23faa56b5d3ca9063a21f72b4278b00 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_4-4-bee09a7384666043621f68297cee2e68 b/sql/hive/src/test/resources/golden/date_4-4-bee09a7384666043621f68297cee2e68 new file mode 100644 index 0000000000000..b61affde4ffce --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_4-4-bee09a7384666043621f68297cee2e68 @@ -0,0 +1 @@ +2011-01-01 2011-01-01 diff --git a/sql/hive/src/test/resources/golden/date_4-5-b84f7e931d710dcbe3c5126d998285a8 b/sql/hive/src/test/resources/golden/date_4-5-b84f7e931d710dcbe3c5126d998285a8 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_comparison-0-69eec445bd045c9dc899fafa348d8495 b/sql/hive/src/test/resources/golden/date_comparison-0-69eec445bd045c9dc899fafa348d8495 new file mode 100644 index 0000000000000..c508d5366f70b --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_comparison-0-69eec445bd045c9dc899fafa348d8495 @@ -0,0 +1 @@ +false diff --git a/sql/hive/src/test/resources/golden/date_comparison-1-fcc400871a502009c8680509e3869ec1 b/sql/hive/src/test/resources/golden/date_comparison-1-fcc400871a502009c8680509e3869ec1 new file mode 100644 index 0000000000000..c508d5366f70b --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_comparison-1-fcc400871a502009c8680509e3869ec1 @@ -0,0 +1 @@ +false diff --git a/sql/hive/src/test/resources/golden/date_comparison-10-a9f2560c273163e11306d4f1dd1d9d54 b/sql/hive/src/test/resources/golden/date_comparison-10-a9f2560c273163e11306d4f1dd1d9d54 new file mode 100644 index 0000000000000..c508d5366f70b --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_comparison-10-a9f2560c273163e11306d4f1dd1d9d54 @@ -0,0 +1 @@ +false diff --git a/sql/hive/src/test/resources/golden/date_comparison-11-4a7bac9ddcf40db6329faaec8e426543 b/sql/hive/src/test/resources/golden/date_comparison-11-4a7bac9ddcf40db6329faaec8e426543 new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_comparison-11-4a7bac9ddcf40db6329faaec8e426543 @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/date_comparison-2-b8598a4d0c948c2ddcf3eeef0abf2264 b/sql/hive/src/test/resources/golden/date_comparison-2-b8598a4d0c948c2ddcf3eeef0abf2264 new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_comparison-2-b8598a4d0c948c2ddcf3eeef0abf2264 @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/date_comparison-3-14d35f266be9cceb11a2ae09ec8b3835 b/sql/hive/src/test/resources/golden/date_comparison-3-14d35f266be9cceb11a2ae09ec8b3835 new file mode 100644 index 0000000000000..c508d5366f70b --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_comparison-3-14d35f266be9cceb11a2ae09ec8b3835 @@ -0,0 +1 @@ +false diff --git a/sql/hive/src/test/resources/golden/date_comparison-4-c8865b14d53f2c2496fb69ee8191bf37 b/sql/hive/src/test/resources/golden/date_comparison-4-c8865b14d53f2c2496fb69ee8191bf37 new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_comparison-4-c8865b14d53f2c2496fb69ee8191bf37 @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/date_comparison-5-f2c907e64da8166a731ddc0ed19bad6c b/sql/hive/src/test/resources/golden/date_comparison-5-f2c907e64da8166a731ddc0ed19bad6c new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_comparison-5-f2c907e64da8166a731ddc0ed19bad6c @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/date_comparison-6-5606505a92bad10023ad9a3ef77eacc9 b/sql/hive/src/test/resources/golden/date_comparison-6-5606505a92bad10023ad9a3ef77eacc9 new file mode 100644 index 0000000000000..c508d5366f70b --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_comparison-6-5606505a92bad10023ad9a3ef77eacc9 @@ -0,0 +1 @@ +false diff --git a/sql/hive/src/test/resources/golden/date_comparison-7-47913d4aaf0d468ab3764cc3bfd68eb b/sql/hive/src/test/resources/golden/date_comparison-7-47913d4aaf0d468ab3764cc3bfd68eb new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_comparison-7-47913d4aaf0d468ab3764cc3bfd68eb @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/date_comparison-8-1e5ce4f833b6fba45618437c8fb7643c b/sql/hive/src/test/resources/golden/date_comparison-8-1e5ce4f833b6fba45618437c8fb7643c new file mode 100644 index 0000000000000..c508d5366f70b --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_comparison-8-1e5ce4f833b6fba45618437c8fb7643c @@ -0,0 +1 @@ +false diff --git a/sql/hive/src/test/resources/golden/date_comparison-9-bcd987341fc1c38047a27d29dac6ae7c b/sql/hive/src/test/resources/golden/date_comparison-9-bcd987341fc1c38047a27d29dac6ae7c new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_comparison-9-bcd987341fc1c38047a27d29dac6ae7c @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/date_join1-3-f71c7be760fb4de4eff8225f2c6614b2 b/sql/hive/src/test/resources/golden/date_join1-3-f71c7be760fb4de4eff8225f2c6614b2 new file mode 100644 index 0000000000000..b7305b903edca --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_join1-3-f71c7be760fb4de4eff8225f2c6614b2 @@ -0,0 +1,22 @@ +1064 2010-10-20 1064 2010-10-20 +1142 2010-10-21 1142 2010-10-21 +1599 2010-10-22 1599 2010-10-22 +361 2010-10-23 361 2010-10-23 +897 2010-10-24 897 2010-10-24 +1531 2010-10-25 1531 2010-10-25 +1610 2010-10-26 1610 2010-10-26 +3198 2010-10-27 3198 2010-10-27 +1064 2010-10-28 1064 2010-10-28 +1142 2010-10-29 1142 2010-10-29 +1064 2000-11-20 1064 2000-11-20 +1142 2000-11-21 1142 2000-11-21 +1599 2000-11-22 1599 2000-11-22 +361 2000-11-23 361 2000-11-23 +897 2000-11-24 897 2000-11-24 +1531 2000-11-25 1531 2000-11-25 +1610 2000-11-26 1610 2000-11-26 +3198 2000-11-27 3198 2000-11-27 +1064 2000-11-28 1064 2000-11-28 +1142 2000-11-28 1064 2000-11-28 +1064 2000-11-28 1142 2000-11-28 +1142 2000-11-28 1142 2000-11-28 diff --git a/sql/hive/src/test/resources/golden/date_join1-4-70b9b49c55699fe94cfde069f5d197c b/sql/hive/src/test/resources/golden/date_join1-4-70b9b49c55699fe94cfde069f5d197c new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-10-d80e681519dcd8f5078c5602bb5befa9 b/sql/hive/src/test/resources/golden/date_serde-10-d80e681519dcd8f5078c5602bb5befa9 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-11-29540200936bba47f17553547b409af7 b/sql/hive/src/test/resources/golden/date_serde-11-29540200936bba47f17553547b409af7 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-12-c3c3275658b89d31fc504db31ae9f99c b/sql/hive/src/test/resources/golden/date_serde-12-c3c3275658b89d31fc504db31ae9f99c new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-13-6c546456c81e635b6753e1552fac9129 b/sql/hive/src/test/resources/golden/date_serde-13-6c546456c81e635b6753e1552fac9129 new file mode 100644 index 0000000000000..9f2238d57d6f5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_serde-13-6c546456c81e635b6753e1552fac9129 @@ -0,0 +1 @@ +2010-10-20 1064 diff --git a/sql/hive/src/test/resources/golden/date_serde-14-f8ba18cc7b0225b4022299c44d435101 b/sql/hive/src/test/resources/golden/date_serde-14-f8ba18cc7b0225b4022299c44d435101 new file mode 100644 index 0000000000000..9f2238d57d6f5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_serde-14-f8ba18cc7b0225b4022299c44d435101 @@ -0,0 +1 @@ +2010-10-20 1064 diff --git a/sql/hive/src/test/resources/golden/date_serde-15-66fadc9bcea7d107a610758aa6f50ff3 b/sql/hive/src/test/resources/golden/date_serde-15-66fadc9bcea7d107a610758aa6f50ff3 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-16-1bd3345b46f77e17810978e56f9f7c6b b/sql/hive/src/test/resources/golden/date_serde-16-1bd3345b46f77e17810978e56f9f7c6b new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-17-a0df43062f8ab676ef728c9968443f12 b/sql/hive/src/test/resources/golden/date_serde-17-a0df43062f8ab676ef728c9968443f12 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-18-b50ecc72ce9018ab12fb17568fef038a b/sql/hive/src/test/resources/golden/date_serde-18-b50ecc72ce9018ab12fb17568fef038a new file mode 100644 index 0000000000000..9f2238d57d6f5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_serde-18-b50ecc72ce9018ab12fb17568fef038a @@ -0,0 +1 @@ +2010-10-20 1064 diff --git a/sql/hive/src/test/resources/golden/date_serde-19-28f1cf92bdd6b2e5d328cd9d10f828b6 b/sql/hive/src/test/resources/golden/date_serde-19-28f1cf92bdd6b2e5d328cd9d10f828b6 new file mode 100644 index 0000000000000..9f2238d57d6f5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_serde-19-28f1cf92bdd6b2e5d328cd9d10f828b6 @@ -0,0 +1 @@ +2010-10-20 1064 diff --git a/sql/hive/src/test/resources/golden/date_serde-20-588516368d8c1533cb7bfb2157fd58c1 b/sql/hive/src/test/resources/golden/date_serde-20-588516368d8c1533cb7bfb2157fd58c1 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-21-dfe166fe053468e738dca23ebe043091 b/sql/hive/src/test/resources/golden/date_serde-21-dfe166fe053468e738dca23ebe043091 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-22-45240a488fb708e432d2f45b74ef7e63 b/sql/hive/src/test/resources/golden/date_serde-22-45240a488fb708e432d2f45b74ef7e63 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-23-1742a51e4967a8d263572d890cd8d4a8 b/sql/hive/src/test/resources/golden/date_serde-23-1742a51e4967a8d263572d890cd8d4a8 new file mode 100644 index 0000000000000..9f2238d57d6f5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_serde-23-1742a51e4967a8d263572d890cd8d4a8 @@ -0,0 +1 @@ +2010-10-20 1064 diff --git a/sql/hive/src/test/resources/golden/date_serde-24-14fd49bd6fee907c1699f7b4e26685b b/sql/hive/src/test/resources/golden/date_serde-24-14fd49bd6fee907c1699f7b4e26685b new file mode 100644 index 0000000000000..9f2238d57d6f5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_serde-24-14fd49bd6fee907c1699f7b4e26685b @@ -0,0 +1 @@ +2010-10-20 1064 diff --git a/sql/hive/src/test/resources/golden/date_serde-25-a199cf185184a25190d65c123d0694ee b/sql/hive/src/test/resources/golden/date_serde-25-a199cf185184a25190d65c123d0694ee new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-26-c5fa68d9aff36f22e5edc1b54332d0ab b/sql/hive/src/test/resources/golden/date_serde-26-c5fa68d9aff36f22e5edc1b54332d0ab new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-27-4d86c79f858866acec3c37f6598c2638 b/sql/hive/src/test/resources/golden/date_serde-27-4d86c79f858866acec3c37f6598c2638 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-28-16a41fc9e0f51eb417c763bae8e9cadb b/sql/hive/src/test/resources/golden/date_serde-28-16a41fc9e0f51eb417c763bae8e9cadb new file mode 100644 index 0000000000000..9f2238d57d6f5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_serde-28-16a41fc9e0f51eb417c763bae8e9cadb @@ -0,0 +1 @@ +2010-10-20 1064 diff --git a/sql/hive/src/test/resources/golden/date_serde-29-bd1cb09aacd906527b0bbf43bbded812 b/sql/hive/src/test/resources/golden/date_serde-29-bd1cb09aacd906527b0bbf43bbded812 new file mode 100644 index 0000000000000..9f2238d57d6f5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_serde-29-bd1cb09aacd906527b0bbf43bbded812 @@ -0,0 +1 @@ +2010-10-20 1064 diff --git a/sql/hive/src/test/resources/golden/date_serde-30-7c80741f9f485729afc68609c55423a0 b/sql/hive/src/test/resources/golden/date_serde-30-7c80741f9f485729afc68609c55423a0 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-31-da36cd1654aee055cb3650133c9d11f b/sql/hive/src/test/resources/golden/date_serde-31-da36cd1654aee055cb3650133c9d11f new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-32-bb2f76bd307ed616a3c797f8dd45a8d1 b/sql/hive/src/test/resources/golden/date_serde-32-bb2f76bd307ed616a3c797f8dd45a8d1 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-33-a742813b024e6dcfb4a358aa4e9fcdb6 b/sql/hive/src/test/resources/golden/date_serde-33-a742813b024e6dcfb4a358aa4e9fcdb6 new file mode 100644 index 0000000000000..9f2238d57d6f5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_serde-33-a742813b024e6dcfb4a358aa4e9fcdb6 @@ -0,0 +1 @@ +2010-10-20 1064 diff --git a/sql/hive/src/test/resources/golden/date_serde-34-6485841336c097895ad5b34f42c0745f b/sql/hive/src/test/resources/golden/date_serde-34-6485841336c097895ad5b34f42c0745f new file mode 100644 index 0000000000000..9f2238d57d6f5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_serde-34-6485841336c097895ad5b34f42c0745f @@ -0,0 +1 @@ +2010-10-20 1064 diff --git a/sql/hive/src/test/resources/golden/date_serde-35-8651a7c351cbc07fb1af6193f6885de8 b/sql/hive/src/test/resources/golden/date_serde-35-8651a7c351cbc07fb1af6193f6885de8 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-36-36e6041f53433482631018410bb62a99 b/sql/hive/src/test/resources/golden/date_serde-36-36e6041f53433482631018410bb62a99 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-37-3ddfd8ecb28991aeed588f1ea852c427 b/sql/hive/src/test/resources/golden/date_serde-37-3ddfd8ecb28991aeed588f1ea852c427 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-38-e6167e27465514356c557a77d956ea46 b/sql/hive/src/test/resources/golden/date_serde-38-e6167e27465514356c557a77d956ea46 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-39-c1e17c93582656c12970c37bac153bf2 b/sql/hive/src/test/resources/golden/date_serde-39-c1e17c93582656c12970c37bac153bf2 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-40-4a17944b9ec8999bb20c5ba5d4cb877c b/sql/hive/src/test/resources/golden/date_serde-40-4a17944b9ec8999bb20c5ba5d4cb877c new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_serde-8-cace4f60a08342f58fbe816a9c3a73cf b/sql/hive/src/test/resources/golden/date_serde-8-cace4f60a08342f58fbe816a9c3a73cf new file mode 100644 index 0000000000000..16c03e7276fec --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_serde-8-cace4f60a08342f58fbe816a9c3a73cf @@ -0,0 +1,137 @@ +Baltimore New York 2010-10-20 -30.0 1064 +Baltimore New York 2010-10-20 23.0 1142 +Baltimore New York 2010-10-20 6.0 1599 +Chicago New York 2010-10-20 42.0 361 +Chicago New York 2010-10-20 24.0 897 +Chicago New York 2010-10-20 15.0 1531 +Chicago New York 2010-10-20 -6.0 1610 +Chicago New York 2010-10-20 -2.0 3198 +Baltimore New York 2010-10-21 17.0 1064 +Baltimore New York 2010-10-21 105.0 1142 +Baltimore New York 2010-10-21 28.0 1599 +Chicago New York 2010-10-21 142.0 361 +Chicago New York 2010-10-21 77.0 897 +Chicago New York 2010-10-21 53.0 1531 +Chicago New York 2010-10-21 -5.0 1610 +Chicago New York 2010-10-21 51.0 3198 +Baltimore New York 2010-10-22 -12.0 1064 +Baltimore New York 2010-10-22 54.0 1142 +Baltimore New York 2010-10-22 18.0 1599 +Chicago New York 2010-10-22 2.0 361 +Chicago New York 2010-10-22 24.0 897 +Chicago New York 2010-10-22 16.0 1531 +Chicago New York 2010-10-22 -6.0 1610 +Chicago New York 2010-10-22 -11.0 3198 +Baltimore New York 2010-10-23 18.0 272 +Baltimore New York 2010-10-23 -10.0 1805 +Baltimore New York 2010-10-23 6.0 3171 +Chicago New York 2010-10-23 3.0 384 +Chicago New York 2010-10-23 32.0 426 +Chicago New York 2010-10-23 1.0 650 +Chicago New York 2010-10-23 11.0 3085 +Baltimore New York 2010-10-24 12.0 1599 +Baltimore New York 2010-10-24 20.0 2571 +Chicago New York 2010-10-24 10.0 361 +Chicago New York 2010-10-24 113.0 897 +Chicago New York 2010-10-24 -5.0 1531 +Chicago New York 2010-10-24 -17.0 1610 +Chicago New York 2010-10-24 -3.0 3198 +Baltimore New York 2010-10-25 -25.0 1064 +Baltimore New York 2010-10-25 92.0 1142 +Baltimore New York 2010-10-25 106.0 1599 +Chicago New York 2010-10-25 31.0 361 +Chicago New York 2010-10-25 -1.0 897 +Chicago New York 2010-10-25 43.0 1531 +Chicago New York 2010-10-25 6.0 1610 +Chicago New York 2010-10-25 -16.0 3198 +Baltimore New York 2010-10-26 -22.0 1064 +Baltimore New York 2010-10-26 123.0 1142 +Baltimore New York 2010-10-26 90.0 1599 +Chicago New York 2010-10-26 12.0 361 +Chicago New York 2010-10-26 0.0 897 +Chicago New York 2010-10-26 29.0 1531 +Chicago New York 2010-10-26 -17.0 1610 +Chicago New York 2010-10-26 6.0 3198 +Baltimore New York 2010-10-27 -18.0 1064 +Baltimore New York 2010-10-27 49.0 1142 +Baltimore New York 2010-10-27 92.0 1599 +Chicago New York 2010-10-27 148.0 361 +Chicago New York 2010-10-27 -11.0 897 +Chicago New York 2010-10-27 70.0 1531 +Chicago New York 2010-10-27 8.0 1610 +Chicago New York 2010-10-27 21.0 3198 +Baltimore New York 2010-10-28 -4.0 1064 +Baltimore New York 2010-10-28 -14.0 1142 +Baltimore New York 2010-10-28 -14.0 1599 +Chicago New York 2010-10-28 2.0 361 +Chicago New York 2010-10-28 2.0 897 +Chicago New York 2010-10-28 -11.0 1531 +Chicago New York 2010-10-28 3.0 1610 +Chicago New York 2010-10-28 -18.0 3198 +Baltimore New York 2010-10-29 -24.0 1064 +Baltimore New York 2010-10-29 21.0 1142 +Baltimore New York 2010-10-29 -2.0 1599 +Chicago New York 2010-10-29 -12.0 361 +Chicago New York 2010-10-29 -11.0 897 +Chicago New York 2010-10-29 15.0 1531 +Chicago New York 2010-10-29 -18.0 1610 +Chicago New York 2010-10-29 -4.0 3198 +Baltimore New York 2010-10-30 14.0 272 +Baltimore New York 2010-10-30 -1.0 1805 +Baltimore New York 2010-10-30 5.0 3171 +Chicago New York 2010-10-30 -6.0 384 +Chicago New York 2010-10-30 -10.0 426 +Chicago New York 2010-10-30 -5.0 650 +Chicago New York 2010-10-30 -5.0 3085 +Baltimore New York 2010-10-31 -1.0 1599 +Baltimore New York 2010-10-31 -14.0 2571 +Chicago New York 2010-10-31 -25.0 361 +Chicago New York 2010-10-31 -18.0 897 +Chicago New York 2010-10-31 -4.0 1531 +Chicago New York 2010-10-31 -22.0 1610 +Chicago New York 2010-10-31 -15.0 3198 +Cleveland New York 2010-10-30 -23.0 2018 +Cleveland New York 2010-10-30 -12.0 2932 +Cleveland New York 2010-10-29 -4.0 2630 +Cleveland New York 2010-10-29 -19.0 2646 +Cleveland New York 2010-10-29 -12.0 3014 +Cleveland New York 2010-10-28 3.0 2630 +Cleveland New York 2010-10-28 -6.0 2646 +Cleveland New York 2010-10-28 1.0 3014 +Cleveland New York 2010-10-27 16.0 2630 +Cleveland New York 2010-10-27 27.0 3014 +Cleveland New York 2010-10-26 4.0 2630 +Cleveland New York 2010-10-26 -27.0 2646 +Cleveland New York 2010-10-26 -11.0 2662 +Cleveland New York 2010-10-26 13.0 3014 +Cleveland New York 2010-10-25 -4.0 2630 +Cleveland New York 2010-10-25 81.0 2646 +Cleveland New York 2010-10-25 42.0 3014 +Cleveland New York 2010-10-24 5.0 2254 +Cleveland New York 2010-10-24 -11.0 2630 +Cleveland New York 2010-10-24 -20.0 2646 +Cleveland New York 2010-10-24 -9.0 3014 +Cleveland New York 2010-10-23 -21.0 2932 +Cleveland New York 2010-10-22 1.0 2630 +Cleveland New York 2010-10-22 -25.0 2646 +Cleveland New York 2010-10-22 -3.0 3014 +Cleveland New York 2010-10-21 3.0 2630 +Cleveland New York 2010-10-21 29.0 2646 +Cleveland New York 2010-10-21 72.0 3014 +Cleveland New York 2010-10-20 -8.0 2630 +Cleveland New York 2010-10-20 -15.0 3014 +Washington New York 2010-10-23 -25.0 5832 +Washington New York 2010-10-23 -21.0 5904 +Washington New York 2010-10-23 -18.0 5917 +Washington New York 2010-10-30 -27.0 5904 +Washington New York 2010-10-30 -16.0 5917 +Washington New York 2010-10-20 -2.0 7291 +Washington New York 2010-10-21 22.0 7291 +Washington New York 2010-10-23 -16.0 7274 +Washington New York 2010-10-24 -26.0 7282 +Washington New York 2010-10-25 9.0 7291 +Washington New York 2010-10-26 4.0 7291 +Washington New York 2010-10-27 26.0 7291 +Washington New York 2010-10-28 45.0 7291 +Washington New York 2010-10-29 1.0 7291 +Washington New York 2010-10-31 -18.0 7282 diff --git a/sql/hive/src/test/resources/golden/date_serde-9-436c3c61cc4278b54ac79c53c88ff422 b/sql/hive/src/test/resources/golden/date_serde-9-436c3c61cc4278b54ac79c53c88ff422 new file mode 100644 index 0000000000000..0f2a6f7a99237 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_serde-9-436c3c61cc4278b54ac79c53c88ff422 @@ -0,0 +1,12 @@ +2010-10-20 11 +2010-10-21 12 +2010-10-22 11 +2010-10-23 12 +2010-10-24 12 +2010-10-25 12 +2010-10-26 13 +2010-10-27 11 +2010-10-28 12 +2010-10-29 12 +2010-10-30 11 +2010-10-31 8 diff --git a/sql/hive/src/test/resources/golden/date_udf-0-84604a42a5d7f2842f1eec10c689d447 b/sql/hive/src/test/resources/golden/date_udf-0-84604a42a5d7f2842f1eec10c689d447 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_udf-1-5e8136f6a6503ae9bef9beca80fada13 b/sql/hive/src/test/resources/golden/date_udf-1-5e8136f6a6503ae9bef9beca80fada13 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_udf-10-988ad9744096a29a3672a2d4c121299b b/sql/hive/src/test/resources/golden/date_udf-10-988ad9744096a29a3672a2d4c121299b new file mode 100644 index 0000000000000..83c33400edb47 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_udf-10-988ad9744096a29a3672a2d4c121299b @@ -0,0 +1 @@ +0 3333 -3333 -3332 3332 diff --git a/sql/hive/src/test/resources/golden/date_udf-11-a5100dd42201b5bc035a9d684cc21bdc b/sql/hive/src/test/resources/golden/date_udf-11-a5100dd42201b5bc035a9d684cc21bdc new file mode 100644 index 0000000000000..4a2462bb3929b --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_udf-11-a5100dd42201b5bc035a9d684cc21bdc @@ -0,0 +1 @@ +NULL 2011 5 6 6 18 2011-05-06 diff --git a/sql/hive/src/test/resources/golden/date_udf-12-eb7280a1f191344a99eaa0f805e8faff b/sql/hive/src/test/resources/golden/date_udf-12-eb7280a1f191344a99eaa0f805e8faff new file mode 100644 index 0000000000000..19497254f8f7e --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_udf-12-eb7280a1f191344a99eaa0f805e8faff @@ -0,0 +1 @@ +2011-05-11 2011-04-26 diff --git a/sql/hive/src/test/resources/golden/date_udf-13-cc99e4f14fd092994b006ee7ebe4fc92 b/sql/hive/src/test/resources/golden/date_udf-13-cc99e4f14fd092994b006ee7ebe4fc92 new file mode 100644 index 0000000000000..977f0d24c58cc --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_udf-13-cc99e4f14fd092994b006ee7ebe4fc92 @@ -0,0 +1 @@ +0 3333 -3333 -3333 3333 diff --git a/sql/hive/src/test/resources/golden/date_udf-14-a6a5ce5134cc1125355a4bdf0a73d97 b/sql/hive/src/test/resources/golden/date_udf-14-a6a5ce5134cc1125355a4bdf0a73d97 new file mode 100644 index 0000000000000..44d1f45e4eb73 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_udf-14-a6a5ce5134cc1125355a4bdf0a73d97 @@ -0,0 +1 @@ +1970-01-01 08:00:00 1969-12-31 16:00:00 2013-06-19 07:00:00 2013-06-18 17:00:00 diff --git a/sql/hive/src/test/resources/golden/date_udf-15-d031ee50c119d7c6acafd53543dbd0c4 b/sql/hive/src/test/resources/golden/date_udf-15-d031ee50c119d7c6acafd53543dbd0c4 new file mode 100644 index 0000000000000..645b71d8d61e7 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_udf-15-d031ee50c119d7c6acafd53543dbd0c4 @@ -0,0 +1 @@ +true true true true diff --git a/sql/hive/src/test/resources/golden/date_udf-16-dc59f69e1685e8d923b187ec50d80f06 b/sql/hive/src/test/resources/golden/date_udf-16-dc59f69e1685e8d923b187ec50d80f06 new file mode 100644 index 0000000000000..51863e9a14e4b --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_udf-16-dc59f69e1685e8d923b187ec50d80f06 @@ -0,0 +1 @@ +2010-10-20 diff --git a/sql/hive/src/test/resources/golden/date_udf-17-7d046d4efc568049cf3792470b6feab9 b/sql/hive/src/test/resources/golden/date_udf-17-7d046d4efc568049cf3792470b6feab9 new file mode 100644 index 0000000000000..4043ee1cbdd40 --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_udf-17-7d046d4efc568049cf3792470b6feab9 @@ -0,0 +1 @@ +2010-10-31 diff --git a/sql/hive/src/test/resources/golden/date_udf-18-84604a42a5d7f2842f1eec10c689d447 b/sql/hive/src/test/resources/golden/date_udf-18-84604a42a5d7f2842f1eec10c689d447 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_udf-19-5e8136f6a6503ae9bef9beca80fada13 b/sql/hive/src/test/resources/golden/date_udf-19-5e8136f6a6503ae9bef9beca80fada13 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_udf-2-10e337c34d1e82a360b8599988f4b266 b/sql/hive/src/test/resources/golden/date_udf-2-10e337c34d1e82a360b8599988f4b266 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_udf-20-10e337c34d1e82a360b8599988f4b266 b/sql/hive/src/test/resources/golden/date_udf-20-10e337c34d1e82a360b8599988f4b266 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_udf-3-29e406e613c0284b3e16a8943a4d31bd b/sql/hive/src/test/resources/golden/date_udf-3-29e406e613c0284b3e16a8943a4d31bd new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_udf-4-23653315213f578856ab5c3bd80c0264 b/sql/hive/src/test/resources/golden/date_udf-4-23653315213f578856ab5c3bd80c0264 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_udf-5-891fd92a4787b9789f6d1f51c1eddc8a b/sql/hive/src/test/resources/golden/date_udf-5-891fd92a4787b9789f6d1f51c1eddc8a new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_udf-6-3473c118d20783eafb456043a2ee5d5b b/sql/hive/src/test/resources/golden/date_udf-6-3473c118d20783eafb456043a2ee5d5b new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_udf-7-9fb5165824e161074565e7500959c1b2 b/sql/hive/src/test/resources/golden/date_udf-7-9fb5165824e161074565e7500959c1b2 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/date_udf-8-badfe833681362092fc6345f888b1c21 b/sql/hive/src/test/resources/golden/date_udf-8-badfe833681362092fc6345f888b1c21 new file mode 100644 index 0000000000000..18d17ea11b53e --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_udf-8-badfe833681362092fc6345f888b1c21 @@ -0,0 +1 @@ +1304665200 2011 5 6 6 18 2011-05-06 diff --git a/sql/hive/src/test/resources/golden/date_udf-9-a8cbb039661d796beaa0d1564c58c563 b/sql/hive/src/test/resources/golden/date_udf-9-a8cbb039661d796beaa0d1564c58c563 new file mode 100644 index 0000000000000..19497254f8f7e --- /dev/null +++ b/sql/hive/src/test/resources/golden/date_udf-9-a8cbb039661d796beaa0d1564c58c563 @@ -0,0 +1 @@ +2011-05-11 2011-04-26 diff --git a/sql/hive/src/test/resources/golden/partition_date-0-7ec1f3a845e2c49191460e15af30aa30 b/sql/hive/src/test/resources/golden/partition_date-0-7ec1f3a845e2c49191460e15af30aa30 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/partition_date-1-916193405ce5e020dcd32c58325db6fe b/sql/hive/src/test/resources/golden/partition_date-1-916193405ce5e020dcd32c58325db6fe new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/partition_date-10-a8dde9c0b5746dd770c9c262d23ffb10 b/sql/hive/src/test/resources/golden/partition_date-10-a8dde9c0b5746dd770c9c262d23ffb10 new file mode 100644 index 0000000000000..7ed6ff82de6bc --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-10-a8dde9c0b5746dd770c9c262d23ffb10 @@ -0,0 +1 @@ +5 diff --git a/sql/hive/src/test/resources/golden/partition_date-11-fdface2fb6eef67f15bb7d0de2294957 b/sql/hive/src/test/resources/golden/partition_date-11-fdface2fb6eef67f15bb7d0de2294957 new file mode 100644 index 0000000000000..b4de394767536 --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-11-fdface2fb6eef67f15bb7d0de2294957 @@ -0,0 +1 @@ +11 diff --git a/sql/hive/src/test/resources/golden/partition_date-12-9b945f8ece6e09ad28c866ff3a10cc24 b/sql/hive/src/test/resources/golden/partition_date-12-9b945f8ece6e09ad28c866ff3a10cc24 new file mode 100644 index 0000000000000..64bb6b746dcea --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-12-9b945f8ece6e09ad28c866ff3a10cc24 @@ -0,0 +1 @@ +30 diff --git a/sql/hive/src/test/resources/golden/partition_date-13-b7cb91c7c459798078a79071d329dbf b/sql/hive/src/test/resources/golden/partition_date-13-b7cb91c7c459798078a79071d329dbf new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-13-b7cb91c7c459798078a79071d329dbf @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/partition_date-14-e4366325f3a0c4a8e92be59f4de73fce b/sql/hive/src/test/resources/golden/partition_date-14-e4366325f3a0c4a8e92be59f4de73fce new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-14-e4366325f3a0c4a8e92be59f4de73fce @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/partition_date-15-a062a6e87867d8c8cfbdad97bedcbe5f b/sql/hive/src/test/resources/golden/partition_date-15-a062a6e87867d8c8cfbdad97bedcbe5f new file mode 100644 index 0000000000000..209e3ef4b6247 --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-15-a062a6e87867d8c8cfbdad97bedcbe5f @@ -0,0 +1 @@ +20 diff --git a/sql/hive/src/test/resources/golden/partition_date-16-22a5627d9ac112665eae01d07a91c89c b/sql/hive/src/test/resources/golden/partition_date-16-22a5627d9ac112665eae01d07a91c89c new file mode 100644 index 0000000000000..f599e28b8ab0d --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-16-22a5627d9ac112665eae01d07a91c89c @@ -0,0 +1 @@ +10 diff --git a/sql/hive/src/test/resources/golden/partition_date-17-b9ce94ef93cb16d629af7d7f8ee637e b/sql/hive/src/test/resources/golden/partition_date-17-b9ce94ef93cb16d629af7d7f8ee637e new file mode 100644 index 0000000000000..209e3ef4b6247 --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-17-b9ce94ef93cb16d629af7d7f8ee637e @@ -0,0 +1 @@ +20 diff --git a/sql/hive/src/test/resources/golden/partition_date-18-72c6e9a4e0b434cef67144825346c687 b/sql/hive/src/test/resources/golden/partition_date-18-72c6e9a4e0b434cef67144825346c687 new file mode 100644 index 0000000000000..f599e28b8ab0d --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-18-72c6e9a4e0b434cef67144825346c687 @@ -0,0 +1 @@ +10 diff --git a/sql/hive/src/test/resources/golden/partition_date-19-44e5165eb210559e420105073bc96125 b/sql/hive/src/test/resources/golden/partition_date-19-44e5165eb210559e420105073bc96125 new file mode 100644 index 0000000000000..209e3ef4b6247 --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-19-44e5165eb210559e420105073bc96125 @@ -0,0 +1 @@ +20 diff --git a/sql/hive/src/test/resources/golden/partition_date-2-e2e70ac0f4e0ea987b49b86f73d819c9 b/sql/hive/src/test/resources/golden/partition_date-2-e2e70ac0f4e0ea987b49b86f73d819c9 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/partition_date-20-7ec1f3a845e2c49191460e15af30aa30 b/sql/hive/src/test/resources/golden/partition_date-20-7ec1f3a845e2c49191460e15af30aa30 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/partition_date-3-c938b08f57d588926a5d5fbfa4531012 b/sql/hive/src/test/resources/golden/partition_date-3-c938b08f57d588926a5d5fbfa4531012 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/partition_date-4-a93eff99ce43bb939ec1d6464c0ef0b3 b/sql/hive/src/test/resources/golden/partition_date-4-a93eff99ce43bb939ec1d6464c0ef0b3 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/partition_date-5-a855aba47876561fd4fb095e09580686 b/sql/hive/src/test/resources/golden/partition_date-5-a855aba47876561fd4fb095e09580686 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/partition_date-6-1405c311915f27b0cc616c83d39eaacc b/sql/hive/src/test/resources/golden/partition_date-6-1405c311915f27b0cc616c83d39eaacc new file mode 100644 index 0000000000000..051ca3d3c28e7 --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-6-1405c311915f27b0cc616c83d39eaacc @@ -0,0 +1,2 @@ +2000-01-01 +2013-08-08 diff --git a/sql/hive/src/test/resources/golden/partition_date-7-2ac950d8d5656549dd453e5464cb8530 b/sql/hive/src/test/resources/golden/partition_date-7-2ac950d8d5656549dd453e5464cb8530 new file mode 100644 index 0000000000000..24192eefd2caf --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-7-2ac950d8d5656549dd453e5464cb8530 @@ -0,0 +1,5 @@ +165 val_165 2000-01-01 2 +238 val_238 2000-01-01 2 +27 val_27 2000-01-01 2 +311 val_311 2000-01-01 2 +86 val_86 2000-01-01 2 diff --git a/sql/hive/src/test/resources/golden/partition_date-8-a425c11c12c9ce4c9c43d4fbccee5347 b/sql/hive/src/test/resources/golden/partition_date-8-a425c11c12c9ce4c9c43d4fbccee5347 new file mode 100644 index 0000000000000..60d3b2f4a4cd5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-8-a425c11c12c9ce4c9c43d4fbccee5347 @@ -0,0 +1 @@ +15 diff --git a/sql/hive/src/test/resources/golden/partition_date-9-aad6078a09b7bd8f5141437e86bb229f b/sql/hive/src/test/resources/golden/partition_date-9-aad6078a09b7bd8f5141437e86bb229f new file mode 100644 index 0000000000000..60d3b2f4a4cd5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_date-9-aad6078a09b7bd8f5141437e86bb229f @@ -0,0 +1 @@ +15 diff --git a/sql/hive/src/test/resources/golden/partition_type_check-12-7e053ba4f9dea1e74c1d04c557c3adac b/sql/hive/src/test/resources/golden/partition_type_check-12-7e053ba4f9dea1e74c1d04c557c3adac new file mode 100644 index 0000000000000..91ba621412d72 --- /dev/null +++ b/sql/hive/src/test/resources/golden/partition_type_check-12-7e053ba4f9dea1e74c1d04c557c3adac @@ -0,0 +1,6 @@ +1 11 2008-01-01 +2 12 2008-01-01 +3 13 2008-01-01 +7 17 2008-01-01 +8 18 2008-01-01 +8 28 2008-01-01 diff --git a/sql/hive/src/test/resources/golden/partition_type_check-13-45fb706ff448da1fe609c7ff76a80d4d b/sql/hive/src/test/resources/golden/partition_type_check-13-45fb706ff448da1fe609c7ff76a80d4d new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/union_date-6-f4d5c71145a9b7464685aa7d09cd4dfd b/sql/hive/src/test/resources/golden/union_date-6-f4d5c71145a9b7464685aa7d09cd4dfd new file mode 100644 index 0000000000000..7941f53d8d4c7 --- /dev/null +++ b/sql/hive/src/test/resources/golden/union_date-6-f4d5c71145a9b7464685aa7d09cd4dfd @@ -0,0 +1,40 @@ +1064 2000-11-20 +1064 2000-11-20 +1142 2000-11-21 +1142 2000-11-21 +1599 2000-11-22 +1599 2000-11-22 +361 2000-11-23 +361 2000-11-23 +897 2000-11-24 +897 2000-11-24 +1531 2000-11-25 +1531 2000-11-25 +1610 2000-11-26 +1610 2000-11-26 +3198 2000-11-27 +3198 2000-11-27 +1064 2000-11-28 +1064 2000-11-28 +1142 2000-11-28 +1142 2000-11-28 +1064 2010-10-20 +1064 2010-10-20 +1142 2010-10-21 +1142 2010-10-21 +1599 2010-10-22 +1599 2010-10-22 +361 2010-10-23 +361 2010-10-23 +897 2010-10-24 +897 2010-10-24 +1531 2010-10-25 +1531 2010-10-25 +1610 2010-10-26 +1610 2010-10-26 +3198 2010-10-27 +3198 2010-10-27 +1064 2010-10-28 +1064 2010-10-28 +1142 2010-10-29 +1142 2010-10-29 diff --git a/sql/hive/src/test/resources/golden/union_date-7-a0bade1c77338d4f72962389a1f5bea2 b/sql/hive/src/test/resources/golden/union_date-7-a0bade1c77338d4f72962389a1f5bea2 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/union_date-8-21306adbd8be8ad75174ad9d3e42b73c b/sql/hive/src/test/resources/golden/union_date-8-21306adbd8be8ad75174ad9d3e42b73c new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index a35c40efdc207..14e791fe0f0ee 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -24,7 +24,7 @@ import scala.reflect.ClassTag import org.apache.spark.sql.{SQLConf, QueryTest} import org.apache.spark.sql.catalyst.plans.logical.NativeCommand -import org.apache.spark.sql.execution.{BroadcastHashJoin, ShuffledHashJoin} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 2e282a9ade40c..3e100775e4981 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -22,6 +22,7 @@ import scala.util.Try import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ @@ -675,6 +676,41 @@ class HiveQuerySuite extends HiveComparisonTest { sql("SELECT * FROM boom").queryExecution.analyzed } + test("SPARK-3810: PreInsertionCasts static partitioning support") { + val analyzedPlan = { + loadTestTable("srcpart") + sql("DROP TABLE IF EXISTS withparts") + sql("CREATE TABLE withparts LIKE srcpart") + sql("INSERT INTO TABLE withparts PARTITION(ds='1', hr='2') SELECT key, value FROM src") + .queryExecution.analyzed + } + + assertResult(1, "Duplicated project detected\n" + analyzedPlan) { + analyzedPlan.collect { + case _: Project => () + }.size + } + } + + test("SPARK-3810: PreInsertionCasts dynamic partitioning support") { + val analyzedPlan = { + loadTestTable("srcpart") + sql("DROP TABLE IF EXISTS withparts") + sql("CREATE TABLE withparts LIKE srcpart") + sql("SET hive.exec.dynamic.partition.mode=nonstrict") + + sql("CREATE TABLE IF NOT EXISTS withparts LIKE srcpart") + sql("INSERT INTO TABLE withparts PARTITION(ds, hr) SELECT key, value FROM src") + .queryExecution.analyzed + } + + assertResult(1, "Duplicated project detected\n" + analyzedPlan) { + analyzedPlan.collect { + case _: Project => () + }.size + } + } + test("parse HQL set commands") { // Adapted from its SQL counterpart. val testKey = "spark.sql.key.usedfortestonly" @@ -766,6 +802,9 @@ class HiveQuerySuite extends HiveComparisonTest { clear() } + createQueryTest("select from thrift based table", + "SELECT * from src_thrift") + // Put tests that depend on specific Hive settings before these last two test, // since they modify /clear stuff. } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index e4324e9528f9b..872f28d514efe 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -17,33 +17,37 @@ package org.apache.spark.sql.hive.execution -import java.io.{DataOutput, DataInput} +import java.io.{DataInput, DataOutput} import java.util import java.util.Properties -import org.apache.spark.util.Utils - -import scala.collection.JavaConversions._ - import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.hive.serde2.{SerDeStats, AbstractSerDe} -import org.apache.hadoop.io.Writable -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorFactory, ObjectInspector} - -import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory import org.apache.hadoop.hive.ql.udf.generic.GenericUDF import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject - -import org.apache.spark.sql.Row +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats} +import org.apache.hadoop.io.Writable +import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ + +import org.apache.spark.util.Utils + +import scala.collection.JavaConversions._ case class Fields(f1: Int, f2: Int, f3: Int, f4: Int, f5: Int) +// Case classes for the custom UDF's. +case class IntegerCaseClass(i: Int) +case class ListListIntCaseClass(lli: Seq[(Int, Int, Int)]) +case class StringCaseClass(s: String) +case class ListStringCaseClass(l: Seq[String]) + /** * A test suite for Hive custom UDFs. */ -class HiveUdfSuite extends HiveComparisonTest { +class HiveUdfSuite extends QueryTest { + import TestHive._ test("spark sql udf test that returns a struct") { registerFunction("getStruct", (_: Int) => Fields(1, 2, 3, 4, 5)) @@ -81,7 +85,84 @@ class HiveUdfSuite extends HiveComparisonTest { } test("SPARK-2693 udaf aggregates test") { - assert(sql("SELECT percentile(key,1) FROM src").first === sql("SELECT max(key) FROM src").first) + checkAnswer(sql("SELECT percentile(key,1) FROM src LIMIT 1"), + sql("SELECT max(key) FROM src").collect().toSeq) + } + + test("UDFIntegerToString") { + val testData = TestHive.sparkContext.parallelize( + IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil) + testData.registerTempTable("integerTable") + + sql(s"CREATE TEMPORARY FUNCTION testUDFIntegerToString AS '${classOf[UDFIntegerToString].getName}'") + checkAnswer( + sql("SELECT testUDFIntegerToString(i) FROM integerTable"), //.collect(), + Seq(Seq("1"), Seq("2"))) + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFIntegerToString") + + TestHive.reset() + } + + test("UDFListListInt") { + val testData = TestHive.sparkContext.parallelize( + ListListIntCaseClass(Nil) :: + ListListIntCaseClass(Seq((1, 2, 3))) :: + ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: Nil) + testData.registerTempTable("listListIntTable") + + sql(s"CREATE TEMPORARY FUNCTION testUDFListListInt AS '${classOf[UDFListListInt].getName}'") + checkAnswer( + sql("SELECT testUDFListListInt(lli) FROM listListIntTable"), //.collect(), + Seq(Seq(0), Seq(2), Seq(13))) + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListListInt") + + TestHive.reset() + } + + test("UDFListString") { + val testData = TestHive.sparkContext.parallelize( + ListStringCaseClass(Seq("a", "b", "c")) :: + ListStringCaseClass(Seq("d", "e")) :: Nil) + testData.registerTempTable("listStringTable") + + sql(s"CREATE TEMPORARY FUNCTION testUDFListString AS '${classOf[UDFListString].getName}'") + checkAnswer( + sql("SELECT testUDFListString(l) FROM listStringTable"), //.collect(), + Seq(Seq("a,b,c"), Seq("d,e"))) + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListString") + + TestHive.reset() + } + + test("UDFStringString") { + val testData = TestHive.sparkContext.parallelize( + StringCaseClass("world") :: StringCaseClass("goodbye") :: Nil) + testData.registerTempTable("stringTable") + + sql(s"CREATE TEMPORARY FUNCTION testStringStringUdf AS '${classOf[UDFStringString].getName}'") + checkAnswer( + sql("SELECT testStringStringUdf(\"hello\", s) FROM stringTable"), //.collect(), + Seq(Seq("hello world"), Seq("hello goodbye"))) + sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUdf") + + TestHive.reset() + } + + test("UDFTwoListList") { + val testData = TestHive.sparkContext.parallelize( + ListListIntCaseClass(Nil) :: + ListListIntCaseClass(Seq((1, 2, 3))) :: + ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: + Nil) + testData.registerTempTable("TwoListTable") + + sql(s"CREATE TEMPORARY FUNCTION testUDFTwoListList AS '${classOf[UDFTwoListList].getName}'") + checkAnswer( + sql("SELECT testUDFTwoListList(lli, lli) FROM TwoListTable"), //.collect(), + Seq(Seq("0, 0"), Seq("2, 2"), Seq("13, 13"))) + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList") + + TestHive.reset() } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 3647bb1c4ce7d..fbe6ac765c009 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -68,5 +68,11 @@ class SQLQuerySuite extends QueryTest { checkAnswer( sql("SELECT k FROM (SELECT `key` AS `k` FROM src) a"), sql("SELECT `key` FROM src").collect().toSeq) - } + } + + test("SPARK-3834 Backticks not correctly handled in subquery aliases") { + checkAnswer( + sql("SELECT a.key FROM (SELECT key FROM src) `a`"), + sql("SELECT `key` FROM src").collect().toSeq) + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index a6184de4e83c1..2a7004e56ef53 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -167,7 +167,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T new JavaPairDStream(dstream.flatMap(fn)(cm))(fakeClassTag[K2], fakeClassTag[V2]) } - /** + /** * Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs * of this DStream. Applying mapPartitions() to an RDD applies a function to each partition * of the RDD. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala new file mode 100644 index 0000000000000..213dff6a76354 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -0,0 +1,316 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.api.python + +import java.io.{ObjectInputStream, ObjectOutputStream} +import java.lang.reflect.Proxy +import java.util.{ArrayList => JArrayList, List => JList} +import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ +import scala.language.existentials + +import py4j.GatewayServer + +import org.apache.spark.api.java._ +import org.apache.spark.api.python._ +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Interval, Duration, Time} +import org.apache.spark.streaming.dstream._ +import org.apache.spark.streaming.api.java._ + + +/** + * Interface for Python callback function which is used to transform RDDs + */ +private[python] trait PythonTransformFunction { + def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]] +} + +/** + * Interface for Python Serializer to serialize PythonTransformFunction + */ +private[python] trait PythonTransformFunctionSerializer { + def dumps(id: String): Array[Byte] + def loads(bytes: Array[Byte]): PythonTransformFunction +} + +/** + * Wraps a PythonTransformFunction (which is a Python object accessed through Py4J) + * so that it looks like a Scala function and can be transparently serialized and + * deserialized by Java. + */ +private[python] class TransformFunction(@transient var pfunc: PythonTransformFunction) + extends function.Function2[JList[JavaRDD[_]], Time, JavaRDD[Array[Byte]]] { + + def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { + Option(pfunc.call(time.milliseconds, List(rdd.map(JavaRDD.fromRDD(_)).orNull).asJava)) + .map(_.rdd) + } + + def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { + val rdds = List(rdd.map(JavaRDD.fromRDD(_)).orNull, rdd2.map(JavaRDD.fromRDD(_)).orNull).asJava + Option(pfunc.call(time.milliseconds, rdds)).map(_.rdd) + } + + // for function.Function2 + def call(rdds: JList[JavaRDD[_]], time: Time): JavaRDD[Array[Byte]] = { + pfunc.call(time.milliseconds, rdds) + } + + private def writeObject(out: ObjectOutputStream): Unit = { + val bytes = PythonTransformFunctionSerializer.serialize(pfunc) + out.writeInt(bytes.length) + out.write(bytes) + } + + private def readObject(in: ObjectInputStream): Unit = { + val length = in.readInt() + val bytes = new Array[Byte](length) + in.readFully(bytes) + pfunc = PythonTransformFunctionSerializer.deserialize(bytes) + } +} + +/** + * Helpers for PythonTransformFunctionSerializer + * + * PythonTransformFunctionSerializer is logically a singleton that's happens to be + * implemented as a Python object. + */ +private[python] object PythonTransformFunctionSerializer { + + /** + * A serializer in Python, used to serialize PythonTransformFunction + */ + private var serializer: PythonTransformFunctionSerializer = _ + + /* + * Register a serializer from Python, should be called during initialization + */ + def register(ser: PythonTransformFunctionSerializer): Unit = { + serializer = ser + } + + def serialize(func: PythonTransformFunction): Array[Byte] = { + assert(serializer != null, "Serializer has not been registered!") + // get the id of PythonTransformFunction in py4j + val h = Proxy.getInvocationHandler(func.asInstanceOf[Proxy]) + val f = h.getClass().getDeclaredField("id") + f.setAccessible(true) + val id = f.get(h).asInstanceOf[String] + serializer.dumps(id) + } + + def deserialize(bytes: Array[Byte]): PythonTransformFunction = { + assert(serializer != null, "Serializer has not been registered!") + serializer.loads(bytes) + } +} + +/** + * Helper functions, which are called from Python via Py4J. + */ +private[python] object PythonDStream { + + /** + * can not access PythonTransformFunctionSerializer.register() via Py4j + * Py4JError: PythonTransformFunctionSerializerregister does not exist in the JVM + */ + def registerSerializer(ser: PythonTransformFunctionSerializer): Unit = { + PythonTransformFunctionSerializer.register(ser) + } + + /** + * Update the port of callback client to `port` + */ + def updatePythonGatewayPort(gws: GatewayServer, port: Int): Unit = { + val cl = gws.getCallbackClient + val f = cl.getClass.getDeclaredField("port") + f.setAccessible(true) + f.setInt(cl, port) + } + + /** + * helper function for DStream.foreachRDD(), + * cannot be `foreachRDD`, it will confusing py4j + */ + def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pfunc: PythonTransformFunction) { + val func = new TransformFunction((pfunc)) + jdstream.dstream.foreachRDD((rdd, time) => func(Some(rdd), time)) + } + + /** + * convert list of RDD into queue of RDDs, for ssc.queueStream() + */ + def toRDDQueue(rdds: JArrayList[JavaRDD[Array[Byte]]]): java.util.Queue[JavaRDD[Array[Byte]]] = { + val queue = new java.util.LinkedList[JavaRDD[Array[Byte]]] + rdds.forall(queue.add(_)) + queue + } +} + +/** + * Base class for PythonDStream with some common methods + */ +private[python] abstract class PythonDStream( + parent: DStream[_], + @transient pfunc: PythonTransformFunction) + extends DStream[Array[Byte]] (parent.ssc) { + + val func = new TransformFunction(pfunc) + + override def dependencies = List(parent) + + override def slideDuration: Duration = parent.slideDuration + + val asJavaDStream = JavaDStream.fromDStream(this) +} + +/** + * Transformed DStream in Python. + */ +private[python] class PythonTransformedDStream ( + parent: DStream[_], + @transient pfunc: PythonTransformFunction) + extends PythonDStream(parent, pfunc) { + + override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { + val rdd = parent.getOrCompute(validTime) + if (rdd.isDefined) { + func(rdd, validTime) + } else { + None + } + } +} + +/** + * Transformed from two DStreams in Python. + */ +private[python] class PythonTransformed2DStream( + parent: DStream[_], + parent2: DStream[_], + @transient pfunc: PythonTransformFunction) + extends DStream[Array[Byte]] (parent.ssc) { + + val func = new TransformFunction(pfunc) + + override def dependencies = List(parent, parent2) + + override def slideDuration: Duration = parent.slideDuration + + override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { + val empty: RDD[_] = ssc.sparkContext.emptyRDD + val rdd1 = parent.getOrCompute(validTime).getOrElse(empty) + val rdd2 = parent2.getOrCompute(validTime).getOrElse(empty) + func(Some(rdd1), Some(rdd2), validTime) + } + + val asJavaDStream = JavaDStream.fromDStream(this) +} + +/** + * similar to StateDStream + */ +private[python] class PythonStateDStream( + parent: DStream[Array[Byte]], + @transient reduceFunc: PythonTransformFunction) + extends PythonDStream(parent, reduceFunc) { + + super.persist(StorageLevel.MEMORY_ONLY) + override val mustCheckpoint = true + + override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { + val lastState = getOrCompute(validTime - slideDuration) + val rdd = parent.getOrCompute(validTime) + if (rdd.isDefined) { + func(lastState, rdd, validTime) + } else { + lastState + } + } +} + +/** + * similar to ReducedWindowedDStream + */ +private[python] class PythonReducedWindowedDStream( + parent: DStream[Array[Byte]], + @transient preduceFunc: PythonTransformFunction, + @transient pinvReduceFunc: PythonTransformFunction, + _windowDuration: Duration, + _slideDuration: Duration) + extends PythonDStream(parent, preduceFunc) { + + super.persist(StorageLevel.MEMORY_ONLY) + override val mustCheckpoint = true + + val invReduceFunc = new TransformFunction(pinvReduceFunc) + + def windowDuration: Duration = _windowDuration + override def slideDuration: Duration = _slideDuration + override def parentRememberDuration: Duration = rememberDuration + windowDuration + + override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { + val currentTime = validTime + val current = new Interval(currentTime - windowDuration, currentTime) + val previous = current - slideDuration + + // _____________________________ + // | previous window _________|___________________ + // |___________________| current window | --------------> Time + // |_____________________________| + // + // |________ _________| |________ _________| + // | | + // V V + // old RDDs new RDDs + // + val previousRDD = getOrCompute(previous.endTime) + + // for small window, reduce once will be better than twice + if (pinvReduceFunc != null && previousRDD.isDefined + && windowDuration >= slideDuration * 5) { + + // subtract the values from old RDDs + val oldRDDs = parent.slice(previous.beginTime + parent.slideDuration, current.beginTime) + val subtracted = if (oldRDDs.size > 0) { + invReduceFunc(previousRDD, Some(ssc.sc.union(oldRDDs)), validTime) + } else { + previousRDD + } + + // add the RDDs of the reduced values in "new time steps" + val newRDDs = parent.slice(previous.endTime + parent.slideDuration, current.endTime) + if (newRDDs.size > 0) { + func(subtracted, Some(ssc.sc.union(newRDDs)), validTime) + } else { + subtracted + } + } else { + // Get the RDDs of the reduced values in current window + val currentRDDs = parent.slice(current.beginTime + parent.slideDuration, current.endTime) + if (currentRDDs.size > 0) { + func(None, Some(ssc.sc.union(currentRDDs)), validTime) + } else { + None + } + } + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 8511390cb1ad5..e5592e52b0d2d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -231,8 +231,7 @@ class CheckpointSuite extends TestSuiteBase { // failure, are re-processed or not. test("recovery with file input stream") { // Set up the streaming context and input streams - val testDir = Files.createTempDir() - testDir.deleteOnExit() + val testDir = Utils.createTempDir() var ssc = new StreamingContext(master, framework, Seconds(1)) ssc.checkpoint(checkpointDir) val fileStream = ssc.textFileStream(testDir.toString) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 952a74fd5f6de..fa04fa326e370 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -18,8 +18,6 @@ package org.apache.spark.streaming import akka.actor.Actor -import akka.actor.IO -import akka.actor.IOManager import akka.actor.Props import akka.util.ByteString @@ -98,8 +96,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock") // Set up the streaming context and input streams - val testDir = Files.createTempDir() - testDir.deleteOnExit() + val testDir = Utils.createTempDir() val ssc = new StreamingContext(conf, batchDuration) val fileStream = ssc.textFileStream(testDir.toString) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] @@ -144,59 +141,6 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") } - // TODO: This test works in IntelliJ but not through SBT - ignore("actor input stream") { - // Start the server - val testServer = new TestServer() - val port = testServer.port - testServer.start() - - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val networkStream = ssc.actorStream[String](Props(new TestActor(port)), "TestActor", - // Had to pass the local value of port to prevent from closing over entire scope - StorageLevel.MEMORY_AND_DISK) - val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] - val outputStream = new TestOutputStream(networkStream, outputBuffer) - def output = outputBuffer.flatMap(x => x) - outputStream.register() - ssc.start() - - // Feed data to the server to send to the network receiver - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - val input = 1 to 9 - val expectedOutput = input.map(x => x.toString) - Thread.sleep(1000) - for (i <- 0 until input.size) { - testServer.send(input(i).toString) - Thread.sleep(500) - clock.addToTime(batchDuration.milliseconds) - } - Thread.sleep(1000) - logInfo("Stopping server") - testServer.stop() - logInfo("Stopping context") - ssc.stop() - - // Verify whether data received was as expected - logInfo("--------------------------------") - logInfo("output.size = " + outputBuffer.size) - logInfo("output") - outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("expected output.size = " + expectedOutput.size) - logInfo("expected output") - expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("--------------------------------") - - // Verify whether all the elements received are as expected - // (whether the elements were received one in each interval is not verified) - assert(output.size === expectedOutput.size) - for (i <- 0 until output.size) { - assert(output(i) === expectedOutput(i)) - } - } - - test("multi-thread receiver") { // set up the test receiver val numThreads = 10 @@ -378,22 +322,6 @@ class TestServer(portToBind: Int = 0) extends Logging { def port = serverSocket.getLocalPort } -/** This is an actor for testing actor input stream */ -class TestActor(port: Int) extends Actor with ActorHelper { - - def bytesToString(byteString: ByteString) = byteString.utf8String - - override def preStart(): Unit = { - @deprecated("suppress compile time deprecation warning", "1.0.0") - val unit = IOManager(context.system).connect(new InetSocketAddress(port)) - } - - def receive = { - case IO.Read(socket, bytes) => - store(bytesToString(bytes)) - } -} - /** This is a receiver to test multiple threads inserting data using block generator */ class MultiThreadTestReceiver(numThreads: Int, numRecordsPerThread: Int) extends Receiver[Int](StorageLevel.MEMORY_ONLY_SER) with Logging { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala index c53c01706083a..5dbb7232009eb 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala @@ -352,8 +352,7 @@ class FileGeneratingThread(input: Seq[String], testDir: Path, interval: Long) extends Thread with Logging { override def run() { - val localTestDir = Files.createTempDir() - localTestDir.deleteOnExit() + val localTestDir = Utils.createTempDir() var fs = testDir.getFileSystem(new Configuration()) val maxTries = 3 try { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 759baacaa4308..9327ff4822699 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -24,12 +24,12 @@ import scala.collection.mutable.SynchronizedBuffer import scala.reflect.ClassTag import org.scalatest.{BeforeAndAfter, FunSuite} -import com.google.common.io.Files import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream} import org.apache.spark.streaming.util.ManualClock import org.apache.spark.{SparkConf, Logging} import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils /** * This is a input stream just for the testsuites. This is equivalent to a checkpointable, @@ -120,9 +120,8 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { // Directory where the checkpoint data will be saved lazy val checkpointDir = { - val dir = Files.createTempDir() + val dir = Utils.createTempDir() logDebug(s"checkpointDir: $dir") - dir.deleteOnExit() dir.toString } diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 5a20532315e59..5c7bca4541222 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -122,7 +122,7 @@ private[spark] class Client( * ApplicationReport#getClientToken is renamed `getClientToAMToken` in the stable API. */ override def getClientToken(report: ApplicationReport): String = - Option(report.getClientToken).getOrElse("") + Option(report.getClientToken).map(_.toString).getOrElse("") } object Client { diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala index 6c93d8582330b..abd37834ed3cc 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala @@ -43,7 +43,7 @@ private[yarn] class YarnAllocationHandler( args: ApplicationMasterArguments, preferredNodes: collection.Map[String, collection.Set[SplitInfo]], securityMgr: SecurityManager) - extends YarnAllocator(conf, sparkConf, args, preferredNodes, securityMgr) { + extends YarnAllocator(conf, sparkConf, appAttemptId, args, preferredNodes, securityMgr) { private val lastResponseId = new AtomicInteger() private val releaseList: CopyOnWriteArrayList[ContainerId] = new CopyOnWriteArrayList() diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index 6ecac6eae6e03..0efac4ea63702 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -23,6 +23,7 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, ListBuffer, Map} import scala.util.{Try, Success, Failure} +import com.google.common.base.Objects import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ import org.apache.hadoop.fs.permission.FsPermission @@ -64,12 +65,12 @@ private[spark] trait ClientBase extends Logging { s"memory capability of the cluster ($maxMem MB per container)") val executorMem = args.executorMemory + executorMemoryOverhead if (executorMem > maxMem) { - throw new IllegalArgumentException(s"Required executor memory (${args.executorMemory}" + + throw new IllegalArgumentException(s"Required executor memory (${args.executorMemory}" + s"+$executorMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster!") } val amMem = args.amMemory + amMemoryOverhead if (amMem > maxMem) { - throw new IllegalArgumentException(s"Required AM memory (${args.amMemory}" + + throw new IllegalArgumentException(s"Required AM memory (${args.amMemory}" + s"+$amMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster!") } logInfo("Will allocate AM container, with %d MB memory including %d MB overhead".format( @@ -142,7 +143,8 @@ private[spark] trait ClientBase extends Logging { val nns = getNameNodesToAccess(sparkConf) + dst obtainTokensForNamenodes(nns, hadoopConf, credentials) - val replication = sparkConf.getInt("spark.yarn.submit.file.replication", 3).toShort + val replication = sparkConf.getInt("spark.yarn.submit.file.replication", + fs.getDefaultReplication(dst)).toShort val localResources = HashMap[String, LocalResource]() FileSystem.mkdirs(fs, dst, new FsPermission(STAGING_DIR_PERMISSION)) @@ -771,15 +773,17 @@ private[spark] object ClientBase extends Logging { private def compareFs(srcFs: FileSystem, destFs: FileSystem): Boolean = { val srcUri = srcFs.getUri() val dstUri = destFs.getUri() - if (srcUri.getScheme() == null) { - return false - } - if (!srcUri.getScheme().equals(dstUri.getScheme())) { + if (srcUri.getScheme() == null || srcUri.getScheme() != dstUri.getScheme()) { return false } + var srcHost = srcUri.getHost() var dstHost = dstUri.getHost() - if ((srcHost != null) && (dstHost != null)) { + + // In HA or when using viewfs, the host part of the URI may not actually be a host, but the + // name of the HDFS namespace. Those names won't resolve, so avoid even trying if they + // match. + if (srcHost != null && dstHost != null && srcHost != dstHost) { try { srcHost = InetAddress.getByName(srcHost).getCanonicalHostName() dstHost = InetAddress.getByName(dstHost).getCanonicalHostName() @@ -787,19 +791,9 @@ private[spark] object ClientBase extends Logging { case e: UnknownHostException => return false } - if (!srcHost.equals(dstHost)) { - return false - } - } else if (srcHost == null && dstHost != null) { - return false - } else if (srcHost != null && dstHost == null) { - return false - } - if (srcUri.getPort() != dstUri.getPort()) { - false - } else { - true } + + Objects.equal(srcHost, dstHost) && srcUri.getPort() == dstUri.getPort() } } diff --git a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala index 9bd916100dd2c..17b79ae1d82c4 100644 --- a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala +++ b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala @@ -20,13 +20,10 @@ package org.apache.spark.deploy.yarn import java.io.File import java.net.URI -import com.google.common.io.Files import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.MRJobConfig -import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.api.ApplicationConstants.Environment -import org.apache.hadoop.yarn.api.protocolrecords.GetNewApplicationResponse import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.mockito.Matchers._ @@ -117,7 +114,7 @@ class ClientBaseSuite extends FunSuite with Matchers { doReturn(new Path("/")).when(client).copyFileToRemote(any(classOf[Path]), any(classOf[Path]), anyShort(), anyBoolean()) - val tempDir = Files.createTempDir() + val tempDir = Utils.createTempDir() try { client.prepareLocalResources(tempDir.getAbsolutePath()) sparkConf.getOption(ClientBase.CONF_SPARK_USER_JAR) should be (Some(USER)) diff --git a/yarn/stable/pom.xml b/yarn/stable/pom.xml index 97eb0548e77c3..fe55d70ccc370 100644 --- a/yarn/stable/pom.xml +++ b/yarn/stable/pom.xml @@ -41,4 +41,55 @@ + + + + hadoop-2.2 + + 1.9 + + + + org.mortbay.jetty + jetty + 6.1.26 + + + org.mortbay.jetty + servlet-api + + + test + + + com.sun.jersey + jersey-core + ${jersey.version} + test + + + com.sun.jersey + jersey-json + ${jersey.version} + test + + + stax + stax-api + + + + + com.sun.jersey + jersey-server + ${jersey.version} + test + + + + +
    Output OperationMeaning
    print() print() Prints first ten elements of every batch of data in a DStream on the driver. - This is useful for development and debugging.
    saveAsObjectFiles(prefix, [suffix])