Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refactored and added more tests for sorting
  • Loading branch information
vinjai committed Jun 29, 2024
commit 64b8975ebbc6dc294faf57808bb804f4064e07b6
6 changes: 3 additions & 3 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,10 @@
StructType,
transform_dict_value_to_str,
)
from pyiceberg.utils.arrow_sorting import PyArrowSortOptions
from pyiceberg.utils.concurrent import ExecutorFactory
from pyiceberg.utils.datetime import datetime_to_millis
from pyiceberg.utils.singleton import _convert_to_hashable_type
from pyiceberg.utils.sorting import PyArrowSortOptions

if TYPE_CHECKING:
import daft
Expand Down Expand Up @@ -3782,7 +3782,7 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T
def _sort_table_by_sort_order(arrow_table: pa.Table, schema: Schema, sort_order: SortOrder) -> pa.Table:
import pyarrow as pa

from pyiceberg.utils.sorting import convert_sort_field_to_pyarrow_sort_options, get_sort_indices_arrow_table
from pyiceberg.utils.arrow_sorting import convert_sort_field_to_pyarrow_sort_options, get_sort_indices_arrow_table

sort_columns: List[Tuple[SortField, NestedField]] = [
(sort_field, schema.find_field(sort_field.source_id)) for sort_field in sort_order.fields
Expand All @@ -3801,5 +3801,5 @@ def _sort_table_by_sort_order(arrow_table: pa.Table, schema: Schema, sort_order:
for sort_field in sort_order.fields
]

sort_indices = get_sort_indices_arrow_table(tbl=sort_values_generated, sort_seq=arrow_sort_options).to_pylist()
sort_indices = get_sort_indices_arrow_table(tbl=sort_values_generated, sort_seq=arrow_sort_options)
return arrow_table.take(sort_indices)
23 changes: 20 additions & 3 deletions pyiceberg/utils/sorting.py → pyiceberg/utils/arrow_sorting.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint:disable=redefined-outer-name
from typing import List, Tuple

import pyarrow as pa
Expand Down Expand Up @@ -28,7 +45,7 @@ def convert_sort_field_to_pyarrow_sort_options(sort_field: SortField) -> PyArrow
)


def get_sort_indices_arrow_table(tbl: pa.Table, sort_seq: List[Tuple[str, PyArrowSortOptions]]) -> pa.Array:
def get_sort_indices_arrow_table(tbl: pa.Table, sort_seq: List[Tuple[str, PyArrowSortOptions]]) -> List[int]:
import pyarrow as pa

index_column_name = "__idx__pyarrow_sort__"
Expand All @@ -37,7 +54,7 @@ def get_sort_indices_arrow_table(tbl: pa.Table, sort_seq: List[Tuple[str, PyArro
while index_column_name in cols:
index_column_name = f"{index_column_name}_1"

table = tbl.add_column(0, index_column_name, [list(range(len(tbl)))])
table: pa.Table = tbl.add_column(0, index_column_name, [list(range(len(tbl)))])

for col_name, _ in sort_seq:
if col_name not in cols:
Expand All @@ -52,4 +69,4 @@ def get_sort_indices_arrow_table(tbl: pa.Table, sort_seq: List[Tuple[str, PyArro
)
)

return table[index_column_name]
return table[index_column_name].to_pylist()
4 changes: 2 additions & 2 deletions tests/integration/test_writes/test_sorted_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def test_query_null_sort(
"sort_col_tuple_3", [("int", "bool", "string"), ("long", "float", "double"), ("date", "timestamp", "timestamptz")]
)
@pytest.mark.parametrize("sort_direction_tuple_3", [(SortDirection.ASC, SortDirection.DESC, SortDirection.DESC)])
@pytest.mark.parametrize("sort_null_ordering_tuple_3", [(NullOrder.NULLS_FIRST, NullOrder.NULLS_FIRST, NullOrder.NULLS_LAST)])
@pytest.mark.parametrize("sort_null_ordering_tuple_3", [(NullOrder.NULLS_FIRST, NullOrder.NULLS_FIRST, NullOrder.NULLS_LAST),(NullOrder.NULLS_FIRST, NullOrder.NULLS_FIRST, NullOrder.NULLS_FIRST)])
@pytest.mark.parametrize("format_version", [1, 2])
def test_query_null_multi_sort(
session_catalog: Catalog,
Expand Down Expand Up @@ -150,4 +150,4 @@ def test_query_null_multi_sort(
assert sorted_df.shape[0] == 3, f"Expected 3 total rows for {sorted_table_identifier}"
assert sorted_df.equals(
query_sorted_df
), f"Expected sorted dataframe for v{format_version} on col: {sort_options_list}, got {sorted_df}"
), f"Expected sorted dataframe for v{format_version} on col: {sort_options_list}, got {sorted_df}"
60 changes: 60 additions & 0 deletions tests/utils/test_arrow_sorting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint:disable=redefined-outer-name
from typing import List, Tuple

import pyarrow as pa
import pytest

from pyiceberg.utils.arrow_sorting import PyArrowSortOptions, get_sort_indices_arrow_table


@pytest.fixture
def example_arrow_table_for_sort() -> pa.Table:
return pa.table({
"column1": [5, None, 3, 1, 1, None, 3],
"column2": ["b", "a", None, "c", "c", "d", "m"],
"column3": [10.5, None, 5.1, None, 2.5, 7.3, 3.3],
})


@pytest.mark.parametrize(
"sort_keys, expected",
[
(
[
("column1", PyArrowSortOptions("ascending", "at_end")),
("column2", PyArrowSortOptions("ascending", "at_start")),
("column3", PyArrowSortOptions("descending", "at_end")),
],
[4, 3, 2, 6, 0, 1, 5],
)
],
)
def test_get_sort_indices_arrow_table(
example_arrow_table_for_sort: pa.Table, sort_keys: List[Tuple[str, PyArrowSortOptions]], expected: List[int]
) -> None:
sorted_indices = get_sort_indices_arrow_table(example_arrow_table_for_sort, sort_keys)
assert sorted_indices == expected, "Table sort not in expected form"


@pytest.mark.parametrize("sort_keys, expected", [([("column1", PyArrowSortOptions())], [3, 4, 2, 6, 0, 1, 5])])
def test_stability_get_sort_indices_arrow_table(
example_arrow_table_for_sort: pa.Table, sort_keys: List[Tuple[str, PyArrowSortOptions]], expected: pa.Table
) -> None:
sorted_indices = get_sort_indices_arrow_table(example_arrow_table_for_sort, sort_keys)
assert sorted_indices == expected, "Arrow Table sort is not stable"