Skip to content

Commit

Permalink
complete BlockDiagonalGappy/BlockDiagonalPadded support on CK (#1108)
Browse files Browse the repository at this point in the history
  • Loading branch information
qianfengz authored Sep 20, 2024
1 parent d3948b5 commit 8c24081
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions xformers/ops/fmha/ck.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
BlockDiagonalCausalWithOffsetPaddedKeysMask,
BlockDiagonalGappyKeysMask,
BlockDiagonalMask,
BlockDiagonalPaddedKeysMask,
LowerTriangularFromBottomRightLocalAttentionMask,
LowerTriangularFromBottomRightMask,
LowerTriangularMask,
Expand All @@ -48,7 +49,8 @@ def _get_seqlen_info(
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], int, int]:
attn_bias = inp.attn_bias
if isinstance(
attn_bias, (BlockDiagonalMask, BlockDiagonalCausalWithOffsetPaddedKeysMask)
attn_bias,
(BlockDiagonalMask, BlockDiagonalPaddedKeysMask, BlockDiagonalGappyKeysMask),
):
attn_bias.k_seqinfo.to(inp.query.device)
attn_bias.q_seqinfo.to(inp.query.device)
Expand Down Expand Up @@ -159,6 +161,7 @@ class FwOp(AttentionFwOpBase):
BlockDiagonalCausalWithOffsetGappyKeysMask,
BlockDiagonalCausalWithOffsetPaddedKeysMask,
BlockDiagonalGappyKeysMask,
BlockDiagonalPaddedKeysMask,
attn_bias.BlockDiagonalCausalFromBottomRightMask,
attn_bias.BlockDiagonalCausalLocalAttentionMask,
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
Expand All @@ -167,6 +170,7 @@ class FwOp(AttentionFwOpBase):
SUPPORTS_DROPOUT = True
SUPPORTS_CUSTOM_SCALE = True
SUPPORTS_DIFFERENT_VALUE_EMBED = True
SUPPORTS_PARTIAL = True
SUPPORTS_BMGHK = True
NAME = "ckF"

Expand Down Expand Up @@ -273,7 +277,11 @@ def apply_bmhk(
seqlen_k=(
inp.attn_bias.k_seqinfo.seqlen
if isinstance(
inp.attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask
inp.attn_bias,
(
BlockDiagonalGappyKeysMask,
BlockDiagonalPaddedKeysMask,
),
)
else None
),
Expand Down Expand Up @@ -417,7 +425,11 @@ def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
seqlen_k=(
inp.attn_bias.k_seqinfo.seqlen
if isinstance(
inp.attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask
inp.attn_bias,
(
BlockDiagonalGappyKeysMask,
BlockDiagonalPaddedKeysMask,
),
)
else None
),
Expand Down

0 comments on commit 8c24081

Please sign in to comment.