Skip to content

Commit

Permalink
Restore name of arguments in flash op (fairinternal/xformers#1205)
Browse files Browse the repository at this point in the history
__original_commit__ = fairinternal/xformers@4eda873
  • Loading branch information
danthe3rd authored and xFormers Bot committed Aug 22, 2024
1 parent 57227c6 commit 6d2200c
Showing 1 changed file with 35 additions and 35 deletions.
70 changes: 35 additions & 35 deletions xformers/ops/fmha/flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,11 @@ def _flash_fwd(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seq_lens_q: Optional[torch.Tensor],
cu_seq_lens_k: Optional[torch.Tensor],
cu_seqlens_q: Optional[torch.Tensor],
cu_seqlens_k: Optional[torch.Tensor],
seqused_k: Optional[torch.Tensor],
max_seq_len_q: int,
max_seq_len_k: int,
max_seqlen_q: int,
max_seqlen_k: int,
p: float,
softmax_scale: float,
is_causal: bool,
Expand All @@ -112,10 +112,10 @@ def _flash_fwd(
query,
key,
value,
cu_seq_lens_q, # cum_seq_q
cu_seq_lens_k, # cum_seq_k
max_seq_len_q, # max_q
max_seq_len_k, # max_k
cu_seqlens_q, # cum_seq_q
cu_seqlens_k, # cum_seq_k
max_seqlen_q, # max_q
max_seqlen_k, # max_k
p, # dropout_p
is_causal,
return_debug_mask=False,
Expand All @@ -128,8 +128,8 @@ def _flash_fwd(
rng_state = torch.stack([philox_seed, philox_offset])
return attention, logsumexp, rng_state
else:
if cu_seq_lens_q is None:
assert cu_seq_lens_k is None
if cu_seqlens_q is None:
assert cu_seqlens_k is None
assert seqused_k is None
(
out,
Expand Down Expand Up @@ -170,14 +170,14 @@ def _flash_fwd(
key,
value,
None, # out
cu_seq_lens_q,
cu_seq_lens_k,
cu_seqlens_q,
cu_seqlens_k,
seqused_k,
None, # leftpad_k_
block_tables,
None, # alibi_slopes
max_seq_len_q,
max_seq_len_k,
max_seqlen_q,
max_seqlen_k,
p,
softmax_scale,
False,
Expand All @@ -195,11 +195,11 @@ def _flash_fwd_abstract(
query,
key,
value,
cu_seq_lens_q,
cu_seq_lens_k,
cu_seqlens_q,
cu_seqlens_k,
seqused_k,
max_seq_len_q,
max_seq_len_k,
max_seqlen_q,
max_seqlen_k,
p,
softmax_scale,
is_causal,
Expand All @@ -209,16 +209,16 @@ def _flash_fwd_abstract(
block_tables,
):
out = torch.empty_like(query)
if cu_seq_lens_q is None:
if cu_seqlens_q is None:
B, M, H, K = query.shape
lse_shape = [B, H, M]
else:
M, H, K = query.shape
B = cu_seq_lens_q.shape[0] - 1
B = cu_seqlens_q.shape[0] - 1
if VARLEN_LSE_PACKED:
lse_shape = [H, M]
else:
lse_shape = [B, H, max_seq_len_q]
lse_shape = [B, H, max_seqlen_q]
softmax_lse = torch.empty(lse_shape, device=query.device, dtype=torch.float32)
rng_state = torch.empty([2], device=query.device, dtype=torch.int64)
return out, softmax_lse, rng_state
Expand All @@ -236,10 +236,10 @@ def _flash_bwd(
value: torch.Tensor,
out: torch.Tensor,
lse: torch.Tensor,
cu_seq_lens_q: torch.Tensor,
cu_seq_lens_k: torch.Tensor,
max_seq_len_q: int,
max_seq_len_k: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
p: float,
softmax_scale: float,
is_causal: bool,
Expand All @@ -262,10 +262,10 @@ def _flash_bwd(
value,
out,
lse,
cu_seq_lens_q,
cu_seq_lens_k,
max_seq_len_q,
max_seq_len_k,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
p,
is_causal,
philox_seed,
Expand All @@ -276,8 +276,8 @@ def _flash_bwd(
)
else:
dq, dk, dv = _create_dq_dk_dv(grads_share_storage, query, key, value)
if cu_seq_lens_k is None:
assert cu_seq_lens_q is None
if cu_seqlens_k is None:
assert cu_seqlens_q is None
_C_flashattention.bwd(
grad,
query,
Expand Down Expand Up @@ -310,11 +310,11 @@ def _flash_bwd(
dq,
dk,
dv,
cu_seq_lens_q,
cu_seq_lens_k,
cu_seqlens_q,
cu_seqlens_k,
None, # alibi_slopes
max_seq_len_q,
max_seq_len_k,
max_seqlen_q,
max_seqlen_k,
p,
softmax_scale,
False, # zero_tensors
Expand Down

0 comments on commit 6d2200c

Please sign in to comment.