Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[SPARK-45856] Move ArtifactManager from Spark Connect into SparkSessi…
…on (sql/core)
  • Loading branch information
vicennial committed Nov 9, 2023
commit c36640b645453cb8e97c228e4b5e59f26aaf6cc4
Original file line number Diff line number Diff line change
Expand Up @@ -206,20 +206,6 @@ object Connect {
.intConf
.createWithDefault(1024)

val CONNECT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL =
buildStaticConf("spark.connect.copyFromLocalToFs.allowDestLocal")
.internal()
.doc("""
|Allow `spark.copyFromLocalToFs` destination to be local file system
| path on spark driver node when
|`spark.connect.copyFromLocalToFs.allowDestLocal` is true.
|This will allow user to overwrite arbitrary file on spark
|driver node we should only enable it for testing purpose.
|""".stripMargin)
.version("3.5.0")
.booleanConf
.createWithDefault(false)

val CONNECT_UI_STATEMENT_LIMIT =
buildStaticConf("spark.sql.connect.ui.retainedStatements")
.doc("The number of statements kept in the Spark Connect UI history.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -941,7 +941,7 @@ class SparkConnectPlanner(
command = fun.getCommand.toByteArray,
// Empty environment variables
envVars = Maps.newHashMap(),
pythonIncludes = sessionHolder.artifactManager.getSparkConnectPythonIncludes.asJava,
pythonIncludes = sessionHolder.artifactManager.getPythonIncludes.asJava,
pythonExec = pythonExec,
pythonVer = fun.getPythonVer,
// Empty broadcast variables
Expand Down Expand Up @@ -995,7 +995,7 @@ class SparkConnectPlanner(

private def transformCachedLocalRelation(rel: proto.CachedLocalRelation): LogicalPlan = {
val blockManager = session.sparkContext.env.blockManager
val blockId = CacheId(sessionHolder.userId, sessionHolder.sessionId, rel.getHash)
val blockId = CacheId(sessionHolder.session.sessionUUID, rel.getHash)
val bytes = blockManager.getLocalBytes(blockId)
bytes
.map { blockData =>
Expand All @@ -1013,7 +1013,7 @@ class SparkConnectPlanner(
.getOrElse {
throw InvalidPlanInput(
s"Not found any cached local relation with the hash: ${blockId.hash} in " +
s"the session ${blockId.sessionId} for the user id ${blockId.userId}.")
s"the session with sessionUUID ${blockId.sessionUUID}.")
}
}

Expand Down Expand Up @@ -1626,7 +1626,7 @@ class SparkConnectPlanner(
command = fun.getCommand.toByteArray,
// Empty environment variables
envVars = Maps.newHashMap(),
pythonIncludes = sessionHolder.artifactManager.getSparkConnectPythonIncludes.asJava,
pythonIncludes = sessionHolder.artifactManager.getPythonIncludes.asJava,
pythonExec = pythonExec,
pythonVer = fun.getPythonVer,
// Empty broadcast variables
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,16 @@ import scala.jdk.CollectionConverters._
import com.google.common.base.Ticker
import com.google.common.cache.CacheBuilder

import org.apache.spark.{JobArtifactSet, SparkException, SparkSQLException}
import org.apache.spark.{SparkException, SparkSQLException}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager
import org.apache.spark.sql.connect.common.InvalidPlanInput
import org.apache.spark.sql.connect.planner.PythonStreamingQueryListener
import org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper
import org.apache.spark.sql.connect.service.SessionHolder.{ERROR_CACHE_SIZE, ERROR_CACHE_TIMEOUT_SEC}
import org.apache.spark.sql.streaming.StreamingQueryListener
import org.apache.spark.util.SystemClock
import org.apache.spark.util.Utils

// Unique key identifying session by combination of user, and session id
case class SessionKey(userId: String, sessionId: String)
Expand Down Expand Up @@ -159,7 +157,7 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
interruptedIds.toSeq
}

private[connect] lazy val artifactManager = new SparkConnectArtifactManager(this)
private[connect] def artifactManager = session.artifactManager

/**
* Add an artifact to this SparkConnect session.
Expand Down Expand Up @@ -231,27 +229,13 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
eventManager.postClosed()
}

/**
* Execute a block of code using this session's classloader.
* @param f
* @tparam T
*/
def withContextClassLoader[T](f: => T): T = {
// Needed for deserializing and evaluating the UDF on the driver
Utils.withContextClassLoader(classloader) {
JobArtifactSet.withActiveJobArtifactState(artifactManager.state) {
f
}
}
}

/**
* Execute a block of code with this session as the active SparkConnect session.
* @param f
* @tparam T
*/
def withSession[T](f: SparkSession => T): T = {
withContextClassLoader {
artifactManager.withResources {
session.withActive {
f(session)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ import io.grpc.stub.StreamObserver
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse}
import org.apache.spark.connect.proto.AddArtifactsResponse.ArtifactSummary
import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager
import org.apache.spark.sql.connect.artifact.util.ArtifactUtils
import org.apache.spark.sql.artifact.ArtifactManager
import org.apache.spark.sql.artifact.util.ArtifactUtils
import org.apache.spark.util.Utils

/**
Expand Down Expand Up @@ -102,7 +102,7 @@ class SparkConnectAddArtifactsHandler(val responseObserver: StreamObserver[AddAr
// summary and it is up to the client to decide whether to retry sending the artifact.
if (artifact.getCrcStatus.contains(true)) {
if (artifact.path.startsWith(
SparkConnectArtifactManager.forwardToFSPrefix + File.separator)) {
ArtifactManager.forwardToFSPrefix + File.separator)) {
holder.artifactManager.uploadArtifactToFs(artifact.path, artifact.stagedPath)
} else {
addStagedArtifactToArtifactManager(artifact)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class SparkConnectArtifactStatusesHandler(
.getOrCreateIsolatedSession(userId, sessionId)
.session
val blockManager = session.sparkContext.env.blockManager
blockManager.getStatus(CacheId(userId, sessionId, hash)).isDefined
blockManager.getStatus(CacheId(session.sessionUUID, hash)).isDefined
}

def handle(request: proto.ArtifactStatusesRequest): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession {
SimplePythonFunction(
command = fcn(sparkPythonPath),
envVars = mutable.Map("PYTHONPATH" -> sparkPythonPath).asJava,
pythonIncludes = sessionHolder.artifactManager.getSparkConnectPythonIncludes.asJava,
pythonIncludes = sessionHolder.artifactManager.getPythonIncludes.asJava,
pythonExec = IntegratedUDFTestUtils.pythonExec,
pythonVer = IntegratedUDFTestUtils.pythonVer,
broadcastVars = Lists.newArrayList(),
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/storage/BlockId.scala
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,8 @@ class UnrecognizedBlockId(name: String)
extends SparkException(s"Failed to parse $name into a block ID")

@DeveloperApi
case class CacheId(userId: String, sessionId: String, hash: String) extends BlockId {
override def name: String = s"cache_${userId}_${sessionId}_$hash"
case class CacheId(sessionUUID: String, hash: String) extends BlockId {
override def name: String = s"cache_${sessionUUID}_$hash"
}

@DeveloperApi
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2048,10 +2048,10 @@ private[spark] class BlockManager(
*
* @return The number of blocks removed.
*/
def removeCache(userId: String, sessionId: String): Int = {
logDebug(s"Removing cache of user id = $userId in the session $sessionId")
def removeCache(sessionUUID: String): Int = {
logDebug(s"Removing cache of spark session with UUID: $sessionUUID")
val blocksToRemove = blockInfoManager.entries.map(_._1).collect {
case cid: CacheId if cid.userId == userId && cid.sessionId == sessionId => cid
case cid: CacheId if cid.sessionUUID == sessionUUID => cid
}
blocksToRemove.foreach { blockId => removeBlock(blockId) }
blocksToRemove.size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class ClassificationTestsOnConnect(ClassificationTestsMixin, unittest.TestCase):
def setUp(self) -> None:
self.spark = (
SparkSession.builder.remote("local[2]")
.config("spark.connect.copyFromLocalToFs.allowDestLocal", "true")
.config("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true")
.getOrCreate()
)

Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/ml/tests/connect/test_connect_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class PipelineTestsOnConnect(PipelineTestsMixin, unittest.TestCase):
def setUp(self) -> None:
self.spark = (
SparkSession.builder.remote("local[2]")
.config("spark.connect.copyFromLocalToFs.allowDestLocal", "true")
.config("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true")
.getOrCreate()
)

Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/ml/tests/connect/test_connect_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class CrossValidatorTestsOnConnect(CrossValidatorTestsMixin, unittest.TestCase):
def setUp(self) -> None:
self.spark = (
SparkSession.builder.remote("local[2]")
.config("spark.connect.copyFromLocalToFs.allowDestLocal", "true")
.config("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true")
.getOrCreate()
)

Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/tests/connect/client/test_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def setUpClass(cls):
@classmethod
def conf(cls):
conf = super().conf()
conf.set("spark.connect.copyFromLocalToFs.allowDestLocal", "true")
conf.set("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true")
return conf

def test_basic_requests(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4531,6 +4531,20 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val ARTIFACT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL =
buildConf("spark.sql.artifact.copyFromLocalToFs.allowDestLocal")
.internal()
.doc("""
|Allow `spark.copyFromLocalToFs` destination to be local file system
| path on spark driver node when
|`spark.sql.artifact.copyFromLocalToFs.allowDestLocal` is true.
|This will allow user to overwrite arbitrary file on spark
|driver node we should only enable it for testing purpose.
|""".stripMargin)
.version("4.0.0")
.booleanConf
.createWithDefault(false)

/**
* Holds information about keys that have been deprecated.
*
Expand Down Expand Up @@ -5414,6 +5428,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def legacyRaiseErrorWithoutErrorClass: Boolean =
getConf(SQLConf.LEGACY_RAISE_ERROR_WITHOUT_ERROR_CLASS)

def allowDestLocalFs: Boolean = getConf(ARTIFACT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL)

/** ********************** SQLConf functionality methods ************ */

/** Set Spark SQL configuration properties. */
Expand Down
11 changes: 11 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.{ConfigEntry, EXECUTOR_ALLOW_SPARK_CONTEXT}
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.artifact.ArtifactManager
import org.apache.spark.sql.catalog.Catalog
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.analysis.{NameParameterizedQuery, PosParameterizedQuery, UnresolvedRelation}
Expand Down Expand Up @@ -242,6 +243,16 @@ class SparkSession private(
@Unstable
def streams: StreamingQueryManager = sessionState.streamingQueryManager

/**
* Returns an `ArtifactManager` that supports adding, managing and using session-scoped artifacts
* (jars, classfiles, etc).
*
* @since 3.5.1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be 4.0.0 because this PR is for Apache Spark 4.0.0, @vicennial .

*/
@Experimental
@Unstable
private[sql] def artifactManager: ArtifactManager = sessionState.artifactManager

/**
* Start a new session with isolated SQL configurations, temporary tables, registered
* functions are isolated, but sharing the underlying `SparkContext` and cached data.
Expand Down
Loading