Skip to content

Commit

Permalink
Improve build time
Browse files Browse the repository at this point in the history
ghstack-source-id: a083a9494486298191eea001ff480a82af6966c7
Pull Request resolved: #539
  • Loading branch information
danthe3rd committed Nov 29, 2022
1 parent c733c99 commit 0bad001
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 101 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
#include <ATen/ScalarOps.h>
#include <ATen/Tensor.h>
#include <ATen/TensorOperators.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/library.h>

#include "kernel_backward.h"

#define DISPATCH_MAXK(func) \
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
#include <ATen/ScalarOps.h>
#include <ATen/Tensor.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/library.h>

#include "kernel_forward.h"

#define DISPATCH_BLOCKSIZE(VALUE_HEAD_DIM, FN) \
Expand Down Expand Up @@ -62,6 +68,57 @@
}

namespace {
template <typename scalar_t>
struct TypeTraits;

template <>
struct TypeTraits<cutlass::half_t> {
using scalar_t = cutlass::half_t;

static constexpr __host__ at::ScalarType atScalarType() {
return at::ScalarType::Half;
}
template <int nDim>
static __host__ at::PackedTensorAccessor32<scalar_t, nDim> packed_accessor(
at::Tensor const& tensor) {
return at::PackedTensorAccessor32<scalar_t, nDim>(
(scalar_t*)(tensor.data_ptr()),
tensor.sizes().data(),
tensor.strides().data());
}
};

template <>
struct TypeTraits<cutlass::bfloat16_t> {
using scalar_t = cutlass::bfloat16_t;

static constexpr __host__ at::ScalarType atScalarType() {
return at::ScalarType::BFloat16;
}
template <int nDim>
static __host__ at::PackedTensorAccessor32<scalar_t, nDim> packed_accessor(
at::Tensor const& tensor) {
return at::PackedTensorAccessor32<scalar_t, nDim>(
(scalar_t*)(tensor.data_ptr()),
tensor.sizes().data(),
tensor.strides().data());
}
};

template <>
struct TypeTraits<float> {
using scalar_t = float;

static constexpr __host__ at::ScalarType atScalarType() {
return at::ScalarType::Float;
}
template <int nDim>
static __host__ at::PackedTensorAccessor32<scalar_t, nDim> packed_accessor(
at::Tensor const& tensor) {
return tensor.packed_accessor32<scalar_t, nDim>();
}
};

/*
There are 2 modes for using this function.
(Mode BMHK) With all the heads having the same seqlen
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && \
threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \
threadIdx.z == 0) { \
printf(msg "\n", __VA_ARGS__); \
printf(msg "\n", ##__VA_ARGS__); \
}
struct __string_view {
char const* data;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,20 @@
////////////////////////////////////////////////////////////////////////////////
// Some helper functions
////////////////////////////////////////////////////////////////////////////////
#define DISPATCH_TYPES(tensor, func) \
{ \
if (query.scalar_type() == at::ScalarType::Float) { \
using scalar_t = float; \
func(); \
} else if (query.scalar_type() == at::ScalarType::Half) { \
using scalar_t = cutlass::half_t; \
func(); \
} else if (query.scalar_type() == at::ScalarType::BFloat16) { \
using scalar_t = cutlass::bfloat16_t; \
func(); \
} else { \
TORCH_CHECK(false, "Only fp32, half & bf16 supported at the moment"); \
} \
#define DISPATCH_TYPES(tensor, func) \
{ \
if (query.scalar_type() == at::ScalarType::Float) { \
using scalar_t = float; \
func(); \
} else if (query.scalar_type() == at::ScalarType::Half) { \
using scalar_t = cutlass::half_t; \
func(); \
} else if (query.scalar_type() == at::ScalarType::BFloat16) { \
using scalar_t = cutlass::bfloat16_t; \
func(); \
} else { \
XFORMERS_CHECK(false, "Only fp32, half & bf16 supported at the moment"); \
} \
}

#define DISPATCH_BOOL(BOOL_V, BOOL_NAME, F) \
Expand Down Expand Up @@ -46,26 +46,27 @@
using ArchTag = cutlass::arch::Sm50; \
func(); \
} else { \
TORCH_CHECK( \
XFORMERS_CHECK( \
false, \
"Your device is too old. We require compute capability >= 50"); \
} \
}

#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \
TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \
TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
TORCH_CHECK(TENSOR.is_contiguous());
#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \
XFORMERS_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \
XFORMERS_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
XFORMERS_CHECK(TENSOR.is_contiguous());

#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \
TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \
TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
TORCH_CHECK( \
#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \
XFORMERS_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \
XFORMERS_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
XFORMERS_CHECK( \
TENSOR.stride(-1) == 1, #TENSOR ": last dimension must be contiguous");

#ifdef HAS_PYTORCH
#ifdef TORCH_CHECK
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
TORCH_CHECK(uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned")
XFORMERS_CHECK( \
uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned")
#define XFORMERS_CHECK TORCH_CHECK
#elif defined(__CUDACC_RTC__)
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
Expand All @@ -77,6 +78,7 @@
return false; \
}
#else
#include <iostream>
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \
std::cerr << #PTR " is not correctly aligned\n"; \
Expand All @@ -89,67 +91,15 @@
}
#endif

#define ASSIGN_CHECK_OVERFLOW(A, B) \
{ \
A = B; \
TORCH_CHECK(B < std::numeric_limits<decltype(A)>::max(), #B " overflows"); \
#define ASSIGN_CHECK_OVERFLOW(A, B) \
{ \
A = B; \
XFORMERS_CHECK( \
B < std::numeric_limits<decltype(A)>::max(), #B " overflows"); \
}

namespace gemm_kernel_utils {

#ifdef HAS_PYTORCH
template <typename scalar_t>
struct TypeTraits;

template <>
struct TypeTraits<cutlass::half_t> {
using scalar_t = cutlass::half_t;

static constexpr __host__ at::ScalarType atScalarType() {
return at::ScalarType::Half;
}
template <int nDim>
static __host__ at::PackedTensorAccessor32<scalar_t, nDim> packed_accessor(
at::Tensor const& tensor) {
return at::PackedTensorAccessor32<scalar_t, nDim>(
(scalar_t*)(tensor.data_ptr()),
tensor.sizes().data(),
tensor.strides().data());
}
};

template <>
struct TypeTraits<cutlass::bfloat16_t> {
using scalar_t = cutlass::bfloat16_t;

static constexpr __host__ at::ScalarType atScalarType() {
return at::ScalarType::BFloat16;
}
template <int nDim>
static __host__ at::PackedTensorAccessor32<scalar_t, nDim> packed_accessor(
at::Tensor const& tensor) {
return at::PackedTensorAccessor32<scalar_t, nDim>(
(scalar_t*)(tensor.data_ptr()),
tensor.sizes().data(),
tensor.strides().data());
}
};

template <>
struct TypeTraits<float> {
using scalar_t = float;

static constexpr __host__ at::ScalarType atScalarType() {
return at::ScalarType::Float;
}
template <int nDim>
static __host__ at::PackedTensorAccessor32<scalar_t, nDim> packed_accessor(
at::Tensor const& tensor) {
return tensor.packed_accessor32<scalar_t, nDim>();
}
};
#endif

template <typename integer>
constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m) {
return (n + m - 1) / m;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
#pragma once

#include <ATen/ATen.h>
#include <torch/library.h>
#include <cmath>
#include <vector>

#include <cuda_fp16.h>

#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include "cutlass/gemm/gemm.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/vector.h"
Expand Down Expand Up @@ -676,18 +671,19 @@ struct AttentionBackwardKernel {
}
};

static void __host__ check_supported(Params const& p) {
static bool __host__ check_supported(Params const& p) {
CHECK_ALIGNED_PTR(p.query_ptr, kMinimumAlignment);
CHECK_ALIGNED_PTR(p.key_ptr, kMinimumAlignment);
CHECK_ALIGNED_PTR(p.value_ptr, kMinimumAlignment);
CHECK_ALIGNED_PTR(p.output_ptr, kMinimumAlignment);
CHECK_ALIGNED_PTR(p.grad_output_ptr, kMinimumAlignment);
TORCH_CHECK(
XFORMERS_CHECK(
p.q_strideH % kMinimumAlignment == 0, "query is not correctly aligned");
TORCH_CHECK(
XFORMERS_CHECK(
p.k_strideH % kMinimumAlignment == 0, "key is not correctly aligned");
TORCH_CHECK(
XFORMERS_CHECK(
p.v_strideH % kMinimumAlignment == 0, "value is not correctly aligned");
return true;
}

static CUTLASS_DEVICE void kernel(Params& p_) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,3 @@
#ifdef HAS_PYTORCH
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/library.h>
#endif

#include <cmath>
#include <vector>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
// but not supported as it worsens perf: older gpus < sm80 don't
// support async tranfers and have to waste registers
CUTLASS_DEVICE
bool set_prologue_done(bool value) {}
void set_prologue_done(bool value) {}
CUTLASS_DEVICE
static void prologue(
typename Base::SharedStorage& shared_storage,
Expand Down Expand Up @@ -695,7 +695,7 @@ class MmaMultistageFromSharedMemory : public MmaBaseFromSharedMemory<
}

CUTLASS_DEVICE
bool set_prologue_done(bool value) {
void set_prologue_done(bool value) {
prologue_done_ = value;
}

Expand Down

0 comments on commit 0bad001

Please sign in to comment.