Skip to content
Prev Previous commit
Next Next commit
feat: implement RollingManifestWriter
  • Loading branch information
felixscherz committed Aug 5, 2024
commit 159999002dbac0e53232f8e01fe8be45c40254a7
230 changes: 169 additions & 61 deletions pyiceberg/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,23 @@
from copy import copy
from enum import Enum
from types import TracebackType
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Literal,
Optional,
Type,
)
from typing import Any, Generator
from typing import Callable
from typing import Dict
from typing import Iterator
from typing import List
from typing import Literal
from typing import Optional
from typing import Type

from pydantic_core import to_json

from pyiceberg.avro.file import AvroFile, AvroOutputFile
from pyiceberg.conversions import to_bytes
from pyiceberg.exceptions import ValidationError
from pyiceberg.io import FileIO, InputFile, OutputFile
from pyiceberg.io import FileIO
from pyiceberg.io import InputFile
from pyiceberg.io import OutputFile
from pyiceberg.partitioning import PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.typedef import Record, TableVersion
Expand All @@ -53,6 +53,7 @@
StringType,
StructType,
)
from pyiceberg.typedef import EMPTY_DICT

UNASSIGNED_SEQ = -1
DEFAULT_BLOCK_SIZE = 67108864 # 64 * 1024 * 1024
Expand Down Expand Up @@ -102,7 +103,9 @@ def __repr__(self) -> str:

DATA_FILE_TYPE: Dict[int, StructType] = {
1: StructType(
NestedField(field_id=100, name="file_path", field_type=StringType(), required=True, doc="Location URI with FS scheme"),
NestedField(
field_id=100, name="file_path", field_type=StringType(), required=True, doc="Location URI with FS scheme"
),
NestedField(
field_id=101,
name="file_format",
Expand All @@ -117,9 +120,15 @@ def __repr__(self) -> str:
required=True,
doc="Partition data tuple, schema based on the partition spec",
),
NestedField(field_id=103, name="record_count", field_type=LongType(), required=True, doc="Number of records in the file"),
NestedField(
field_id=104, name="file_size_in_bytes", field_type=LongType(), required=True, doc="Total file size in bytes"
field_id=103, name="record_count", field_type=LongType(), required=True, doc="Number of records in the file"
),
NestedField(
field_id=104,
name="file_size_in_bytes",
field_type=LongType(),
required=True,
doc="Total file size in bytes",
),
NestedField(
field_id=105,
Expand Down Expand Up @@ -172,7 +181,11 @@ def __repr__(self) -> str:
doc="Map of column id to upper bound",
),
NestedField(
field_id=131, name="key_metadata", field_type=BinaryType(), required=False, doc="Encryption key metadata blob"
field_id=131,
name="key_metadata",
field_type=BinaryType(),
required=False,
doc="Encryption key metadata blob",
),
NestedField(
field_id=132,
Expand All @@ -192,7 +205,9 @@ def __repr__(self) -> str:
doc="File format name: avro, orc, or parquet",
initial_default=DataFileContent.DATA,
),
NestedField(field_id=100, name="file_path", field_type=StringType(), required=True, doc="Location URI with FS scheme"),
NestedField(
field_id=100, name="file_path", field_type=StringType(), required=True, doc="Location URI with FS scheme"
),
NestedField(
field_id=101,
name="file_format",
Expand All @@ -207,9 +222,15 @@ def __repr__(self) -> str:
required=True,
doc="Partition data tuple, schema based on the partition spec",
),
NestedField(field_id=103, name="record_count", field_type=LongType(), required=True, doc="Number of records in the file"),
NestedField(
field_id=104, name="file_size_in_bytes", field_type=LongType(), required=True, doc="Total file size in bytes"
field_id=103, name="record_count", field_type=LongType(), required=True, doc="Number of records in the file"
),
NestedField(
field_id=104,
name="file_size_in_bytes",
field_type=LongType(),
required=True,
doc="Total file size in bytes",
),
NestedField(
field_id=108,
Expand Down Expand Up @@ -254,7 +275,11 @@ def __repr__(self) -> str:
doc="Map of column id to upper bound",
),
NestedField(
field_id=131, name="key_metadata", field_type=BinaryType(), required=False, doc="Encryption key metadata blob"
field_id=131,
name="key_metadata",
field_type=BinaryType(),
required=False,
doc="Encryption key metadata blob",
),
NestedField(
field_id=132,
Expand Down Expand Up @@ -282,28 +307,34 @@ def __repr__(self) -> str:


def data_file_with_partition(partition_type: StructType, format_version: TableVersion) -> StructType:
data_file_partition_type = StructType(*[
NestedField(
field_id=field.field_id,
name=field.name,
field_type=field.field_type,
required=field.required,
)
for field in partition_type.fields
])
data_file_partition_type = StructType(
*[
NestedField(
field_id=field.field_id,
name=field.name,
field_type=field.field_type,
required=field.required,
)
for field in partition_type.fields
]
)

return StructType(*[
NestedField(
field_id=102,
name="partition",
field_type=data_file_partition_type,
required=True,
doc="Partition data tuple, schema based on the partition spec",
)
if field.field_id == 102
else field
for field in DATA_FILE_TYPE[format_version].fields
])
return StructType(
*[
(
NestedField(
field_id=102,
name="partition",
field_type=data_file_partition_type,
required=True,
doc="Partition data tuple, schema based on the partition spec",
)
if field.field_id == 102
else field
)
for field in DATA_FILE_TYPE[format_version].fields
]
)


class DataFile(Record):
Expand Down Expand Up @@ -384,14 +415,18 @@ def __eq__(self, other: Any) -> bool:
),
}

MANIFEST_ENTRY_SCHEMAS_STRUCT = {format_version: schema.as_struct() for format_version, schema in MANIFEST_ENTRY_SCHEMAS.items()}
MANIFEST_ENTRY_SCHEMAS_STRUCT = {
format_version: schema.as_struct() for format_version, schema in MANIFEST_ENTRY_SCHEMAS.items()
}


def manifest_entry_schema_with_data_file(format_version: TableVersion, data_file: StructType) -> Schema:
return Schema(*[
NestedField(2, "data_file", data_file, required=True) if field.field_id == 2 else field
for field in MANIFEST_ENTRY_SCHEMAS[format_version].fields
])
return Schema(
*[
NestedField(2, "data_file", data_file, required=True) if field.field_id == 2 else field
for field in MANIFEST_ENTRY_SCHEMAS[format_version].fields
]
)


class ManifestEntry(Record):
Expand Down Expand Up @@ -499,7 +534,9 @@ def update(self, value: Any) -> None:
self._min = min(self._min, value)


def construct_partition_summaries(spec: PartitionSpec, schema: Schema, partitions: List[Record]) -> List[PartitionFieldSummary]:
def construct_partition_summaries(
spec: PartitionSpec, schema: Schema, partitions: List[Record]
) -> List[PartitionFieldSummary]:
types = [field.field_type for field in spec.partition_type(schema).fields]
field_stats = [PartitionFieldStats(field_type) for field_type in types]
for partition_keys in partitions:
Expand All @@ -523,7 +560,9 @@ def construct_partition_summaries(spec: PartitionSpec, schema: Schema, partition
NestedField(512, "added_rows_count", LongType(), required=False),
NestedField(513, "existing_rows_count", LongType(), required=False),
NestedField(514, "deleted_rows_count", LongType(), required=False),
NestedField(507, "partitions", ListType(508, PARTITION_FIELD_SUMMARY_TYPE, element_required=True), required=False),
NestedField(
507, "partitions", ListType(508, PARTITION_FIELD_SUMMARY_TYPE, element_required=True), required=False
),
NestedField(519, "key_metadata", BinaryType(), required=False),
),
2: Schema(
Expand All @@ -540,12 +579,16 @@ def construct_partition_summaries(spec: PartitionSpec, schema: Schema, partition
NestedField(512, "added_rows_count", LongType(), required=True),
NestedField(513, "existing_rows_count", LongType(), required=True),
NestedField(514, "deleted_rows_count", LongType(), required=True),
NestedField(507, "partitions", ListType(508, PARTITION_FIELD_SUMMARY_TYPE, element_required=True), required=False),
NestedField(
507, "partitions", ListType(508, PARTITION_FIELD_SUMMARY_TYPE, element_required=True), required=False
),
NestedField(519, "key_metadata", BinaryType(), required=False),
),
}

MANIFEST_LIST_FILE_STRUCTS = {format_version: schema.as_struct() for format_version, schema in MANIFEST_LIST_FILE_SCHEMAS.items()}
MANIFEST_LIST_FILE_STRUCTS = {
format_version: schema.as_struct() for format_version, schema in MANIFEST_LIST_FILE_SCHEMAS.items()
}


POSITIONAL_DELETE_SCHEMA = Schema(
Expand Down Expand Up @@ -669,7 +712,9 @@ def _inherit_from_manifest(entry: ManifestEntry, manifest: ManifestFile) -> Mani

# in v1 tables, the file sequence number is not persisted and can be safely defaulted to 0
# in v2 tables, the file sequence number should be inherited iff the entry status is ADDED
if entry.file_sequence_number is None and (manifest.sequence_number == 0 or entry.status == ManifestEntryStatus.ADDED):
if entry.file_sequence_number is None and (
manifest.sequence_number == 0 or entry.status == ManifestEntryStatus.ADDED
):
# Only available in V2, always 0 in V1
entry.file_sequence_number = manifest.sequence_number

Expand Down Expand Up @@ -842,17 +887,74 @@ def existing(self, entry: ManifestEntry) -> ManifestWriter:


class RollingManifestWriter:
_current_writer: ManifestWriter
_supplier: Callable[[], ManifestWriter]

def __init__(self, supplier: Callable[[], ManifestWriter], target_file_size_in_bytes, target_number_of_rows) -> None:
pass

def _should_roll_to_new_file(self) -> bool: ...
closed: bool
_supplier: Generator[ManifestWriter, None, None]
_manifest_files: list[ManifestFile]
_target_file_size_in_bytes: int
_target_number_of_rows: int
_current_writer: Optional[ManifestWriter]
_current_file_rows: int

def __init__(
self, supplier: Generator[ManifestWriter, None, None], target_file_size_in_bytes, target_number_of_rows
) -> None:
self._closed = False
self._manifest_files = []
self._supplier = supplier
self._target_file_size_in_bytes = target_file_size_in_bytes
self._target_number_of_rows = target_number_of_rows
self._current_writer = None
self._current_file_rows = 0

def __enter__(self) -> RollingManifestWriter:
self._get_current_writer().__enter__()
return self

def to_manifest_files(self) -> list[ManifestFile]: ...
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
self.closed = True
if self._current_writer:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not re-use _close_current_writer here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! I changed it to use _close_current_writer

self._current_writer.__exit__(exc_type, exc_value, traceback)

def _get_current_writer(self) -> ManifestWriter:
if not self._current_writer:
self._current_writer = next(self._supplier)
self._current_writer.__enter__()
return self._current_writer
if self._should_roll_to_new_file():
self._close_current_writer()
return self._current_writer

def _should_roll_to_new_file(self) -> bool:
if not self._current_writer:
return False
return (
self._current_file_rows >= self._target_number_of_rows
or len(self._current_writer._output_file) >= self._target_file_size_in_bytes
)

def add_entry(self, entry: ManifestEntry) -> RollingManifestWriter: ...
def _close_current_writer(self):
if self._current_writer:
self._current_writer.__exit__(None, None, None)
current_file = self._current_writer.to_manifest_file()
self._manifest_files.append(current_file)
self._current_writer = None
self._current_file_rows = 0

def to_manifest_files(self) -> list[ManifestFile]:
self._close_current_writer()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the same pattern as in Java, where the to_manifest_files call expects the writer to be closed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed it to raise a RuntimeError if the writer is not closed, similar to how trying to add an entry to a closed writer raises a RuntimeError.

self._closed = True
return self._manifest_files

def add_entry(self, entry: ManifestEntry) -> RollingManifestWriter:
if self._closed:
raise RuntimeError("Cannot add entry to closed manifest writer")
self._get_current_writer().add_entry(entry)
return self


class ManifestWriterV1(ManifestWriter):
Expand Down Expand Up @@ -962,7 +1064,11 @@ def __init__(self, output_file: OutputFile, snapshot_id: int, parent_snapshot_id
super().__init__(
format_version=1,
output_file=output_file,
meta={"snapshot-id": str(snapshot_id), "parent-snapshot-id": str(parent_snapshot_id), "format-version": "1"},
meta={
"snapshot-id": str(snapshot_id),
"parent-snapshot-id": str(parent_snapshot_id),
"format-version": "1",
},
)

def prepare_manifest(self, manifest_file: ManifestFile) -> ManifestFile:
Expand All @@ -975,7 +1081,9 @@ class ManifestListWriterV2(ManifestListWriter):
_commit_snapshot_id: int
_sequence_number: int

def __init__(self, output_file: OutputFile, snapshot_id: int, parent_snapshot_id: Optional[int], sequence_number: int):
def __init__(
self, output_file: OutputFile, snapshot_id: int, parent_snapshot_id: Optional[int], sequence_number: int
):
super().__init__(
format_version=2,
output_file=output_file,
Expand Down