Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Implement DataFrameQueryContext in Spark Connect
  • Loading branch information
HyukjinKwon committed Jun 17, 2024
commit 853fe82228f488934257087aa5a8f9f1a792e7fd
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ message Expression {
google.protobuf.Any extension = 999;
}

// (Optional) Keep the information of the origin for this expression such as stacktrace.
Origin origin = 18;

// Expression for the OVER clause or WINDOW clause.
message Window {
Expand Down Expand Up @@ -407,3 +409,18 @@ message NamedArgumentExpression {
// (Required) The value expression of the named argument.
Expression value = 2;
}

message Origin {
// (Required) Indicate the origin type.
oneof function {
PythonOrigin python_origin = 1;
}
}

message PythonOrigin {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don;t name stuff Python/Scala if it is not language specific.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually intend this to be language specific. For example, Scala side could have stacktrace chain

This is Python specifically a string for now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, I don't mind combining it for now if you believe it won't be language-specific.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't the Python code be modeled as a StackTraceElement? Whats the difference?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StackTraceElement has JDK dedicated method and fields (e.g., classLoaderName). I think we should have a dedicated one for individual languages.

While DataFrameQueryContext / SQLQueryContext have common information (for now), I think we will end up with having some language-specific and dedicated information for both APIs in the future.

However, I am open to having common one. There is a way to have the common (and, e.g., throw an exception if that information doesn't make sense in some languages).

// (Required) Name of the origin, for example, the name of the function
string fragment = 1;

// (Required) Callsite to show to end users, for example, stacktrace.
string call_site = 2;
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, SESSION_ID}
import org.apache.spark.ml.{functions => MLFunctions}
import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile, TaskResourceProfile, TaskResourceRequest}
import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, SparkSession}
import org.apache.spark.sql.{withOrigin, Column, Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, SparkSession}
import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro}
import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier, QueryPlanningTracker}
import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
Expand All @@ -57,6 +57,7 @@ import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, L
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, CoGroup, CollectMetrics, CommandResult, Deduplicate, DeduplicateWithinWatermark, DeserializeToObject, Except, FlatMapGroupsWithState, Intersect, JoinWith, LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, TypedFilter, Union, Unpivot, UnresolvedHint}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.catalyst.trees.PySparkCurrentOrigin
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils}
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ForeachWriterPacket, InvalidPlanInput, LiteralValueProtoConverter, StorageLevelProtoConverter, StreamingListenerPacket, UdfPacket}
Expand Down Expand Up @@ -1471,7 +1472,20 @@ class SparkConnectPlanner(
* Catalyst expression
*/
@DeveloperApi
def transformExpression(exp: proto.Expression): Expression = {
def transformExpression(exp: proto.Expression): Expression = if (exp.hasOrigin) {
try {
PySparkCurrentOrigin.set(
exp.getOrigin.getPythonOrigin.getFragment,
exp.getOrigin.getPythonOrigin.getCallSite)
withOrigin { doTransformExpression(exp) }
} finally {
PySparkCurrentOrigin.clear()
}
} else {
doTransformExpression(exp)
}

private def doTransformExpression(exp: proto.Expression): Expression = {
exp.getExprTypeCase match {
case proto.Expression.ExprTypeCase.LITERAL => transformLiteral(exp.getLiteral)
case proto.Expression.ExprTypeCase.UNRESOLVED_ATTRIBUTE =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.commons.lang3.exception.ExceptionUtils
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods

import org.apache.spark.{SparkEnv, SparkException, SparkThrowable}
import org.apache.spark.{QueryContextType, SparkEnv, SparkException, SparkThrowable}
import org.apache.spark.api.python.PythonException
import org.apache.spark.connect.proto.FetchErrorDetailsResponse
import org.apache.spark.internal.{Logging, MDC}
Expand Down Expand Up @@ -118,15 +118,27 @@ private[connect] object ErrorUtils extends Logging {
sparkThrowableBuilder.setErrorClass(sparkThrowable.getErrorClass)
}
for (queryCtx <- sparkThrowable.getQueryContext) {
sparkThrowableBuilder.addQueryContexts(
FetchErrorDetailsResponse.QueryContext
.newBuilder()
val builder = FetchErrorDetailsResponse.QueryContext
.newBuilder()
val context = if (queryCtx.contextType() == QueryContextType.SQL) {
builder
.setContextType(FetchErrorDetailsResponse.QueryContext.ContextType.SQL)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did we never set this before?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah .. so it has been always SQLQueryContext by default ...

.setObjectType(queryCtx.objectType())
.setObjectName(queryCtx.objectName())
.setStartIndex(queryCtx.startIndex())
.setStopIndex(queryCtx.stopIndex())
.setFragment(queryCtx.fragment())
.build())
.setSummary(queryCtx.summary())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, so we did not have QueryContext.sumary() API before this change.

.build()
} else {
Copy link
Contributor

@grundprinzip grundprinzip Jun 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really an unconditional else?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, yes because we only have QueryContextType.SQL and QueryContextType.DataFrame.

builder
.setContextType(FetchErrorDetailsResponse.QueryContext.ContextType.DATAFRAME)
.setFragment(queryCtx.fragment())
.setCallSite(queryCtx.callSite())
.setSummary(queryCtx.summary())
.build()
}
sparkThrowableBuilder.addQueryContexts(context)
}
if (sparkThrowable.getSqlState != null) {
sparkThrowableBuilder.setSqlState(sparkThrowable.getSqlState)
Expand Down
51 changes: 37 additions & 14 deletions python/pyspark/errors/exceptions/captured.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,14 @@ def getQueryContext(self) -> List[BaseQueryContext]:
if self._origin is not None and is_instance_of(
gw, self._origin, "org.apache.spark.SparkThrowable"
):
return [QueryContext(q) for q in self._origin.getQueryContext()]
contexts: List[BaseQueryContext] = []
for q in self._origin.getQueryContext():
if q.contextType().toString() == "SQL":
contexts.append(SQLQueryContext(q))
else:
contexts.append(DataFrameQueryContext(q))

return contexts
else:
return []

Expand Down Expand Up @@ -379,17 +386,12 @@ class UnknownException(CapturedException, BaseUnknownException):
"""


class QueryContext(BaseQueryContext):
class SQLQueryContext(BaseQueryContext):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we consider this a private / developer API?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only parent class QueryContext is an API (at pyspark.errors.QueryContext) for now. This is at least consistent with Scala side.

def __init__(self, q: "JavaObject"):
self._q = q

def contextType(self) -> QueryContextType:
context_type = self._q.contextType().toString()
assert context_type in ("SQL", "DataFrame")
if context_type == "DataFrame":
return QueryContextType.DataFrame
else:
return QueryContextType.SQL
return QueryContextType.SQL

def objectType(self) -> str:
return str(self._q.objectType())
Expand All @@ -409,13 +411,34 @@ def fragment(self) -> str:
def callSite(self) -> str:
return str(self._q.callSite())

def pysparkFragment(self) -> Optional[str]: # type: ignore[return]
if self.contextType() == QueryContextType.DataFrame:
return str(self._q.pysparkFragment())
def summary(self) -> str:
return str(self._q.summary())


class DataFrameQueryContext(BaseQueryContext):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the type annotation wrong here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see this is for the classic side.

def __init__(self, q: "JavaObject"):
self._q = q

def contextType(self) -> QueryContextType:
return QueryContextType.DataFrame

def objectType(self) -> str:
return str(self._q.objectType())

def objectName(self) -> str:
return str(self._q.objectName())

def pysparkCallSite(self) -> Optional[str]: # type: ignore[return]
if self.contextType() == QueryContextType.DataFrame:
return str(self._q.pysparkCallSite())
def startIndex(self) -> int:
return int(self._q.startIndex())

def stopIndex(self) -> int:
return int(self._q.stopIndex())

def fragment(self) -> str:
return str(self._q.fragment())

def callSite(self) -> str:
return str(self._q.callSite())

def summary(self) -> str:
return str(self._q.summary())
83 changes: 75 additions & 8 deletions python/pyspark/errors/exceptions/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,10 @@ def convert_exception(
)
query_contexts = []
for query_context in resp.errors[resp.root_error_idx].spark_throwable.query_contexts:
query_contexts.append(QueryContext(query_context))
if query_context.context_type == pb2.FetchErrorDetailsResponse.QueryContext.SQL:
query_contexts.append(SQLQueryContext(query_context))
else:
query_contexts.append(DataFrameQueryContext(query_context))

if "org.apache.spark.sql.catalyst.parser.ParseException" in classes:
return ParseException(
Expand Down Expand Up @@ -430,17 +433,12 @@ class SparkNoSuchElementException(SparkConnectGrpcException, BaseNoSuchElementEx
"""


class QueryContext(BaseQueryContext):
class SQLQueryContext(BaseQueryContext):
def __init__(self, q: pb2.FetchErrorDetailsResponse.QueryContext):
self._q = q

def contextType(self) -> QueryContextType:
context_type = self._q.context_type

if int(context_type) == QueryContextType.DataFrame.value:
return QueryContextType.DataFrame
else:
return QueryContextType.SQL
return QueryContextType.SQL

def objectType(self) -> str:
return str(self._q.object_type)
Expand All @@ -457,6 +455,75 @@ def stopIndex(self) -> int:
def fragment(self) -> str:
return str(self._q.fragment)

def callSite(self) -> str:
raise UnsupportedOperationException(
"",
error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
message_parameters={"className": "SQLQueryContext", "methodName": "callSite"},
sql_state="0A000",
server_stacktrace=None,
display_server_stacktrace=False,
query_contexts=[],
)

def summary(self) -> str:
return str(self._q.summary)


class DataFrameQueryContext(BaseQueryContext):
def __init__(self, q: pb2.FetchErrorDetailsResponse.QueryContext):
self._q = q

def contextType(self) -> QueryContextType:
return QueryContextType.DataFrame

def objectType(self) -> str:
raise UnsupportedOperationException(
"",
error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
message_parameters={"className": "DataFrameQueryContext", "methodName": "objectType"},
sql_state="0A000",
server_stacktrace=None,
display_server_stacktrace=False,
query_contexts=[],
)

def objectName(self) -> str:
raise UnsupportedOperationException(
"",
error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
message_parameters={"className": "DataFrameQueryContext", "methodName": "objectName"},
sql_state="0A000",
server_stacktrace=None,
display_server_stacktrace=False,
query_contexts=[],
)

def startIndex(self) -> int:
raise UnsupportedOperationException(
"",
error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
message_parameters={"className": "DataFrameQueryContext", "methodName": "startIndex"},
sql_state="0A000",
server_stacktrace=None,
display_server_stacktrace=False,
query_contexts=[],
)

def stopIndex(self) -> int:
raise UnsupportedOperationException(
"",
error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
message_parameters={"className": "DataFrameQueryContext", "methodName": "stopIndex"},
sql_state="0A000",
server_stacktrace=None,
display_server_stacktrace=False,
query_contexts=[],
)

def fragment(self) -> str:
return str(self._q.fragment)

def callSite(self) -> str:
return str(self._q.call_site)

Expand Down
Loading