Skip to content
Merged
114 changes: 113 additions & 1 deletion python/paddle/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any, NamedTuple
from typing import TYPE_CHECKING, Any, Literal, NamedTuple

from typing_extensions import overload

import paddle
from paddle import _C_ops
Expand Down Expand Up @@ -47,6 +49,7 @@
'split',
'min',
'max',
'unique',
'median',
'nanmedian',
'seed',
Expand Down Expand Up @@ -837,6 +840,115 @@ def sort(
return SortRetType(values=outputs, indices=indices)


@overload
def unique(
input: Tensor,
sorted: bool = ...,
return_inverse: Literal[True] = ...,
return_counts: Literal[True] = ...,
dim: int | None = ...,
) -> tuple[Tensor, Tensor, Tensor]: ...


@overload
def unique(
input: Tensor,
sorted: bool = ...,
return_inverse: Literal[False] = ...,
return_counts: Literal[True] = ...,
dim: int | None = ...,
) -> tuple[Tensor, Tensor]: ...


@overload
def unique(
input: Tensor,
sorted: bool = ...,
return_inverse: Literal[True] = ...,
return_counts: Literal[False] = ...,
dim: int | None = ...,
) -> tuple[Tensor, Tensor]: ...


@overload
def unique(
input: Tensor,
sorted: bool = ...,
return_inverse: Literal[False] = ...,
return_counts: Literal[False] = ...,
dim: int | None = ...,
) -> Tensor: ...


@ForbidKeywordsDecorator(
illegal_keys={"x", "axis"},
func_name="paddle.compat.unique",
correct_name="paddle.unique",
)
def unique(
input,
sorted=True,
return_inverse=False,
return_counts=False,
dim=None,
):
r"""
Returns the unique elements of `input` in ascending order.

Args:
input(Tensor): The input tensor, it's data type should be float32, float64, int32, int64.
sorted(bool, optional): Does not affect the return result, same as PyTorch.
return_inverse(bool, optional): If True, also return the indices for where elements in
the original input ended up in the returned unique tensor.
return_counts(bool, optional): If True, also return the counts for each unique element.
dim(int, optional): The axis to apply unique. If None, the input will be flattened.
Default: None.

Returns:
tuple (output, inverse_indices, counts). `output` is the unique tensor for `input`. \
`inverse_indices` is provided only if `return_inverse` \
is True. `counts` is provided only if `return_counts` is True.

Examples:
.. code-block:: python

>>> import paddle

>>> x = paddle.to_tensor([2, 3, 3, 1, 5, 3])
>>> unique = paddle.compat.unique(x)
>>> print(unique)
Tensor(shape=[4], dtype=int64, place=Place(cpu), stop_gradient=True,
[1, 2, 3, 5])

>>> _, inverse_indices, counts = paddle.compat.unique(x, return_inverse=True, return_counts=True)
>>> print(inverse_indices)
Tensor(shape=[6], dtype=int64, place=Place(cpu), stop_gradient=True,
[1, 2, 2, 0, 3, 2])
>>> print(counts)
Tensor(shape=[4], dtype=int64, place=Place(cpu), stop_gradient=True,
[1, 1, 3, 1])

>>> x = paddle.to_tensor([[2, 1, 3], [3, 0, 1], [2, 1, 3]])
>>> unique = paddle.compat.unique(x)
>>> print(unique)
Tensor(shape=[4], dtype=int64, place=Place(cpu), stop_gradient=True,
[0, 1, 2, 3])

>>> unique = paddle.compat.unique(x, dim=0)
>>> print(unique)
Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True,
[[2, 1, 3],
[3, 0, 1]])
"""
return paddle.unique(
input,
return_inverse=return_inverse,
return_counts=return_counts,
axis=dim,
sorted=sorted,
)


@ForbidKeywordsDecorator(
illegal_keys={"x", "num_or_sections", "axis", "name"},
func_name="paddle.compat.split",
Expand Down
11 changes: 11 additions & 0 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3672,6 +3672,7 @@ def unique(
return_counts: Literal[True] = ...,
axis: int | None = ...,
dtype: DTypeLike = ...,
sorted: bool = ...,
name: str | None = ...,
) -> tuple[Tensor, Tensor, Tensor, Tensor]: ...

Expand All @@ -3684,6 +3685,7 @@ def unique(
return_counts: Literal[True] = ...,
axis: int | None = ...,
dtype: DTypeLike = ...,
sorted: bool = ...,
name: str | None = ...,
) -> tuple[Tensor, Tensor, Tensor]: ...

Expand All @@ -3696,6 +3698,7 @@ def unique(
return_counts: Literal[True] = ...,
axis: int | None = ...,
dtype: DTypeLike = ...,
sorted: bool = ...,
name: str | None = ...,
) -> tuple[Tensor, Tensor, Tensor]: ...

Expand All @@ -3708,6 +3711,7 @@ def unique(
return_counts: Literal[False] = ...,
axis: int | None = ...,
dtype: DTypeLike = ...,
sorted: bool = ...,
name: str | None = ...,
) -> tuple[Tensor, Tensor, Tensor]: ...

Expand All @@ -3720,6 +3724,7 @@ def unique(
return_counts: Literal[True] = ...,
axis: int | None = ...,
dtype: DTypeLike = ...,
sorted: bool = ...,
name: str | None = ...,
) -> tuple[Tensor, Tensor]: ...

Expand All @@ -3732,6 +3737,7 @@ def unique(
return_counts: Literal[False] = ...,
axis: int | None = ...,
dtype: DTypeLike = ...,
sorted: bool = ...,
name: str | None = ...,
) -> tuple[Tensor, Tensor]: ...

Expand All @@ -3744,6 +3750,7 @@ def unique(
return_counts: Literal[False] = ...,
axis: int | None = ...,
dtype: DTypeLike = ...,
sorted: bool = ...,
name: str | None = ...,
) -> tuple[Tensor, Tensor]: ...

Expand All @@ -3756,6 +3763,7 @@ def unique(
return_counts: Literal[False] = ...,
axis: int | None = ...,
dtype: DTypeLike = ...,
sorted: bool = ...,
name: str | None = ...,
) -> Tensor: ...

Expand All @@ -3768,6 +3776,7 @@ def unique(
return_counts: bool = False,
axis: int | None = ...,
dtype: DTypeLike = ...,
sorted: bool = ...,
name: str | None = ...,
) -> Tensor | tuple[Tensor, ...]: ...

Expand All @@ -3779,6 +3788,7 @@ def unique(
return_counts=False,
axis=None,
dtype="int64",
sorted=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个目前应该无法实现 API完全一致吧,参数顺序不同。

group_norm也有这个问题。不过这个没法判断数据类型了,只能新增加一个paddle.compat.unique了?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的那增加paddle.compat.unique

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhwesky2010 已增加paddle.compat.unique CI完成,PaConvert PaddlePaddle/PaConvert#758 修改测试通过

name=None,
):
r"""
Expand All @@ -3795,6 +3805,7 @@ def unique(
Default: None.
dtype(str|paddle.dtype|np.dtype, optional): The date type of `indices` or `inverse` tensor: int32 or int64.
Default: int64.
sorted(bool, optional): Does not affect the return result, same as PyTorch.
name(str|None, optional): Name for the operation. For more information, please refer to
:ref:`api_guide_Name`. Default: None.

Expand Down
85 changes: 85 additions & 0 deletions test/legacy_test/test_compat_unique.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

import paddle
from paddle import base


class TestCompatUniqueAPI(unittest.TestCase):
def test_basic(self):
paddle.disable_static()
x = paddle.to_tensor([2, 3, 3, 1, 5, 3])
result = paddle.compat.unique(x)
expected = paddle.to_tensor([1, 2, 3, 5], dtype='int64')
np.testing.assert_allclose(result.numpy(), expected.numpy())

_, inverse_indices, counts = paddle.compat.unique(
x, return_inverse=True, return_counts=True
)
expected_indices = paddle.to_tensor([1, 2, 2, 0, 3, 2], dtype='int64')
expected_counts = paddle.to_tensor([1, 1, 3, 1], dtype='int64')
np.testing.assert_allclose(
inverse_indices.numpy(), expected_indices.numpy()
)
np.testing.assert_allclose(counts.numpy(), expected_counts.numpy())

x = paddle.to_tensor([[2, 1, 3], [3, 0, 1], [2, 1, 3]])
result = paddle.compat.unique(x)
expected = paddle.to_tensor([0, 1, 2, 3], dtype='int64')
np.testing.assert_allclose(result.numpy(), expected.numpy())
paddle.enable_static()

def test_static(self):
paddle.enable_static()

with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.static.data(name='input', shape=[6], dtype='int64')
out, inverse_indices, counts = paddle.compat.unique(
x, return_inverse=True, return_counts=True
)

exe = base.Executor(base.CPUPlace())
x_data = np.array([2, 3, 3, 1, 5, 3], dtype='int64')
result = exe.run(
feed={'input': x_data},
fetch_list=[out, inverse_indices, counts],
)

np.testing.assert_allclose(result[1], [1, 2, 2, 0, 3, 2])
np.testing.assert_allclose(result[2], [1, 1, 3, 1])

with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.static.data(name='input', shape=[3, 3], dtype='int64')
out = paddle.compat.unique(x)

exe = base.Executor(base.CPUPlace())
x_data = np.array([[2, 1, 3], [3, 0, 1], [2, 1, 3]], dtype='int64')
result = exe.run(feed={'input': x_data}, fetch_list=[out])

expected = np.array([0, 1, 2, 3], dtype='int64')
np.testing.assert_allclose(result[0], expected)

paddle.disable_static()


if __name__ == '__main__':
unittest.main()
45 changes: 45 additions & 0 deletions test/legacy_test/test_unique.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,5 +500,50 @@ def test_dygraph_api_out(self):
np.testing.assert_allclose(out.numpy(), expected_out)


class TestUniqueAPI_Compatibility(unittest.TestCase):
def setUp(self):
self.x_np = np.random.random(size=[3, 5]).astype("float32")
self.place = (
core.CUDAPlace(0)
if core.is_compiled_with_cuda()
else core.CPUPlace()
)

def test_dygraph(self):
paddle.disable_static()
out = paddle.unique(paddle.to_tensor(self.x_np))
expected_out = np.unique(self.x_np)
np.testing.assert_allclose(out.numpy(), expected_out)

def test_static(self):
paddle.enable_static()
x = paddle.static.data(name='x1', shape=[-1, 5], dtype='float32')
out1 = paddle.unique(x)
out2 = paddle.unique(x=x)
exe = paddle.static.Executor(self.place)
res = exe.run(
feed={
'x1': self.x_np.reshape(3, 5),
},
fetch_list=[out1, out2],
)
expected_out = np.unique(self.x_np)
for result in res:
np.testing.assert_array_equal(result, expected_out)
paddle.disable_static()

def test_dygraph_sorted(self):
paddle.disable_static()
out = paddle.unique(paddle.to_tensor(self.x_np), sorted=True)
expected_out = np.unique(self.x_np)
np.testing.assert_allclose(out.numpy(), expected_out)

def test_dygraph_axis(self):
paddle.disable_static()
out = paddle.unique(paddle.to_tensor(self.x_np), sorted=True, axis=1)
expected_out = np.unique(self.x_np, axis=1)
np.testing.assert_allclose(out.numpy(), expected_out)


if __name__ == "__main__":
unittest.main()
Loading