Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add merge_append
  • Loading branch information
HonahX committed Jun 3, 2024
commit cbb8cecee9a226a5b6568e36316f13d24e9acc3c
49 changes: 49 additions & 0 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,44 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
for data_file in data_files:
update_snapshot.append_data_file(data_file)

def merge_append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None:
"""
Shorthand API for appending a PyArrow table to a table transaction.

Args:
df: The Arrow dataframe that will be appended to overwrite the table
snapshot_properties: Custom properties to be added to the snapshot summary
"""
try:
import pyarrow as pa
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e

if not isinstance(df, pa.Table):
raise ValueError(f"Expected PyArrow table, got: {df}")

if unsupported_partitions := [
field for field in self.table_metadata.spec().fields if not field.transform.supports_pyarrow_transform
]:
raise ValueError(
f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}."
)

_check_schema_compatible(self._table.schema(), other_schema=df.schema)
# cast if the two schemas are compatible but not equal
table_arrow_schema = self._table.schema().as_arrow()
if table_arrow_schema != df.schema:
df = df.cast(table_arrow_schema)

with self.update_snapshot(snapshot_properties=snapshot_properties).merge_append() as update_snapshot:
# skip writing data files if the dataframe is empty
if df.shape[0] > 0:
data_files = _dataframe_to_data_files(
table_metadata=self._table.metadata, write_uuid=update_snapshot.commit_uuid, df=df,
io=self._table.io
)
for data_file in data_files:
update_snapshot.append_data_file(data_file)
def overwrite(
self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_TRUE, snapshot_properties: Dict[str, str] = EMPTY_DICT
) -> None:
Expand Down Expand Up @@ -1352,6 +1390,17 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
with self.transaction() as tx:
tx.append(df=df, snapshot_properties=snapshot_properties)

def merge_append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None:
"""
Shorthand API for appending a PyArrow table to the table.

Args:
df: The Arrow dataframe that will be appended to overwrite the table
snapshot_properties: Custom properties to be added to the snapshot summary
"""
with self.transaction() as tx:
tx.merge_append(df=df, snapshot_properties=snapshot_properties)

def overwrite(
self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_TRUE, snapshot_properties: Dict[str, str] = EMPTY_DICT
) -> None:
Expand Down
20 changes: 10 additions & 10 deletions tests/integration/test_writes/test_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,7 @@ def table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with_null
@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_merge_manifest_min_count_to_merge(
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
) -> None:
tbl_a = _create_table(
session_catalog,
Expand All @@ -898,19 +898,19 @@ def test_merge_manifest_min_count_to_merge(
)

# tbl_a should merge all manifests into 1
tbl_a.append(arrow_table_with_null)
tbl_a.append(arrow_table_with_null)
tbl_a.append(arrow_table_with_null)
tbl_a.merge_append(arrow_table_with_null)
tbl_a.merge_append(arrow_table_with_null)
tbl_a.merge_append(arrow_table_with_null)

# tbl_b should not merge any manifests because the target size is too small
tbl_b.append(arrow_table_with_null)
tbl_b.append(arrow_table_with_null)
tbl_b.append(arrow_table_with_null)
tbl_b.merge_append(arrow_table_with_null)
tbl_b.merge_append(arrow_table_with_null)
tbl_b.merge_append(arrow_table_with_null)

# tbl_c should not merge any manifests because merging is disabled
tbl_c.append(arrow_table_with_null)
tbl_c.append(arrow_table_with_null)
tbl_c.append(arrow_table_with_null)
tbl_c.merge_append(arrow_table_with_null)
tbl_c.merge_append(arrow_table_with_null)
tbl_c.merge_append(arrow_table_with_null)

assert len(tbl_a.current_snapshot().manifests(tbl_a.io)) == 1 # type: ignore
assert len(tbl_b.current_snapshot().manifests(tbl_b.io)) == 3 # type: ignore
Expand Down