Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,34 @@ message FetchErrorDetailsResponse {
}
}

message BuildResourceProfileRequest {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would need @grundprinzip and @hvanhovell 's review.

// (Required)
//
// The session_id specifies a spark session for a user id (which is specified
// by user_context.user_id). The session_id is set by the client to be able to
// collate streaming responses from different queries within the dedicated session.
// The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff`
string session_id = 1;

// (Required) User context
UserContext user_context = 2;

// (Required) The ResourceProfile to be built on the server-side.
ResourceProfile profile = 3;
}

// Response to building resource profile.
message BuildResourceProfileResponse {
string session_id = 1;

// Server-side generated idempotency key that the client can use to assert that the server side
// session has not changed.
string server_side_session_id = 4;

// (Required) Server-side generated resource profile id.
int32 profile_id = 5;
}

// Main interface for the SparkConnect service.
service SparkConnectService {

Expand Down Expand Up @@ -1011,5 +1039,7 @@ service SparkConnectService {

// FetchErrorDetails retrieves the matched exception with details based on a provided error id.
rpc FetchErrorDetails(FetchErrorDetailsRequest) returns (FetchErrorDetailsResponse) {}
}

// Build ResourceProfile and get the profile id
rpc BuildResourceProfile(BuildResourceProfileRequest) returns (BuildResourceProfileResponse) {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this need to be an extra RPC and not just a command?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @grundprinzip, Really good suggestion, just made the newest commit to move it to the command.

}
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,38 @@ message ResourceInformation {
// (Required) An array of strings describing the addresses of the resource.
repeated string addresses = 2;
}

// An executor resource request.
message ExecutorResourceRequest {
// (Required) resource name.
string resource_name = 1;

// (Required) resource amount requesting.
int64 amount = 2;

// Optional script used to discover the resources.
optional string discovery_script = 3;

// Optional vendor, required for some cluster managers.
optional string vendor = 4;
}

// A task resource request.
message TaskResourceRequest {
// (Required) resource name.
string resource_name = 1;

// (Required) resource amount requesting as a double to support fractional
// resource requests.
double amount = 2;
}

message ResourceProfile {
// (Optional) Resource requests for executors. Mapped from the resource name
// (e.g., cores, memory, CPU) to its specific request.
map<string, ExecutorResourceRequest> executor_resources = 1;

// (Optional) Resource requests for tasks. Mapped from the resource name
// (e.g., cores, memory, CPU) to its specific request.
map<string, TaskResourceRequest> task_resources = 2;
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import "google/protobuf/any.proto";
import "spark/connect/expressions.proto";
import "spark/connect/types.proto";
import "spark/connect/catalog.proto";
import "spark/connect/common.proto";

option java_multiple_files = true;
option java_package = "org.apache.spark.connect.proto";
Expand Down Expand Up @@ -892,6 +893,9 @@ message MapPartitions {

// (Optional) Whether to use barrier mode execution or not.
optional bool is_barrier = 3;

// (Optional) ResourceProfile id used for the stage level scheduling.
optional int32 profile_id = 4;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't these be uuids? Just to make sure that we have intentional difference and no off by one?

Copy link
Contributor Author

@wbo4958 wbo4958 Mar 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @grundprinzip,

I'm not quite following you about "by uuids".

So the basic implementation is

  1. The client creates ResourceProfile
  2. if the profile ID of ResourceProfile is accessed for the first time, then the client will ask to create ResourceProfile and add it to the ResourceProfileManager on the server side, and the server side will return the profile ID to the client which will set the id to the ResourceProfile on the client side.
  3. The internal mapInPandas/mapInArrow will just use the ResourceProfile id, and the server side can extract the ResourceProfile from ResourceProfileManager according to the id.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need an extra RPC for that? Can't you attach the resource profile directly to this call?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @grundprinzip, The user still needs to know the exact ResourceProfile id, if we attach resource profile in the call, seems we can't get id in this call.

}

message GroupMap {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -543,21 +543,27 @@ class SparkConnectPlanner(
case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF =>
val pythonUdf = transformPythonUDF(commonUdf)
val isBarrier = if (rel.hasIsBarrier) rel.getIsBarrier else false
val profile = if (rel.hasProfileId) {
val profileId = rel.getProfileId
Some(session.sparkContext.resourceProfileManager.resourceProfileFromId(profileId))
} else {
None
}
pythonUdf.evalType match {
case PythonEvalType.SQL_MAP_PANDAS_ITER_UDF =>
logical.MapInPandas(
pythonUdf,
DataTypeUtils.toAttributes(pythonUdf.dataType.asInstanceOf[StructType]),
baseRel,
isBarrier,
None)
profile)
case PythonEvalType.SQL_MAP_ARROW_ITER_UDF =>
logical.MapInArrow(
pythonUdf,
DataTypeUtils.toAttributes(pythonUdf.dataType.asInstanceOf[StructType]),
baseRel,
isBarrier,
None)
profile)
case _ =>
throw InvalidPlanInput(
s"Function with EvalType: ${pythonUdf.evalType} is not supported")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connect.service

import scala.jdk.CollectionConverters.MapHasAsScala

import io.grpc.stub.StreamObserver

import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile, TaskResourceProfile, TaskResourceRequest}

class SparkConnectBuildResourceProfileHandler(
responseObserver: StreamObserver[proto.BuildResourceProfileResponse])
extends Logging {

/**
* transform the spark connect ResourceProfile to spark ResourceProfile
* @param rp
* Spark connect ResourceProfile
* @return
* the Spark ResourceProfile
*/
private def transformResourceProfile(rp: proto.ResourceProfile): ResourceProfile = {
val ereqs = rp.getExecutorResourcesMap.asScala.map { case (name, res) =>
name -> new ExecutorResourceRequest(
res.getResourceName,
res.getAmount,
res.getDiscoveryScript,
res.getVendor)
}.toMap
val treqs = rp.getTaskResourcesMap.asScala.map { case (name, res) =>
name -> new TaskResourceRequest(res.getResourceName, res.getAmount)
}.toMap

if (ereqs.isEmpty) {
new TaskResourceProfile(treqs)
} else {
new ResourceProfile(ereqs, treqs)
}
}

def handle(request: proto.BuildResourceProfileRequest): Unit = {
val holder = SparkConnectService
.getOrCreateIsolatedSession(request.getUserContext.getUserId, request.getSessionId)

val rp = transformResourceProfile(request.getProfile)

val session = holder.session
session.sparkContext.resourceProfileManager.addResourceProfile(rp)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are these cleaned up?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, Both ResourceProfile and ResourceProfileManager don't have the cleanup. If you think we need to cleanup, we can file another PR for it.


val builder = proto.BuildResourceProfileResponse.newBuilder()
builder.setProfileId(rp.id)
builder.setSessionId(request.getSessionId)
builder.setServerSideSessionId(holder.serverSessionId)
responseObserver.onNext(builder.build())
responseObserver.onCompleted()
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.commons.lang3.StringUtils

import org.apache.spark.{SparkContext, SparkEnv}
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse, SparkConnectServiceGrpc}
import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse, BuildResourceProfileRequest, BuildResourceProfileResponse, SparkConnectServiceGrpc}
import org.apache.spark.connect.proto.SparkConnectServiceGrpc.AsyncService
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.UI.UI_ENABLED
Expand Down Expand Up @@ -227,6 +227,20 @@ class SparkConnectService(debug: Boolean) extends AsyncService with BindableServ
}
}

override def buildResourceProfile(
request: BuildResourceProfileRequest,
responseObserver: StreamObserver[BuildResourceProfileResponse]): Unit = {
try {
new SparkConnectBuildResourceProfileHandler(responseObserver).handle(request)
} catch {
ErrorUtils.handleError(
"buildResourceProfile",
observer = responseObserver,
userId = request.getUserContext.getUserId,
sessionId = request.getSessionId)
}
}

private def methodWithCustomMarshallers(methodDesc: MethodDescriptor[MessageLite, MessageLite])
: MethodDescriptor[MessageLite, MessageLite] = {
val recursionLimit =
Expand Down
1 change: 1 addition & 0 deletions dev/check_pyspark_custom_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def check_pyspark_custom_errors(target_paths, exclude_paths):
TARGET_PATHS = ["python/pyspark/sql"]
EXCLUDE_PATHS = [
"python/pyspark/sql/tests",
"python/pyspark/sql/connect/resource",
"python/pyspark/sql/connect/proto",
]

Expand Down
2 changes: 2 additions & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,7 @@ def __hash__(self):
"pyspark.resource.profile",
# unittests
"pyspark.resource.tests.test_resources",
"pyspark.resource.tests.test_connect_resources",
Copy link
Contributor

@zhengruifeng zhengruifeng Mar 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test is for spark connect, I think we should move it to Module pyspark_connect?

or maybe we can move the test cases in it to pyspark.sql.tests.connect.test_resources?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pull request already includes pyspark.sql.tests.connect.test_resources to test the general mapInPandas/mapInArrow functionality with ResourceProfile. On the other hand, pyspark.resource.tests.test_connect_resources is specifically for testing special cases like creating a ResourceProfile before establishing a remote session. Therefore, it seems appropriate to keep the tests in their respective locations.

],
)

Expand Down Expand Up @@ -1027,6 +1028,7 @@ def __hash__(self):
"pyspark.sql.tests.connect.test_parity_pandas_udf_scalar",
"pyspark.sql.tests.connect.test_parity_pandas_udf_grouped_agg",
"pyspark.sql.tests.connect.test_parity_pandas_udf_window",
"pyspark.sql.tests.connect.test_resources",
],
excluded_python_implementations=[
"PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and
Expand Down
38 changes: 29 additions & 9 deletions python/pyspark/resource/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

from threading import RLock
from typing import overload, Dict, Union, Optional

from py4j.java_gateway import JavaObject
Expand All @@ -37,6 +37,9 @@ class ResourceProfile:

.. versionadded:: 3.1.0

.. versionchanged:: 4.0.0
Supports Spark Connect.

Notes
-----
This API is evolving.
Expand Down Expand Up @@ -99,6 +102,11 @@ def __init__(
_exec_req: Optional[Dict[str, ExecutorResourceRequest]] = None,
_task_req: Optional[Dict[str, TaskResourceRequest]] = None,
):
# profile id
self._id: Optional[int] = None
# lock to protect _id
self._lock = RLock()

if _java_resource_profile is not None:
self._java_resource_profile = _java_resource_profile
else:
Expand All @@ -114,14 +122,23 @@ def id(self) -> int:
int
A unique id of this :class:`ResourceProfile`
"""
with self._lock:
if self._id is None:
if self._java_resource_profile is not None:
self._id = self._java_resource_profile.id()
else:
from pyspark.sql import is_remote

if self._java_resource_profile is not None:
return self._java_resource_profile.id()
else:
raise RuntimeError(
"SparkContext must be created to get the id, get the id "
"after adding the ResourceProfile to an RDD"
)
if is_remote():
from pyspark.sql.connect.resource.profile import ResourceProfile

rp = ResourceProfile(
self._executor_resource_requests, self._task_resource_requests
)
self._id = rp.id
else:
raise RuntimeError("SparkContext must be created to get the profile id.")
return self._id

@property
def taskResources(self) -> Dict[str, TaskResourceRequest]:
Expand Down Expand Up @@ -185,7 +202,10 @@ def __init__(self) -> None:

# TODO: ignore[attr-defined] will be removed, once SparkContext is inlined
_jvm = SparkContext._jvm
if _jvm is not None:

from pyspark.sql import is_remote

if _jvm is not None and not is_remote():
self._jvm = _jvm
self._java_resource_profile_builder = (
_jvm.org.apache.spark.resource.ResourceProfileBuilder()
Expand Down
8 changes: 6 additions & 2 deletions python/pyspark/resource/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,11 @@ def __init__(
_requests: Optional[Dict[str, ExecutorResourceRequest]] = None,
):
from pyspark import SparkContext
from pyspark.sql import is_remote

_jvm = _jvm or SparkContext._jvm
if _jvm is not None:

if _jvm is not None and not is_remote():
self._java_executor_resource_requests = (
_jvm.org.apache.spark.resource.ExecutorResourceRequests()
)
Expand Down Expand Up @@ -460,9 +462,11 @@ def __init__(
_requests: Optional[Dict[str, TaskResourceRequest]] = None,
):
from pyspark import SparkContext
from pyspark.sql import is_remote

_jvm = _jvm or SparkContext._jvm
if _jvm is not None:

if _jvm is not None and not is_remote():
self._java_task_resource_requests: Optional[
JavaObject
] = _jvm.org.apache.spark.resource.TaskResourceRequests()
Expand Down
Loading