Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
scala client, without tests
  • Loading branch information
juliuszsompolski committed Jul 20, 2023
commit a297d45eac191c2f450da81c0c90e95db09301b6
Original file line number Diff line number Diff line change
Expand Up @@ -613,16 +613,30 @@ class SparkSession private[sql] (
/**
* Interrupt all operations of this session currently running on the connected server.
*
* TODO/WIP: Currently it will interrupt the Spark Jobs running on the server, triggered from
* ExecutePlan requests. If an operation is not running a Spark Job, it becomes an noop and the
* operation will continue afterwards, possibly with more Spark Jobs.
*
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has actually been fixed by #41315
(now the execution is in different thread, and the interrupt interrupts that thread, not only Spark Jobs.

* @since 3.5.0
*/
def interruptAll(): Unit = {
client.interruptAll()
}

/**
* Interrupt all operations of this session with the given operation tag.
*
* @since 3.5.0
*/
def interruptTag(tag: String): Unit = {
client.interruptTag(tag)
}

/**
* Interrupt an operation of this session with the given operationId.
*
* @since 3.5.0
*/
def interruptOperation(operationId: String): Unit = {
client.interruptOperation(operationId)
}

/**
* Synonym for `close()`.
*
Expand All @@ -641,6 +655,50 @@ class SparkSession private[sql] (
allocator.close()
SparkSession.onSessionClose(this)
}

/**
* Add a tag to be assigned to all the operations started by this thread in this session.
*
* @param tag
* The tag to be added. Cannot contain ',' (comma) character or be an empty string.
*
* @since 3.5.0
*/
def addTag(tag: String): Unit = {
client.addTag(tag)
}

/**
* Remove a tag previously added to be assigned to all the operations started by this thread in
* this session. Noop if such a tag was not added earlier.
*
* @param tag
* The tag to be removed. Cannot contain ',' (comma) character or be an empty string.
*
* @since 3.5.0
*/
def removeTag(tag: String): Unit = {
client.removeTag(tag)
}

/**
* Get the tags that are currently set to be assigned to all the operations started by this
* thread.
*
* @since 3.5.0
*/
def getTags(): Set[String] = {
client.getTags()
}

/**
* Clear the current thread's operation tags.
*
* @since 3.5.0
*/
def clearTags(): Unit = {
client.clearTags()
}
}

// The minimal builder needed to create a spark session.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@ import java.net.URI
import java.util.UUID
import java.util.concurrent.Executor

import scala.collection.JavaConverters._
import scala.collection.mutable

import com.google.protobuf.ByteString
import io.grpc._

import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.UserContext
import org.apache.spark.sql.connect.common.ProtoUtils
import org.apache.spark.sql.connect.common.config.ConnectCommon

/**
Expand Down Expand Up @@ -76,6 +80,7 @@ private[sql] class SparkConnectClient(
.setUserContext(userContext)
.setSessionId(sessionId)
.setClientType(userAgent)
.addAllTags(tags.get.toSeq.asJava)
.build()
bstub.executePlan(request)
}
Expand Down Expand Up @@ -195,6 +200,57 @@ private[sql] class SparkConnectClient(
bstub.interrupt(request)
}

private[sql] def interruptTag(tag: String): proto.InterruptResponse = {
val builder = proto.InterruptRequest.newBuilder()
val request = builder
.setUserContext(userContext)
.setSessionId(sessionId)
.setClientType(userAgent)
.setInterruptType(proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_TAG)
.setOperationTag(tag)
.build()
bstub.interrupt(request)
}

private[sql] def interruptOperation(id: String): proto.InterruptResponse = {
val builder = proto.InterruptRequest.newBuilder()
val request = builder
.setUserContext(userContext)
.setSessionId(sessionId)
.setClientType(userAgent)
.setInterruptType(proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_OPERATION_ID)
.setOperationId(id)
.build()
bstub.interrupt(request)
}

private[this] val tags = new InheritableThreadLocal[mutable.Set[String]] {
override def childValue(parent: mutable.Set[String]): mutable.Set[String] = {
// Note: make a clone such that changes in the parent tags aren't reflected in
// the those of the children threads.
parent.clone()
}
override protected def initialValue(): mutable.Set[String] = new mutable.HashSet[String]()
}

private[sql] def addTag(tag: String): Unit = {
ProtoUtils.throwIfInvalidTag(tag)
tags.get += tag
}

private[sql] def removeTag(tag: String): Unit = {
ProtoUtils.throwIfInvalidTag(tag)
tags.get.remove(tag)
}

private[sql] def getTags(): Set[String] = {
tags.get.toSet
}

private[sql] def clearTags(): Unit = {
tags.get.clear()
}

def copy(): SparkConnectClient = configuration.toSparkConnectClient

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ private[sql] class SparkResult[T](
extends AutoCloseable
with Cleanable { self =>

private[this] var opId: String = null
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is using Option[String] more ideomatic?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but I adapted to the surrounding style of

  private[this] var opId: String = null
  private[this] var numRecords: Int = 0
  private[this] var structType: StructType = _
  private[this] var arrowSchema: pojo.Schema = _
  private[this] var nextResultIndex: Int = 0

From

  def schema: StructType = {
    if (structType == null) {
      processResponses(stopOnSchema = true)
    }
    structType
  }

it looks like I can assume _ is null, so I could make it _ and maybe that's more idiomatic.

private[this] var numRecords: Int = 0
private[this] var structType: StructType = _
private[this] var arrowSchema: pojo.Schema = _
Expand Down Expand Up @@ -79,6 +80,19 @@ private[sql] class SparkResult[T](
var stop = false
while (!stop && responses.hasNext) {
val response = responses.next()

// Save and validate operationId
if (opId == null) {
opId = response.getOperationId
}
if (opId != response.getOperationId) {
// backwards compatibility:
// response from an old server without operationId field would have getOperationId == "".
throw new IllegalStateException(
"Received response with wrong operationId. " +
s"Expected '$opId' but received '${response.getOperationId}'.")
}

if (response.hasSchema) {
// The original schema should arrive before ArrowBatches.
structType =
Expand Down Expand Up @@ -148,6 +162,15 @@ private[sql] class SparkResult[T](
structType
}

/**
* @return
* the operationId of the result.
*/
def operationId: String = {
processResponses(stopOnFirstNonEmptyResponse = true)
opId
}

/**
* Create an Array with the contents of the result.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ message ExecutePlanRequest {
}

// Tags to tag the given execution with.
// Tags cannot contain ',' character and cannot be empty strings.
// Used by Interrupt with interrupt.tag.
repeated string tags = 7;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,28 @@ private[connect] object ProtoUtils {
private def createString(prefix: String, size: Int): String = {
s"$prefix[truncated(size=${format.format(size)})]"
}

// Because Spark Connect operation tags are also set as SparkContext Job tags, they cannot contain
// SparkContext.SPARK_JOB_TAGS_SEP
private var SPARK_JOB_TAGS_SEP = ',' // SparkContext.SPARK_JOB_TAGS_SEP

/**
* Validate if a tag for ExecutePlanRequest.tags is valid. Throw IllegalArgumentException if
* not.
*/
def throwIfInvalidTag(tag: String): Unit = {
// Same format rules apply to Spark Connect execution tags as to SparkContext job tags,
// because the Spark Connect job tag is also used as part of SparkContext job tag.
// See SparkContext.throwIfInvalidTag and ExecuteHolder.tagToSparkJobTag
if (tag == null) {
throw new IllegalArgumentException("Spark Connect execution tag cannot be null.")
}
if (tag.contains(SPARK_JOB_TAGS_SEP)) {
throw new IllegalArgumentException(
s"Spark Connect execution tag cannot contain '$SPARK_JOB_TAGS_SEP'.")
}
if (tag.isEmpty) {
throw new IllegalArgumentException("Spark Connect execution tag cannot be an empty string.")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ import scala.collection.mutable

import io.grpc.stub.StreamObserver

import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
import org.apache.spark.sql.connect.service.ExecuteHolder

/**
* This StreamObserver is running on the execution thread. Execution pushes responses to it, it
Expand All @@ -40,7 +42,9 @@ import org.apache.spark.internal.Logging
* @see
* attachConsumer
*/
private[connect] class ExecuteResponseObserver[T]() extends StreamObserver[T] with Logging {
private[connect] class ExecuteResponseObserver[T](val executeHolder: ExecuteHolder)
extends StreamObserver[T]
with Logging {

/**
* Cached responses produced by the execution. Map from response index -> response. Response
Expand Down Expand Up @@ -77,7 +81,9 @@ private[connect] class ExecuteResponseObserver[T]() extends StreamObserver[T] wi
throw new IllegalStateException("Stream onNext can't be called after stream completed")
}
lastProducedIndex += 1
responses += ((lastProducedIndex, CachedStreamResponse[T](r, lastProducedIndex)))
val processedResponse = setCommonResponseFields(r)
responses +=
((lastProducedIndex, CachedStreamResponse[T](processedResponse, lastProducedIndex)))
logDebug(s"Saved response with index=$lastProducedIndex")
notifyAll()
}
Expand Down Expand Up @@ -158,4 +164,19 @@ private[connect] class ExecuteResponseObserver[T]() extends StreamObserver[T] wi
i -= 1
}
}

/**
* Make sure that response fields that common should be set in every response are populated.
*/
private def setCommonResponseFields(response: T): T = {
response match {
case executePlanResponse: proto.ExecutePlanResponse =>
executePlanResponse
.toBuilder()
.setSessionId(executeHolder.sessionHolder.sessionId)
.setOperationId(executeHolder.operationId)
.build()
.asInstanceOf[T]
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ package org.apache.spark.sql.connect.service

import scala.collection.JavaConverters._

import org.apache.spark.SparkContext
import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
import org.apache.spark.sql.connect.common.ProtoUtils
import org.apache.spark.sql.connect.execution.{ExecuteGrpcResponseSender, ExecuteResponseObserver, ExecuteThreadRunner}
import org.apache.spark.util.SystemClock

Expand All @@ -40,15 +40,20 @@ private[connect] class ExecuteHolder(
s"Session_${sessionHolder.sessionId}_" +
s"Request_${operationId}"

val userDefinedTags: Seq[String] = request.getTagsList().asScala.toSeq.map { tag =>
throwIfInvalidTag(tag)
tag
}
val userDefinedTags: Set[String] = request
.getTagsList()
.asScala
.toSeq
.map { tag =>
ProtoUtils.throwIfInvalidTag(tag)
tag
}
.toSet

val session = sessionHolder.session

val responseObserver: ExecuteResponseObserver[proto.ExecutePlanResponse] =
new ExecuteResponseObserver[proto.ExecutePlanResponse]()
new ExecuteResponseObserver[proto.ExecutePlanResponse](this)

val eventsManager: ExecuteEventsManager = ExecuteEventsManager(this, new SystemClock())

Expand Down Expand Up @@ -98,23 +103,12 @@ private[connect] class ExecuteHolder(
runner.interrupt()
}

/**
* Spark Connect tags are also added as SparkContext job tags, but to make the tag unique, they
* need to be combined with userId and sessionId.
*/
def tagToSparkJobTag(tag: String): String = {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@juliuszsompolski input tag isn't used for output, which doesn't look intended

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for spotting!
@HyukjinKwon could you maybe piggy back changing it to maybe

    "SparkConnect_Execute_" +
      s"User_${sessionHolder.userId}_Session_${sessionHolder.sessionId}_Tag_${tag}"

to #42120 ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

"SparkConnectUserDefinedTag_" +
s"User_${sessionHolder.userId}_Session_${sessionHolder.sessionId}"
}

private def throwIfInvalidTag(tag: String) = {
// Same format rules apply to Spark Connect execution tags as to SparkContext job tags.
// see SparkContext.throwIfInvalidTag.
if (tag == null) {
throw new IllegalArgumentException("Spark Connect execution tag cannot be null.")
}
if (tag.contains(SparkContext.SPARK_JOB_TAGS_SEP)) {
throw new IllegalArgumentException(
s"Spark Connect execution tag cannot contain '${SparkContext.SPARK_JOB_TAGS_SEP}'.")
}
if (tag.isEmpty) {
throw new IllegalArgumentException("Spark Connect execution tag cannot be an empty string.")
}
}
}