diff --git a/NOTICE b/NOTICE index 6ec240efbf12..876d6068c6c7 100644 --- a/NOTICE +++ b/NOTICE @@ -5,6 +5,31 @@ This product includes software developed at The Apache Software Foundation (http://www.apache.org/). +Export Control Notice +--------------------- + +This distribution includes cryptographic software. The country in which you currently reside may have +restrictions on the import, possession, use, and/or re-export to another country, of encryption software. +BEFORE using any encryption software, please check your country's laws, regulations and policies concerning +the import, possession, or use, and re-export of encryption software, to see if this is permitted. See + for more information. + +The U.S. Government Department of Commerce, Bureau of Industry and Security (BIS), has classified this +software as Export Commodity Control Number (ECCN) 5D002.C.1, which includes information security software +using or performing cryptographic functions with asymmetric algorithms. The form and manner of this Apache +Software Foundation distribution makes it eligible for export under the License Exception ENC Technology +Software Unrestricted (TSU) exception (see the BIS Export Administration Regulations, Section 740.13) for +both object code and source code. + +The following provides more details on the included cryptographic software: + +This software uses Apache Commons Crypto (https://commons.apache.org/proper/commons-crypto/) to +support authentication, and encryption and decryption of data sent across the network between +services. + +This software includes Bouncy Castle (http://bouncycastle.org/) to support the jets3t library. + + ======================================================================== Common Development and Distribution License 1.0 ======================================================================== diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 443c2ff8f9ac..25e2d158bfef 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -138,11 +138,10 @@ parallelize <- function(sc, coll, numSlices = 1) { sizeLimit <- getMaxAllocationLimit(sc) objectSize <- object.size(coll) + len <- length(coll) # For large objects we make sure the size of each slice is also smaller than sizeLimit - numSerializedSlices <- max(numSlices, ceiling(objectSize / sizeLimit)) - if (numSerializedSlices > length(coll)) - numSerializedSlices <- length(coll) + numSerializedSlices <- min(len, max(numSlices, ceiling(objectSize / sizeLimit))) # Generate the slice ids to put each row # For instance, for numSerializedSlices of 22, length of 50 @@ -153,8 +152,8 @@ parallelize <- function(sc, coll, numSlices = 1) { splits <- if (numSerializedSlices > 0) { unlist(lapply(0: (numSerializedSlices - 1), function(x) { # nolint start - start <- trunc((x * length(coll)) / numSerializedSlices) - end <- trunc(((x + 1) * length(coll)) / numSerializedSlices) + start <- trunc((as.numeric(x) * len) / numSerializedSlices) + end <- trunc(((as.numeric(x) + 1) * len) / numSerializedSlices) # nolint end rep(start, end - start) })) diff --git a/R/pkg/inst/tests/testthat/test_basic.R b/R/pkg/inst/tests/testthat/test_basic.R index 243f5f029828..80df3d8ce6e5 100644 --- a/R/pkg/inst/tests/testthat/test_basic.R +++ b/R/pkg/inst/tests/testthat/test_basic.R @@ -18,9 +18,9 @@ context("basic tests for CRAN") test_that("create DataFrame from list or data.frame", { - tryCatch( checkJavaVersion(), + tryCatch(checkJavaVersion(), error = function(e) { skip("error on Java check") }, - warning = function(e) { skip("warning on Java check") } ) + warning = function(e) { skip("warning on Java check") }) sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE, sparkConfig = sparkRTestConfig) @@ -54,9 +54,9 @@ test_that("create DataFrame from list or data.frame", { }) test_that("spark.glm and predict", { - tryCatch( checkJavaVersion(), + tryCatch(checkJavaVersion(), error = function(e) { skip("error on Java check") }, - warning = function(e) { skip("warning on Java check") } ) + warning = function(e) { skip("warning on Java check") }) sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE, sparkConfig = sparkRTestConfig) diff --git a/R/pkg/tests/fulltests/test_context.R b/R/pkg/tests/fulltests/test_context.R index f0d0a5114f89..288a2714a554 100644 --- a/R/pkg/tests/fulltests/test_context.R +++ b/R/pkg/tests/fulltests/test_context.R @@ -240,3 +240,10 @@ test_that("add and get file to be downloaded with Spark job on every node", { unlink(path, recursive = TRUE) sparkR.session.stop() }) + +test_that("SPARK-25234: parallelize should not have integer overflow", { + sc <- sparkR.sparkContext(master = sparkRTestMaster) + # 47000 * 47000 exceeds integer range + parallelize(sc, 1:47000, 47000) + sparkR.session.stop() +}) diff --git a/assembly/README b/assembly/README index d5dafab47741..affd281a1385 100644 --- a/assembly/README +++ b/assembly/README @@ -9,4 +9,4 @@ This module is off by default. To activate it specify the profile in the command If you need to build an assembly for a different version of Hadoop the hadoop-version system property needs to be set as in this example: - -Dhadoop.version=2.7.3 + -Dhadoop.version=2.7.7 diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index 91497b949221..34e4bb5912dc 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -209,7 +209,7 @@ public String keyFactoryAlgorithm() { * (128 bits by default), which is not generally the case with user passwords. */ public int keyFactoryIterations() { - return conf.getInt("spark.networy.crypto.keyFactoryIterations", 1024); + return conf.getInt("spark.network.crypto.keyFactoryIterations", 1024); } /** diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java index dc36809d8911..0d069125dc60 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java @@ -66,7 +66,7 @@ public int compare(PackedRecordPointer left, PackedRecordPointer right) { */ private int usableCapacity = 0; - private int initialSize; + private final int initialSize; ShuffleInMemorySorter(MemoryConsumer consumer, int initialSize, boolean useRadixSort) { this.consumer = consumer; @@ -95,12 +95,20 @@ public int numRecords() { } public void reset() { + // Reset `pos` here so that `spill` triggered by the below `allocateArray` will be no-op. + pos = 0; if (consumer != null) { consumer.freeArray(array); + // As `array` has been released, we should set it to `null` to avoid accessing it before + // `allocateArray` returns. `usableCapacity` is also set to `0` to avoid any codes writing + // data to `ShuffleInMemorySorter` when `array` is `null` (e.g., in + // ShuffleExternalSorter.growPointerArrayIfNecessary, we may try to access + // `ShuffleInMemorySorter` when `allocateArray` throws SparkOutOfMemoryError). + array = null; + usableCapacity = 0; array = consumer.allocateArray(initialSize); usableCapacity = getUsableCapacity(); } - pos = 0; } public void expandPointerArray(LongArray newArray) { diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index a1ee2f7d1b11..8bc0ff7936da 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -586,8 +586,9 @@ class BytesToString extends org.apache.spark.api.java.function.Function[Array[By */ private[spark] class PythonAccumulatorV2( @transient private val serverHost: String, - private val serverPort: Int) - extends CollectionAccumulator[Array[Byte]] { + private val serverPort: Int, + private val secretToken: String) + extends CollectionAccumulator[Array[Byte]] with Logging{ Utils.checkHost(serverHost) @@ -602,12 +603,17 @@ private[spark] class PythonAccumulatorV2( private def openSocket(): Socket = synchronized { if (socket == null || socket.isClosed) { socket = new Socket(serverHost, serverPort) + logInfo(s"Connected to AccumulatorServer at host: $serverHost port: $serverPort") + // send the secret just for the initial authentication when opening a new connection + socket.getOutputStream.write(secretToken.getBytes(StandardCharsets.UTF_8)) } socket } // Need to override so the types match with PythonFunction - override def copyAndReset(): PythonAccumulatorV2 = new PythonAccumulatorV2(serverHost, serverPort) + override def copyAndReset(): PythonAccumulatorV2 = { + new PythonAccumulatorV2(serverHost, serverPort, secretToken) + } override def merge(other: AccumulatorV2[Array[Byte], JList[Array[Byte]]]): Unit = synchronized { val otherPythonAccumulator = other.asInstanceOf[PythonAccumulatorV2] diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 177295fb7af0..dec1b5fafbaa 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -31,7 +31,6 @@ import scala.util.control.NonFatal import com.google.common.primitives.Longs import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter} -import org.apache.hadoop.fs.permission.FsAction import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.hadoop.security.token.{Token, TokenIdentifier} @@ -378,28 +377,6 @@ class SparkHadoopUtil extends Logging { buffer.toString } - private[spark] def checkAccessPermission(status: FileStatus, mode: FsAction): Boolean = { - val perm = status.getPermission - val ugi = UserGroupInformation.getCurrentUser - - if (ugi.getShortUserName == status.getOwner) { - if (perm.getUserAction.implies(mode)) { - return true - } - } else if (ugi.getGroupNames.contains(status.getGroup)) { - if (perm.getGroupAction.implies(mode)) { - return true - } - } else if (perm.getOtherAction.implies(mode)) { - return true - } - - logDebug(s"Permission denied: user=${ugi.getShortUserName}, " + - s"path=${status.getPath}:${status.getOwner}:${status.getGroup}" + - s"${if (status.isDirectory) "d" else "-"}$perm") - false - } - def serialize(creds: Credentials): Array[Byte] = { val byteStream = new ByteArrayOutputStream val dataStream = new DataOutputStream(byteStream) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index ace6d9e00c83..23572ebf58f0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -19,19 +19,19 @@ package org.apache.spark.deploy.history import java.io.{File, FileNotFoundException, IOException} import java.util.{Date, ServiceLoader, UUID} -import java.util.concurrent.{ExecutorService, TimeUnit} +import java.util.concurrent.{ConcurrentHashMap, ExecutorService, Future, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.concurrent.ExecutionException import scala.util.Try import scala.xml.Node import com.fasterxml.jackson.annotation.JsonIgnore import com.google.common.io.ByteStreams import com.google.common.util.concurrent.MoreExecutors -import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.fs.permission.FsAction +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.hdfs.DistributedFileSystem import org.apache.hadoop.hdfs.protocol.HdfsConstants import org.apache.hadoop.security.AccessControlException @@ -111,7 +111,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) "; groups with admin permissions" + HISTORY_UI_ADMIN_ACLS_GROUPS.toString) private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) - private val fs = new Path(logDir).getFileSystem(hadoopConf) + // Visible for testing + private[history] val fs: FileSystem = new Path(logDir).getFileSystem(hadoopConf) // Used by check event thread and clean log thread. // Scheduled thread pool size must be one, otherwise it will have concurrent issues about fs @@ -155,6 +156,25 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) new HistoryServerDiskManager(conf, path, listing, clock) } + private val blacklist = new ConcurrentHashMap[String, Long] + + // Visible for testing + private[history] def isBlacklisted(path: Path): Boolean = { + blacklist.containsKey(path.getName) + } + + private def blacklist(path: Path): Unit = { + blacklist.put(path.getName, clock.getTimeMillis()) + } + + /** + * Removes expired entries in the blacklist, according to the provided `expireTimeInSeconds`. + */ + private def clearBlacklist(expireTimeInSeconds: Long): Unit = { + val expiredThreshold = clock.getTimeMillis() - expireTimeInSeconds * 1000 + blacklist.asScala.retain((_, creationTime) => creationTime >= expiredThreshold) + } + private val activeUIs = new mutable.HashMap[(String, Option[String]), LoadedAppUI]() /** @@ -412,7 +432,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // reading a garbage file is safe, but we would log an error which can be scary to // the end-user. !entry.getPath().getName().startsWith(".") && - SparkHadoopUtil.get.checkAccessPermission(entry, FsAction.READ) + !isBlacklisted(entry.getPath) } .filter { entry => try { @@ -446,32 +466,37 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) logDebug(s"New/updated attempts found: ${updated.size} ${updated.map(_.getPath)}") } - val tasks = updated.map { entry => + val tasks = updated.flatMap { entry => try { - replayExecutor.submit(new Runnable { + val task: Future[Unit] = replayExecutor.submit(new Runnable { override def run(): Unit = mergeApplicationListing(entry, newLastScanTime) - }) + }, Unit) + Some(task -> entry.getPath) } catch { // let the iteration over the updated entries break, since an exception on // replayExecutor.submit (..) indicates the ExecutorService is unable // to take any more submissions at this time case e: Exception => logError(s"Exception while submitting event log for replay", e) - null + None } - }.filter(_ != null) + } pendingReplayTasksCount.addAndGet(tasks.size) // Wait for all tasks to finish. This makes sure that checkForLogs // is not scheduled again while some tasks are already running in // the replayExecutor. - tasks.foreach { task => + tasks.foreach { case (task, path) => try { task.get() } catch { case e: InterruptedException => throw e + case e: ExecutionException if e.getCause.isInstanceOf[AccessControlException] => + // We don't have read permissions on the log file + logWarning(s"Unable to read log $path", e.getCause) + blacklist(path) case e: Exception => logError("Exception while merging application listings", e) } finally { @@ -694,6 +719,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) listing.delete(classOf[LogInfo], log.logPath) } } + // Clean the blacklist from the expired entries. + clearBlacklist(CLEAN_INTERVAL_S) } /** @@ -871,13 +898,17 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } private def deleteLog(log: Path): Unit = { - try { - fs.delete(log, true) - } catch { - case _: AccessControlException => - logInfo(s"No permission to delete $log, ignoring.") - case ioe: IOException => - logError(s"IOException in cleaning $log", ioe) + if (isBlacklisted(log)) { + logDebug(s"Skipping deleting $log as we don't have permissions on it.") + } else { + try { + fs.delete(log, true) + } catch { + case _: AccessControlException => + logInfo(s"No permission to delete $log, ignoring.") + case ioe: IOException => + logError(s"IOException in cleaning $log", ioe) + } } } diff --git a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java index db91329c94cb..1b7739cce6fb 100644 --- a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java +++ b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java @@ -19,6 +19,8 @@ import java.io.IOException; +import org.apache.spark.unsafe.memory.MemoryBlock; + public class TestMemoryConsumer extends MemoryConsumer { public TestMemoryConsumer(TaskMemoryManager memoryManager, MemoryMode mode) { super(memoryManager, 1024L, mode); @@ -43,6 +45,11 @@ void free(long size) { used -= size; taskMemoryManager.releaseExecutionMemory(size, this); } + + public void freePage(MemoryBlock page) { + used -= page.size(); + taskMemoryManager.freePage(page, this); + } } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala deleted file mode 100644 index ab24a76e20a3..000000000000 --- a/core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy - -import java.security.PrivilegedExceptionAction - -import scala.util.Random - -import org.apache.hadoop.fs.FileStatus -import org.apache.hadoop.fs.permission.{FsAction, FsPermission} -import org.apache.hadoop.security.UserGroupInformation -import org.scalatest.Matchers - -import org.apache.spark.SparkFunSuite - -class SparkHadoopUtilSuite extends SparkFunSuite with Matchers { - test("check file permission") { - import FsAction._ - val testUser = s"user-${Random.nextInt(100)}" - val testGroups = Array(s"group-${Random.nextInt(100)}") - val testUgi = UserGroupInformation.createUserForTesting(testUser, testGroups) - - testUgi.doAs(new PrivilegedExceptionAction[Void] { - override def run(): Void = { - val sparkHadoopUtil = new SparkHadoopUtil - - // If file is owned by user and user has access permission - var status = fileStatus(testUser, testGroups.head, READ_WRITE, READ_WRITE, NONE) - sparkHadoopUtil.checkAccessPermission(status, READ) should be(true) - sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(true) - - // If file is owned by user but user has no access permission - status = fileStatus(testUser, testGroups.head, NONE, READ_WRITE, NONE) - sparkHadoopUtil.checkAccessPermission(status, READ) should be(false) - sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(false) - - val otherUser = s"test-${Random.nextInt(100)}" - val otherGroup = s"test-${Random.nextInt(100)}" - - // If file is owned by user's group and user's group has access permission - status = fileStatus(otherUser, testGroups.head, NONE, READ_WRITE, NONE) - sparkHadoopUtil.checkAccessPermission(status, READ) should be(true) - sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(true) - - // If file is owned by user's group but user's group has no access permission - status = fileStatus(otherUser, testGroups.head, READ_WRITE, NONE, NONE) - sparkHadoopUtil.checkAccessPermission(status, READ) should be(false) - sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(false) - - // If file is owned by other user and this user has access permission - status = fileStatus(otherUser, otherGroup, READ_WRITE, READ_WRITE, READ_WRITE) - sparkHadoopUtil.checkAccessPermission(status, READ) should be(true) - sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(true) - - // If file is owned by other user but this user has no access permission - status = fileStatus(otherUser, otherGroup, READ_WRITE, READ_WRITE, NONE) - sparkHadoopUtil.checkAccessPermission(status, READ) should be(false) - sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(false) - - null - } - }) - } - - private def fileStatus( - owner: String, - group: String, - userAction: FsAction, - groupAction: FsAction, - otherAction: FsAction): FileStatus = { - new FileStatus(0L, - false, - 0, - 0L, - 0L, - 0L, - new FsPermission(userAction, groupAction, otherAction), - owner, - group, - null) - } -} diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 787de59edf46..bf2b04484a42 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -29,9 +29,11 @@ import scala.language.postfixOps import com.google.common.io.{ByteStreams, Files} import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.hdfs.DistributedFileSystem +import org.apache.hadoop.security.AccessControlException import org.json4s.jackson.JsonMethods._ -import org.mockito.Matchers.any -import org.mockito.Mockito.{doReturn, mock, spy, verify} +import org.mockito.ArgumentMatcher +import org.mockito.Matchers.{any, argThat} +import org.mockito.Mockito.{doReturn, doThrow, mock, spy, verify, when} import org.scalatest.BeforeAndAfter import org.scalatest.Matchers import org.scalatest.concurrent.Eventually._ @@ -774,6 +776,42 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc assert(new File(testDir.toURI).listFiles().size === validLogCount) } + test("SPARK-24948: blacklist files we don't have read permission on") { + val clock = new ManualClock(1533132471) + val provider = new FsHistoryProvider(createTestConf(), clock) + val accessDenied = newLogFile("accessDenied", None, inProgress = false) + writeFile(accessDenied, true, None, + SparkListenerApplicationStart("accessDenied", Some("accessDenied"), 1L, "test", None)) + val accessGranted = newLogFile("accessGranted", None, inProgress = false) + writeFile(accessGranted, true, None, + SparkListenerApplicationStart("accessGranted", Some("accessGranted"), 1L, "test", None), + SparkListenerApplicationEnd(5L)) + val mockedFs = spy(provider.fs) + doThrow(new AccessControlException("Cannot read accessDenied file")).when(mockedFs).open( + argThat(new ArgumentMatcher[Path]() { + override def matches(path: Any): Boolean = { + path.asInstanceOf[Path].getName.toLowerCase == "accessdenied" + } + })) + val mockedProvider = spy(provider) + when(mockedProvider.fs).thenReturn(mockedFs) + updateAndCheck(mockedProvider) { list => + list.size should be(1) + } + writeFile(accessDenied, true, None, + SparkListenerApplicationStart("accessDenied", Some("accessDenied"), 1L, "test", None), + SparkListenerApplicationEnd(5L)) + // Doing 2 times in order to check the blacklist filter too + updateAndCheck(mockedProvider) { list => + list.size should be(1) + } + val accessDeniedPath = new Path(accessDenied.getPath) + assert(mockedProvider.isBlacklisted(accessDeniedPath)) + clock.advance(24 * 60 * 60 * 1000 + 1) // add a bit more than 1d + mockedProvider.cleanLogs() + assert(!mockedProvider.isBlacklisted(accessDeniedPath)) + } + /** * Asks the provider to check for logs and calls a function to perform checks on the updated * app list. Example: diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala new file mode 100644 index 000000000000..b9f0e873375b --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort + +import java.lang.{Long => JLong} + +import org.mockito.Mockito.when +import org.scalatest.mockito.MockitoSugar + +import org.apache.spark._ +import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics} +import org.apache.spark.memory._ +import org.apache.spark.unsafe.Platform + +class ShuffleExternalSorterSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar { + + test("nested spill should be no-op") { + val conf = new SparkConf() + .setMaster("local[1]") + .setAppName("ShuffleExternalSorterSuite") + .set("spark.testing", "true") + .set("spark.testing.memory", "1600") + .set("spark.memory.fraction", "1") + sc = new SparkContext(conf) + + val memoryManager = UnifiedMemoryManager(conf, 1) + + var shouldAllocate = false + + // Mock `TaskMemoryManager` to allocate free memory when `shouldAllocate` is true. + // This will trigger a nested spill and expose issues if we don't handle this case properly. + val taskMemoryManager = new TaskMemoryManager(memoryManager, 0) { + override def acquireExecutionMemory(required: Long, consumer: MemoryConsumer): Long = { + // ExecutionMemoryPool.acquireMemory will wait until there are 400 bytes for a task to use. + // So we leave 400 bytes for the task. + if (shouldAllocate && + memoryManager.maxHeapMemory - memoryManager.executionMemoryUsed > 400) { + val acquireExecutionMemoryMethod = + memoryManager.getClass.getMethods.filter(_.getName == "acquireExecutionMemory").head + acquireExecutionMemoryMethod.invoke( + memoryManager, + JLong.valueOf( + memoryManager.maxHeapMemory - memoryManager.executionMemoryUsed - 400), + JLong.valueOf(1L), // taskAttemptId + MemoryMode.ON_HEAP + ).asInstanceOf[java.lang.Long] + } + super.acquireExecutionMemory(required, consumer) + } + } + val taskContext = mock[TaskContext] + val taskMetrics = new TaskMetrics + when(taskContext.taskMetrics()).thenReturn(taskMetrics) + val sorter = new ShuffleExternalSorter( + taskMemoryManager, + sc.env.blockManager, + taskContext, + 100, // initialSize - This will require ShuffleInMemorySorter to acquire at least 800 bytes + 1, // numPartitions + conf, + new ShuffleWriteMetrics) + val inMemSorter = { + val field = sorter.getClass.getDeclaredField("inMemSorter") + field.setAccessible(true) + field.get(sorter).asInstanceOf[ShuffleInMemorySorter] + } + // Allocate memory to make the next "insertRecord" call triggers a spill. + val bytes = new Array[Byte](1) + while (inMemSorter.hasSpaceForAnotherRecord) { + sorter.insertRecord(bytes, Platform.BYTE_ARRAY_OFFSET, 1, 0) + } + + // This flag will make the mocked TaskMemoryManager acquire free memory released by spill to + // trigger a nested spill. + shouldAllocate = true + + // Should throw `SparkOutOfMemoryError` as there is no enough memory: `ShuffleInMemorySorter` + // will try to acquire 800 bytes but there are only 400 bytes available. + // + // Before the fix, a nested spill may use a released page and this causes two tasks access the + // same memory page. When a task reads memory written by another task, many types of failures + // may happen. Here are some examples we have seen: + // + // - JVM crash. (This is easy to reproduce in the unit test as we fill newly allocated and + // deallocated memory with 0xa5 and 0x5a bytes which usually points to an invalid memory + // address) + // - java.lang.IllegalArgumentException: Comparison method violates its general contract! + // - java.lang.NullPointerException + // at org.apache.spark.memory.TaskMemoryManager.getPage(TaskMemoryManager.java:384) + // - java.lang.UnsupportedOperationException: Cannot grow BufferHolder by size -536870912 + // because the size after growing exceeds size limitation 2147483632 + intercept[SparkOutOfMemoryError] { + sorter.insertRecord(bytes, Platform.BYTE_ARRAY_OFFSET, 1, 0) + } + } +} diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index df2be777ff5a..6a3b38b91981 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -66,21 +66,21 @@ gson-2.2.4.jar guava-14.0.1.jar guice-3.0.jar guice-servlet-3.0.jar -hadoop-annotations-2.7.3.jar -hadoop-auth-2.7.3.jar -hadoop-client-2.7.3.jar -hadoop-common-2.7.3.jar -hadoop-hdfs-2.7.3.jar -hadoop-mapreduce-client-app-2.7.3.jar -hadoop-mapreduce-client-common-2.7.3.jar -hadoop-mapreduce-client-core-2.7.3.jar -hadoop-mapreduce-client-jobclient-2.7.3.jar -hadoop-mapreduce-client-shuffle-2.7.3.jar -hadoop-yarn-api-2.7.3.jar -hadoop-yarn-client-2.7.3.jar -hadoop-yarn-common-2.7.3.jar -hadoop-yarn-server-common-2.7.3.jar -hadoop-yarn-server-web-proxy-2.7.3.jar +hadoop-annotations-2.7.7.jar +hadoop-auth-2.7.7.jar +hadoop-client-2.7.7.jar +hadoop-common-2.7.7.jar +hadoop-hdfs-2.7.7.jar +hadoop-mapreduce-client-app-2.7.7.jar +hadoop-mapreduce-client-common-2.7.7.jar +hadoop-mapreduce-client-core-2.7.7.jar +hadoop-mapreduce-client-jobclient-2.7.7.jar +hadoop-mapreduce-client-shuffle-2.7.7.jar +hadoop-yarn-api-2.7.7.jar +hadoop-yarn-client-2.7.7.jar +hadoop-yarn-common-2.7.7.jar +hadoop-yarn-server-common-2.7.7.jar +hadoop-yarn-server-web-proxy-2.7.7.jar hk2-api-2.4.0-b34.jar hk2-locator-2.4.0-b34.jar hk2-utils-2.4.0-b34.jar @@ -121,6 +121,7 @@ jersey-media-jaxb-2.22.2.jar jersey-server-2.22.2.jar jets3t-0.9.4.jar jetty-6.1.26.jar +jetty-sslengine-6.1.26.jar jetty-util-6.1.26.jar jline-2.12.1.jar joda-time-2.9.3.jar diff --git a/docs/building-spark.md b/docs/building-spark.md index c391255a9159..9f78c04a5c8e 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -67,7 +67,7 @@ Examples: ./build/mvn -Pyarn -DskipTests clean package # Apache Hadoop 2.7.X and later - ./build/mvn -Pyarn -Phadoop-2.7 -Dhadoop.version=2.7.3 -DskipTests clean package + ./build/mvn -Pyarn -Phadoop-2.7 -Dhadoop.version=2.7.7 -DskipTests clean package ## Building With Hive and JDBC Support diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md index 7f277543d2e9..ac398fb7fce7 100644 --- a/docs/mllib-evaluation-metrics.md +++ b/docs/mllib-evaluation-metrics.md @@ -462,13 +462,13 @@ $$rel_D(r) = \begin{cases}1 & \text{if $r \in D$}, \\ 0 & \text{otherwise}.\end{ Normalized Discounted Cumulative Gain $NDCG(k)=\frac{1}{M} \sum_{i=0}^{M-1} {\frac{1}{IDCG(D_i, k)}\sum_{j=0}^{n-1} - \frac{rel_{D_i}(R_i(j))}{\text{ln}(j+1)}} \\ + \frac{rel_{D_i}(R_i(j))}{\text{ln}(j+2)}} \\ \text{Where} \\ \hspace{5 mm} n = \text{min}\left(\text{max}\left(|R_i|,|D_i|\right),k\right) \\ - \hspace{5 mm} IDCG(D, k) = \sum_{j=0}^{\text{min}(\left|D\right|, k) - 1} \frac{1}{\text{ln}(j+1)}$ + \hspace{5 mm} IDCG(D, k) = \sum_{j=0}^{\text{min}(\left|D\right|, k) - 1} \frac{1}{\text{ln}(j+2)}$ - NDCG at k is a + NDCG at k is a measure of how many of the first k recommended documents are in the set of true relevant documents averaged across all users. In contrast to precision at k, this metric takes into account the order of the recommendations (documents are assumed to be in order of decreasing relevance). diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala index 498e344ea39f..ca05b9e8d39b 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala @@ -124,8 +124,6 @@ private[kafka010] class KafkaSourceRDD( thePart: Partition, context: TaskContext): Iterator[ConsumerRecord[Array[Byte], Array[Byte]]] = { val sourcePartition = thePart.asInstanceOf[KafkaSourceRDDPartition] - val topic = sourcePartition.offsetRange.topic - val kafkaPartition = sourcePartition.offsetRange.partition val consumer = KafkaDataConsumer.acquire( sourcePartition.offsetRange.topicPartition, executorKafkaParams, reuseKafkaConsumer) @@ -138,6 +136,7 @@ private[kafka010] class KafkaSourceRDD( if (range.fromOffset == range.untilOffset) { logInfo(s"Beginning offset ${range.fromOffset} is the same as ending offset " + s"skipping ${range.topic} ${range.partition}") + consumer.release() Iterator.empty } else { val underlying = new NextIterator[ConsumerRecord[Array[Byte], Array[Byte]]]() { diff --git a/pom.xml b/pom.xml index 9d29b3a1a8a8..1ce3900da81e 100644 --- a/pom.xml +++ b/pom.xml @@ -132,7 +132,7 @@ 1.4.4 nohive 1.6.0 - 9.3.20.v20170531 + 9.3.24.v20180605 3.1.0 0.8.4 2.4.0 @@ -2678,7 +2678,7 @@ hadoop-2.7 - 2.7.3 + 2.7.7 2.7.1 diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 6ef8cf53cc74..bc0be07bfb36 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -228,20 +228,49 @@ class _UpdateRequestHandler(SocketServer.StreamRequestHandler): def handle(self): from pyspark.accumulators import _accumulatorRegistry - while not self.server.server_shutdown: - # Poll every 1 second for new data -- don't block in case of shutdown. - r, _, _ = select.select([self.rfile], [], [], 1) - if self.rfile in r: - num_updates = read_int(self.rfile) - for _ in range(num_updates): - (aid, update) = pickleSer._read_with_length(self.rfile) - _accumulatorRegistry[aid] += update - # Write a byte in acknowledgement - self.wfile.write(struct.pack("!b", 1)) + auth_token = self.server.auth_token + + def poll(func): + while not self.server.server_shutdown: + # Poll every 1 second for new data -- don't block in case of shutdown. + r, _, _ = select.select([self.rfile], [], [], 1) + if self.rfile in r: + if func(): + break + + def accum_updates(): + num_updates = read_int(self.rfile) + for _ in range(num_updates): + (aid, update) = pickleSer._read_with_length(self.rfile) + _accumulatorRegistry[aid] += update + # Write a byte in acknowledgement + self.wfile.write(struct.pack("!b", 1)) + return False + + def authenticate_and_accum_updates(): + received_token = self.rfile.read(len(auth_token)) + if isinstance(received_token, bytes): + received_token = received_token.decode("utf-8") + if (received_token == auth_token): + accum_updates() + # we've authenticated, we can break out of the first loop now + return True + else: + raise Exception( + "The value of the provided token to the AccumulatorServer is not correct.") + + # first we keep polling till we've received the authentication token + poll(authenticate_and_accum_updates) + # now we've authenticated, don't need to check for the token anymore + poll(accum_updates) class AccumulatorServer(SocketServer.TCPServer): + def __init__(self, server_address, RequestHandlerClass, auth_token): + SocketServer.TCPServer.__init__(self, server_address, RequestHandlerClass) + self.auth_token = auth_token + """ A simple TCP server that intercepts shutdown() in order to interrupt our continuous polling on the handler. @@ -254,9 +283,9 @@ def shutdown(self): self.server_close() -def _start_update_server(): +def _start_update_server(auth_token): """Start a TCP server to receive accumulator updates in a daemon thread, and returns it""" - server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler) + server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler, auth_token) thread = threading.Thread(target=server.serve_forever) thread.daemon = True thread.start() diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 880559719d3a..667d0a310161 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -183,9 +183,10 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, # Create a single Accumulator in Java that we'll send all our updates through; # they will be passed back to us through a TCP server - self._accumulatorServer = accumulators._start_update_server() + auth_token = self._gateway.gateway_parameters.auth_token + self._accumulatorServer = accumulators._start_update_server(auth_token) (host, port) = self._accumulatorServer.server_address - self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port) + self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port, auth_token) self._jsc.sc().register(self._javaAccumulator) self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 04b07e6a0548..a444fe0f98a4 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -3673,12 +3673,12 @@ def setParams(self, inputCol=None, size=None, handleInvalid="error"): @since("2.3.0") def getSize(self): """ Gets size param, the size of vectors in `inputCol`.""" - self.getOrDefault(self.size) + return self.getOrDefault(self.size) @since("2.3.0") def setSize(self, value): """ Sets size param, the size of vectors in `inputCol`.""" - self._set(size=value) + return self._set(size=value) if __name__ == "__main__": diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 1af2b91da900..49912d20da1d 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -678,6 +678,23 @@ def test_string_indexer_handle_invalid(self): expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0)] self.assertEqual(actual2, expected2) + def test_vector_size_hint(self): + df = self.spark.createDataFrame( + [(0, Vectors.dense([0.0, 10.0, 0.5])), + (1, Vectors.dense([1.0, 11.0, 0.5, 0.6])), + (2, Vectors.dense([2.0, 12.0]))], + ["id", "vector"]) + + sizeHint = VectorSizeHint( + inputCol="vector", + handleInvalid="skip") + sizeHint.setSize(3) + self.assertEqual(sizeHint.getSize(), 3) + + output = sizeHint.transform(df).head().vector + expected = DenseVector([0.0, 10.0, 0.5]) + self.assertEqual(output, expected) + class HasInducedError(Params): diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java index bb77b5bf6de2..40c2cc806e87 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java @@ -22,12 +22,10 @@ public final class RecordBinaryComparator extends RecordComparator { - // TODO(jiangxb) Add test suite for this. @Override public int compare( Object leftObj, long leftOff, int leftLen, Object rightObj, long rightOff, int rightLen) { int i = 0; - int res = 0; // If the arrays have different length, the longer one is larger. if (leftLen != rightLen) { @@ -40,27 +38,33 @@ public int compare( // check if stars align and we can get both offsets to be aligned if ((leftOff % 8) == (rightOff % 8)) { while ((leftOff + i) % 8 != 0 && i < leftLen) { - res = (Platform.getByte(leftObj, leftOff + i) & 0xff) - - (Platform.getByte(rightObj, rightOff + i) & 0xff); - if (res != 0) return res; + final int v1 = Platform.getByte(leftObj, leftOff + i) & 0xff; + final int v2 = Platform.getByte(rightObj, rightOff + i) & 0xff; + if (v1 != v2) { + return v1 > v2 ? 1 : -1; + } i += 1; } } // for architectures that support unaligned accesses, chew it up 8 bytes at a time if (Platform.unaligned() || (((leftOff + i) % 8 == 0) && ((rightOff + i) % 8 == 0))) { while (i <= leftLen - 8) { - res = (int) ((Platform.getLong(leftObj, leftOff + i) - - Platform.getLong(rightObj, rightOff + i)) % Integer.MAX_VALUE); - if (res != 0) return res; + final long v1 = Platform.getLong(leftObj, leftOff + i); + final long v2 = Platform.getLong(rightObj, rightOff + i); + if (v1 != v2) { + return v1 > v2 ? 1 : -1; + } i += 8; } } // this will finish off the unaligned comparisons, or do the entire aligned comparison // whichever is needed. while (i < leftLen) { - res = (Platform.getByte(leftObj, leftOff + i) & 0xff) - - (Platform.getByte(rightObj, rightOff + i) & 0xff); - if (res != 0) return res; + final int v1 = Platform.getByte(leftObj, leftOff + i) & 0xff; + final int v2 = Platform.getByte(rightObj, rightOff + i) & 0xff; + if (v1 != v2) { + return v1 > v2 ? 1 : -1; + } i += 1; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 5963c1467e25..531f3d468185 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1704,6 +1704,7 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case p if !p.resolved => p // Skip unresolved nodes. + case ab: AnalysisBarrier => apply(ab.child) case p: LogicalPlan if p.resolved => val childrenOutput = p.children.flatMap(c => c.output).groupBy(_.exprId).flatMap { case (exprId, attributes) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 055ebf6c0da5..5aa0325678d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -404,14 +404,15 @@ abstract class HashExpression[E] extends Expression { input: String, result: String, fields: Array[StructField]): String = { + val tmpInput = ctx.freshName("input") val fieldsHash = fields.zipWithIndex.map { case (field, index) => - nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx) + nullSafeElementHash(tmpInput, index.toString, field.nullable, field.dataType, result, ctx) } val hashResultType = ctx.javaType(dataType) - ctx.splitExpressions( + val code = ctx.splitExpressions( expressions = fieldsHash, funcName = "computeHashForStruct", - arguments = Seq("InternalRow" -> input, hashResultType -> result), + arguments = Seq("InternalRow" -> tmpInput, hashResultType -> result), returnType = hashResultType, makeSplitFunction = body => s""" @@ -419,6 +420,10 @@ abstract class HashExpression[E] extends Expression { |return $result; """.stripMargin, foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n")) + s""" + |final InternalRow $tmpInput = $input; + |$code + """.stripMargin } @tailrec @@ -769,10 +774,11 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { input: String, result: String, fields: Array[StructField]): String = { + val tmpInput = ctx.freshName("input") val childResult = ctx.freshName("childResult") val fieldsHash = fields.zipWithIndex.map { case (field, index) => val computeFieldHash = nullSafeElementHash( - input, index.toString, field.nullable, field.dataType, childResult, ctx) + tmpInput, index.toString, field.nullable, field.dataType, childResult, ctx) s""" |$childResult = 0; |$computeFieldHash @@ -780,10 +786,10 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { """.stripMargin } - s"${ctx.JAVA_INT} $childResult = 0;\n" + ctx.splitExpressions( + val code = ctx.splitExpressions( expressions = fieldsHash, funcName = "computeHashForStruct", - arguments = Seq("InternalRow" -> input, ctx.JAVA_INT -> result), + arguments = Seq("InternalRow" -> tmpInput, ctx.JAVA_INT -> result), returnType = ctx.JAVA_INT, makeSplitFunction = body => s""" @@ -792,6 +798,11 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { |return $result; """.stripMargin, foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n")) + s""" + |final InternalRow $tmpInput = $input; + |${ctx.JAVA_INT} $childResult = 0; + |$code + """.stripMargin } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 34161f0f03f4..e4766e4b02c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -628,7 +628,7 @@ case class JsonToStructs( {"a":1,"b":2} > SELECT _FUNC_(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')); {"time":"26/08/2015"} - > SELECT _FUNC_(array(named_struct('a', 1, 'b', 2)); + > SELECT _FUNC_(array(named_struct('a', 1, 'b', 2))); [{"a":1,"b":2}] > SELECT _FUNC_(map('a', named_struct('b', 1))); {"a":{"b":1}} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index d7612e30b4c5..791dd7156d7c 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1653,10 +1653,9 @@ case class Left(str: Expression, len: Expression, child: Expression) extends Run * A function that returns the char length of the given string expression or * number of bytes of the given binary expression. */ +// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the character length of string data or number of bytes of " + - "binary data. The length of string data includes the trailing spaces. The length of binary " + - "data includes binary zeros.", + usage = "_FUNC_(expr) - Returns the character length of string data or number of bytes of binary data. The length of string data includes the trailing spaces. The length of binary data includes binary zeros.", examples = """ Examples: > SELECT _FUNC_('Spark SQL '); @@ -1666,6 +1665,7 @@ case class Left(str: Expression, len: Expression, child: Expression) extends Run > SELECT CHARACTER_LENGTH('Spark SQL '); 10 """) +// scalastyle:on line.size.limit case class Length(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 625ff38943fa..b025b85d9f60 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -650,18 +650,18 @@ class DateTimeUtilsSuite extends SparkFunSuite { assert(daysToMillis(16800, TimeZoneGMT) === c.getTimeInMillis) // There are some days are skipped entirely in some timezone, skip them here. - val skipped_days = Map[String, Int]( - "Kwajalein" -> 8632, - "Pacific/Apia" -> 15338, - "Pacific/Enderbury" -> 9131, - "Pacific/Fakaofo" -> 15338, - "Pacific/Kiritimati" -> 9131, - "Pacific/Kwajalein" -> 8632, - "MIT" -> 15338) + val skipped_days = Map[String, Set[Int]]( + "Kwajalein" -> Set(8632), + "Pacific/Apia" -> Set(15338), + "Pacific/Enderbury" -> Set(9130, 9131), + "Pacific/Fakaofo" -> Set(15338), + "Pacific/Kiritimati" -> Set(9130, 9131), + "Pacific/Kwajalein" -> Set(8632), + "MIT" -> Set(15338)) for (tz <- DateTimeTestUtils.ALL_TIMEZONES) { - val skipped = skipped_days.getOrElse(tz.getID, Int.MinValue) + val skipped = skipped_days.getOrElse(tz.getID, Set.empty) (-20000 to 20000).foreach { d => - if (d != skipped) { + if (!skipped.contains(d)) { assert(millisToDays(daysToMillis(d, tz), tz) === d, s"Round trip of ${d} did not work in tz ${tz}") } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index c7c4c7b3e771..c8cf44b51df7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -20,8 +20,8 @@ import java.io.IOException; import org.apache.spark.SparkEnv; +import org.apache.spark.TaskContext; import org.apache.spark.internal.config.package$; -import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; @@ -82,7 +82,7 @@ public static boolean supportsAggregationBufferSchema(StructType schema) { * @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function) * @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion. * @param groupingKeySchema the schema of the grouping key, used for row conversion. - * @param taskMemoryManager the memory manager used to allocate our Unsafe memory structures. + * @param taskContext the current task context. * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing). * @param pageSizeBytes the data page size, in bytes; limits the maximum record size. */ @@ -90,19 +90,26 @@ public UnsafeFixedWidthAggregationMap( InternalRow emptyAggregationBuffer, StructType aggregationBufferSchema, StructType groupingKeySchema, - TaskMemoryManager taskMemoryManager, + TaskContext taskContext, int initialCapacity, long pageSizeBytes) { this.aggregationBufferSchema = aggregationBufferSchema; this.currentAggregationBuffer = new UnsafeRow(aggregationBufferSchema.length()); this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema); this.groupingKeySchema = groupingKeySchema; - this.map = - new BytesToBytesMap(taskMemoryManager, initialCapacity, pageSizeBytes, true); + this.map = new BytesToBytesMap( + taskContext.taskMemoryManager(), initialCapacity, pageSizeBytes, true); // Initialize the buffer for aggregation value final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema); this.emptyAggregationBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes(); + + // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at + // the end of the task. This is necessary to avoid memory leaks in when the downstream operator + // does not fully consume the aggregation map's output (e.g. aggregate followed by limit). + taskContext.addTaskCompletionListener(context -> { + free(); + }); } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 58780a3ed08b..14e9d89a1e41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -93,7 +93,8 @@ class SparkSession private( // If there is no active SparkSession, uses the default SQL conf. Otherwise, use the session's. SQLConf.setSQLConfGetter(() => { - SparkSession.getActiveSession.map(_.sessionState.conf).getOrElse(SQLConf.getFallbackConf) + SparkSession.getActiveSession.filterNot(_.sparkContext.isStopped).map(_.sessionState.conf) + .getOrElse(SQLConf.getFallbackConf) }) /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index ce3c68810f3b..c0fbaf60c416 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -324,7 +324,7 @@ case class HashAggregateExec( initialBuffer, bufferSchema, groupingKeySchema, - TaskContext.get().taskMemoryManager(), + TaskContext.get(), 1024 * 16, // initial capacity TaskContext.get().taskMemoryManager().pageSizeBytes ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 9dc334c1ead3..c1911235f8df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -166,7 +166,7 @@ class TungstenAggregationIterator( initialAggregationBuffer, StructType.fromAttributes(aggregateFunctions.flatMap(_.aggBufferAttributes)), StructType.fromAttributes(groupingExpressions.map(_.toAttribute)), - TaskContext.get().taskMemoryManager(), + TaskContext.get(), 1024 * 16, // initial capacity TaskContext.get().taskMemoryManager().pageSizeBytes ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala index 5b54b2270b5e..18fefa0a6f19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.{AnalysisException, Column, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType, ExternalCatalogUtils} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, Literal} import org.apache.spark.sql.execution.datasources.PartitioningUtils @@ -140,7 +140,13 @@ case class AnalyzePartitionCommand( val df = tableDf.filter(Column(filter)).groupBy(partitionColumns: _*).count() df.collect().map { r => - val partitionColumnValues = partitionColumns.indices.map(r.get(_).toString) + val partitionColumnValues = partitionColumns.indices.map { i => + if (r.isNullAt(i)) { + ExternalCatalogUtils.DEFAULT_PARTITION_NAME + } else { + r.get(i).toString + } + } val spec = tableMeta.partitionColumnNames.zip(partitionColumnValues).toMap val count = BigInt(r.getLong(partitionColumns.size)) (spec, count) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java new file mode 100644 index 000000000000..97f3dc588ecc --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java @@ -0,0 +1,322 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.execution.sort; + +import org.apache.spark.SparkConf; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.memory.TestMemoryConsumer; +import org.apache.spark.memory.TestMemoryManager; +import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.execution.RecordBinaryComparator; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.UnsafeAlignedOffset; +import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.util.collection.unsafe.sort.*; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +/** + * Test the RecordBinaryComparator, which compares two UnsafeRows by their binary form. + */ +public class RecordBinaryComparatorSuite { + + private final TaskMemoryManager memoryManager = new TaskMemoryManager( + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); + private final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager); + + private final int uaoSize = UnsafeAlignedOffset.getUaoSize(); + + private MemoryBlock dataPage; + private long pageCursor; + + private LongArray array; + private int pos; + + @Before + public void beforeEach() { + // Only compare between two input rows. + array = consumer.allocateArray(2); + pos = 0; + + dataPage = memoryManager.allocatePage(4096, consumer); + pageCursor = dataPage.getBaseOffset(); + } + + @After + public void afterEach() { + consumer.freePage(dataPage); + dataPage = null; + pageCursor = 0; + + consumer.freeArray(array); + array = null; + pos = 0; + } + + private void insertRow(UnsafeRow row) { + Object recordBase = row.getBaseObject(); + long recordOffset = row.getBaseOffset(); + int recordLength = row.getSizeInBytes(); + + Object baseObject = dataPage.getBaseObject(); + assert(pageCursor + recordLength <= dataPage.getBaseOffset() + dataPage.size()); + long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, pageCursor); + UnsafeAlignedOffset.putSize(baseObject, pageCursor, recordLength); + pageCursor += uaoSize; + Platform.copyMemory(recordBase, recordOffset, baseObject, pageCursor, recordLength); + pageCursor += recordLength; + + assert(pos < 2); + array.set(pos, recordAddress); + pos++; + } + + private int compare(int index1, int index2) { + Object baseObject = dataPage.getBaseObject(); + + long recordAddress1 = array.get(index1); + long baseOffset1 = memoryManager.getOffsetInPage(recordAddress1) + uaoSize; + int recordLength1 = UnsafeAlignedOffset.getSize(baseObject, baseOffset1 - uaoSize); + + long recordAddress2 = array.get(index2); + long baseOffset2 = memoryManager.getOffsetInPage(recordAddress2) + uaoSize; + int recordLength2 = UnsafeAlignedOffset.getSize(baseObject, baseOffset2 - uaoSize); + + return binaryComparator.compare(baseObject, baseOffset1, recordLength1, baseObject, + baseOffset2, recordLength2); + } + + private final RecordComparator binaryComparator = new RecordBinaryComparator(); + + // Compute the most compact size for UnsafeRow's backing data. + private int computeSizeInBytes(int originalSize) { + // All the UnsafeRows in this suite contains less than 64 columns, so the bitSetSize shall + // always be 8. + return 8 + (originalSize + 7) / 8 * 8; + } + + // Compute the relative offset of variable-length values. + private long relativeOffset(int numFields) { + // All the UnsafeRows in this suite contains less than 64 columns, so the bitSetSize shall + // always be 8. + return 8 + numFields * 8L; + } + + @Test + public void testBinaryComparatorForSingleColumnRow() throws Exception { + int numFields = 1; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + row1.setInt(0, 11); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + row2.setInt(0, 42); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 0) == 0); + assert(compare(0, 1) < 0); + } + + @Test + public void testBinaryComparatorForMultipleColumnRow() throws Exception { + int numFields = 5; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + for (int i = 0; i < numFields; i++) { + row1.setDouble(i, i * 3.14); + } + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + for (int i = 0; i < numFields; i++) { + row2.setDouble(i, 198.7 / (i + 1)); + } + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 0) == 0); + assert(compare(0, 1) < 0); + } + + @Test + public void testBinaryComparatorForArrayColumn() throws Exception { + int numFields = 1; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + UnsafeArrayData arrayData1 = UnsafeArrayData.fromPrimitiveArray(new int[]{11, 42, -1}); + row1.pointTo(data1, computeSizeInBytes(numFields * 8 + arrayData1.getSizeInBytes())); + row1.setLong(0, (relativeOffset(numFields) << 32) | (long) arrayData1.getSizeInBytes()); + Platform.copyMemory(arrayData1.getBaseObject(), arrayData1.getBaseOffset(), data1, + row1.getBaseOffset() + relativeOffset(numFields), arrayData1.getSizeInBytes()); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + UnsafeArrayData arrayData2 = UnsafeArrayData.fromPrimitiveArray(new int[]{22}); + row2.pointTo(data2, computeSizeInBytes(numFields * 8 + arrayData2.getSizeInBytes())); + row2.setLong(0, (relativeOffset(numFields) << 32) | (long) arrayData2.getSizeInBytes()); + Platform.copyMemory(arrayData2.getBaseObject(), arrayData2.getBaseOffset(), data2, + row2.getBaseOffset() + relativeOffset(numFields), arrayData2.getSizeInBytes()); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 0) == 0); + assert(compare(0, 1) > 0); + } + + @Test + public void testBinaryComparatorForMixedColumns() throws Exception { + int numFields = 4; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + UTF8String str1 = UTF8String.fromString("Milk tea"); + row1.pointTo(data1, computeSizeInBytes(numFields * 8 + str1.numBytes())); + row1.setInt(0, 11); + row1.setDouble(1, 3.14); + row1.setInt(2, -1); + row1.setLong(3, (relativeOffset(numFields) << 32) | (long) str1.numBytes()); + Platform.copyMemory(str1.getBaseObject(), str1.getBaseOffset(), data1, + row1.getBaseOffset() + relativeOffset(numFields), str1.numBytes()); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + UTF8String str2 = UTF8String.fromString("Java"); + row2.pointTo(data2, computeSizeInBytes(numFields * 8 + str2.numBytes())); + row2.setInt(0, 11); + row2.setDouble(1, 3.14); + row2.setInt(2, -1); + row2.setLong(3, (relativeOffset(numFields) << 32) | (long) str2.numBytes()); + Platform.copyMemory(str2.getBaseObject(), str2.getBaseOffset(), data2, + row2.getBaseOffset() + relativeOffset(numFields), str2.numBytes()); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 0) == 0); + assert(compare(0, 1) > 0); + } + + @Test + public void testBinaryComparatorForNullColumns() throws Exception { + int numFields = 3; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + for (int i = 0; i < numFields; i++) { + row1.setNullAt(i); + } + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + for (int i = 0; i < numFields - 1; i++) { + row2.setNullAt(i); + } + row2.setDouble(numFields - 1, 3.14); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 0) == 0); + assert(compare(0, 1) > 0); + } + + @Test + public void testBinaryComparatorWhenSubtractionIsDivisibleByMaxIntValue() throws Exception { + int numFields = 1; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + row1.setLong(0, 11); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + row2.setLong(0, 11L + Integer.MAX_VALUE); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 1) < 0); + } + + @Test + public void testBinaryComparatorWhenSubtractionCanOverflowLongValue() throws Exception { + int numFields = 1; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + row1.setLong(0, Long.MIN_VALUE); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + row2.setLong(0, 1); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 1) < 0); + } + + @Test + public void testBinaryComparatorWhenOnlyTheLastColumnDiffers() throws Exception { + int numFields = 4; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + row1.setInt(0, 11); + row1.setDouble(1, 3.14); + row1.setInt(2, -1); + row1.setLong(3, 0); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + row2.setInt(0, 11); + row2.setDouble(1, 3.14); + row2.setInt(2, -1); + row2.setLong(3, 1); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 1) < 0); + } +} diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index 14a69128ffb4..1a1d5a3efe29 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -24,7 +24,7 @@ Extended Usage: {"a":1,"b":2} > SELECT to_json(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')); {"time":"26/08/2015"} - > SELECT to_json(array(named_struct('a', 1, 'b', 2)); + > SELECT to_json(array(named_struct('a', 1, 'b', 2))); [{"a":1,"b":2}] > SELECT to_json(map('a', named_struct('b', 1))); {"a":{"b":1}} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 9f8d337ead09..3640f6a4d0c6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2300,4 +2300,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer(aggPlusFilter1, aggPlusFilter2.collect()) } } + + test("SPARK-25051: fix nullabilities of outer join attributes doesn't stop on AnalysisBarrier") { + val df1 = spark.range(4).selectExpr("id", "cast(id as string) as name") + val df2 = spark.range(3).selectExpr("id") + assert(df1.join(df2, Seq("id"), "left_outer").where(df2("id").isNull).collect().length == 1) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala index cbef1c782831..6b90f20a94fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala @@ -36,19 +36,14 @@ trait LocalSparkSession extends BeforeAndAfterEach with BeforeAndAfterAll { self override def afterEach() { try { - resetSparkContext() + LocalSparkSession.stop(spark) SparkSession.clearActiveSession() SparkSession.clearDefaultSession() + spark = null } finally { super.afterEach() } } - - def resetSparkContext(): Unit = { - LocalSparkSession.stop(spark) - spark = null - } - } object LocalSparkSession { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5aabc98b714c..2ff47f9e2a12 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2806,4 +2806,24 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(df, Seq(Row(3, 99, 1))) } } + + test("SPARK-25084: 'distribute by' on multiple columns may lead to codegen issue") { + withView("spark_25084") { + val count = 1000 + val df = spark.range(count) + val columns = (0 until 400).map{ i => s"id as id$i" } + val distributeExprs = (0 until 100).map(c => s"id$c").mkString(",") + df.selectExpr(columns : _*).createTempView("spark_25084") + assert( + spark.sql(s"select * from spark_25084 distribute by ($distributeExprs)").count === count) + } + } + + test("SPARK-25144 'distinct' causes memory leak") { + val ds = List(Foo(Some("bar"))).toDS + val result = ds.flatMap(_.bar).distinct + result.rdd.isEmpty + } } + +case class Foo(bar: Option[String]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index b11e79853205..0e7209a78b25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -198,6 +198,24 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared } } + test("SPARK-25028: column stats collection for null partitioning columns") { + val table = "analyze_partition_with_null" + withTempDir { dir => + withTable(table) { + sql(s""" + |CREATE TABLE $table (value string, name string) + |USING PARQUET + |PARTITIONED BY (name) + |LOCATION '${dir.toURI}'""".stripMargin) + val df = Seq(("a", null), ("b", null)).toDF("value", "name") + df.write.mode("overwrite").insertInto(table) + sql(s"ANALYZE TABLE $table PARTITION (name) COMPUTE STATISTICS") + val partitions = spark.sessionState.catalog.listPartitions(TableIdentifier(table)) + assert(partitions.head.stats.get.rowCount.get == 2) + } + } + } + test("number format in statistics") { val numbers = Seq( BigInt(0) -> (("0.0 B", "0")), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 3e31d22e15c0..5c15ecd42fa0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -23,6 +23,7 @@ import scala.collection.mutable import scala.util.{Random, Try} import scala.util.control.NonFatal +import org.mockito.Mockito._ import org.scalatest.Matchers import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext, TaskContextImpl} @@ -54,6 +55,8 @@ class UnsafeFixedWidthAggregationMapSuite private var memoryManager: TestMemoryManager = null private var taskMemoryManager: TaskMemoryManager = null + private var taskContext: TaskContext = null + def testWithMemoryLeakDetection(name: String)(f: => Unit) { def cleanup(): Unit = { if (taskMemoryManager != null) { @@ -67,6 +70,8 @@ class UnsafeFixedWidthAggregationMapSuite val conf = new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false") memoryManager = new TestMemoryManager(conf) taskMemoryManager = new TaskMemoryManager(memoryManager, 0) + taskContext = mock(classOf[TaskContext]) + when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager) TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, @@ -111,7 +116,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 1024, // initial capacity, PAGE_SIZE_BYTES ) @@ -124,7 +129,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 1024, // initial capacity PAGE_SIZE_BYTES ) @@ -151,7 +156,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity PAGE_SIZE_BYTES ) @@ -176,7 +181,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity PAGE_SIZE_BYTES ) @@ -223,7 +228,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity PAGE_SIZE_BYTES ) @@ -263,7 +268,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, StructType(Nil), StructType(Nil), - taskMemoryManager, + taskContext, 128, // initial capacity PAGE_SIZE_BYTES ) @@ -307,7 +312,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity pageSize ) @@ -344,7 +349,7 @@ class UnsafeFixedWidthAggregationMapSuite emptyAggregationBuffer, aggBufferSchema, groupKeySchema, - taskMemoryManager, + taskContext, 128, // initial capacity pageSize ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfGetterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfGetterSuite.scala new file mode 100644 index 000000000000..bb79d3a84e5a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfGetterSuite.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{LocalSparkSession, SparkSession} + +class SQLConfGetterSuite extends SparkFunSuite with LocalSparkSession { + + test("SPARK-25076: SQLConf should not be retrieved from a stopped SparkSession") { + spark = SparkSession.builder().master("local").getOrCreate() + assert(SQLConf.get eq spark.sessionState.conf, + "SQLConf.get should get the conf from the active spark session.") + spark.stop() + assert(SQLConf.get eq SQLConf.getFallbackConf, + "SQLConf.get should not get conf from a stopped spark session.") + } +}