Skip to content
Merged
Prev Previous commit
Next Next commit
Make linter happy
  • Loading branch information
Fokko committed Jul 10, 2024
commit ee293a1692666b5c41112e81f520e5882bde937a
16 changes: 12 additions & 4 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1274,19 +1274,27 @@ def project_batches(


def to_requested_schema(
requested_schema: Schema, file_schema: Schema, batch: pa.RecordBatch, downcast_ns_timestamp_to_us: bool = False, include_field_ids: bool = False
requested_schema: Schema,
file_schema: Schema,
batch: pa.RecordBatch,
downcast_ns_timestamp_to_us: bool = False,
include_field_ids: bool = False,
) -> pa.RecordBatch:
# We could re-use some of these visitors
struct_array = visit_with_partner(
requested_schema, batch, ArrowProjectionVisitor(file_schema, downcast_ns_timestamp_to_us, include_field_ids), ArrowAccessor(file_schema)
requested_schema,
batch,
ArrowProjectionVisitor(file_schema, downcast_ns_timestamp_to_us, include_field_ids),
ArrowAccessor(file_schema),
)
return pa.RecordBatch.from_struct_array(struct_array)


class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Array]]):
file_schema: Schema
_include_field_ids: bool

def __init__(self, file_schema: Schema, downcast_ns_timestamp_to_us: bool = False, include_field_ids: bool = False) -> None:
def __init__(self, file_schema: Schema, downcast_ns_timestamp_to_us: bool = False, include_field_ids: bool = False) -> None:
self.file_schema = file_schema
self._include_field_ids = include_field_ids
self.downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us
Expand Down Expand Up @@ -1967,7 +1975,7 @@ def write_parquet(task: WriteTask) -> DataFile:
file_schema=table_schema,
batch=batch,
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us,
include_field_ids=True
include_field_ids=True,
)
for batch in task.record_batches
]
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_add_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,7 @@ def test_add_files_with_large_and_regular_schema(spark: SparkSession, session_ca
table_schema = tbl.scan().to_arrow().schema
assert table_schema == arrow_schema_large


def test_timestamp_tz_ns_downcast_on_read(session_catalog: Catalog, format_version: int, mocker: MockerFixture) -> None:
nanoseconds_schema_iceberg = Schema(NestedField(1, "quux", TimestamptzType()))

Expand Down