diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 2e26a4ccc2..88a7bd00c2 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -734,6 +734,7 @@ def upsert( when_not_matched_insert_all: bool = True, case_sensitive: bool = True, branch: str | None = MAIN_BRANCH, + snapshot_properties: dict[str, str] = EMPTY_DICT, ) -> UpsertResult: """Shorthand API for performing an upsert to an iceberg table. @@ -745,6 +746,7 @@ def upsert( when_not_matched_insert_all: Bool indicating new rows to be inserted that do not match any existing rows in the table case_sensitive: Bool indicating if the match should be case-sensitive branch: Branch Reference to run the upsert operation + snapshot_properties: Custom properties to be added to the snapshot summary To learn more about the identifier-field-ids: https://iceberg.apache.org/spec/#identifier-field-ids @@ -861,12 +863,13 @@ def upsert( rows_to_update, overwrite_filter=Or(*overwrite_predicates) if len(overwrite_predicates) > 1 else overwrite_predicates[0], branch=branch, + snapshot_properties=snapshot_properties, ) if when_not_matched_insert_all: insert_row_cnt = len(rows_to_insert) if rows_to_insert: - self.append(rows_to_insert, branch=branch) + self.append(rows_to_insert, branch=branch, snapshot_properties=snapshot_properties) return UpsertResult(rows_updated=update_row_cnt, rows_inserted=insert_row_cnt) @@ -1327,6 +1330,7 @@ def upsert( when_not_matched_insert_all: bool = True, case_sensitive: bool = True, branch: str | None = MAIN_BRANCH, + snapshot_properties: dict[str, str] = EMPTY_DICT, ) -> UpsertResult: """Shorthand API for performing an upsert to an iceberg table. @@ -1338,6 +1342,7 @@ def upsert( when_not_matched_insert_all: Bool indicating new rows to be inserted that do not match any existing rows in the table case_sensitive: Bool indicating if the match should be case-sensitive branch: Branch Reference to run the upsert operation + snapshot_properties: Custom properties to be added to the snapshot summary To learn more about the identifier-field-ids: https://iceberg.apache.org/spec/#identifier-field-ids @@ -1371,6 +1376,7 @@ def upsert( when_not_matched_insert_all=when_not_matched_insert_all, case_sensitive=case_sensitive, branch=branch, + snapshot_properties=snapshot_properties, ) def append(self, df: pa.Table, snapshot_properties: dict[str, str] = EMPTY_DICT, branch: str | None = MAIN_BRANCH) -> None: diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index 891d4bbac7..35a3a11926 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -834,3 +834,54 @@ def test_stage_only_upsert(catalog: Catalog) -> None: assert operations == ["append", "append", "append"] # both subsequent parent id should be the first snapshot id assert parent_snapshot_id == [None, current_snapshot, current_snapshot] + + +def test_upsert_snapshot_properties(catalog: Catalog) -> None: + """Test that snapshot_properties are applied to snapshots created by upsert.""" + identifier = "default.test_upsert_snapshot_properties" + _drop_table(catalog, identifier) + + schema = Schema( + NestedField(1, "city", StringType(), required=True), + NestedField(2, "population", IntegerType(), required=True), + identifier_field_ids=[1], + ) + + tbl = catalog.create_table(identifier, schema=schema) + arrow_schema = pa.schema( + [ + pa.field("city", pa.string(), nullable=False), + pa.field("population", pa.int32(), nullable=False), + ] + ) + + # Initial data + df = pa.Table.from_pylist( + [{"city": "Amsterdam", "population": 921402}], + schema=arrow_schema, + ) + tbl.append(df) + initial_snapshot_count = len(list(tbl.snapshots())) + + # Upsert with snapshot_properties (both update and insert) + df = pa.Table.from_pylist( + [ + {"city": "Amsterdam", "population": 950000}, # Update + {"city": "Berlin", "population": 3432000}, # Insert + ], + schema=arrow_schema, + ) + result = tbl.upsert(df, snapshot_properties={"test_prop": "test_value"}) + + assert result.rows_updated == 1 + assert result.rows_inserted == 1 + + # Verify properties are on the snapshots created by upsert + snapshots = list(tbl.snapshots()) + # Upsert should have created additional snapshots (overwrite + append) + assert len(snapshots) > initial_snapshot_count + + # Check that all new snapshots have the snapshot_properties + for snapshot in snapshots[initial_snapshot_count:]: + assert snapshot.summary is not None + assert snapshot.summary.additional_properties.get("test_prop") == "test_value"