diff --git a/tests/test_mem_eff_attention.py b/tests/test_mem_eff_attention.py index e1103c40a..0491f9e5c 100644 --- a/tests/test_mem_eff_attention.py +++ b/tests/test_mem_eff_attention.py @@ -3312,7 +3312,6 @@ def _merge_attentions_ref(attn_split, lse_split): @sm80_or_better_only -@skip_if_rocm # rocm doesn't support backward yet @pytest.mark.parametrize( "bias_t", [None, fmha.attn_bias.LowerTriangularMask, fmha.attn_bias.BlockDiagonalMask], diff --git a/xformers/csrc/attention/hip_fmha/attention_backward_generic_ck_tiled.cpp b/xformers/csrc/attention/hip_fmha/attention_backward_generic_ck_tiled.cpp index b470f5990..2bc96fa7e 100644 --- a/xformers/csrc/attention/hip_fmha/attention_backward_generic_ck_tiled.cpp +++ b/xformers/csrc/attention/hip_fmha/attention_backward_generic_ck_tiled.cpp @@ -547,6 +547,86 @@ efficient_attention_backward_ck( return std::make_tuple(grad_q, grad_k, grad_v, grad_bias); } +std::tuple +efficient_attention_backward_ck_meta( + const at::Tensor& grad_out, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const c10::optional& bias, // additive attention bias + // (Mode 1MHK only) [b+1]: cu_seqlens_q[b] contains the + // position of the first query token for batch $b + const c10::optional& seqstart_q, + // (Mode 1MHK only) [b+1]: cu_seqlens_k[b] contains the + // position of the first key token for batch $b + const c10::optional& seqstart_k, + // (Mode 1MHK only) Maximum sequence length across batches + const c10::optional max_seqlen_q_, + // (Mode 1MHK only) Maximum sequence length across batches + const c10::optional max_seqlen_k_, + const c10::optional& seqlen_k, + const at::Tensor& logsumexp, + const at::Tensor& out, + double dropout_p, // dropout probability + int64_t rng_seed, // seed using for generating random numbers for dropout + int64_t rng_offset, // offset into random number sequence + int64_t custom_mask_type, + const c10::optional scale, + const c10::optional window_size) { + int64_t B = query.size(0); + int64_t M = query.size(1); + int64_t N = key.size(1); + int64_t Hq = query.size(2); + int64_t Hkv = key.size(2); + int64_t K = query.size(3); + int64_t Kv = value.size(3); + + auto opts = query.options(); + + at::Tensor grad_q, grad_k, grad_v, grad_bias; + + if (query.size(1) == key.size(1) && query.size(3) == value.size(3) && + query.size(2) == key.size(2) && + query.storage().is_alias_of(key.storage()) && + query.storage().is_alias_of(value.storage())) { + // Create one big contiguous chunk for grad_q, grad_k, grad_v + // This is because q, k and v usually come from a single + // output of a linear layer that is chunked. + // Creating the gradients with the right layout saves us + // a `torch.cat` call in the backward pass + at::Tensor chunk = at::empty({B, M, 3, Hq, K}, opts); + grad_q = chunk.select(2, 0); + grad_k = chunk.select(2, 1); + grad_v = chunk.select(2, 2); + } else if ( + key.size(3) == value.size(3) && + key.storage().is_alias_of(value.storage())) { + // Create one big contiguous chunk for grad_k, grad_v + // This is because k and v usually come from a single + // output of a linear layer that is chunked. + // Creating the gradients with the right layout saves us + // a `torch.cat` call in the backward pass + at::Tensor chunk = at::empty({B, N, 2, Hkv, Kv}, opts); + grad_k = chunk.select(2, 0); + grad_v = chunk.select(2, 1); + + grad_q = at::empty_strided(query.sizes(), query.strides(), query.options()); + } else { + grad_q = at::empty_strided(query.sizes(), query.strides(), query.options()); + grad_k = at::empty_strided(key.sizes(), key.strides(), key.options()); + grad_v = at::empty_strided(value.sizes(), value.strides(), value.options()); + } + + const bool bias_requires_grad = bias.has_value() && bias->requires_grad(); + // even it is an output, the grad_bias is required to use the same data-type + // as bias in CK-FlashAttn + if (bias_requires_grad) { + grad_bias = + at::empty_strided(bias->sizes(), bias->strides(), bias->options()); + } + return std::make_tuple(grad_q, grad_k, grad_v, grad_bias); +} + } // namespace TORCH_LIBRARY_IMPL(xformers, CUDA, m) { @@ -554,3 +634,9 @@ TORCH_LIBRARY_IMPL(xformers, CUDA, m) { TORCH_SELECTIVE_NAME("xformers::efficient_attention_backward_ck"), TORCH_FN(efficient_attention_backward_ck)); } + +TORCH_LIBRARY_IMPL(xformers, Meta, m) { + m.impl( + TORCH_SELECTIVE_NAME("xformers::efficient_attention_backward_ck"), + TORCH_FN(efficient_attention_backward_ck_meta)); +} diff --git a/xformers/csrc/attention/hip_fmha/attention_forward_generic_ck_tiled.cpp b/xformers/csrc/attention/hip_fmha/attention_forward_generic_ck_tiled.cpp index 4bbfe71ad..b17c036ae 100644 --- a/xformers/csrc/attention/hip_fmha/attention_forward_generic_ck_tiled.cpp +++ b/xformers/csrc/attention/hip_fmha/attention_forward_generic_ck_tiled.cpp @@ -374,6 +374,55 @@ efficient_attention_forward_ck( return std::make_tuple(out, logsumexp, philox_seed, philox_offset); } +/* + There are 2 modes for using this function. + (Mode BMHK) With all the heads having the same seqlen + (Mode 1MHK) `batch=1` with all tokens across batches concatenated +*/ +std::tuple +efficient_attention_forward_ck_meta( + const at::Tensor& query, // [b, seqlen, num_heads_q, K] + const at::Tensor& key, // [b, seqlen, num_heads_kv, K] + const at::Tensor& value, // [b, seqlen, num_heads_kv, Kv] + const c10::optional& bias, // [b, num_heads_q, seqlen, seqlen] + // (Mode 1MHK only) [b+1]: cu_seqlens_q[b] contains the + // position of the first query token for batch $b + const c10::optional& seqstart_q, + // (Mode 1MHK only) [b+1]: cu_seqlen_k[b] contains the + // position of the first key token for batch $b + const c10::optional& seqstart_k, + // (Mode 1MHK only) Maximum sequence length across batches + const c10::optional max_seqlen_q_, + double dropout_p, // attention matrix dropout probability + bool compute_logsumexp, + int64_t custom_mask_type, + c10::optional scale, + const c10::optional& seqlen_k, + const c10::optional window_size) { + int64_t B = query.size(0); + int64_t M = query.size(1); + int64_t N = key.size(1); + int64_t Hq = query.size(-2); + int64_t Hkv = key.size(-2); + int64_t K = query.size(-1); + int64_t Kv = value.size(-1); + auto opts = query.options(); + at::Tensor logsumexp; + at::Tensor out = at::empty({B, M, Hq, Kv}, opts); + int64_t philox_seed = 0; + int64_t philox_offset = 0; + if (!seqstart_q.has_value()) { // input is batched + if (compute_logsumexp) { + logsumexp = at::empty({B, Hq, M}, opts.dtype(at::kFloat)); + } + } else { + if (compute_logsumexp) { + logsumexp = at::empty({1, Hq, M}, opts.dtype(at::kFloat)); + } + } + return std::make_tuple(out, logsumexp, philox_seed, philox_offset); +} + } // namespace TORCH_LIBRARY_IMPL(xformers, CUDA, m) { @@ -381,3 +430,9 @@ TORCH_LIBRARY_IMPL(xformers, CUDA, m) { TORCH_SELECTIVE_NAME("xformers::efficient_attention_forward_ck"), TORCH_FN(efficient_attention_forward_ck)); } + +TORCH_LIBRARY_IMPL(xformers, Meta, m) { + m.impl( + TORCH_SELECTIVE_NAME("xformers::efficient_attention_forward_ck"), + TORCH_FN(efficient_attention_forward_ck_meta)); +}