Skip to content

Commit

Permalink
Add torch compile support for ck attention op (#1085)
Browse files Browse the repository at this point in the history
  • Loading branch information
jianyuh authored Aug 26, 2024
1 parent 616e5bd commit 00aba59
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 1 deletion.
1 change: 0 additions & 1 deletion tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3312,7 +3312,6 @@ def _merge_attentions_ref(attn_split, lse_split):


@sm80_or_better_only
@skip_if_rocm # rocm doesn't support backward yet
@pytest.mark.parametrize(
"bias_t",
[None, fmha.attn_bias.LowerTriangularMask, fmha.attn_bias.BlockDiagonalMask],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -547,10 +547,96 @@ efficient_attention_backward_ck(
return std::make_tuple(grad_q, grad_k, grad_v, grad_bias);
}

std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
efficient_attention_backward_ck_meta(
const at::Tensor& grad_out,
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const c10::optional<at::Tensor>& bias, // additive attention bias
// (Mode 1MHK only) [b+1]: cu_seqlens_q[b] contains the
// position of the first query token for batch $b
const c10::optional<at::Tensor>& seqstart_q,
// (Mode 1MHK only) [b+1]: cu_seqlens_k[b] contains the
// position of the first key token for batch $b
const c10::optional<at::Tensor>& seqstart_k,
// (Mode 1MHK only) Maximum sequence length across batches
const c10::optional<int64_t> max_seqlen_q_,
// (Mode 1MHK only) Maximum sequence length across batches
const c10::optional<int64_t> max_seqlen_k_,
const c10::optional<at::Tensor>& seqlen_k,
const at::Tensor& logsumexp,
const at::Tensor& out,
double dropout_p, // dropout probability
int64_t rng_seed, // seed using for generating random numbers for dropout
int64_t rng_offset, // offset into random number sequence
int64_t custom_mask_type,
const c10::optional<double> scale,
const c10::optional<int64_t> window_size) {
int64_t B = query.size(0);
int64_t M = query.size(1);
int64_t N = key.size(1);
int64_t Hq = query.size(2);
int64_t Hkv = key.size(2);
int64_t K = query.size(3);
int64_t Kv = value.size(3);

auto opts = query.options();

at::Tensor grad_q, grad_k, grad_v, grad_bias;

if (query.size(1) == key.size(1) && query.size(3) == value.size(3) &&
query.size(2) == key.size(2) &&
query.storage().is_alias_of(key.storage()) &&
query.storage().is_alias_of(value.storage())) {
// Create one big contiguous chunk for grad_q, grad_k, grad_v
// This is because q, k and v usually come from a single
// output of a linear layer that is chunked.
// Creating the gradients with the right layout saves us
// a `torch.cat` call in the backward pass
at::Tensor chunk = at::empty({B, M, 3, Hq, K}, opts);
grad_q = chunk.select(2, 0);
grad_k = chunk.select(2, 1);
grad_v = chunk.select(2, 2);
} else if (
key.size(3) == value.size(3) &&
key.storage().is_alias_of(value.storage())) {
// Create one big contiguous chunk for grad_k, grad_v
// This is because k and v usually come from a single
// output of a linear layer that is chunked.
// Creating the gradients with the right layout saves us
// a `torch.cat` call in the backward pass
at::Tensor chunk = at::empty({B, N, 2, Hkv, Kv}, opts);
grad_k = chunk.select(2, 0);
grad_v = chunk.select(2, 1);

grad_q = at::empty_strided(query.sizes(), query.strides(), query.options());
} else {
grad_q = at::empty_strided(query.sizes(), query.strides(), query.options());
grad_k = at::empty_strided(key.sizes(), key.strides(), key.options());
grad_v = at::empty_strided(value.sizes(), value.strides(), value.options());
}

const bool bias_requires_grad = bias.has_value() && bias->requires_grad();
// even it is an output, the grad_bias is required to use the same data-type
// as bias in CK-FlashAttn
if (bias_requires_grad) {
grad_bias =
at::empty_strided(bias->sizes(), bias->strides(), bias->options());
}
return std::make_tuple(grad_q, grad_k, grad_v, grad_bias);
}

} // namespace

TORCH_LIBRARY_IMPL(xformers, CUDA, m) {
m.impl(
TORCH_SELECTIVE_NAME("xformers::efficient_attention_backward_ck"),
TORCH_FN(efficient_attention_backward_ck));
}

TORCH_LIBRARY_IMPL(xformers, Meta, m) {
m.impl(
TORCH_SELECTIVE_NAME("xformers::efficient_attention_backward_ck"),
TORCH_FN(efficient_attention_backward_ck_meta));
}
Original file line number Diff line number Diff line change
Expand Up @@ -374,10 +374,65 @@ efficient_attention_forward_ck(
return std::make_tuple(out, logsumexp, philox_seed, philox_offset);
}

/*
There are 2 modes for using this function.
(Mode BMHK) With all the heads having the same seqlen
(Mode 1MHK) `batch=1` with all tokens across batches concatenated
*/
std::tuple<at::Tensor, at::Tensor, int64_t, int64_t>
efficient_attention_forward_ck_meta(
const at::Tensor& query, // [b, seqlen, num_heads_q, K]
const at::Tensor& key, // [b, seqlen, num_heads_kv, K]
const at::Tensor& value, // [b, seqlen, num_heads_kv, Kv]
const c10::optional<at::Tensor>& bias, // [b, num_heads_q, seqlen, seqlen]
// (Mode 1MHK only) [b+1]: cu_seqlens_q[b] contains the
// position of the first query token for batch $b
const c10::optional<at::Tensor>& seqstart_q,
// (Mode 1MHK only) [b+1]: cu_seqlen_k[b] contains the
// position of the first key token for batch $b
const c10::optional<at::Tensor>& seqstart_k,
// (Mode 1MHK only) Maximum sequence length across batches
const c10::optional<int64_t> max_seqlen_q_,
double dropout_p, // attention matrix dropout probability
bool compute_logsumexp,
int64_t custom_mask_type,
c10::optional<double> scale,
const c10::optional<at::Tensor>& seqlen_k,
const c10::optional<int64_t> window_size) {
int64_t B = query.size(0);
int64_t M = query.size(1);
int64_t N = key.size(1);
int64_t Hq = query.size(-2);
int64_t Hkv = key.size(-2);
int64_t K = query.size(-1);
int64_t Kv = value.size(-1);
auto opts = query.options();
at::Tensor logsumexp;
at::Tensor out = at::empty({B, M, Hq, Kv}, opts);
int64_t philox_seed = 0;
int64_t philox_offset = 0;
if (!seqstart_q.has_value()) { // input is batched
if (compute_logsumexp) {
logsumexp = at::empty({B, Hq, M}, opts.dtype(at::kFloat));
}
} else {
if (compute_logsumexp) {
logsumexp = at::empty({1, Hq, M}, opts.dtype(at::kFloat));
}
}
return std::make_tuple(out, logsumexp, philox_seed, philox_offset);
}

} // namespace

TORCH_LIBRARY_IMPL(xformers, CUDA, m) {
m.impl(
TORCH_SELECTIVE_NAME("xformers::efficient_attention_forward_ck"),
TORCH_FN(efficient_attention_forward_ck));
}

TORCH_LIBRARY_IMPL(xformers, Meta, m) {
m.impl(
TORCH_SELECTIVE_NAME("xformers::efficient_attention_forward_ck"),
TORCH_FN(efficient_attention_forward_ck_meta));
}

0 comments on commit 00aba59

Please sign in to comment.