Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,9 @@ message ExecutePlanResponse {
// Response type informing if the stream is complete in reattachable execution.
ResultComplete result_complete = 14;

// Response for command that creates ResourceProfile.
CreateResourceProfileCommandResult create_resource_profile_command_result = 17;

// Support arbitrary result objects.
google.protobuf.Any extension = 999;
}
Expand Down Expand Up @@ -1069,4 +1072,3 @@ service SparkConnectService {
// FetchErrorDetails retrieves the matched exception with details based on a provided error id.
rpc FetchErrorDetails(FetchErrorDetailsRequest) returns (FetchErrorDetailsResponse) {}
}

Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ message Command {
CommonInlineUserDefinedTableFunction register_table_function = 10;
StreamingQueryListenerBusCommand streaming_query_listener_bus_command = 11;
CommonInlineUserDefinedDataSource register_data_source = 12;
CreateResourceProfileCommand create_resource_profile_command = 13;

// This field is used to mark extensions to the protocol. When plugins generate arbitrary
// Commands they can add them here. During the planning the correct resolution is done.
Expand Down Expand Up @@ -468,3 +469,15 @@ message GetResourcesCommand { }
message GetResourcesCommandResult {
map<string, ResourceInformation> resources = 1;
}

// Command to create ResourceProfile
message CreateResourceProfileCommand {
// (Required) The ResourceProfile to be built on the server-side.
ResourceProfile profile = 1;
}

// Response for command 'CreateResourceProfileCommand'.
message CreateResourceProfileCommandResult {
// (Required) Server-side generated resource profile id.
int32 profile_id = 1;
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,38 @@ message ResourceInformation {
// (Required) An array of strings describing the addresses of the resource.
repeated string addresses = 2;
}

// An executor resource request.
message ExecutorResourceRequest {
// (Required) resource name.
string resource_name = 1;

// (Required) resource amount requesting.
int64 amount = 2;

// Optional script used to discover the resources.
optional string discovery_script = 3;

// Optional vendor, required for some cluster managers.
optional string vendor = 4;
}

// A task resource request.
message TaskResourceRequest {
// (Required) resource name.
string resource_name = 1;

// (Required) resource amount requesting as a double to support fractional
// resource requests.
double amount = 2;
}

message ResourceProfile {
// (Optional) Resource requests for executors. Mapped from the resource name
// (e.g., cores, memory, CPU) to its specific request.
map<string, ExecutorResourceRequest> executor_resources = 1;

// (Optional) Resource requests for tasks. Mapped from the resource name
// (e.g., cores, memory, CPU) to its specific request.
map<string, TaskResourceRequest> task_resources = 2;
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import "google/protobuf/any.proto";
import "spark/connect/expressions.proto";
import "spark/connect/types.proto";
import "spark/connect/catalog.proto";
import "spark/connect/common.proto";

option java_multiple_files = true;
option java_package = "org.apache.spark.connect.proto";
Expand Down Expand Up @@ -893,6 +894,9 @@ message MapPartitions {

// (Optional) Whether to use barrier mode execution or not.
optional bool is_barrier = 3;

// (Optional) ResourceProfile id used for the stage level scheduling.
optional int32 profile_id = 4;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't these be uuids? Just to make sure that we have intentional difference and no off by one?

Copy link
Contributor Author

@wbo4958 wbo4958 Mar 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @grundprinzip,

I'm not quite following you about "by uuids".

So the basic implementation is

  1. The client creates ResourceProfile
  2. if the profile ID of ResourceProfile is accessed for the first time, then the client will ask to create ResourceProfile and add it to the ResourceProfileManager on the server side, and the server side will return the profile ID to the client which will set the id to the ResourceProfile on the client side.
  3. The internal mapInPandas/mapInArrow will just use the ResourceProfile id, and the server side can extract the ResourceProfile from ResourceProfileManager according to the id.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need an extra RPC for that? Can't you attach the resource profile directly to this call?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @grundprinzip, The user still needs to know the exact ResourceProfile id, if we attach resource profile in the call, seems we can't get id in this call.

}

message GroupMap {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,14 @@ import org.apache.commons.lang3.exception.ExceptionUtils
import org.apache.spark.{Partition, SparkEnv, TaskContext}
import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction}
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.{ExecutePlanResponse, SqlCommand, StreamingQueryCommand, StreamingQueryCommandResult, StreamingQueryInstanceId, WriteStreamOperationStart, WriteStreamOperationStartResult}
import org.apache.spark.connect.proto.{CreateResourceProfileCommand, ExecutePlanResponse, SqlCommand, StreamingForeachFunction, StreamingQueryCommand, StreamingQueryCommandResult, StreamingQueryInstanceId, StreamingQueryManagerCommand, StreamingQueryManagerCommandResult, WriteStreamOperationStart, WriteStreamOperationStartResult}
import org.apache.spark.connect.proto.ExecutePlanResponse.SqlCommandResult
import org.apache.spark.connect.proto.Parse.ParseFormat
import org.apache.spark.connect.proto.StreamingForeachFunction
import org.apache.spark.connect.proto.StreamingQueryManagerCommand
import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult
import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult.StreamingQueryInstance
import org.apache.spark.connect.proto.WriteStreamOperationStart.TriggerCase
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{functions => MLFunctions}
import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile, TaskResourceProfile, TaskResourceRequest}
import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, SparkSession}
import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro}
import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier}
Expand Down Expand Up @@ -544,21 +542,27 @@ class SparkConnectPlanner(
case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF =>
val pythonUdf = transformPythonUDF(commonUdf)
val isBarrier = if (rel.hasIsBarrier) rel.getIsBarrier else false
val profile = if (rel.hasProfileId) {
val profileId = rel.getProfileId
Some(session.sparkContext.resourceProfileManager.resourceProfileFromId(profileId))
} else {
None
}
pythonUdf.evalType match {
case PythonEvalType.SQL_MAP_PANDAS_ITER_UDF =>
logical.MapInPandas(
pythonUdf,
DataTypeUtils.toAttributes(pythonUdf.dataType.asInstanceOf[StructType]),
baseRel,
isBarrier,
None)
profile)
case PythonEvalType.SQL_MAP_ARROW_ITER_UDF =>
logical.MapInArrow(
pythonUdf,
DataTypeUtils.toAttributes(pythonUdf.dataType.asInstanceOf[StructType]),
baseRel,
isBarrier,
None)
profile)
case _ =>
throw InvalidPlanInput(
s"Function with EvalType: ${pythonUdf.evalType} is not supported")
Expand Down Expand Up @@ -2531,6 +2535,11 @@ class SparkConnectPlanner(
responseObserver)
case proto.Command.CommandTypeCase.GET_RESOURCES_COMMAND =>
handleGetResourcesCommand(responseObserver)
case proto.Command.CommandTypeCase.CREATE_RESOURCE_PROFILE_COMMAND =>
handleCreateResourceProfileCommand(
command.getCreateResourceProfileCommand,
responseObserver)

case _ => throw new UnsupportedOperationException(s"$command not supported.")
}
}
Expand Down Expand Up @@ -3327,6 +3336,43 @@ class SparkConnectPlanner(
.build())
}

def handleCreateResourceProfileCommand(
createResourceProfileCommand: CreateResourceProfileCommand,
responseObserver: StreamObserver[proto.ExecutePlanResponse]): Unit = {
val rp = createResourceProfileCommand.getProfile
val ereqs = rp.getExecutorResourcesMap.asScala.map { case (name, res) =>
name -> new ExecutorResourceRequest(
res.getResourceName,
res.getAmount,
res.getDiscoveryScript,
res.getVendor)
}.toMap
val treqs = rp.getTaskResourcesMap.asScala.map { case (name, res) =>
name -> new TaskResourceRequest(res.getResourceName, res.getAmount)
}.toMap

// Create ResourceProfile add add it to ResourceProfileManager
val profile = if (ereqs.isEmpty) {
new TaskResourceProfile(treqs)
} else {
new ResourceProfile(ereqs, treqs)
}
session.sparkContext.resourceProfileManager.addResourceProfile(profile)

executeHolder.eventsManager.postFinished()
responseObserver.onNext(
proto.ExecutePlanResponse
.newBuilder()
.setSessionId(sessionId)
.setServerSideSessionId(sessionHolder.serverSessionId)
.setCreateResourceProfileCommandResult(
proto.CreateResourceProfileCommandResult
.newBuilder()
.setProfileId(profile.id)
.build())
.build())
}

private val emptyLocalRelation = LocalRelation(
output = AttributeReference("value", StringType, false)() :: Nil,
data = Seq.empty)
Expand Down
1 change: 1 addition & 0 deletions dev/check_pyspark_custom_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def check_pyspark_custom_errors(target_paths, exclude_paths):
TARGET_PATHS = ["python/pyspark/sql"]
EXCLUDE_PATHS = [
"python/pyspark/sql/tests",
"python/pyspark/sql/connect/resource",
"python/pyspark/sql/connect/proto",
]

Expand Down
2 changes: 2 additions & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,7 @@ def __hash__(self):
"pyspark.resource.profile",
# unittests
"pyspark.resource.tests.test_resources",
"pyspark.resource.tests.test_connect_resources",
Copy link
Contributor

@zhengruifeng zhengruifeng Mar 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test is for spark connect, I think we should move it to Module pyspark_connect?

or maybe we can move the test cases in it to pyspark.sql.tests.connect.test_resources?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pull request already includes pyspark.sql.tests.connect.test_resources to test the general mapInPandas/mapInArrow functionality with ResourceProfile. On the other hand, pyspark.resource.tests.test_connect_resources is specifically for testing special cases like creating a ResourceProfile before establishing a remote session. Therefore, it seems appropriate to keep the tests in their respective locations.

],
)

Expand Down Expand Up @@ -1057,6 +1058,7 @@ def __hash__(self):
"pyspark.sql.tests.connect.test_parity_pandas_udf_scalar",
"pyspark.sql.tests.connect.test_parity_pandas_udf_grouped_agg",
"pyspark.sql.tests.connect.test_parity_pandas_udf_window",
"pyspark.sql.tests.connect.test_resources",
],
excluded_python_implementations=[
"PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and
Expand Down
40 changes: 31 additions & 9 deletions python/pyspark/resource/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

from threading import RLock
from typing import overload, Dict, Union, Optional

from py4j.java_gateway import JavaObject
Expand All @@ -37,6 +37,9 @@ class ResourceProfile:

.. versionadded:: 3.1.0

.. versionchanged:: 4.0.0
Supports Spark Connect.

Notes
-----
This API is evolving.
Expand Down Expand Up @@ -99,6 +102,11 @@ def __init__(
_exec_req: Optional[Dict[str, ExecutorResourceRequest]] = None,
_task_req: Optional[Dict[str, TaskResourceRequest]] = None,
):
# profile id
self._id: Optional[int] = None
# lock to protect _id
self._lock = RLock()

if _java_resource_profile is not None:
self._java_resource_profile = _java_resource_profile
else:
Expand All @@ -114,14 +122,25 @@ def id(self) -> int:
int
A unique id of this :class:`ResourceProfile`
"""
with self._lock:
if self._id is None:
if self._java_resource_profile is not None:
self._id = self._java_resource_profile.id()
else:
from pyspark.sql import is_remote

if self._java_resource_profile is not None:
return self._java_resource_profile.id()
else:
raise RuntimeError(
"SparkContext must be created to get the id, get the id "
"after adding the ResourceProfile to an RDD"
)
if is_remote():
from pyspark.sql.connect.resource.profile import ResourceProfile

# Utilize the connect ResourceProfile to create Spark ResourceProfile
# on the server and get the profile ID.
rp = ResourceProfile(
self._executor_resource_requests, self._task_resource_requests
)
self._id = rp.id
else:
raise RuntimeError("SparkContext must be created to get the profile id.")
return self._id

@property
def taskResources(self) -> Dict[str, TaskResourceRequest]:
Expand Down Expand Up @@ -185,7 +204,10 @@ def __init__(self) -> None:

# TODO: ignore[attr-defined] will be removed, once SparkContext is inlined
_jvm = SparkContext._jvm
if _jvm is not None:

from pyspark.sql import is_remote

if _jvm is not None and not is_remote():
self._jvm = _jvm
self._java_resource_profile_builder = (
_jvm.org.apache.spark.resource.ResourceProfileBuilder()
Expand Down
8 changes: 6 additions & 2 deletions python/pyspark/resource/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,11 @@ def __init__(
_requests: Optional[Dict[str, ExecutorResourceRequest]] = None,
):
from pyspark import SparkContext
from pyspark.sql import is_remote

_jvm = _jvm or SparkContext._jvm
if _jvm is not None:

if _jvm is not None and not is_remote():
self._java_executor_resource_requests = (
_jvm.org.apache.spark.resource.ExecutorResourceRequests()
)
Expand Down Expand Up @@ -460,9 +462,11 @@ def __init__(
_requests: Optional[Dict[str, TaskResourceRequest]] = None,
):
from pyspark import SparkContext
from pyspark.sql import is_remote

_jvm = _jvm or SparkContext._jvm
if _jvm is not None:

if _jvm is not None and not is_remote():
self._java_task_resource_requests: Optional[
JavaObject
] = _jvm.org.apache.spark.resource.TaskResourceRequests()
Expand Down
Loading