Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import scala.collection.mutable

import com.google.common.collect.{Lists, Maps}
import com.google.protobuf.{Any => ProtoAny, ByteString}
import io.grpc.stub.StreamObserver
import org.apache.commons.lang3.exception.ExceptionUtils

import org.apache.spark.{Partition, SparkEnv, TaskContext}
Expand Down Expand Up @@ -2083,7 +2082,7 @@ class SparkConnectPlanner(val session: SparkSession) {
command: proto.Command,
userId: String,
sessionId: String,
responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
responses: mutable.ArrayBuffer[ExecutePlanResponse] = mutable.ArrayBuffer.empty): Unit = {
command.getCommandTypeCase match {
case proto.Command.CommandTypeCase.REGISTER_FUNCTION =>
handleRegisterUserDefinedFunction(command.getRegisterFunction)
Expand All @@ -2096,30 +2095,30 @@ class SparkConnectPlanner(val session: SparkSession) {
case proto.Command.CommandTypeCase.EXTENSION =>
handleCommandPlugin(command.getExtension)
case proto.Command.CommandTypeCase.SQL_COMMAND =>
handleSqlCommand(command.getSqlCommand, sessionId, responseObserver)
handleSqlCommand(command.getSqlCommand, sessionId, responses)
case proto.Command.CommandTypeCase.WRITE_STREAM_OPERATION_START =>
handleWriteStreamOperationStart(
command.getWriteStreamOperationStart,
userId,
sessionId,
responseObserver)
responses)
case proto.Command.CommandTypeCase.STREAMING_QUERY_COMMAND =>
handleStreamingQueryCommand(command.getStreamingQueryCommand, sessionId, responseObserver)
handleStreamingQueryCommand(command.getStreamingQueryCommand, sessionId, responses)
case proto.Command.CommandTypeCase.STREAMING_QUERY_MANAGER_COMMAND =>
handleStreamingQueryManagerCommand(
command.getStreamingQueryManagerCommand,
sessionId,
responseObserver)
responses)
case proto.Command.CommandTypeCase.GET_RESOURCES_COMMAND =>
handleGetResourcesCommand(sessionId, responseObserver)
handleGetResourcesCommand(sessionId, responses)
case _ => throw new UnsupportedOperationException(s"$command not supported.")
}
}

def handleSqlCommand(
getSqlCommand: SqlCommand,
sessionId: String,
responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
responses: mutable.ArrayBuffer[ExecutePlanResponse]): Unit = {
// Eagerly execute commands of the provided SQL string.
val df = session.sql(
getSqlCommand.getSql,
Expand Down Expand Up @@ -2180,15 +2179,15 @@ class SparkConnectPlanner(val session: SparkSession) {
.putAllArgs(getSqlCommand.getArgsMap)))
}
// Exactly one SQL Command Result Batch
responseObserver.onNext(
responses +=
ExecutePlanResponse
.newBuilder()
.setSessionId(sessionId)
.setSqlCommandResult(result)
.build())
.build()

// Send Metrics
responseObserver.onNext(SparkConnectStreamHandler.createMetricsResponse(sessionId, df))
responses += SparkConnectStreamHandler.createMetricsResponse(sessionId, df)
}

private def handleRegisterUserDefinedFunction(
Expand Down Expand Up @@ -2408,7 +2407,7 @@ class SparkConnectPlanner(val session: SparkSession) {
writeOp: WriteStreamOperationStart,
userId: String,
sessionId: String,
responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
responses: mutable.ArrayBuffer[ExecutePlanResponse]): Unit = {
val plan = transformRelation(writeOp.getInput)
val dataset = Dataset.ofRows(session, logicalPlan = plan)

Expand Down Expand Up @@ -2473,18 +2472,18 @@ class SparkConnectPlanner(val session: SparkSession) {
.setName(Option(query.name).getOrElse(""))
.build()

responseObserver.onNext(
responses +=
ExecutePlanResponse
.newBuilder()
.setSessionId(sessionId)
.setWriteStreamOperationStartResult(result)
.build())
.build()
}

def handleStreamingQueryCommand(
command: StreamingQueryCommand,
sessionId: String,
responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
responses: mutable.ArrayBuffer[ExecutePlanResponse]): Unit = {

val id = command.getQueryId.getId
val runId = command.getQueryId.getRunId
Expand Down Expand Up @@ -2589,12 +2588,12 @@ class SparkConnectPlanner(val session: SparkSession) {
throw new IllegalArgumentException("Missing command in StreamingQueryCommand")
}

responseObserver.onNext(
responses +=
ExecutePlanResponse
.newBuilder()
.setSessionId(sessionId)
.setStreamingQueryCommandResult(respBuilder.build())
.build())
.build()
}

private def buildStreamingQueryInstance(query: StreamingQuery): StreamingQueryInstance = {
Expand All @@ -2615,7 +2614,7 @@ class SparkConnectPlanner(val session: SparkSession) {
def handleStreamingQueryManagerCommand(
command: StreamingQueryManagerCommand,
sessionId: String,
responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
responses: mutable.ArrayBuffer[ExecutePlanResponse]): Unit = {

val respBuilder = StreamingQueryManagerCommandResult.newBuilder()

Expand Down Expand Up @@ -2650,18 +2649,18 @@ class SparkConnectPlanner(val session: SparkSession) {
throw new IllegalArgumentException("Missing command in StreamingQueryManagerCommand")
}

responseObserver.onNext(
responses +=
ExecutePlanResponse
.newBuilder()
.setSessionId(sessionId)
.setStreamingQueryManagerCommandResult(respBuilder.build())
.build())
.build()
}

def handleGetResourcesCommand(
sessionId: String,
responseObserver: StreamObserver[proto.ExecutePlanResponse]): Unit = {
responseObserver.onNext(
responses: mutable.ArrayBuffer[ExecutePlanResponse]): Unit = {
responses +=
proto.ExecutePlanResponse
.newBuilder()
.setSessionId(sessionId)
Expand All @@ -2679,7 +2678,7 @@ class SparkConnectPlanner(val session: SparkSession) {
.toMap
.asJava)
.build())
.build())
.build()
}

private val emptyLocalRelation = LocalRelation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.connect.service

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.control.NonFatal

import com.google.protobuf.ByteString
Expand Down Expand Up @@ -113,11 +114,13 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp
private def handleCommand(session: SparkSession, request: ExecutePlanRequest): Unit = {
val command = request.getPlan.getCommand
val planner = new SparkConnectPlanner(session)
val responses = mutable.ArrayBuffer.empty[ExecutePlanResponse]
planner.process(
command = command,
userId = request.getUserContext.getUserId,
sessionId = request.getSessionId,
responseObserver = responseObserver)
responses = responses)
responses.foreach(responseObserver.onNext(_))
responseObserver.onCompleted()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,9 @@ package org.apache.spark.sql.connect.planner
import scala.collection.JavaConverters._

import com.google.protobuf.ByteString
import io.grpc.stub.StreamObserver

import org.apache.spark.SparkFunSuite
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.ExecutePlanResponse
import org.apache.spark.connect.proto.Expression.{Alias, ExpressionString, UnresolvedStar}
import org.apache.spark.sql.{AnalysisException, Dataset, Row}
import org.apache.spark.sql.catalyst.InternalRow
Expand All @@ -44,18 +42,12 @@ import org.apache.spark.unsafe.types.UTF8String
*/
trait SparkConnectPlanTest extends SharedSparkSession {

class MockObserver extends StreamObserver[proto.ExecutePlanResponse] {
override def onNext(value: ExecutePlanResponse): Unit = {}
override def onError(t: Throwable): Unit = {}
override def onCompleted(): Unit = {}
}

def transform(rel: proto.Relation): logical.LogicalPlan = {
new SparkConnectPlanner(spark).transformRelation(rel)
}

def transform(cmd: proto.Command): Unit = {
new SparkConnectPlanner(spark).process(cmd, "clientId", "sessionId", new MockObserver())
new SparkConnectPlanner(spark).process(cmd, "clientId", "sessionId")
}

def readRel: proto.Relation =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ class SparkConnectPluginRegistrySuite extends SharedSparkSession with SparkConne
.build()))
.build()

new SparkConnectPlanner(spark).process(plan, "clientId", "sessionId", new MockObserver())
new SparkConnectPlanner(spark).process(plan, "clientId", "sessionId")
assert(spark.sparkContext.getLocalProperty("testingProperty").equals("Martin"))
}
}
Expand Down