Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Added warning for unsupported Transforms
  • Loading branch information
vinjai committed Jul 4, 2024
commit 793c99f7adb4d6f87feadcb395b84e4f724a02a6
44 changes: 26 additions & 18 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2727,19 +2727,23 @@ def _dataframe_to_data_files(
if len(table_metadata.spec().fields) > 0:
partitions = _determine_partitions(spec=table_metadata.spec(), schema=table_metadata.schema(), arrow_table=df)

write_partitions = (
[
TablePartition(
partition_key=partition.partition_key,
arrow_table_partition=_sort_table_by_sort_order(
arrow_table=partition.arrow_table_partition, schema=table_metadata.schema(), sort_order=sort_order
),
)
for partition in partitions
]
if sort_order and not sort_order.is_unsorted
else partitions
)
if sort_order and not sort_order.is_unsorted:
try:
write_partitions = [
TablePartition(
partition_key=partition.partition_key,
arrow_table_partition=_sort_table_by_sort_order(
arrow_table=partition.arrow_table_partition, schema=table_metadata.schema(), sort_order=sort_order
),
)
for partition in partitions
]
except Exception as exc:
warnings.warn(f"Failed to sort table with error: {exc}")
sort_order = UNSORTED_SORT_ORDER
write_partitions = partitions
else:
write_partitions = partitions

yield from write_file(
io=io,
Expand All @@ -2758,11 +2762,15 @@ def _dataframe_to_data_files(
]),
)
else:
write_df = (
_sort_table_by_sort_order(arrow_table=df, schema=table_metadata.schema(), sort_order=sort_order)
if sort_order and not sort_order.is_unsorted
else df
)
if sort_order and not sort_order.is_unsorted:
try:
write_df = _sort_table_by_sort_order(arrow_table=df, schema=table_metadata.schema(), sort_order=sort_order)
except Exception as exc:
warnings.warn(f"Failed to sort table with error: {exc}")
sort_order = UNSORTED_SORT_ORDER
write_df = df
else:
write_df = df

yield from write_file(
io=io,
Expand Down
71 changes: 66 additions & 5 deletions tests/integration/test_writes/test_sorted_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,32 +212,93 @@ def test_query_null_append_partitioned_multi_sort(
SortOrder(SortField(source_id=5, transform=BucketTransform(2))),
SortOrder(SortField(source_id=8, transform=BucketTransform(2))),
SortOrder(SortField(source_id=9, transform=BucketTransform(2))),
SortOrder(SortField(source_id=10, transform=BucketTransform(2))),
SortOrder(SortField(source_id=4, transform=TruncateTransform(2))),
SortOrder(SortField(source_id=5, transform=TruncateTransform(2))),
],
)
def test_invalid_sort_transform(
session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table, sort_order: SortOrder
) -> None:
table_identifier = (
f"default.arrow_table_invalid_sort_transform_{','.join([str(field).replace('', '_') for field in sort_order.fields])}"
import re

table_identifier = f"""default.arrow_table_invalid_sort_transform_{'_'.join([f"__{re.sub(r'[^A-Za-z0-9_]', '', str(field.transform))}_{field.source_id}_{field.direction}_{str(field.null_order)}__".replace(' ', '') for field in sort_order.fields])}"""

tbl = _create_table(
session_catalog=session_catalog,
identifier=table_identifier,
properties={"format-version": "1"},
schema=TABLE_SCHEMA,
sort_order=sort_order,
)

with pytest.warns(
UserWarning,
match="Not all sort transforms are supported for writes. Following sort orders cannot be written using pyarrow: *",
):
tbl.append(arrow_table_with_null)

files_df = spark.sql(
f"""
SELECT *
FROM {table_identifier}.files
"""
)

assert [row.sort_order_id for row in files_df.select("sort_order_id").distinct().collect()] == [
0
], "Expected Sort Order Id to be set as 0 (Unsorted) in the manifest file"


@pytest.mark.integration
@pytest.mark.parametrize(
"sort_order",
[
SortOrder(*[
SortField(source_id=1, transform=IdentityTransform()),
SortField(source_id=4, transform=BucketTransform(2)),
]),
SortOrder(SortField(source_id=5, transform=BucketTransform(2))),
SortOrder(SortField(source_id=8, transform=BucketTransform(2))),
SortOrder(SortField(source_id=9, transform=BucketTransform(2))),
SortOrder(SortField(source_id=4, transform=TruncateTransform(2))),
SortOrder(SortField(source_id=5, transform=TruncateTransform(2))),
],
)
def test_invalid_sort_transform_partitioned(
session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table, sort_order: SortOrder
) -> None:
import re

table_identifier = f"""default.arrow_table_invalid_sort_transform_partitioned_{'_'.join([f"__{re.sub(r'[^A-Za-z0-9_]', '', str(field.transform))}_{field.source_id}_{field.direction}_{str(field.null_order)}__".replace(' ', '') for field in sort_order.fields])}"""

tbl = _create_table(
session_catalog=session_catalog,
identifier=table_identifier,
properties={"format-version": "1"},
schema=TABLE_SCHEMA,
sort_order=sort_order,
partition_spec=PartitionSpec(
PartitionField(source_id=10, field_id=1001, transform=IdentityTransform(), name="identity_date")
),
)

with pytest.raises(
ValueError,
with pytest.warns(
UserWarning,
match="Not all sort transforms are supported for writes. Following sort orders cannot be written using pyarrow: *",
):
tbl.append(arrow_table_with_null)

files_df = spark.sql(
f"""
SELECT *
FROM {table_identifier}.files
"""
)

assert [row.sort_order_id for row in files_df.select("sort_order_id").distinct().collect()] == [
0
], "Expected Sort Order Id to be set as 0 (Unsorted) in the manifest file"


@pytest.mark.integration
@pytest.mark.parametrize(
Expand Down