Skip to content

Commit

Permalink
bwaccf32: Accumulate in f32 for bw
Browse files Browse the repository at this point in the history
ghstack-source-id: f713589d43273c6785ba6e3ae92e0974ef8ccfba
Pull Request resolved: #467
  • Loading branch information
danthe3rd committed Dec 6, 2022
1 parent 4c06c79 commit 1924b19
Show file tree
Hide file tree
Showing 5 changed files with 285 additions and 102 deletions.
27 changes: 5 additions & 22 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import math
import random
from dataclasses import dataclass
from typing import Any, Sequence, Type
Expand Down Expand Up @@ -257,25 +256,6 @@ def bmk2bmhk(tensor, num_heads: int) -> torch.Tensor:
)


def backward_error_atol(k, kv_len, q_len, dtype):
atol = 2e-4 + 2e-6 * k * kv_len * math.sqrt(q_len)
rtol = 1e-4
if dtype is torch.half:
atol = 5e-2
rtol = 2e-2
# TODO: Implement f32 accumulation for bw
# Longer sequences mean we iterate more and errors accumulate
atol *= 1.4 ** (max(q_len, kv_len) // 64)
if dtype is torch.bfloat16:
# I've seen (out=-1.9 and ref=-1.0 with flash)
atol = 0.5
rtol = 0.1
# TODO: Implement f32 accumulation for bw
# Longer sequences mean we iterate more and errors accumulate
atol *= 1.4 ** (max(q_len, kv_len) // 64)
return atol, rtol


@pytest.mark.parametrize("fmt", ["BMK", "BMHK"])
@pytest.mark.parametrize("packed", [False, True])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -627,7 +607,9 @@ def test_backward(
del grad_out
del ref

atol, rtol = backward_error_atol(k, kv_len, q_len, dtype)
atol = op.BACKWARD_ERROR_ATOL[dtype]
rtol = op.BACKWARD_ERROR_RTOL[dtype]

grads_ref = []
grads_name = []
if qkv is None:
Expand Down Expand Up @@ -943,7 +925,8 @@ def test_custom_scale(op_device_dtype_B_Mq_Mkv_H_K_Kv):

atol = op.FORWARD_ERROR_ATOL[dtype]
assert_allclose(out.float(), ref.float(), atol=atol)
atol, rtol = backward_error_atol(k, kv_len, q_len, dtype)
atol = op.BACKWARD_ERROR_ATOL[dtype]
rtol = op.BACKWARD_ERROR_RTOL[dtype]
assert_allclose(grad_q, ref_grad_q, atol=atol, rtol=rtol)
assert_allclose(grad_k, ref_grad_k, atol=atol, rtol=rtol)
assert_allclose(grad_v, ref_grad_v, atol=atol, rtol=rtol)
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,11 @@ mem_efficient_attention_backward_cutlass(

int64_t B = query.size(0);
int64_t M = query.size(1);
int64_t Mkv = key.size(1);
int64_t N = key.size(1);
int64_t nH = query.size(2);
int64_t K = query.size(3);
int64_t Kv = value.size(3);

// It does not make sense to use that in practice,
// but let's still make sure we are correct
Expand All @@ -133,6 +135,7 @@ mem_efficient_attention_backward_cutlass(
grad_k = grad_kv_needs_init ? at::zeros_like(key) : at::empty_like(key);
grad_v = grad_kv_needs_init ? at::zeros_like(value) : at::empty_like(value);
}
at::Tensor workspace;

auto launchKernel = [&](auto _k, int computeCapability) {
using Kernel = decltype(_k);
Expand Down Expand Up @@ -205,6 +208,12 @@ mem_efficient_attention_backward_cutlass(
ASSIGN_CHECK_OVERFLOW(p.k_strideH, key.stride(2));
ASSIGN_CHECK_OVERFLOW(p.v_strideH, value.stride(2));

int64_t size_bytes = p.workspace_size();
if (size_bytes) {
workspace =
at::empty({size_bytes}, query.options().dtype(at::ScalarType::Byte));
p.workspace = (float*)workspace.data_ptr();
}
Kernel::check_supported(p);

constexpr auto kernel_fn = attention_kernel_backward_batched<Kernel>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m) {
return (n + m - 1) / m;
}

template <typename integer>
constexpr CUTLASS_HOST_DEVICE integer align_up(integer n, integer m) {
return ((n + m - 1) / m) * m;
}

////////////////////////////////////////////////////////////////////////////////
// Determine the type of GEMM we do (TensorCores or not, Shapes ...)
// TODO: Maybe we could rely on Cutlass's DefaultGemm templates
Expand Down
Loading

0 comments on commit 1924b19

Please sign in to comment.