Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
adopt review feedback: more tests, refactoring, stricter checks
  • Loading branch information
sungwy committed Jul 11, 2024
commit ce643f670f04e7ee4db0363cebdc5457e3323e56
51 changes: 30 additions & 21 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,40 +1297,49 @@ def to_requested_schema(


class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Array]]):
file_schema: Schema
_file_schema: Schema
_include_field_ids: bool
_downcast_ns_timestamp_to_us: bool

def __init__(self, file_schema: Schema, downcast_ns_timestamp_to_us: bool = False, include_field_ids: bool = False) -> None:
self.file_schema = file_schema
self._file_schema = file_schema
self._include_field_ids = include_field_ids
self.downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us
self._downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us

def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array:
file_field = self.file_schema.find_field(field.field_id)
file_field = self._file_schema.find_field(field.field_id)

if field.field_type.is_primitive:
if field.field_type != file_field.field_type:
return values.cast(
schema_to_pyarrow(promote(file_field.field_type, field.field_type), include_field_ids=self._include_field_ids)
)
elif (target_type := schema_to_pyarrow(field.field_type, include_field_ids=self._include_field_ids)) != values.type:
# Downcasting of nanoseconds to microseconds
if (
pa.types.is_timestamp(target_type)
and target_type.unit == "us"
and pa.types.is_timestamp(values.type)
and values.type.unit == "ns"
):
if (target_type.tz == "UTC" and values.type.tz in UTC_ALIASES) or (not target_type.tz and not values.type.tz):
return values.cast(target_type, safe=False)
if (
pa.types.is_timestamp(target_type)
and target_type.unit == "us"
and pa.types.is_timestamp(values.type)
and values.type.unit in {"s", "ms", "us"}
):
if (target_type.tz == "UTC" and values.type.tz in UTC_ALIASES) or (not target_type.tz and not values.type.tz):
return values.cast(target_type)
if field.field_type == TimestampType():
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stricter and clearer field_type driven checks

# Downcasting of nanoseconds to microseconds
if (
pa.types.is_timestamp(target_type)
and not target_type.tz
and pa.types.is_timestamp(values.type)
and not values.type.tz
):
if target_type.unit == "us" and values.type.unit == "ns" and self._downcast_ns_timestamp_to_us:
return values.cast(target_type, safe=False)
elif target_type.unit == "us" and values.type.unit in {"s", "ms"}:
return values.cast(target_type)
raise ValueError(f"Unsupported schema projection from {values.type} to {target_type}")
elif field.field_type == TimestamptzType():
if (
pa.types.is_timestamp(target_type)
and target_type.tz == "UTC"
and pa.types.is_timestamp(values.type)
and values.type.tz in UTC_ALIASES
):
if target_type.unit == "us" and values.type.unit == "ns" and self._downcast_ns_timestamp_to_us:
return values.cast(target_type, safe=False)
elif target_type.unit == "us" and values.type.unit in {"s", "ms", "us"}:
return values.cast(target_type)
raise ValueError(f"Unsupported schema projection from {values.type} to {target_type}")
return values

def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> pa.Field:
Expand Down
116 changes: 114 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2382,10 +2382,122 @@ def arrow_table_date_timestamps() -> "pa.Table":


@pytest.fixture(scope="session")
def arrow_table_date_timestamps_schema() -> Schema:
"""Pyarrow table Schema with only date, timestamp and timestamptz values."""
def table_date_timestamps_schema() -> Schema:
"""Iceberg table Schema with only date, timestamp and timestamptz values."""
return Schema(
NestedField(field_id=1, name="date", field_type=DateType(), required=False),
NestedField(field_id=2, name="timestamp", field_type=TimestampType(), required=False),
NestedField(field_id=3, name="timestamptz", field_type=TimestamptzType(), required=False),
)


@pytest.fixture(scope="session")
def arrow_table_schema_with_all_timestamp_precisions() -> "pa.Schema":
"""Pyarrow Schema with all supported timestamp types."""
import pyarrow as pa

return pa.schema([
("timestamp_s", pa.timestamp(unit="s")),
("timestamptz_s", pa.timestamp(unit="s", tz="UTC")),
("timestamp_ms", pa.timestamp(unit="ms")),
("timestamptz_ms", pa.timestamp(unit="ms", tz="UTC")),
("timestamp_us", pa.timestamp(unit="us")),
("timestamptz_us", pa.timestamp(unit="us", tz="UTC")),
("timestamp_ns", pa.timestamp(unit="ns")),
("timestamptz_ns", pa.timestamp(unit="ns", tz="UTC")),
("timestamptz_us_etc_utc", pa.timestamp(unit="us", tz="Etc/UTC")),
("timestamptz_ns_z", pa.timestamp(unit="ns", tz="Z")),
("timestamptz_s_0000", pa.timestamp(unit="s", tz="+00:00")),
])


@pytest.fixture(scope="session")
def arrow_table_with_all_timestamp_precisions(arrow_table_schema_with_all_timestamp_precisions: "pa.Schema") -> "pa.Table":
"""Pyarrow table with all supported timestamp types."""
import pandas as pd
import pyarrow as pa

test_data = pd.DataFrame({
"timestamp_s": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)],
"timestamptz_s": [
datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
None,
datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
],
"timestamp_ms": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)],
"timestamptz_ms": [
datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
None,
datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
],
"timestamp_us": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)],
"timestamptz_us": [
datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
None,
datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
],
"timestamp_ns": [
pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=6),
None,
pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=7),
],
"timestamptz_ns": [
datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
None,
datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
],
"timestamptz_us_etc_utc": [
datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
None,
datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
],
"timestamptz_ns_z": [
pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=6, tz="UTC"),
None,
pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=7, tz="UTC"),
],
"timestamptz_s_0000": [
datetime(2023, 1, 1, 19, 25, 1, tzinfo=timezone.utc),
None,
datetime(2023, 3, 1, 19, 25, 1, tzinfo=timezone.utc),
],
})
return pa.Table.from_pandas(test_data, schema=arrow_table_schema_with_all_timestamp_precisions)


@pytest.fixture(scope="session")
def arrow_table_schema_with_all_microseconds_timestamp_precisions() -> "pa.Schema":
"""Pyarrow Schema with all microseconds timestamp."""
import pyarrow as pa

return pa.schema([
("timestamp_s", pa.timestamp(unit="us")),
("timestamptz_s", pa.timestamp(unit="us", tz="UTC")),
("timestamp_ms", pa.timestamp(unit="us")),
("timestamptz_ms", pa.timestamp(unit="us", tz="UTC")),
("timestamp_us", pa.timestamp(unit="us")),
("timestamptz_us", pa.timestamp(unit="us", tz="UTC")),
("timestamp_ns", pa.timestamp(unit="us")),
("timestamptz_ns", pa.timestamp(unit="us", tz="UTC")),
("timestamptz_us_etc_utc", pa.timestamp(unit="us", tz="UTC")),
("timestamptz_ns_z", pa.timestamp(unit="us", tz="UTC")),
("timestamptz_s_0000", pa.timestamp(unit="us", tz="UTC")),
])


@pytest.fixture(scope="session")
def table_schema_with_all_microseconds_timestamp_precision() -> Schema:
"""Iceberg table Schema with only date, timestamp and timestamptz values."""
return Schema(
NestedField(field_id=1, name="timestamp_s", field_type=TimestampType(), required=False),
NestedField(field_id=2, name="timestamptz_s", field_type=TimestamptzType(), required=False),
NestedField(field_id=3, name="timestamp_ms", field_type=TimestampType(), required=False),
NestedField(field_id=4, name="timestamptz_ms", field_type=TimestamptzType(), required=False),
NestedField(field_id=5, name="timestamp_us", field_type=TimestampType(), required=False),
NestedField(field_id=6, name="timestamptz_us", field_type=TimestamptzType(), required=False),
NestedField(field_id=7, name="timestamp_ns", field_type=TimestampType(), required=False),
NestedField(field_id=8, name="timestamptz_ns", field_type=TimestamptzType(), required=False),
NestedField(field_id=9, name="timestamptz_us_etc_utc", field_type=TimestamptzType(), required=False),
NestedField(field_id=10, name="timestamptz_ns_z", field_type=TimestamptzType(), required=False),
NestedField(field_id=11, name="timestamptz_s_0000", field_type=TimestamptzType(), required=False),
)
1 change: 1 addition & 0 deletions tests/integration/test_add_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,7 @@ def test_add_files_with_large_and_regular_schema(spark: SparkSession, session_ca
assert table_schema == arrow_schema_large


@pytest.mark.integration
def test_timestamp_tz_ns_downcast_on_read(session_catalog: Catalog, format_version: int, mocker: MockerFixture) -> None:
nanoseconds_schema_iceberg = Schema(NestedField(1, "quux", TimestamptzType()))

Expand Down
14 changes: 7 additions & 7 deletions tests/integration/test_writes/test_partitioned_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,15 +461,15 @@ def test_append_transform_partition_verify_partitions_count(
session_catalog: Catalog,
spark: SparkSession,
arrow_table_date_timestamps: pa.Table,
arrow_table_date_timestamps_schema: Schema,
table_date_timestamps_schema: Schema,
transform: Transform[Any, Any],
expected_partitions: Set[Any],
format_version: int,
) -> None:
# Given
part_col = "timestamptz"
identifier = f"default.arrow_table_v{format_version}_with_{str(transform)}_transform_partitioned_on_col_{part_col}"
nested_field = arrow_table_date_timestamps_schema.find_field(part_col)
nested_field = table_date_timestamps_schema.find_field(part_col)
partition_spec = PartitionSpec(
PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=part_col),
)
Expand All @@ -481,7 +481,7 @@ def test_append_transform_partition_verify_partitions_count(
properties={"format-version": str(format_version)},
data=[arrow_table_date_timestamps],
partition_spec=partition_spec,
schema=arrow_table_date_timestamps_schema,
schema=table_date_timestamps_schema,
)

# Then
Expand Down Expand Up @@ -510,20 +510,20 @@ def test_append_multiple_partitions(
session_catalog: Catalog,
spark: SparkSession,
arrow_table_date_timestamps: pa.Table,
arrow_table_date_timestamps_schema: Schema,
table_date_timestamps_schema: Schema,
format_version: int,
) -> None:
# Given
identifier = f"default.arrow_table_v{format_version}_with_multiple_partitions"
partition_spec = PartitionSpec(
PartitionField(
source_id=arrow_table_date_timestamps_schema.find_field("date").field_id,
source_id=table_date_timestamps_schema.find_field("date").field_id,
field_id=1001,
transform=YearTransform(),
name="date_year",
),
PartitionField(
source_id=arrow_table_date_timestamps_schema.find_field("timestamptz").field_id,
source_id=table_date_timestamps_schema.find_field("timestamptz").field_id,
field_id=1000,
transform=HourTransform(),
name="timestamptz_hour",
Expand All @@ -537,7 +537,7 @@ def test_append_multiple_partitions(
properties={"format-version": str(format_version)},
data=[arrow_table_date_timestamps],
partition_spec=partition_spec,
schema=arrow_table_date_timestamps_schema,
schema=table_date_timestamps_schema,
)

# Then
Expand Down
85 changes: 14 additions & 71 deletions tests/integration/test_writes/test_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import math
import os
import time
from datetime import date, datetime, timezone
from datetime import date, datetime
from pathlib import Path
from typing import Any, Dict
from urllib.parse import urlparse
Expand Down Expand Up @@ -979,88 +979,31 @@ def table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with_null
@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_write_all_timestamp_precision(
mocker: MockerFixture, spark: SparkSession, session_catalog: Catalog, format_version: int
mocker: MockerFixture,
spark: SparkSession,
session_catalog: Catalog,
format_version: int,
arrow_table_schema_with_all_timestamp_precisions: pa.Schema,
arrow_table_with_all_timestamp_precisions: pa.Table,
arrow_table_schema_with_all_microseconds_timestamp_precisions: pa.Schema,
) -> None:
identifier = "default.table_all_timestamp_precision"
arrow_table_schema_with_all_timestamp_precisions = pa.schema([
("timestamp_s", pa.timestamp(unit="s")),
("timestamptz_s", pa.timestamp(unit="s", tz="UTC")),
("timestamp_ms", pa.timestamp(unit="ms")),
("timestamptz_ms", pa.timestamp(unit="ms", tz="UTC")),
("timestamp_us", pa.timestamp(unit="us")),
("timestamptz_us", pa.timestamp(unit="us", tz="UTC")),
("timestamp_ns", pa.timestamp(unit="ns")),
("timestamptz_ns", pa.timestamp(unit="ns", tz="UTC")),
("timestamptz_us_etc_utc", pa.timestamp(unit="us", tz="Etc/UTC")),
("timestamptz_us_z", pa.timestamp(unit="us", tz="Z")),
])
TEST_DATA_WITH_NULL = pd.DataFrame({
"timestamp_s": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)],
"timestamptz_s": [
datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
None,
datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
],
"timestamp_ms": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)],
"timestamptz_ms": [
datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
None,
datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
],
"timestamp_us": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)],
"timestamptz_us": [
datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
None,
datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
],
"timestamp_ns": [
pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=6),
None,
pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=7),
],
"timestamptz_ns": [
datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
None,
datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
],
"timestamptz_us_etc_utc": [
datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
None,
datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
],
"timestamptz_us_z": [
datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
None,
datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
],
})
input_arrow_table = pa.Table.from_pandas(TEST_DATA_WITH_NULL, schema=arrow_table_schema_with_all_timestamp_precisions)
mocker.patch.dict(os.environ, values={"PYICEBERG_DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE": "True"})

tbl = _create_table(
session_catalog,
identifier,
{"format-version": format_version},
data=[input_arrow_table],
data=[arrow_table_with_all_timestamp_precisions],
schema=arrow_table_schema_with_all_timestamp_precisions,
)
tbl.overwrite(input_arrow_table)
tbl.overwrite(arrow_table_with_all_timestamp_precisions)
written_arrow_table = tbl.scan().to_arrow()

expected_schema_in_all_us = pa.schema([
("timestamp_s", pa.timestamp(unit="us")),
("timestamptz_s", pa.timestamp(unit="us", tz="UTC")),
("timestamp_ms", pa.timestamp(unit="us")),
("timestamptz_ms", pa.timestamp(unit="us", tz="UTC")),
("timestamp_us", pa.timestamp(unit="us")),
("timestamptz_us", pa.timestamp(unit="us", tz="UTC")),
("timestamp_ns", pa.timestamp(unit="us")),
("timestamptz_ns", pa.timestamp(unit="us", tz="UTC")),
("timestamptz_us_etc_utc", pa.timestamp(unit="us", tz="UTC")),
("timestamptz_us_z", pa.timestamp(unit="us", tz="UTC")),
])
assert written_arrow_table.schema == expected_schema_in_all_us
assert written_arrow_table == input_arrow_table.cast(expected_schema_in_all_us, safe=False)
assert written_arrow_table.schema == arrow_table_schema_with_all_microseconds_timestamp_precisions
assert written_arrow_table == arrow_table_with_all_timestamp_precisions.cast(
arrow_table_schema_with_all_microseconds_timestamp_precisions, safe=False
)
lhs = spark.table(f"{identifier}").toPandas()
rhs = written_arrow_table.to_pandas()

Expand Down
Loading