Skip to content
Prev Previous commit
Next Next commit
Reuse Table schema
  • Loading branch information
Fokko committed Jul 10, 2024
commit 4ca513b732b7ecc79d2369d18d0d2028ae0b8e84
28 changes: 21 additions & 7 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1251,30 +1251,44 @@ def project_batches(
total_row_count += len(batch)


def to_requested_schema(requested_schema: Schema, file_schema: Schema, batch: pa.RecordBatch) -> pa.RecordBatch:
def to_requested_schema(
requested_schema: Schema, file_schema: Schema, batch: pa.RecordBatch, include_field_ids: bool = False
) -> pa.RecordBatch:
# We could re-use some of these visitors
struct_array = visit_with_partner(requested_schema, batch, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema))
struct_array = visit_with_partner(
requested_schema, batch, ArrowProjectionVisitor(file_schema, include_field_ids), ArrowAccessor(file_schema)
)
return pa.RecordBatch.from_struct_array(struct_array)


class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Array]]):
file_schema: Schema

def __init__(self, file_schema: Schema):
def __init__(self, file_schema: Schema, include_field_ids: bool = False) -> None:
self.file_schema = file_schema
self._include_field_ids = include_field_ids

def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array:
file_field = self.file_schema.find_field(field.field_id)
if field.field_type.is_primitive and field.field_type != file_field.field_type:
return values.cast(schema_to_pyarrow(promote(file_field.field_type, field.field_type), include_field_ids=False))
return values.cast(
schema_to_pyarrow(promote(file_field.field_type, field.field_type), include_field_ids=self._include_field_ids)
)

return values

def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> pa.Field:
metadata = {}
if field.doc:
metadata[PYARROW_FIELD_DOC_KEY] = field.doc
if self._include_field_ids:
metadata[PYARROW_PARQUET_FIELD_ID_KEY] = str(field.field_id)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah good catch on this one as well 👍


return pa.field(
name=field.name,
type=arrow_type,
nullable=field.optional,
metadata={DOC: field.doc} if field.doc is not None else None,
metadata=metadata,
)

def schema(self, schema: Schema, schema_partner: Optional[pa.Array], struct_result: Optional[pa.Array]) -> Optional[pa.Array]:
Expand Down Expand Up @@ -1904,14 +1918,14 @@ def write_parquet(task: WriteTask) -> DataFile:
file_schema = table_schema

batches = [
to_requested_schema(requested_schema=file_schema, file_schema=table_schema, batch=batch)
to_requested_schema(requested_schema=file_schema, file_schema=table_schema, batch=batch, include_field_ids=True)
for batch in task.record_batches
]
arrow_table = pa.Table.from_batches(batches)
file_path = f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}'
fo = io.new_output(file_path)
with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=file_schema.as_arrow(), **parquet_writer_kwargs) as writer:
with pq.ParquetWriter(fos, schema=arrow_table.schema, **parquet_writer_kwargs) as writer:
writer.write(arrow_table, row_group_size=row_group_size)
statistics = data_file_statistics_from_parquet_metadata(
parquet_metadata=writer.writer.metadata,
Expand Down