Skip to content
Merged
Next Next commit
Migrate swiglu from incubate
  • Loading branch information
DanielSun11 committed Dec 9, 2025
commit d9ab2b343407f805bbeb02d9c04647e856403cd6
1 change: 1 addition & 0 deletions python/paddle/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
softplus,
softshrink,
softsign,
swiglu, # noqa: F401
swish,
tanh,
tanh_,
Expand Down
42 changes: 42 additions & 0 deletions python/paddle/nn/functional/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1820,3 +1820,45 @@ def gumbel_softmax(
attrs={'temperature': temperature, 'hard': hard, 'axis': axis},
)
return out


def swiglu(
x: Tensor, y: Tensor | None = None, name: str | None = None
) -> Tensor:
"""
This function performs SwiGLU activation to the input Tensor.

.. math::

out = silu(x) * y when y is not None
out = silu(xs[0]) * xs[1] when y is None, where xs = paddle.chunk(x, 2, axis=-1)

Args:
x (Tensor): The first input Tensor of SwiGLU.
y (Tensor, optional): The second input Tensor of SwiGLU. Default: None.
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.

Returns:
A Tensor with the same data type with x and y.

Examples:
.. code-block:: python

>>> import paddle
>>> import paddle.incubate.nn.functional as F
>>> x = paddle.to_tensor([1, 2], dtype='float32')
>>> out1, out2 = F.swiglu(x), F.swiglu(x, x)
>>> print(out1, out2)
Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
[1.46211720]) Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
[0.73105860, 3.52318811])
"""
if in_dynamic_or_pir_mode():
return _C_ops.swiglu(x, y)
else:
helper = LayerHelper("swiglu", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type="swiglu", inputs={"x": x, "y": y}, outputs={"out": out}
)
return out
47 changes: 41 additions & 6 deletions test/legacy_test/test_swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
TensorDistAttr,
)
from paddle.incubate.nn.functional import swiglu as fused_swiglu_impl
from paddle.nn.functional import swiglu as swiglu_activation


def swiglu(x, y, out_grad):
Expand Down Expand Up @@ -71,13 +72,13 @@ def swiglu(x, y, out_grad):
return ret


def fused_swiglu(x, y, out_grad):
def fused_swiglu(x, y, out_grad, swiglu_func=fused_swiglu_impl):
x = x.detach().clone()
x.stop_gradient = False
if y is not None:
y = y.detach().clone()
y.stop_gradient = False
out = fused_swiglu_impl(x, y)
out = swiglu_func(x, y)
out.backward(out_grad)

output_dtype = x.dtype
Expand Down Expand Up @@ -106,14 +107,20 @@ def fused_swiglu(x, y, out_grad):


class TestSwiGLUDygraph(unittest.TestCase):
def fused_swiglu(self, x, y, out_grad):
return fused_swiglu(x, y, out_grad)

def fused_swiglu_impl(self, x, y=None):
return fused_swiglu_impl(x, y)

def check_dygraph_impl(self, device, shape, dtype):
x = paddle.randn(shape, dtype=dtype)
y = paddle.randn(shape, dtype=dtype)
out_grad = paddle.randn(shape, dtype=dtype)

ret1 = swiglu(x, y, out_grad)
ret2 = fused_swiglu(x, y, out_grad)
ret3 = fused_swiglu(paddle.concat([x, y], axis=-1), None, out_grad)
ret2 = self.fused_swiglu(x, y, out_grad)
ret3 = self.fused_swiglu(paddle.concat([x, y], axis=-1), None, out_grad)

atol, rtol = tol_map[dtype]
err_msg = (
Expand Down Expand Up @@ -152,8 +159,8 @@ def check_static_graph(self, shape, dtype="float32"):
shape=[*shape[:-1], shape[-1] * 2],
dtype=dtype,
)
out1 = fused_swiglu_impl(x, y)
out2 = fused_swiglu_impl(concated_x)
out1 = self.fused_swiglu_impl(x, y)
out2 = self.fused_swiglu_impl(concated_x)

concated_x_np = np.random.random(concated_x.shape).astype(dtype)
x_np, y_np = np.split(concated_x_np, 2, axis=-1)
Expand All @@ -179,6 +186,14 @@ def test_main(self):
self.check_main([4, 101])


class TestNNActivationSwiGLUDygraph(unittest.TestCase):
def fused_swiglu(self, x, y, out_grad):
return fused_swiglu(x, y, out_grad, swiglu_func=swiglu_activation)

def fused_swiglu_impl(self, x, y=None):
return swiglu_activation(x, y)


class TestSwigluOp(OpTest):
def config(self):
self.x_shape = (8, 128)
Expand Down Expand Up @@ -215,6 +230,26 @@ def test_check_grad(self):
)


class TestNNActivationSwigluOp(TestSwigluOp):
def setUp(self):
self.config()
self.op_type = "swiglu"
self.prim_op_type = "comp"
self.python_api = swiglu_activation
self.public_python_api = swiglu_activation
x = np.random.uniform(-1, 1, self.x_shape).astype("float64")
y = np.random.uniform(-1, 1, self.x_shape).astype("float64")
out_grad = np.random.uniform(-1, 1, self.x_shape).astype("float64")
res = swiglu(x, y, out_grad)
self.inputs = {'x': x, 'y': y}
self.outputs = {'out': res[0].numpy()}
self.placements = {
'x': [dist.Shard(1)],
'y': [dist.Shard(1)],
'out': [dist.Shard(1)],
}


class TestSwigluOp2(TestSwigluOp):
def setUp(self):
self.config()
Expand Down
Loading