From cb175425d1f1e1d75fa03b5a214013acc0288420 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Tue, 11 Jun 2024 10:25:43 +0800 Subject: [PATCH 1/2] init --- python/pyspark/sql/connect/dataframe.py | 69 ++++++++++++++++++------- 1 file changed, 50 insertions(+), 19 deletions(-) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index baac1523c709..f2705ec7ad71 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -262,7 +262,9 @@ def agg(self, *exprs: Union[Column, Dict[str, str]]) -> ParentDataFrame: return self.groupBy().agg(*exprs) def alias(self, alias: str) -> ParentDataFrame: - return DataFrame(plan.SubqueryAlias(self._plan, alias), session=self._session) + res = DataFrame(plan.SubqueryAlias(self._plan, alias), session=self._session) + res._cached_schema = self._cached_schema + return res def colRegex(self, colName: str) -> Column: from pyspark.sql.connect.column import Column as ConnectColumn @@ -314,10 +316,12 @@ def coalesce(self, numPartitions: int) -> ParentDataFrame: error_class="VALUE_NOT_POSITIVE", message_parameters={"arg_name": "numPartitions", "arg_value": str(numPartitions)}, ) - return DataFrame( + res = DataFrame( plan.Repartition(self._plan, num_partitions=numPartitions, shuffle=False), self._session, ) + res._cached_schema = self._cached_schema + return res @overload def repartition(self, numPartitions: int, *cols: "ColumnOrName") -> ParentDataFrame: @@ -340,12 +344,12 @@ def repartition( # type: ignore[misc] }, ) if len(cols) == 0: - return DataFrame( + res = DataFrame( plan.Repartition(self._plan, numPartitions, shuffle=True), self._session, ) else: - return DataFrame( + res = DataFrame( plan.RepartitionByExpression( self._plan, numPartitions, [F._to_col(c) for c in cols] ), @@ -353,7 +357,7 @@ def repartition( # type: ignore[misc] ) elif isinstance(numPartitions, (str, Column)): cols = (numPartitions,) + cols - return DataFrame( + res = DataFrame( plan.RepartitionByExpression(self._plan, None, [F._to_col(c) for c in cols]), self.sparkSession, ) @@ -366,6 +370,9 @@ def repartition( # type: ignore[misc] }, ) + res._cached_schema = self._cached_schema + return res + @overload def repartitionByRange(self, numPartitions: int, *cols: "ColumnOrName") -> ParentDataFrame: ... @@ -392,14 +399,14 @@ def repartitionByRange( # type: ignore[misc] message_parameters={"item": "cols"}, ) else: - return DataFrame( + res = DataFrame( plan.RepartitionByExpression( self._plan, numPartitions, [F._sort_col(c) for c in cols] ), self.sparkSession, ) elif isinstance(numPartitions, (str, Column)): - return DataFrame( + res = DataFrame( plan.RepartitionByExpression( self._plan, None, [F._sort_col(c) for c in [numPartitions] + list(cols)] ), @@ -414,6 +421,9 @@ def repartitionByRange( # type: ignore[misc] }, ) + res._cached_schema = self._cached_schema + return res + def dropDuplicates(self, *subset: Union[str, List[str]]) -> ParentDataFrame: # Acceptable args should be str, ... or a single List[str] # So if subset length is 1, it can be either single str, or a list of str @@ -422,20 +432,23 @@ def dropDuplicates(self, *subset: Union[str, List[str]]) -> ParentDataFrame: assert all(isinstance(c, str) for c in subset) if not subset: - return DataFrame( + res = DataFrame( plan.Deduplicate(child=self._plan, all_columns_as_keys=True), session=self._session ) elif len(subset) == 1 and isinstance(subset[0], list): - return DataFrame( + res = DataFrame( plan.Deduplicate(child=self._plan, column_names=subset[0]), session=self._session, ) else: - return DataFrame( + res = DataFrame( plan.Deduplicate(child=self._plan, column_names=cast(List[str], subset)), session=self._session, ) + res._cached_schema = self._cached_schema + return res + drop_duplicates = dropDuplicates def dropDuplicatesWithinWatermark(self, *subset: Union[str, List[str]]) -> ParentDataFrame: @@ -466,9 +479,11 @@ def dropDuplicatesWithinWatermark(self, *subset: Union[str, List[str]]) -> Paren ) def distinct(self) -> ParentDataFrame: - return DataFrame( + res = DataFrame( plan.Deduplicate(child=self._plan, all_columns_as_keys=True), session=self._session ) + res._cached_schema = self._cached_schema + return res @overload def drop(self, cols: "ColumnOrName") -> ParentDataFrame: @@ -499,7 +514,9 @@ def filter(self, condition: Union[Column, str]) -> ParentDataFrame: expr = F.expr(condition) else: expr = condition - return DataFrame(plan.Filter(child=self._plan, filter=expr), session=self._session) + res = DataFrame(plan.Filter(child=self._plan, filter=expr), session=self._session) + res._cached_schema = self._cached_schema + return res def first(self) -> Optional[Row]: return self.head() @@ -709,7 +726,9 @@ def _convert_col(df: ParentDataFrame, col: "ColumnOrName") -> Column: ) def limit(self, n: int) -> ParentDataFrame: - return DataFrame(plan.Limit(child=self._plan, limit=n), session=self._session) + res = DataFrame(plan.Limit(child=self._plan, limit=n), session=self._session) + res._cached_schema = self._cached_schema + return res def tail(self, num: int) -> List[Row]: return DataFrame(plan.Tail(child=self._plan, limit=num), session=self._session).collect() @@ -766,7 +785,7 @@ def sort( *cols: Union[int, str, Column, List[Union[int, str, Column]]], **kwargs: Any, ) -> ParentDataFrame: - return DataFrame( + res = DataFrame( plan.Sort( self._plan, columns=self._sort_cols(cols, kwargs), @@ -774,6 +793,8 @@ def sort( ), session=self._session, ) + res._cached_schema = self._cached_schema + return res orderBy = sort @@ -782,7 +803,7 @@ def sortWithinPartitions( *cols: Union[int, str, Column, List[Union[int, str, Column]]], **kwargs: Any, ) -> ParentDataFrame: - return DataFrame( + res = DataFrame( plan.Sort( self._plan, columns=self._sort_cols(cols, kwargs), @@ -790,6 +811,8 @@ def sortWithinPartitions( ), session=self._session, ) + res._cached_schema = self._cached_schema + return res def sample( self, @@ -837,7 +860,7 @@ def sample( seed = int(seed) if seed is not None else random.randint(0, sys.maxsize) - return DataFrame( + res = DataFrame( plan.Sample( child=self._plan, lower_bound=0.0, @@ -847,6 +870,8 @@ def sample( ), session=self._session, ) + res._cached_schema = self._cached_schema + return res def withColumnRenamed(self, existing: str, new: str) -> ParentDataFrame: return self.withColumnsRenamed({existing: new}) @@ -1050,10 +1075,12 @@ def hint( }, ) - return DataFrame( + res = DataFrame( plan.Hint(self._plan, name, [F.lit(p) for p in list(parameters)]), session=self._session, ) + res._cached_schema = self._cached_schema + return res def randomSplit( self, @@ -1094,6 +1121,7 @@ def randomSplit( ), session=self._session, ) + samplePlan._cached_schema = self._cached_schema splits.append(samplePlan) j += 1 @@ -1118,9 +1146,9 @@ def observe( ) if isinstance(observation, Observation): - return observation._on(self, *exprs) + res = observation._on(self, *exprs) elif isinstance(observation, str): - return DataFrame( + res = DataFrame( plan.CollectMetrics(self._plan, observation, list(exprs)), self._session, ) @@ -1133,6 +1161,9 @@ def observe( }, ) + res._cached_schema = self._cached_schema + return res + def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = False) -> None: print(self._show_string(n, truncate, vertical)) From aa3175367cd99e251f2214da1728934b8b6b7cfc Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 12 Jun 2024 15:57:56 +0800 Subject: [PATCH 2/2] add test --- .../test_connect_dataframe_property.py | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py index 4a7e1e1ea760..c712e5d6efcb 100644 --- a/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py +++ b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py @@ -20,6 +20,9 @@ from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType, DoubleType from pyspark.sql.utils import is_remote +from pyspark.sql import functions as SF +from pyspark.sql.connect import functions as CF + from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase from pyspark.testing.sqlutils import ( have_pandas, @@ -393,6 +396,38 @@ def test_cached_schema_set_op(self): # cannot infer when schemas mismatch self.assertTrue(cdf1.intersectAll(cdf3)._cached_schema is None) + def test_cached_schema_in_chain_op(self): + data = [(1, 1.0), (2, 2.0), (1, 3.0), (2, 4.0)] + + cdf = self.connect.createDataFrame(data, ("id", "v1")) + sdf = self.spark.createDataFrame(data, ("id", "v1")) + + cdf1 = cdf.withColumn("v2", CF.lit(1)) + sdf1 = sdf.withColumn("v2", SF.lit(1)) + + self.assertTrue(cdf1._cached_schema is None) + # trigger analysis of cdf1.schema + self.assertEqual(cdf1.schema, sdf1.schema) + self.assertTrue(cdf1._cached_schema is not None) + + cdf2 = cdf1.where(cdf1.v2 > 0) + sdf2 = sdf1.where(sdf1.v2 > 0) + self.assertEqual(cdf1._cached_schema, cdf2._cached_schema) + + cdf3 = cdf2.repartition(10) + sdf3 = sdf2.repartition(10) + self.assertEqual(cdf1._cached_schema, cdf3._cached_schema) + + cdf4 = cdf3.distinct() + sdf4 = sdf3.distinct() + self.assertEqual(cdf1._cached_schema, cdf4._cached_schema) + + cdf5 = cdf4.sample(fraction=0.5) + sdf5 = sdf4.sample(fraction=0.5) + self.assertEqual(cdf1._cached_schema, cdf5._cached_schema) + + self.assertEqual(cdf5.schema, sdf5.schema) + if __name__ == "__main__": from pyspark.sql.tests.connect.test_connect_dataframe_property import * # noqa: F401