Skip to content
Merged
2 changes: 1 addition & 1 deletion python/paddle/distributed/auto_parallel/high_level_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ def to_distributed(
... )

... def forward(self, x):
... x = paddle.incubate.nn.functional.swiglu(
... x = paddle.nn.functional.swiglu(
... self.gate_proj(x), self.up_proj(x)
... )
... out = self.down_proj(x)
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/distributed/auto_parallel/static/pir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1592,7 +1592,7 @@ def fuse_attention_ffn_qkv_pass(
if add_gate is not None and add_up is not None:
fused_o = paddle.add(fused_o, fused_bias)
fused_o.get_defining_op().copy_attrs_from(add_gate)
out = paddle.incubate.nn.functional.swiglu(fused_o)
out = paddle.nn.functional.swiglu(fused_o)
out.get_defining_op().copy_attrs_from(pat[-1])
pat[-1].result(0).replace_all_uses_with(out)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ def build(self):
def apply(hidden_states, gate_weight, up_weight, down_weight):
gate = paddle.matmul(hidden_states, gate_weight)
up = paddle.matmul(hidden_states, up_weight)
tmp = paddle.incubate.nn.functional.swiglu(gate, up)
tmp = paddle.nn.functional.swiglu(gate, up)
out = paddle.matmul(tmp, down_weight)
return out

Expand Down
3 changes: 2 additions & 1 deletion python/paddle/incubate/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.


from .batched_gemm import batched_gemm
from .blha_get_max_len import blha_get_max_len
from .block_multihead_attention import (
Expand Down Expand Up @@ -109,12 +110,12 @@
"masked_multihead_attention",
"blha_get_max_len",
"block_multihead_attention",
"swiglu",
"moe_combine",
"expand_modality_expert_id",
"cal_aux_loss",
"build_src_rank_and_local_expert_id",
"int_bincount",
"swiglu",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because we no longer recommend using paddle.incubate.nn.functional.swiglu in the future, it can be removed from __all__ list here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"fused_rms_norm_ext",
"moe_gate_dispatch",
"moe_gate_dispatch_permute",
Expand Down
16 changes: 8 additions & 8 deletions python/paddle/incubate/nn/functional/swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,20 @@
from typing import TYPE_CHECKING

from paddle import _C_ops
from paddle.utils import deprecated

from ....framework import LayerHelper, in_dynamic_or_pir_mode
from ....framework import in_dynamic_or_pir_mode

if TYPE_CHECKING:
from paddle import Tensor


@deprecated(
since="3.3.0",
update_to="paddle.nn.functional.swiglu",
level=1,
reason="paddle.incubate.nn.functional.swiglu will be removed in future. Please use paddle.nn.functional.swiglu instead.",
)
def swiglu(
x: Tensor, y: Tensor | None = None, name: str | None = None
) -> Tensor:
Expand Down Expand Up @@ -56,10 +63,3 @@ def swiglu(
"""
if in_dynamic_or_pir_mode():
return _C_ops.swiglu(x, y)
else:
helper = LayerHelper("swiglu", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type="swiglu", inputs={"x": x, "y": y}, outputs={"out": out}
)
return out
2 changes: 2 additions & 0 deletions python/paddle/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
softplus,
softshrink,
softsign,
swiglu,
swish,
tanh,
tanh_,
Expand Down Expand Up @@ -207,6 +208,7 @@
'softsign',
'sigmoid',
'silu',
'swiglu',
'swish',
'mish',
'tanh',
Expand Down
35 changes: 35 additions & 0 deletions python/paddle/nn/functional/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1820,3 +1820,38 @@ def gumbel_softmax(
attrs={'temperature': temperature, 'hard': hard, 'axis': axis},
)
return out


def swiglu(
x: Tensor, y: Tensor | None = None, name: str | None = None
) -> Tensor:
"""
This function performs SwiGLU activation to the input Tensor.

.. math::

out = silu(x) * y when y is not None
out = silu(xs[0]) * xs[1] when y is None, where xs = paddle.chunk(x, 2, axis=-1)

Args:
x (Tensor): The first input Tensor of SwiGLU.
y (Tensor, optional): The second input Tensor of SwiGLU. Default: None.
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.

Returns:
A Tensor with the same data type with x and y.

Examples:
.. code-block:: python

>>> import paddle
>>> import paddle.nn.functional as F
>>> x = paddle.to_tensor([1, 2], dtype='float32')
>>> out1, out2 = F.swiglu(x), F.swiglu(x, x)
>>> print(out1, out2)
Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
[1.46211720]) Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
[0.73105860, 3.52318811])
"""
if in_dynamic_or_pir_mode():
return _C_ops.swiglu(x, y)
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,7 @@ def __init__(
)

def forward(self, x):
x = paddle.incubate.nn.functional.swiglu(
self.gate_proj(x), self.up_proj(x)
)
x = paddle.nn.functional.swiglu(self.gate_proj(x), self.up_proj(x))
out = self.down_proj(x)
return out

Expand Down
2 changes: 1 addition & 1 deletion test/ir/pir/cinn/llama_test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import paddle
import paddle.nn.functional as F
from paddle import nn
from paddle.incubate.nn.functional import swiglu
from paddle.nn.functional import swiglu

sys.path.append(dirname(__file__))

Expand Down
2 changes: 1 addition & 1 deletion test/ir/pir/cinn/symbolic/test_llama_group_swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self):
super().__init__()

def forward(self, x, y):
out = paddle.incubate.nn.functional.swiglu(x, y)
out = paddle.nn.functional.swiglu(x, y)

return out

Expand Down
2 changes: 1 addition & 1 deletion test/legacy_test/test_fused_swiglu_weighted_bwd_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np

import paddle
from paddle.incubate.nn.functional import swiglu
from paddle.nn.functional import swiglu


class TestFusedWeightedSwigluBwd(unittest.TestCase):
Expand Down
20 changes: 13 additions & 7 deletions test/legacy_test/test_swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
DistTensorSpec,
TensorDistAttr,
)
from paddle.incubate.nn.functional import swiglu as fused_swiglu_impl
from paddle.nn.functional import swiglu as fused_swiglu_impl


def swiglu(x, y, out_grad):
Expand Down Expand Up @@ -71,13 +71,13 @@ def swiglu(x, y, out_grad):
return ret


def fused_swiglu(x, y, out_grad):
def fused_swiglu(x, y, out_grad, swiglu_func=fused_swiglu_impl):
x = x.detach().clone()
x.stop_gradient = False
if y is not None:
y = y.detach().clone()
y.stop_gradient = False
out = fused_swiglu_impl(x, y)
out = swiglu_func(x, y)
out.backward(out_grad)

output_dtype = x.dtype
Expand Down Expand Up @@ -106,14 +106,20 @@ def fused_swiglu(x, y, out_grad):


class TestSwiGLUDygraph(unittest.TestCase):
def fused_swiglu(self, x, y, out_grad):
return fused_swiglu(x, y, out_grad)

def fused_swiglu_impl(self, x, y=None):
return fused_swiglu_impl(x, y)

def check_dygraph_impl(self, device, shape, dtype):
x = paddle.randn(shape, dtype=dtype)
y = paddle.randn(shape, dtype=dtype)
out_grad = paddle.randn(shape, dtype=dtype)

ret1 = swiglu(x, y, out_grad)
ret2 = fused_swiglu(x, y, out_grad)
ret3 = fused_swiglu(paddle.concat([x, y], axis=-1), None, out_grad)
ret2 = self.fused_swiglu(x, y, out_grad)
ret3 = self.fused_swiglu(paddle.concat([x, y], axis=-1), None, out_grad)

atol, rtol = tol_map[dtype]
err_msg = (
Expand Down Expand Up @@ -152,8 +158,8 @@ def check_static_graph(self, shape, dtype="float32"):
shape=[*shape[:-1], shape[-1] * 2],
dtype=dtype,
)
out1 = fused_swiglu_impl(x, y)
out2 = fused_swiglu_impl(concated_x)
out1 = self.fused_swiglu_impl(x, y)
out2 = self.fused_swiglu_impl(concated_x)

concated_x_np = np.random.random(concated_x.shape).astype(dtype)
x_np, y_np = np.split(concated_x_np, 2, axis=-1)
Expand Down
4 changes: 2 additions & 2 deletions test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,11 @@ def batch_norm_net4(x, y, z):


def swiglu_net1(x, y):
return paddle.incubate.nn.functional.swiglu(x, y)
return paddle.nn.functional.swiglu(x, y)


def swiglu_net2(x):
return paddle.incubate.nn.functional.swiglu(x)
return paddle.nn.functional.swiglu(x)


def squared_l2_norm_net(x):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,11 @@ def sum_net5(x):


def swiglu_net1(x, y):
return paddle.incubate.nn.functional.swiglu(x, y)
return paddle.nn.functional.swiglu(x, y)


def swiglu_net2(x):
return paddle.incubate.nn.functional.swiglu(x)
return paddle.nn.functional.swiglu(x)


def swish_net(x):
Expand Down
2 changes: 1 addition & 1 deletion test/xpu/test_swiglu_op_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import paddle
import paddle.nn.functional as F
from paddle.incubate.nn.functional import swiglu as fused_swiglu_impl
from paddle.nn.functional import swiglu as fused_swiglu_impl


def swiglu(x, y, out_grad):
Expand Down
Loading