From 0bad001ddd56c080524d37c84ff58d9cd030ebfd Mon Sep 17 00:00:00 2001 From: danthe3rd Date: Tue, 29 Nov 2022 12:03:43 +0000 Subject: [PATCH] Improve build time ghstack-source-id: a083a9494486298191eea001ff480a82af6966c7 Pull Request resolved: https://github.com/facebookresearch/xformers/pull/539 --- .../attention_backward_generic.cu | 7 ++ .../attention_forward_generic.cu | 57 +++++++++ .../csrc/cuda/mem_eff_attention/debug_utils.h | 2 +- .../mem_eff_attention/gemm_kernel_utils.h | 114 +++++------------- .../cuda/mem_eff_attention/kernel_backward.h | 14 +-- .../cuda/mem_eff_attention/kernel_forward.h | 7 -- .../cuda/mem_eff_attention/mma_from_smem.h | 4 +- 7 files changed, 104 insertions(+), 101 deletions(-) diff --git a/xformers/components/attention/csrc/cuda/mem_eff_attention/attention_backward_generic.cu b/xformers/components/attention/csrc/cuda/mem_eff_attention/attention_backward_generic.cu index dc63b0124..f3fb29ff5 100644 --- a/xformers/components/attention/csrc/cuda/mem_eff_attention/attention_backward_generic.cu +++ b/xformers/components/attention/csrc/cuda/mem_eff_attention/attention_backward_generic.cu @@ -1,3 +1,10 @@ +#include +#include +#include +#include +#include +#include + #include "kernel_backward.h" #define DISPATCH_MAXK(func) \ diff --git a/xformers/components/attention/csrc/cuda/mem_eff_attention/attention_forward_generic.cu b/xformers/components/attention/csrc/cuda/mem_eff_attention/attention_forward_generic.cu index aea452109..620a4f58e 100644 --- a/xformers/components/attention/csrc/cuda/mem_eff_attention/attention_forward_generic.cu +++ b/xformers/components/attention/csrc/cuda/mem_eff_attention/attention_forward_generic.cu @@ -1,3 +1,9 @@ +#include +#include +#include +#include +#include + #include "kernel_forward.h" #define DISPATCH_BLOCKSIZE(VALUE_HEAD_DIM, FN) \ @@ -62,6 +68,57 @@ } namespace { +template +struct TypeTraits; + +template <> +struct TypeTraits { + using scalar_t = cutlass::half_t; + + static constexpr __host__ at::ScalarType atScalarType() { + return at::ScalarType::Half; + } + template + static __host__ at::PackedTensorAccessor32 packed_accessor( + at::Tensor const& tensor) { + return at::PackedTensorAccessor32( + (scalar_t*)(tensor.data_ptr()), + tensor.sizes().data(), + tensor.strides().data()); + } +}; + +template <> +struct TypeTraits { + using scalar_t = cutlass::bfloat16_t; + + static constexpr __host__ at::ScalarType atScalarType() { + return at::ScalarType::BFloat16; + } + template + static __host__ at::PackedTensorAccessor32 packed_accessor( + at::Tensor const& tensor) { + return at::PackedTensorAccessor32( + (scalar_t*)(tensor.data_ptr()), + tensor.sizes().data(), + tensor.strides().data()); + } +}; + +template <> +struct TypeTraits { + using scalar_t = float; + + static constexpr __host__ at::ScalarType atScalarType() { + return at::ScalarType::Float; + } + template + static __host__ at::PackedTensorAccessor32 packed_accessor( + at::Tensor const& tensor) { + return tensor.packed_accessor32(); + } +}; + /* There are 2 modes for using this function. (Mode BMHK) With all the heads having the same seqlen diff --git a/xformers/components/attention/csrc/cuda/mem_eff_attention/debug_utils.h b/xformers/components/attention/csrc/cuda/mem_eff_attention/debug_utils.h index 8e4826611..6c348f1a1 100644 --- a/xformers/components/attention/csrc/cuda/mem_eff_attention/debug_utils.h +++ b/xformers/components/attention/csrc/cuda/mem_eff_attention/debug_utils.h @@ -23,7 +23,7 @@ if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && \ threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \ threadIdx.z == 0) { \ - printf(msg "\n", __VA_ARGS__); \ + printf(msg "\n", ##__VA_ARGS__); \ } struct __string_view { char const* data; diff --git a/xformers/components/attention/csrc/cuda/mem_eff_attention/gemm_kernel_utils.h b/xformers/components/attention/csrc/cuda/mem_eff_attention/gemm_kernel_utils.h index 16ac9aed3..48994e2ab 100644 --- a/xformers/components/attention/csrc/cuda/mem_eff_attention/gemm_kernel_utils.h +++ b/xformers/components/attention/csrc/cuda/mem_eff_attention/gemm_kernel_utils.h @@ -5,20 +5,20 @@ //////////////////////////////////////////////////////////////////////////////// // Some helper functions //////////////////////////////////////////////////////////////////////////////// -#define DISPATCH_TYPES(tensor, func) \ - { \ - if (query.scalar_type() == at::ScalarType::Float) { \ - using scalar_t = float; \ - func(); \ - } else if (query.scalar_type() == at::ScalarType::Half) { \ - using scalar_t = cutlass::half_t; \ - func(); \ - } else if (query.scalar_type() == at::ScalarType::BFloat16) { \ - using scalar_t = cutlass::bfloat16_t; \ - func(); \ - } else { \ - TORCH_CHECK(false, "Only fp32, half & bf16 supported at the moment"); \ - } \ +#define DISPATCH_TYPES(tensor, func) \ + { \ + if (query.scalar_type() == at::ScalarType::Float) { \ + using scalar_t = float; \ + func(); \ + } else if (query.scalar_type() == at::ScalarType::Half) { \ + using scalar_t = cutlass::half_t; \ + func(); \ + } else if (query.scalar_type() == at::ScalarType::BFloat16) { \ + using scalar_t = cutlass::bfloat16_t; \ + func(); \ + } else { \ + XFORMERS_CHECK(false, "Only fp32, half & bf16 supported at the moment"); \ + } \ } #define DISPATCH_BOOL(BOOL_V, BOOL_NAME, F) \ @@ -46,26 +46,27 @@ using ArchTag = cutlass::arch::Sm50; \ func(); \ } else { \ - TORCH_CHECK( \ + XFORMERS_CHECK( \ false, \ "Your device is too old. We require compute capability >= 50"); \ } \ } -#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \ - TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ - TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ - TORCH_CHECK(TENSOR.is_contiguous()); +#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \ + XFORMERS_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ + XFORMERS_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + XFORMERS_CHECK(TENSOR.is_contiguous()); -#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \ - TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ - TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ - TORCH_CHECK( \ +#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \ + XFORMERS_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ + XFORMERS_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + XFORMERS_CHECK( \ TENSOR.stride(-1) == 1, #TENSOR ": last dimension must be contiguous"); -#ifdef HAS_PYTORCH +#ifdef TORCH_CHECK #define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ - TORCH_CHECK(uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned") + XFORMERS_CHECK( \ + uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned") #define XFORMERS_CHECK TORCH_CHECK #elif defined(__CUDACC_RTC__) #define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ @@ -77,6 +78,7 @@ return false; \ } #else +#include #define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \ std::cerr << #PTR " is not correctly aligned\n"; \ @@ -89,67 +91,15 @@ } #endif -#define ASSIGN_CHECK_OVERFLOW(A, B) \ - { \ - A = B; \ - TORCH_CHECK(B < std::numeric_limits::max(), #B " overflows"); \ +#define ASSIGN_CHECK_OVERFLOW(A, B) \ + { \ + A = B; \ + XFORMERS_CHECK( \ + B < std::numeric_limits::max(), #B " overflows"); \ } namespace gemm_kernel_utils { -#ifdef HAS_PYTORCH -template -struct TypeTraits; - -template <> -struct TypeTraits { - using scalar_t = cutlass::half_t; - - static constexpr __host__ at::ScalarType atScalarType() { - return at::ScalarType::Half; - } - template - static __host__ at::PackedTensorAccessor32 packed_accessor( - at::Tensor const& tensor) { - return at::PackedTensorAccessor32( - (scalar_t*)(tensor.data_ptr()), - tensor.sizes().data(), - tensor.strides().data()); - } -}; - -template <> -struct TypeTraits { - using scalar_t = cutlass::bfloat16_t; - - static constexpr __host__ at::ScalarType atScalarType() { - return at::ScalarType::BFloat16; - } - template - static __host__ at::PackedTensorAccessor32 packed_accessor( - at::Tensor const& tensor) { - return at::PackedTensorAccessor32( - (scalar_t*)(tensor.data_ptr()), - tensor.sizes().data(), - tensor.strides().data()); - } -}; - -template <> -struct TypeTraits { - using scalar_t = float; - - static constexpr __host__ at::ScalarType atScalarType() { - return at::ScalarType::Float; - } - template - static __host__ at::PackedTensorAccessor32 packed_accessor( - at::Tensor const& tensor) { - return tensor.packed_accessor32(); - } -}; -#endif - template constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m) { return (n + m - 1) / m; diff --git a/xformers/components/attention/csrc/cuda/mem_eff_attention/kernel_backward.h b/xformers/components/attention/csrc/cuda/mem_eff_attention/kernel_backward.h index 7eb73b255..3d2908704 100644 --- a/xformers/components/attention/csrc/cuda/mem_eff_attention/kernel_backward.h +++ b/xformers/components/attention/csrc/cuda/mem_eff_attention/kernel_backward.h @@ -1,15 +1,10 @@ #pragma once -#include -#include #include #include #include -#include -#include - #include "cutlass/gemm/gemm.h" #include "cutlass/layout/matrix.h" #include "cutlass/layout/vector.h" @@ -676,18 +671,19 @@ struct AttentionBackwardKernel { } }; - static void __host__ check_supported(Params const& p) { + static bool __host__ check_supported(Params const& p) { CHECK_ALIGNED_PTR(p.query_ptr, kMinimumAlignment); CHECK_ALIGNED_PTR(p.key_ptr, kMinimumAlignment); CHECK_ALIGNED_PTR(p.value_ptr, kMinimumAlignment); CHECK_ALIGNED_PTR(p.output_ptr, kMinimumAlignment); CHECK_ALIGNED_PTR(p.grad_output_ptr, kMinimumAlignment); - TORCH_CHECK( + XFORMERS_CHECK( p.q_strideH % kMinimumAlignment == 0, "query is not correctly aligned"); - TORCH_CHECK( + XFORMERS_CHECK( p.k_strideH % kMinimumAlignment == 0, "key is not correctly aligned"); - TORCH_CHECK( + XFORMERS_CHECK( p.v_strideH % kMinimumAlignment == 0, "value is not correctly aligned"); + return true; } static CUTLASS_DEVICE void kernel(Params& p_) { diff --git a/xformers/components/attention/csrc/cuda/mem_eff_attention/kernel_forward.h b/xformers/components/attention/csrc/cuda/mem_eff_attention/kernel_forward.h index 910d845a1..2023924cb 100644 --- a/xformers/components/attention/csrc/cuda/mem_eff_attention/kernel_forward.h +++ b/xformers/components/attention/csrc/cuda/mem_eff_attention/kernel_forward.h @@ -1,10 +1,3 @@ -#ifdef HAS_PYTORCH -#include -#include -#include -#include -#endif - #include #include diff --git a/xformers/components/attention/csrc/cuda/mem_eff_attention/mma_from_smem.h b/xformers/components/attention/csrc/cuda/mem_eff_attention/mma_from_smem.h index e610db3c7..1f55f1434 100644 --- a/xformers/components/attention/csrc/cuda/mem_eff_attention/mma_from_smem.h +++ b/xformers/components/attention/csrc/cuda/mem_eff_attention/mma_from_smem.h @@ -384,7 +384,7 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory< // but not supported as it worsens perf: older gpus < sm80 don't // support async tranfers and have to waste registers CUTLASS_DEVICE - bool set_prologue_done(bool value) {} + void set_prologue_done(bool value) {} CUTLASS_DEVICE static void prologue( typename Base::SharedStorage& shared_storage, @@ -695,7 +695,7 @@ class MmaMultistageFromSharedMemory : public MmaBaseFromSharedMemory< } CUTLASS_DEVICE - bool set_prologue_done(bool value) { + void set_prologue_done(bool value) { prologue_done_ = value; }