Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Added comments
  • Loading branch information
vinjai committed Jun 30, 2024
commit bd600d607e7fe485aafe18cc21a251e5958066f8
11 changes: 10 additions & 1 deletion pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3780,6 +3780,15 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T


def _sort_table_by_sort_order(arrow_table: pa.Table, schema: Schema, sort_order: SortOrder) -> pa.Table:
"""
Sorts an Arrow Table using Iceberg Sort Order.
The sort implementation is stable.

@param arrow_table: Arrow Table that needs to be sorted
@param schema: Schema of the Iceberg Table
@param sort_order: Sort Order of the Iceberg Table
@return: Sorted Arrow Table
"""
import pyarrow as pa

from pyiceberg.utils.arrow_sorting import convert_sort_field_to_pyarrow_sort_options, get_sort_indices_arrow_table
Expand All @@ -3806,5 +3815,5 @@ def _sort_table_by_sort_order(arrow_table: pa.Table, schema: Schema, sort_order:
for sort_field in sort_order.fields
]

sort_indices = get_sort_indices_arrow_table(tbl=sort_values_generated, sort_seq=arrow_sort_options)
sort_indices = get_sort_indices_arrow_table(arrow_table=sort_values_generated, sort_seq=arrow_sort_options)
return arrow_table.take(sort_indices)
25 changes: 19 additions & 6 deletions pyiceberg/utils/arrow_sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ def __init__(self, sort_direction: str = "ascending", null_order: str = "at_end"


def convert_sort_field_to_pyarrow_sort_options(sort_field: SortField) -> PyArrowSortOptions:
"""
Convert an Iceberg Table Sort Field to Arrow Sort Options

@param sort_field: Source Iceberg Sort Field to be converted
@return: Returns SortField as PyArrow Sort Options
"""
pyarrow_sort_direction = {SortDirection.ASC: "ascending", SortDirection.DESC: "descending"}
pyarrow_null_ordering = {NullOrder.NULLS_LAST: "at_end", NullOrder.NULLS_FIRST: "at_start"}
return PyArrowSortOptions(
Expand All @@ -45,16 +51,23 @@ def convert_sort_field_to_pyarrow_sort_options(sort_field: SortField) -> PyArrow
)


def get_sort_indices_arrow_table(tbl: pa.Table, sort_seq: List[Tuple[str, PyArrowSortOptions]]) -> List[int]:
def get_sort_indices_arrow_table(arrow_table: pa.Table, sort_seq: List[Tuple[str, PyArrowSortOptions]]) -> List[int]:
Copy link
Contributor Author

@vinjai vinjai Jul 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wanted to clarify on the separate implementation for sort_indices other than the one provided by pyarrow.
This is because pyarrow sort_indices or Sort Options only supports one order for null placement across keys.
More details here:

While, the iceberg spec doesn't discriminate of having different null ordering across keys: https://iceberg.apache.org/spec/#sort-orders

This function specifically helps to implement the above functionality by sorting across keys and utilizing the stable nature of the sort_indices algo from pyarrow.


We can raise another issue to improve the performance of this function.


In future, if pyarrow sort_indices does support different null ordering across, we can mark this function as obsolete and keep the implementation clean in the iceberg table append and overwrite methods.

"""
Sorts a Pyarrow Table with a given sort sequence.

@param arrow_table: Input table to be sorted
@param sort_seq: Seq of PyArrowOptions to apply sorting.
@return: Sorted Arrow Table
"""
import pyarrow as pa

index_column_name = "__idx__pyarrow_sort__"
cols = set(tbl.column_names)
cols = set(arrow_table.column_names)

while index_column_name in cols:
index_column_name = f"{index_column_name}_1"

table: pa.Table = tbl.add_column(0, index_column_name, [list(range(len(tbl)))])
sorted_table: pa.Table = arrow_table.add_column(0, index_column_name, [list(range(len(arrow_table)))])

for col_name, _ in sort_seq:
if col_name not in cols:
Expand All @@ -63,10 +76,10 @@ def get_sort_indices_arrow_table(tbl: pa.Table, sort_seq: List[Tuple[str, PyArro
)

for col_name, sort_options in sort_seq[::-1]:
table = table.take(
sorted_table = sorted_table.take(
pa.compute.sort_indices(
table, sort_keys=[(col_name, sort_options.sort_direction)], null_placement=sort_options.null_order
sorted_table, sort_keys=[(col_name, sort_options.sort_direction)], null_placement=sort_options.null_order
)
)

return table[index_column_name].to_pylist()
return sorted_table[index_column_name].to_pylist()