Skip to content
Merged
2 changes: 1 addition & 1 deletion python/paddle/distributed/auto_parallel/high_level_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ def to_distributed(
... )

... def forward(self, x):
... x = paddle.incubate.nn.functional.swiglu(
... x = paddle.nn.functional.swiglu(
... self.gate_proj(x), self.up_proj(x)
... )
... out = self.down_proj(x)
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/distributed/auto_parallel/static/pir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1592,7 +1592,7 @@ def fuse_attention_ffn_qkv_pass(
if add_gate is not None and add_up is not None:
fused_o = paddle.add(fused_o, fused_bias)
fused_o.get_defining_op().copy_attrs_from(add_gate)
out = paddle.incubate.nn.functional.swiglu(fused_o)
out = paddle.nn.functional.swiglu(fused_o)
out.get_defining_op().copy_attrs_from(pat[-1])
pat[-1].result(0).replace_all_uses_with(out)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ def build(self):
def apply(hidden_states, gate_weight, up_weight, down_weight):
gate = paddle.matmul(hidden_states, gate_weight)
up = paddle.matmul(hidden_states, up_weight)
tmp = paddle.incubate.nn.functional.swiglu(gate, up)
tmp = paddle.nn.functional.swiglu(gate, up)
out = paddle.matmul(tmp, down_weight)
return out

Expand Down
2 changes: 1 addition & 1 deletion python/paddle/incubate/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.


from .batched_gemm import batched_gemm
from .blha_get_max_len import blha_get_max_len
from .block_multihead_attention import (
Expand Down Expand Up @@ -109,7 +110,6 @@
"masked_multihead_attention",
"blha_get_max_len",
"block_multihead_attention",
"swiglu",
"moe_combine",
"expand_modality_expert_id",
"cal_aux_loss",
Expand Down
16 changes: 8 additions & 8 deletions python/paddle/incubate/nn/functional/swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,20 @@
from typing import TYPE_CHECKING

from paddle import _C_ops
from paddle.utils import deprecated

from ....framework import LayerHelper, in_dynamic_or_pir_mode
from ....framework import in_dynamic_or_pir_mode

if TYPE_CHECKING:
from paddle import Tensor


@deprecated(
since="3.3.0",
update_to="paddle.nn.functional.swiglu",
level=1,
reason="paddle.incubate.nn.functional.swiglu will be removed in future. Please use paddle.nn.functional.swiglu instead.",
)
def swiglu(
x: Tensor, y: Tensor | None = None, name: str | None = None
) -> Tensor:
Expand Down Expand Up @@ -56,10 +63,3 @@ def swiglu(
"""
if in_dynamic_or_pir_mode():
return _C_ops.swiglu(x, y)
else:
helper = LayerHelper("swiglu", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type="swiglu", inputs={"x": x, "y": y}, outputs={"out": out}
)
return out
2 changes: 2 additions & 0 deletions python/paddle/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
softplus,
softshrink,
softsign,
swiglu,
swish,
tanh,
tanh_,
Expand Down Expand Up @@ -207,6 +208,7 @@
'softsign',
'sigmoid',
'silu',
'swiglu',
'swish',
'mish',
'tanh',
Expand Down
35 changes: 35 additions & 0 deletions python/paddle/nn/functional/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1820,3 +1820,38 @@ def gumbel_softmax(
attrs={'temperature': temperature, 'hard': hard, 'axis': axis},
)
return out


def swiglu(
x: Tensor, y: Tensor | None = None, name: str | None = None
) -> Tensor:
"""
This function performs SwiGLU activation to the input Tensor.

.. math::

out = silu(x) * y when y is not None
out = silu(xs[0]) * xs[1] when y is None, where xs = paddle.chunk(x, 2, axis=-1)

Args:
x (Tensor): The first input Tensor of SwiGLU.
y (Tensor, optional): The second input Tensor of SwiGLU. Default: None.
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.

Returns:
A Tensor with the same data type with x and y.

Examples:
.. code-block:: python

>>> import paddle
>>> import paddle.nn.functional as F
>>> x = paddle.to_tensor([1, 2], dtype='float32')
>>> out1, out2 = F.swiglu(x), F.swiglu(x, x)
>>> print(out1, out2)
Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
[1.46211720]) Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
[0.73105860, 3.52318811])
"""
if in_dynamic_or_pir_mode():
return _C_ops.swiglu(x, y)
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,7 @@ def __init__(
)

def forward(self, x):
x = paddle.incubate.nn.functional.swiglu(
self.gate_proj(x), self.up_proj(x)
)
x = paddle.nn.functional.swiglu(self.gate_proj(x), self.up_proj(x))
out = self.down_proj(x)
return out

Expand Down
2 changes: 1 addition & 1 deletion test/ir/pir/cinn/llama_test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import paddle
import paddle.nn.functional as F
from paddle import nn
from paddle.incubate.nn.functional import swiglu
from paddle.nn.functional import swiglu

sys.path.append(dirname(__file__))

Expand Down
2 changes: 1 addition & 1 deletion test/ir/pir/cinn/symbolic/test_llama_group_swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self):
super().__init__()

def forward(self, x, y):
out = paddle.incubate.nn.functional.swiglu(x, y)
out = paddle.nn.functional.swiglu(x, y)

return out

Expand Down
2 changes: 1 addition & 1 deletion test/legacy_test/test_fused_swiglu_weighted_bwd_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np

import paddle
from paddle.incubate.nn.functional import swiglu
from paddle.nn.functional import swiglu


class TestFusedWeightedSwigluBwd(unittest.TestCase):
Expand Down
20 changes: 13 additions & 7 deletions test/legacy_test/test_swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
DistTensorSpec,
TensorDistAttr,
)
from paddle.incubate.nn.functional import swiglu as fused_swiglu_impl
from paddle.nn.functional import swiglu as fused_swiglu_impl


def swiglu(x, y, out_grad):
Expand Down Expand Up @@ -71,13 +71,13 @@ def swiglu(x, y, out_grad):
return ret


def fused_swiglu(x, y, out_grad):
def fused_swiglu(x, y, out_grad, swiglu_func=fused_swiglu_impl):
x = x.detach().clone()
x.stop_gradient = False
if y is not None:
y = y.detach().clone()
y.stop_gradient = False
out = fused_swiglu_impl(x, y)
out = swiglu_func(x, y)
out.backward(out_grad)

output_dtype = x.dtype
Expand Down Expand Up @@ -106,14 +106,20 @@ def fused_swiglu(x, y, out_grad):


class TestSwiGLUDygraph(unittest.TestCase):
def fused_swiglu(self, x, y, out_grad):
return fused_swiglu(x, y, out_grad)

def fused_swiglu_impl(self, x, y=None):
return fused_swiglu_impl(x, y)

def check_dygraph_impl(self, device, shape, dtype):
x = paddle.randn(shape, dtype=dtype)
y = paddle.randn(shape, dtype=dtype)
out_grad = paddle.randn(shape, dtype=dtype)

ret1 = swiglu(x, y, out_grad)
ret2 = fused_swiglu(x, y, out_grad)
ret3 = fused_swiglu(paddle.concat([x, y], axis=-1), None, out_grad)
ret2 = self.fused_swiglu(x, y, out_grad)
ret3 = self.fused_swiglu(paddle.concat([x, y], axis=-1), None, out_grad)

atol, rtol = tol_map[dtype]
err_msg = (
Expand Down Expand Up @@ -152,8 +158,8 @@ def check_static_graph(self, shape, dtype="float32"):
shape=[*shape[:-1], shape[-1] * 2],
dtype=dtype,
)
out1 = fused_swiglu_impl(x, y)
out2 = fused_swiglu_impl(concated_x)
out1 = self.fused_swiglu_impl(x, y)
out2 = self.fused_swiglu_impl(concated_x)

concated_x_np = np.random.random(concated_x.shape).astype(dtype)
x_np, y_np = np.split(concated_x_np, 2, axis=-1)
Expand Down
4 changes: 2 additions & 2 deletions test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,11 @@ def batch_norm_net4(x, y, z):


def swiglu_net1(x, y):
return paddle.incubate.nn.functional.swiglu(x, y)
return paddle.nn.functional.swiglu(x, y)


def swiglu_net2(x):
return paddle.incubate.nn.functional.swiglu(x)
return paddle.nn.functional.swiglu(x)


def squared_l2_norm_net(x):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,11 @@ def sum_net5(x):


def swiglu_net1(x, y):
return paddle.incubate.nn.functional.swiglu(x, y)
return paddle.nn.functional.swiglu(x, y)


def swiglu_net2(x):
return paddle.incubate.nn.functional.swiglu(x)
return paddle.nn.functional.swiglu(x)


def swish_net(x):
Expand Down
2 changes: 1 addition & 1 deletion test/xpu/test_swiglu_op_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import paddle
import paddle.nn.functional as F
from paddle.incubate.nn.functional import swiglu as fused_swiglu_impl
from paddle.nn.functional import swiglu as fused_swiglu_impl


def swiglu(x, y, out_grad):
Expand Down