Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
tests
  • Loading branch information
sungwy committed Jul 11, 2024
commit fd5313848c736aa97e04ed05e22ea5e301be3264
2 changes: 1 addition & 1 deletion pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,7 +937,7 @@ def primitive(self, primitive: pa.DataType) -> PrimitiveType:
else:
raise TypeError(f"Unsupported precision for timestamp type: {primitive.unit}")

if primitive.tz in {"UTC", "+00:00", "Etc/UTC"}:
if primitive.tz in {"UTC", "+00:00", "Etc/UTC", "Z"}:
return TimestamptzType()
elif primitive.tz is None:
return TimestampType()
Expand Down
8 changes: 0 additions & 8 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,10 +529,6 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
)

_check_schema_compatible(self._table.schema(), other_schema=df.schema)
# cast if the two schemas are compatible but not equal
table_arrow_schema = self._table.schema().as_arrow()
if table_arrow_schema != df.schema:
df = df.cast(table_arrow_schema)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed this cast function - to_requested_schema should be responsible for casting the types to their desired schema, instead of casting it here


manifest_merge_enabled = PropertyUtil.property_as_bool(
self.table_metadata.properties,
Expand Down Expand Up @@ -588,10 +584,6 @@ def overwrite(
)

_check_schema_compatible(self._table.schema(), other_schema=df.schema)
# cast if the two schemas are compatible but not equal
table_arrow_schema = self._table.schema().as_arrow()
if table_arrow_schema != df.schema:
df = df.cast(table_arrow_schema)

self.delete(delete_filter=overwrite_filter, snapshot_properties=snapshot_properties)

Expand Down
37 changes: 31 additions & 6 deletions tests/integration/test_writes/test_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import Any, Dict
from urllib.parse import urlparse

import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import pytest
Expand Down Expand Up @@ -968,7 +969,9 @@ def table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with_null

@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_write_all_timestamp_precision(mocker: MockerFixture, session_catalog: Catalog, format_version: int) -> None:
def test_write_all_timestamp_precision(
mocker: MockerFixture, spark: SparkSession, session_catalog: Catalog, format_version: int
) -> None:
identifier = "default.table_all_timestamp_precision"
arrow_table_schema_with_all_timestamp_precisions = pa.schema([
("timestamp_s", pa.timestamp(unit="s")),
Expand All @@ -980,8 +983,9 @@ def test_write_all_timestamp_precision(mocker: MockerFixture, session_catalog: C
("timestamp_ns", pa.timestamp(unit="ns")),
("timestamptz_ns", pa.timestamp(unit="ns", tz="UTC")),
("timestamptz_us_etc_utc", pa.timestamp(unit="us", tz="Etc/UTC")),
("timestamptz_us_z", pa.timestamp(unit="us", tz="Z")),
])
TEST_DATA_WITH_NULL = {
TEST_DATA_WITH_NULL = pd.DataFrame({
"timestamp_s": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)],
"timestamptz_s": [
datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
Expand All @@ -1000,7 +1004,11 @@ def test_write_all_timestamp_precision(mocker: MockerFixture, session_catalog: C
None,
datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
],
"timestamp_ns": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)],
"timestamp_ns": [
pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=6),
None,
pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=7),
],
"timestamptz_ns": [
datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
None,
Expand All @@ -1011,8 +1019,13 @@ def test_write_all_timestamp_precision(mocker: MockerFixture, session_catalog: C
None,
datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
],
}
input_arrow_table = pa.Table.from_pydict(TEST_DATA_WITH_NULL, schema=arrow_table_schema_with_all_timestamp_precisions)
"timestamptz_us_z": [
datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
None,
datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
],
})
input_arrow_table = pa.Table.from_pandas(TEST_DATA_WITH_NULL, schema=arrow_table_schema_with_all_timestamp_precisions)
mocker.patch.dict(os.environ, values={"PYICEBERG_DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE": "True"})

tbl = _create_table(
Expand All @@ -1035,9 +1048,21 @@ def test_write_all_timestamp_precision(mocker: MockerFixture, session_catalog: C
("timestamp_ns", pa.timestamp(unit="us")),
("timestamptz_ns", pa.timestamp(unit="us", tz="UTC")),
("timestamptz_us_etc_utc", pa.timestamp(unit="us", tz="UTC")),
("timestamptz_us_z", pa.timestamp(unit="us", tz="UTC")),
])
assert written_arrow_table.schema == expected_schema_in_all_us
assert written_arrow_table == input_arrow_table.cast(expected_schema_in_all_us)
assert written_arrow_table == input_arrow_table.cast(expected_schema_in_all_us, safe=False)
lhs = spark.table(f"{identifier}").toPandas()
rhs = written_arrow_table.to_pandas()

for column in written_arrow_table.column_names:
for left, right in zip(lhs[column].to_list(), rhs[column].to_list()):
if pd.isnull(left):
assert pd.isnull(right)
else:
# Check only upto microsecond precision since Spark loaded dtype is timezone unaware
# and supports upto microsecond precision
assert left.timestamp() == right.timestamp(), f"Difference in column {column}: {left} != {right}"


@pytest.mark.integration
Expand Down