Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix hash function locations
  • Loading branch information
jqin61 committed Apr 5, 2024
commit d64654882e66bd97723e4e20db6509a139617a95
24 changes: 14 additions & 10 deletions pyiceberg/expressions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,11 @@ def __eq__(self, other: Any) -> bool:
@abstractmethod
def as_unbound(self) -> Type[UnboundPredicate[Any]]: ...


def __hash__(self) -> int:
"""Return hash value of the BoundPredicate class."""
return hash(str(self))


class UnboundPredicate(Generic[L], Unbound[BooleanExpression], BooleanExpression, ABC):
term: UnboundTerm[Any]
Expand All @@ -369,6 +374,10 @@ def bind(self, schema: Schema, case_sensitive: bool = True) -> BooleanExpression
@abstractmethod
def as_bound(self) -> Type[BoundPredicate[L]]: ...

def __hash__(self) -> int:
"""Return hash value of the UnaryPredicate class."""
return hash(str(self))


class UnaryPredicate(UnboundPredicate[Any], ABC):
def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundUnaryPredicate[Any]:
Expand All @@ -383,9 +392,7 @@ def __repr__(self) -> str:
@abstractmethod
def as_bound(self) -> Type[BoundUnaryPredicate[Any]]: ...

def __hash__(self) -> int:
"""Return hash value of the UnaryPredicate class."""
return hash(str(self))



class BoundUnaryPredicate(BoundPredicate[L], ABC):
Expand Down Expand Up @@ -416,10 +423,6 @@ def __invert__(self) -> BoundNotNull[L]:
def as_unbound(self) -> Type[IsNull]:
return IsNull

def __hash__(self) -> int:
"""Return hash value of the BoundIsNull class."""
return hash(str(self))


class BoundNotNull(BoundUnaryPredicate[L]):
def __new__(cls, term: BoundTerm[L]): # type: ignore # pylint: disable=W0221
Expand Down Expand Up @@ -733,6 +736,10 @@ def __repr__(self) -> str:
@abstractmethod
def as_unbound(self) -> Type[LiteralPredicate[L]]: ...

def __hash__(self) -> int:
"""Return hash value of the BoundLiteralPredicate class."""
return hash(str(self))


class BoundEqualTo(BoundLiteralPredicate[L]):
def __invert__(self) -> BoundNotEqualTo[L]:
Expand All @@ -743,9 +750,6 @@ def __invert__(self) -> BoundNotEqualTo[L]:
def as_unbound(self) -> Type[EqualTo[L]]:
return EqualTo

def __hash__(self) -> int:
"""Return hash value of the BoundEqualTo class."""
return hash(str(self))


class BoundNotEqualTo(BoundLiteralPredicate[L]):
Expand Down
40 changes: 20 additions & 20 deletions tests/table/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import re
import uuid
from copy import copy
from typing import Any, Dict, Set, Union
from typing import Any, Dict, Set, Union, List

import pyarrow as pa
import pytest
Expand Down Expand Up @@ -1349,51 +1349,51 @@ def test__bind_and_validate_static_overwrite_filter_predicate_fails_to_bind_due_


@pytest.mark.parametrize(
"pred, raises, is_null_preds, eq_to_preds",
"expr, raises, is_null_preds, eq_to_preds",
[
(EqualTo(Reference("foo"), "hello"), False, {}, {EqualTo(Reference("foo"), "hello")}),
(IsNull(Reference("foo")), False, {IsNull(Reference("foo"))}, {}),
(EqualTo(Reference("foo"), "hello"), False, [], [EqualTo(Reference("foo"), "hello")]),
(IsNull(Reference("foo")), False, [IsNull(Reference("foo"))], []),
(
And(IsNull(Reference("foo")), EqualTo(Reference("boo"), "hello")),
False,
{IsNull(Reference("foo"))},
{EqualTo(Reference("boo"), "hello")},
[IsNull(Reference("foo"))],
[EqualTo(Reference("boo"), "hello")],
),
(NotNull, True, {}, {}),
(NotEqualTo, True, {}, {}),
(LessThan(Reference("foo"), 5), True, {}, {}),
(Or(IsNull(Reference("foo")), EqualTo(Reference("foo"), "hello")), True, {}, {}),
(NotNull, True, [], []),
(NotEqualTo, True, [], []),
(LessThan(Reference("foo"), 5), True, [], []),
(Or(IsNull(Reference("foo")), EqualTo(Reference("foo"), "hello")), True, [], []),
(
And(EqualTo(Reference("foo"), "hello"), And(IsNull(Reference("baz")), EqualTo(Reference("boo"), "hello"))),
False,
{IsNull(Reference("baz"))},
{EqualTo(Reference("foo"), "hello"), EqualTo(Reference("boo"), "hello")},
[IsNull(Reference("baz"))],
[EqualTo(Reference("foo"), "hello"), EqualTo(Reference("boo"), "hello")],
),
# Below are crowd-crush tests: a same field can only be with same literal/null, not different literals or both literal and null
# A false crush: when there are duplicated isnull/equalto, the collector should deduplicate them.
(
And(EqualTo(Reference("foo"), "hello"), EqualTo(Reference("foo"), "hello")),
False,
{},
{EqualTo(Reference("foo"), "hello")},
[],
[EqualTo(Reference("foo"), "hello")],
),
# When crush happens
(
And(EqualTo(Reference("foo"), "hello"), EqualTo(Reference("foo"), "bye")),
True,
{},
{EqualTo(Reference("foo"), "hello"), EqualTo(Reference("foo"), "bye")},
[],
[EqualTo(Reference("foo"), "hello"), EqualTo(Reference("foo"), "bye")],
),
(And(EqualTo(Reference("foo"), "hello"), IsNull(Reference("foo"))), True, {IsNull(Reference("foo"))}, {}),
(And(EqualTo(Reference("foo"), "hello"), IsNull(Reference("foo"))), True, [IsNull(Reference("foo"))], []),
],
)
def test__validate_static_overwrite_filter_expr_type(
pred: Union[IsNull, EqualTo[Any]], raises: bool, is_null_preds: Set[IsNull], eq_to_preds: Set[EqualTo[L]]
expr: Union[IsNull, EqualTo[Any]], raises: bool, is_null_preds: List[IsNull], eq_to_preds: List[EqualTo[L]]
) -> None:
if raises:
with pytest.raises(ValueError):
res = _validate_static_overwrite_filter_expr_type(pred)
res = _validate_static_overwrite_filter_expr_type(expr)
else:
res = _validate_static_overwrite_filter_expr_type(pred)
res = _validate_static_overwrite_filter_expr_type(expr)
assert {str(e) for e in res[0]} == {str(e) for e in is_null_preds}
assert {str(e) for e in res[1]} == {str(e) for e in eq_to_preds}