Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add row for command and add test
  • Loading branch information
Lingkai Kong committed Aug 14, 2023
commit 2e3252d61fc4ea046efaff134178b031ca73a20a
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,10 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)

dataframe.queryExecution.executedPlan match {
case LocalTableScanExec(_, rows) =>
executePlan.eventsManager.postFinished(Some(totalNumRows))
converter(rows.iterator).foreach { case (bytes, count) =>
sendBatch(bytes, count)
}
executePlan.eventsManager.postFinished(Some(totalNumRows))
case _ =>
SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
val rows = dataframe.queryExecution.executedPlan.execute()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2480,7 +2480,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
.putAllArgs(getSqlCommand.getArgsMap)
.addAllPosArgs(getSqlCommand.getPosArgsList)))
}
executeHolder.eventsManager.postFinished()
executeHolder.eventsManager.postFinished(Some(rows.size))
// Exactly one SQL Command Result Batch
responseObserver.onNext(
ExecutePlanResponse
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.CreateDataFrameViewCommand
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent}
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection}
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.dsl.MockRemoteSession
Expand All @@ -49,13 +50,18 @@ import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteStatus, Sessi
import org.apache.spark.sql.connector.catalog.InMemoryPartitionTableCatalog
import org.apache.spark.sql.streaming.StreamingQuery
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils

/**
* Testing Connect Service implementation.
*/
class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with Logging {
class SparkConnectServiceSuite
extends SharedSparkSession
with MockitoSugar
with Logging
with SparkConnectPlanTest{

private def sparkSessionHolder = SessionHolder.forTesting(spark)
private def DEFAULT_UUID = UUID.fromString("89ea6117-1f45-4c03-ae27-f47c6aded093")
Expand Down Expand Up @@ -238,6 +244,79 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with
}
}

test("SPARK-44776: LocalTableScanExec") {
withEvents { verifyEvents =>
// TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21
assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17))
val instance = new SparkConnectService(false)
val connect = new MockRemoteSession()
val context = proto.UserContext
.newBuilder()
.setUserId("c1")
.build()

val rows = (0L to 5L).map { i =>
new GenericInternalRow(Array(i, UTF8String.fromString("" + (i - 1 + 'a').toChar)))
}

val schema = StructType(Array(StructField("id", LongType), StructField("data", StringType)))
val inputRows = rows.map { row =>
val proj = UnsafeProjection.create(schema)
proj(row).copy()
}

val localRelation = createLocalRelationProto(schema, inputRows)
val plan = proto.Plan
.newBuilder()
.setRoot(
localRelation
)
.build()

val request = proto.ExecutePlanRequest
.newBuilder()
.setPlan(plan)
.setUserContext(context)
.setSessionId(UUID.randomUUID.toString())
.build()

// Execute plan.
@volatile var done = false
val responses = mutable.Buffer.empty[proto.ExecutePlanResponse]
instance.executePlan(
request,
new StreamObserver[proto.ExecutePlanResponse] {
override def onNext(v: proto.ExecutePlanResponse): Unit = {
responses += v
verifyEvents.onNext(v)
}

override def onError(throwable: Throwable): Unit = {
verifyEvents.onError(throwable)
throw throwable
}

override def onCompleted(): Unit = {
done = true
}
})
verifyEvents.onCompleted(Some(6))
// The current implementation is expected to be blocking. This is here to make sure it is.
assert(done)

// 1 Partitions + Metrics
assert(responses.size == 3)

// Make sure the first response is schema only
val head = responses.head
assert(head.hasSchema && !head.hasArrowBatch && !head.hasMetrics)

// Make sure the last response is metrics only
val last = responses.last
assert(last.hasMetrics && !last.hasSchema && !last.hasArrowBatch)
}
}

test("SPARK-44657: Arrow batches respect max batch size limit") {
// Set 10 KiB as the batch size limit
val batchSize = 10 * 1024
Expand Down Expand Up @@ -301,13 +380,20 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with

gridTest("SPARK-43923: commands send events")(
Seq(
proto.Command
(
proto.Command
.newBuilder()
.setSqlCommand(proto.SqlCommand.newBuilder().setSql("select 1").build()),
proto.Command
Some(0L)
),
(
proto.Command
.newBuilder()
.setSqlCommand(proto.SqlCommand.newBuilder().setSql("show tables").build()),
proto.Command
.setSqlCommand(proto.SqlCommand.newBuilder().setSql("show databases").build()),
Some(1L)
),
(
proto.Command
.newBuilder()
.setWriteOperation(
proto.WriteOperation
Expand All @@ -316,7 +402,10 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with
proto.Relation.newBuilder().setSql(proto.SQL.newBuilder().setQuery("select 1")))
.setPath(Utils.createTempDir().getAbsolutePath)
.setMode(proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE)),
proto.Command
None
),
(
proto.Command
.newBuilder()
.setWriteOperationV2(
proto.WriteOperationV2
Expand All @@ -325,26 +414,36 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with
proto.Range.newBuilder().setStart(0).setEnd(2).setStep(1L)))
.setTableName("testcat.testtable")
.setMode(proto.WriteOperationV2.Mode.MODE_CREATE)),
proto.Command
None
),
(
proto.Command
.newBuilder()
.setCreateDataframeView(
CreateDataFrameViewCommand
.newBuilder()
.setName("testview")
.setInput(
proto.Relation.newBuilder().setSql(proto.SQL.newBuilder().setQuery("select 1")))),
proto.Command
None
),
(proto.Command
.newBuilder()
.setGetResourcesCommand(proto.GetResourcesCommand.newBuilder()),
proto.Command
None),
(
proto.Command
.newBuilder()
.setExtension(
protobuf.Any.pack(
proto.ExamplePluginCommand
.newBuilder()
.setCustomField("SPARK-43923")
.build())),
proto.Command
None
),
(
proto.Command
.newBuilder()
.setWriteStreamOperationStart(
proto.WriteStreamOperationStart
Expand All @@ -365,7 +464,10 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with
.putOptions("checkpointLocation", Utils.createTempDir().getAbsolutePath)
.setPath("test-path")
.build()),
proto.Command
None
),
(
proto.Command
.newBuilder()
.setStreamingQueryCommand(
proto.StreamingQueryCommand
Expand All @@ -377,12 +479,18 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with
.setRunId(DEFAULT_UUID.toString)
.build())
.setStop(true)),
proto.Command
None
),
(
proto.Command
.newBuilder()
.setStreamingQueryManagerCommand(proto.StreamingQueryManagerCommand
.newBuilder()
.setListListeners(true)),
proto.Command
None
),
(
proto.Command
.newBuilder()
.setRegisterFunction(
proto.CommonInlineUserDefinedFunction
Expand All @@ -395,7 +503,11 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with
.setOutputType(DataTypeProtoConverter.toConnectProtoType(IntegerType))
.setCommand(ByteString.copyFrom("command".getBytes()))
.setPythonVer("3.10")
.build())))) { command =>
.build())),
None
)
)
) { case (command, producedNumRows) =>
val sessionId = UUID.randomUUID.toString()
withCommandTest(sessionId) { verifyEvents =>
val instance = new SparkConnectService(false)
Expand Down Expand Up @@ -435,7 +547,7 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with
done = true
}
})
verifyEvents.onCompleted()
verifyEvents.onCompleted(producedNumRows)
// The current implementation is expected to be blocking.
// This is here to make sure it is.
assert(done)
Expand Down