Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
todo: sort with pyarrow_transform vals
  • Loading branch information
sungwy committed May 31, 2024
commit c30a57cfe93aaf979df949f801929ecf10079601
7 changes: 4 additions & 3 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,10 +392,11 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
if not isinstance(df, pa.Table):
raise ValueError(f"Expected PyArrow table, got: {df}")

supported_transforms = {IdentityTransform}
if not all(type(field.transform) in supported_transforms for field in self.table_metadata.spec().fields):
if unsupported_partitions := [
field for field in self.table_metadata.spec().fields if not field.transform.supports_pyarrow_transform
]:
raise ValueError(
f"All transforms are not supported, expected: {supported_transforms}, but get: {[str(field) for field in self.table_metadata.spec().fields if field.transform not in supported_transforms]}."
f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}."
)

_check_schema_compatible(self._table.schema(), other_schema=df.schema)
Expand Down
18 changes: 18 additions & 0 deletions pyiceberg/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,10 @@ def __eq__(self, other: Any) -> bool:
return self.root == other.root
return False

@property
def supports_pyarrow_transform(self) -> bool:
return False


class BucketTransform(Transform[S, int]):
"""Base Transform class to transform a value into a bucket partition value.
Expand Down Expand Up @@ -352,6 +356,13 @@ def dedup_name(self) -> str:
def preserves_order(self) -> bool:
return True

@abstractmethod
def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": ...

@property
def supports_pyarrow_transform(self) -> bool:
return True


class YearTransform(TimeTransform[S]):
"""Transforms a datetime value into a year value.
Expand Down Expand Up @@ -652,6 +663,13 @@ def __repr__(self) -> str:
"""Return the string representation of the IdentityTransform class."""
return "IdentityTransform()"

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
return lambda v: v

@property
def supports_pyarrow_transform(self) -> bool:
return True


class TruncateTransform(Transform[S, S]):
"""A transform for truncating a value to a specified width.
Expand Down
31 changes: 26 additions & 5 deletions tests/integration/test_writes/test_partitioned_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
# pylint:disable=redefined-outer-name


from typing import Any

import pyarrow as pa
import pytest
from pyspark.sql import SparkSession
from typing import Any

from pyiceberg.catalog import Catalog
from pyiceberg.exceptions import NoSuchTableError
Expand Down Expand Up @@ -390,13 +391,24 @@ def test_unsupported_transform(


@pytest.mark.integration
@pytest.mark.parametrize('transform', [YearTransform(), MonthTransform(), DayTransform()])
@pytest.mark.parametrize(
"part_col", ["date", "timestamp", "timestamptz"]
"transform,expected_rows",
[
pytest.param(YearTransform(), 2, id="year_transform"),
pytest.param(MonthTransform(), 3, id="month_transform"),
pytest.param(DayTransform(), 3, id="day_transform"),
],
)
@pytest.mark.parametrize("part_col", ["date", "timestamp", "timestamptz"])
@pytest.mark.parametrize("format_version", [1, 2])
def test_append_ymd_transform_partitioned(
session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table, transform: Transform[Any, Any], part_col: str, format_version: int
session_catalog: Catalog,
spark: SparkSession,
arrow_table_with_null: pa.Table,
transform: Transform[Any, Any],
expected_rows: int,
part_col: str,
format_version: int,
) -> None:
# Given
identifier = f"default.arrow_table_v{format_version}_with_ymd_transform_partitioned_on_col_{part_col}"
Expand All @@ -420,4 +432,13 @@ def test_append_ymd_transform_partitioned(
assert df.count() == 3, f"Expected 3 total rows for {identifier}"
for col in TEST_DATA_WITH_NULL.keys():
assert df.where(f"{col} is not null").count() == 2, f"Expected 2 non-null rows for {col}"
assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null"
assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null"

assert tbl.inspect.partitions().num_rows == expected_rows
files_df = spark.sql(
f"""
SELECT *
FROM {identifier}.files
"""
)
assert files_df.count() == expected_rows
28 changes: 9 additions & 19 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1847,7 +1847,7 @@ def arrow_table_date_timestamps() -> "pa.Table":
)


@pytest.mark.parametrize('transform', [YearTransform(), MonthTransform(), DayTransform()])
@pytest.mark.parametrize('transform', [YearTransform(), MonthTransform(), DayTransform(), HourTransform()])
@pytest.mark.parametrize(
"source_col, source_type", [("date", DateType()), ("timestamp", TimestampType()), ("timestamptz", TimestamptzType())]
)
Expand All @@ -1857,21 +1857,11 @@ def test_ymd_pyarrow_transforms(
source_type: PrimitiveType,
transform: Transform[Any, Any],
) -> None:
assert transform.pyarrow_transform(source_type)(arrow_table_date_timestamps[source_col]).to_pylist() == [
transform.transform(source_type)(_to_partition_representation(source_type, v))
for v in arrow_table_date_timestamps[source_col].to_pylist()
]


@pytest.mark.parametrize("source_col, source_type", [("timestamp", TimestampType()), ("timestamptz", TimestamptzType())])
def test_hour_pyarrow_transforms(arrow_table_date_timestamps: "pa.Table", source_col: str, source_type: PrimitiveType) -> None:
assert HourTransform().pyarrow_transform(source_type)(arrow_table_date_timestamps[source_col]).to_pylist() == [
HourTransform().transform(source_type)(_to_partition_representation(source_type, v))
for v in arrow_table_date_timestamps[source_col].to_pylist()
]


def test_hour_pyarrow_transforms_throws_with_dates(arrow_table_date_timestamps: "pa.Table") -> None:
# HourTransform is not supported for DateType
with pytest.raises(ValueError):
HourTransform().pyarrow_transform(DateType())(arrow_table_date_timestamps["date"])
if transform.can_transform(source_type):
assert transform.pyarrow_transform(source_type)(arrow_table_date_timestamps[source_col]).to_pylist() == [
transform.transform(source_type)(_to_partition_representation(source_type, v))
for v in arrow_table_date_timestamps[source_col].to_pylist()
]
else:
with pytest.raises(ValueError):
transform.pyarrow_transform(DateType())(arrow_table_date_timestamps[source_col])