Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
support time travel
  • Loading branch information
sungwy committed Apr 13, 2024
commit 9a26b7b0d0030819c2e996b6b306f2c211eecaeb
22 changes: 11 additions & 11 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3423,7 +3423,7 @@ def _readable_metrics_struct(bound_type: PrimitiveType) -> pa.StructType:
schema=entries_schema,
)

def partitions(self) -> "pa.Table":
def partitions(self, snapshot_id: Optional[int] = None) -> "pa.Table":
import pyarrow as pa

from pyiceberg.io.pyarrow import schema_to_pyarrow
Expand Down Expand Up @@ -3495,16 +3495,16 @@ def update_partitions_map(
raise ValueError(f"Unknown DataFileContent ({file.content})")

partitions_map: Dict[Tuple[str, Any], Any] = {}
if snapshot := self.tbl.metadata.current_snapshot():
for manifest in snapshot.manifests(self.tbl.io):
for entry in manifest.fetch_manifest_entry(io=self.tbl.io):
partition = entry.data_file.partition
partition_record_dict = {
field.name: partition[pos]
for pos, field in enumerate(self.tbl.metadata.specs()[manifest.partition_spec_id].fields)
}
entry_snapshot = self.tbl.snapshot_by_id(entry.snapshot_id) if entry.snapshot_id is not None else None
update_partitions_map(partitions_map, entry.data_file, partition_record_dict, entry_snapshot)
snapshot = self._get_snapshot(snapshot_id)
for manifest in snapshot.manifests(self.tbl.io):
for entry in manifest.fetch_manifest_entry(io=self.tbl.io):
partition = entry.data_file.partition
partition_record_dict = {
field.name: partition[pos]
for pos, field in enumerate(self.tbl.metadata.specs()[manifest.partition_spec_id].fields)
}
entry_snapshot = self.tbl.snapshot_by_id(entry.snapshot_id) if entry.snapshot_id is not None else None
update_partitions_map(partitions_map, entry.data_file, partition_record_dict, entry_snapshot)

return pa.Table.from_pylist(
partitions_map.values(),
Expand Down
24 changes: 13 additions & 11 deletions tests/integration/test_inspect_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,15 +376,17 @@ def test_inspect_partitions_partitioned(spark: SparkSession, session_catalog: Ca
"""
)

df = session_catalog.load_table(identifier).inspect.partitions()

lhs = df.to_pandas()
rhs = spark.table(f"{identifier}.partitions").toPandas()
def check_pyiceberg_df_equals_spark_df(df: pa.Table, spark_df: DataFrame) -> None:
lhs = df.to_pandas().sort_values('spec_id')
rhs = spark_df.toPandas().sort_values('spec_id')
for column in df.column_names:
for left, right in zip(lhs[column].to_list(), rhs[column].to_list()):
if column == "partition":
right = right.asDict()
assert left == right, f"Difference in column {column}: {left} != {right}"

lhs.sort_values('spec_id', inplace=True)
rhs.sort_values('spec_id', inplace=True)
for column in df.column_names:
for left, right in zip(lhs[column].to_list(), rhs[column].to_list()):
if column == "partition":
right = right.asDict()
assert left == right, f"Difference in column {column}: {left} != {right}"
tbl = session_catalog.load_table(identifier)
for snapshot in tbl.metadata.snapshots:
df = tbl.inspect.partitions(snapshot_id=snapshot.snapshot_id)
spark_df = spark.sql(f"SELECT * FROM {identifier}.partitions VERSION AS OF {snapshot.snapshot_id}")
check_pyiceberg_df_equals_spark_df(df, spark_df)