Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support pp on xpu #1014

Open
wants to merge 9 commits into
base: develop
Choose a base branch
from
16 changes: 16 additions & 0 deletions ppfleetx/models/language_model/gpt/dygraph/hybrid_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
from ppfleetx.distributed.apis import env
from ppfleetx.utils.log import logger

import numpy as np


def get_attr(layer, name):
if getattr(layer, name, None) is not None:
Expand Down Expand Up @@ -285,6 +287,10 @@ def core_attn(self, q, k, v, attn_mask=None):
# scale dot product attention
product = paddle.matmul(
x=q, y=k, transpose_y=True) * self.head_dim**-0.5

# softmax_mask_fuse_upper_triangle is not supported sif paddle is not compiled with cuda/rocm
if not paddle.is_compiled_with_cuda():
attn_mask = get_triangle_upper_mask(product, attn_mask)

if attn_mask is not None:
product = product + attn_mask
Expand Down Expand Up @@ -1592,3 +1598,13 @@ def forward(self, input_ids=None, **model_kwargs):
else:
raise ValueError(f'Not support {decoding_strategy} strategy yet!')
return ret


def get_triangle_upper_mask(x, mask):
if mask is not None:
return mask
mask = paddle.full_like(x, -np.inf)
mask.stop_gradient = True
mask = paddle.triu(mask, diagonal=1)
mask.stop_gradient = True
return mask
4 changes: 2 additions & 2 deletions ppfleetx/optims/grad_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
# limitations under the License.

import paddle
from paddle.fluid.clip import ClipGradByGlobalNorm
from paddle.nn.clip import ClipGradByGlobalNorm

from paddle.fluid.clip import ClipGradBase, _squared_l2_norm
from paddle.nn.clip import ClipGradBase, _squared_l2_norm
from paddle.fluid.dygraph import base as imperative_base
from paddle.fluid import core, layers
from paddle.distributed import collective
Expand Down