Skip to content

Commit

Permalink
Add specialized kernel for cutlassB / K<96
Browse files Browse the repository at this point in the history
ghstack-source-id: fe298b6c7f897aff850298e301a07289e019e7b1
Pull Request resolved: https://github.com/fairinternal/xformers/pull/455

__original_commit__ = fairinternal/xformers@81e708e30a9b4dc44c1b1c19d2cfde5989f52221
  • Loading branch information
danthe3rd authored and xFormers Bot committed Feb 2, 2023
1 parent 615175f commit 82d5881
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 65 deletions.
8 changes: 4 additions & 4 deletions xformers/csrc/attention/cuda/fmha/kernel_backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -830,9 +830,9 @@ struct AttentionBackwardKernel {
};
static void print_size() {
// Field size
#define FSZ(f) int((sizeof(((SharedStorage*)0)->f)))
#define FSZ(f) int((sizeof(((SharedStoragePrologue*)0)->f)))

printf("Total smem: %d bytes\n", int(sizeof(SharedStorage)));
printf("Total smem: %d bytes\n", int(sizeof(SharedStoragePrologue)));
printf(" persistent: %db\n", FSZ(persistent));
printf(" mm_qk_k: %db\n", FSZ(persistent.mm_qk_k));
printf(" p1: %db\n", FSZ(p1));
Expand Down Expand Up @@ -968,8 +968,8 @@ struct AttentionBackwardKernel {
} p6;
};
static void print_size() {
#define FIELD_SIZEOF(f) int((sizeof(((SharedStorage*)0)->f)))
printf("Total smem: %d bytes\n", int(sizeof(SharedStorage)));
#define FIELD_SIZEOF(f) int((sizeof(((SharedStorageNoPrologue*)0)->f)))
printf("Total smem: %d bytes\n", int(sizeof(SharedStorageNoPrologue)));
printf(" persistent: %db\n", FIELD_SIZEOF(persistent));
printf(" p1: %db\n", FIELD_SIZEOF(p1));
printf(" p2: %db\n", FIELD_SIZEOF(p2));
Expand Down
46 changes: 28 additions & 18 deletions xformers/csrc/attention/cuda/fmha/kernels/cutlassB.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ __global__ void __launch_bounds__(
AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, false, true, false, 64, 64, 65536>::kMinBlocksPerSm)
fmha_cutlassB_f16_notaligned_64x64_k65536_dropout_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, false, true, false, 64, 64, 65536>::Params p);

template <typename T> void dispatch_cutlassB_f16_sm50(T cb) {
template <typename T> void dispatch_cutlassB_f16_sm50(T cb, int cc) {
cb(AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, true, false, false, 64, 64, 32>(), fmha_cutlassB_f16_aligned_64x64_k32_sm50);
cb(AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, true, false, false, 64, 64, 64>(), fmha_cutlassB_f16_aligned_64x64_k64_sm50);
cb(AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, true, false, false, 64, 64, 128>(), fmha_cutlassB_f16_aligned_64x64_k128_sm50);
Expand Down Expand Up @@ -154,7 +154,7 @@ __global__ void __launch_bounds__(
AttentionBackwardKernel<cutlass::arch::Sm50, float, false, true, false, 64, 64, 65536>::kMinBlocksPerSm)
fmha_cutlassB_f32_notaligned_64x64_k65536_dropout_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, float, false, true, false, 64, 64, 65536>::Params p);

template <typename T> void dispatch_cutlassB_f32_sm50(T cb) {
template <typename T> void dispatch_cutlassB_f32_sm50(T cb, int cc) {
cb(AttentionBackwardKernel<cutlass::arch::Sm50, float, true, false, false, 64, 64, 32>(), fmha_cutlassB_f32_aligned_64x64_k32_sm50);
cb(AttentionBackwardKernel<cutlass::arch::Sm50, float, true, false, false, 64, 64, 64>(), fmha_cutlassB_f32_aligned_64x64_k64_sm50);
cb(AttentionBackwardKernel<cutlass::arch::Sm50, float, true, false, false, 64, 64, 128>(), fmha_cutlassB_f32_aligned_64x64_k128_sm50);
Expand Down Expand Up @@ -271,7 +271,7 @@ __global__ void __launch_bounds__(
AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, true, false, 64, 64, 65536>::kMinBlocksPerSm)
fmha_cutlassB_f16_notaligned_64x64_k65536_dropout_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, true, false, 64, 64, 65536>::Params p);

template <typename T> void dispatch_cutlassB_f16_sm70(T cb) {
template <typename T> void dispatch_cutlassB_f16_sm70(T cb, int cc) {
cb(AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 64, 64, 32>(), fmha_cutlassB_f16_aligned_64x64_k32_sm70);
cb(AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 64, 64, 64>(), fmha_cutlassB_f16_aligned_64x64_k64_sm70);
cb(AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 128, 64, 128>(), fmha_cutlassB_f16_aligned_128x64_k128_sm70);
Expand Down Expand Up @@ -364,7 +364,7 @@ __global__ void __launch_bounds__(
AttentionBackwardKernel<cutlass::arch::Sm70, float, false, true, false, 64, 64, 65536>::kMinBlocksPerSm)
fmha_cutlassB_f32_notaligned_64x64_k65536_dropout_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, float, false, true, false, 64, 64, 65536>::Params p);

template <typename T> void dispatch_cutlassB_f32_sm70(T cb) {
template <typename T> void dispatch_cutlassB_f32_sm70(T cb, int cc) {
cb(AttentionBackwardKernel<cutlass::arch::Sm70, float, true, false, false, 64, 64, 32>(), fmha_cutlassB_f32_aligned_64x64_k32_sm70);
cb(AttentionBackwardKernel<cutlass::arch::Sm70, float, true, false, false, 64, 64, 64>(), fmha_cutlassB_f32_aligned_64x64_k64_sm70);
cb(AttentionBackwardKernel<cutlass::arch::Sm70, float, true, false, false, 64, 64, 128>(), fmha_cutlassB_f32_aligned_64x64_k128_sm70);
Expand Down Expand Up @@ -481,7 +481,7 @@ __global__ void __launch_bounds__(
AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, true, false, 64, 64, 65536>::kMinBlocksPerSm)
fmha_cutlassB_f16_notaligned_64x64_k65536_dropout_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, true, false, 64, 64, 65536>::Params p);

template <typename T> void dispatch_cutlassB_f16_sm75(T cb) {
template <typename T> void dispatch_cutlassB_f16_sm75(T cb, int cc) {
cb(AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, false, false, 64, 64, 32>(), fmha_cutlassB_f16_aligned_64x64_k32_sm75);
cb(AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, false, false, 64, 64, 64>(), fmha_cutlassB_f16_aligned_64x64_k64_sm75);
cb(AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, false, false, 128, 64, 128>(), fmha_cutlassB_f16_aligned_128x64_k128_sm75);
Expand Down Expand Up @@ -574,7 +574,7 @@ __global__ void __launch_bounds__(
AttentionBackwardKernel<cutlass::arch::Sm75, float, false, true, false, 64, 64, 65536>::kMinBlocksPerSm)
fmha_cutlassB_f32_notaligned_64x64_k65536_dropout_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, float, false, true, false, 64, 64, 65536>::Params p);

template <typename T> void dispatch_cutlassB_f32_sm75(T cb) {
template <typename T> void dispatch_cutlassB_f32_sm75(T cb, int cc) {
cb(AttentionBackwardKernel<cutlass::arch::Sm75, float, true, false, false, 64, 64, 32>(), fmha_cutlassB_f32_aligned_64x64_k32_sm75);
cb(AttentionBackwardKernel<cutlass::arch::Sm75, float, true, false, false, 64, 64, 64>(), fmha_cutlassB_f32_aligned_64x64_k64_sm75);
cb(AttentionBackwardKernel<cutlass::arch::Sm75, float, true, false, false, 64, 64, 128>(), fmha_cutlassB_f32_aligned_64x64_k128_sm75);
Expand Down Expand Up @@ -602,6 +602,10 @@ __global__ void __launch_bounds__(
AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 64, 64, 64>::kNumThreads,
AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 64, 64, 64>::kMinBlocksPerSm)
fmha_cutlassB_bf16_aligned_64x64_k64_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 64, 64, 64>::Params p);
__global__ void __launch_bounds__(
AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 128, 64, 96>::kNumThreads,
AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 128, 64, 96>::kMinBlocksPerSm)
fmha_cutlassB_bf16_aligned_128x64_k96_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 128, 64, 96>::Params p);
__global__ void __launch_bounds__(
AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 128, 128, 128>::kNumThreads,
AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 128, 128, 128>::kMinBlocksPerSm)
Expand Down Expand Up @@ -643,9 +647,10 @@ __global__ void __launch_bounds__(
AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, true, false, 64, 64, 65536>::kMinBlocksPerSm)
fmha_cutlassB_bf16_aligned_64x64_k65536_dropout_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, true, false, 64, 64, 65536>::Params p);

template <typename T> void dispatch_cutlassB_bf16_sm80(T cb) {
template <typename T> void dispatch_cutlassB_bf16_sm80(T cb, int cc) {
cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 64, 64, 32>(), fmha_cutlassB_bf16_aligned_64x64_k32_sm80);
cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 64, 64, 64>(), fmha_cutlassB_bf16_aligned_64x64_k64_sm80);
if (cc == 86 || cc == 89) cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 128, 64, 96>(), fmha_cutlassB_bf16_aligned_128x64_k96_sm80);
cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 128, 128, 128>(), fmha_cutlassB_bf16_aligned_128x128_k128_sm80);
cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, false, 64, 64, 128>(), fmha_cutlassB_bf16_aligned_64x64_k128_sm80);
cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, false, 128, 64, 65536>(), fmha_cutlassB_bf16_aligned_128x64_k65536_sm80);
Expand All @@ -667,6 +672,10 @@ __global__ void __launch_bounds__(
AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 64, 64, 64>::kNumThreads,
AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 64, 64, 64>::kMinBlocksPerSm)
fmha_cutlassB_f16_aligned_64x64_k64_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 64, 64, 64>::Params p);
__global__ void __launch_bounds__(
AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 128, 64, 96>::kNumThreads,
AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 128, 64, 96>::kMinBlocksPerSm)
fmha_cutlassB_f16_aligned_128x64_k96_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 128, 64, 96>::Params p);
__global__ void __launch_bounds__(
AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 128, 128, 128>::kNumThreads,
AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 128, 128, 128>::kMinBlocksPerSm)
Expand Down Expand Up @@ -708,9 +717,10 @@ __global__ void __launch_bounds__(
AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, true, false, 64, 64, 65536>::kMinBlocksPerSm)
fmha_cutlassB_f16_aligned_64x64_k65536_dropout_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, true, false, 64, 64, 65536>::Params p);

template <typename T> void dispatch_cutlassB_f16_sm80(T cb) {
template <typename T> void dispatch_cutlassB_f16_sm80(T cb, int cc) {
cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 64, 64, 32>(), fmha_cutlassB_f16_aligned_64x64_k32_sm80);
cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 64, 64, 64>(), fmha_cutlassB_f16_aligned_64x64_k64_sm80);
if (cc == 86 || cc == 89) cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 128, 64, 96>(), fmha_cutlassB_f16_aligned_128x64_k96_sm80);
cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 128, 128, 128>(), fmha_cutlassB_f16_aligned_128x128_k128_sm80);
cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, false, 64, 64, 128>(), fmha_cutlassB_f16_aligned_64x64_k128_sm80);
cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, false, 128, 64, 65536>(), fmha_cutlassB_f16_aligned_128x64_k65536_sm80);
Expand Down Expand Up @@ -773,7 +783,7 @@ __global__ void __launch_bounds__(
AttentionBackwardKernel<cutlass::arch::Sm80, float, true, true, false, 64, 64, 65536>::kMinBlocksPerSm)
fmha_cutlassB_f32_aligned_64x64_k65536_dropout_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, float, true, true, false, 64, 64, 65536>::Params p);

template <typename T> void dispatch_cutlassB_f32_sm80(T cb) {
template <typename T> void dispatch_cutlassB_f32_sm80(T cb, int cc) {
cb(AttentionBackwardKernel<cutlass::arch::Sm80, float, true, false, false, 64, 64, 32>(), fmha_cutlassB_f32_aligned_64x64_k32_sm80);
cb(AttentionBackwardKernel<cutlass::arch::Sm80, float, true, false, false, 64, 64, 64>(), fmha_cutlassB_f32_aligned_64x64_k64_sm80);
cb(AttentionBackwardKernel<cutlass::arch::Sm80, float, true, false, false, 128, 64, 128>(), fmha_cutlassB_f32_aligned_128x64_k128_sm80);
Expand All @@ -793,31 +803,31 @@ template <typename DT, typename T>
void dispatch_cutlassB(T cb, int cc = 0) {

if (std::is_same<DT, cutlass::half_t>::value && 50 <= cc && cc < 70) {
dispatch_cutlassB_f16_sm50(cb);
dispatch_cutlassB_f16_sm50(cb, cc);
}
if (std::is_same<DT, float>::value && 50 <= cc && cc < 70) {
dispatch_cutlassB_f32_sm50(cb);
dispatch_cutlassB_f32_sm50(cb, cc);
}
if (std::is_same<DT, cutlass::half_t>::value && 70 <= cc && cc < 75) {
dispatch_cutlassB_f16_sm70(cb);
dispatch_cutlassB_f16_sm70(cb, cc);
}
if (std::is_same<DT, float>::value && 70 <= cc && cc < 75) {
dispatch_cutlassB_f32_sm70(cb);
dispatch_cutlassB_f32_sm70(cb, cc);
}
if (std::is_same<DT, cutlass::half_t>::value && 75 <= cc && cc < 80) {
dispatch_cutlassB_f16_sm75(cb);
dispatch_cutlassB_f16_sm75(cb, cc);
}
if (std::is_same<DT, float>::value && 75 <= cc && cc < 80) {
dispatch_cutlassB_f32_sm75(cb);
dispatch_cutlassB_f32_sm75(cb, cc);
}
if (std::is_same<DT, cutlass::bfloat16_t>::value && 80 <= cc && cc < 90) {
dispatch_cutlassB_bf16_sm80(cb);
dispatch_cutlassB_bf16_sm80(cb, cc);
}
if (std::is_same<DT, cutlass::half_t>::value && 80 <= cc && cc < 90) {
dispatch_cutlassB_f16_sm80(cb);
dispatch_cutlassB_f16_sm80(cb, cc);
}
if (std::is_same<DT, float>::value && 80 <= cc && cc < 90) {
dispatch_cutlassB_f32_sm80(cb);
dispatch_cutlassB_f32_sm80(cb, cc);
}
}
#endif // XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD
// This file is auto-generated. See "generate_kernels.py"
#include "../kernel_backward.h"

__global__ void __launch_bounds__(
AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 128, 64, 96>::kNumThreads,
AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 128, 64, 96>::kMinBlocksPerSm)
fmha_cutlassB_bf16_aligned_128x64_k96_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 128, 64, 96>::Params p) {
#ifdef __CUDA_ARCH__
#if __CUDA_ARCH__ >= 800
#if __CUDA_ARCH__ < 900
if (!p.advance_to_block()) {
return;
}
AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 128, 64, 96>::attention_kernel(p);
return;
#endif
#endif
printf(
"FATAL: kernel `fmha_cutlassB_bf16_aligned_128x64_k96_sm80` is for sm80-sm90, but was built for sm%d\n",
int(__CUDA_ARCH__ + 0) / 10);
#endif
}
#endif // XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD
// This file is auto-generated. See "generate_kernels.py"
#include "../kernel_backward.h"

__global__ void __launch_bounds__(
AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 128, 64, 96>::kNumThreads,
AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 128, 64, 96>::kMinBlocksPerSm)
fmha_cutlassB_f16_aligned_128x64_k96_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 128, 64, 96>::Params p) {
#ifdef __CUDA_ARCH__
#if __CUDA_ARCH__ >= 800
#if __CUDA_ARCH__ < 900
if (!p.advance_to_block()) {
return;
}
AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 128, 64, 96>::attention_kernel(p);
return;
#endif
#endif
printf(
"FATAL: kernel `fmha_cutlassB_f16_aligned_128x64_k96_sm80` is for sm80-sm90, but was built for sm%d\n",
int(__CUDA_ARCH__ + 0) / 10);
#endif
}
#endif // XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD
Loading

0 comments on commit 82d5881

Please sign in to comment.