Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 50 additions & 19 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,9 @@ def agg(self, *exprs: Union[Column, Dict[str, str]]) -> ParentDataFrame:
return self.groupBy().agg(*exprs)

def alias(self, alias: str) -> ParentDataFrame:
return DataFrame(plan.SubqueryAlias(self._plan, alias), session=self._session)
res = DataFrame(plan.SubqueryAlias(self._plan, alias), session=self._session)
res._cached_schema = self._cached_schema
return res

def colRegex(self, colName: str) -> Column:
from pyspark.sql.connect.column import Column as ConnectColumn
Expand Down Expand Up @@ -314,10 +316,12 @@ def coalesce(self, numPartitions: int) -> ParentDataFrame:
error_class="VALUE_NOT_POSITIVE",
message_parameters={"arg_name": "numPartitions", "arg_value": str(numPartitions)},
)
return DataFrame(
res = DataFrame(
plan.Repartition(self._plan, num_partitions=numPartitions, shuffle=False),
self._session,
)
res._cached_schema = self._cached_schema
return res

@overload
def repartition(self, numPartitions: int, *cols: "ColumnOrName") -> ParentDataFrame:
Expand All @@ -340,20 +344,20 @@ def repartition( # type: ignore[misc]
},
)
if len(cols) == 0:
return DataFrame(
res = DataFrame(
plan.Repartition(self._plan, numPartitions, shuffle=True),
self._session,
)
else:
return DataFrame(
res = DataFrame(
plan.RepartitionByExpression(
self._plan, numPartitions, [F._to_col(c) for c in cols]
),
self.sparkSession,
)
elif isinstance(numPartitions, (str, Column)):
cols = (numPartitions,) + cols
return DataFrame(
res = DataFrame(
plan.RepartitionByExpression(self._plan, None, [F._to_col(c) for c in cols]),
self.sparkSession,
)
Expand All @@ -366,6 +370,9 @@ def repartition( # type: ignore[misc]
},
)

res._cached_schema = self._cached_schema
return res

@overload
def repartitionByRange(self, numPartitions: int, *cols: "ColumnOrName") -> ParentDataFrame:
...
Expand All @@ -392,14 +399,14 @@ def repartitionByRange( # type: ignore[misc]
message_parameters={"item": "cols"},
)
else:
return DataFrame(
res = DataFrame(
plan.RepartitionByExpression(
self._plan, numPartitions, [F._sort_col(c) for c in cols]
),
self.sparkSession,
)
elif isinstance(numPartitions, (str, Column)):
return DataFrame(
res = DataFrame(
plan.RepartitionByExpression(
self._plan, None, [F._sort_col(c) for c in [numPartitions] + list(cols)]
),
Expand All @@ -414,6 +421,9 @@ def repartitionByRange( # type: ignore[misc]
},
)

res._cached_schema = self._cached_schema
return res

def dropDuplicates(self, *subset: Union[str, List[str]]) -> ParentDataFrame:
# Acceptable args should be str, ... or a single List[str]
# So if subset length is 1, it can be either single str, or a list of str
Expand All @@ -422,20 +432,23 @@ def dropDuplicates(self, *subset: Union[str, List[str]]) -> ParentDataFrame:
assert all(isinstance(c, str) for c in subset)

if not subset:
return DataFrame(
res = DataFrame(
plan.Deduplicate(child=self._plan, all_columns_as_keys=True), session=self._session
)
elif len(subset) == 1 and isinstance(subset[0], list):
return DataFrame(
res = DataFrame(
plan.Deduplicate(child=self._plan, column_names=subset[0]),
session=self._session,
)
else:
return DataFrame(
res = DataFrame(
plan.Deduplicate(child=self._plan, column_names=cast(List[str], subset)),
session=self._session,
)

res._cached_schema = self._cached_schema
return res

drop_duplicates = dropDuplicates

def dropDuplicatesWithinWatermark(self, *subset: Union[str, List[str]]) -> ParentDataFrame:
Expand Down Expand Up @@ -466,9 +479,11 @@ def dropDuplicatesWithinWatermark(self, *subset: Union[str, List[str]]) -> Paren
)

def distinct(self) -> ParentDataFrame:
return DataFrame(
res = DataFrame(
plan.Deduplicate(child=self._plan, all_columns_as_keys=True), session=self._session
)
res._cached_schema = self._cached_schema
return res

@overload
def drop(self, cols: "ColumnOrName") -> ParentDataFrame:
Expand Down Expand Up @@ -499,7 +514,9 @@ def filter(self, condition: Union[Column, str]) -> ParentDataFrame:
expr = F.expr(condition)
else:
expr = condition
return DataFrame(plan.Filter(child=self._plan, filter=expr), session=self._session)
res = DataFrame(plan.Filter(child=self._plan, filter=expr), session=self._session)
res._cached_schema = self._cached_schema
return res

def first(self) -> Optional[Row]:
return self.head()
Expand Down Expand Up @@ -709,7 +726,9 @@ def _convert_col(df: ParentDataFrame, col: "ColumnOrName") -> Column:
)

def limit(self, n: int) -> ParentDataFrame:
return DataFrame(plan.Limit(child=self._plan, limit=n), session=self._session)
res = DataFrame(plan.Limit(child=self._plan, limit=n), session=self._session)
res._cached_schema = self._cached_schema
return res

def tail(self, num: int) -> List[Row]:
return DataFrame(plan.Tail(child=self._plan, limit=num), session=self._session).collect()
Expand Down Expand Up @@ -766,14 +785,16 @@ def sort(
*cols: Union[int, str, Column, List[Union[int, str, Column]]],
**kwargs: Any,
) -> ParentDataFrame:
return DataFrame(
res = DataFrame(
plan.Sort(
self._plan,
columns=self._sort_cols(cols, kwargs),
is_global=True,
),
session=self._session,
)
res._cached_schema = self._cached_schema
return res

orderBy = sort

Expand All @@ -782,14 +803,16 @@ def sortWithinPartitions(
*cols: Union[int, str, Column, List[Union[int, str, Column]]],
**kwargs: Any,
) -> ParentDataFrame:
return DataFrame(
res = DataFrame(
plan.Sort(
self._plan,
columns=self._sort_cols(cols, kwargs),
is_global=False,
),
session=self._session,
)
res._cached_schema = self._cached_schema
return res

def sample(
self,
Expand Down Expand Up @@ -837,7 +860,7 @@ def sample(

seed = int(seed) if seed is not None else random.randint(0, sys.maxsize)

return DataFrame(
res = DataFrame(
plan.Sample(
child=self._plan,
lower_bound=0.0,
Expand All @@ -847,6 +870,8 @@ def sample(
),
session=self._session,
)
res._cached_schema = self._cached_schema
return res

def withColumnRenamed(self, existing: str, new: str) -> ParentDataFrame:
return self.withColumnsRenamed({existing: new})
Expand Down Expand Up @@ -1050,10 +1075,12 @@ def hint(
},
)

return DataFrame(
res = DataFrame(
plan.Hint(self._plan, name, [F.lit(p) for p in list(parameters)]),
session=self._session,
)
res._cached_schema = self._cached_schema
return res

def randomSplit(
self,
Expand Down Expand Up @@ -1094,6 +1121,7 @@ def randomSplit(
),
session=self._session,
)
samplePlan._cached_schema = self._cached_schema
splits.append(samplePlan)
j += 1

Expand All @@ -1118,9 +1146,9 @@ def observe(
)

if isinstance(observation, Observation):
return observation._on(self, *exprs)
res = observation._on(self, *exprs)
elif isinstance(observation, str):
return DataFrame(
res = DataFrame(
plan.CollectMetrics(self._plan, observation, list(exprs)),
self._session,
)
Expand All @@ -1133,6 +1161,9 @@ def observe(
},
)

res._cached_schema = self._cached_schema
return res

def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = False) -> None:
print(self._show_string(n, truncate, vertical))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType, DoubleType
from pyspark.sql.utils import is_remote

from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF

from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase
from pyspark.testing.sqlutils import (
have_pandas,
Expand Down Expand Up @@ -393,6 +396,38 @@ def test_cached_schema_set_op(self):
# cannot infer when schemas mismatch
self.assertTrue(cdf1.intersectAll(cdf3)._cached_schema is None)

def test_cached_schema_in_chain_op(self):
data = [(1, 1.0), (2, 2.0), (1, 3.0), (2, 4.0)]

cdf = self.connect.createDataFrame(data, ("id", "v1"))
sdf = self.spark.createDataFrame(data, ("id", "v1"))

cdf1 = cdf.withColumn("v2", CF.lit(1))
sdf1 = sdf.withColumn("v2", SF.lit(1))

self.assertTrue(cdf1._cached_schema is None)
# trigger analysis of cdf1.schema
self.assertEqual(cdf1.schema, sdf1.schema)
self.assertTrue(cdf1._cached_schema is not None)

cdf2 = cdf1.where(cdf1.v2 > 0)
sdf2 = sdf1.where(sdf1.v2 > 0)
self.assertEqual(cdf1._cached_schema, cdf2._cached_schema)

cdf3 = cdf2.repartition(10)
sdf3 = sdf2.repartition(10)
self.assertEqual(cdf1._cached_schema, cdf3._cached_schema)

cdf4 = cdf3.distinct()
sdf4 = sdf3.distinct()
self.assertEqual(cdf1._cached_schema, cdf4._cached_schema)

cdf5 = cdf4.sample(fraction=0.5)
sdf5 = sdf4.sample(fraction=0.5)
self.assertEqual(cdf1._cached_schema, cdf5._cached_schema)

self.assertEqual(cdf5.schema, sdf5.schema)


if __name__ == "__main__":
from pyspark.sql.tests.connect.test_connect_dataframe_property import * # noqa: F401
Expand Down